# Compare Cell Entry Effects

In this notebook, we'll investigate whether the same mutation affects cell entry differently across three cell types (**293T-Mxra8**, **293T-TIM1**, and **C6/36**). 

Although *Mxra8* serves as a receptor and *TIM1* as an entry factor for CHIKV in humans, the mosquito receptor remains unknown. By identifying sites where mutations affect cell entry differently in mosquito cells (**C6/36**) than in human cells (**293T-Mxra8** and **293T-TIM1**), we may uncover sites involved in binding to the unknown mosquito receptor. 


In [1]:
import itertools

import altair as alt

import dmslogo.colorschemes

import numpy

import pandas as pd

import polyclonal.alphabets

import scipy.spatial.distance

# Remove the limit of ~5000 rows -- maybe there are better ways? (https://altair-viz.github.io/user_guide/large_datasets.html)
_ = alt.data_transformers.disable_max_rows()

## Read the data

For this analysis, we'll need the effects of mutations on cell entry in each cell line and the annotations of each site in CHIKV E (E1, E2, E3, 6K).

We'll use cell entry data from selections in three cell lines:

1. **293T-Mxra8**: Human cells over-expressing the receptor Mxra8.
2. **293T-TIM1**: Human cells over-expressing the entry factor TIM1.
3. **C6/36**: Mosquito midgut cells.

These are pre-filtered (for QC metrics) values:

In [2]:
mut_effects_csv = "../results/summaries/entry_293T-Mxra8_C636_293T-TIM1_Mxra8-binding.csv"

print(f"Reading mutation effects from {mut_effects_csv=}")
mut_effects = pd.read_csv(mut_effects_csv)

mut_effects

Reading mutation effects from mut_effects_csv='../results/summaries/entry_293T-Mxra8_C636_293T-TIM1_Mxra8-binding.csv'


Unnamed: 0,site,wildtype,mutant,entry in 293T_Mxra8 cells,entry in C636 cells,entry in 293T_TIM1 cells,binding to mouse Mxra8,sequential_site,region
0,-1(E3),M,I,-7.5410,-7.514,-7.50200,,1,E3
1,-1(E3),M,M,0.0000,0.000,0.00000,0.00000,1,E3
2,-1(E3),M,T,-7.5630,-7.541,-7.57600,,1,E3
3,1(6K),A,A,0.0000,0.000,0.00000,0.00000,489,6K
4,1(6K),A,C,0.1786,0.035,0.02934,-0.02603,489,6K
...,...,...,...,...,...,...,...,...,...
19001,99(E2),H,S,-7.2690,-7.132,-6.60600,,164,E2
19002,99(E2),H,T,-7.4930,-6.834,-6.99100,,164,E2
19003,99(E2),H,V,-7.5370,-7.494,-7.41200,,164,E2
19004,99(E2),H,W,-7.0080,-6.427,-5.61900,,164,E2


Get the data tidy format:

In [3]:
# cells and their names in input file
cells = {"293T-Mxra8": "293T_Mxra8", "C6/36": "C636", "293T-TIM1": "293T_TIM1"}

col_to_cell = {f"entry in {label} cells": cell for (cell, label) in cells.items()}

assert set(col_to_cell).issubset(mut_effects.columns), f"{col_to_cell=}, {mut_effects.columns=}"

mut_effects_tidy = (
    mut_effects.rename(columns=col_to_cell)
    .melt(
        id_vars=["site", "sequential_site", "wildtype", "mutant", "region"],
        value_vars=col_to_cell.values(),
        var_name="cell",
        value_name="effect",
    )
    .sort_values("sequential_site")
)

mut_effects_tidy

Unnamed: 0,site,sequential_site,wildtype,mutant,region,cell,effect
0,-1(E3),1,M,I,E3,293T-Mxra8,-7.5410
19006,-1(E3),1,M,I,E3,C6/36,-7.5140
19007,-1(E3),1,M,M,E3,C6/36,0.0000
19008,-1(E3),1,M,T,E3,C6/36,-7.5410
38012,-1(E3),1,M,I,E3,293T-TIM1,-7.5020
...,...,...,...,...,...,...,...
15614,439(E1),988,H,E,E1,293T-Mxra8,-0.6632
15615,439(E1),988,H,F,E1,293T-Mxra8,-1.0950
34618,439(E1),988,H,C,E1,C6/36,-0.3651
34619,439(E1),988,H,D,E1,C6/36,-0.3749


