# `multidms` fitting pipeline

Here, we demonstrate the pipeline for fitting a `multidms` model on some data using the tools available in the package. Currently, we show how to pre-process data from various dms experiments targeting Delta and Omicron spike protein homologs.

In [29]:
# built-in libraries
import os
import sys
# import pickle
from itertools import combinations
import importlib
import math
import re
from timeit import default_timer as timer
import json

# external dependencies
import pandas as pd
# import jax
# import jax.numpy as jnp
# from jax.experimental import sparse
# from jaxopt import ProximalGradient
# import jaxopt
import numpy as onp
# from scipy.stats import pearsonr
# import binarymap as bmap
from tqdm.notebook import tqdm
# import matplotlib.pyplot as plt
# import seaborn as sns
# import plotnine

# local dependencies
# sys.path.append("..")
# import multidms
import multidms
# from multidms import Multidms, AAS_WITHSTOP

%load_ext autoreload
%autoreload 2

**First, we'll define a few global variables.**

In [30]:
# The substitution column to use from the functional score dataframes
substitution_column = 'aa_substitutions_reference'

# The column providing which individual conditions each variant observation was derived from.
experiment_column = 'homolog_exp'

**Define all the knobs which define the pipeline's behavior**

In [31]:
fit_params = {
    # pre-processing params
    "scale_counts" : False,
    "fs_scaling_group_column" : "homolog_exp",
    "min_pre_counts" : 100,
    "pseudocount" : 0.1,
    "agg_variants" : True,
    "sample" : 3000,
    "min_pre_counts" : 100,
    "clip_target" : (-3.5, 2.5),
    "include_stop_variants" : True,
    "shift_func_score_target_nonref" : False,
    # conditions to include in fit
    "reference_condition" : "Delta-2-1",
    "conditions" : ["Delta-2-1", "Omicron_BA.1-2-1"]
}

**Define the attributes you'll have access to after each fit**

In [32]:
# cols = list(fit_params.keys()) + [
#     "tuned_model_params", "mutation_effects_df", "variant_prediction_df", "site_map"
# ]

**uncomment if you want to reset all results, or if this is the first run**

In [33]:
# results = pd.DataFrame(columns = cols)

## Pre-Proccessing

**Read in the dms data and list all available experimental conditions**

In [34]:
func_score_data = pd.DataFrame()
sites = {}
wt_seqs = {}

for homolog in ["Delta", "Omicron_BA.1", "Omicron_BA.2"]:
    
    # functional scores
    func_sel = pd.read_csv(f"../results/{homolog}/functional_selections.csv")
    func_sel = func_sel.assign(
        filename = f"../results/{homolog}/" + 
        func_sel.library + "_" + 
        func_sel.preselection_sample + 
        "_vs_" + func_sel.postselection_sample + 
        "_func_scores.csv"
    )
    func_sel = func_sel.assign(
        func_sel_scores_df = func_sel.filename.apply(lambda f: pd.read_csv(f))
    )
    func_sel = func_sel.assign(
        len_func_sel_scores_df = func_sel.func_sel_scores_df.apply(lambda x: len(x))
    )
    fun_sel = func_sel.assign(homolog = homolog)
    func_score_data = pd.concat([func_score_data, fun_sel]).reset_index(drop=True)

# Add a column that gives a unique ID to each homolog/DMS experiment
func_score_data['homolog_exp'] = func_score_data.apply(
    lambda row: f"{row['homolog']}-{row['library']}-{row['replicate']}".replace('-Lib',''),
    axis=1
)
func_score_data

