# Make some paper figures

The next cell is tagged `parameters` for papermill parameterization:

In [None]:
# tagged parameters for `papermill`

Python imports:

In [None]:
import os

import altair as alt

import dms_variants.codonvarianttable

import numpy

import pandas as pd

import scipy.stats

_ = alt.data_transformers.disable_max_rows()

## Distribution of variant functional scores
We want the distribution of variant functional scores, similar to as made by [this notebook](https://dms-vep.org/CHIKV_181-25_E_DMS/notebooks/analyze_func_scores.html) but including both E3-E2 and 6K-E1 fragments and not including deletions (since those are rare in our libraries).

First read all the functional scores, ignoring those for deletions since those are rare and not reported in paper:

In [None]:
def classify_selection(sel):
    sels = {"293T-Mxra8": "293T-Mxra8", "293T-TIM1":"293T-TIM1", "C636":"C6/36"}
    assert sum(s in sel for s in sels) == 1, sel
    for s in sels:
        if s in sel:
            label = [sels[s]]
    libs = {"-A-": "library A", "-B-": "library B"}
    assert sum(l in sel for l in libs) == 1, sel
    for l in libs:
        if l in sel:
            label.append(libs[l])
    return " ".join(label)


func_scores_df = (
    pd.concat(
        [
            pd.read_csv(f).assign(selection=sel)
            for (sel, f) in params.items() if sel.startswith("func_scores_")
        ]
    )
    .assign(selection=lambda x: x["selection"].map(classify_selection))
    .pipe(dms_variants.codonvarianttable.CodonVariantTable.classifyVariants)
    .query("variant_class != 'deletion'")
)

(
    func_scores_df
    .groupby(["selection", "variant_class"])
    .aggregate(n_variants=pd.NamedAgg("barcode", "count"))
)

Make the plot:

In [None]:
def ridgeplot(df):
    variant_classes = list(
        reversed(
            [
                c
                for c in [
                    "wildtype",
                    "synonymous",
                    "1 nonsynonymous",
                    ">1 nonsynonymous",
                    "deletion",
                    "stop",
                ]
                if c in set(df["variant_class"])
            ]
        )
    )

    assert set(df["variant_class"]) == set(variant_classes)

    # get smoothed distribution of functional scores
    bins = numpy.linspace(
        df["func_score"].min(),
        df["func_score"].max(),
        num=50,
    )
    smoothed_dist = pd.concat(
        [
            pd.DataFrame(
                {
                    "selection": sel,
                    "variant_class": var,
                    "func_score": bins,
                    "count": scipy.stats.gaussian_kde(df["func_score"])(bins),
                    "mean_func_score": df["func_score"].mean(),
                    "number of variants": len(df),
                }
            )
            for (sel, var), df in df.groupby(["selection", "variant_class"])
        ]
    )

    # assign y / y2 for plotting
    facet_overlap = 0.7  # maximal facet overlap
    max_count = (smoothed_dist["count"]).max()
    smoothed_dist = smoothed_dist.assign(
        y=lambda x: x["variant_class"].map(lambda v: variant_classes.index(v)),
        y2=lambda x: x["y"] + x["count"] / max_count / facet_overlap,
    )

    # ridgeline plot, based on this but using y / y2 rather than row:
    # https://altair-viz.github.io/gallery/ridgeline_plot.html
    ridgeline_chart = (
        alt.Chart(smoothed_dist)
        .encode(
            x=alt.X(
                "func_score", title="functional score for cell entry", scale=alt.Scale(nice=False)
            ),
            y=alt.Y(
                "y",
                scale=alt.Scale(nice=False),
                title=None,
                axis=alt.Axis(
                    ticks=False,
                    domain=False,
                    # set manual labels https://stackoverflow.com/a/64106056
                    values=[v + 0.5 for v in range(len(variant_classes))],
                    labelExpr=f"{str(variant_classes)}[round(datum.value - 0.5)]",
                ),
            ),
            y2=alt.Y2("y2"),
            fill=alt.Fill(
                "mean_func_score:Q",
                title="mean functional score",
                legend=alt.Legend(direction="horizontal"),
                scale=alt.Scale(scheme="yellowgreenblue"),
            ),
            facet=alt.Facet(
                "selection",
                columns=2,
                title=None,
                header=alt.Header(
                    labelFontWeight="bold",
                    labelPadding=0,
                ),
            ),
            tooltip=[
                "selection",
                "variant_class",
                alt.Tooltip(
                    "mean_func_score", format=".2f", title="mean functional score"
                ),
            ],
        )
        .mark_area(
            interpolate="monotone",
            smooth=True,
            fillOpacity=0.8,
            stroke="lightgray",
            strokeWidth=0.5,
        )
        .configure_view(stroke=None)
        .configure_axis(grid=False)
        .properties(width=180, height=22 * len(variant_classes))
    )

    ridgeline_chart = ridgeline_chart.properties(
        autosize=alt.AutoSizeParams(resize=True),
    )

    return ridgeline_chart


func_scores_chart = ridgeplot(func_scores_df)

func_scores_chart