## Scatter plots of cell entry for each cell

How does the same mutation affect entry in each cell line? We'll plot the effect of each mutation between pairs of cell lines to determine if there are global differences.

In [4]:
def plot_mutation_level_comparison(
    data,
    condition,
    value,
    groupby=['site', 'mutant', 'wildtype', 'sequential_site'],
    color=None,
    label_suffix="",
    init_floor_value=-6,
):
    """
    Make an Altair scatter plot comparing mutant-level values for each condition.

    Parameters
    ----------
    data : pd.DataFrame
        The long-form data to plot
    conditions: str
        The column containing the condition labels (i.e. TIM1, MXRA8, C636)
    value : str
        The column containing the values to compare between conditions
    groupby : list of str
        The columns to group the data on (i.e. ['site', 'mutant', 'wildtype'])
    color : str
        The column to color the points and add an interactive legend for
    label_suffix : str
        Label suffixed to x- and y-axis labels.
    init_floor_value : float or None
        Initial value for floor slider for values.

    Returns
    -------
    alt.Chart
        The Altair chart object
    """
    
    if 'mutant' not in groupby or 'site' not in groupby:
        raise ValueError("groupby must contain 'mutant' and 'site'")
    
    missing_cols = [col for col in [condition, value] + groupby if col not in data.columns]
    if missing_cols:
        raise ValueError(f"Columns are missing from the data: {missing_cols}")
    
    if color is not None:
        if color not in data.columns:
            raise ValueError(f"Color column '{color}' not found in data")
        groupby.append(color)
    
    conditions = data[condition].unique()

    # pivot the data
    data_wide = (
        data
        .pivot_table(index=groupby, columns=condition, values=value)
        .reset_index()
    )

    tooltips = []
    for col in groupby:
        tooltips.append(alt.Tooltip(f'{col}:N'))
    for col in conditions:
        tooltips.append(alt.Tooltip(f'{col}:Q', format=".2f"))

    brush = alt.selection_interval()
    
    mut_selection = alt.selection_point(on="mouseover", fields=groupby, empty=False)

    min_value_slider = alt.param(
        name="min_value_slider",
        bind=alt.binding_range(
            min=min(data[value]),
            max=max(data[value]),
            name="floor values at this number",
        ),
        value=(
            max(init_floor_value, min(data[value]))
            if init_floor_value is not None
            else min(data[value])
        ),
    )

    base = (
        alt.Chart(data_wide)
        .add_params(mut_selection, brush, min_value_slider)
        .transform_filter(brush)
    )

    scatters = []
    for condition_a, condition_b in itertools.combinations(conditions, 2):
        # Base data for the scatter plot
        scatter = base.transform_filter(
            f'isValid(datum["{condition_a}"]) && isValid(datum["{condition_b}"])'
        ).transform_calculate(
            condition_a_floored=f'max(datum["{condition_a}"], min_value_slider)',
            condition_b_floored=f'max(datum["{condition_b}"], min_value_slider)',
        ).encode(
            x=alt.X(
                "condition_a_floored:Q",
                title=condition_a + label_suffix,
                scale=alt.Scale(padding=10, nice=False, zero=False),
                axis=alt.Axis(titleFontSize=14, labelFontSize=11, labelOverlap="greedy"),
            ),
            y=alt.Y(
                "condition_b_floored:Q",
                title=condition_b + label_suffix,
                scale=alt.Scale(padding=10, nice=False, zero=False),
                axis=alt.Axis(titleFontSize=14, labelFontSize=11, labelOverlap="greedy"),
            ),
        ).properties(
            title=alt.TitleParams(f'{condition_a} vs {condition_b}', fontSize=16),
            width=250,
            height=250
        )
        # Background points to show the full range of data when brushing
        background = scatter.mark_point(
            filled=True,
            size=25,
            color='lightgray',
            opacity=0.3,
        )
        # Foreground points have tooltips and respond to brushing (and legend selection)
        if color is not None:
            selection = alt.selection_point(fields=[color], bind='legend')
            foreground = scatter.mark_point(
                filled=True,
                fillOpacity=0.5,
                stroke="black",
                strokeOpacity=1,
            ).encode(
                color=alt.Color(color, type='nominal').scale(domain=data[color].unique()),
                strokeWidth=alt.condition(mut_selection, alt.value(3), alt.value(0)),
                size=alt.condition(mut_selection, alt.value(80), alt.value(40)),
                tooltip=tooltips,
            ).add_params(
                selection
            ).transform_filter(selection)
        else:
            foreground = scatter.mark_point(
                filled=True,
                color='steelblue',
                fillOpacity=0.5,
                stroke="black",
                strokeOpacity=1,
            ).encode(
                tooltip=tooltips,
                strokeWidth=alt.condition(mut_selection, alt.value(3), alt.value(0)),
                size=alt.condition(mut_selection, alt.value(70), alt.value(35)),
            )

        scatters.append((background + foreground))

    chart = alt.hconcat(*scatters).configure_axis(grid=False).configure_legend(
        titleFontSize=14, labelFontSize=14
    )

    return chart

