# Analyze and plot the amino-acid fitness values

Get variables from `snakemake`:

In [None]:
if "snakemake" in globals() or "snakemake" in locals():
    # get variables from snakemake
    min_expected_count = snakemake.params.min_expected_count
    clade_corr_min_count = snakemake.params.clade_corr_min_count
    init_ref_clade = snakemake.params.init_ref_clade
    clade_synonyms = snakemake.params.clade_synonyms
    heatmap_minimal_domain = snakemake.params.heatmap_minimal_domain
    orf1ab_to_nsps = snakemake.params.orf1ab_to_nsps
    aamut_all_csv = snakemake.input.aamut_all
    aamut_by_subset_csv = snakemake.input.aamut_by_subset
    aamut_by_clade_csv = snakemake.input.aamut_by_clade
    aafitness_csv = snakemake.input.aafitness
    clade_founder_nts_csv = snakemake.input.clade_founder_nts
    outdir = snakemake.output.outdir
else:
    # manually defined if running interactively for debugging
    import yaml

    with open("../config.yaml") as f:
        config = yaml.safe_load(f)

    min_expected_count = config["min_expected_count"]
    clade_corr_min_count = config["clade_corr_min_count"]
    init_ref_clade = config["aa_fitness_init_ref_clade"]
    clade_synonyms = config["clade_synonyms"]
    heatmap_minimal_domain = config["aa_fitness_heatmap_minimal_domain"]
    orf1ab_to_nsps = config["orf1ab_to_nsps"]
    aamut_all_csv = "../results/aa_fitness/aamut_fitness_all.csv"
    aamut_by_subset_csv = "../results/aa_fitness/aamut_fitness_by_subset.csv"
    aamut_by_clade_csv = "../results/aa_fitness/aamut_fitness_by_clade.csv"
    aafitness_csv = "../results/aa_fitness/aa_fitness.csv"
    clade_founder_nts_csv = "../results/clade_founder_nts/clade_founder_nts.csv"
    outdir = "../results/aa_fitness/analysis_plots"

Import Python modules:

In [None]:
import itertools
import os

import altair as alt

import Bio.Seq

import pandas as pd

Some settings:

In [None]:
_ = alt.data_transformers.disable_max_rows()

os.makedirs(outdir, exist_ok=True)

Read input data:

In [None]:
aamut_all = pd.read_csv(aamut_all_csv)
aamut_by_subset = pd.read_csv(aamut_by_subset_csv)
aamut_by_clade = pd.read_csv(aamut_by_clade_csv)
aafitness = pd.read_csv(aafitness_csv)

Define function that gives clade labels:

In [None]:
def clade_label(clade):
    if clade in clade_synonyms:
        return f"{clade} ({clade_synonyms[clade]})"
    else:
        return clade

## Correlation in fitness among subsets and clades
Plot correlation in mutation fitness effect ($\Delta f_{xy}$ values) among subsets (geographic locations) and clades with large numbers of counts.
We just plot ORF1ab and not its constituent nsps to avoid double counting.
We do not include the "all" subset as that contains all the other subsets:

