# Compare Cell Entry Effects

In this notebook, we'll investigate whether the same mutation affects cell entry differently across three cell types (**293T-Mxra8**, **293T-TIM1**, and **C6/36**). 

Although *Mxra8* serves as a receptor and *TIM1* as an entry factor for CHIKV in humans, the mosquito receptor remains unknown. By identifying sites where mutations affect cell entry differently in mosquito cells (**C6/36**) than in human cells (**293T-Mxra8** and **293T-TIM1**), we may uncover sites involved in binding to the unknown mosquito receptor. 


In [None]:
import pandas as pd
import numpy as np
import altair as alt
import itertools

# Remove the limit of ~5000 rows -- maybe there are better ways? (https://altair-viz.github.io/user_guide/large_datasets.html)
_ = alt.data_transformers.disable_max_rows()

## Read the data

For this analysis, we'll need the effects of mutations on cell entry in each cell line and the annotations of each site in CHIKV E (E1, E2, E3, 6K).

We'll use cell entry data from selections in three cell lines:

1. **293T-Mxra8**: Human cells over-expressing the receptor Mxra8.
2. **293T-TIM1**: Human cells over-expressing the entry factor TIM1.
3. **C6/36**: Mosquito midgut cells.

These are pre-filtered (for QC metrics) values:

In [None]:
mut_effects_csv = "../results/summaries/entry_293T-Mxra8_C636_293T-TIM1_Mxra8-binding.csv"

print(f"Reading mutation effects from {mut_effects_csv=}")
mut_effects = pd.read_csv(mut_effects_csv)

mut_effects

Get the data tidy format:

In [None]:
# cells and their names in input file
cells = {"293T-Mxra8": "293T_Mxra8", "C6/36": "C636", "293T-TIM1": "293T_TIM1"}

col_to_cell = {f"entry in {label} cells": cell for (cell, label) in cells.items()}

assert set(col_to_cell).issubset(mut_effects.columns), f"{col_to_cell=}, {mut_effects.columns=}"

mut_effects_tidy = (
    mut_effects.rename(columns=col_to_cell)
    .melt(
        id_vars=["site", "sequential_site", "wildtype", "mutant", "region"],
        value_vars=col_to_cell.values(),
        var_name="cell",
        value_name="effect",
    )
    .sort_values("sequential_site")
)

mut_effects_tidy

## Scatter plots of cell entry for each cell

How does the same mutation affect entry in each cell line? We'll plot the effect of each mutation between pairs of cell lines to determine if there are global differences.