In [5]:
compare_mutation_effects = plot_mutation_level_comparison(
    mut_effects_tidy,
    "cell",
    "effect", 
    color="region",
    label_suffix=" cell entry",
    init_floor_value=-6,
)

compare_mutation_effects

- *Mouseover* on points to see a tooltip with information about that mutation.
- *Hold Click and Drag* over points to show only those mutations.
- *Click* on conditions in the legend to show only that condition (`region`).
- *Use the slider* to floor values at some mimum plot value.
- *Double Click* on the plot or legend to reset the plot.

Points with color show the active selection and gray points show total distribution of the data.

## Identify sites where mutations have different effects in each cell

### Compute site differences between conditions

We use three different site-level metrics for the differences between conditions:
 - **mean difference**: The mean difference in effect on cell entry for all non-wildtype amino acids at each site in *cell_1* minus *cell_2*.
 - **Jensen-Shannon divergence**: A "probability" is assigned to each amino acid at each site as proportional `exp(effect)`, and then the Jensen-Shannon divergence is computed for the probabilities for *cell_1* versus *cell_2*.
 - **difference in constraint**: A "probability" is assigned to each amino acid as proportional `exp(effect)`, and then the number of effective amino acids at each site is computed for each cell, and we report the number for *cell_1* minus *cell_2*.

In [6]:
# first get color to use for each amino-acid in scatter plot
# this also defines list of amino acids to keep
aa_color_df = (
    pd.Series(dmslogo.colorschemes.AA_FUNCTIONAL_GROUP)
    .rename_axis("mutant")
    .rename("color")
    .reset_index()
)
aas = polyclonal.alphabets.biochem_order_aas(polyclonal.alphabets.AAS)
assert set(aa_color_df["mutant"]) == set(aas)

# get mutation level data, just for amino acids
assert set(cells) == set(mut_effects_tidy["cell"])
mut_data = (
    mut_effects_tidy
    .query("mutant in @aas")
    .pivot_table(
        index=["site", "sequential_site", "wildtype", "mutant", "region"],
        columns="cell",
        values="effect",
    )
    .sort_values("sequential_site")
    .reset_index()
)
assert set(mut_data["wildtype"]).issubset(aas)

# get site difference data
def get_site_diffs(df):
    is_wildtype = df.iloc[:, 0]
    s1 = df.iloc[:, 1]
    s2 = df.iloc[:, 2]
    # simple mean difference across non-wildtype sites
    mean_diff = (s1 - s2)[~is_wildtype].mean()
    # relative entropy
    p1 = numpy.exp(s1[s1.notnull() & s2.notnull()])
    p2 = numpy.exp(s2[s1.notnull() & s2.notnull()])
    assert len(p1) == len(p2)
    if len(p1):
        p1 /= p1.sum()
        p2 /= p2.sum()
        jsd = scipy.spatial.distance.jensenshannon(p1, p2)**2
    else:
        jsd = 0
    # difference in n_effective
    if len(p1) == 0:
        n_eff_diff = 0
    else:
        n_eff_1 = len(aas)**(-p1 * numpy.log(p1) / numpy.log(len(aas))).sum()
        n_eff_2 = len(aas)**(-p2 * numpy.log(p2) / numpy.log(len(aas))).sum()
        n_eff_diff = n_eff_1 - n_eff_2
    return pd.Series(
        {
            "mean difference": mean_diff,
            "Jensen-Shannon divergence": jsd,
            "difference in constraint": n_eff_diff,
        }
    )
    

