# Notebook for interactive plots of a.a. mutations

## Snakemake input

In [None]:
min_predicted_count = snakemake.params.min_predicted_count
clade_synonyms = snakemake.params.clade_synonyms
heatmap_minimal_domain = snakemake.params.heatmap_minimal_domain
clade_cluster = snakemake.params.clade_cluster
cluster_founder = snakemake.params.cluster_founder
aamut_by_cluster_csv = snakemake.input.aamut_by_cluster
clade_founder_nts_csv = snakemake.input.clade_founder_nts
cluster_corr_min_count = snakemake.params.cluster_corr_min_count
outdir = snakemake.output.outdir

## Import packages

In [40]:
import pandas as pd
import altair as alt
import Bio.Seq
import os
import yaml
import itertools

Some setting

In [41]:
_ = alt.data_transformers.disable_max_rows()

os.makedirs(outdir, exist_ok=True)

Define function that gives clade labels

In [42]:
def clade_label(clade):
    if clade in clade_synonyms:
        return f"{clade} ({clade_synonyms[clade]})"
    else:
        return clade


## Mapping clades cluster name to their founder

In [43]:
def cluster_founder_map(cluster):
    if cluster in list(cluster_founder.keys()):
        return cluster_founder[cluster]
    else:
        print("Cluster not defined\n")

Dataframe with clades founder amino acids

In [44]:
clade_founder_nts = pd.read_csv(clade_founder_nts_csv)

In [45]:
clade_founder_nts

Unnamed: 0,clade,site,nt,gene,codon,codon_position,codon_site,four_fold_degenerate
0,root,1,G,noncoding,noncoding,noncoding,noncoding,False
1,root,2,T,noncoding,noncoding,noncoding,noncoding,False
2,root,3,A,noncoding,noncoding,noncoding,noncoding,False
3,root,4,A,noncoding,noncoding,noncoding,noncoding,False
4,root,5,A,noncoding,noncoding,noncoding,noncoding,False
...,...,...,...,...,...,...,...,...
15220,root,15221,T,noncoding,noncoding,noncoding,noncoding,False
15221,root,15222,A,noncoding,noncoding,noncoding,noncoding,False
15222,root,15223,T,noncoding,noncoding,noncoding,noncoding,False
15223,root,15224,T,noncoding,noncoding,noncoding,noncoding,False


In [46]:
# codon translation table
codon_table = {
    f"{nt1}{nt2}{nt3}": str(Bio.Seq.Seq(f"{nt1}{nt2}{nt3}").translate())
    for nt1 in "ACGT" for nt2 in "ACGT" for nt3 in "ACGT"
}

# get clade founder amino-acids
clade_founder_aas = (
    clade_founder_nts
    .query("gene != 'noncoding'")
    [["clade", "gene", "codon", "codon_site"]]
    .drop_duplicates()
    .assign(
        gene=lambda x: x["gene"].str.split(";"),
        codon=lambda x: x["codon"].str.split(";"),
        codon_site=lambda x: x["codon_site"].str.split(";"),
    )
    .explode(["gene", "codon", "codon_site"])
    .assign(
        aa=lambda x: x["codon"].map(codon_table),
        codon_site=lambda x: x["codon_site"].astype(int),
    )
    .rename(columns={"codon_site": "site", "aa": "amino acid"})
    .drop(columns="codon")
)


In [47]:
clade_founder_aas.head()

Unnamed: 0,clade,gene,site,amino acid
69,root,NS1,1,M
72,root,NS1,2,G
75,root,NS1,3,S
78,root,NS1,4,N
81,root,NS1,5,S


## Read-in input dataframes

In [48]:
aamut = pd.read_csv(aamut_by_cluster_csv)

In [49]:
aamut

