# `multidms` fitting pipeline

Here, we demonstrate the pipeline for fitting a `multidms` model on some data using the tools available in the package. Currently, we show how to pre-process data from various dms experiments targeting Delta and Omicron spike protein homologs.

In [79]:
# built-in libraries
import os
import sys
# import pickle
from itertools import combinations
import importlib
import math
import re
# from timeit import default_timer as timer
import time
import json

# external dependencies
import pandas as pd
# import jax
# import jax.numpy as jnp
# from jax.experimental import sparse
# from jaxopt import ProximalGradient
# import jaxopt
import numpy as onp
# from scipy.stats import pearsonr
# import binarymap as bmap
from tqdm.notebook import tqdm
# import matplotlib.pyplot as plt
# import seaborn as sns
# import plotnine

# local dependencies
# sys.path.append("..")
# import multidms
import multidms
# from multidms import Multidms, AAS_WITHSTOP

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


**First, we'll define a few global variables.**

In [2]:
# The substitution column to use from the functional score dataframes
substitution_column = 'aa_substitutions_reference'

# The column providing which individual conditions each variant observation was derived from.
experiment_column = 'homolog_exp'

**Define all the knobs which define the pipeline's behavior**

In [38]:
fit_params = {
    # pre-processing params
    "scale_counts" : False,
    "fs_scaling_group_column" : "homolog_exp",
    "min_pre_counts" : 100,
    "pseudocount" : 0.1,
    "agg_variants" : True,
    "sample" : False, #3000,
    "min_pre_counts" : 100,
    "clip_target" : (-3.5, 2.5),
    "include_stop_variants" : True,
    "shift_func_score_target_nonref" : False,
    # conditions to include in fit
    "reference_condition" : "Delta-2-1",
    "conditions" : ["Delta-2-1", "Omicron_BA.1-2-1"]
}

**Define the attributes you'll have access to after each fit**

In [39]:
# cols = list(fit_params.keys()) + [
#     "tuned_model_params", "mutation_effects_df", "variant_prediction_df", "site_map"
# ]

**uncomment if you want to reset all results, or if this is the first run**

In [40]:
# results = pd.DataFrame(columns = cols)

## Pre-Proccessing

**Read in the dms data and list all available experimental conditions**

In [41]:
func_score_data = pd.DataFrame()
sites = {}
wt_seqs = {}

for homolog in ["Delta", "Omicron_BA.1", "Omicron_BA.2"]:
    
    # functional scores
    func_sel = pd.read_csv(f"../results/{homolog}/functional_selections.csv")
    func_sel = func_sel.assign(
        filename = f"../results/{homolog}/" + 
        func_sel.library + "_" + 
        func_sel.preselection_sample + 
        "_vs_" + func_sel.postselection_sample + 
        "_func_scores.csv"
    )
    func_sel = func_sel.assign(
        func_sel_scores_df = func_sel.filename.apply(lambda f: pd.read_csv(f))
    )
    func_sel = func_sel.assign(
        len_func_sel_scores_df = func_sel.func_sel_scores_df.apply(lambda x: len(x))
    )
    fun_sel = func_sel.assign(homolog = homolog)
    func_score_data = pd.concat([func_score_data, fun_sel]).reset_index(drop=True)

# Add a column that gives a unique ID to each homolog/DMS experiment
func_score_data['homolog_exp'] = func_score_data.apply(
    lambda row: f"{row['homolog']}-{row['library']}-{row['replicate']}".replace('-Lib',''),
    axis=1
)
func_score_data

