# Visualize functional score distributions for selections used for final analysis

In [None]:
# Imports
import os
import math
import numpy
import yaml
import scipy.stats
import matplotlib.colors
import altair as alt
import pandas as pd
import plotnine as p9
import dms_variants.codonvarianttable

# Create color palette
def color_gradient_hex(start, end, n):
    """Color function from polyclonal"""
    cmap = matplotlib.colors.LinearSegmentedColormap.from_list(
            name="_", colors=[start, end], N=n
        )
    return [matplotlib.colors.rgb2hex(tup) for tup in cmap(list(range(0, n)))]

# Orange to white to blue color gradient
orangeblue = color_gradient_hex("#E69F00", "white", n=20) + color_gradient_hex("white", "#0072B2", n=20)

# Allow more rows for Altair
_ = alt.data_transformers.disable_max_rows()

In [None]:
# this cell is tagged as `parameters` for papermill parameterization
libA_1  = None
libA_2  = None
libA_3  = None
libA_4  = None

libB_1  = None
libB_2  = None
libB_3  = None
libB_4  = None

summary_dir = None
score_dir = None

html_dir = None
html_output = None

In [None]:
# # Uncomment for running interactive
# libA_1 = "LibA-220823-293T-1"
# libA_2 = "LibA-220823-293T-2"
# libA_3 = "LibA-220907-293T-1"
# libA_4 = "LibA-220907-293T-2"

# libB_1 = "LibB-220823-293T-1"
# libB_2 = "LibB-220823-293T-2"
# libB_3 = "LibB-220907-293T-1"
# libB_4 = "LibB-220907-293T-2"

# summary_dir = "../results/func_scores/"
# score_dir = "../results/func_scores/"

# html_dir = "../results/averaged_func_scores_ridgeplot/"
# html_output = "../results/averaged_func_scores_ridgeplot/averaged_func_scores_ridgeplot.html"

In [None]:
# Selection names
selections = [
    libA_1,
    libA_2,
    libA_3,
    libA_4,
    libB_1,
    libB_2,
    libB_3,
    libB_4,
]

# Read and concat all count summary files
count_summaries = pd.concat(
    [pd.read_csv(summary_dir + s + "_count_summary.csv") for s in selections],
    ignore_index=True,
)

# Read and concat all func scores files
func_scores = pd.concat(
    [
        pd.read_csv(score_dir + s + "_func_scores.csv").assign(selection=s)
        for s in selections
    ],
    ignore_index=True,
)

# Create selection filters for plots
selectors = [
    alt.selection_point(
        fields=[sel],
        bind=alt.binding_select(
            options=[None] + sorted(count_summaries[sel].unique()),
            labels=["all"] + sorted(count_summaries[sel].unique()),
            name=sel,
        ),
    )
    for sel in ["library", "pre_selection_date", "post_selection_date"]
]

reduced_selectors = [
    alt.selection_point(
        fields=[sel],
        bind=alt.binding_select(
            options=[None] + sorted(count_summaries[sel].unique()),
            labels=["all"] + sorted(count_summaries[sel].unique()),
            name=sel,
        ),
    )
    for sel in ["library"]
]

## Distributions of functional scores
Plot the functional scores distributions for functional selections (4 per library) used to calculate averages.

These are plotted as ridgeplots.

In [None]:
# classify variants
func_scores = func_scores.pipe(
    dms_variants.codonvarianttable.CodonVariantTable.classifyVariants
)