Unnamed: 0,cluster,gene,clade_founder_aa,mutant_aa,aa_site,aa_mutation,expected_count,predicted_count,actual_count,tau_squared,naive_delta_fitness,delta_fitness,uncertainty
0,all,NS1,M,I,1,M1I,0,27.308774,0,0.630213,0.000000,-3.517313,1.011497
1,all,NS1,M,K,1,M1K,0,3.143715,0,0.574855,0.000000,-2.643958,1.594999
2,all,NS1,M,L,1,M1L,0,1.954298,0,0.957700,0.000000,-1.783230,1.266954
3,all,NS1,M,R,1,M1R,0,0.376106,0,0.400043,0.000000,-1.529365,1.815631
4,all,NS1,M,T,1,M1T,0,4.165623,0,0.378555,0.000000,-2.862902,1.532561
...,...,...,...,...,...,...,...,...,...,...,...,...,...
30618,all,L,*,K,2166,*2166K,0,0.812895,0,0.574855,0.000000,-1.892742,1.741187
30619,all,L,*,L,2166,*2166L,0,0.908702,0,0.649366,0.000000,-1.940745,1.737454
30620,all,L,*,Q,2166,*2166Q,0,9.337971,9,0.424801,2.944439,-0.084789,0.705718
30621,all,L,*,S,2166,*2166S,0,0.371016,0,0.384165,0.000000,-1.524180,1.815781


In [50]:
clust_fnd = list(cluster_founder.values())

In [51]:
clust_founder_aas = clade_founder_aas.query("clade in @clust_fnd")

Adding clade columns to `aamut`

In [52]:
aamut = aamut.assign(clade=lambda x: x['cluster'].map(cluster_founder_map))

Retain only mutations from the cluster founder amino acids

In [53]:
aamut_cl_fnd = (
    aamut
    .rename(columns={'aa_site':'site', 'clade_founder_aa':'amino acid'})
    .merge(clust_founder_aas, on=['clade', 'gene', 'site', 'amino acid'], how='inner', validate='many_to_one')
    .rename(columns={'amino acid': 'ref_aa'})
    .drop(columns=['clade'])
)

In [54]:
assert sum(aamut_cl_fnd.groupby(['cluster', 'gene', 'site']).apply(lambda x: len(x.ref_aa.unique()) != 1)) == 0

  assert sum(aamut_cl_fnd.groupby(['cluster', 'gene', 'site']).apply(lambda x: len(x.ref_aa.unique()) != 1)) == 0


Adding Pango lineage to `clade` column in `clust_founder_aas`

In [55]:
clust_founder_aas = clust_founder_aas.assign(clade=lambda x: x["clade"].map(clade_label))
clust_founder_aas

Unnamed: 0,clade,gene,site,amino acid
69,root (root),NS1,1,M
72,root (root),NS1,2,G
75,root (root),NS1,3,S
78,root (root),NS1,4,N
81,root (root),NS1,5,S
...,...,...,...,...
15014,root (root),L,2162,F
15017,root (root),L,2163,Y
15020,root (root),L,2164,N
15023,root (root),L,2165,E


## Plotting

### Scatter plot of fitness effects

Plot correlation in mutation fitness effect ($\Delta f_{xy}$ values) among cluster of clades with large numbers of counts. We only plot amino-acid mutations (not synonymous ones).

#### Comparison among clusters

Plotting function

In [61]:
# def plot_corr_scatters(corr_df_tidy, neutral_counts):
#     "Plot set of correlation scatters."""
    
#     nc_label = str.split(neutral_counts, '_')
#     nc_label = " ".join(nc_label)

#     subsets = corr_df_tidy["subset"].unique()
#     genes = corr_df_tidy["gene"].unique()
    
#     corr_df_wide = pd.merge(
#         *[
#             corr_df_tidy
#             .pivot_table(
#                 index=["gene", "aa_mutation"],
#                 values=prop,
#                 columns="subset",
#             )
#             .reset_index()
#             .rename(columns={subset: f"{prop} {subset}" for subset in subsets})
#             for prop in ["fitness", neutral_counts]
#         ]
#     )
    
#     fitness_min = corr_df_tidy["fitness"].min()
#     fitness_max = corr_df_tidy["fitness"].max()

#     gene_selection = alt.selection_multi(
#         fields=["gene"], bind="legend",
#     )

#     predicted_count_selection = alt.selection_single(
#         bind=alt.binding_range(
#             min=1,
#             max=min(5 * min_predicted_count, corr_df_tidy[neutral_counts].quantile(0.9)),
#             step=1,
#             name=f"minimum {nc_label}",
#         ),
#         fields=["cutoff"],
#         init={"cutoff": min_predicted_count},
#     )
    