In [None]:
def plot_corr_scatters(corr_df_tidy):
    "Plot set of correlation scatters."""
    
    subsets = corr_df_tidy["subset"].unique()
    genes = corr_df_tidy["gene"].unique()
    
    corr_df_wide = pd.merge(
        *[
            corr_df_tidy
            .pivot_table(
                index=["gene", "aa_mutation"],
                values=prop,
                columns="subset",
            )
            .reset_index()
            .rename(columns={subset: f"{prop} {subset}" for subset in subsets})
            for prop in ["delta_fitness", "expected_count"]
        ]
    )
    
    delta_fitness_min = corr_df_tidy["delta_fitness"].min()
    delta_fitness_max = corr_df_tidy["delta_fitness"].max()

    gene_selection = alt.selection_multi(
        fields=["gene"], bind="legend",
    )

    expected_count_selection = alt.selection_single(
        bind=alt.binding_range(
            min=1,
            max=min(5 * min_expected_count, corr_df_tidy["expected_count"].quantile(0.9)),
            step=1,
            name="minimum expected count",
        ),
        fields=["cutoff"],
        init={"cutoff": min_expected_count},
    )
    
    corr_charts = []
    base_chart = alt.Chart(corr_df_wide)
    for subset1, subset2 in itertools.combinations(subsets, 2):
        base = (
            base_chart
            .encode(
                x=alt.X(
                    f"delta_fitness {subset1}",
                    title=f"{subset1} fitness effect",
                    scale=alt.Scale(domain=(delta_fitness_min, delta_fitness_max), nice=False),
                ),
                y=alt.Y(
                    f"delta_fitness {subset2}",
                    title=f"{subset2} fitness effect",
                    scale=alt.Scale(domain=(delta_fitness_min, delta_fitness_max), nice=False),
                ),
                tooltip=[
                    "gene",
                    "aa_mutation",
                    alt.Tooltip(
                        f"delta_fitness {subset1}", title=f"{subset1} fitness effect",
                    ),
                    alt.Tooltip(
                        f"delta_fitness {subset2}", title=f"{subset2} fitness effect",
                    ),
                    f"expected_count {subset1}",
                    f"expected_count {subset2}",
                ],
            )
            .mark_circle(opacity=0.3)
            .properties(width=200, height=200)
            .transform_filter(gene_selection)
            .transform_filter(
                (alt.datum[f"expected_count {subset1}"] >= expected_count_selection["cutoff"] - 1e-6)
                & (alt.datum[f"expected_count {subset2}"] >= expected_count_selection["cutoff"] - 1e-6)
            )
        )
    
        scatter = (
            base
            .encode(
                color=alt.Color(
                    "gene",
                    scale=alt.Scale(
                        domain=genes,
                        range=["#5778a4"] * len(genes),
                    ),
                    legend=alt.Legend(
                        symbolOpacity=1,
                        orient="bottom",
                        title="click / shift-click to select specific genes to show",
                        titleLimit=500,
                        columns=6,
                    ),
                ),            
            )
            .mark_circle(opacity=0.3)
        )
    
        # regression line and correlation coefficient: https://stackoverflow.com/a/60239699
        line = (
            base
            .transform_regression(
                f"delta_fitness {subset1}",
                f"delta_fitness {subset2}",
                extent=(delta_fitness_min, delta_fitness_max),
            )
            .mark_line(color="orange", clip=True)
        )
    
        params_r = (
            base
            .transform_regression(
                f"delta_fitness {subset1}",
                f"delta_fitness {subset2}",
                params=True,
            )
            .transform_calculate(
                r=alt.expr.sqrt(alt.datum["rSquared"]),
                label='"r = " + format(datum.r, ".3f")',
            )
            .mark_text(align="left", color="orange", fontWeight="bold")
            .encode(
                x=alt.value(5),
                y=alt.value(8),
                text=alt.Text("label:N"),
            )
        )
        
                # show number of points
        params_n = (
            base
            .transform_filter(
                (~alt.expr.isNaN(alt.datum[f"delta_fitness {subset1}"]))
                & (~alt.expr.isNaN(alt.datum[f"delta_fitness {subset2}"]))
            )
            .transform_calculate(dummy=alt.datum[f"delta_fitness {subset1}"])
            .transform_aggregate(n="valid(dummy)")
            .transform_calculate(label='"n = " + datum.n')
            .mark_text(align="left", color="orange", fontWeight="bold")
            .encode(
                x=alt.value(5),
                y=alt.value(20),
                text=alt.Text("label:N"),
            )
        )
    
        chart = (
            (scatter + line + params_r + params_n)
            .add_selection(gene_selection)
            .add_selection(expected_count_selection)
        )
    
        corr_charts.append(chart)
    
    ncols = 4
    rows = []
    for i in range(0, len(corr_charts), ncols):
        rows.append(alt.hconcat(*corr_charts[i: i + ncols]))
    corr_chart = alt.vconcat(*rows).configure_axis(grid=False)
    return corr_chart
    