Unnamed: 0,preselection_sample,library,virus_batch,replicate,postselection_sample,preselection_library_sample,postselection_library_sample,selection_name,filename,func_sel_scores_df,len_func_sel_scores_df,homolog,homolog_exp
0,2021-10-28_thaw-1_VSVG_control_1,Lib-1,thaw-1,1,2021-12-14_thaw-1_no-antibody_control_1,Lib-1_2021-10-28_thaw-1_VSVG_control_1,Lib-1_2021-12-14_thaw-1_no-antibody_control_1,Lib-1_2021-10-28_thaw-1_VSVG_control_1_vs_2021...,../results/Delta/Lib-1_2021-10-28_thaw-1_VSVG_...,library pre_sampl...,45131,Delta,Delta-1-1
1,2021-10-28_thaw-1_VSVG_control_2,Lib-1,thaw-1,2,2021-12-14_thaw-1_no-antibody_control_2,Lib-1_2021-10-28_thaw-1_VSVG_control_2,Lib-1_2021-12-14_thaw-1_no-antibody_control_2,Lib-1_2021-10-28_thaw-1_VSVG_control_2_vs_2021...,../results/Delta/Lib-1_2021-10-28_thaw-1_VSVG_...,library pre_sampl...,45131,Delta,Delta-1-2
2,2021-10-28_thaw-1_VSVG_control_1,Lib-3,thaw-1,1,2021-12-14_thaw-1_no-antibody_control_1,Lib-3_2021-10-28_thaw-1_VSVG_control_1,Lib-3_2021-12-14_thaw-1_no-antibody_control_1,Lib-3_2021-10-28_thaw-1_VSVG_control_1_vs_2021...,../results/Delta/Lib-3_2021-10-28_thaw-1_VSVG_...,library pre_sampl...,37703,Delta,Delta-3-1
3,2021-10-28_thaw-1_VSVG_control_2,Lib-3,thaw-1,2,2021-12-14_thaw-1_no-antibody_control_2,Lib-3_2021-10-28_thaw-1_VSVG_control_2,Lib-3_2021-12-14_thaw-1_no-antibody_control_2,Lib-3_2021-10-28_thaw-1_VSVG_control_2_vs_2021...,../results/Delta/Lib-3_2021-10-28_thaw-1_VSVG_...,library pre_sampl...,37703,Delta,Delta-3-2
4,2021-10-28_thaw-1_VSVG_control_1,Lib-4,thaw-1,1,2021-12-14_thaw-1_no-antibody_control_1,Lib-4_2021-10-28_thaw-1_VSVG_control_1,Lib-4_2021-12-14_thaw-1_no-antibody_control_1,Lib-4_2021-10-28_thaw-1_VSVG_control_1_vs_2021...,../results/Delta/Lib-4_2021-10-28_thaw-1_VSVG_...,library pre_sampl...,39694,Delta,Delta-4-1
5,2021-10-28_thaw-1_VSVG_control_2,Lib-4,thaw-1,2,2021-12-14_thaw-1_no-antibody_control_2,Lib-4_2021-10-28_thaw-1_VSVG_control_2,Lib-4_2021-12-14_thaw-1_no-antibody_control_2,Lib-4_2021-10-28_thaw-1_VSVG_control_2_vs_2021...,../results/Delta/Lib-4_2021-10-28_thaw-1_VSVG_...,library pre_sampl...,39694,Delta,Delta-4-2
6,2021-10-28_thaw-1_VSVG_control_1,Lib-2,thaw-1,1,2021-11-28_thaw-1_no-antibody_control_1,Lib-2_2021-10-28_thaw-1_VSVG_control_1,Lib-2_2021-11-28_thaw-1_no-antibody_control_1,Lib-2_2021-10-28_thaw-1_VSVG_control_1_vs_2021...,../results/Delta/Lib-2_2021-10-28_thaw-1_VSVG_...,library pre_sampl...,41586,Delta,Delta-2-1
7,2021-10-28_thaw-1_VSVG_control_2,Lib-2,thaw-1,2,2021-11-28_thaw-1_no-antibody_control_2,Lib-2_2021-10-28_thaw-1_VSVG_control_2,Lib-2_2021-11-28_thaw-1_no-antibody_control_2,Lib-2_2021-10-28_thaw-1_VSVG_control_2_vs_2021...,../results/Delta/Lib-2_2021-10-28_thaw-1_VSVG_...,library pre_sampl...,41586,Delta,Delta-2-2
8,2022-03-25_thaw-1_VSVG_control_1,Lib-1,thaw-1,1,2022-04-13_thaw-1_no-antibody_control_1,Lib-1_2022-03-25_thaw-1_VSVG_control_1,Lib-1_2022-04-13_thaw-1_no-antibody_control_1,Lib-1_2022-03-25_thaw-1_VSVG_control_1_vs_2022...,../results/Omicron_BA.1/Lib-1_2022-03-25_thaw-...,library pre_sampl...,94347,Omicron_BA.1,Omicron_BA.1-1-1
9,2022-03-25_thaw-1_VSVG_control_2,Lib-1,thaw-1,2,2022-04-13_thaw-1_no-antibody_control_2,Lib-1_2022-03-25_thaw-1_VSVG_control_2,Lib-1_2022-04-13_thaw-1_no-antibody_control_2,Lib-1_2022-03-25_thaw-1_VSVG_control_2_vs_2022...,../results/Omicron_BA.1/Lib-1_2022-03-25_thaw-...,library pre_sampl...,94347,Omicron_BA.1,Omicron_BA.1-1-2