#     highlight = alt.selection_single(
#         on="mouseover",
#         fields=["gene", "aa_mutation"],
#         empty="none",
#     )
    
#     corr_charts = []
#     base_chart = alt.Chart(corr_df_wide)
#     for subset1, subset2 in itertools.combinations(subsets, 2):
#         base = (
#             base_chart
#             .encode(
#                 x=alt.X(
#                     f"fitness {subset1}",
#                     title=f"{subset1} fitness effect",
#                     scale=alt.Scale(domain=(fitness_min, fitness_max), nice=False),
#                 ),
#                 y=alt.Y(
#                     f"fitness {subset2}",
#                     title=f"{subset2} fitness effect",
#                     scale=alt.Scale(domain=(fitness_min, fitness_max), nice=False),
#                 ),
#                 tooltip=[
#                     "gene",
#                     "aa_mutation",
#                     alt.Tooltip(
#                         f"fitness {subset1}", title=f"{subset1} fitness effect",
#                     ),
#                     alt.Tooltip(
#                         f"fitness {subset2}", title=f"{subset2} fitness effect",
#                     ),
#                     f"{neutral_counts} {subset1}",
#                     f"{neutral_counts} {subset2}",
#                 ],
#             )
#             .mark_circle(opacity=0.3)
#             .properties(width=200, height=200)
#             .transform_filter(gene_selection)
#             .transform_filter(
#                 (alt.datum[f"{neutral_counts} {subset1}"] >= predicted_count_selection["cutoff"] - 1e-6)
#                 & (alt.datum[f"{neutral_counts} {subset2}"] >= predicted_count_selection["cutoff"] - 1e-6)
#             )
#         )
    
#         scatter = (
#             base
#             .encode(
#                 color=alt.Color(
#                     "gene",
#                     scale=alt.Scale(
#                         domain=genes,
#                         range=["#5778a4"] * len(genes),
#                     ),
#                     legend=alt.Legend(
#                         symbolOpacity=1,
#                         orient="bottom",
#                         title="click / shift-click to select specific genes to show",
#                         titleLimit=500,
#                         columns=6,
#                     ),
#                 ),
#                 size=alt.condition(highlight, alt.value(85), alt.value(30)),
#                 opacity=alt.condition(highlight, alt.value(1), alt.value(0.3)),
#                 strokeWidth=alt.condition(highlight, alt.value(1.5), alt.value(0)),
#             )
#             .mark_circle(stroke="black")
#         )

#         line = alt.Chart(
#             pd.DataFrame({
#                 "x": [fitness_min, fitness_max],
#                 "y": [fitness_min, fitness_max]
#             })
#         ).mark_line(color="orange", clip=True).encode(
#             x="x:Q",
#             y="y:Q",
#         )
    
#         params_r = (
#             base
#             .transform_regression(
#                 f"fitness {subset1}",
#                 f"fitness {subset2}",
#                 params=True,
#             )
#             .transform_calculate(
#                 r=alt.expr.sqrt(alt.datum["rSquared"]),
#                 label='"r = " + format(datum.r, ".3f")',
#             )
#             .mark_text(align="left", color="orange", fontWeight="bold")
#             .encode(
#                 x=alt.value(5),
#                 y=alt.value(8),
#                 text=alt.Text("label:N"),
#             )
#         )
        
#         # show number of points
#         params_n = (
#             base
#             .transform_filter(
#                 (~alt.expr.isNaN(alt.datum[f"fitness {subset1}"]))
#                 & (~alt.expr.isNaN(alt.datum[f"fitness {subset2}"]))
#             )
#             .transform_calculate(dummy=alt.datum[f"fitness {subset1}"])
#             .transform_aggregate(n="valid(dummy)")
#             .transform_calculate(label='"n = " + datum.n')
#             .mark_text(align="left", color="orange", fontWeight="bold")
#             .encode(
#                 x=alt.value(5),
#                 y=alt.value(20),
#                 text=alt.Text("label:N"),
#             )
#         )
    
#         chart = (
#             (scatter + params_r + params_n)
#             .add_selection(gene_selection)
#             .add_selection(predicted_count_selection)
#             .add_selection(highlight)
#         )
    
#         corr_charts.append(chart + line)
    