# correlations for subsets
subset_corr_df = (
    aamut_by_subset
    .query("not subset_of_ORF1ab")
    [["subset", "gene", "aa_mutation", "expected_count", "delta_fitness"]]
    .query("subset != 'all'")
)
subset_corr_chart = plot_corr_scatters(subset_corr_df)
subset_corr_chart_file = os.path.join(outdir, "subset_corr_chart.html")
print(f"Saving to {subset_corr_chart_file}\n\n")
subset_corr_chart.save(subset_corr_chart_file)
display(subset_corr_chart)

# correlations for clades
clade_corr_df = (
    aamut_by_clade
    .query("not subset_of_ORF1ab")
    [["clade", "gene", "aa_mutation", "expected_count", "delta_fitness"]]
    .assign(
        clade_counts=lambda x: x.groupby("clade")["expected_count"].transform("sum"),
        clade=lambda x: x["clade"].map(clade_label).str.replace(".", "_", regex=False),
    )
    .query("clade_counts >= @clade_corr_min_count")
    .drop(columns="clade_counts")
    .rename(columns={"clade": "subset"})
)
clade_corr_chart = plot_corr_scatters(clade_corr_df)
clade_corr_chart_file = os.path.join(outdir, "clade_corr_chart.html")
print(f"Saving to {clade_corr_chart_file}")
clade_corr_chart.save(clade_corr_chart_file)
display(clade_corr_chart)

## Histograms of mutation effects
Histograms of mutation effects.
We make a version both with ORF1ab labeled genes and nsp labeled genes:

In [None]:
gene_selection = alt.selection_multi(
    fields=["gene"], bind="legend",
)

expected_count_selection = alt.selection_single(
    bind=alt.binding_range(
        min=1,
        max=min(5 * min_expected_count, aamut_all["expected_count"].quantile(0.9)),
        step=1,
        name="minimum expected count",
    ),
    fields=["cutoff"],
    init={"cutoff": min_expected_count},
)

for orf1ab_nsp in ["ORF1ab", "nsp"]:
    query_str = (
        "not subset_of_ORF1ab" if orf1ab_nsp == "ORF1ab"
        else "gene != 'ORF1ab'"
    )
    
    hist_df = (
        aamut_all
        .query(query_str)
        .assign(
            mut_type=lambda x: x["aa_mutation"].map(
                lambda m: (
                    "synonymous" if m[0] == m[-1]
                    else "stop" if m[-1] == "*" else "nonsynonymous"
                )
            )
        )
        [["gene", "expected_count", "delta_fitness", "mut_type"]]
    )
    
    delta_fitness_max = hist_df["delta_fitness"].max()
    delta_fitness_min = hist_df["delta_fitness"].min()

    hist_chart = (
        alt.Chart(hist_df)
        .encode(
            x=alt.X(
                "delta_fitness",
                bin=alt.Bin(step=(delta_fitness_max - delta_fitness_min) / 35),
                scale=alt.Scale(domain=(delta_fitness_min, delta_fitness_max)),
                title="fitness effect of mutation",
            ),
            y=alt.Y("count()", title="number of mutations"),
            color=alt.Color(
                "gene",
                scale=alt.Scale(
                    domain=hist_df["gene"].unique(),
                    range=["#5778a4"] * hist_df["gene"].nunique(),
                ),
                legend=alt.Legend(
                    symbolOpacity=1,
                    orient="bottom",
                    title="click / shift-click to select specific genes to show",
                    titleLimit=500,
                    columns=5,
                    padding=5,
                ),
            ),    
            facet=alt.Facet(
                "mut_type",
                title=None,
                columns=1,
                header=alt.Header(labelFontSize=12, labelFontWeight="bold"),
            ),
        )
        .mark_bar(clip=True, stroke="#5778a4")
        .transform_filter(gene_selection)
        .transform_filter(
            alt.datum[f"expected_count"] >= expected_count_selection["cutoff"] - 1e-6
        )
        .add_selection(gene_selection)
        .add_selection(expected_count_selection)
        .properties(width=250, height=120)
        .resolve_scale(y="independent")
    )
    
    chartfile = os.path.join(outdir, f"histogram_{orf1ab_nsp}_naming.html")
    print(f"Saving to {chartfile}")
    hist_chart.save(chartfile)

    display(hist_chart)