site_diff_metrics = [
    "difference in constraint", "mean difference", "Jensen-Shannon divergence"
]
site_diffs = []
for cell_1, cell_2 in itertools.combinations(cells, 2):
    site_diffs.append(
        mut_data
        .assign(is_wildtype=lambda x: x["mutant"] == x["wildtype"])
        .groupby(["site", "sequential_site", "region"])
        [["is_wildtype", cell_1, cell_2]]
        .apply(get_site_diffs)
        .assign(cell_1=cell_1, cell_2=cell_2)
        .sort_values("sequential_site")
        .reset_index()
    )
site_diffs = pd.concat(site_diffs, ignore_index=True)
assert set(site_diff_metrics).issubset(site_diffs.columns)

### Plot sites with large differences

In [7]:
display(mut_data.head())
display(site_diffs.head())

cell,site,sequential_site,wildtype,mutant,region,293T-Mxra8,293T-TIM1,C6/36
0,-1(E3),1,M,I,E3,-7.541,-7.502,-7.514
1,-1(E3),1,M,M,E3,0.0,0.0,0.0
2,-1(E3),1,M,T,E3,-7.563,-7.576,-7.541
3,1(E3),2,S,D,E3,0.1852,0.1291,0.477
4,1(E3),2,S,E,E3,0.3038,0.08297,-0.0481


Unnamed: 0,site,sequential_site,region,mean difference,Jensen-Shannon divergence,difference in constraint,cell_1,cell_2
0,-1(E3),1,E3,-0.0245,8.062812e-08,-0.000198,293T-Mxra8,C6/36
1,1(E3),2,E3,0.118259,0.0112694,0.854814,293T-Mxra8,C6/36
2,2(E3),3,E3,-0.064232,0.02247468,0.370147,293T-Mxra8,C6/36
3,3(E3),4,E3,-0.068184,0.005355085,0.956139,293T-Mxra8,C6/36
4,4(E3),5,E3,-0.153826,0.00610687,-0.343112,293T-Mxra8,C6/36


In [20]:
plot_site_comparison(
    mut_data, site_diffs, cells, site_diff_metrics, aas, aa_color_df, -6
)