#     ncols = 4
#     rows = []
#     for i in range(0, len(corr_charts), ncols):
#         rows.append(alt.hconcat(*corr_charts[i: i + ncols]))
#     corr_chart = alt.vconcat(*rows).configure_axis(grid=False)
#     return corr_chart

Scatter for probabilistic fitness

In [62]:
clade_corr_df = (
    aamut_cl_fnd
    .query("aa_mutation.str[0] != aa_mutation.str[-1]")
    # .query("cluster != 'all'")
    [['cluster', 'gene', 'aa_mutation', 'predicted_count', 'delta_fitness']]
    .assign(
        cluster_counts = lambda x: x.groupby(['cluster'])['predicted_count'].transform('sum'),
        clade = lambda x: x['cluster'].map(lambda x: clade_label(cluster_founder_map(x))).str.replace(".", "_", regex=False),
        cluster = lambda x: x['cluster'].str.replace(".", "_", regex=False),
    )
    .query("cluster_counts >= @cluster_corr_min_count")
    .drop(columns="cluster_counts")
    .rename(columns={'cluster': 'subset', 'delta_fitness': 'fitness'})
)
clade_corr_df

Unnamed: 0,subset,gene,aa_mutation,predicted_count,fitness,clade
0,all,NS1,M1I,27.308774,-3.517313,root (root)
1,all,NS1,M1K,3.143715,-2.643958,root (root)
2,all,NS1,M1L,1.954298,-1.783230,root (root)
3,all,NS1,M1R,0.376106,-1.529365,root (root)
4,all,NS1,M1T,4.165623,-2.862902,root (root)
...,...,...,...,...,...,...
30507,all,L,*2166K,0.812895,-1.892742,root (root)
30508,all,L,*2166L,0.908702,-1.940745,root (root)
30509,all,L,*2166Q,9.337971,-0.084789,root (root)
30510,all,L,*2166S,0.371016,-1.524180,root (root)


In [63]:
# cluster_corr_chart = plot_corr_scatters(clade_corr_df, 'predicted_count')
# cluster_corr_chart_file = os.path.join(outdir, "cluster_corr_chart.html")
# print(f"Saving to {cluster_corr_chart_file}")
# cluster_corr_chart.save(cluster_corr_chart_file)
# display(cluster_corr_chart)

Scatter for naive fitness

In [64]:
clade_corr_naive_df = (
    aamut_cl_fnd
    .query("aa_mutation.str[0] != aa_mutation.str[-1]")
    # .query("cluster != 'all'")
    [['cluster', 'gene', 'aa_mutation', 'expected_count', 'predicted_count', 'naive_delta_fitness']]
    .assign(
        cluster_counts = lambda x: x.groupby(['cluster'])['predicted_count'].transform('sum'),
        clade = lambda x: x['cluster'].map(lambda x: clade_label(cluster_founder_map(x))).str.replace(".", "_", regex=False),
        cluster = lambda x: x['cluster'].str.replace(".", "_", regex=False),
    )
    .query("cluster_counts >= @cluster_corr_min_count")
    .drop(columns={'cluster_counts', 'predicted_count'})
    .rename(columns={'cluster': 'subset', 'naive_delta_fitness': 'fitness'})
)
clade_corr_naive_df

Unnamed: 0,subset,gene,aa_mutation,expected_count,fitness,clade
0,all,NS1,M1I,0,0.000000,root (root)
1,all,NS1,M1K,0,0.000000,root (root)
2,all,NS1,M1L,0,0.000000,root (root)
3,all,NS1,M1R,0,0.000000,root (root)
4,all,NS1,M1T,0,0.000000,root (root)
...,...,...,...,...,...,...
30507,all,L,*2166K,0,0.000000,root (root)
30508,all,L,*2166L,0,0.000000,root (root)
30509,all,L,*2166Q,0,2.944439,root (root)
30510,all,L,*2166S,0,0.000000,root (root)


In [66]:
# cluster_corr_chart_naive = plot_corr_scatters(clade_corr_naive_df, 'expected_count')
# cluster_corr_chart_naive_file = os.path.join(outdir, "cluster_corr_chart_naive.html")
# print(f"Saving to {cluster_corr_chart_naive_file}")
# cluster_corr_chart_naive.save(cluster_corr_chart_naive_file)
# display(cluster_corr_chart_naive)

