# Analyze and plot the amino-acid fitness values

Get variables from `snakemake`:

In [None]:
min_expected_count = snakemake.params.min_expected_count
aamut_all_csv = snakemake.input.aamut_all
aamut_by_subset_csv = snakemake.input.aamut_by_subset
outdir = snakemake.output.outdir

In [None]:
# this cell used for interactive development, commented out for pipeline
#min_expected_count = 20
#aamut_all_csv = "../results/aa_fitness/aamut_fitness_all.csv"
#aamut_by_subset_csv = "../results/aa_fitness/aamut_fitness_by_subset.csv"
#outdir = "../results/aa_fitness/analysis_plots"

Import Python modules:

In [None]:
import itertools
import os

import altair as alt

import pandas as pd

import polyclonal.plot

Some settings:

In [None]:
_ = alt.data_transformers.disable_max_rows()

os.makedirs(outdir, exist_ok=True)

Read input data:

In [None]:
aamut_all = pd.read_csv(aamut_all_csv)
aamut_by_subset = pd.read_csv(aamut_by_subset_csv)

## Correlation in fitness among subses
Plot correlation in fitness values among subsets, just plotting ORF1ab and not its constituent nsps to avoid double counting:

In [None]:
corr_df = (
    aamut_by_subset
    .query("not subset_of_ORF1ab")
    [["subset", "gene", "aa_mutation", "expected_count", "actual_count", "delta_fitness"]]
)

delta_fitness_min = corr_df["delta_fitness"].min()
delta_fitness_max = corr_df["delta_fitness"].max()

gene_selection = alt.selection_point(
    fields=["gene"], bind="legend",
)

expected_count_selection = alt.selection_point(
    bind=alt.binding_range(
        min=1,
        max=min(5 * min_expected_count, corr_df["expected_count"].quantile(0.8)),
        step=1,
        name="minimum expected count",
    ),
    fields=["cutoff"],
    init=[{"cutoff": min_expected_count}],
)

corr_charts = []
for subset1, subset2 in itertools.combinations(corr_df["subset"].unique(), 2):
    df = (
        corr_df
        .query("subset == @subset1")
        .drop(columns="subset")
        .merge(
            corr_df.query("subset == @subset2").drop(columns="subset"),
            on=["gene", "aa_mutation"],
            validate="one_to_one",
            suffixes=[f" {subset1}", f" {subset2}"],
        )
    )
    
    base = (
        alt.Chart(df)
        .encode(
            x=alt.X(
                f"delta_fitness {subset1}",
                scale=alt.Scale(domain=(delta_fitness_min, delta_fitness_max)),
            ),
            y=alt.Y(
                f"delta_fitness {subset2}",
                scale=alt.Scale(domain=(delta_fitness_min, delta_fitness_max)),
            ),
            tooltip=df.columns.tolist(),
        )
        .mark_circle(opacity=0.3)
        .properties(width=200, height=200)
        .transform_filter(gene_selection)
        .transform_filter(
            (alt.datum[f"expected_count {subset1}"] >= expected_count_selection["cutoff"] - 1e-6)
            & (alt.datum[f"expected_count {subset2}"] >= expected_count_selection["cutoff"] - 1e-6)
        )
    )
    
    scatter = (
        base
        .encode(
            color=alt.Color(
                "gene",
                scale=alt.Scale(
                    domain=corr_df["gene"].unique(),
                    range=["#5778a4"] * corr_df["gene"].nunique(),
                ),
                legend=alt.Legend(
                    symbolOpacity=1,
                    orient="bottom",
                    title="click / shift-click to select specific genes to show",
                    titleLimit=500,
                ),
            ),            
        )
        .mark_circle(opacity=0.3)
    )
    
    # regression line and correlation coefficient: https://stackoverflow.com/a/60239699
    line = (
        base
        .transform_regression(
            f"delta_fitness {subset1}",
            f"delta_fitness {subset2}",
            extent=(delta_fitness_min, delta_fitness_max),
        )
        .mark_line(color="orange", clip=True)
    )
    
    params = (
        base
        .transform_regression(
            f"delta_fitness {subset1}",
            f"delta_fitness {subset2}",
            params=True,
        )
        .mark_text(align="left", color="orange")
        .encode(
            x=alt.value(10),
            y=alt.value(10),
            text=alt.Text("rSquared:Q", format=".3f"),
        )
    )
    
    chart = (
        (scatter + line + params)
        .add_parameter(gene_selection)
        .add_parameter(expected_count_selection)
    )
    
    corr_charts.append(chart)
    
corr_chart = alt.hconcat(*corr_charts)

corr_chart_file = os.path.join(outdir, "subset_corr_chart.html")
print(f"Saving to {corr_chart_file}")
corr_chart.save(corr_chart_file)

corr_chart

## Histograms of mutation effects
Histograms of mutation effects.
We make a version both with ORF1ab labeled genes and nsp labeled genes:

In [None]:
for orf1ab_nsp in ["ORF1ab", "nsp"]:
    query_str = (
        "not subset_of_ORF1ab" if orf1ab_nsp == "ORF1ab"
        else "gene != 'ORF1ab'"
    )
    
    hist_df = (
        aamut_all
        .query(query_str)
        .assign(
            mut_type=lambda x: x["aa_mutation"].map(
                lambda m: (
                    "synonymous" if m[0] == m[-1]
                    else "stop" if m[-1] == "*" else "nonsynonymous"
                )
            )
        )
        [["gene", "expected_count", "delta_fitness", "mut_type"]]
    )

    hist_chart = (
        alt.Chart(hist_df)
        .encode(
            x=alt.X(
                "delta_fitness",
                bin=alt.Bin(step=(delta_fitness_max - delta_fitness_min) / 35),
                scale=alt.Scale(domain=(delta_fitness_min, delta_fitness_max)),
                title="fitness effect of mutation",
            ),
            y=alt.Y("count()", title="number of mutations"),
            color=alt.Color(
                "gene",
                scale=alt.Scale(
                    domain=hist_df["gene"].unique(),
                    range=["#5778a4"] * corr_df["gene"].nunique(),
                ),
                legend=alt.Legend(
                    symbolOpacity=1,
                    orient="bottom",
                    title="click / shift-click to select specific genes to show",
                    titleLimit=500,
                    columns=5,
                    padding=5,
                ),
            ),    
            facet=alt.Facet(
                "mut_type",
                title=None,
                columns=1,
                header=alt.Header(labelFontSize=12, labelFontWeight="bold"),
            ),
        )
        .mark_bar(clip=True, stroke="#5778a4")
        .transform_filter(gene_selection)
        .transform_filter(
            alt.datum[f"expected_count"] >= expected_count_selection["cutoff"] - 1e-6
        )
        .add_parameter(gene_selection)
        .add_parameter(expected_count_selection)
        .properties(width=250, height=120)
        .resolve_scale(y="independent")
    )
    
    chartfile = os.path.join(outdir, f"histogram_{orf1ab_nsp}_naming.html")
    print(f"Saving to {chartfile}")
    hist_chart.save(chartfile)

    display(hist_chart)

## Plot results for individual genes
We use the [lineplot_and_heatmap function from the polyclonal package](https://jbloomlab.github.io/polyclonal/polyclonal.plot.html#polyclonal.plot.lineplot_and_heatmap)

**This is still in progress.**