# Notebook for interactive plots of a.a. mutations

## Snakemake input

In [None]:
min_expected_count = snakemake.params.min_expected_count
init_ref_clade = snakemake.params.init_ref_clade
clade_synonyms = snakemake.params.clade_synonyms
heatmap_minimal_domain = snakemake.params.heatmap_minimal_domain
orf1ab_to_nsps = snakemake.params.orf1ab_to_nsps
clade_cluster = snakemake.params.clade_cluster
cluster_founder = snakemake.params.cluster_founder
aamut_by_cluster_csv = snakemake.input.aamut_by_cluster
clade_founder_nts_csv = snakemake.input.clade_founder_nts
outdir = snakemake.output.outdir

## Import packages

In [None]:
import pandas as pd
import altair as alt
import Bio.Seq
import os
import yaml

Some setting

In [None]:
_ = alt.data_transformers.disable_max_rows()

os.makedirs(outdir, exist_ok=True)

Define function that gives clade labels

In [46]:
def clade_label(clade):
    if clade in clade_synonyms:
        return f"{clade} ({clade_synonyms[clade]})"
    else:
        return clade


Mapping clades cluster name to their founder

In [1]:
def cluster_founder_map(cluster):
    if cluster in list(cluster_founder.keys()):
        return cluster_founder[cluster]
    else:
        print("Cluster not defined\n")

Dataframe with clades founder amino acids

In [None]:
clade_founder_nts = pd.read_csv(clade_founder_nts_csv)

In [None]:
# codon translation table
codon_table = {
    f"{nt1}{nt2}{nt3}": str(Bio.Seq.Seq(f"{nt1}{nt2}{nt3}").translate())
    for nt1 in "ACGT" for nt2 in "ACGT" for nt3 in "ACGT"
}

# get clade founder amino-acids
clade_founder_aas = (
    clade_founder_nts
    .query("gene != 'noncoding'")
    [["clade", "gene", "codon", "codon_site"]]
    .drop_duplicates()
    .assign(
        gene=lambda x: x["gene"].str.split(";"),
        codon=lambda x: x["codon"].str.split(";"),
        codon_site=lambda x: x["codon_site"].str.split(";"),
    )
    .explode(["gene", "codon", "codon_site"])
    .assign(
        aa=lambda x: x["codon"].map(codon_table),
        codon_site=lambda x: x["codon_site"].astype(int),
    )
    .rename(columns={"codon_site": "site", "aa": "amino acid"})
    .drop(columns="codon")
    .query("gene != 'ORF1a'")  # this is just subset of ORF1ab
)


In [65]:
# now convert ORF1ab numbers to nsp numbers
orf1ab_to_nsps_df = pd.concat(
    [
        pd.DataFrame(
            [(i, i - start + 1) for i in range(start, end + 1)],
            columns=["ORF1ab_site", "nsp_site"],
        ).assign(nsp=nsp).drop_duplicates()
        for nsp, (start, end) in orf1ab_to_nsps.items()
    ],
    ignore_index=True,
)

clade_founder_aas = pd.concat(
    [
        clade_founder_aas,
        (
            clade_founder_aas
            .query("gene == 'ORF1ab'")
            .merge(
                orf1ab_to_nsps_df,
                left_on="site",
                right_on="ORF1ab_site",
                validate="many_to_one",
            )
            .drop(columns=["gene", "ORF1ab_site", "site"])
            .rename(columns={"nsp_site": "site", "nsp": "gene"})
        ),
    ],
    ignore_index=False,
).drop_duplicates()

In [None]:
clade_founder_aas.head()

Unnamed: 0,clade,gene,site,amino acid
265,19A,ORF1ab,1,M
268,19A,ORF1ab,2,E
271,19A,ORF1ab,3,S
274,19A,ORF1ab,4,L
277,19A,ORF1ab,5,V
...,...,...,...,...
213055,23F,nsp16,298,N
213056,24A,nsp16,298,N
213057,24B,nsp16,298,N
213058,24C,nsp16,298,N


Read input data

In [None]:
aamut = pd.read_csv(aamut_by_cluster_csv)

In [None]:
clust_fnd = list(cluster_founder.values())