#### Naive Vs novel fitness effects

In [67]:
def plot_fit_scatters(corr_df_tidy):
    "Plot set of correlation scatters."""

    subsets = corr_df_tidy["subset"].unique()
    genes = corr_df_tidy["gene"].unique()

    corr_ref = pd.merge(
        *[
            corr_df_tidy.pivot_table(
            index=["gene", "aa_mutation"],
            values=prop,
            columns="subset",
            )
            .reset_index()
            .rename(columns={subset: f"{prop} {subset}" for subset in subsets})
            for prop in ["delta_fitness", "predicted_count"]
        ]
    )

    corr_naive = pd.merge(
        *[
            corr_df_tidy.pivot_table(
            index=["gene", "aa_mutation"],
            values=prop,
            columns="subset",
            )
            .reset_index()
            .rename(columns={subset: f"{prop} {subset}" for subset in subsets})
            for prop in ["naive_delta_fitness", "expected_count"]
        ]
    )
    
    corr_df_wide = pd.merge(corr_ref, corr_naive)
    
    delta_fitness_min = corr_df_tidy["delta_fitness"].min()
    delta_fitness_max = corr_df_tidy["delta_fitness"].max()
    naive_fitness_min = corr_df_tidy["naive_delta_fitness"].min()
    naive_fitness_max = corr_df_tidy["naive_delta_fitness"].max()
    fitness_min = min(delta_fitness_min, naive_fitness_min)
    fitness_max = max(delta_fitness_max, naive_fitness_max)

    gene_selection = alt.selection_multi(
        fields=["gene"], bind="legend",
    )

    predicted_count_selection = alt.selection_single(
        bind=alt.binding_range(
            min=1,
            max=min(5 * min_predicted_count, corr_df_tidy["predicted_count"].quantile(0.9)),
            step=1,
            name="minimum predicted count",
        ),
        fields=["cutoff"],
        init={"cutoff": min_predicted_count},
    )

    expected_count_selection = alt.selection_single(
        bind=alt.binding_range(
            min=1,
            max=min(5 * min_predicted_count, corr_df_tidy["expected_count"].quantile(0.9)),
            step=1,
            name="minimum expected count",
        ),
        fields=["cutoff"],
        init={"cutoff": min_predicted_count},
    )
    
    highlight = alt.selection_single(
        on="mouseover",
        fields=["gene", "aa_mutation"],
        empty="none",
    )
    
    corr_charts = []
    base_chart = alt.Chart(corr_df_wide)
    for subset in subsets:
        base = (
            base_chart
            .encode(
                x=alt.X(
                    f"naive_delta_fitness {subset}",
                    title=f"{subset} naive fitness effect",
                    scale=alt.Scale(domain=(fitness_min,fitness_max), nice=False),
                ),
                y=alt.Y(
                    f"delta_fitness {subset}",
                    title=f"{subset} fitness effect",
                    scale=alt.Scale(domain=(fitness_min, fitness_max), nice=False),
                ),
                tooltip=[
                    "gene",
                    "aa_mutation",
                    alt.Tooltip(
                        f"delta_fitness {subset}", title=f"{subset} fitness effect",
                    ),
                    alt.Tooltip(
                        f"naive_delta_fitness {subset}", title=f"{subset} naive fitness effect",
                    ),
                    f"predicted_count {subset}",
                    f"expected_count {subset}",
                ],
            )
            .mark_circle(opacity=0.3)
            .properties(width=200, height=200)
            .transform_filter(gene_selection)
            .transform_filter(
                (alt.datum[f"predicted_count {subset}"] >= predicted_count_selection["cutoff"] - 1e-6)
                & (alt.datum[f"expected_count {subset}"] >= expected_count_selection["cutoff"] - 1e-6)
            )
        )
    
        scatter = (
            base
            .encode(
                color=alt.Color(
                    "gene",
                    scale=alt.Scale(
                        domain=genes,
                        range=["#5778a4"] * len(genes),
                    ),
                    legend=alt.Legend(
                        symbolOpacity=1,
                        orient="bottom",
                        title="click / shift-click to select specific genes to show",
                        titleLimit=500,
                        columns=6,
                    ),
                ),
                size=alt.condition(highlight, alt.value(85), alt.value(30)),
                opacity=alt.condition(highlight, alt.value(1), alt.value(0.3)),
                strokeWidth=alt.condition(highlight, alt.value(1.5), alt.value(0)),
            )
            .mark_circle(stroke="black")
        )

        line = alt.Chart(
            pd.DataFrame({
                "x": [fitness_min, fitness_max],
                "y": [fitness_min, fitness_max]
            })
        ).mark_line(color="orange", clip=True).encode(
            x="x:Q",
            y="y:Q",
        )
    
        params_r = (
            base
            .transform_regression(
                f"delta_fitness {subset}",
                f"naive_delta_fitness {subset}",
                params=True,
            )
            .transform_calculate(
                r=alt.expr.sqrt(alt.datum["rSquared"]),
                label='"r = " + format(datum.r, ".3f")',
            )
            .mark_text(align="left", color="orange", fontWeight="bold")
            .encode(
                x=alt.value(5),
                y=alt.value(8),
                text=alt.Text("label:N"),
            )
        )
        
        # show number of points
        params_n = (
            base
            .transform_filter(
                (~alt.expr.isNaN(alt.datum[f"delta_fitness {subset}"]))
                & (~alt.expr.isNaN(alt.datum[f"naive_delta_fitness {subset}"]))
            )
            .transform_calculate(dummy=alt.datum[f"delta_fitness {subset}"])
            .transform_aggregate(n="valid(dummy)")
            .transform_calculate(label='"n = " + datum.n')
            .mark_text(align="left", color="orange", fontWeight="bold")
            .encode(
                x=alt.value(5),
                y=alt.value(20),
                text=alt.Text("label:N"),
            )
        )
    
        chart = (
            (scatter + params_r + params_n)
            .add_selection(gene_selection)
            .add_selection(predicted_count_selection)
            .add_selection(expected_count_selection)
            .add_selection(highlight)
        )
    
        corr_charts.append(chart + line)
    
    ncols = 4
    rows = []
    for i in range(0, len(corr_charts), ncols):
        rows.append(alt.hconcat(*corr_charts[i: i + ncols]))
    corr_chart = alt.vconcat(*rows).configure_axis(grid=False)
    return corr_chart

