# Visualize deep mutational scanning phenotypes

Get input variables from [papermill](https://papermill.readthedocs.io/) parameterization:

In [None]:
# this cell is tagged `parameters` and so is parameterized by `papermill`
analysis_filters = None
decimal_scale = None
pheno_scatter_init_phenos = None
site_numbering_schemes = None
input_tsv = None
usher_tree = None
filter_chart_html = None
pheno_scatter_chart_html= None
haplotypes_tsv = None
nextstrain_json = None

Import Python modules:

In [None]:
import json
import operator

import altair as alt

import bte

import pandas as pd

_ = alt.data_transformers.disable_max_rows()

## Read and filter the data
Read the DMS data and filter on the fraction aligned and the fraction divergence from the DMS strain:

In [None]:
dms_data = pd.read_csv(input_tsv, sep="\t")

stat_hists = []
for stat, op in [
    ("frac_aligned", operator.ge),
    ("frac_divergence", operator.le),
]:
    cutoff = analysis_filters[stat]
    stat_filter_df = dms_data[["name", stat]].drop_duplicates().assign(
        retained=lambda x: x[stat].map(lambda s: op(s, cutoff))
    )
    n = len(stat_filter_df)
    n_kept = sum(stat_filter_df["retained"])
    
    dms_data = dms_data[dms_data[stat].map(lambda s: op(s, cutoff))]

    stat_hist = (
        alt.Chart(stat_filter_df)
        .encode(
            alt.X(stat, bin=alt.Bin(step=0.01), title=stat),
            alt.Y("count()", title="number of sequences"),
            alt.Color("retained"),
        )
        .mark_bar()
        .properties(
            width=220,
            height=150,
            title=alt.TitleParams(
                f"{n_kept} / {n} sequences meet {stat} cutoff of {cutoff}",
                fontSize=10,
                fontWeight="normal",
            ),
        )
    )
    stat_hists.append(stat_hist)

assert n_kept == dms_data["name"].nunique() 

filter_chart = alt.hconcat(*stat_hists).properties(
    title=alt.TitleParams(
        f"Filtering of sequences by identity to protein used in DMS (retained {n_kept})",
        anchor="middle",
    )
)

print(f"Saving to {filter_chart_html}")
filter_chart.save(filter_chart_html)

display(filter_chart)

## Make plots of phenotypes for each haplotype
We first group the strains by amino-acid haplotype (group together strains w same mutations):

In [None]:
# `dms_data` has phenotypes by strain name; group all strains with same mutations
# to create `dms_data_by_haplotype`
assert len(dms_data) == len(dms_data.drop_duplicates())
dms_data_by_haplotype = (
    dms_data
    .groupby([c for c in dms_data.columns if c != "name"], as_index=False)
    .aggregate(
        n_strains=pd.NamedAgg("name", "nunique"),
        strains=pd.NamedAgg("name", lambda s: ", ".join(s)),
    )
)
n_strains = dms_data["name"].nunique()
n_haplotypes = dms_data_by_haplotype["strains"].nunique()
assert n_strains == dms_data_by_haplotype[["strains", "n_strains"]].drop_duplicates()["n_strains"].sum()
assert n_haplotypes == len(dms_data_by_haplotype)
print(f"There are {n_strains=} corresponding to {n_haplotypes=}")

print(f"Writing the haplotype data to {haplotypes_tsv}")
dms_data_by_haplotype.to_csv(haplotypes_tsv, index=False, sep="\t", float_format=f"%.{decimal_scale}f")

Now make a plot, where the size of each circle is proportional to the number of strains in that haplotype:

In [None]:
full_description = [
    (
        f"{n_strains} samples corresponding to {n_haplotypes} protein haplotypes "
        + f"with identity >= {analysis_filters['frac_divergence']} to DMS strain."
    ),
    description,
]

# get the relevant columns
mutations_cols = [f"mutations_{scheme}" for scheme in site_numbering_schemes]
assert set(mutations_cols).issubset(dms_data_by_haplotype.columns)
pheno_cols = [
    c for c in dms_data_by_haplotype
    if c not in ["n_strains", "strains", "frac_aligned"] + mutations_cols
]

assert all(x in pheno_cols for x in pheno_scatter_init_phenos.values())
x_pheno_param = alt.param(
    value=pheno_scatter_init_phenos["x"],
    bind=alt.binding_select(
        options=pheno_cols,
        name="x-axis phenotype",
    )
)
y_pheno_param = alt.param(
    value=pheno_scatter_init_phenos["y"],
    bind=alt.binding_select(
        options=pheno_cols,
        name="y-axis phenotype",
    )
)

max_frac_divergence = alt.param(
    value=analysis_filters["frac_divergence"],
    bind=alt.binding_range(
        min=0,
        step=0.001,
        max=analysis_filters["frac_divergence"],
        name="max divergence from DMS protein",
    ),
)

scale_size_up_to = alt.param(
    value=25,
    bind=alt.binding_range(
        min=1,
        step=1,
        max=dms_data_by_haplotype["n_strains"].max(),
        name="scale point size by n_strains up to this size",
    ),
)

min_n_strains = alt.param(
    value=1,
    bind=alt.binding_range(
        min=1,
        max=20,
        step=1,
        name="only show points for haplotypes with at least this many strains",
    ),
)

haplotype_selection = alt.selection_point(on="mouseover", empty=False)

haplotype_pheno_chart = (
    alt.Chart(dms_data_by_haplotype)
    .add_params(
        x_pheno_param,
        y_pheno_param,
        max_frac_divergence,
        scale_size_up_to,
        min_n_strains,
        haplotype_selection,
    )
    .transform_filter(alt.datum["frac_divergence"] <= max_frac_divergence)
    .transform_filter(alt.datum["n_strains"] >= min_n_strains)
    .transform_calculate(
        x_pheno=f"datum[{x_pheno_param.name}]",
        y_pheno=f"datum[{y_pheno_param.name}]",
        size=alt.expr.min(alt.datum["n_strains"], scale_size_up_to),
    )
    .encode(
        alt.X("x_pheno:Q", title="x-axis phenotype", scale=alt.Scale(nice=False, padding=10)),
        alt.Y("y_pheno:Q", title="y-axis phenotype", scale=alt.Scale(nice=False, padding=10)),
        alt.Size(
            "size:Q",
            legend=None,
            scale=alt.Scale(range=[20, 200]),
        ),
        stroke=alt.condition(haplotype_selection, alt.value("red"), alt.value("black")),
        strokeWidth=alt.condition(haplotype_selection, alt.value(2), alt.value(0.5)),
        tooltip=[
            "n_strains",
            *[
                alt.Tooltip(c, format=f".{decimal_scale}f")
                for c in ["frac_divergence", "frac_aligned", *pheno_cols]
            ],
            *mutations_cols,
            "strains",
        ],
    )
    .mark_circle(strokeOpacity=1, fillOpacity=0.3)
    .properties(
        width=500,
        height=500,
        title=alt.TitleParams(
            title,
            subtitle=full_description,
        ),
    )
    .configure_axis(grid=False, labelFontSize=12, titleFontSize=14)
)

print(f"Saving to {pheno_scatter_chart_html}")
haplotype_pheno_chart.save(pheno_scatter_chart_html)

haplotype_pheno_chart

## Write Nextstrain JSON of tree of the filtered samples
Write JSON of the filtered samples:

In [None]:
samples = dms_data["name"].tolist()

t = bte.MATree(usher_tree)

t.write_json(
    nextstrain_json,
    samples=samples,
    metafiles=[input_tsv],
)

Edit the JSON for better display:

In [None]:
with open(nextstrain_json) as f:
    t_json = json.load(f)

init_coloring  = pheno_scatter_init_phenos["x"]
t_json["meta"]["colorings"] = [
    {"key": c, "title": c, "type": "continuous"}
    for c in [init_coloring] + [c for c in pheno_cols if c != init_coloring]
]

t_json["meta"]["filters"] = []

t_json["meta"]["title"] = title
t_json["meta"]["description"] = " ".join(full_description)

with open(nextstrain_json, "w") as f:
    json.dump(t_json, f)