In [None]:
def plot_mutation_level_comparison(
    data,
    condition,
    value,
    groupby=['site', 'mutant', 'wildtype', 'sequential_site'],
    color=None,
    label_suffix="",
    init_floor_value=-6,
):
    """
    Make an Altair scatter plot comparing mutant-level values for each condition.

    Parameters
    ----------
    data : pd.DataFrame
        The long-form data to plot
    conditions: str
        The column containing the condition labels (i.e. TIM1, MXRA8, C636)
    value : str
        The column containing the values to compare between conditions
    groupby : list of str
        The columns to group the data on (i.e. ['site', 'mutant', 'wildtype'])
    color : str
        The column to color the points and add an interactive legend for
    label_suffix : str
        Label suffixed to x- and y-axis labels.
    init_floor_value : float or None
        Initial value for floor slider for values.

    Returns
    -------
    alt.Chart
        The Altair chart object
    """
    
    if 'mutant' not in groupby or 'site' not in groupby:
        raise ValueError("groupby must contain 'mutant' and 'site'")
    
    missing_cols = [col for col in [condition, value] + groupby if col not in data.columns]
    if missing_cols:
        raise ValueError(f"Columns are missing from the data: {missing_cols}")
    
    if color is not None:
        if color not in data.columns:
            raise ValueError(f"Color column '{color}' not found in data")
        groupby.append(color)
    
    conditions = data[condition].unique()

    # pivot the data
    data_wide = (
        data
        .pivot_table(index=groupby, columns=condition, values=value)
        .reset_index()
    )

    tooltips = []
    for col in groupby:
        tooltips.append(alt.Tooltip(f'{col}:N'))
    for col in conditions:
        tooltips.append(alt.Tooltip(f'{col}:Q', format=".2f"))

    brush = alt.selection_interval()
    
    mut_selection = alt.selection_point(on="mouseover", fields=groupby, empty=False)

    min_value_slider = alt.param(
        name="min_value_slider",
        bind=alt.binding_range(
            min=min(data[value]),
            max=max(data[value]),
            name="floor values at this number",
        ),
        value=(
            max(init_floor_value, min(data[value]))
            if init_floor_value is not None
            else min(data[value])
        ),
    )

    base = (
        alt.Chart(data_wide)
        .add_params(mut_selection, brush, min_value_slider)
        .transform_filter(brush)
    )

    scatters = []
    for condition_a, condition_b in itertools.combinations(conditions, 2):
        # Base data for the scatter plot
        scatter = base.transform_filter(
            f'isValid(datum["{condition_a}"]) && isValid(datum["{condition_b}"])'
        ).transform_calculate(
            condition_a_floored=f'max(datum["{condition_a}"], min_value_slider)',
            condition_b_floored=f'max(datum["{condition_b}"], min_value_slider)',
        ).encode(
            x=alt.X(
                "condition_a_floored:Q",
                title=condition_a + label_suffix,
                scale=alt.Scale(padding=10, nice=False, zero=False),
                axis=alt.Axis(titleFontSize=14, labelFontSize=11, labelOverlap="greedy"),
            ),
            y=alt.Y(
                "condition_b_floored:Q",
                title=condition_b + label_suffix,
                scale=alt.Scale(padding=10, nice=False, zero=False),
                axis=alt.Axis(titleFontSize=14, labelFontSize=11, labelOverlap="greedy"),
            ),
        ).properties(
            title=alt.TitleParams(f'{condition_a} vs {condition_b}', fontSize=16),
            width=250,
            height=250
        )
        # Background points to show the full range of data when brushing
        background = scatter.mark_point(
            filled=True,
            size=25,
            color='lightgray',
            opacity=0.3,
        )
        # Foreground points have tooltips and respond to brushing (and legend selection)
        if color is not None:
            selection = alt.selection_point(fields=[color], bind='legend')
            foreground = scatter.mark_point(
                filled=True,
                fillOpacity=0.5,
                stroke="black",
                strokeOpacity=1,
            ).encode(
                color=alt.Color(color, type='nominal').scale(domain=data[color].unique()),
                strokeWidth=alt.condition(mut_selection, alt.value(3), alt.value(0)),
                size=alt.condition(mut_selection, alt.value(80), alt.value(40)),
                tooltip=tooltips,
            ).add_params(
                selection
            ).transform_filter(selection)
        else:
            foreground = scatter.mark_point(
                filled=True,
                color='steelblue',
                fillOpacity=0.5,
                stroke="black",
                strokeOpacity=1,
            ).encode(
                tooltip=tooltips,
                strokeWidth=alt.condition(mut_selection, alt.value(3), alt.value(0)),
                size=alt.condition(mut_selection, alt.value(70), alt.value(35)),
            )

        scatters.append((background + foreground))

    chart = alt.hconcat(*scatters).configure_axis(grid=False).configure_legend(
        titleFontSize=14, labelFontSize=14
    )

    return chart

In [None]:
compare_mutation_effects = plot_mutation_level_comparison(
    mut_effects_tidy,
    "cell",
    "effect", 
    color="region",
    label_suffix=" cell entry",
    init_floor_value=-6,
)

compare_mutation_effects

- *Mouseover* on points to see a tooltip with information about that mutation.
- *Hold Click and Drag* over points to show only those mutations.
- *Click* on conditions in the legend to show only that condition (`region`).
- *Use the slider* to floor values at some mimum plot value.
- *Double Click* on the plot or legend to reset the plot.

Points with color show the active selection and gray points show total distribution of the data.

## Compare cell entry for each site

While scanning all mutations for global patterns is useful, our primary focus is on *sites* involved in binding to the unidentified mosquito receptor. To identify these sites, we will compare the aggregated differences in mutation effects on cell entry at each site.

But how should we aggregate effects and calculate the difference to extract the most biologically meaningful information from the data?

### Aggregating cell entry effects

There are several ways to summarize mutation effects at each site. The mean (or median) is the most straightforward, but it's not a great approach if only a few mutation are expected to contribute to the largest phenotypic differences. You can choose how to summarize the data in the plot menu.

### Calculating the difference

In our assay, differences in cell entry closer to zero (wildtype) are more biologically meaningful than differences further from zero. This is especially true for negative values. 

