In [1]:
%load_ext nb_black
# %matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import set_matplotlib_formats

set_matplotlib_formats("svg")

import numpy as np
import pandas as pd

import jax.numpy as jnp
import jax
from jax import vmap
from jax.scipy.special import xlogy

jax.config.update("jax_enable_x64", True)
import scipy.linalg

from collections import defaultdict
import gzip

from Bio import AlignIO
from Bio.Align import MultipleSeqAlignment
from ete3 import Tree
from datetime import datetime, MINYEAR

from vbsky.fasta import SeqData
from vbsky.bdsky import _lognorm_logpdf
from vbsky.prob import VF
from vbsky.prob.distribution import PointMass
from vbsky.prob.transform import (
    Transform,
    Compose,
    Affine,
    Blockwise,
    Positive,
    ZeroOne,
    DiagonalAffine,
    Householder,
    Shift,
    Scale,
    Bounded,
    Exp,
    Softplus,
    Concat,
)
from vbsky.prob.distribution import Constant
from vbsky.prob import arf

from vbsky.plot import *

pos = Compose(DiagonalAffine, Exp)
plus = Compose(DiagonalAffine, Positive)
z1 = Compose(DiagonalAffine, ZeroOne)

  set_matplotlib_formats("svg")


<IPython.core.display.Javascript object>

## Helper Functions

In [4]:
def _params_prior_loglik(params):
    ll = 0
    tau = {"R": params["precision_R"][0], "s": params["precision_s"][0]}
    ll += jax.scipy.stats.gamma.logpdf(tau["R"], a=0.001, scale=1 / 0.001)
    ll += jax.scipy.stats.gamma.logpdf(tau["s"], a=0.001, scale=1 / 0.001)

    ll += jax.scipy.stats.beta.logpdf(params["s"], 0.02, 0.98).sum()

    #     mus = [0.5, 4.1, -2]
    #     sigmas = [1, 0.5, 0.5]

    mus = [1.0, -1.2]
    sigmas = [1, 0.1]

    for i, k in enumerate(["R", "origin"]):
        #     for i, k in enumerate(["R"]):
        log_rate = jnp.log(params[k])
        ll += _lognorm_logpdf(log_rate, mu=mus[i], sigma=sigmas[i]).sum()

    for k in ["R", "s"]:
        log_rate = jnp.log(params[k])
        if k in ["R", "delta", "s"]:
            ll -= (tau[k] / 2) * (jnp.diff(log_rate) ** 2).sum()
            m = len(log_rate)
            ll += xlogy((m - 1) / 2, tau[k] / (2 * jnp.pi))
    return ll


def _params_prior_loglik_less_smooth(params):
    ll = 0
    tau = {"R": params["precision_R"][0], "s": params["precision_s"][0]}
    ll += jax.scipy.stats.gamma.logpdf(tau["R"], a=10, scale=0.1 / 10)
    ll += jax.scipy.stats.gamma.logpdf(tau["s"], a=10, scale=0.1 / 10)

    ll += jax.scipy.stats.beta.logpdf(params["s"], 0.02, 0.98).sum()

    #     mus = [0.5, 4.1, -2]
    #     sigmas = [1, 0.5, 0.5]

    mus = [1.0, -1.2]
    sigmas = [1, 0.1]

    for i, k in enumerate(["R", "origin"]):
        #     for i, k in enumerate(["R"]):
        log_rate = jnp.log(params[k])
        ll += _lognorm_logpdf(log_rate, mu=mus[i], sigma=sigmas[i]).sum()

    for k in ["R", "s"]:
        log_rate = jnp.log(params[k])
        if k in ["R", "delta", "s"]:
            ll -= (tau[k] / 2) * (jnp.diff(log_rate) ** 2).sum()
            m = len(log_rate)
            ll += xlogy((m - 1) / 2, tau[k] / (2 * jnp.pi))
    return ll


def _params_prior_loglik_bias(params):
    ll = 0
    tau = {"R": params["precision_R"][0], "s": params["precision_s"][0]}
    ll += jax.scipy.stats.gamma.logpdf(tau["R"], a=0.001, scale=1 / 0.001)
    ll += jax.scipy.stats.gamma.logpdf(tau["s"], a=10, scale=150 / 10)

    ll += jax.scipy.stats.beta.logpdf(params["s"], 0.02, 0.98).sum()

    #     mus = [0.5, 4.1, -2]
    #     sigmas = [1, 0.5, 0.5]

    mus = [1.0, -3.5]
    sigmas = [1, 0.1]

    #     for i, k in enumerate(["R", "origin"]):
    for i, k in enumerate(["R"]):
        log_rate = jnp.log(params[k])
        ll += _lognorm_logpdf(log_rate, mu=mus[i], sigma=sigmas[i]).sum()

    for k in ["R", "s"]:
        log_rate = jnp.log(params[k])
        if k in ["R", "delta", "s"]:
            ll -= (tau[k] / 2) * (jnp.diff(log_rate) ** 2).sum()
            m = len(log_rate)
            ll += xlogy((m - 1) / 2, tau[k] / (2 * jnp.pi))
    return ll