In [21]:
def plot_site_comparison(
    mut_data,
    site_diffs,
    cells,
    site_diff_metrics,
    aas,
    aa_color_df,
    init_floor_effect,
):
    """Plot (site-level) difference of entry effects between cells w mutation zooms."""

    # some params
    site_chart_width = 600
    default_site=site_diffs["site"].tolist()[0]

    assert set(mut_data["site"]) == set(site_diffs["site"])
    assert set(site_diff_metrics).issubset(site_diffs.columns)

    # Drag to zoom into sites on the x-axis colored by region
    zoom_selection = alt.selection_interval(
        encodings=["x"],
        mark=alt.BrushConfig(stroke='black', strokeWidth=2)
    )

    # zoom bar
    zoom_bar = (
        alt.Chart(mut_data[["site", "sequential_site", "region"]])
        .mark_rect()
        .encode(
            alt.X(
                "site:N",
                sort=alt.SortField("sequential_site"),
                title="click and drag to zoom on sites",
                axis=alt.Axis(ticks=False, labels=False, titleFontSize=14, titleFontWeight="normal"),
            ),
            alt.Color("region", scale=alt.Scale(scheme="set3"), legend=None),
            tooltip=["site", "sequential_site", "region"],
        )
        .properties(width=site_chart_width, height=12)
        .add_params(zoom_selection)
    )

    # line plot
    metric_selection = alt.selection_point(
        fields=["metric"],
        name="metric_selection",
        value=site_diff_metrics[0],
        bind=alt.binding_select(
            options=site_diff_metrics,
            name="metric for site differences between cells",
        ),
    )

    cell_1_options = [c for c in cells if c in set(site_diffs["cell_1"])]
    cell_1_selection = alt.param(
        name="cell_1",
        value=cell_1_options[0],
        bind=alt.binding_select(
            options=cell_1_options,
            name="comparator cell line",
        )
    )

    cell_2_options = [c for c in cells if c in set(site_diffs["cell_2"])]
    cell_2_selection = alt.param(
        name="cell_2",
        value=cell_2_options[0],
        bind=alt.binding_select(
            options=cell_2_options,
            name="reference cell line",
        )
    )

    # site w biggest effect
    default_site = (
        site_diffs[
            (site_diffs["cell_1"] == cell_1_options[0])
            & (site_diffs["cell_2"] == cell_2_options[0])
        ]
        .set_index("site")
        [site_diff_metrics[0]]
        .abs()
        .sort_values(ascending=False)
        .index[0]
    )

    site_selection = alt.selection_point(
        fields=["site"], empty=False, value=default_site, on="click"
    )
    
    site_base = (
        alt.Chart(site_diffs)
        .transform_filter(zoom_selection)
        .transform_filter(
            (alt.datum["cell_1"] == cell_1_selection)
            & (alt.datum["cell_2"] == cell_2_selection)
        )
        .transform_fold(
            site_diff_metrics,
            ["metric", "difference"],
        )
        .transform_filter(metric_selection)
        .encode(
            alt.X(
                "site:N",
                sort=alt.SortField("sequential_site"),
                title=None,
                axis=alt.Axis(labelOverlap="greedy", ticks=False),
            ),
            alt.Y(
                "difference:Q",
                title="difference at site",
                scale=alt.Scale(nice=False, padding=7),
            ),
            tooltip=[
                "site", "sequential_site", "region", alt.Tooltip("difference:Q", format=".2f")
            ],
        )
    )
    
    site_lines = site_base.mark_line(color="gray", strokeWidth=1, opacity=0.7)

    site_points = site_base.mark_circle(filled=True, opacity=1).encode(
        color=alt.condition(site_selection, alt.value("black"), alt.value("gray")),
        size=alt.condition(site_selection, alt.value(120), alt.value(50)),
        opacity=alt.condition(site_selection, alt.value(1), alt.value(0.7)),
    )

    # Dynamic title for chart plot
    site_title = alt.TitleParams(
        alt.expr(
            f'"difference between mutation effects in " + {cell_1_selection.name} + " versus " + {cell_2_selection.name} + " cells"'
        ),
        anchor="middle",
        fontSize=16,
    )

    site_chart = (
        (site_lines + site_points)
        .properties(width=site_chart_width, height=180, title=site_title)
        .add_params(metric_selection, site_selection, cell_1_selection, cell_2_selection)
    )

    # amino-acid scatter plot for a single site
    min_effect = mut_data[list(cells)].min().min()
    max_effect = mut_data[list(cells)].max().max()
    min_effect_slider = alt.param(
        name="min_effect_slider",
        bind=alt.binding_range(
            min=min_effect, max=max_effect, name="floor display mutation effect at",
        ),
        value=max(init_floor_effect, min_effect) if init_floor_effect is not None else min_effect,
    )
    
    mut_base = alt.Chart(mut_data).add_params(min_effect_slider)


    mutant_selection = alt.selection_point(
        fields=["mutant"], on="mouseover", empty=False
    )

    mut_scatter = (
        mut_base
        .transform_filter(site_selection)
        .transform_lookup(
            lookup='mutant',
            from_=alt.LookupData(data=aa_color_df, key='mutant', fields=['color']),
        )
        .transform_calculate(
            x=f"datum[{cell_1_selection.name}]",
            y=f"datum[{cell_2_selection.name}]",
            x_floored=f'isValid(datum.x) ? max(datum.x, {min_effect_slider.name}) : datum.x',
            y_floored=f'isValid(datum.y) ? max(datum.y, {min_effect_slider.name}) : datum.y',
        )
        .encode(
            alt.X("x_floored:Q", title="comparator cell line"),
            alt.Y("y_floored:Q", title="reference cell line"),
            alt.Text("mutant:N"),
            alt.Color("color:N", scale=None),
            size=alt.condition(mutant_selection, alt.value(20), alt.value(16)),
            strokeWidth=alt.condition(mutant_selection, alt.value(1), alt.value(0)),
            fillOpacity=alt.condition(mutant_selection, alt.value(1), alt.value(0.75)),
            tooltip=(
                ["mutant", "wildtype"] + [alt.Tooltip(c, format=".2f") for c in cells]
            )
        )
        .mark_text(stroke="black", strokeOpacity=1, fontWeight=600)
        .add_params(cell_1_selection, cell_2_selection, mutant_selection)
        .properties(
            title=alt.TitleParams(
                alt.expr(f'"mutation effects at site " + {site_selection.name}.site')
            ),
            width=180,
            height=180,
        )
    )

    return (
        alt.hconcat(alt.vconcat(site_chart & zoom_bar), mut_scatter)
    ).configure_axis(grid=False)


    # OLD STUFF

