# Analyze and plot correlations among fitness estimates and DMS measurements

Import Python modules:

In [None]:
import altair as alt

import numpy

import pandas as pd

import yaml

Now get variables from `snakemake`:

In [None]:
if "snakemake" not in globals() and "snakemake" not in locals():
    # variables set manually for interactive debugging
    aa_fitness_csv = "../results/aa_fitness/aa_fitness.csv"
    neher_fitness_csv = "../data/Neher_aa_fitness.csv"
    plotsdir = "../results/fitness_dms_corr/plots"
    
    with open("../config.yaml") as f:
        config = yaml.safe_load(f)
    min_expected_count = config["min_expected_count"]
    dms_datasets = config["dms_datasets"]
    
    dms_datasets_csvs = {
        dms_dataset: f"../results/dms/{dms_dataset}/processed.csv"
        for dms_dataset in dms_datasets
    }
    
else:
    # get variables from `snakemake` when running pipeline
    aa_fitness_csv = snakemake.input.aafitness
    neher_fitness_csv = snakemake.input.neher_fitness
    plotsdir = snakemake.output.plotsdir
    min_expected_count = snakemake.params.min_expected_count
    dms_datasets = snakemake.params.dms_datasets
    dms_datasets_csvs = {
        dms_dataset: getattr(snakemake.input, dms_dataset)
        for dms_dataset in dms_datasets
    }

Read the datasets:

In [None]:
aa_fitness = pd.read_csv(aa_fitness_csv)
neher_fitness = pd.read_csv(neher_fitness_csv)
dms_datasets = {
    dms_dataset: pd.read_csv(f) for dms_dataset, f in dms_datasets_csvs.items()
}

First examining correlations between amino-acid fitness values from the current approach and the Neher estimates (which are just for spike).
To do this, we have to get the estimates to share a common "wildtype" identity, which we choose as whatever is the Neher et al identity:

In [None]:
spike_fitness = (
    aa_fitness
    .query("gene == 'S'")
    .rename(columns={"aa_site": "site"})
    [["site", "aa", "fitness", "expected_count"]]
)

neher_corr_df = (
    neher_fitness
    .query("aa_fitness.notnull()")
    .assign(
        wildtype=lambda x: x["mutation"].str[0],
        site=lambda x: x["mutation"].str[1: -1].astype(int),
        mutant=lambda x: x["mutation"].str[-1],
    )
    .rename(columns={"aa_fitness": "Neher fitness effect"})
    [["wildtype", "site", "mutant", "Neher fitness effect"]]
    .merge(
        spike_fitness.rename(
            columns={
                "aa": "wildtype",
                "fitness": "wt_fitness",
                "expected_count": "expected_count_wt",
            }
        ),
        on=["site", "wildtype"],
        how="inner",
        validate="many_to_one",
    )
    .merge(
        spike_fitness.rename(
            columns={
                "aa": "mutant",
                "fitness": "mut_fitness",
                "expected_count": "expected_count_mut",
            }
        ),
        on=["site", "mutant"],
        how="inner",
        validate="many_to_one",
    )
    .assign(
        expected_count=lambda x: numpy.minimum(
            x["expected_count_wt"], x["expected_count_mut"],
        ),
        fitness_effect=lambda x: x["mut_fitness"] - x["wt_fitness"],
        mutation=lambda x: x["wildtype"] + x["site"].astype(str) + x["mutant"],
    )
    [["mutation", "fitness_effect", "Neher fitness effect", "expected_count"]]
)

neher_corr_df

In [None]:
(
    neher_corr_df
    .assign(n=lambda x: x.groupby(["site", "mutant"])["Neher fitness effect"].transform("count"))
    .query("n > 1")
    .sort_values(["site", "mutant"])
)

In [None]:
numpy.minimum([1, 2], [0, 3])