In [68]:
fit_corr_df = (
    aamut_cl_fnd
    .query("aa_mutation.str[0] != aa_mutation.str[-1]")
    [['cluster', 'gene', 'aa_mutation', 'predicted_count', 'delta_fitness', 'expected_count', 'naive_delta_fitness']]
    .assign(
        cluster_counts = lambda x: x.groupby(['cluster'])['predicted_count'].transform('sum'),
        clade = lambda x: x['cluster'].map(lambda x: clade_label(cluster_founder_map(x))).str.replace(".", "_", regex=False),
        cluster = lambda x: x['cluster'].str.replace(".", "_", regex=False),
    )
    .query("cluster_counts >= @cluster_corr_min_count")
    .drop(columns="cluster_counts")
    .rename(columns={'cluster': 'subset'})
)
fit_corr_df

Unnamed: 0,subset,gene,aa_mutation,predicted_count,delta_fitness,expected_count,naive_delta_fitness,clade
0,all,NS1,M1I,27.308774,-3.517313,0,0.000000,root (root)
1,all,NS1,M1K,3.143715,-2.643958,0,0.000000,root (root)
2,all,NS1,M1L,1.954298,-1.783230,0,0.000000,root (root)
3,all,NS1,M1R,0.376106,-1.529365,0,0.000000,root (root)
4,all,NS1,M1T,4.165623,-2.862902,0,0.000000,root (root)
...,...,...,...,...,...,...,...,...
30507,all,L,*2166K,0.812895,-1.892742,0,0.000000,root (root)
30508,all,L,*2166L,0.908702,-1.940745,0,0.000000,root (root)
30509,all,L,*2166Q,9.337971,-0.084789,0,2.944439,root (root)
30510,all,L,*2166S,0.371016,-1.524180,0,0.000000,root (root)


In [69]:
(fit_corr_df.delta_fitness, fit_corr_df.naive_delta_fitness)

