# Analyze comparator studies versus fitness estimates and deep mutational scanning for spike

In [None]:
import itertools

import altair as alt

import numpy

import pandas as pd

import scipy.stats.mstats

_ = alt.data_transformers.disable_max_rows()

In [None]:
# get variables from `snakemake`
aa_fitness_csv = snakemake.input.aa_fitness
corr_html = snakemake.output.corr_html
min_expected_count = snakemake.params.min_expected_count
comparator_studies = snakemake.params.comparator_studies

# get DMS datasets with added entry for data
dms_datasets = snakemake.params.dms_datasets
for dataset in dms_datasets:
    dms_datasets[dataset]["data"] = pd.read_csv(getattr(snakemake.input, dataset))

Read data from comparator studies:

In [None]:
for study, study_d in comparator_studies.items():
    
    datafile = getattr(snakemake.input, study)
    print(f"Reading data from {datafile}")
    
    if study == "rodriguez_rivas_dca":
        comparator_studies[study]["df"] = (
            pd.read_csv(datafile)
            .rename(
                columns={
                    "position_protein": "site",
                    "mutability_score(DCA)": "DCA mutability score",
                }
            )
            .query("protein == 'Spike'")
            .assign(gene="S")
            [["gene", "site", "DCA mutability score"]]
            .query("`DCA mutability score`.notnull()")
            .reset_index(drop=True)
        )   
        
    elif study == "maher_drivers":
        comparator_studies[study]["df"] = (
            pd.read_csv("data/comparator_studies/maher_drivers.csv")
            .assign(
                gene=lambda x: x["Unnamed: 0"].str.split("_").str[0],
                mutation=lambda x: x["Unnamed: 0"].str.split("_", 1).str[1],
            )
            .query("gene == 'Spike'")
            .assign(gene="S")
            .query("mutation.str.fullmatch('[A-Z]\d+[A-Z]')")
            .assign(
                wildtype=lambda x: x["mutation"].str[0],
                site=lambda x: x["mutation"].str[1: -1].astype(int),
                aa=lambda x: x["mutation"].str[-1],
            )
            .rename(columns={"EpiScore_AllLineages": "EpiScore"})
            .groupby(["gene", "site", "aa"], as_index=False)
            .aggregate({"EpiScore": "mean"})
            .reset_index(drop=True)
        )
        
    elif study == "thadani_learning":
        comparator_studies[study]["df"] = (
            pd.read_csv("data/comparator_studies/thadani_learning.csv")
            .rename(columns={"i": "site", "mut": "aa", "eve": "EVE score"})
            .assign(gene="S")
            [["gene", "site", "aa", "EVE score"]]
    
        )

    else:
        raise ValueError(f"invalid {study=}")
    

Read amino-acid fitnesses, and for each site calculate the mean fitness of mutations at each site.

In [None]:
aa_fitness = (
    pd.read_csv(aa_fitness_csv)
    .query("aa != '*'")
    .query("expected_count >= @min_expected_count")
    .rename(columns={"aa_site": "site"})
    .query("gene == 'S'")
    [["gene", "site", "aa", "fitness"]]
)

print("Fitnesses for amino acids:")
display(aa_fitness)

print("\nMean fitnesses for sites:")
site_fitness = (
    aa_fitness
    .groupby(["gene", "site"], as_index=False)
    .aggregate(fitness=pd.NamedAgg("fitness", "mean"))
)
display(site_fitness)

Now get spike DMS dataset and also get the mean DMS measured value at each site:

In [None]:
dms_dfs = []

for dms_dataset, dms_d in dms_datasets.items():
    
    if dms_d["gene"] != "S":
        continue
        
    if dms_d["description"] != "spike (Dadonaite et al, 2023)":
        continue
   
    dms_df = (
        dms_d["data"]
        .query("(wildtype != '*') and (mutant != '*')")
    )
    
    if dms_df["site"].dtype != int:
        dms_df = (
            dms_df
            .query("site.str.isnumeric()")
            .assign(site=lambda x: x["site"].astype(int))
        )
    
    if not len(dms_df.query("wildtype == mutant")):
        # add wildtypes with effects of zero
        dms_df = pd.concat(
            [
                dms_df,
                (
                    dms_df
                    [["site", "wildtype"]]
                    .drop_duplicates()
                    .assign(mutant=lambda x: x["wildtype"], effect=0)
                ),
            ]
        )
    
    if "filter_cols" in dms_d:
        for c, x in dms_d["filter_cols"].items():
            dms_df = dms_df[dms_df[c] >= x]
            
    # calculate site summary DMS and merge with dN/dS and fitness
    dms_dfs.append(
        dms_df.assign(gene=dms_d["gene"], study=dms_d["description"])
    )
    