Unnamed: 0,preselection_sample,library,virus_batch,replicate,postselection_sample,preselection_library_sample,postselection_library_sample,selection_name,filename,func_sel_scores_df,len_func_sel_scores_df,homolog,homolog_exp
0,2021-10-28_thaw-1_VSVG_control_1,Lib-1,thaw-1,1,2021-12-14_thaw-1_no-antibody_control_1,Lib-1_2021-10-28_thaw-1_VSVG_control_1,Lib-1_2021-12-14_thaw-1_no-antibody_control_1,Lib-1_2021-10-28_thaw-1_VSVG_control_1_vs_2021...,../results/Delta/Lib-1_2021-10-28_thaw-1_VSVG_...,library pre_sampl...,45131,Delta,Delta-1-1
1,2021-10-28_thaw-1_VSVG_control_2,Lib-1,thaw-1,2,2021-12-14_thaw-1_no-antibody_control_2,Lib-1_2021-10-28_thaw-1_VSVG_control_2,Lib-1_2021-12-14_thaw-1_no-antibody_control_2,Lib-1_2021-10-28_thaw-1_VSVG_control_2_vs_2021...,../results/Delta/Lib-1_2021-10-28_thaw-1_VSVG_...,library pre_sampl...,45131,Delta,Delta-1-2
2,2021-10-28_thaw-1_VSVG_control_1,Lib-3,thaw-1,1,2021-12-14_thaw-1_no-antibody_control_1,Lib-3_2021-10-28_thaw-1_VSVG_control_1,Lib-3_2021-12-14_thaw-1_no-antibody_control_1,Lib-3_2021-10-28_thaw-1_VSVG_control_1_vs_2021...,../results/Delta/Lib-3_2021-10-28_thaw-1_VSVG_...,library pre_sampl...,37703,Delta,Delta-3-1
3,2021-10-28_thaw-1_VSVG_control_2,Lib-3,thaw-1,2,2021-12-14_thaw-1_no-antibody_control_2,Lib-3_2021-10-28_thaw-1_VSVG_control_2,Lib-3_2021-12-14_thaw-1_no-antibody_control_2,Lib-3_2021-10-28_thaw-1_VSVG_control_2_vs_2021...,../results/Delta/Lib-3_2021-10-28_thaw-1_VSVG_...,library pre_sampl...,37703,Delta,Delta-3-2
4,2021-10-28_thaw-1_VSVG_control_1,Lib-4,thaw-1,1,2021-12-14_thaw-1_no-antibody_control_1,Lib-4_2021-10-28_thaw-1_VSVG_control_1,Lib-4_2021-12-14_thaw-1_no-antibody_control_1,Lib-4_2021-10-28_thaw-1_VSVG_control_1_vs_2021...,../results/Delta/Lib-4_2021-10-28_thaw-1_VSVG_...,library pre_sampl...,39694,Delta,Delta-4-1
5,2021-10-28_thaw-1_VSVG_control_2,Lib-4,thaw-1,2,2021-12-14_thaw-1_no-antibody_control_2,Lib-4_2021-10-28_thaw-1_VSVG_control_2,Lib-4_2021-12-14_thaw-1_no-antibody_control_2,Lib-4_2021-10-28_thaw-1_VSVG_control_2_vs_2021...,../results/Delta/Lib-4_2021-10-28_thaw-1_VSVG_...,library pre_sampl...,39694,Delta,Delta-4-2
6,2021-10-28_thaw-1_VSVG_control_1,Lib-2,thaw-1,1,2021-11-28_thaw-1_no-antibody_control_1,Lib-2_2021-10-28_thaw-1_VSVG_control_1,Lib-2_2021-11-28_thaw-1_no-antibody_control_1,Lib-2_2021-10-28_thaw-1_VSVG_control_1_vs_2021...,../results/Delta/Lib-2_2021-10-28_thaw-1_VSVG_...,library pre_sampl...,41586,Delta,Delta-2-1
7,2021-10-28_thaw-1_VSVG_control_2,Lib-2,thaw-1,2,2021-11-28_thaw-1_no-antibody_control_2,Lib-2_2021-10-28_thaw-1_VSVG_control_2,Lib-2_2021-11-28_thaw-1_no-antibody_control_2,Lib-2_2021-10-28_thaw-1_VSVG_control_2_vs_2021...,../results/Delta/Lib-2_2021-10-28_thaw-1_VSVG_...,library pre_sampl...,41586,Delta,Delta-2-2
8,2022-03-25_thaw-1_VSVG_control_1,Lib-1,thaw-1,1,2022-04-13_thaw-1_no-antibody_control_1,Lib-1_2022-03-25_thaw-1_VSVG_control_1,Lib-1_2022-04-13_thaw-1_no-antibody_control_1,Lib-1_2022-03-25_thaw-1_VSVG_control_1_vs_2022...,../results/Omicron_BA.1/Lib-1_2022-03-25_thaw-...,library pre_sampl...,94347,Omicron_BA.1,Omicron_BA.1-1-1
9,2022-03-25_thaw-1_VSVG_control_2,Lib-1,thaw-1,2,2022-04-13_thaw-1_no-antibody_control_2,Lib-1_2022-03-25_thaw-1_VSVG_control_2,Lib-1_2022-04-13_thaw-1_no-antibody_control_2,Lib-1_2022-03-25_thaw-1_VSVG_control_2_vs_2022...,../results/Omicron_BA.1/Lib-1_2022-03-25_thaw-...,library pre_sampl...,94347,Omicron_BA.1,Omicron_BA.1-1-2


