# `multidms` fitting pipeline

Here, we demonstrate the pipeline for fitting a `multidms` model to data from [six deep mutational scanning experiments](https://github.com/dms-vep) across 3 homologs of the Spike protein.

In [1]:
import os
import sys
from collections import defaultdict
import time

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as onp
from tqdm.notebook import tqdm
import jax.numpy as jnp

import multidms
%load_ext autoreload
%autoreload 2
%matplotlib inline

Read in the dms data and list all available experimental conditions.

In [2]:
func_score_df = pd.read_csv("Delta_BA1_BA2_func_score_df.csv", index_col=0)
print(sorted(list(func_score_df.condition.unique())))

['Delta-1', 'Delta-2', 'Delta-3', 'Delta-4', 'Omicron_BA.1-1', 'Omicron_BA.1-2', 'Omicron_BA.1-3', 'Omicron_BA.2-1', 'Omicron_BA.2-2']


Choose a reference and all other conditions to be included in the fit. Note, the reference must exist in the available conditions.

In [3]:
reference_condition = "Delta-2"
fit_included_conditions = [
    'Delta-2', 'Delta-3', #'Delta-1', 'Delta-2', 'Delta-3', 'Delta-4', 
    'Omicron_BA.1-2', 'Omicron_BA.1-3', #'Omicron_BA.1-1', 'Omicron_BA.1-2', 'Omicron_BA.1-3', 
    'Omicron_BA.2-1', 'Omicron_BA.2-2'
]

Query the conditions to be included in a fit.

In [4]:
func_score_df= func_score_df.query(
    f"condition.isin({fit_included_conditions})"
)
func_score_df.aa_substitutions.fillna("", inplace=True)
func_score_df.sort_values("condition", inplace=True)
func_score_df.head()

Unnamed: 0,library,pre_sample,post_sample,barcode,func_score,func_score_var,pre_count,post_count,pre_count_wt,post_count_wt,pseudocount,n_codon_substitutions,aa_substitutions_sequential,n_aa_substitutions,aa_substitutions,pre_count_threshold,homolog,replicate,condition
10566,Lib-2,2021-10-28_thaw-1_VSVG_control_2,2021-11-28_thaw-1_no-antibody_control_2,AACTGCTTCATAAACC,-0.2306,0.0072,641,521,2009970,1917126,0.5,2,N17T L174T,2,N17T L176T,21,Delta,2,Delta-2
24964,Lib-2,2021-10-28_thaw-1_VSVG_control_1,2021-11-28_thaw-1_no-antibody_control_1,GGACGGCAGTGGGACC,-1.1103,0.0237,273,129,1909460,1951855,0.5,3,P207H A781V G1244C,3,P209H A783V G1246C,20,Delta,1,Delta-2
24962,Lib-2,2021-10-28_thaw-1_VSVG_control_1,2021-11-28_thaw-1_no-antibody_control_1,GGAATAAAACAGTGAA,-0.7305,0.02,273,168,1909460,1951855,0.5,1,D80Q,1,D80Q,20,Delta,1,Delta-2
24960,Lib-2,2021-10-28_thaw-1_VSVG_control_1,2021-11-28_thaw-1_no-antibody_control_1,GAGCGATAGCCTAAGC,0.4256,0.0132,273,375,1909460,1951855,0.5,3,N437Q D1039G D1137A,3,N439Q D1041G D1139A,20,Delta,1,Delta-2
24959,Lib-2,2021-10-28_thaw-1_VSVG_control_1,2021-11-28_thaw-1_no-antibody_control_1,GACGATGAATTATACA,0.9906,0.0114,273,555,1909460,1951855,0.5,5,V16I D138C F454Y T676S E988D,5,V16I D138C F456Y T678S E990D,20,Delta,1,Delta-2


## `MultiDmsData`

After the functional score dataframe for all variants has been prepped, you can now initialize a `mulidms.MultiDmsData` Object. This will (1) convert substitution string to be with respect to a reference (if necessary) (2) set static attributes and (3) prep model training data which can be shared by multiple `multidms.MulitDmsModel` objects.

In [5]:
# uncomment to see docs
# help(multidms.MultiDmsData)

In [6]:
data = multidms.MultiDmsData(
    func_score_df,
    alphabet= multidms.AAS_WITHSTOP,
    condition_colors = sns.color_palette("Paired"),
    reference=reference_condition
)

100%|█████████████████████████████████████████████████████| 51258/51258 [02:14<00:00, 381.45it/s]
100%|█████████████████████████████████████████████████████| 43804/43804 [01:56<00:00, 375.47it/s]
100%|█████████████████████████████████████████████████████| 45965/45965 [02:20<00:00, 327.74it/s]
100%|█████████████████████████████████████████████████████| 44003/44003 [02:10<00:00, 336.20it/s]


We can now view a few useful attributes

In [7]:
data.site_map.head()

Unnamed: 0,Delta-2,Delta-3,Omicron_BA.1-2,Omicron_BA.1-3,Omicron_BA.2-1,Omicron_BA.2-2
1015,A,A,A,A,A,A
1188,E,E,E,E,E,E
1027,T,T,T,T,T,T
1016,A,A,A,A,A,A
1175,S,S,S,S,S,S


In [8]:
data.mutations[:5]

('M1F', 'M1I', 'M1K', 'M1L', 'M1N')

In [9]:
data.mutations_df.head()

Unnamed: 0,mutation,wts,sites,muts,times_seen_Delta-2,times_seen_Delta-3,times_seen_Omicron_BA.1-2,times_seen_Omicron_BA.1-3,times_seen_Omicron_BA.2-1,times_seen_Omicron_BA.2-2
0,M1F,M,1,F,1.0,5.0,0.0,0.0,0.0,0.0
1,M1I,M,1,I,2.0,6.0,4.0,6.0,6.0,7.0
2,M1K,M,1,K,2.0,3.0,0.0,0.0,0.0,0.0
3,M1L,M,1,L,1.0,2.0,0.0,0.0,0.0,1.0
4,M1N,M,1,N,0.0,3.0,0.0,0.0,0.0,0.0


In [10]:
data.conditions

('Delta-2',
 'Delta-3',
 'Omicron_BA.1-2',
 'Omicron_BA.1-3',
 'Omicron_BA.2-1',
 'Omicron_BA.2-2')

In [11]:
data.reference

'Delta-2'

In [12]:
data.binarymaps

{'Delta-2': <binarymap.binarymap.BinaryMap at 0x7f488ed0fcd0>,
 'Delta-3': <binarymap.binarymap.BinaryMap at 0x7f4887fd32e0>,
 'Omicron_BA.1-2': <binarymap.binarymap.BinaryMap at 0x7f4887d4b250>,
 'Omicron_BA.1-3': <binarymap.binarymap.BinaryMap at 0x7f4887d4bbb0>,
 'Omicron_BA.2-1': <binarymap.binarymap.BinaryMap at 0x7f4886b00fd0>,
 'Omicron_BA.2-2': <binarymap.binarymap.BinaryMap at 0x7f4886b01a50>}

In [13]:
data.non_identical_sites

frozendict.frozendict({'Delta-2': [], 'Delta-3': array([], dtype=int64), 'Omicron_BA.1-2': array([655, 484, 440, 679, 681, 501, 950, 856, 417, 796, 764, 496, 452,
       373, 375, 477, 493, 954, 446,  67,  95, 981, 547, 212, 969, 339,
       156, 498, 371,  19, 505]), 'Omicron_BA.1-3': array([655, 484, 440, 679, 681, 501, 950, 856, 417, 796, 764, 496, 452,
       373, 375, 477, 493, 954, 446,  67,  95, 981, 547, 212, 969, 339,
       156, 498, 371,  19, 505]), 'Omicron_BA.2-1': array([655, 484, 440, 679, 681, 376, 501, 950, 408, 417, 796, 764, 452,
       373, 375, 477, 493, 954, 405,  27, 969, 339, 156, 498, 371, 213,
        19, 505]), 'Omicron_BA.2-2': array([655, 484, 440, 679, 681, 376, 501, 950, 408, 417, 796, 764, 452,
       373, 375, 477, 493, 954, 405,  27, 969, 339, 156, 498, 371, 213,
        19, 505])})

In [14]:
data.non_identical_mutations

frozendict.frozendict({'Delta-2': '', 'Delta-3': '', 'Omicron_BA.1-2': 'H655Y E484A N440K N679K R681H N501Y N950D N856K K417N D796Y N764K G496S R452L S373P S375F S477N Q493R Q954H G446S A67V T95I L981F T547K L212I N969K G339D G156E Q498R S371L R19T Y505H', 'Omicron_BA.1-3': 'H655Y E484A N440K N679K R681H N501Y N950D N856K K417N D796Y N764K G496S R452L S373P S375F S477N Q493R Q954H G446S A67V T95I L981F T547K L212I N969K G339D G156E Q498R S371L R19T Y505H', 'Omicron_BA.2-1': 'H655Y E484A N440K N679K R681H T376A N501Y N950D R408S K417N D796Y N764K R452L S373P S375F S477N Q493R Q954H D405N A27S N969K G339D G156E Q498R S371F V213G R19I Y505H', 'Omicron_BA.2-2': 'H655Y E484A N440K N679K R681H T376A N501Y N950D R408S K417N D796Y N764K R452L S373P S375F S477N Q493R Q954H D405N A27S N969K G339D G156E Q498R S371F V213G R19I Y505H'})

## `MultiDmsModel`

Next, we iterate through datasets and model definitions of interest and (1) initialize a model object with the respective dataset (2) fit that data to a model with some given hyperparameters and (3) save the models in a dataframe for comparison and plotting.

In [25]:
models = defaultdict(list)
for ge_func in ["sigmoid", "softplus", "identity"]:
    for output_act in ["softplus"]:

        imodel = multidms.MultiDmsModel(
                data,
                epistatic_model=ge_func,
                output_activation=output_act
        )

        start = time.time()
        imodel.fit(lasso_shift=1e-5, maxiter=5000, tol=1e-6)
        end = time.time()
        print(f"{ge_func},{output_act} done. fitting time = {round(end - start, 2)} seconds")

        models["ge_func"].append(ge_func)
        models["output_act"].append(output_act)
        models["model"].append(imodel)
            
models_df = pd.DataFrame(models)
models_df

sigmoid,softplus done. fitting time = 162.41 seconds
softplus,softplus done. fitting time = 163.47 seconds
identity,softplus done. fitting time = 162.15 seconds


Unnamed: 0,ge_func,output_act,model
0,sigmoid,softplus,<multidms.model.MultiDmsModel object at 0x7f48...
1,softplus,softplus,<multidms.model.MultiDmsModel object at 0x7f48...
2,identity,softplus,<multidms.model.MultiDmsModel object at 0x7f48...


In [33]:
a_model_object = models_df.loc[0, "model"]
a_model_object

<multidms.model.MultiDmsModel at 0x7f4887535960>

models with the same dataset share the reference. This saves precious memory.

In [34]:
model_w_same_dataset = models_df.loc[1, "model"]
model_w_same_dataset.data is a_model_object.data

True

The Model object shares many similar properties from the data object, like mutations and variants dataframes, but the getters give you useful model-specific features like parameters and predictions.

In [35]:
a_model_object.mutations_df

Unnamed: 0,mutation,wts,sites,muts,times_seen_Delta-2,times_seen_Delta-3,times_seen_Omicron_BA.1-2,times_seen_Omicron_BA.1-3,times_seen_Omicron_BA.2-1,times_seen_Omicron_BA.2-2,...,S_Delta-3,F_Delta-3,S_Omicron_BA.1-2,F_Omicron_BA.1-2,S_Omicron_BA.1-3,F_Omicron_BA.1-3,S_Omicron_BA.2-1,F_Omicron_BA.2-1,S_Omicron_BA.2-2,F_Omicron_BA.2-2
0,M1F,M,1,F,1.0,5.0,0.0,0.0,0.0,0.0,...,-1.114087,-2.316968,0.000000,-1.037506,0.000000,-1.037506,0.000000,-1.037506,0.000000,-1.037506
1,M1I,M,1,I,2.0,6.0,4.0,6.0,6.0,7.0,...,0.101774,-2.685297,0.050907,-2.751813,-0.517335,-3.454450,0.027630,-2.782116,-0.510907,-3.447103
2,M1K,M,1,K,2.0,3.0,0.0,0.0,0.0,0.0,...,0.948611,-1.014987,0.000000,-2.070066,0.000000,-2.070066,0.000000,-2.070066,0.000000,-2.070066
3,M1L,M,1,L,1.0,2.0,0.0,0.0,0.0,1.0,...,-0.565530,-3.229575,0.000000,-2.524612,0.000000,-2.524612,0.000000,-2.524612,-0.257047,-2.860454
4,M1N,M,1,N,0.0,3.0,0.0,0.0,0.0,0.0,...,0.261646,0.043491,0.000000,0.017676,0.000000,0.017676,0.000000,0.017676,0.000000,0.017676
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10525,S1252T,S,1252,T,13.0,25.0,92.0,64.0,78.0,72.0,...,-0.713572,-0.405463,0.252703,-0.086477,0.005444,-0.144471,0.055311,-0.131709,0.054863,-0.131821
10526,S1252V,S,1252,V,8.0,14.0,58.0,37.0,56.0,51.0,...,-0.065076,0.053516,-0.602393,-0.000719,-0.463518,0.016076,-0.895514,-0.044318,-0.845191,-0.035954
10527,S1252W,S,1252,W,1.0,4.0,20.0,24.0,22.0,24.0,...,-0.274269,-0.029075,0.007924,0.009542,0.051697,0.014658,-0.000528,0.008529,0.249411,0.035273
10528,S1252Y,S,1252,Y,25.0,21.0,70.0,59.0,101.0,122.0,...,0.178195,0.083179,-0.493674,0.037515,-0.154600,0.064258,-0.698847,0.016531,-0.632361,0.023788


In [36]:
a_model_object.variants_df

Unnamed: 0,condition,aa_substitutions,weight,func_score,var_wrt_ref,predicted_latent,predicted_func_score,corrected_func_score
0,Delta-2,,4774,-0.139269,,3.279702,-0.059151,-0.139269
1,Delta-2,A1015D,1,-1.676400,A1015D,0.627221,-1.700153,-1.676400
2,Delta-2,A1015D E1188Q,2,-1.641850,A1015D E1188Q,0.865805,-1.426468,-1.641850
3,Delta-2,A1015D T1027S,1,-3.500000,A1015D T1027S,1.040749,-1.241658,-3.500000
4,Delta-2,A1016S,3,-0.410833,A1016S,3.268183,-0.061281,-0.410833
...,...,...,...,...,...,...,...,...
241254,Omicron_BA.2-2,Y91H T95R M731I I844V,2,-3.500000,H655Y E484A N440K N679K R681H T376A N501Y N950...,-5.106477,-5.064621,-5.189393
241255,Omicron_BA.2-2,Y91T K129N,1,-3.500000,H655Y E484A N440K N679K R681H T376A N501Y N950...,-3.766567,-4.995635,-5.189393
241256,Omicron_BA.2-2,Y91T T553I,1,-3.251900,H655Y E484A N440K N679K R681H T376A N501Y N950...,-4.132069,-5.025455,-4.941293
241257,Omicron_BA.2-2,Y91V A1078S,1,-1.466500,A1078S H655Y E484A N440K N679K R681H T376A N50...,-2.049817,-4.529789,-3.155893


### Visualization

We offer the ability to visualize a model's shift parameters using the `MultiDmsModel.mut_shift_plot()` method which wraps the original viz function from `polyclonal.plot.lineplot_and_heatmap()` as seen [here](https://github.com/jbloomlab/polyclonal/blob/92fee4badb14e1db719074f202b4fab374dd0613/polyclonal/plot.py#L263) 

**NOTE:** Currently, the heatmaps place an "X" at the wildtype for the reference sequence _only_. Note that you quite easily look up the wildtypes for any homolog at a given site using the `MultiDmdData.site_map` property.

In [40]:
data.site_map.loc[67, :]


Delta-2           A
Delta-3           A
Omicron_BA.1-2    V
Omicron_BA.1-3    V
Omicron_BA.2-1    A
Omicron_BA.2-2    A
Name: 67, dtype: object

In [45]:
chart = a_model_object.mut_shift_plot()
chart

  for col_name, dtype in df.dtypes.iteritems():


In [46]:
chart.save("ref-delta-sigmoid-example.html")