Imagine three mutations—$A$, $B$, and $C$—where $A$ represents the wildtype ($0$), $B$ is a deleterious mutation with an effect of $-2$, and $C$ is a *more* deleterious mutation with an effect of $-4$. While the difference in effect size between these mutations is evenly spaced ($A → B = B → C = -2$), the **relative biological impact** of substituting $A$ with $B$ is greater than substituting $B$ with $C$.

I can think of two ways to emphasize **biologically** impactful differences:

1. **Transform** the data before taking the difference.
    - Apply a heuristic cutoff at which all values greater than the cutoff are reduced to the cutoff.
    - Apply a function to compress values with large absolute values or very negative value.

2. **Weight** the difference based on the midpoint of the comparison.
    - Apply a fractional weight pegged to the midpoint of difference.

In [None]:
def plot_site_level_comparison(data):
    """
    Plot the aggregated (site-level) difference of cell entry effects between conditions.
    """
    # Top-level variables for the chart
    metric = "effect"
    min, max = data[metric].min(), data[metric].max()
    default_summary = "mean"
    default_comparison = ["MXRA8", "TIM1"]
    conditions = data['condition'].unique()
    mutant_domain = [min + min * .2, max + max * .2]
    filters = {
        "times_seen": 2,
        "n_selections": 2
    }
    stats = {
        stat: f'{stat}({metric})' 
        for stat in ["mean", "median", "min", "max", "sum"]
    }
    x = -4
    k = 3 
    
    # Calculate a sensible default site to display
    default_site = (
        data 
            .query('mutant not in ["*", "-"]')
            .groupby(['condition', 'site', 'wildtype'])
            .agg({metric: [default_summary]})
            .loc[:, metric]
            .loc[:, default_summary]
            .reset_index()
            .pivot(index=['site', 'wildtype'], columns='condition', values=default_summary)
            .loc[:, default_comparison]
            .assign(difference=lambda x: x[default_comparison[0]] - x[default_comparison[1]])
            .abs()
            .difference
            .idxmax()[0]
    )

    # Top-level configuration for the chart
    primary_color = '#f8b196'
    active_color = '#c06c84'
    width = 1000
    height = 300
    interactive = True

    # Canonically color the amino acids by functional group
    # https://jbloomlab.github.io/dmslogo/dmslogo.colorschemes.html
    AA_FUNCTIONAL_GROUP = {
        "G": "#f76ab4",
        "A": "#f76ab4",
        "S": "#ff7f00",
        "T": "#ff7f00",
        "C": "#ff7f00",
        "V": "#12ab0d",
        "L": "#12ab0d",
        "I": "#12ab0d",
        "M": "#12ab0d",
        "P": "#12ab0d",
        "F": "#84380b",
        "Y": "#84380b",
        "W": "#84380b",
        "D": "#e41a1c",
        "E": "#e41a1c",
        "H": "#3c58e5",
        "K": "#3c58e5",
        "R": "#3c58e5",
        "N": "#972aa8",
        "Q": "#972aa8",
        "*": "#A9A9A9",
        "-": "#A9A9A9",
    }
    aa_color_df = pd.DataFrame({'mutant': AA_FUNCTIONAL_GROUP.keys(), 'color': AA_FUNCTIONAL_GROUP.values()})

    # Define the tooltips for the chart
    mutant_tooltips = [
        alt.Tooltip('wildtype:N', title="Wildtype"),
        alt.Tooltip('x:Q', title="Baseline", format=".2f"),
        alt.Tooltip('y:Q', title="Comparison", format=".2f")
    ]
    summary_tooltips = []
    for col in ['site']:
        summary_tooltips.append(alt.Tooltip(f'{col}:N'))
    for col in conditions:
        summary_tooltips.append(alt.Tooltip(f'{col}:Q', format=".2f"))

    # The data is shared across all parts of the final charts
    base = alt.Chart(
        data.query('mutant not in ["*", "-"]')
    )

    # Drag to zoom into sites on the x-axis colored by region
    zoom_selection = alt.selection_interval(
        encodings=["x"],
        mark=alt.BrushConfig(stroke='black', strokeWidth=2)
    )
    zoom = (
        base
            .mark_rect()
            .transform_aggregate(
                start='min(site)',
                stop='max(site)',
                groupby=['region']
            )
            .encode(
                x=alt.X('start:Q', title='click and drag to zoom'),
                x2=alt.X2('stop:Q'),
                color=alt.Color(
                    'region',
                    type="nominal",
                    scale=alt.Scale(scheme='set3'),
                    legend=None,
                ),
                tooltip=[
                    alt.Tooltip('region', title='Region:'),
                ]
            )
            .properties(
                width=width,
                height=15
            )
            .add_params(
                zoom_selection
            )
    )

    # Add filters to the data *after* making the zoom bar
    exprs = []
    selectors = []
    for col, val in filters.items():
        selector = alt.selection_point(
            name=col,
            fields=[col],
            bind=alt.binding_range(min=data[col].min(), max=data[col].max(), step=1, name=col),
            value=val,
        )
        selectors.append(selector)
        exprs.append(f"datum.{col} >= {selector.name}.{col}")
    filtered = base.add_params(
        *selectors
    ).transform_filter(
        ' & '.join(exprs)
    )

    # Dynamic selections to configure the chart 
    site_selection = alt.selection_point(fields=["site"], empty=False, value=default_site, on="click")
    stat_dropdown = alt.binding_select(options=list(stats.keys()), name='Summary Statistic: ')
    stat_selection = alt.param(value=default_summary, bind=stat_dropdown)
    baseline_dropdown = alt.binding_select(options=conditions, name='Select Baseline: ')
    baseline_selection = alt.param(value=default_comparison[0], bind=baseline_dropdown)
    comparison_dropdown = alt.binding_select(options=conditions, name='Select Comparison: ')
    comparison_selection = alt.param(value=default_comparison[1], bind=comparison_dropdown)

    # Dynamic titles for each chart
    summary_title = (
        alt.Title(
            alt.expr(
                f'"Difference in " + {stat_selection.name} + " effect " + "(" + {baseline_selection.name} + " - " + {comparison_selection.name} + ")"')
            )
    )
    mutant_title = (
        alt.Title(
            alt.expr(f'"Site: " + {site_selection.name}.site')
        )
    )

    # Transform the filter data for the summary plot
    transform_dropdown = alt.binding_select(options=['clamp', 'arcsinh', 'none'], name=f'Transform (x={x}, k={k}): ')
    transform = alt.param(value='none', bind=transform_dropdown)
    transforms = {
        'clamp': f'clamp(datum.effect, {x}, max(datum.effect))',
        'arcsinh': f'datum.effect < 0 ? log(datum.effect / {k} + sqrt(pow(datum.effect / {k}, 2) + 1)) * {k} : datum.effect',
        'none': 'datum.effect'
    }
    
    # Transform, aggregate, pivot, and calculate the difference between conditions
    transform = (
        filtered
            .transform_calculate(
                **transforms
            )
            .transform_calculate(
                effect=f'datum[{transform.name}]'
            )
            .add_params(
                transform
            )
    )
    aggregate = (
        transform
            .transform_aggregate(
                **stats,
                groupby=['site', 'condition']
            )
            .transform_calculate(
                summary=f'datum[{stat_selection.name}]'
            )
            .add_params(
                stat_selection
            )
    )
    pivot = (
        aggregate
            .transform_pivot(
                'condition',
                groupby=['site'],
                value="summary"
            )
    )
    difference = (
        pivot
            .transform_calculate(
                difference=f'datum[{baseline_selection.name}] - datum[{comparison_selection.name}]'
            )
            .add_params(
                baseline_selection,
                comparison_selection
            ).encode(
                x=alt.X('site:Q', title="Site").scale(domain=zoom_selection),
                y=alt.Y("difference:Q", title="Δ(Baseline, Comparison)"),
                tooltip=summary_tooltips
            )
    )
    line = (
        difference
            .mark_line(
                point=False,
                color=primary_color,
            )
    )
    point = (
        difference
            .mark_point(
                filled=True,
                opacity=1,
            )
            .encode(
               size=alt.condition(site_selection, alt.value(200), alt.value(50)),
               color=alt.condition(site_selection, alt.value(active_color), alt.value(primary_color)), 
            )
    )
    summary = (
        (line + point)
            .properties(
                title=summary_title,
                width=width,
                height=height
            )
            .add_params(
                site_selection
            )
    )

    # Mutant-level effects for each site
    residues = (
        filtered
            .transform_pivot(
                'condition',
                groupby=['site', 'mutant', 'wildtype'],
                value="effect"
            )
            .transform_lookup(
                lookup='mutant',
                from_=alt.LookupData(
                    data=aa_color_df,
                    key='mutant',
                    fields=['color']
                ),
            )
            .mark_text(
                size=16
            )
            .encode(
                x=alt.X(
                    'x:Q',
                    scale=alt.Scale(domain=mutant_domain),
                    title="Baseline"
                ),
                y=alt.Y(
                    'y:Q',
                    scale=alt.Scale(domain=mutant_domain),
                    title="Comparison"
                ),
                text="mutant:N",
                color=alt.Color('color:N').scale(None),
                tooltip=mutant_tooltips,
            )
            .transform_calculate(
                x=f'datum[{baseline_selection.name}]',
                y=f'datum[{comparison_selection.name}]',
            )
            .transform_filter( 
                site_selection
            )
            .add_params(
                baseline_selection,
                comparison_selection,
                site_selection
            )
    )
    if interactive:
        residues = residues.interactive()
        
    rule = (
        alt.Chart()
            .mark_rule(
                color=primary_color,
                strokeWidth=3,
                strokeCap="round",
                strokeDash=[8, 8],
            )
            .encode(
                x=alt.datum(
                    mutant_domain[1],
                    type="quantitative",
                    scale=alt.Scale(domain=mutant_domain)
                ),
                y=alt.datum(
                    mutant_domain[1], 
                    type="quantitative", 
                    scale=alt.Scale(domain=mutant_domain)
                ),
                x2=alt.datum(mutant_domain[0]),
                y2=alt.datum(mutant_domain[0]),
            )
    )
    mutants = (
        (rule + residues)
            .properties(
                width=height,
                height=height,
                title=mutant_title
            )
    )
    
    # Assemble the final chart
    chart = (
        (summary & zoom) | mutants
    ).configure_point(
        size=50
    ).configure_axis(
        labelFontSize=13,
        titleFontSize=16
    ).configure_legend(
        labelFontSize=13,
        titleFontSize=16
    )

    return chart