**Query the conditions to be included in a fit**

In [42]:
func_score_data = func_score_data.query(
    f"{experiment_column}.isin({fit_params['conditions']})"
)
func_score_data

Unnamed: 0,preselection_sample,library,virus_batch,replicate,postselection_sample,preselection_library_sample,postselection_library_sample,selection_name,filename,func_sel_scores_df,len_func_sel_scores_df,homolog,homolog_exp
6,2021-10-28_thaw-1_VSVG_control_1,Lib-2,thaw-1,1,2021-11-28_thaw-1_no-antibody_control_1,Lib-2_2021-10-28_thaw-1_VSVG_control_1,Lib-2_2021-11-28_thaw-1_no-antibody_control_1,Lib-2_2021-10-28_thaw-1_VSVG_control_1_vs_2021...,../results/Delta/Lib-2_2021-10-28_thaw-1_VSVG_...,library pre_sampl...,41586,Delta,Delta-2-1
10,2022-06-22_thaw-1_VSVG_control_1,Lib-2,thaw-1,1,2022-06-22_thaw-1_no-antibody_control_1,Lib-2_2022-06-22_thaw-1_VSVG_control_1,Lib-2_2022-06-22_thaw-1_no-antibody_control_1,Lib-2_2022-06-22_thaw-1_VSVG_control_1_vs_2022...,../results/Omicron_BA.1/Lib-2_2022-06-22_thaw-...,library pre_samp...,140643,Omicron_BA.1,Omicron_BA.1-2-1


In [43]:
func_score_df = pd.DataFrame()
for idx, row in tqdm(func_score_data.iterrows(), total=len(func_score_data)):
    df = row.func_sel_scores_df.assign(homolog=row.homolog)
    df = df.assign(library = row.library)
    df = df.assign(replicate = row.replicate)
    exp_func_score_df = df.assign(homolog_exp=row.homolog_exp)
    func_score_df = pd.concat([func_score_df, exp_func_score_df])

  0%|          | 0/2 [00:00<?, ?it/s]

**Optionally subset the variants**

In [44]:
if fit_params["sample"]:
    func_score_df = func_score_df.sample(fit_params["sample"])

**remove all variants with string sites and stop codon wildtypes**

In [45]:
func_score_df.aa_substitutions_reference.fillna("", inplace=True)
gapped_sub_vars = []
stop_wt_vars = []
non_numeric_sites = []
for idx, row in tqdm(func_score_df.iterrows(), total=len(func_score_df)):
    if "-" in row[substitution_column]:
        gapped_sub_vars.append(idx)
    for sub in row[substitution_column].split():
        if sub[0] == "*":
            stop_wt_vars.append(idx)
        if not sub[-2].isnumeric():
            non_numeric_sites.append(idx)

to_drop = set.union(set(gapped_sub_vars), set(stop_wt_vars), set(non_numeric_sites))
func_score_df.drop(to_drop, inplace=True)

  0%|          | 0/182229 [00:00<?, ?it/s]

**Optionally, scale the counts**