**Query the conditions to be included in a fit**

In [35]:
func_score_data = func_score_data.query(
    f"{experiment_column}.isin({fit_params['conditions']})"
)
func_score_data

Unnamed: 0,preselection_sample,library,virus_batch,replicate,postselection_sample,preselection_library_sample,postselection_library_sample,selection_name,filename,func_sel_scores_df,len_func_sel_scores_df,homolog,homolog_exp
6,2021-10-28_thaw-1_VSVG_control_1,Lib-2,thaw-1,1,2021-11-28_thaw-1_no-antibody_control_1,Lib-2_2021-10-28_thaw-1_VSVG_control_1,Lib-2_2021-11-28_thaw-1_no-antibody_control_1,Lib-2_2021-10-28_thaw-1_VSVG_control_1_vs_2021...,../results/Delta/Lib-2_2021-10-28_thaw-1_VSVG_...,library pre_sampl...,41586,Delta,Delta-2-1
10,2022-06-22_thaw-1_VSVG_control_1,Lib-2,thaw-1,1,2022-06-22_thaw-1_no-antibody_control_1,Lib-2_2022-06-22_thaw-1_VSVG_control_1,Lib-2_2022-06-22_thaw-1_no-antibody_control_1,Lib-2_2022-06-22_thaw-1_VSVG_control_1_vs_2022...,../results/Omicron_BA.1/Lib-2_2022-06-22_thaw-...,library pre_samp...,140643,Omicron_BA.1,Omicron_BA.1-2-1


In [36]:
func_score_df = pd.DataFrame()
for idx, row in tqdm(func_score_data.iterrows(), total=len(func_score_data)):
    df = row.func_sel_scores_df.assign(homolog=row.homolog)
    df = df.assign(library = row.library)
    df = df.assign(replicate = row.replicate)
    exp_func_score_df = df.assign(homolog_exp=row.homolog_exp)
    func_score_df = pd.concat([func_score_df, exp_func_score_df])

  0%|          | 0/2 [00:00<?, ?it/s]

**Optionally subset the variants**

In [37]:
if fit_params["sample"]:
    func_score_df = func_score_df.sample(fit_params["sample"])

**remove all variants with string sites and stop codon wildtypes**

In [38]:
func_score_df.aa_substitutions_reference.fillna("", inplace=True)
gapped_sub_vars = []
stop_wt_vars = []
non_numeric_sites = []
for idx, row in tqdm(func_score_df.iterrows(), total=len(func_score_df)):
    if "-" in row[substitution_column]:
        gapped_sub_vars.append(idx)
    for sub in row[substitution_column].split():
        if sub[0] == "*":
            stop_wt_vars.append(idx)
        if not sub[-2].isnumeric():
            non_numeric_sites.append(idx)

to_drop = set.union(set(gapped_sub_vars), set(stop_wt_vars), set(non_numeric_sites))
func_score_df.drop(to_drop, inplace=True)

  0%|          | 0/3000 [00:00<?, ?it/s]

**Optionally, scale the counts**

