# Analyze fitness effects of mutations versus terminal and non-terminal counts
This notebook is designed to see if the relative number of counts of mutations on terminal and non-terminal nodes scales with the fitness effects of those mutations.

Import Python modules:

In [None]:
import os

import altair as alt

import numpy

import pandas as pd

import yaml

_ = alt.data_transformers.disable_max_rows()

Now get variables from `snakemake`:

In [None]:
if "snakemake" not in globals() and "snakemake" not in locals():
    # variables set manually for interactive debugging
    aamut_all_csv = "../results/aa_fitness/aamut_fitness_all.csv"
    chartfile = "../results/fitness_vs_terminal/fitness_vs_terminal.html"
    
    with open("../config.yaml") as f:
        config = yaml.safe_load(f)
    min_expected_count = config["min_expected_count"]
    min_actual_count = config["terminal_min_actual_count"]
    pseudocount = config["terminal_pseudocount"]
    
else:
    # get variables from `snakemake` when running pipeline
    aamut_all_csv = snakemake.input.aamut_all
    chartfile = snakemake.output.chart
    min_expected_count = snakemake.params.min_expected_count
    min_actual_count = snakemake.params.min_actual_count
    pseudocount = snakemake.params.pseudocount

Read variables:

In [None]:
aamut_all = pd.read_csv(aamut_all_csv)

Get a simplified data frame to plot, removing the duplicate entries for ORF1ab genes:

In [None]:
df = (
    aamut_all
    .query("gene != 'ORF1ab'")
    .rename(columns={"aa_mutation": "mutation"})
    [["gene", "mutation", "delta_fitness", "expected_count", "actual_count", "count_terminal", "count_non_terminal"]]
    .assign(
        non_terminal_to_terminal=lambda x: numpy.log(
            (x["count_non_terminal"] + pseudocount) / (x["count_terminal"] + pseudocount)
        ),
        mut_type=lambda x: numpy.where(
            x["mutation"].str[0] == x["mutation"].str[-1],
            "synonymous",
            numpy.where(x["mutation"].str[-1] == "*", "stop", "nonsynonymous"),
        ),
    )
)

df

Now plot a scatter plot:

In [None]:
expected_count_selection = alt.selection_single(
    bind=alt.binding_range(
        min=1,
        max=min(4 * min_expected_count, df["expected_count"].quantile(0.8)),
        step=1,
        name="minimum expected count",
    ),
    fields=["cutoff"],
    init={"cutoff": min_expected_count},
)

actual_count_selection = alt.selection_single(
    bind=alt.binding_range(
        min=1,
        max=min(5 * min_actual_count, df["actual_count"].quantile(0.8)),
        step=1,
        name="minimum actual count",
    ),
    fields=["cutoff"],
    init={"cutoff": min_actual_count},
)   

gene_selection = alt.selection_multi(
    fields=["gene"], bind="legend",
)

highlight = alt.selection_single(
    on="mouseover",
    fields=["gene", "mutation"],
    empty="none",
)

base = (
    alt.Chart(df)
    .encode(
        x=alt.X(
            "non_terminal_to_terminal",
            title="log non-terminal to terminal counts",
        ),
        y=alt.Y("delta_fitness", title="fitness effect of mutation"),
    )
    .transform_filter(gene_selection)
    .transform_filter(alt.datum["expected_count"] >= expected_count_selection["cutoff"] - 1e-6)
    .transform_filter(alt.datum["actual_count"] >= actual_count_selection["cutoff"] - 1e-6)
)

scatter = (
    base
    .encode(
        color=alt.Color(
            "gene",
            scale=alt.Scale(
                domain=df["gene"].unique(),
                range=["#5778a4"] * df["gene"].nunique(),
            ),
            legend=alt.Legend(
                symbolOpacity=1,
                orient="bottom",
                title="click / shift-click to select specific genes to show",
                titleLimit=500,
                columns=6,
            ),
        ),
        size=alt.condition(highlight, alt.value(85), alt.value(30)),
        opacity=alt.condition(highlight, alt.value(1), alt.value(0.3)),
        strokeWidth=alt.condition(highlight, alt.value(1.5), alt.value(0)),
        tooltip=df.columns.tolist(),
    )
    .mark_circle(stroke="black")
)

# regression line and correlation coefficient: https://stackoverflow.com/a/60239699
line = (
    base
    .transform_regression(
        "non_terminal_to_terminal",
        "delta_fitness",
    )
    .mark_line(color="orange", clip=True)
)

params_r = (
    base
    .transform_regression(
        "delta_fitness",
        "non_terminal_to_terminal",
        params=True,
    )
    .transform_calculate(label='"r2 = " + format(datum.rSquared, ".3f")')
    .mark_text(align="left", color="orange", fontWeight="bold")
    .encode(
        x=alt.value(5),
        y=alt.value(8),
        text=alt.Text("label:N"),
    )
)

params_n = (
    base
    .transform_aggregate(n="valid(delta_fitness)")
    .transform_calculate(label='"n = " + datum.n')
    .mark_text(align="left", color="orange", fontWeight="bold")
    .encode(
        x=alt.value(5),
        y=alt.value(20),
        text=alt.Text("label:N"),
    )
)

chart = (
    (scatter + line + params_r + params_n)
    .add_selection(gene_selection)
    .add_selection(expected_count_selection)
    .add_selection(actual_count_selection)
    .add_selection(highlight)
    .properties(width=175, height=175)
    .facet(
        column=alt.Column(
            "mut_type",
            sort=["nonsynonymous", "synonymous", "stop"],
            title=None,
            header=alt.Header(labelFontSize=12, labelFontStyle="bold"),
        ),
    )
    .configure_axis(grid=False)
)

print(f"Saving to {chartfile}")
os.makedirs(os.path.dirname(chartfile), exist_ok=True)
chart.save(chartfile)

chart