## Plot results for individual genes
Now we plot heatmaps of the amino-acid fitness estimates for each protein.

First, get the "wildtype" amino acid at each site in our reference sequence:

In [None]:
# codon translation table
codon_table = {
    f"{nt1}{nt2}{nt3}": str(Bio.Seq.Seq(f"{nt1}{nt2}{nt3}").translate())
    for nt1 in "ACGT" for nt2 in "ACGT" for nt3 in "ACGT"
}

# get clade founder amino-acids
clade_founder_aas = (
    pd.read_csv(clade_founder_nts_csv)
    [["clade", "gene", "codon", "codon_site"]]
    .drop_duplicates()
    .assign(
        gene=lambda x: x["gene"].str.split(";"),
        codon=lambda x: x["codon"].str.split(";"),
        codon_site=lambda x: x["codon_site"].str.split(";"),
    )
    .explode(["gene", "codon", "codon_site"])
    .assign(
        aa=lambda x: x["codon"].map(codon_table),
        codon_site=lambda x: x["codon_site"].astype(int),
        clade=lambda x: x["clade"].map(clade_label),
    )
    .rename(columns={"codon_site": "site", "aa": "amino acid"})
    .drop(columns="codon")
)

# now convert ORF1ab numbers to nsp numbers
orf1ab_to_nsps_df = pd.concat(
    [
        pd.DataFrame(
            [(i, i - start + 1) for i in range(start, end + 1)],
            columns=["ORF1ab_site", "nsp_site"],
        ).assign(nsp=nsp).drop_duplicates()
        for nsp, (start, end) in orf1ab_to_nsps.items()
    ],
    ignore_index=True,
)

clade_founder_aas = pd.concat(
    [
        clade_founder_aas,
        (
            clade_founder_aas
            .query("gene == 'ORF1ab'")
            .merge(
                orf1ab_to_nsps_df,
                left_on="site",
                right_on="ORF1ab_site",
                validate="many_to_one",
            )
            .drop(columns=["gene", "ORF1ab_site", "site"])
            .rename(columns={"nsp_site": "site", "nsp": "gene"})
        ),
    ],
    ignore_index=False,
)

Now plot the heatmaps:

In [None]:
def plot_aa_fitness(gene, fitness_df, clade_founder_df):
    """Plot of amino-acid fitness values."""
    
    # biochemically ordered alphabet
    aas = tuple("RKHDEQNSTYWFAILMVGPC*")
    assert set(fitness_df["amino acid"]).issubset(aas)
    
    expected_count_selection = alt.selection_single(
        bind=alt.binding_range(
            min=1,
            max=min(5 * min_expected_count, fitness_df["expected_count"].quantile(0.9)),
            step=1,
            name="minimum expected count",
        ),
        fields=["cutoff"],
        init={"cutoff": min_expected_count},
    )
   
    site_zoom_brush = alt.selection_interval(
        encodings=["x"],
        mark=alt.BrushConfig(
            stroke="gold", strokeWidth=1.5, fill="yellow", fillOpacity=0.3,
        ),
    )
        
    base = (
        alt.Chart(fitness_df)
        .encode(x=alt.X("site:O", axis=alt.Axis(labelOverlap="parity")))
        .transform_filter(
            alt.datum[f"expected_count"] >= expected_count_selection["cutoff"] - 1e-6
        )
    )
    
    heatmap_y = alt.Y("amino acid", sort=aas, scale=alt.Scale(domain=aas))
    heatmap_base = (
        base
        .encode(y=heatmap_y)
        .properties(width=alt.Step(12), height=alt.Step(12))
    )
    
    # background fill for missing values in heatmap, imputing dummy stat
    # to get all cells
    heatmap_bg = (
        heatmap_base
        .transform_impute(
            impute="_stat_dummy",
            key="amino acid",
            keyvals=aas,
            groupby=["site"],
            value=None,
        )
        .mark_rect(color="gray", opacity=0.25)
    )
    
    # heatmap showing non-filtered amino acids
    heatmap_aas = (
        heatmap_base
        .encode(
            color=alt.Color(
                "fitness:Q",
                legend=alt.Legend(
                    orient="bottom",
                    titleOrient="left",
                    gradientLength=150,
                    gradientStrokeColor="black",
                    gradientStrokeWidth=0.5,
                ),
                scale=alt.Scale(
                    zero=True,
                    nice=False,
                    type="linear",
                    domainMid=0,
                    domain=alt.DomainUnionWith(heatmap_minimal_domain),
                ),
            ),
            stroke=alt.value("black"),
            tooltip=[
                alt.Tooltip(c, format=".3g")
                if fitness_df[c].dtype == float
                else c
                for c in fitness_df.columns
            ],
        )
        .mark_rect()
    )
    
    # place X values at "wildtype"
    wildtype_clade_selection = alt.selection_single(
        fields=["clade"],
        bind=alt.binding_select(
            options=clade_founder_df["clade"].unique(),
            name="X denotes wildtype in",
        ),
        init={"clade": clade_label(init_ref_clade)},
    )
    heatmap_wildtype = (
        alt.Chart(clade_founder_df)
        .encode(
            x=alt.X("site:O"),
            y=heatmap_y,
        )
        .mark_text(text="x", color="black")
        .add_selection(wildtype_clade_selection)
        .transform_filter(wildtype_clade_selection)
        .transform_filter(site_zoom_brush)
    )

    heatmap = (
        (heatmap_bg + heatmap_aas + heatmap_wildtype)
        .add_selection(expected_count_selection)
        .transform_filter(site_zoom_brush)
    )
    
    # make lineplot
    site_statistics = ["mean", "max", "min"]
    site_stat = alt.selection_single(
        bind=alt.binding_radio(
            options=site_statistics,
            name="site fitness statistic",
        ),
        fields=["site fitness statistic"],
        init={"site fitness statistic": site_statistics[0]},
    )
    
    lineplot = (
        base
        .transform_aggregate(
            **{stat: f"{stat}(fitness)" for stat in site_statistics},
            groupby=["site"],
        )
        .transform_fold(
            site_statistics,
            ["site fitness statistic", "site fitness"],
        )
        .add_selection(site_stat)
        .add_selection(site_zoom_brush)
        .transform_filter(site_stat)
        .encode(
            y=alt.Y("site fitness:Q", axis=alt.Axis(grid=False)),
            tooltip=[
                "site",
                alt.Tooltip("site fitness:Q", format=".3g"),
                "site fitness statistic:N",
            ],
        )
        .mark_area(color="black")
        .properties(
            height=75,
            width=min(750, 12 * fitness_df["site"].nunique()),
            title=alt.TitleParams(
                "use this site plot to zoom into regions on the heat map",
                anchor="start",
                fontWeight="normal",
                fontSize=11,
            ),
        )
    )
    
    return (
        (lineplot & heatmap)
        .properties(
            title=alt.TitleParams(
                f"estimated fitness of amino acids for SARS-CoV-2 {gene} protein",
                fontSize=15,
            ),
        )
    )


for gene, fitness_df in (
    aafitness
    [["gene", "aa_site", "aa", "fitness", "expected_count"]]
    .rename(columns={"aa_site": "site", "aa": "amino acid"})
    .groupby("gene")
):
    if gene == "ORF1ab":
        continue  # do not make a plot for ORF1ab, too big
    chart = plot_aa_fitness(gene, fitness_df, clade_founder_aas.query("gene == @gene"))
    gene_name_for_file = gene.replace(" ", "_").replace("(", "").replace(")", "")
    chartfile = os.path.join(outdir, f"{gene.split()[0]}.html")
    print(f"\nSaving chart for {gene} to {chartfile}")
    chart.save(chartfile)
    display(chart)