In [39]:
# TODO re-write and make function
# def normalize_by_freq()?
if fit_params['scale_counts']:
    dfs = []
    for (h, hdf) in func_score_df.groupby(fit_params["fs_scaling_group_column"]):
        n_post_counts = sum(hdf['post_count'])
        if 'Delta' in h:
            bottleneck = 1e5
            scaling_factor = bottleneck / n_post_counts # scaling_factor = 0.05
        else:
            bottleneck = 1e5
            scaling_factor = bottleneck / n_post_counts # scaling_factor = 0.05
        hdf['orig_post_count'] = hdf['post_count']
        hdf['post_count'] *= scaling_factor
        hdf['post_count_wt'] *= scaling_factor
        print(h, n_post_counts, round(scaling_factor, 2), round(sum(hdf['post_count']),2))

        # Recompute enrichment ratios with new counts
        hdf['pre_count_ps'] = hdf['pre_count'] + fit_params["pseudocount"]
        hdf['post_count_ps'] = hdf['post_count'] + fit_params["pseudocount"]
        hdf['pre_count_wt_ps'] = hdf['pre_count_wt'] + fit_params["pseudocount"]
        hdf['post_count_wt_ps'] = hdf['post_count_wt'] + fit_params["pseudocount"]

        total_pre_count = sum(hdf['pre_count_ps'])
        total_post_count = sum(hdf['post_count_ps'])

        hdf['pre_freq'] = hdf['pre_count_ps'] / total_pre_count
        hdf['post_freq'] = hdf['post_count_ps'] / total_post_count
        hdf['pre_freq_wt'] = hdf['pre_count_wt_ps'] / total_pre_count
        hdf['post_freq_wt'] = hdf['post_count_wt_ps'] / total_post_count

        hdf['wt_e'] = hdf['post_freq_wt'] / hdf['pre_freq_wt']
        hdf['var_e'] = hdf['post_freq'] / hdf['pre_freq']
        hdf['e'] = hdf['var_e'] / hdf['wt_e']
        #hdf.dropna(subset=['e'], inplace=True)
        hdf['func_score'] = hdf['e'].apply(lambda x: math.log(x, 2))
        dfs.append(hdf)
    func_score_df = pd.concat(dfs)

**Drop all variants with pre-counts below a threshold.**

In [40]:
# Drop barcoded variants with pre-counts below a threshold
n_pre_threshold = len(func_score_df)
func_score_df = func_score_df[func_score_df['pre_count'] >= fit_params["min_pre_counts"]]
print(f"Of {n_pre_threshold} variants, {n_pre_threshold - len(func_score_df)} had fewer than {fit_params['min_pre_counts']} counts before selection, and were filtered out")

Of 2849 variants, 280 had fewer than 100 counts before selection, and were filtered out


**Optionally throw all variants with stop codons.**

In [41]:
# results["include_stop_variants"] = True
if not fit_params["include_stop_variants"]:
    stop_variants = [idx for idx, row in func_score_df.iterrows() if "*" in row[substitution_column]]
    func_score_df = func_score_df.drop(stop_variants)

**Optionally, aggregate variant functional scores across barcode replicates**

In [42]:
# if fit_params["agg_variants"]:
#     func_score_df = func_score_df.groupby([substitution_column, experiment_column]).mean().reset_index()
# func_score_df["pre_count"] = func_score_df["pre_count"].astype(int)
# func_score_df["post_count"] = func_score_df["post_count"].astype(int)

**Optionally, clip the target functional scores**

In [43]:
if fit_params['clip_target']:
    func_score_df["func_score"] = func_score_df["func_score"].clip(*fit_params["clip_target"])

In [44]:
func_score_df.rename({"aa_substitutions_reference":"aa_substitutions"}, axis=1, inplace=True)
func_score_df.rename({"homolog_exp":"condition"}, axis=1, inplace=True)

In [45]:
func_score_df.head()

Unnamed: 0,library,pre_sample,post_sample,barcode,func_score,func_score_var,pre_count,post_count,pre_count_wt,post_count_wt,pseudocount,n_codon_substitutions,aa_substitutions_sequential,n_aa_substitutions,aa_substitutions,pre_count_threshold,homolog,replicate,condition
4951,Lib-2,2022-06-22_thaw-1_VSVG_control_1,2022-06-22_thaw-1_no-antibody_control_1,ACCGGTAACAGGTCAC,-3.5,0.0294,1469,74,11338665,14040452,0.5,4,D136A R1011G C1233Y G1243S,4,D138A R1014G C1236Y G1246S,66,Omicron_BA.1,1,Omicron_BA.1-2-1
2815,Lib-2,2022-06-22_thaw-1_VSVG_control_1,2022-06-22_thaw-1_no-antibody_control_1,CGGGGACATAACGACA,-0.1798,0.0023,1738,1900,11338665,14040452,0.5,0,,0,,66,Omicron_BA.1,1,Omicron_BA.1-2-1
36799,Lib-2,2022-06-22_thaw-1_VSVG_control_1,2022-06-22_thaw-1_no-antibody_control_1,CCTGAAAGTTAGCATA,0.2163,0.006,590,849,11338665,14040452,0.5,1,,0,,66,Omicron_BA.1,1,Omicron_BA.1-2-1
63538,Lib-2,2022-06-22_thaw-1_VSVG_control_1,2022-06-22_thaw-1_no-antibody_control_1,TGCGACCACACACGAG,0.3322,0.0091,377,588,11338665,14040452,0.5,0,,0,,66,Omicron_BA.1,1,Omicron_BA.1-2-1
55782,Lib-2,2022-06-22_thaw-1_VSVG_control_1,2022-06-22_thaw-1_no-antibody_control_1,TGCTCCATTAATGAAA,-0.1248,0.0092,427,485,11338665,14040452,0.5,4,V67N S146K E658K,3,V67N S151K E661K,66,Omicron_BA.1,1,Omicron_BA.1-2-1