In [46]:
# TODO re-write and make function
# def normalize_by_freq()?
if fit_params['scale_counts']:
    dfs = []
    for (h, hdf) in func_score_df.groupby(fit_params["fs_scaling_group_column"]):
        n_post_counts = sum(hdf['post_count'])
        if 'Delta' in h:
            bottleneck = 1e5
            scaling_factor = bottleneck / n_post_counts # scaling_factor = 0.05
        else:
            bottleneck = 1e5
            scaling_factor = bottleneck / n_post_counts # scaling_factor = 0.05
        hdf['orig_post_count'] = hdf['post_count']
        hdf['post_count'] *= scaling_factor
        hdf['post_count_wt'] *= scaling_factor
        print(h, n_post_counts, round(scaling_factor, 2), round(sum(hdf['post_count']),2))

        # Recompute enrichment ratios with new counts
        hdf['pre_count_ps'] = hdf['pre_count'] + fit_params["pseudocount"]
        hdf['post_count_ps'] = hdf['post_count'] + fit_params["pseudocount"]
        hdf['pre_count_wt_ps'] = hdf['pre_count_wt'] + fit_params["pseudocount"]
        hdf['post_count_wt_ps'] = hdf['post_count_wt'] + fit_params["pseudocount"]

        total_pre_count = sum(hdf['pre_count_ps'])
        total_post_count = sum(hdf['post_count_ps'])

        hdf['pre_freq'] = hdf['pre_count_ps'] / total_pre_count
        hdf['post_freq'] = hdf['post_count_ps'] / total_post_count
        hdf['pre_freq_wt'] = hdf['pre_count_wt_ps'] / total_pre_count
        hdf['post_freq_wt'] = hdf['post_count_wt_ps'] / total_post_count

        hdf['wt_e'] = hdf['post_freq_wt'] / hdf['pre_freq_wt']
        hdf['var_e'] = hdf['post_freq'] / hdf['pre_freq']
        hdf['e'] = hdf['var_e'] / hdf['wt_e']
        #hdf.dropna(subset=['e'], inplace=True)
        hdf['func_score'] = hdf['e'].apply(lambda x: math.log(x, 2))
        dfs.append(hdf)
    func_score_df = pd.concat(dfs)

**Drop all variants with pre-counts below a threshold.**

In [47]:
# Drop barcoded variants with pre-counts below a threshold
n_pre_threshold = len(func_score_df)
func_score_df = func_score_df[func_score_df['pre_count'] >= fit_params["min_pre_counts"]]
print(f"Of {n_pre_threshold} variants, {n_pre_threshold - len(func_score_df)} had fewer than {fit_params['min_pre_counts']} counts before selection, and were filtered out")

Of 169856 variants, 16326 had fewer than 100 counts before selection, and were filtered out


**Optionally throw all variants with stop codons.**

In [48]:
# results["include_stop_variants"] = True
if not fit_params["include_stop_variants"]:
    stop_variants = [idx for idx, row in func_score_df.iterrows() if "*" in row[substitution_column]]
    func_score_df = func_score_df.drop(stop_variants)

**Optionally, aggregate variant functional scores across barcode replicates**

In [49]:
# if fit_params["agg_variants"]:
#     func_score_df = func_score_df.groupby([substitution_column, experiment_column]).mean().reset_index()
# func_score_df["pre_count"] = func_score_df["pre_count"].astype(int)
# func_score_df["post_count"] = func_score_df["post_count"].astype(int)

**Optionally, clip the target functional scores**

In [50]:
if fit_params['clip_target']:
    func_score_df["func_score"] = func_score_df["func_score"].clip(*fit_params["clip_target"])

In [51]:
func_score_df.rename({"aa_substitutions_reference":"aa_substitutions"}, axis=1, inplace=True)
func_score_df.rename({"homolog_exp":"condition"}, axis=1, inplace=True)

In [52]:
func_score_df.head()