#            .transform_calculate(
#            **{
#                f"{cell}_floored": (
#                    f'isValid(datum["{cell}"]) '
#                    f'? max(datum["{cell}"], {min_effect_slider.name}) '
#                    f': datum["{cell}"]'
#                )
#                for cell in cells
#            }
#        )


        
    rule = (
        alt.Chart()
            .mark_rule(
                color=primary_color,
                strokeWidth=3,
                strokeCap="round",
                strokeDash=[8, 8],
            )
            .encode(
                x=alt.datum(
                    mutant_domain[1],
                    type="quantitative",
                    scale=alt.Scale(domain=mutant_domain)
                ),
                y=alt.datum(
                    mutant_domain[1], 
                    type="quantitative", 
                    scale=alt.Scale(domain=mutant_domain)
                ),
                x2=alt.datum(mutant_domain[0]),
                y2=alt.datum(mutant_domain[0]),
            )
    )
    mutants = (
        (rule + residues)
            .properties(
                width=height,
                height=height,
                title=mutant_title
            )
    )
    
    # Assemble the final chart
    chart = (
        (summary & zoom) | mutants
    ).configure_point(
        size=50
    ).configure_axis(
        labelFontSize=13,
        titleFontSize=16
    ).configure_legend(
        labelFontSize=13,
        titleFontSize=16
    )

    return chart