In [None]:
clust_founder_aas = clade_founder_aas.query("clade in @clust_fnd")

Unnamed: 0,clade,gene,site,amino acid
30168,20A,ORF1ab,1,M
30171,20A,ORF1ab,2,E
30174,20A,ORF1ab,3,S
30177,20A,ORF1ab,4,L
30180,20A,ORF1ab,5,V
...,...,...,...,...
212982,21K,nsp16,296,V
213001,20A,nsp16,297,N
213012,21K,nsp16,297,N
213031,20A,nsp16,298,N


Adding clade columns to `aamut`

In [97]:
aamut = aamut.assign(clade=lambda x: x['cluster'].map(clade_label(cluster_founder_map)))

Retain only mutations from the cluster founder amino acids

In [None]:
aamut_cl_fnd = (
    aamut
    .rename(columns={'aa_site':'site', 'clade_founder_aa':'amino acid'})
    .merge(clust_founder_aas, on=['clade', 'gene', 'site', 'amino acid'], how='inner', validate='many_to_one')
    .rename(columns={'amino acid': 'ref_aa', 'delta_fitness': 'fitness', 'mutant_aa': 'amino acid'})
    .drop(columns=['clade'])
)

In [None]:
assert sum(aamut_cl_fnd.groupby(['cluster', 'gene', 'site']).apply(lambda x: len(x.ref_aa.unique()) != 1)) == 0

Plotting function