def ridgeplot(df):
    variant_classes = list(
        reversed(
            [
                c
                for c in [
                    "wildtype",
                    "synonymous",
                    "1 nonsynonymous",
                    ">1 nonsynonymous",
                    "stop",
                ]
                if c in set(df["variant_class"])
            ]
        )
    )

    # remove deletion variants
    df = df.loc[df["variant_class"] != "deletion"]

    assert set(df["variant_class"]) == set(variant_classes)

    # get smoothed distribution of functional scores
    bins = numpy.linspace(
        df["func_score"].min(),
        df["func_score"].max(),
        num=50,
    )
    smoothed_dist = pd.concat(
        [
            pd.DataFrame(
                {
                    "selection": sel,
                    "variant_class": var,
                    "func_score": bins,
                    "count": scipy.stats.gaussian_kde(df["func_score"])(bins),
                    "mean_func_score": df["func_score"].mean(),
                    "number of variants": len(df),
                }
            )
            for (sel, var), df in df.groupby(["selection", "variant_class"])
        ]
    ).merge(
        count_summaries[
            ["selection", "library", "pre_selection_date", "post_selection_date"]
        ],
        on="selection",
        validate="many_to_one",
    )

    # assign y / y2 for plotting
    facet_overlap = 0.7  # maximal facet overlap
    max_count = (smoothed_dist["count"]).max()
    smoothed_dist = smoothed_dist.assign(
        y=lambda x: x["variant_class"].map(lambda v: variant_classes.index(v)),
        y2=lambda x: x["y"] + x["count"] / max_count / facet_overlap,
    )

    # ridgeline plot, based on this but using y / y2 rather than row:
    # https://altair-viz.github.io/gallery/ridgeline_plot.html
    ridgeline_chart = (
        alt.Chart(smoothed_dist)
        .encode(
            x=alt.X(
                "func_score", title="effect on cell entry", scale=alt.Scale(nice=False)
            ),
            y=alt.Y(
                "y",
                scale=alt.Scale(nice=False),
                title=None,
                axis=alt.Axis(
                    ticks=False,
                    domain=False,
                    # set manual labels https://stackoverflow.com/a/64106056
                    values=[v + 0.5 for v in range(len(variant_classes))],
                    labelExpr=f"{str(variant_classes)}[round(datum.value - 0.5)]",
                ),
            ),
            y2=alt.Y2("y2"),
            fill=alt.Fill(
                "mean_func_score:Q",
                title="mean effect on cell entry",
                legend=alt.Legend(direction="horizontal"),
                scale=alt.Scale(
                    domainMid=0,
                    range=orangeblue, 
                    type="linear",
                    domain=[-6,1],
                ),
            ),
            facet=alt.Facet(
                "selection",
                columns=4,
                title=None,
                header=alt.Header(
                    labelFontWeight="bold",
                    labelPadding=0,
                ),
            ),
            tooltip=[
                "selection",
                "variant_class",
                alt.Tooltip(
                    "mean_func_score", format=".2f", title="mean effect on cell entry"
                ),
                "number of variants",
                "library",
                "pre_selection_date",
                "post_selection_date",
            ],
        )
        .mark_area(
            interpolate="monotone",
            smooth=True,
            fillOpacity=0.8,
            stroke="black",
            strokeWidth=0.5,
        )
        .configure_view(stroke=None)
        .configure_axis(grid=False)
        .properties(width=180, height=22 * len(variant_classes))
    )

    for sel in selectors:
        ridgeline_chart = ridgeline_chart.add_params(sel).transform_filter(sel)

    ridgeline_chart = ridgeline_chart.properties(
        autosize=alt.AutoSizeParams(resize=True),
    )

    return ridgeline_chart


ridgeplot(func_scores)

Now plot distribution of functional scores in a ridgeplot using `altair`, this time faceting **just by library** (averaging all replicates within a library) and clipping scores on the low end at the median of stop codons and on the upper end at 2.

