# Summarize actual versus expected counts
Plot average per-site mutations for each nucleotide type across all clades (summed) and sites (averaged).

Import Python modules:

In [None]:
import altair as alt

import numpy

import pandas as pd

Now get variables from `snakemake`:

In [None]:
if "snakemake" not in globals() and "snakemake" not in locals():
    # variables set manually for interactive debugging
    counts_csv = "../results/expected_vs_actual_mut_counts/expected_vs_actual_mut_counts.csv"
    chartfile = "../results/expected_vs_actual_mut_counts/avg_counts.html"
    
else:
    # get variables from `snakemake` when running pipeline
    counts_csv = snakemake.input.csv
    chartfile = snakemake.output.chart

Get average per-site expected, synonymous, nonsynonymous, and stop counts across all clades and sites:

In [None]:
counts = pd.read_csv(counts_csv, low_memory=False)

In [None]:
avg_counts = (
    counts
    .query("subset == 'all'")
    .query("not exclude")
    .assign(
        mut_type=lambda x: numpy.where(
            x["clade_founder_aa"] == x["mutant_aa"],
            "synonymous",
            numpy.where(x["mutant_aa"] == "*", "stop", "nonsynonymous")
        ),
        mut=lambda x: x["nt_mutation"].str[0] + " to " + x["nt_mutation"].str[-1],
    )
    .groupby(["mut_type", "mut"], as_index=False)
    .aggregate(
        total_expected_count=pd.NamedAgg("expected_count", "sum"),
        total_actual_count=pd.NamedAgg("actual_count", "sum"),
        n_mutations=pd.NamedAgg("nt_mutation", "nunique"),
    )
    .assign(
        expected=lambda x: x["total_expected_count"] / x["n_mutations"],
        actual=lambda x: x["total_actual_count"] / x["n_mutations"],
    )
    .melt(
        id_vars=["mut_type", "mut"],
        value_vars=["expected", "actual"],
        var_name="count_type",
        value_name="counts",
    )
    .assign(
        mut_type=lambda x: numpy.where(
            x["count_type"] == "actual", x["mut_type"], "expected (4-fold degenerate)",
        ),
    )
    .groupby(["mut_type", "mut"], as_index=False)
    .aggregate({"counts": "mean"})
)

avg_counts

In [None]:
height = 120

mut_type_order = ["expected (4-fold degenerate)", "synonymous", "nonsynonymous", "stop"]

mut_type_selection = alt.selection_multi(fields=["mut_type"], bind="legend")

avg_count_chart = (
    alt.Chart(avg_counts)
    .encode(
        x=alt.X(
            "mut_type",
            title=None,
            axis=alt.Axis(labels=False, ticks=True, values=["synonymous"]),
            scale=alt.Scale(
                domain=["dummy", *mut_type_order],
            ),
        ),
        y=alt.Y("counts", title="average counts per site"),
        column=alt.Column(
            "mut",
            title=None,
            header=alt.Header(
                labelOrient="bottom",
                labelAngle=-90,
                labelAlign="right",
                labelBaseline="middle",
                labelPadding=height + 8,
            ),
            spacing=0,
        ),
        color=alt.Color(
            "mut_type",
            title="type of count",
            sort=mut_type_order,
            scale=alt.Scale(
                range=["#999999", "#009E73", "#56B4E9", "#E69F00"],
            ),
        ),
        tooltip=[
            alt.Tooltip("mut_type", title="type of count"),
            alt.Tooltip("mut", title="mutation type"),
            alt.Tooltip("counts", title="average counts", format=".2f"),
        ],
        opacity=alt.condition(mut_type_selection, alt.value(1), alt.value(0.25)),
    )
    .mark_bar()
    .configure_axis(grid=False)
    .configure_view(stroke=None)
    .properties(height=height, width=42)
    .add_selection(mut_type_selection)
)

avg_count_chart.save(chartfile)

avg_count_chart