In [None]:
def plot_site_level_comparison(
    data,
    default_comparison
):
    """
    Plot the aggregated (site-level) difference of cell entry effects between conditions.

    Parameters
    ----------
    data : pd.DataFrame
        The long-form data to plot
    default_comparison : list
        Default conditions to compare
        
    """
    # Top-level variables for the chart
    metric = "effect"
    condition_col = "cell"
    default_summary = "mean"
    conditions = data[condition_col].unique()
    mutant_domain = [1.2 * data[metric].min(), 1.2 * data[metric].max()]

    # Canonically color the amino acids by functional group
    # https://jbloomlab.github.io/dmslogo/dmslogo.colorschemes.html
    AA_FUNCTIONAL_GROUP = dmslogo.colorschemes.AA_FUNCTIONAL_GROUP
    aa_color_df = pd.DataFrame(
        {'mutant': AA_FUNCTIONAL_GROUP.keys(), 'color': AA_FUNCTIONAL_GROUP.values()}
    )

    stats = {
        stat: f'{stat}({metric})' 
        for stat in ["mean", "median", "min", "max", "sum"]
    }
    x = -4
    k = 3 
    
    # Calculate a sensible default site to display
    default_site = (
        data 
            .query('mutant not in ["*", "-"]')
            .groupby([condition_col, 'site', 'wildtype'])
            .agg({metric: [default_summary]})
            .loc[:, metric]
            .loc[:, default_summary]
            .reset_index()
            .pivot(index=['site', 'wildtype'], columns=condition_col, values=default_summary)
            .loc[:, default_comparison]
            .assign(difference=lambda x: x[default_comparison[0]] - x[default_comparison[1]])
            .abs()
            .difference
            .idxmax()[0]
    )

    # Top-level configuration for the chart
    primary_color = '#f8b196'
    active_color = '#c06c84'
    width = 1000
    height = 300
    interactive = True

    # Define the tooltips for the chart
    mutant_tooltips = [
        "wildtype",
        alt.Tooltip('x:Q', title="baseline", format=".2f"),
        alt.Tooltip('y:Q', title="comparison", format=".2f")
    ]
    summary_tooltips = ["site", "sequential_site"]
    for col in conditions:
        summary_tooltips.append(alt.Tooltip(f'{col}:Q', format=".2f"))

    # The data is shared across all parts of the final charts
    base = alt.Chart(
        data.query('mutant not in ["*", "-"]')
    )

    # Drag to zoom into sites on the x-axis colored by region
    zoom_selection = alt.selection_interval(
        encodings=["x"],
        mark=alt.BrushConfig(stroke='black', strokeWidth=2)
    )
    zoom = (
        base
            .mark_rect()
            .transform_aggregate(
                start='min(sequential_site)',
                stop='max(sequential_site)',
                groupby=['region']
            )
            .encode(
                x=alt.X('start:Q', title='click and drag to zoom'),
                x2=alt.X2('stop:Q'),
                color=alt.Color(
                    'region',
                    type="nominal",
                    scale=alt.Scale(scheme='set3'),
                    legend=None,
                ),
                tooltip=[
                    alt.Tooltip('region', title='Region:'),
                ]
            )
            .properties(
                width=width,
                height=15
            )
            .add_params(
                zoom_selection
            )
    )

    # Dynamic selections to configure the chart 
    site_selection = alt.selection_point(fields=["site"], empty=False, value=default_site, on="click")
    stat_dropdown = alt.binding_select(options=list(stats.keys()), name='Summary Statistic: ')
    stat_selection = alt.param(value=default_summary, bind=stat_dropdown)
    baseline_dropdown = alt.binding_select(options=conditions, name='Select Baseline: ')
    baseline_selection = alt.param(value=default_comparison[0], bind=baseline_dropdown)
    comparison_dropdown = alt.binding_select(options=conditions, name='Select Comparison: ')
    comparison_selection = alt.param(value=default_comparison[1], bind=comparison_dropdown)

    # Dynamic titles for each chart
    summary_title = (
        alt.Title(
            alt.expr(
                f'"Difference in " + {stat_selection.name} + " effect " + "(" + {baseline_selection.name} + " - " + {comparison_selection.name} + ")"')
            )
    )
    mutant_title = (
        alt.Title(
            alt.expr(f'"Site: " + {site_selection.name}.site')
        )
    )

    # Transform the data for the summary plot
    transform_dropdown = alt.binding_select(options=['clamp', 'arcsinh', 'none'], name=f'Transform (x={x}, k={k}): ')
    transform = alt.param(value='none', bind=transform_dropdown)
    transforms = {
        'clamp': f'clamp(datum.effect, {x}, max(datum.effect))',
        'arcsinh': f'datum.effect < 0 ? log(datum.effect / {k} + sqrt(pow(datum.effect / {k}, 2) + 1)) * {k} : datum.effect',
        'none': 'datum.effect'
    }
    
    # Transform, aggregate, pivot, and calculate the difference between conditions
    transform = (
        base
            .transform_calculate(
                **transforms
            )
            .transform_calculate(
                effect=f'datum[{transform.name}]'
            )
            .add_params(
                transform
            )
    )
    aggregate = (
        transform
            .transform_aggregate(
                **stats,
                groupby=['sequential_site', condition_col]
            )
            .transform_calculate(
                summary=f'datum[{stat_selection.name}]'
            )
            .add_params(
                stat_selection
            )
    )
    pivot = (
        aggregate
            .transform_pivot(
                condition_col,
                groupby=['sequential_site'],
                value="summary"
            )
    )
    difference = (
        pivot
            .transform_calculate(
                difference=f'datum[{baseline_selection.name}] - datum[{comparison_selection.name}]'
            )
            .add_params(
                baseline_selection,
                comparison_selection
            ).encode(
                x=alt.X('sequential_site:Q', title="Site").scale(domain=zoom_selection),
                y=alt.Y("difference:Q", title="Δ(Baseline, Comparison)"),
                tooltip=summary_tooltips
            )
    )
    line = (
        difference
            .mark_line(
                point=False,
                color=primary_color,
            )
    )
    point = (
        difference
            .mark_point(
                filled=True,
                opacity=1,
            )
            .encode(
               size=alt.condition(site_selection, alt.value(200), alt.value(50)),
               color=alt.condition(site_selection, alt.value(active_color), alt.value(primary_color)), 
            )
    )
    summary = (
        (line + point)
            .properties(
                title=summary_title,
                width=width,
                height=height
            )
            .add_params(
                site_selection
            )
    )

    # Mutant-level effects for each site
    residues = (
        base
            .transform_pivot(
                condition_col,
                groupby=['site', 'mutant', 'wildtype'],
                value="effect"
            )
            .transform_lookup(
                lookup='mutant',
                from_=alt.LookupData(
                    data=aa_color_df,
                    key='mutant',
                    fields=['color']
                ),
            )
            .mark_text(
                size=16
            )
            .encode(
                x=alt.X(
                    'x:Q',
                    scale=alt.Scale(domain=mutant_domain),
                    title="Baseline"
                ),
                y=alt.Y(
                    'y:Q',
                    scale=alt.Scale(domain=mutant_domain),
                    title="Comparison"
                ),
                text="mutant:N",
                color=alt.Color('color:N').scale(None),
                tooltip=mutant_tooltips,
            )
            .transform_calculate(
                x=f'datum[{baseline_selection.name}]',
                y=f'datum[{comparison_selection.name}]',
            )
            .transform_filter( 
                site_selection
            )
            .add_params(
                baseline_selection,
                comparison_selection,
                site_selection
            )
    )
    if interactive:
        residues = residues.interactive()
        
    rule = (
        alt.Chart()
            .mark_rule(
                color=primary_color,
                strokeWidth=3,
                strokeCap="round",
                strokeDash=[8, 8],
            )
            .encode(
                x=alt.datum(
                    mutant_domain[1],
                    type="quantitative",
                    scale=alt.Scale(domain=mutant_domain)
                ),
                y=alt.datum(
                    mutant_domain[1], 
                    type="quantitative", 
                    scale=alt.Scale(domain=mutant_domain)
                ),
                x2=alt.datum(mutant_domain[0]),
                y2=alt.datum(mutant_domain[0]),
            )
    )
    mutants = (
        (rule + residues)
            .properties(
                width=height,
                height=height,
                title=mutant_title
            )
    )
    
    # Assemble the final chart
    chart = (
        (summary & zoom) | mutants
    ).configure_point(
        size=50
    ).configure_axis(
        labelFontSize=13,
        titleFontSize=16
    ).configure_legend(
        labelFontSize=13,
        titleFontSize=16
    )

    return chart