(0       -3.517313
 1       -2.643958
 2       -1.783230
 3       -1.529365
 4       -2.862902
            ...   
 30507   -1.892742
 30508   -1.940745
 30509   -0.084789
 30510   -1.524180
 30511   -1.671538
 Name: delta_fitness, Length: 26370, dtype: float64,
 0        0.000000
 1        0.000000
 2        0.000000
 3        0.000000
 4        0.000000
            ...   
 30507    0.000000
 30508    0.000000
 30509    2.944439
 30510    0.000000
 30511    0.000000
 Name: naive_delta_fitness, Length: 26370, dtype: float64)

In [70]:
fit_corr_chart = plot_fit_scatters(fit_corr_df)
fit_corr_chart_file = os.path.join(outdir, "fit_corr_chart.html")
print(f"Saving to {fit_corr_chart_file}")
fit_corr_chart.save(fit_corr_chart_file)
display(fit_corr_chart)

Saving to ../results/aamut_fitness/plots/fit_corr_chart.html


  col = df[col_name].apply(to_list_if_array, convert_dtype=False)
  col = df[col_name].apply(to_list_if_array, convert_dtype=False)
  col = df[col_name].apply(to_list_if_array, convert_dtype=False)
  col = df[col_name].apply(to_list_if_array, convert_dtype=False)


### Heatmaps of mutational effects

Plotting function