In [None]:
def plot_aa_fitness(gene, fitness_df, clade_founder_df):
    """Plot of amino-acid fitness values."""
    
    # biochemically ordered alphabet
    aas = tuple("RKHDEQNSTYWFAILMVGPC*")
    assert set(fitness_df["amino acid"]).issubset(aas)
    
    sites = fitness_df["site"].unique().tolist()
    
    expected_count_selection = alt.selection_single(
        bind=alt.binding_range(
            min=1,
            max=min(5 * min_expected_count, fitness_df["predicted_count"].quantile(0.9)),
            step=1,
            name="minimum expected count",
        ),
        fields=["cutoff"],
        init={"cutoff": min_expected_count},
    )
   
    site_zoom_brush = alt.selection_interval(
        encodings=["x"],
        mark=alt.BrushConfig(
            stroke="gold", strokeWidth=1.5, fill="yellow", fillOpacity=0.3,
        ),
    )
        
    base = (
        alt.Chart(fitness_df)
        .encode(x=alt.X("site:O", axis=alt.Axis(labelOverlap="parity")))
        .transform_filter(
            alt.datum[f"predicted_count"] >= expected_count_selection["cutoff"] - 1e-6
        )
    )
    
    heatmap_y = alt.Y("amino acid", sort=aas, scale=alt.Scale(domain=aas))
    heatmap_base = (
        base
        .encode(y=heatmap_y)
        .properties(width=alt.Step(12), height=alt.Step(12))
    )
    
    # background fill for missing values in heatmap, imputing dummy stat
    # to get all cells
    heatmap_bg = (
        heatmap_base
        .transform_impute(
            impute="_stat_dummy",
            key="amino acid",
            keyvals=aas,
            groupby=["site"],
            value=None,
        )
        .mark_rect(color="gray", opacity=0.25)
    )

    # Select fitness for clades cluster
    cluster_selection = alt.selection_single(
        fields=["cluster"],
        bind=alt.binding_select(
            options=fitness_df["cluster"].unique(),
            name="Cluster of clades",
        ),
        #init={"cluster": clade_label(init_ref_clade)},
    )
    
    # heatmap showing non-filtered amino acids
    heatmap_aas = (
        heatmap_base
        .encode(
            color=alt.Color(
                "fitness:Q",
                legend=alt.Legend(
                    orient="bottom",
                    titleOrient="left",
                    gradientLength=150,
                    gradientStrokeColor="black",
                    gradientStrokeWidth=0.5,
                ),
                scale=alt.Scale(
                    zero=True,
                    nice=False,
                    type="linear",
                    domainMid=0,
                    domain=alt.DomainUnionWith(heatmap_minimal_domain),
                ),
            ),
            stroke=alt.value("black"),
            tooltip=[
                alt.Tooltip(c, format=".3g")
                if fitness_df[c].dtype == float
                else c
                for c in fitness_df.columns
            ],
        )
        .mark_rect()
        .add_selection(cluster_selection)
        .transform_filter(cluster_selection)
    )

    # place X values at "wildtype"
    wildtype_clade_selection = alt.selection_single(
        fields=["clade"],
        bind=alt.binding_select(
            options=clade_founder_df["clade"].unique(),
            name="X denotes wildtype in",
        ),
        init={"clade": clade_label(init_ref_clade)},
    )
    heatmap_wildtype = (
        alt.Chart(clade_founder_df.query("site in @sites"))
        .encode(
            x=alt.X("site:O"),
            y=heatmap_y,
        )
        .mark_text(text="x", color="black")
        .add_selection(wildtype_clade_selection)
        .transform_filter(wildtype_clade_selection)
        .transform_filter(site_zoom_brush)
    )

    heatmap = (
        (heatmap_bg + heatmap_aas + heatmap_wildtype)
        .add_selection(expected_count_selection)
        .transform_filter(site_zoom_brush)
    )
    
    # make lineplot
    site_statistics = ["mean", "max", "min"]
    site_stat = alt.selection_single(
        bind=alt.binding_radio(
            options=site_statistics,
            name="site fitness statistic",
        ),
        fields=["site fitness statistic"],
        init={"site fitness statistic": site_statistics[0]},
    )
    
    lineplot = (
        base
        .transform_filter(alt.datum["amino acid"] != "*")
        .transform_aggregate(
            **{stat: f"{stat}(fitness)" for stat in site_statistics},
            groupby=["site"],
        )
        .transform_fold(
            site_statistics,
            ["site fitness statistic", "site fitness"],
        )
        .add_selection(site_stat)
        .add_selection(site_zoom_brush)
        .transform_filter(site_stat)
        .encode(
            y=alt.Y("site fitness:Q", axis=alt.Axis(grid=False)),
            tooltip=[
                "site",
                alt.Tooltip("site fitness:Q", format=".3g"),
                "site fitness statistic:N",
            ],
        )
        .mark_area(color="black", opacity=0.7)
        .properties(
            height=75,
            width=min(750, 12 * fitness_df["site"].nunique()),
            title=alt.TitleParams(
                "use this site plot to zoom into regions on the heat map",
                anchor="start",
                fontWeight="normal",
                fontSize=11,
            ),
        )
    )
    
    show_stop = alt.selection_single(
        fields=["_dummy"],
        bind=alt.binding_radio(
            options=["yes", "no"],
            name="show stop in magenta on top site plot",
        ),
        init={"_dummy": "no"},
    )
    
    stopplot = (
        base
        .add_selection(show_stop)
        .transform_filter(alt.datum["amino acid"] == "*")
        .transform_calculate(_dummy="'yes'")
        .transform_filter(show_stop)
        .encode(
            y=alt.Y("fitness", title="site fitness"),
            color=alt.value("#CC79A7"),
            tooltip=["site", alt.Tooltip("fitness", format=".3g", title="stop fitness")],
        )
        .mark_line(point=True, strokeWidth=0.5, strokeDash=[2, 2])
    )
    
    return (
        (alt.layer(lineplot, stopplot) & heatmap)
        .properties(
            title=alt.TitleParams(
                f"estimated fitness of amino acids for SARS-CoV-2 {gene} protein",
                fontSize=15,
            ),
        )
        .resolve_scale(color="independent")
    )

In [None]:
for gene, fitness_df in (
    aamut_cl_fnd
    [['cluster', 'gene', 'site', 'amino acid', 'fitness', 'predicted_count']]
    .groupby("gene")
):
    if gene == "ORF1ab":
        continue  # do not make a plot for ORF1ab, too big
    chart = plot_aa_fitness(gene, fitness_df, clust_founder_aas.query("gene == @gene"))
    gene_name_for_file = gene.replace(" ", "_").replace("(", "").replace(")", "")
    chartfile = os.path.join(outdir, f"{gene.split()[0]}.html")
    print(f"\nSaving chart for {gene} to {chartfile}")
    chart.save(chartfile)
    display(chart)