Unnamed: 0,library,pre_sample,post_sample,barcode,func_score,func_score_var,pre_count,post_count,pre_count_wt,post_count_wt,pseudocount,n_codon_substitutions,aa_substitutions_sequential,n_aa_substitutions,aa_substitutions,pre_count_threshold,homolog,replicate,condition
0,Lib-2,2021-10-28_thaw-1_VSVG_control_1,2021-11-28_thaw-1_no-antibody_control_1,CCGAGTAATCCAATAA,-0.4039,0.0005,10237,7909,1909460,1951855,0.5,3,D251A M1235V,2,D253A M1237V,20,Delta,1,Delta-2-1
1,Lib-2,2021-10-28_thaw-1_VSVG_control_1,2021-11-28_thaw-1_no-antibody_control_1,TAAGGACACTCACAAA,-0.3398,0.0005,8528,6888,1909460,1951855,0.5,3,H623L A1014V T1115M,3,H625L A1016V T1117M,20,Delta,1,Delta-2-1
2,Lib-2,2021-10-28_thaw-1_VSVG_control_1,2021-11-28_thaw-1_no-antibody_control_1,ATATTAGTTTATGCCT,-0.1499,0.0005,7979,7351,1909460,1951855,0.5,2,G767H,1,G769H,20,Delta,1,Delta-2-1
3,Lib-2,2021-10-28_thaw-1_VSVG_control_1,2021-11-28_thaw-1_no-antibody_control_1,GAAGATCACTGGTCGA,-0.1554,0.0005,7956,7302,1909460,1951855,0.5,3,P26T S475D,2,P26T S477D,20,Delta,1,Delta-2-1
4,Lib-2,2021-10-28_thaw-1_VSVG_control_1,2021-11-28_thaw-1_no-antibody_control_1,CAAGTGGGTAAATGAT,-1.053,0.0009,7340,3616,1909460,1951855,0.5,4,G75W D176E P320Q H1099R,4,G75W D178E P322Q H1101R,20,Delta,1,Delta-2-1


## `MultiDmsData`

In [98]:
data = multidms.MultiDmsData(
    func_score_df,
    alphabet= multidms.AAS_WITHSTOP,
    reference="Delta-2-1"
)

In [118]:
data.mutations[:5]

('M1F', 'M1I', 'M1K', 'M1L', 'M1T')

In [119]:
data.mutations_df.head()

Unnamed: 0,mutation,wts,sites,muts,times_seen
0,M1F,M,1,F,2
1,M1I,M,1,I,8
2,M1K,M,1,K,3
3,M1L,M,1,L,4
4,M1T,M,1,T,7


In [120]:
data.variants_df.head()

Unnamed: 0,condition,aa_substitutions,weight,func_score,allowed_variant,var_wrt_ref
0,Delta-2-1,,4344,-0.158614,True,
1,Delta-2-1,A1015D,1,-1.6764,True,A1015D
2,Delta-2-1,A1015D E1188Q,1,-1.1184,True,A1015D E1188Q
3,Delta-2-1,A1015D T1027S,1,-3.5,True,A1015D T1027S
4,Delta-2-1,A1016S,3,-0.222967,True,A1016S


In [101]:
data.conditions

('Delta-2-1', 'Omicron_BA.1-2-1')

In [102]:
data.reference

'Delta-2-1'

In [103]:
data.binarymaps['X'] # this is poorly named, 'y' also exists here

{'Delta-2-1': BCOO(int8[24271, 9655], nse=65004),
 'Omicron_BA.1-2-1': BCOO(int8[66785, 9655], nse=2225339)}

## `MultiDmsModel`

In short the `MultiDmsModel` object initialization includes:
   1. initialize attributes
   2. initialize params depending on model composits chosen
   3. compile model composition

In [104]:
model = multidms.MultiDmsModel(
    data,
    latent_model="phi",
    epistatic_model="identity",
    output_activation="identity"
)

In [105]:
model.__dict__.keys()

dict_keys(['gamma_corrected', 'conditional_shifts', 'conditional_c', '_data', 'params', '_model'])

In [106]:
print(f"Before fitting, the loss of this model (with initial params) is {round(model.loss, 3)}")
start = time.time()
model.fit(lasso_shift=1e-5, maxiter=1000)
end = time.time()
print(f"Fitting took {end - start}")
print(f"After fitting, the loss of this model (with initial params) is {round(model.loss, 3)}")

Before fitting, the loss of this model (with initial params) is 11.408
Fitting took 18.7742919921875
After fitting, the loss of this model (with initial params) is 0.873


In [107]:
model.data.mutations_df.head()

Unnamed: 0,mutation,wts,sites,muts,times_seen
0,M1F,M,1,F,2
1,M1I,M,1,I,8
2,M1K,M,1,K,3
3,M1L,M,1,L,4
4,M1T,M,1,T,7


In [108]:
model.mutations_df.head()