In [None]:
alt.renderers.set_embed_options(
    padding={"left": 25, "right": 100, "bottom": 25, "top": 25}
)

chart = plot_site_level_comparison(
    data=mut_effects_tidy,
    default_comparison=list(cells)[: 2],
)

chart

In [None]:
def plot_transform(data):
    """
    Plot different transformation functions on the data 
    to see their effect on raw vs. transformed values.
    """

    min, max = data['effect'].min(), data['effect'].max()

    base = alt.Chart(data)

    x_slider = alt.binding_range(min=min, max=0, step=0.5, name='x: ')
    x = alt.param(value=-4, bind=x_slider)
    k_slider = alt.binding_range(min=0, max=10, step=0.1, name='k:')
    k = alt.param(value=3, bind=k_slider)
    transform_dropdown = alt.binding_select(options=['clamp', 'arcsinh'], name='Transform: ')
    transform = alt.param(value='clamp', bind=transform_dropdown)

    transform_expr = {
        'clamp': f'clamp(datum.effect, {x.name}, max(datum.effect))',
        'arcsinh': f'datum.effect < 0 ? log(datum.effect / {k.name} + sqrt(pow(datum.effect / {k.name}, 2) + 1)) * {k.name} : datum.effect',
    }
    
    points = (
        base
            .mark_point()
            .encode(
                x=alt.X('effect:Q', title="Raw"),
                y=alt.Y('transformed:Q', title="Transformed"),
                tooltip=['effect:Q', 'transformed:Q']
            )
            .transform_calculate(
                transformed = f"{transform.name} == 'clamp' ? {transform_expr['clamp']} : {transform_expr['arcsinh']}"
            ).add_params(
                x, k, transform
            )
    )
    rule = (
        base
            .mark_rule(strokeWidth=1, strokeDash=[4, 4], color='steelblue')
            .encode(
                x=alt.datum(min),
                y=alt.datum(min),
                x2=alt.datum(max),
                y2=alt.datum(max),
            )
    )
    xrule = (
        base
            .mark_rule(strokeWidth=1, strokeDash=[8, 8])
            .encode(
                x=alt.datum(0),
            )
    )
    yrule = (
        base
            .mark_rule(strokeWidth=1, strokeDash=[8, 8])
            .encode(
                y=alt.datum(0),
            )
    )

    chart = xrule + yrule + rule + points

    return chart
    

In [None]:
chart