priors = {
    "original": _params_prior_loglik,
    "less": _params_prior_loglik_less_smooth,
    "bias": _params_prior_loglik_bias,
}


def default_flows(data, m, rate):

    local_flows = [
        {"proportions": Transform(td.n - 2, z1), "root_proportion": Transform(1, z1)}
        for td in data.tds
    ]

    global_flows = VF(
        origin=Transform(1, pos),
        #         origin=Constant(0.3),
        origin_start=Constant(data.earliest),
        # delta=Transform(m, pos),
        delta=Constant(np.repeat(36.5, m)),
        R=Transform(m, pos),
        rho_m=Constant(0),
        s=Transform(m, z1),
        #         s=Constant(np.repeat(0.02, m)),
        # precision=Constant(1.0),
        precision_R=Transform(1, pos),
        precision_s=Transform(1, pos),
        clock_rate=Constant(rate),
    )
    return global_flows, local_flows


def fixed_origin_flows(data, m, rate):

    local_flows = [
        {"proportions": Transform(td.n - 2, z1), "root_proportion": Transform(1, z1)}
        for td in data.tds
    ]

    global_flows = VF(
        origin=Constant(0.3),
        origin_start=Constant(data.earliest),
        delta=Constant(np.repeat(36.5, m)),
        R=Transform(m, pos),
        rho_m=Constant(0),
        s=Transform(m, z1),
        precision_R=Transform(1, pos),
        precision_s=Transform(1, pos),
        clock_rate=Constant(rate),
    )
    return global_flows, local_flows

<IPython.core.display.Javascript object>

## Import Covid sequence data

In [5]:
fasta = {}
fasta["florida"] = AlignIO.read("covid/audacity_fl.fa", format="fasta")
fasta["michigan"] = AlignIO.read("covid/audacity_mi.fa", format="fasta")
fasta["usa"] = AlignIO.read("covid/audacity_usa.fa", format="fasta")

KeyboardInterrupt: 

<IPython.core.display.Javascript object>

In [None]:
data = {k: SeqData(v) for k, v in fasta.items()}

In [None]:
df_variant = pd.read_csv("covid/variant_surveillance.tsv", sep="\t")

In [None]:
strains = {}
for k, v in data.items():
    df_k = df_variant.loc[df_variant["Accession ID"].isin(v.seqs)]
    v_dict = {}
    v_dict["delta"] = SeqData(
        MultipleSeqAlignment(
            df_k.loc[df_k["Variant"].str.contains("Delta", na=False), "Accession ID"]
            .map(v.seqs)
            .tolist()
        ),
        left_end=datetime(2021, 2, 23),
    )
    v_dict["alpha"] = SeqData(
        MultipleSeqAlignment(
            df_k.loc[df_k["Variant"].str.contains("Alpha", na=False), "Accession ID"]
            .map(v.seqs)
            .tolist()
        ),
        left_end=datetime(2020, 11, 15),
    )
    strains[k] = v_dict

In [None]:
# def filter_audacity_tree(subset, name):
#     global_tree = Tree("covid/global.tree", format=1)
#     leaves = set([leaf.name for leaf in global_tree])

#     to_prune = []

#     for s in subset:
#         desc = s.description.split("|")
#         to_prune.append(desc[1])

#     global_tree.prune(to_prune, preserve_branch_length=True)
#     global_tree.write(outfile=f"covid/global_{name}.tree")


# for k1 in strains.keys():
#     for k2, v in strains[k1].items():
#         filter_audacity_tree(v.aln, f"{k1}_{k2}")

## Run Analysis

In [None]:
n_tips = 200
temp_folder = "covid/temp"
tree_path = "covid/temp/subsample.trees"
audacity = True
stratified = False
stratify_by = None

