# Synonymous mutation spectrum
Analyze the synonymous mutation spectrum.

Get input variables from [papermill](https://papermill.readthedocs.io/) parameterization (note next cell is tagged as `parameters`).
So when this notebook is run via `papermill`, those values will be replaced with whatever is in pipeline:

In [1]:
input_csv = "../results/mutation_counts/aggregated.csv"

synonymous_spectra_min_counts = 10000

subset_order = ["all", "USA", "England"]

Import Python modules:

In [2]:
import itertools
import os

import altair as alt

import numpy

import scipy

import pandas as pd

import sklearn.decomposition

Read the mutation counts and assign mutation types:

In [3]:
mutation_counts = pd.read_csv(input_csv).assign(
    mut_type=lambda x: x["nt_mutation"].map(lambda m: f"{m[0]}to{m[-1]}")
)

For each clade plot the top mutations as a fraction of all mutations in that clade, just using the "all" subset.

You can mouseover points to highlight mutations (which will highlight all mutations at that site on all facets), and click the legend to show/hide excluded or non-excluded mutations.

This plot is useful to look at to identifier apparent outlier sites with aberrantly high mutation counts that can then be specified for exclusion (note those specifications are done in the pipeline `config.yaml` file, and also all reversions from clade founder to reference may be excluded):

In [4]:
top_n = 100  # plot this many per clade

mutation_freqs = (
    mutation_counts
    .query("subset == 'all'")
    .sort_values(["clade", "count"], ascending=False)
    .groupby("clade")
    .head(n=top_n)
    .assign(
        freq=lambda x: x["count"] / x.groupby("clade")["count"].transform("sum"),
        rank=lambda x: x.groupby("clade")["freq"].rank(ascending=False, method="first"),
        exclude=lambda x: x["exclude"].map({True: "yes", False: "no"}),
    )
)

select_exclude = alt.selection_multi(
    fields=["exclude"], bind="legend", init=[{"exclude": "yes"}, {"exclude": "no"}],
)

select_site = alt.selection_single(
    fields=["nt_site"], on="mouseover", empty="none",
)

mutation_freqs_chart = (
    alt.Chart(mutation_freqs)
    .encode(
        x="rank",
        y="freq",
        strokeWidth=alt.condition(select_site, alt.value(2), alt.value(0)),
        color=alt.Color("exclude", scale=alt.Scale(domain=["yes", "no"])),
        shape=alt.Shape("synonymous"),
        size=alt.condition(select_site, alt.value(50), alt.value(25)),
        tooltip=["nt_site", "nt_mutation", "count", "freq"],
    )
    .mark_point(filled=True, stroke="black")
    .properties(width=200, height=100)
    .facet("clade", columns=4)
    .add_selection(select_exclude, select_site)
    .transform_filter(select_exclude)
)

mutation_freqs_chart

Tally mutation type counts among **only synonymous** mutations for each clade and subset, also removing any mutations specified for exclusion:

In [5]:
mut_type_counts = (
    mutation_counts
    .query("synonymous")
    .query("not exclude")
    .groupby(["clade", "subset", "mut_type"], as_index=False)
    .aggregate({"count": "sum"})
)

Now also repeat these mutation type counts tally, but any mutations in the top 10 most frequent observed mutation for any clade, not doing any subsetting (just taking subset "all"):

In [6]:
exclude_top_n = 10  # exclude mutations in this top rank for any clade

mut_type_counts_exclude_top = (
    mutation_counts
    .query("synonymous")
    .query("not exclude")
    .query("subset == 'all'")
    .assign(
        clade_rank=lambda x: x.groupby("clade")["count"].rank(ascending=False, method="min"),
        highest_rank=lambda x: x.groupby("nt_mutation")["clade_rank"].transform("min"),
    )
    .query("highest_rank > @exclude_top_n")
    .groupby(["clade", "mut_type"], as_index=False)
    .aggregate({"count": "sum"})
)

Plot total mutation counts for each clade and subset on a log scale.
Also draw a line at our minimum cutoff: we only keep subsets above this cutoff:

In [7]:
clade_counts = (
    mut_type_counts
    .groupby(["clade", "subset"], as_index=False)
    .aggregate({"count": "sum"})
)

clade_counts_chart = (
    alt.Chart(clade_counts)
    .encode(
        x="clade",
        y=alt.Y("count", title="total mutations", scale=alt.Scale(type="log")),
        tooltip=["clade", "subset", "count"],
        color="subset",
    )
    .mark_circle(size=50, opacity=0.7)
    .properties(width=alt.Step(18), height=175)
)

# draw cutoff line
cutoff = (
    alt.Chart(pd.DataFrame({"y": [synonymous_spectra_min_counts]}))
    .encode(y="y")
    .mark_rule(strokeDash=[2, 2])
)

(clade_counts_chart + cutoff).configure_axis(grid=False)

For genome partitioning, we subdivide the genome into halves based on the first and last site with an observed mutation:

In [8]:
n_partitions = 2

min_site = mutation_counts["nt_site"].min()
max_site = mutation_counts["nt_site"].max() + 1
partition_bounds = numpy.linspace(min_site, max_site, n_partitions + 1)

def assign_partition(r):
    """Assign nucleotide mutation to its partition."""
    for i in range(1, n_partitions + 1):
        if partition_bounds[i - 1] <= r < partition_bounds[i]:
            return f"partition {i}"

mutation_counts = (
    mutation_counts
    .assign(partition=lambda x: x["nt_site"].map(assign_partition))
)

Get PCA of mutation spectrum, using only filtered synonymous mutation counts for non-excluded mutations for clades/subsets/partitions with adequate counts.

We do the PCA on three different ways of partitioning the data:

 1. Just looking at the "all" subset for each clade across entire genome.
 2. Looking at all subsets for each clade across entire genome.
 3. Looking at the "all" subset along thirds of the genome.
 
In the plots below, you can mouseover the points for details and click on clades in legends (shift click for multiple clades) to highlight just points for the selected clade(s).
You can also use the scroll bar to only show points with at least the indicated number of total synonymous mutation counts (after filtering):

In [18]:
for title, subsets, partition, exclude_top in [
    ("all samples, whole genome", ["all"], False, False),
    ("all samples, whole genome, +/- top mutations", ["all"], False, False),
    ("by region, whole genome", subset_order, False, False),
    ("all samples, partitioned genome", ["all"], True, False),
]:
    
    filtered_mutation_counts = (
        mutation_counts
        .query("synonymous")
        .query("not exclude")
        .query("subset in @subsets")
    )
    
    if partition:
        filtered_mutation_counts = pd.concat(
            [
                filtered_mutation_counts.assign(partition="all"),
                filtered_mutation_counts,
            ]
        )
    else:
        filtered_mutation_counts = filtered_mutation_counts.assign(partition="all")
        
    mut_type_counts = (
        filtered_mutation_counts
        .groupby(["clade", "subset", "partition", "mut_type"], as_index=False)
        .aggregate({"count": "sum"})
    )
    
    if exclude_top:
        assert all(mut_type_counts["partition"] == "all")
        assert all(mut_type_counts["subset"] == "all")
        mut_type_counts = pd.concat(
            [
                mut_type_counts,
                mut_type_counts_exclude_top.assign(
                    partition="all", subset="all",
                ),
            ]
        )
   
    mut_type_freqs = (
        mut_type_counts
        .assign(
            total_count=lambda x: (
                x.groupby(["clade", "subset", "partition"])["count"].transform("sum")
            ),
            freq=lambda x: x["count"] / x["total_count"],
        )
        .query("total_count >= @synonymous_spectra_min_counts")
        .pivot_table(
            index=["clade", "subset", "partition", "total_count"],
            values="freq",
            columns="mut_type",
            fill_value=0,
        )
    )
    
    pca = sklearn.decomposition.PCA(n_components=2)
    pca_coords = pca.fit_transform(mut_type_freqs.values)
    assert len(pca_coords) == len(mut_type_freqs)

    mut_type_freqs_pca_coords = (
        mut_type_freqs
        .reset_index()
        .assign(
            principal_component_1=pca_coords[:, 0],
            principal_component_2=pca_coords[:, 1],
            log10_total_count=lambda x: numpy.log(x["total_count"]) / numpy.log(10),
        )
    )
    
    # percent variance explained by each component
    pca_var = 100 * pca.explained_variance_ratio_
    
    total_count_selection = alt.selection_single(
        fields=["log10_total_count"],
        init={"log10_total_count": 4},
        bind=alt.binding_range(
            name="minimum log10 total counts",
            min=int(mut_type_freqs_pca_coords["log10_total_count"].min()),
            max=mut_type_freqs_pca_coords["log10_total_count"].max(),
        )
    )
    
    clade_selection = alt.selection_multi(fields=["clade"], bind="legend")

    tooltip = ["clade", "total_count"]
    
    plot_size = 300  # scaled by component variance explained
    
    pca_chart = (
        alt.Chart(mut_type_freqs_pca_coords)
        .encode(
            y=alt.Y(
                "principal_component_1",
                title=f"PC1 ({pca_var[0]:.0f}% variance)",
                scale=alt.Scale(nice=False, padding=10),
                axis=alt.Axis(labels=False, ticks=False),
            ),
            x=alt.X(
                "principal_component_2",
                title=f"PC2 ({pca_var[1]:.0f}% variance)",
                scale=alt.Scale(nice=False, padding=10),
                axis=alt.Axis(labels=False, ticks=False),
            ),
            color=alt.Color("clade", scale=alt.Scale(scheme="viridis")),
            strokeWidth=alt.condition(clade_selection, alt.value(1.5), alt.value(0)),
            opacity=alt.condition(clade_selection, alt.value(0.9), alt.value(0.45)),
            size=alt.condition(clade_selection, alt.value(65), alt.value(45)),
        )
        .mark_point(filled=True, stroke="black")
        .add_selection(total_count_selection, clade_selection)
        .transform_filter(
            total_count_selection.log10_total_count <= alt.datum.log10_total_count
        )
        .configure_axis(grid=False)
        .configure_legend(columns=2)
        .properties(
            height=plot_size, width=plot_size * pca_var[1] / pca_var[0],
            title=title,
        )
    )
    
    if len(subsets) > 1:
        pca_chart = pca_chart.encode(shape=alt.Shape("subset", sort=subset_order))
        tooltip.append("subset")
        
    if partition:
        pca_chart = pca_chart.encode(shape="partition")
        tooltip.append("partition")
                       
    pca_chart = pca_chart.encode(tooltip=tooltip)

    display(pca_chart)
    print("\n\n")

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,mut_type,AtoC,AtoG,AtoT,CtoA,CtoG,CtoT,GtoA,GtoC,GtoT,TtoA,TtoC,TtoG
clade,subset,partition,total_count,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
20A,all,all,39420,0.008447,0.083257,0.013724,0.010198,0.003298,0.556215,0.053729,0.003577,0.067225,0.016514,0.173922,0.009893
20B,all,all,31773,0.006169,0.085135,0.013219,0.010669,0.002612,0.562805,0.052088,0.003084,0.070185,0.015863,0.167249,0.010921
20C,all,all,21585,0.006625,0.090341,0.012787,0.011582,0.001251,0.553811,0.057494,0.003011,0.063702,0.018022,0.171137,0.010239
20E,all,all,24019,0.007286,0.087348,0.013073,0.010408,0.004954,0.578042,0.047379,0.001998,0.054623,0.017445,0.167659,0.009784
20G,all,all,32854,0.00624,0.083278,0.014427,0.010745,0.001065,0.575577,0.052992,0.002131,0.057314,0.017532,0.168655,0.010044







Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,mut_type,AtoC,AtoG,AtoT,CtoA,CtoG,CtoT,GtoA,GtoC,GtoT,TtoA,TtoC,TtoG
clade,subset,partition,total_count,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
20A,all,all,39420,0.008447,0.083257,0.013724,0.010198,0.003298,0.556215,0.053729,0.003577,0.067225,0.016514,0.173922,0.009893
20B,all,all,31773,0.006169,0.085135,0.013219,0.010669,0.002612,0.562805,0.052088,0.003084,0.070185,0.015863,0.167249,0.010921
20C,all,all,21585,0.006625,0.090341,0.012787,0.011582,0.001251,0.553811,0.057494,0.003011,0.063702,0.018022,0.171137,0.010239
20E,all,all,24019,0.007286,0.087348,0.013073,0.010408,0.004954,0.578042,0.047379,0.001998,0.054623,0.017445,0.167659,0.009784
20G,all,all,32854,0.00624,0.083278,0.014427,0.010745,0.001065,0.575577,0.052992,0.002131,0.057314,0.017532,0.168655,0.010044







Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,mut_type,AtoC,AtoG,AtoT,CtoA,CtoG,CtoT,GtoA,GtoC,GtoT,TtoA,TtoC,TtoG
clade,subset,partition,total_count,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
20A,USA,all,19412,0.006697,0.082526,0.013806,0.009942,0.000979,0.569184,0.054708,0.003039,0.067226,0.016536,0.165568,0.009788
20A,all,all,39420,0.008447,0.083257,0.013724,0.010198,0.003298,0.556215,0.053729,0.003577,0.067225,0.016514,0.173922,0.009893
20B,England,all,10018,0.006788,0.085446,0.015772,0.011879,0.001298,0.558894,0.049611,0.003893,0.080755,0.014873,0.162408,0.008385
20B,USA,all,13884,0.005834,0.082829,0.01174,0.01138,0.000864,0.566551,0.055892,0.002593,0.065759,0.016062,0.168611,0.011884
20B,all,all,31773,0.006169,0.085135,0.013219,0.010669,0.002612,0.562805,0.052088,0.003084,0.070185,0.015863,0.167249,0.010921







Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,mut_type,AtoC,AtoG,AtoT,CtoA,CtoG,CtoT,GtoA,GtoC,GtoT,TtoA,TtoC,TtoG
clade,subset,partition,total_count,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
20A,all,all,39420,0.008447,0.083257,0.013724,0.010198,0.003298,0.556215,0.053729,0.003577,0.067225,0.016514,0.173922,0.009893
20A,all,partition 1,19122,0.005857,0.086654,0.009832,0.009361,0.000889,0.581425,0.054701,0.001621,0.058101,0.016944,0.16379,0.010825
20A,all,partition 2,20298,0.010888,0.080057,0.017391,0.010986,0.005567,0.532466,0.052813,0.005419,0.07582,0.01611,0.183466,0.009016
20B,all,all,31773,0.006169,0.085135,0.013219,0.010669,0.002612,0.562805,0.052088,0.003084,0.070185,0.015863,0.167249,0.010921
20B,all,partition 1,15510,0.005545,0.08717,0.009607,0.009865,0.000645,0.591941,0.054997,0.001161,0.054997,0.017408,0.155899,0.010767







Compute statistical significance of differences between clades.
We just do this on "all" sequences for a clade, not partitioning the genomes:

In [9]:
all_mut_type_counts = (
    mut_type_counts.query("subset == 'all'")
    .drop(columns="subset")
    .assign(total_count=lambda x: x.groupby("clade")["count"].transform("sum"))
    .query("total_count >= @synonymous_spectra_min_counts")
    .drop(columns="total_count")
)

wide_all_mut_type_counts = all_mut_type_counts.pivot_table(
    index="mut_type",
    columns="clade",
    values="count",
    fill_value=0,
)

Now run chi2 test.
Also, Bonferroni correct the P-values (this is conservative, but is fine as these P-values are so tiny):

In [10]:
min_p = 1e-20  # plot P-values less than this as this

records = []
for clade1, clade2 in itertools.combinations(wide_all_mut_type_counts.columns, 2):
    chi2, p, dof, _ = scipy.stats.chi2_contingency(
        wide_all_mut_type_counts[[clade1, clade2]]
    )
    records.append((clade1, clade2, p, chi2))
    
chi2_stats = (
    pd.DataFrame(records, columns=["clade_1", "clade_2", "p", "chi2"])
    .assign(
        p=lambda x: x["p"].clip(lower=min_p),
        bonferroni_p=lambda x: (x["p"] * len(x)).clip(upper=1),
    )
)

Plot the Bonferroni corrected P-values.
Note since counts are very large, many comparisons will be highly significant:

In [11]:
p_chart = (
    alt.Chart(chi2_stats)
    .encode(
        x=alt.X("clade_1", title=None),
        y=alt.Y("clade_2", title=None),
        fill=alt.Fill(
            "bonferroni_p",
            title="Bonferroni corrected P-value",
            scale=alt.Scale(type="log", scheme="yelloworangered", reverse=True),
            legend=alt.Legend(orient="top"),
        ),
        tooltip=[
            "clade_1",
            "clade_2",
            alt.Tooltip("p", format=".2g"),
            alt.Tooltip("bonferroni_p", format=".2g"),
            alt.Tooltip("chi2", format=".2g"),
        ],
    )
    .mark_rect(stroke="black")
    .properties(width=alt.Step(14), height=alt.Step(14))
)

p_chart