Unnamed: 0,mutation,wts,sites,muts,times_seen,β,S_Delta-2-1,F_Delta-2-1,S_Omicron_BA.1-2-1,F_Omicron_BA.1-2-1
0,M1F,M,1,F,2,1.361793,0.0,1.00078,0.0,1.00078
1,M1I,M,1,I,8,-1.012618,0.0,-1.373631,-0.23532,-1.60895
2,M1K,M,1,K,3,0.092072,0.0,-0.268941,0.0,-0.268941
3,M1L,M,1,L,4,-0.864001,0.0,-1.225014,0.0,-1.225014
4,M1T,M,1,T,7,1.593224,0.0,1.232211,-0.454679,0.777532


In [121]:
model.data.variants_df.head()

Unnamed: 0,condition,aa_substitutions,weight,func_score,allowed_variant,var_wrt_ref
0,Delta-2-1,,4344,-0.158614,True,
1,Delta-2-1,A1015D,1,-1.6764,True,A1015D
2,Delta-2-1,A1015D E1188Q,1,-1.1184,True,A1015D E1188Q
3,Delta-2-1,A1015D T1027S,1,-3.5,True,A1015D T1027S
4,Delta-2-1,A1016S,3,-0.222967,True,A1016S


In [110]:
model.variants_df.head()

Unnamed: 0,condition,aa_substitutions,weight,func_score,allowed_variant,var_wrt_ref,predicted_latent,predicted_func_score,corrected_func_score
0,Delta-2-1,,4344,-0.158614,True,,-0.361013,-0.361013,-0.158614
1,Delta-2-1,A1015D,1,-1.6764,True,A1015D,-1.196098,-1.196098,-1.6764
2,Delta-2-1,A1015D E1188Q,1,-1.1184,True,A1015D E1188Q,-0.999916,-0.999916,-1.1184
3,Delta-2-1,A1015D T1027S,1,-3.5,True,A1015D T1027S,-1.297847,-1.297847,-3.5
4,Delta-2-1,A1016S,3,-0.222967,True,A1016S,-0.663131,-0.663131,-0.222967


In [122]:
sigmoid_model = multidms.MultiDmsModel(
    data,
    latent_model="phi",
    epistatic_model="sigmoid",
    output_activation="identity"
)

**The model classes share references to the same data to keep things effecient.** 

In [123]:
model.data is sigmoid_model.data

True

In [113]:
print(f"Before fitting, the loss of this model (with initial params) is {round(sigmoid_model.loss, 3)}")
sigmoid_model.fit(lasso_shift=1e-5, maxiter=1000)
print(f"After fitting, the loss of this model (with initial params) is {round(sigmoid_model.loss, 3)}")

Before fitting, the loss of this model (with initial params) is 1.8940000000000001
After fitting, the loss of this model (with initial params) is 0.9380000000000001


In [114]:
sigmoid_model.params

{'C_Delta-2-1': DeviceArray([0.], dtype=float64),
 'C_Omicron_BA.1-2-1': DeviceArray([0.], dtype=float64),
 'C_ref': DeviceArray([3.16808994], dtype=float64),
 'S_Delta-2-1': DeviceArray([0., 0., 0., ..., 0., 0., 0.], dtype=float64),
 'S_Omicron_BA.1-2-1': DeviceArray([ 0.        , -0.50970376,  0.        , ...,  0.0292292 ,
               0.64317001,  0.54683618], dtype=float64),
 'α': {'ge_bias': DeviceArray([-3.70341545], dtype=float64),
  'ge_scale': DeviceArray([3.43341845], dtype=float64)},
 'β': DeviceArray([ 1.76247481, -1.76742336,  0.69877739, ...,  0.00863932,
               1.3728258 ,  0.89143468], dtype=float64),
 'γ_Delta-2-1': DeviceArray([0.], dtype=float64),
 'γ_Omicron_BA.1-2-1': DeviceArray([-0.3866507], dtype=float64)}

In [117]:
for output_act in ["identity", "softplus", "gelu"]:
    imodel = multidms.MultiDmsModel(
        data,
        latent_model="phi",
        epistatic_model="sigmoid",
        output_activation=output_act
    )
    imodel.fit(lasso_shift=1e-5, maxiter=1000)
    print(f"loss w/ {output_act} output activation of this model (with initial params) is {round(imodel.loss, 3)}")

loss w/ identity output activation of this model (with initial params) is 0.9380000000000001
loss w/ softplus output activation of this model (with initial params) is 0.96
loss w/ gelu output activation of this model (with initial params) is 0.988
