# Plot substitution counts

Import Python modules:

In [None]:
import os

import altair as alt

import Bio.SeqIO

import pandas as pd

Get variables from `snakemake`:

In [None]:
sub_counts_csv = snakemake.input.sub_counts_csv
ref_prots_fasta = snakemake.input.ref_prots
sub_count_plotsdir = snakemake.output.sub_count_plotsdir

Read substitution counts:

In [None]:
sub_counts = pd.read_csv(sub_counts_csv)

sub_counts

Get reference sites and amino acid identity:

In [None]:
ref_sites = pd.concat(
    [
        pd.DataFrame(
            enumerate(str(seq.seq), start=1),
            columns=["site", "reference_aa"],
        ).assign(protein=seq.id)
        for seq in Bio.SeqIO.parse(ref_prots_fasta, "fasta")
    ],
    ignore_index=True
)

Get data frame with counts and sites:

In [None]:
df = (
    ref_sites
    .merge(sub_counts, how="outer", validate="one_to_many", on=["site", "protein"])
    .assign(
        count=lambda x: x["count"].fillna(0).astype(int),
        sub_to=lambda x: x["substitution"].str[-1],
        substitution_to=lambda x: x.apply(
            lambda r: f"{r['sub_to']} ({r['count']})" if r["count"] else "", axis=1,
        ),
    )
    .sort_values("count", ascending=False)
    .groupby(["protein", "site", "reference_aa"], as_index=False, sort=True)
    .aggregate(
        count=pd.NamedAgg("count", "sum"),
        substitution_to=pd.NamedAgg("substitution_to", lambda s: ", ".join(s.values)),
    )
)

df

Now make the plots:

In [None]:
alt.data_transformers.disable_max_rows()

os.makedirs(sub_count_plotsdir, exist_ok=True)

for prot, prot_df in df.groupby("protein"):

    zoom_brush = alt.selection_interval(
        encodings=["x"],
        mark=alt.BrushConfig(stroke="black", strokeWidth=2)
    )
    
    zoom_bar = (
        alt.Chart(prot_df[["site"]].drop_duplicates())
        .mark_rect(color="gray")
        .encode(x=alt.X("site:O"))
        .add_selection(zoom_brush)
        .properties(
            height=15,
            width=min(400, 3 * len(prot_df)),
            title="site zoom bar",
        )
    )
    
    lineplot = (
        alt.Chart(prot_df)
        .encode(
            x=alt.X("site:O"),
            y=alt.Y("count:Q", scale=alt.Scale(type="symlog", constant=50)),
            tooltip=prot_df.columns.tolist(),
        )
        .mark_line(point=True, size=1)
        .properties(
            width=3 * len(prot_df),
            height=200,
            title=f"{prot} protein",
        )
        .add_selection(zoom_brush)
        .transform_filter(zoom_brush)
    )
    
    chart = (lineplot & zoom_bar).configure_title(anchor="start").configure_axis(labelOverlap="parity")
    display(chart)
    
    outfile = os.path.join(sub_count_plotsdir, f"{prot}.html")
    print(f"Saving to {outfile}")
    
    chart.save(outfile)