## `MultiDmsData`

In [63]:
data = multidms.MultiDmsData(
    func_score_df,
    alphabet= multidms.AAS_WITHSTOP,
    reference="Delta-2-1"
)

Found 495 site(s) lacking data in at least one condition.
901 of the 2076 variants were removed because they had mutations at the above sites, leaving 1175 variants.


100%|███████████████████████████████████████████████████████████████████████████████████████| 799/799 [00:02<00:00, 272.52it/s]


There were 0 cache hits in total for condition Omicron_BA.1-2-1.


In [64]:
data.mutations[:5]

('M1I', 'M1T', 'V3A', 'V3F', 'V3G')

In [65]:
data.mutations_df.head()

Unnamed: 0,mutation,wts,sites,muts,times_seen
0,M1I,M,1,I,2
1,M1T,M,1,T,1
2,V3A,V,3,A,1
3,V3F,V,3,F,3
4,V3G,V,3,G,2


In [67]:
data.variants_df.head()

Unnamed: 0,condition,aa_substitutions,weight,func_score,allowed_variant,var_wrt_ref
0,Delta-2-1,,66,-0.201139,True,
1,Delta-2-1,A1020Y N1135S,1,0.0294,True,A1020Y N1135S
3,Delta-2-1,A1080S,1,-0.5419,True,A1080S
7,Delta-2-1,A222S,1,-0.65,True,A222S
8,Delta-2-1,A222V,1,0.1309,True,A222V


In [68]:
data.conditions

('Omicron_BA.1-2-1', 'Delta-2-1')

In [69]:
data.reference

'Delta-2-1'

In [70]:
data.binarymaps['X'] # 'y' also exists here

{'Delta-2-1': BCOO(int8[376, 1584], nse=836),
 'Omicron_BA.1-2-1': BCOO(int8[799, 1584], nse=20461)}

## `MultiDmsModel`

In short the `MultiDmsModel` object initialization includes:
   1. initialize attributes
   2. initialize params depending on model composits chosen
   3. compile model composition

In [71]:
model = multidms.MultiDmsModel(
    data,
    latent_model="phi",
    epistatic_model="identity",
    output_activation="identity"
)

In [72]:
model.__dict__.keys()

dict_keys(['gamma_corrected', 'conditional_shifts', 'conditional_c', '_data', 'params', '_model'])

In [73]:
print(f"Before fitting, the loss of this model (with initial params) is {round(model.loss, 3)}")
model.fit(lasso_shift=1e-5, maxiter=10)
print(f"After fitting, the loss of this model (with initial params) is {round(model.loss, 3)}")

Before fitting, the loss of this model (with initial params) is 8.121
After fitting, the loss of this model (with initial params) is 4.761


In [74]:
model.data.mutations_df.head()

Unnamed: 0,mutation,wts,sites,muts,times_seen
0,M1I,M,1,I,2
1,M1T,M,1,T,1
2,V3A,V,3,A,1
3,V3F,V,3,F,3
4,V3G,V,3,G,2


In [75]:
model.mutations_df.head()