In [None]:
def plot_aa_fitness(gene, fitness_df, clade_founder_df):
    """Plot of amino-acid fitness values."""
    
    # biochemically ordered alphabet
    aas = tuple("RKHDEQNSTYWFAILMVGPC*")
    assert set(fitness_df["amino acid"]).issubset(aas)
    
    sites = fitness_df["site"].unique().tolist()
    
    predicted_count_selection = alt.selection_single(
        bind=alt.binding_range(
            min=1,
            max=min(5 * min_predicted_count, fitness_df["predicted_count"].quantile(0.9)),
            step=1,
            name="minimum predicted count",
        ),
        fields=["cutoff"],
        init={"cutoff": min_predicted_count},
    )
   
    site_zoom_brush = alt.selection_interval(
        encodings=["x"],
        mark=alt.BrushConfig(
            stroke="gold", strokeWidth=1.5, fill="yellow", fillOpacity=0.3,
        ),
    )
        
    base = (
        alt.Chart(fitness_df)
        .encode(x=alt.X("site:O", axis=alt.Axis(labelOverlap="parity")))
        .transform_filter(
            alt.datum[f"predicted_count"] >= predicted_count_selection["cutoff"] - 1e-6
        )
    )
    
    heatmap_y = alt.Y("amino acid", sort=aas, scale=alt.Scale(domain=aas))
    heatmap_base = (
        base
        .encode(y=heatmap_y)
        .properties(width=alt.Step(12), height=alt.Step(12))
    )
    
    # background fill for missing values in heatmap, imputing dummy stat
    # to get all cells
    heatmap_bg = (
        heatmap_base
        .transform_impute(
            impute="_stat_dummy",
            key="amino acid",
            keyvals=aas,
            groupby=["site"],
            value=None,
        )
        .mark_rect(color="gray", opacity=0.25)
    )

    # Select fitness for clades cluster
    cluster_selection = alt.selection_single(
        fields=["cluster"],
        bind=alt.binding_select(
            options=fitness_df["cluster"].unique(),
            name="Cluster of clades",
        ),
        init={"cluster": fitness_df["cluster"].unique().tolist()[-1]},
    )

    # place X values at "wildtype"
    wildtype_clade_selection = alt.selection_single(
        fields=["clade"],
        bind=alt.binding_select(
            options=clade_founder_df["clade"].unique(),
            name="X denotes wildtype in",
        ),
        init={"clade": clade_label(cluster_founder_map(fitness_df["cluster"].unique().tolist()[-1]))},
    )
    heatmap_wildtype = (
        alt.Chart(clade_founder_df.query("site in @sites"))
        .encode(
            x=alt.X("site:O"),
            y=heatmap_y,
        )
        .mark_text(text="x", color="black")
        .add_selection(wildtype_clade_selection)
        .transform_filter(wildtype_clade_selection)
        .transform_filter(site_zoom_brush)
    )
    
    # heatmap showing non-filtered amino acids
    heatmap_aas = (
        heatmap_base
        .encode(
            color=alt.Color(
                "fitness:Q",
                legend=alt.Legend(
                    orient="bottom",
                    titleOrient="left",
                    gradientLength=150,
                    gradientStrokeColor="black",
                    gradientStrokeWidth=0.5,
                ),
                scale=alt.Scale(
                    zero=True,
                    nice=False,
                    type="linear",
                    domainMid=0,
                    domain=alt.DomainUnionWith(heatmap_minimal_domain),
                ),
            ),
            stroke=alt.value("black"),
            tooltip=[
                alt.Tooltip(c, format=".3g")
                if fitness_df[c].dtype == float
                else c
                for c in fitness_df.columns
            ],
        )
        .mark_rect()
        .add_selection(cluster_selection)
        .transform_filter(cluster_selection)
        .transform_filter(site_zoom_brush)
    )

    heatmap = (
        (heatmap_bg + heatmap_aas + heatmap_wildtype)
        .add_selection(predicted_count_selection)
        .transform_filter(site_zoom_brush)
    )
    
    # make lineplot
    site_statistics = ["mean", "max", "min"]
    site_stat = alt.selection_single(
        bind=alt.binding_radio(
            options=site_statistics,
            name="site fitness statistic",
        ),
        fields=["site fitness statistic"],
        init={"site fitness statistic": site_statistics[0]},
    )
    
    lineplot = (
        base
        .transform_filter(alt.datum["amino acid"] != "*")
        .transform_filter(cluster_selection)
        .transform_aggregate(
            **{stat: f"{stat}(fitness)" for stat in site_statistics},
            groupby=["cluster", "site"],
        )
        .transform_fold(
            site_statistics,
            ["site fitness statistic", "site fitness"],
        )
        .add_selection(site_stat)
        .add_selection(site_zoom_brush)
        .transform_filter(site_stat)
        .encode(
            y=alt.Y("site fitness:Q", axis=alt.Axis(grid=False)),
            tooltip=[
                "site",
                alt.Tooltip("site fitness:Q", format=".3g"),
                "site fitness statistic:N",
            ],
        )
        .mark_area(color="black", opacity=0.7)
        .properties(
            height=75,
            width=min(750, 12 * fitness_df["site"].nunique()),
            title=alt.TitleParams(
                "use this site plot to zoom into regions on the heat map",
                anchor="start",
                fontWeight="normal",
                fontSize=11,
            ),
        )
    )
    
    show_stop = alt.selection_single(
        fields=["_dummy"],
        bind=alt.binding_radio(
            options=["yes", "no"],
            name="show stop in magenta on top site plot",
        ),
        init={"_dummy": "no"},
    )
    
    stopplot = (
        base
        .transform_filter(cluster_selection)
        .add_selection(show_stop)
        .transform_filter(alt.datum["amino acid"] == "*")
        .transform_calculate(_dummy="'yes'")
        .transform_filter(show_stop)
        .encode(
            y=alt.Y("fitness", title="site fitness"),
            color=alt.value("#CC79A7"),
            tooltip=["site", alt.Tooltip("fitness", format=".3g", title="stop fitness")],
        )
        .mark_line(point=True, strokeWidth=0.5, strokeDash=[2, 2])
    )
    
    return (
        (alt.layer(lineplot, stopplot) & heatmap)
        .properties(
            title=alt.TitleParams(
                f"estimated fitness of amino acids for SARS-CoV-2 {gene} protein",
                fontSize=15,
            ),
        )
        .resolve_scale(color="independent")
    )

In [None]:
for gene, fitness_df in (
    aamut_cl_fnd
    .rename(columns={'delta_fitness': 'fitness', 'mutant_aa': 'amino acid'})
    [['cluster', 'gene', 'site', 'amino acid', 'fitness', 'predicted_count']]
    .groupby("gene")
):
    chart = plot_aa_fitness(gene, fitness_df, clust_founder_aas.query("gene == @gene"))
    gene_name_for_file = gene.replace(" ", "_").replace("(", "").replace(")", "")
    chartfile = os.path.join(outdir, f"{gene.split()[0]}.html")
    print(f"\nSaving chart for {gene} to {chartfile}")
    chart.save(chartfile)
    display(chart)