# Analyze and plot correlations among fitness estimates and DMS measurements

Import Python modules:

In [1]:
import itertools

import altair as alt

import numpy

import pandas as pd

import yaml

_ = alt.data_transformers.disable_max_rows()

Now get variables from `snakemake`:

In [2]:
if "snakemake" not in globals() and "snakemake" not in locals():
    # variables set manually for interactive debugging
    aa_fitness_csv = "../results/aa_fitness/aa_fitness.csv"
    neher_fitness_csv = "../data/Neher_aa_fitness.csv"
    plotsdir = "../results/fitness_dms_corr/plots"
    
    with open("../config.yaml") as f:
        config = yaml.safe_load(f)
    min_expected_count = config["min_expected_count"]
    dms_datasets = config["dms_datasets"]
    
    dms_datasets_csvs = {
        dms_dataset: f"../results/dms/{dms_dataset}/processed.csv"
        for dms_dataset in dms_datasets
    }
    
else:
    # get variables from `snakemake` when running pipeline
    aa_fitness_csv = snakemake.input.aafitness
    neher_fitness_csv = snakemake.input.neher_fitness
    plotsdir = snakemake.output.plotsdir
    min_expected_count = snakemake.params.min_expected_count
    dms_datasets = snakemake.params.dms_datasets
    dms_datasets_csvs = {
        dms_dataset: getattr(snakemake.input, dms_dataset)
        for dms_dataset in dms_datasets
    }

Read the datasets:

In [3]:
aa_fitness = pd.read_csv(aa_fitness_csv)
neher_fitness = pd.read_csv(neher_fitness_csv)
dms_dataset_dfs = {
    dms_dataset: pd.read_csv(f) for dms_dataset, f in dms_datasets_csvs.items()
}

## Correlation of fitness estimates with Neher fitness estimates
First examining correlations between amino-acid fitness values from the current approach and the Neher estimates (which are just for spike).
To do this, we have to get the estimates to share a common "wildtype" identity, which we choose as whatever is the Neher et al identity.
Note this is just for spike as we only have Neher et al estimates for spike:

In [4]:
spike_fitness = (
    aa_fitness
    .query("gene == 'S'")
    .rename(columns={"aa_site": "site"})
    [["site", "aa", "fitness", "expected_count"]]
)

neher_corr_df = (
    neher_fitness
    .query("aa_fitness.notnull()")
    .assign(
        wildtype=lambda x: x["mutation"].str[0],
        site=lambda x: x["mutation"].str[1: -1].astype(int),
        mutant=lambda x: x["mutation"].str[-1],
    )
    .rename(columns={"aa_fitness": "Neher fitness effect"})
    [["wildtype", "site", "mutant", "Neher fitness effect"]]
    .merge(
        spike_fitness.rename(
            columns={
                "aa": "wildtype",
                "fitness": "wt_fitness",
                "expected_count": "expected_count_wt",
            }
        ),
        on=["site", "wildtype"],
        how="inner",
        validate="many_to_one",
    )
    .merge(
        spike_fitness.rename(
            columns={
                "aa": "mutant",
                "fitness": "mut_fitness",
                "expected_count": "expected_count_mut",
            }
        ),
        on=["site", "mutant"],
        how="inner",
        validate="many_to_one",
    )
    .query("wildtype != mutant")
    .assign(
        expected_count=lambda x: numpy.minimum(
            x["expected_count_wt"], x["expected_count_mut"],
        ),
        fitness_effect=lambda x: x["mut_fitness"] - x["wt_fitness"],
        mutation=lambda x: x["wildtype"] + x["site"].astype(str) + x["mutant"],
    )
    [["mutation", "fitness_effect", "Neher fitness effect", "expected_count"]]
)

display(neher_corr_df.head())

neher_corr_df.corr()

Unnamed: 0,mutation,fitness_effect,Neher fitness effect,expected_count
0,A1015D,-4.241,3.777431,34.24
1,A1015S,-2.2782,5.823637,316.67
2,A1015T,-3.7103,3.633602,142.53
3,A1015V,-4.7398,2.729965,514.34
4,A1015G,-2.07,6.00177,11.387


Unnamed: 0,fitness_effect,Neher fitness effect,expected_count
fitness_effect,1.0,0.869017,-0.276794
Neher fitness effect,0.869017,1.0,-0.150465
expected_count,-0.276794,-0.150465,1.0


Plot correlation.
First define a function to do this, then make the plots:

In [5]:
def plot_corr_scatters(corr_df_tidy, ncols=4, ignore_pairs=set()):
    "Plot set of correlation scatters."""
    
    subsets = corr_df_tidy["subset"].unique()
    
    corr_df_wide = (
        corr_df_tidy
        .pivot_table(
            index=["mutation", "expected_count"],
            values="delta_fitness",
            columns="subset",
        )
        .reset_index()
        .rename(columns={subset: f"delta_fitness {subset}" for subset in subsets})
    )

    expected_count_selection = alt.selection_single(
        bind=alt.binding_range(
            min=1,
            max=min(5 * min_expected_count, corr_df_tidy["expected_count"].quantile(0.9)),
            step=1,
            name="minimum expected count",
        ),
        fields=["cutoff"],
        init={"cutoff": min_expected_count},
    )
    
    corr_charts = []
    base_chart = alt.Chart(corr_df_wide)
    for subset2, subset1 in itertools.combinations(subsets, 2):
        if {subset1, subset2} in ignore_pairs:
            continue
        
        x_extent = (
            corr_df_wide[f"delta_fitness {subset1}"].min(),
            corr_df_wide[f"delta_fitness {subset1}"].max(),
        )
        y_extent = (
            corr_df_wide[f"delta_fitness {subset2}"].min(),
            corr_df_wide[f"delta_fitness {subset2}"].max(),
        )
        
        base = (
            base_chart
            .encode(
                x=alt.X(
                    f"delta_fitness {subset1}",
                    title=subset1,
                    scale=alt.Scale(zero=False, domain=x_extent, nice=False),
                ),
                y=alt.Y(
                    f"delta_fitness {subset2}",
                    title=subset2,
                    scale=alt.Scale(zero=False, domain=y_extent, nice=False),
                ),
                tooltip=[
                    "mutation",
                    alt.Tooltip(
                        f"delta_fitness {subset1}", title=subset1, format=".4g",
                    ),
                    alt.Tooltip(
                        f"delta_fitness {subset2}", title=subset2, format=".4g",
                    ),
                    alt.Tooltip("expected_count"),
                ],
            )
            .mark_circle(opacity=0.3)
            .properties(width=200, height=200)
            .transform_filter(
                alt.datum["expected_count"] >= expected_count_selection["cutoff"] - 1e-6
            )
        )
    
        scatter = (
            base
            .mark_circle(opacity=0.3)
        )
    
        # regression line and correlation coefficient: https://stackoverflow.com/a/60239699
        line = (
            base
            .transform_regression(
                f"delta_fitness {subset1}",
                f"delta_fitness {subset2}",
                extent=x_extent,
            )
            .mark_line(color="orange", clip=True)
        )
    
        # show correlation coefficient
        params_r = (
            base
            .transform_regression(
                f"delta_fitness {subset1}",
                f"delta_fitness {subset2}",
                params=True,
            )
            .transform_calculate(
                r=alt.expr.sqrt(alt.datum["rSquared"]),
                label='"r = " + format(datum.r, ".3f")',
            )
            .mark_text(align="left", color="orange", fontWeight="bold")
            .encode(
                x=alt.value(5),
                y=alt.value(8),
                text=alt.Text("label:N"),
            )
        )
        
        # show number of points
        params_n = (
            base
            .transform_calculate(
                # add together to get dummy variable so null if either null
                dummy=(
                    alt.datum[f"delta_fitness {subset1}"]
                    + alt.datum[f"delta_fitness {subset2}"]
                ),
            )
            .transform_aggregate(
                n="valid(dummy)",
            )
            .transform_calculate(
                label='"n = " + datum.n',
            )
            .mark_text(align="left", color="orange", fontWeight="bold")
            .encode(
                x=alt.value(5),
                y=alt.value(20),
                text=alt.Text("label:N"),
            )
        )
    
        chart = (
            (scatter + line + params_r + params_n)
            .add_selection(expected_count_selection)
        )
    
        corr_charts.append(chart)

    rows = []
    for i in range(0, len(corr_charts), ncols):
        rows.append(alt.hconcat(*corr_charts[i: i + ncols]))
    corr_chart = alt.vconcat(*rows).configure_axis(grid=False)
    return corr_chart


neher_corr_df_tidy = (
    neher_corr_df
    .rename(columns={"fitness_effect": "fitness_effect (current study)"})
    .melt(
        id_vars=["mutation", "expected_count"],
        var_name="subset",
        value_name="delta_fitness",
    )
)

neher_corr_chart = plot_corr_scatters(neher_corr_df_tidy)
neher_corr_chart

## Correlation of current and Neher fitness estimates with each spike deep mutational scanning dataset
Do the current or Neher fitness estimates correlate better with deep mutational scanning?
For each deep mutational scanning dataset, correlate the fitness effects estimated from the current study and the Neher et al fitness estimates:

In [6]:
for dms_dataset, dms_df in dms_dataset_dfs.items():
    
    dms_config = dms_datasets[dms_dataset]
    gene = dms_config["gene"]
    
    if gene != "S":
        continue  # we only have Neher estimates for spike
    
    desc = dms_config["description"]
    print(f"\nAnalyzing {dms_dataset=} for {gene=}")
    if "filter_cols" in dms_config:
        for filter_col, filter_val in dms_config["filter_cols"].items():
            print(f"Filtering for {filter_col} >= {filter_val}")
            dms_df = dms_df.query(f"{filter_col} >= {filter_val}").drop(columns=filter_col)
            
    dms_cols = {
        c: f"DMS {c.replace('_', ' ')} {desc}"
        for c in dms_df
        if c not in {"site", "wildtype", "mutant"}
    }
    
    # do not correlate total DMS effect with individual assays
    ignore_pairs = {frozenset({f"DMS effect {desc}", c}) for c in dms_cols.values()}
    
    dms_df = dms_df.rename(columns=dms_cols)
    if dms_df["site"].dtype != int:
        # only keep integer sites as they are only ones with fitness estimates
        dms_df = dms_df[dms_df["site"].str.fullmatch("\d+")]
        dms_df["site"] = dms_df["site"].astype(int)
    
    gene_fitness = (
        aa_fitness
        .query("gene == @gene")
        .rename(columns={"aa_site": "site"})
        [["site", "aa", "fitness", "expected_count"]]
    )
    
    corr_df_wide = (
        dms_df
        .merge(
            gene_fitness.rename(
                columns={
                    "aa": "wildtype",
                    "fitness": "wt_fitness",
                    "expected_count": "expected_count_wt",
                }
            ),
            on=["site", "wildtype"],
            how="inner",
            validate="many_to_one",
        )
        .merge(
            gene_fitness.rename(
                columns={
                    "aa": "mutant",
                    "fitness": "mut_fitness",
                    "expected_count": "expected_count_mut",
                }
            ),
            on=["site", "mutant"],
            how="inner",
            validate="many_to_one",
        )
        .query("wildtype != mutant")
        .assign(
            expected_count=lambda x: numpy.minimum(
                x["expected_count_wt"], x["expected_count_mut"],
            ),
            fitness_effect=lambda x: x["mut_fitness"] - x["wt_fitness"],
            mutation=lambda x: x["wildtype"] + x["site"].astype(str) + x["mutant"],
        )
        [["mutation", "fitness_effect", "expected_count", *dms_cols.values()]]
        .rename(columns={"fitness_effect": "fitness effect"})
    )
    
    if gene == "S":
        corr_df_wide = corr_df_wide.merge(
            neher_fitness
            [["mutation", "aa_fitness"]]
            .rename(columns={"aa_fitness": "Neher fitness effect"})
        )
    
    corr_df = (
        corr_df_wide
        .melt(
            id_vars=["mutation", "expected_count"],
            var_name="subset",
            value_name="delta_fitness",
        )
    )

    corr_chart = plot_corr_scatters(corr_df, ignore_pairs=ignore_pairs)
    display(corr_chart)


Analyzing dms_dataset='dadonaite_ba1_spike' for gene='S'
Filtering for times_seen >= 3



Analyzing dms_dataset='starr_rbd' for gene='S'


In [7]:
neher_fitness

Unnamed: 0.1,Unnamed: 0,mutation,aa_fitness
0,0,A1015D,3.777431
1,1,A1015S,5.823637
2,2,A1015T,3.633602
3,3,A1015V,2.729965
4,4,A1016E,3.777431
...,...,...,...
15253,15253,L1034H,
15254,15254,L1034K,
15255,15255,L1034T,
15256,15256,L1034Y,


In [8]:
corr_df_wide.merge(neher_fitness[["mutation", "aa_fitness"]].rename(columns={"aa_fitness": "Neher fitness effect"}))

Unnamed: 0,mutation,fitness effect,expected_count,"DMS effect RBD (Starr et al, 2022)","DMS ACE2 affinity RBD (Starr et al, 2022)","DMS expression RBD (Starr et al, 2022)",Neher fitness effect
0,A344D,-1.10550,34.2400,-0.716370,-0.487320,-0.94542,5.836648
1,A344G,-0.97135,11.3870,-0.714570,-0.445410,-0.98374,6.693679
2,A344P,-1.36840,17.1820,-1.527700,-0.924900,-2.13040,5.349514
3,A344S,-0.10869,316.6700,-0.323250,-0.222570,-0.42392,7.158958
4,A344T,-0.46883,142.5300,-0.736950,-0.464930,-1.00900,6.628574
...,...,...,...,...,...,...,...
1149,Y508D,-1.72430,7.9128,-2.431500,-2.737100,-2.12580,4.806662
1150,Y508F,-2.55470,18.8010,-0.077242,-0.004363,-0.15012,4.487906
1151,Y508H,0.79801,58.2540,0.080949,0.048525,0.11337,7.057575
1152,Y508N,-1.63980,12.3850,-0.598350,-0.412150,-0.78454,5.481164