dms_df = pd.concat(dms_dfs)[["study", "gene", "site", "mutant", "effect"]].rename(
    columns={"mutant": "aa", "study": "DMS study", "effect": "DMS effect"}
)

print("Deep mutational scanning data:")
display(dms_df)

dms_site_df = (
    dms_df.groupby(["DMS study", "gene", "site"], as_index=False)
    .aggregate({"DMS effect": "mean"})
)

print(f"\nDeep mutational scanning data mean site values:")
dms_site_df

Now concatenate everything together into one data frame with comparator study effects, fitness effects, and DMS:

In [None]:
merged_dfs = []

for study, study_d in comparator_studies.items():
    comparator_df = (
        study_d["df"]
        .assign(
            comparator_study=study_d["name"],
            comparator_metric=study_d["df"].columns[-1],
        )
        .rename(columns={study_d["df"].columns[-1]: "comparator score"})
    )
    
    if "aa" in comparator_df.columns:
        comparator_df = (
            comparator_df
            .merge(dms_df, validate="one_to_many")
            .merge(aa_fitness, validate="many_to_one")
            .assign(aa_site=lambda x: x["aa"] + x["site"].astype(str))
            .drop(columns=["site", "aa"])
        )
    else:
        comparator_df = (
            comparator_df
            .merge(dms_site_df, validate="one_to_many")
            .merge(site_fitness, validate="many_to_one")
            .rename(columns={"site": "aa_site"})
        )
        
    merged_dfs.append(comparator_df)
    
merged_df = pd.concat(merged_dfs, ignore_index=True).rename(
    columns={"fitness": "fitness (current study)"}
)

merged_df

Now make plots:

In [None]:
charts = []
for study, study_df in merged_df.groupby("comparator_study"):
    
    comparator_metric = study_df["comparator_metric"].unique()
    assert len(comparator_metric) == 1
    comparator_metric = comparator_metric[0]
    
    dms_study = study_df["DMS study"].unique()
    assert len(dms_study) == 1
    dms_study = dms_study[0].split(None, maxsplit=1)[1]
    dms_metric = f"DMS effect {dms_study}"
    
    study_df = (
        study_df
        .rename(
            columns={"comparator score": comparator_metric, "DMS effect": dms_metric})
        .drop(columns=["comparator_metric", "DMS study"])
    )
    
    base = alt.Chart(study_df)
    study_charts = []
    for col1, col2 in itertools.combinations(
        [comparator_metric, "fitness (current study)", dms_metric],
    2):
        scatter = (
            base
            .encode(
                x=alt.X(col1),
                y=alt.Y(col2),
                tooltip=[
                    alt.Tooltip(c, format=".2f") if study_df[c].dtype == float else c
                    for c in study_df.columns
                ],
            )
            .mark_circle(opacity=0.4)
            .properties(width=175, height=175)
        )
        
        # regression line and correlation coefficient: https://stackoverflow.com/a/60239699
        line = (
            scatter
            .transform_regression(col1, col2)
            .mark_line(color="orange", size=4, opacity=0.5, clip=True)
        )
    
        params_r = (
            scatter
            .transform_regression(col1, col2, params=True)
            .transform_calculate(
                r=alt.expr.sqrt(alt.datum["rSquared"]),
                label='"r = " + format(datum.r, ".2f")',
            )
            .mark_text(align="left", color="orange", fontWeight="bold", fontSize=13)
            .encode(x=alt.value(5), y=alt.value(8), text=alt.Text("label:N"))
        )
        
        study_charts.append(scatter + line + params_r)

    charts.append(
        alt.hconcat(*study_charts, title=alt.TitleParams(study, anchor="middle"))
    )
    
chart = alt.vconcat(*charts).configure_axis(grid=False)

chart.save(corr_html)

chart