In [None]:
def library_average_ridgeplot(df):
    variant_classes = list(
        reversed(
            [
                c
                for c in [
                    "wildtype",
                    "synonymous",
                    "1 nonsynonymous",
                    ">1 nonsynonymous",
                    "stop",
                ]
                if c in set(df["variant_class"])
            ]
        )
    )

    # remove deletions
    df = df.loc[df["variant_class"] != "deletion"]

    assert set(df["variant_class"]) == set(variant_classes)

    # get smoothed distribution of functional scores
    bins = numpy.linspace(
        df["func_score"].min(),
        df["func_score"].max(),
        num=50,
    )
    smoothed_dist = pd.concat(
        [
            pd.DataFrame(
                {
                    "library": lib,
                    "variant_class": var,
                    "func_score": bins,
                    "count": scipy.stats.gaussian_kde(df["func_score"])(bins),
                    "mean_func_score": df["func_score"].mean(),
                    "number of variants": len(df),
                    "number of selections": 4, # four selections per library
                }
            )
            for (lib, var), df in df.groupby(["library", "variant_class"])
        ]
    )

    # assign y / y2 for plotting
    facet_overlap = 0.35  # maximal facet overlap
    max_count = (smoothed_dist["count"]).max()
    smoothed_dist = smoothed_dist.assign(
        y=lambda x: x["variant_class"].map(lambda v: variant_classes.index(v)),
        y2=lambda x: x["y"] + x["count"] / max_count / facet_overlap,
    )

    # ridgeline plot, based on this but using y / y2 rather than row:
    # https://altair-viz.github.io/gallery/ridgeline_plot.html
    ridgeline_chart = (
        alt.Chart(smoothed_dist)
        .encode(
            x=alt.X(
                "func_score", 
                title="effect on cell entry", 
                scale=alt.Scale(nice=False),
                axis=alt.Axis(
                    labelFontSize=12,
                    titleFontSize=12,
                    values=[-4,-2,0,2],
                    labelFontWeight="normal",
                    titleFontWeight="normal",
                )
            ),
            y=alt.Y(
                "y",
                scale=alt.Scale(nice=False),
                title=None,
                axis=alt.Axis(
                    ticks=False,
                    domain=False,
                    # set manual labels https://stackoverflow.com/a/64106056
                    values=[v + 0.5 for v in range(len(variant_classes))],
                    labelExpr=f"{str(variant_classes)}[round(datum.value - 0.5)]",
                    labelFontSize=12,
                    labelFontWeight="normal",
                ),
            ),
            y2=alt.Y2("y2"),
            fill=alt.Fill(
                "mean_func_score:Q",
                title=["mean", "effect on", "cell entry"],
                legend=alt.Legend(
                    direction="vertical",
                    titleFontSize=12,
                    labelFontSize=12,
                    gradientLength=150,
                    # gradientThickness=10,
                    gradientStrokeColor="black",
                    gradientStrokeWidth=0.5,
                    labelFontWeight="normal",
                    titleFontWeight="normal",
                ),
                scale=alt.Scale(
                    domainMid=0,
                    range=orangeblue, 
                    type="linear",
                    domain=[-4.2, 1]
                ),
            ),
            facet=alt.Facet(
                "library",
                columns=4,
                title=None,
                header=alt.Header(
                    labelFontWeight="bold",
                    labelPadding=0,
                    labelFontSize=16,
                ),
            ),
            tooltip=[
                "library",
                "variant_class",
                alt.Tooltip(
                    "mean_func_score", format=".2f", title="mean effect on cell entry"
                ),
                "number of selections",
                "number of variants",
            ],
        )
        .mark_area(
            interpolate="monotone",
            smooth=True,
            fillOpacity=0.8,
            stroke="black",
            strokeWidth=0.5,
        )
        .configure_view(stroke=None)
        .configure_axis(grid=False)
        .properties(
            width=240, 
            height=80 * len(variant_classes), 
            title="Functional score distributions by variant type",
        )
        .configure_title(fontSize=24)
    )

    for sel in reduced_selectors:
        ridgeline_chart = ridgeline_chart.add_params(sel).transform_filter(sel)

    ridgeline_chart = ridgeline_chart.properties(
        autosize=alt.AutoSizeParams(resize=True),
    )

    # Make output dir if doesn't exist
    if not os.path.exists(html_dir):
        os.mkdir(html_dir)

    print(f"Saving to {html_output}")
    ridgeline_chart.save(html_output)

    return ridgeline_chart

# Merge func_scores and count summaries for library info
merged_df = (
    func_scores.merge(
        count_summaries,
        on="selection",
        validate="many_to_one",
    )
)

# Clip functional scores at median of stop codons
lower_floor = merged_df.loc[merged_df["variant_class"] == "stop"]["func_score"].median()
print(f"functional scores are clipped on the lower end at {lower_floor} (median of stop codons) and on the upper end at 2")
merged_df["func_score"] = merged_df["func_score"].clip(lower=lower_floor, upper=2)

library_average_ridgeplot(merged_df)

Recreate **same** plot as above but reduce font sizes for a figure in a manuscript