for k1 in strains.keys():
    for k2, v in strains[k1].items():
        print(k1, k2)
        audacity_tree_path = f"covid/global_{k1}_{k2}.tree"

        if stratified:
            stratify_by = defaultdict(list)
            for s, d in zip(v.sids, v.dates):
                days = (v.max_date - d).days
                stratify_by[(d.year, (d.month - 1) // 3)].append(s)

        n_trees = min(int(np.ceil(v.n / n_tips)), 50)

        v.prep_data(
            n_tips,
            n_trees,
            temp_folder,
            tree_path,
            audacity=audacity,
            audacity_tree_path=audacity_tree_path,
            stratified=stratified,
            stratify_by=stratify_by,
        )

In [None]:
rate = 1.10e-3
m = 50

for k in strains.keys():
    for v in strains[k].values():
        global_flows, local_flows = default_flows(v, m, rate)
        v.setup_flows(global_flows, local_flows)

rng = jax.random.PRNGKey(6)
res = {}
n_iter = 10
step_size = 1.0
threshold = 0.001
for k in strains.keys():
    res1 = {}
    for k1, v in strains[k].items():
        print(k, k1)
        res1[k1] = v.loop(
            _params_prior_loglik, rng, n_iter, step_size=step_size, threshold=threshold
        )
    res[k] = res1

In [None]:
fig, axs = plt.subplots(3)
fig.set_size_inches(10, 15)

for i, b in enumerate(["florida", "michigan", "usa"]):
    for strain in ["alpha", "delta"]:
        start, top, end, x0 = plot_helper(res[b][strain], strains[b][strain], 200)
        if b == "usa":
            title = b.upper()
        else:
            title = b.title()
        plot_one(
            res[b][strain],
            axs[i],
            "R",
            m,
            start,
            top,
            end,
            x0,
            strain.title(),
            "fill",
            title,
        )

for ax in axs:
    ax.set_xlim(2020.8, 2021.9)
    ax.set_ylim(0, 2)
axs[0].legend(loc="lower left")

fig.savefig("covid/figures/strains/strain_R.pdf", format="pdf")

In [None]:
fig, axs = plt.subplots(3)
fig.set_size_inches(10, 15)

for i, b in enumerate(["florida", "michigan", "usa"]):
    for strain in ["alpha", "delta"]:
        start, top, end, x0 = plot_helper(res[b][strain], strains[b][strain], 200)
        if b == "usa":
            title = b.upper()
        else:
            title = b.title()
        plot_one(
            res[b][strain],
            axs[i],
            "s",
            m,
            start,
            top,
            end,
            x0,
            strain.title(),
            "fill",
            title,
        )

for ax in axs:
    ax.set_xlim(2020.8, 2021.9)
axs[0].legend(loc="lower left")

fig.savefig("covid/figures/strains/strain_s.pdf", format="pdf")

## BEAST

In [42]:
def sample_beast_tips(
    k1, k2, n=200, stratified=False, stratify_by=None, stratified_handle=""
):
    d = strains[k1][k2]
    dates_dict = {d1: d2 for d1, d2 in zip(d.sids, d.dates)}
    if stratified:
        if stratify_by is None:
            stratify_by = d.sample_months
        key_list = [k for k, v in stratify_by.items() if len(v) > 50]

        seqs = []
        ell = len(key_list)
        mod = n % ell
        floor = n // ell
        for i in range(ell):
            sequence_pool = stratify_by[key_list[i]]
            if i < mod:
                size = floor + 1
            else:
                size = floor
            seqs.append(np.random.choice(sequence_pool, size=size, replace=False))
        seqs = np.concatenate(seqs)
        aln_sample = MultipleSeqAlignment([d.seqs[s] for s in seqs])
    else:
        inds = np.random.choice(len(d.aln._records), size=n, replace=False)
        aln_sample = MultipleSeqAlignment([d.aln[int(inds[0])]])
        for j in range(1, n):
            aln_sample.append(d.aln[int(inds[j])])

    for r in aln_sample:
        sp = r.description.split("|")
        r.id = sp[0] + "_" + dates_dict[r.name].strftime("%Y-%m-%d")
    #         r.description = ""

    handle = f"covid/beast/multistrain/{k1}_{k2}_beast"
    if stratified:
        handle += stratified_handle
    handle += ".fa"
    with open(handle, "w") as output_handle:
        count = AlignIO.write(aln_sample, output_handle, "fasta")

<IPython.core.display.Javascript object>

In [43]:
stratified = False
# stratified_handle = "_quarters"
stratified_handle = ""
n = 200

for k1, v in strains.items():
    #     stratify_by = defaultdict(list)
    #     for s, d in zip(v.sids, v.dates):
    #         days = (v.max_date - d).days
    #         stratify_by[(d.year, (d.month - 1) // 3)].append(s)
    for k2 in v.keys():
        # sample_beast_tips(k, n, stratified, stratify_by, stratified_handle)
        sample_beast_tips(k1, k2, n, stratified)

<IPython.core.display.Javascript object>

In [45]:
import xml.etree.ElementTree as ET

<IPython.core.display.Javascript object>

In [48]:
def edit_template(k1, k2):
    aln = AlignIO.read(f"covid/beast/multistrain/{k1}_{k2}_beast.fa", "fasta")

    tree = ET.parse(f"covid/beast/multistrain/template.xml")
    root = tree.getroot()

    value = ""
    for rec, child in zip(aln, root.find("data")):
        seq = str(rec.seq)
        child.set("id", f"seq_{rec.name}")
        child.set("taxon", rec.name)
        child.set("value", seq)
        value += f"{rec.name}={rec.name.split('_')[1]},"
    root.find("run").find("state").find("tree").find("trait").set("value", value[:-1])

    for log in root.find("run").findall("logger"):
        if log.get("id") == "tracelog":
            log.set("fileName", f"covid/beast/multistrain/logs/{k1}_{k2}_beast.log")
        if log.get("id") == "treelog.t:template":
            log.set("fileName", f"covid/beast/multistrain/trees/{k1}_{k2}_beast.log")

    xml_str = ET.tostring(root, encoding="unicode")
    #     xml_str.replace('sim6"', f'sim{i}"')
    with open(f"covid/beast/multistrain/{k1}_{k2}_beast.xml", "w") as xml:
        xml.write(xml_str)

<IPython.core.display.Javascript object>

In [49]:
for k1 in strains.keys():
    for k2 in strains[k1].keys():
        edit_template(k1, k2)

<IPython.core.display.Javascript object>