In [None]:
alt.renderers.set_embed_options(
    padding={"left": 25, "right": 100, "bottom": 25, "top": 25}
)

In [None]:
chart = plot_site_level_comparison(combined_func_effects)
chart.save('./plots/compare_site_effects.html')

## Testing different scaling approaches

Two approach I can think to compress the data are to `clamp` the data at some value $x$, or to apply a function like `arcsinh` to compress the data with a scaling factor $k$.

You can test these methods with different parametes below.

In [None]:
def plot_transform(data):
    """
    Plot different transformation functions on the data 
    to see their effect on raw vs. transformed values.
    """

    min, max = data['effect'].min(), data['effect'].max()

    base = alt.Chart(data)

    x_slider = alt.binding_range(min=min, max=0, step=0.5, name='x: ')
    x = alt.param(value=-4, bind=x_slider)
    k_slider = alt.binding_range(min=0, max=10, step=0.1, name='k:')
    k = alt.param(value=3, bind=k_slider)
    transform_dropdown = alt.binding_select(options=['clamp', 'arcsinh'], name='Transform: ')
    transform = alt.param(value='clamp', bind=transform_dropdown)

    transform_expr = {
        'clamp': f'clamp(datum.effect, {x.name}, max(datum.effect))',
        'arcsinh': f'datum.effect < 0 ? log(datum.effect / {k.name} + sqrt(pow(datum.effect / {k.name}, 2) + 1)) * {k.name} : datum.effect',
    }
    
    points = (
        base
            .mark_point()
            .encode(
                x=alt.X('effect:Q', title="Raw"),
                y=alt.Y('transformed:Q', title="Transformed"),
                tooltip=['effect:Q', 'transformed:Q']
            )
            .transform_calculate(
                transformed = f"{transform.name} == 'clamp' ? {transform_expr['clamp']} : {transform_expr['arcsinh']}"
            ).add_params(
                x, k, transform
            )
    )
    rule = (
        base
            .mark_rule(strokeWidth=1, strokeDash=[4, 4], color='steelblue')
            .encode(
                x=alt.datum(min),
                y=alt.datum(min),
                x2=alt.datum(max),
                y2=alt.datum(max),
            )
    )
    xrule = (
        base
            .mark_rule(strokeWidth=1, strokeDash=[8, 8])
            .encode(
                x=alt.datum(0),
            )
    )
    yrule = (
        base
            .mark_rule(strokeWidth=1, strokeDash=[8, 8])
            .encode(
                y=alt.datum(0),
            )
    )

    chart = xrule + yrule + rule + points

    return chart
    

In [None]:
chart = plot_transform(combined_func_effects)
chart.save('./plots/transform_comparison.html')

In [None]:
chart