Unnamed: 0,mutation,wts,sites,muts,times_seen,β,S_Omicron_BA.1-2-1,F_Omicron_BA.1-2-1,S_Delta-2-1,F_Delta-2-1
0,M1I,M,1,I,2,-0.824364,0.0,2.433046,0.0,2.433046
1,M1T,M,1,T,1,0.88223,-0.002143,4.137496,0.0,4.139639
2,V3A,V,3,A,1,-1.064535,-0.002143,2.190732,0.0,2.192874
3,V3F,V,3,F,3,0.886644,-0.002143,4.141911,0.0,4.144054
4,V3G,V,3,G,2,-1.660799,0.002002,1.598612,0.0,1.59661


In [76]:
model.data.variants_df.head()

Unnamed: 0,condition,aa_substitutions,weight,func_score,allowed_variant,var_wrt_ref
0,Delta-2-1,,66,-0.201139,True,
1,Delta-2-1,A1020Y N1135S,1,0.0294,True,A1020Y N1135S
3,Delta-2-1,A1080S,1,-0.5419,True,A1080S
7,Delta-2-1,A222S,1,-0.65,True,A222S
8,Delta-2-1,A222V,1,0.1309,True,A222V


In [77]:
model.variants_df.head()

Unnamed: 0,condition,aa_substitutions,weight,func_score,allowed_variant,var_wrt_ref,predicted_latent,predicted_func_score,corrected_func_score
0,Delta-2-1,,66,-0.201139,True,,3.257409,3.257409,-0.201139
1,Delta-2-1,A1020Y N1135S,1,0.0294,True,A1020Y N1135S,5.09066,5.09066,0.0294
3,Delta-2-1,A1080S,1,-0.5419,True,A1080S,4.396266,4.396266,-0.5419
7,Delta-2-1,A222S,1,-0.65,True,A222S,5.353698,5.353698,-0.65
8,Delta-2-1,A222V,1,0.1309,True,A222V,2.816859,2.816859,0.1309


In [87]:
sigmoid_model = multidms.MultiDmsModel(
    data,
    latent_model="phi",
    epistatic_model="sigmoid",
    output_activation="identity"
)

**The model classes share references to the same data to keep things effecient.** 

In [93]:
model.data is sigmoid_model.data

True

In [94]:
print(f"Before fitting, the loss of this model (with initial params) is {round(sigmoid_model.loss, 3)}")
sigmoid_model.fit(lasso_shift=1e-5, maxiter=10)
print(f"After fitting, the loss of this model (with initial params) is {round(sigmoid_model.loss, 3)}")

Before fitting, the loss of this model (with initial params) is 1.437
After fitting, the loss of this model (with initial params) is 1.437


In [95]:
sigmoid_softplus_model = multidms.MultiDmsModel(
    data,
    latent_model="phi",
    epistatic_model="sigmoid",
    output_activation="softplus"
)
print(f"Before fitting, the loss of this model (with initial params) is {round(sigmoid_softplus_model.loss, 3)}")
sigmoid_softplus_model.fit(lasso_shift=1e-5, maxiter=10)
print(f"After fitting, the loss of this model (with initial params) is {round(sigmoid_softplus_model.loss, 3)}")

Before fitting, the loss of this model (with initial params) is 1.691
After fitting, the loss of this model (with initial params) is 1.437


In [99]:
sigmoid_model.variants_df.predicted_func_score.sum()

0.0

In [100]:
sigmoid_model.params

{'C_Delta-2-1': DeviceArray([0.], dtype=float64),
 'C_Omicron_BA.1-2-1': DeviceArray([0.], dtype=float64),
 'C_ref': DeviceArray([4.96247717], dtype=float64),
 'S_Delta-2-1': DeviceArray([0., 0., 0., ..., 0., 0., 0.], dtype=float64),
 'S_Omicron_BA.1-2-1': DeviceArray([ 0.        , -0.00036382, -0.00036382, ...,  0.        ,
               0.        ,  0.00152727], dtype=float64),
 'α': {'ge_bias': DeviceArray([0.], dtype=float64),
  'ge_scale': DeviceArray([0.], dtype=float64)},
 'β': DeviceArray([-0.8346446 ,  0.88331895, -1.06344564, ..., -1.85269346,
               0.87187642, -0.19739826], dtype=float64),
 'γ_Delta-2-1': DeviceArray([0.], dtype=float64),
 'γ_Omicron_BA.1-2-1': DeviceArray([0.61121915], dtype=float64)}