In [None]:
def library_average_ridgeplot(df):
    variant_classes = list(
        reversed(
            [
                c
                for c in [
                    "wildtype",
                    "synonymous",
                    "1 nonsynonymous",
                    ">1 nonsynonymous",
                    "stop",
                ]
                if c in set(df["variant_class"])
            ]
        )
    )

    # remove deletions
    df = df.loc[df["variant_class"] != "deletion"]

    assert set(df["variant_class"]) == set(variant_classes)

    # get smoothed distribution of functional scores
    bins = numpy.linspace(
        df["func_score"].min(),
        df["func_score"].max(),
        num=50,
    )
    smoothed_dist = pd.concat(
        [
            pd.DataFrame(
                {
                    "library": lib,
                    "variant_class": var,
                    "func_score": bins,
                    "count": scipy.stats.gaussian_kde(df["func_score"])(bins),
                    "mean_func_score": df["func_score"].mean(),
                    "number of variants": len(df),
                    "number of selections": 4, # four selections per library
                }
            )
            for (lib, var), df in df.groupby(["library", "variant_class"])
        ]
    )

    # assign y / y2 for plotting
    facet_overlap = 0.35  # maximal facet overlap
    max_count = (smoothed_dist["count"]).max()
    smoothed_dist = smoothed_dist.assign(
        y=lambda x: x["variant_class"].map(lambda v: variant_classes.index(v)),
        y2=lambda x: x["y"] + x["count"] / max_count / facet_overlap,
    )

    # ridgeline plot, based on this but using y / y2 rather than row:
    # https://altair-viz.github.io/gallery/ridgeline_plot.html
    ridgeline_chart = (
        alt.Chart(smoothed_dist)
        .encode(
            x=alt.X(
                "func_score", 
                title="effect on cell entry", 
                scale=alt.Scale(nice=False),
                axis=alt.Axis(
                    labelFontSize=8,
                    titleFontSize=8,
                    values=[-4,-2,0,2],
                    labelFontWeight="normal",
                    titleFontWeight="normal",
                )
            ),
            y=alt.Y(
                "y",
                scale=alt.Scale(nice=False),
                title=None,
                axis=alt.Axis(
                    ticks=False,
                    domain=False,
                    # set manual labels https://stackoverflow.com/a/64106056
                    values=[v + 0.5 for v in range(len(variant_classes))],
                    labelExpr=f"{str(variant_classes)}[round(datum.value - 0.5)]",
                    labelFontSize=8,
                    labelFontWeight="normal",
                    titleFontWeight="normal",
                ),
            ),
            y2=alt.Y2("y2"),
            fill=alt.Fill(
                "mean_func_score:Q",
                title=["mean", "effect on", "cell entry"],
                legend=alt.Legend(
                    direction="vertical",
                    titleFontSize=8,
                    labelFontSize=8,
                    gradientLength=60,
                    gradientThickness=10,
                    gradientStrokeColor="black",
                    gradientStrokeWidth=0.5,
                    labelFontWeight="normal",
                    titleFontWeight="normal",
                ),
                scale=alt.Scale(
                    domainMid=0,
                    range=orangeblue, 
                    type="linear",
                    domain=[-4.2, 1]
                ),
            ),
            facet=alt.Facet(
                "library",
                columns=4,
                title=None,
                header=alt.Header(
                    labelFontWeight="bold",
                    labelPadding=0,
                    labelFontSize=8,
                ),
                spacing=5,
            ),
            tooltip=[
                "library",
                "variant_class",
                alt.Tooltip(
                    "mean_func_score", format=".2f", title="mean effect on cell entry"
                ),
                "number of selections",
                "number of variants",
            ],
        )
        .mark_area(
            interpolate="monotone",
            smooth=True,
            fillOpacity=0.8,
            stroke="black",
            strokeWidth=0.5,
        )
        .configure_view(stroke=None)
        .configure_axis(grid=False)
        .properties(width=60, height=20 * len(variant_classes))
    )

    for sel in reduced_selectors:
        ridgeline_chart = ridgeline_chart.add_params(sel).transform_filter(sel)

    ridgeline_chart = ridgeline_chart.properties(
        autosize=alt.AutoSizeParams(resize=True),
    )

    return ridgeline_chart

# Merge func_scores and count summaries for library info
merged_df = (
    func_scores.merge(
        count_summaries,
        on="selection",
        validate="many_to_one",
    )
)

# Clip functional scores at median of stop codons
lower_floor = merged_df.loc[merged_df["variant_class"] == "stop"]["func_score"].median()
print(f"functional scores are clipped on the lower end at {lower_floor} (median of stop codons) and on the upper end at 2")
merged_df["func_score"] = merged_df["func_score"].clip(lower=lower_floor, upper=2)

library_average_ridgeplot(merged_df)