# `multidms` fitting pipeline

Here, we demonstrate the pipeline for fitting a `multidms` model to data from [six deep mutational scanning experiments](https://github.com/dms-vep) across 3 homologs of the Spike protein.

In [69]:
import os
import sys
from collections import defaultdict
import time

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as onp
from tqdm.notebook import tqdm
import jax.numpy as jnp

import multidms
%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Read in the dms data and list all available experimental conditions.

In [72]:
func_score_df = pd.read_csv("../data/Delta_BA1_BA2_func_score_df.csv", index_col=0)
print(sorted(list(func_score_df.condition.unique())))

['Delta-1', 'Delta-2', 'Delta-3', 'Delta-4', 'Omicron_BA.1-1', 'Omicron_BA.1-2', 'Omicron_BA.1-3', 'Omicron_BA.2-1', 'Omicron_BA.2-2']


Choose a reference and all other conditions to be included in the fit. Note, the reference must exist in the available conditions.

In [73]:
reference_condition = "Delta-3"
fit_included_conditions = [
    'Delta-3', 'Delta-4', #'Delta-1', 'Delta-2', 'Delta-3', 'Delta-4', 
    'Omicron_BA.1-2', 'Omicron_BA.1-3', #'Omicron_BA.1-1', 'Omicron_BA.1-2', 'Omicron_BA.1-3', 
    'Omicron_BA.2-1', 'Omicron_BA.2-2'
]

Query the conditions to be included in a fit.

In [74]:
func_score_df= func_score_df.query(
    f"condition.isin({fit_included_conditions})"
)
func_score_df.aa_substitutions.fillna("", inplace=True)
func_score_df.sort_values("condition", inplace=True)
func_score_df.head()

Unnamed: 0,library,pre_sample,post_sample,barcode,func_score,func_score_var,pre_count,post_count,pre_count_wt,post_count_wt,pseudocount,n_codon_substitutions,aa_substitutions_sequential,n_aa_substitutions,aa_substitutions,pre_count_threshold,homolog,replicate,condition
1,Lib-3,2021-10-28_thaw-1_VSVG_control_1,2021-12-14_thaw-1_no-antibody_control_1,CAACCGTCACCACCAG,0.1042,0.001,8201,2734,1154203,358015,0.5,3,T22I R401K E1090D,3,T22I R403K E1092D,20,Delta,1,Delta-3
10942,Lib-3,2021-10-28_thaw-1_VSVG_control_2,2021-12-14_thaw-1_no-antibody_control_2,ATGATTAGAAATGGAA,-3.5,0.1381,545,15,1058992,550512,0.5,2,C669Y G1165S,2,C671Y G1167S,20,Delta,2,Delta-3
10940,Lib-3,2021-10-28_thaw-1_VSVG_control_2,2021-12-14_thaw-1_no-antibody_control_2,AGGTACACAGCATAAC,-0.8302,0.0169,545,159,1058992,550512,0.5,5,E617D I995V,2,E619D I997V,20,Delta,2,Delta-3
10939,Lib-3,2021-10-28_thaw-1_VSVG_control_2,2021-12-14_thaw-1_no-antibody_control_2,AGATGATCCAGAGTAA,-2.0923,0.0351,545,66,1058992,550512,0.5,2,M900I G1244C,2,M902I G1246C,20,Delta,2,Delta-3
10935,Lib-3,2021-10-28_thaw-1_VSVG_control_2,2021-12-14_thaw-1_no-antibody_control_2,ACAGTCACAACCACCG,0.1122,0.0106,545,306,1058992,550512,0.5,0,,0,,20,Delta,2,Delta-3


## `MultiDmsData`

After the functional score dataframe for all variants has been prepped, you can now initialize a `mulidms.MultiDmsData` Object. This will (1) convert substitution string to be with respect to a reference (if necessary) (2) set static attributes and (3) prep model training data which can be shared by multiple `multidms.MulitDmsModel` objects.

In [75]:
# uncomment to see docs
# help(multidms.MultiDmsData)

In [76]:
data = multidms.MultiDmsData(
    func_score_df,
    alphabet= multidms.AAS_WITHSTOP,
    condition_colors = sns.color_palette("Paired"),
    reference=reference_condition
)

100%|█████████████████████████████████████████████████████████████████████████████████████| 51274/51274 [02:16<00:00, 375.36it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████| 43835/43835 [01:57<00:00, 372.59it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████| 45998/45998 [02:21<00:00, 325.22it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████| 44021/44021 [02:10<00:00, 337.77it/s]


We can now view a few useful attributes

In [77]:
data.site_map.head()

Unnamed: 0,Delta-3,Delta-4,Omicron_BA.1-2,Omicron_BA.1-3,Omicron_BA.2-1,Omicron_BA.2-2
1015,A,A,A,A,A,A
1204,G,G,G,G,G,G
1245,K,K,K,K,K,K
1016,A,A,A,A,A,A
1176,V,V,V,V,V,V


In [78]:
data.mutations[:5]

('M1F', 'M1I', 'M1K', 'M1L', 'M1N')

In [79]:
data.mutations_df.head()

Unnamed: 0,mutation,wts,sites,muts,times_seen_Delta-3,times_seen_Delta-4,times_seen_Omicron_BA.1-2,times_seen_Omicron_BA.1-3,times_seen_Omicron_BA.2-1,times_seen_Omicron_BA.2-2
0,M1F,M,1,F,5.0,2.0,0.0,0.0,0.0,0.0
1,M1I,M,1,I,6.0,6.0,4.0,6.0,6.0,7.0
2,M1K,M,1,K,3.0,5.0,0.0,0.0,0.0,0.0
3,M1L,M,1,L,2.0,6.0,0.0,0.0,0.0,1.0
4,M1N,M,1,N,3.0,1.0,0.0,0.0,0.0,0.0


In [80]:
data.conditions

('Delta-3',
 'Delta-4',
 'Omicron_BA.1-2',
 'Omicron_BA.1-3',
 'Omicron_BA.2-1',
 'Omicron_BA.2-2')

In [81]:
data.reference

'Delta-3'

In [82]:
data.binarymaps

{'Delta-3': <binarymap.binarymap.BinaryMap at 0x7f4899ee6bf0>,
 'Delta-4': <binarymap.binarymap.BinaryMap at 0x7f489af4f880>,
 'Omicron_BA.1-2': <binarymap.binarymap.BinaryMap at 0x7f48871cbf40>,
 'Omicron_BA.1-3': <binarymap.binarymap.BinaryMap at 0x7f348bc33af0>,
 'Omicron_BA.2-1': <binarymap.binarymap.BinaryMap at 0x7f4899ee6a10>,
 'Omicron_BA.2-2': <binarymap.binarymap.BinaryMap at 0x7f348ba9ecb0>}

In [83]:
data.non_identical_sites

frozendict.frozendict({'Delta-3': [], 'Delta-4': array([], dtype=int64), 'Omicron_BA.1-2': array([547, 440, 950, 681, 417, 375, 212, 452, 655, 477, 501, 679, 796,
       484, 954, 496, 371, 764, 493,  67,  95, 446, 498, 981, 969, 156,
       505, 856, 339, 373,  19]), 'Omicron_BA.1-3': array([547, 440, 950, 681, 417, 375, 212, 452, 655, 477, 501, 679, 796,
       484, 954, 496, 371, 764, 493,  67,  95, 446, 498, 981, 969, 156,
       505, 856, 339, 373,  19]), 'Omicron_BA.2-1': array([440, 950, 681, 417, 375, 452, 655, 477, 501, 679, 408, 376, 796,
       484, 954, 405, 371, 764, 493,  27, 498, 969, 156, 505, 213, 339,
       373,  19]), 'Omicron_BA.2-2': array([440, 950, 681, 417, 375, 452, 655, 477, 501, 679, 408, 376, 796,
       484, 954, 405, 371, 764, 493,  27, 498, 969, 156, 505, 213, 339,
       373,  19])})

In [84]:
data.non_identical_mutations

frozendict.frozendict({'Delta-3': '', 'Delta-4': '', 'Omicron_BA.1-2': 'T547K N440K N950D R681H K417N S375F L212I R452L H655Y S477N N501Y N679K D796Y E484A Q954H G496S S371L N764K Q493R A67V T95I G446S Q498R L981F N969K G156E Y505H N856K G339D S373P R19T', 'Omicron_BA.1-3': 'T547K N440K N950D R681H K417N S375F L212I R452L H655Y S477N N501Y N679K D796Y E484A Q954H G496S S371L N764K Q493R A67V T95I G446S Q498R L981F N969K G156E Y505H N856K G339D S373P R19T', 'Omicron_BA.2-1': 'N440K N950D R681H K417N S375F R452L H655Y S477N N501Y N679K R408S T376A D796Y E484A Q954H D405N S371F N764K Q493R A27S Q498R N969K G156E Y505H V213G G339D S373P R19I', 'Omicron_BA.2-2': 'N440K N950D R681H K417N S375F R452L H655Y S477N N501Y N679K R408S T376A D796Y E484A Q954H D405N S371F N764K Q493R A27S Q498R N969K G156E Y505H V213G G339D S373P R19I'})

## `MultiDmsModel`

Next, we iterate through datasets and model definitions of interest and (1) initialize a model object with the respective dataset (2) fit that data to a model with some given hyperparameters and (3) save the models in a dataframe for comparison and plotting.

In [85]:
models = defaultdict(list)
for ge_func in ["sigmoid"]: #, "softplus", "identity"]:
    for output_act in ["softplus"]:

        imodel = multidms.MultiDmsModel(
                data,
                epistatic_model=ge_func,
                output_activation=output_act
        )

        start = time.time()
        imodel.fit(lasso_shift=1e-5, maxiter=5000, tol=1e-6)
        end = time.time()
        print(f"{ge_func},{output_act} done. fitting time = {round(end - start, 2)} seconds")

        models["ge_func"].append(ge_func)
        models["output_act"].append(output_act)
        models["model"].append(imodel)
            
models_df = pd.DataFrame(models)
models_df

sigmoid,softplus done. fitting time = 165.93 seconds


Unnamed: 0,ge_func,output_act,model
0,sigmoid,softplus,<multidms.model.MultiDmsModel object at 0x7f34...


In [86]:
a_model_object = models_df.loc[0, "model"]
a_model_object

<multidms.model.MultiDmsModel at 0x7f34885a6470>

models with the same dataset share the reference. This saves precious memory.

In [87]:
# model_w_same_dataset = models_df.loc[1, "model"]
# model_w_same_dataset.data is a_model_object.data

The Model object shares many similar properties from the data object, like mutations and variants dataframes, but the getters give you useful model-specific features like parameters and predictions.

In [88]:
a_model_object.mutations_df

Unnamed: 0,mutation,wts,sites,muts,times_seen_Delta-3,times_seen_Delta-4,times_seen_Omicron_BA.1-2,times_seen_Omicron_BA.1-3,times_seen_Omicron_BA.2-1,times_seen_Omicron_BA.2-2,...,S_Delta-4,F_Delta-4,S_Omicron_BA.1-2,F_Omicron_BA.1-2,S_Omicron_BA.1-3,F_Omicron_BA.1-3,S_Omicron_BA.2-1,F_Omicron_BA.2-1,S_Omicron_BA.2-2,F_Omicron_BA.2-2
0,M1F,M,1,F,5.0,2.0,0.0,0.0,0.0,0.0,...,0.540857,-0.953732,0.000000,-1.543076,0.000000,-1.543076,0.000000,-1.543076,0.000000,-1.543076
1,M1I,M,1,I,6.0,6.0,4.0,6.0,6.0,7.0,...,0.078443,-2.817999,-0.006960,-2.810931,-0.035774,-2.770219,-0.000115,-2.832660,0.125620,-2.804900
2,M1K,M,1,K,3.0,5.0,0.0,0.0,0.0,0.0,...,-1.712971,-2.736362,0.000000,-1.750442,0.000000,-1.750441,0.000000,-1.750442,0.000000,-1.750442
3,M1L,M,1,L,2.0,6.0,0.0,0.0,0.0,1.0,...,0.658943,-2.298255,0.000000,-2.619586,0.000000,-2.611861,0.000000,-2.622669,-0.270900,-2.712950
4,M1N,M,1,N,3.0,1.0,0.0,0.0,0.0,0.0,...,-2.234444,-2.208945,0.000000,0.066338,0.000000,0.066338,0.000000,0.066338,0.000000,0.066338
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10556,S1252T,S,1252,T,25.0,18.0,92.0,64.0,78.0,72.0,...,-0.091087,-0.520807,0.326278,-0.056783,0.015032,-0.399641,0.139781,-0.259529,0.079899,-0.326413
10557,S1252V,S,1252,V,14.0,11.0,58.0,37.0,56.0,51.0,...,-0.042160,-0.169704,0.613244,0.472141,0.871512,0.676485,0.280615,0.166507,0.357464,0.241144
10558,S1252W,S,1252,W,4.0,5.0,20.0,24.0,22.0,24.0,...,-0.044388,0.622249,-0.007277,0.650368,-0.091502,0.585666,-0.112024,0.569423,0.304166,0.862469
10559,S1252Y,S,1252,Y,21.0,23.0,70.0,59.0,101.0,123.0,...,0.085663,0.650723,0.237913,0.759663,0.664426,1.012225,0.067456,0.637008,0.185379,0.723230


In [89]:
a_model_object.variants_df

Unnamed: 0,condition,aa_substitutions,weight,func_score,var_wrt_ref,predicted_latent,predicted_func_score,corrected_func_score
0,Delta-3,,2566,-0.193454,,0.725305,0.081274,-0.193454
1,Delta-3,A1015D,1,-3.500000,A1015D,-4.233753,-2.970265,-3.500000
2,Delta-3,A1015T G1204D K1245Y,2,-0.710050,A1015T G1204D K1245Y,0.248601,-0.436785,-0.710050
3,Delta-3,A1016D,1,-0.417800,A1016D,-0.319579,-1.089572,-0.417800
4,Delta-3,A1016D V1176R,1,-2.225200,A1016D V1176R,-2.600064,-2.716915,-2.225200
...,...,...,...,...,...,...,...,...
241696,Omicron_BA.2-2,Y91H T95R M731I I844V,2,-3.500000,N440K N950D R681H K417N S375F R452L H655Y S477...,-6.369990,-3.002489,-3.148235
241697,Omicron_BA.2-2,Y91T K129N,1,-3.500000,N440K N950D R681H K417N S375F R452L H655Y S477...,-3.948217,-2.936501,-3.148235
241698,Omicron_BA.2-2,Y91T T553I,1,-3.251900,N440K N950D R681H K417N S375F R452L H655Y S477...,-4.340445,-2.960728,-2.900135
241699,Omicron_BA.2-2,Y91V A1078S,1,-1.466500,A1078S N440K N950D R681H K417N S375F R452L H65...,-2.223884,-2.584739,-1.114735


### Visualization

We offer the ability to visualize a model's shift parameters using the `MultiDmsModel.mut_shift_plot()` method which wraps the original viz function from `polyclonal.plot.lineplot_and_heatmap()` as seen [here](https://github.com/jbloomlab/polyclonal/blob/92fee4badb14e1db719074f202b4fab374dd0613/polyclonal/plot.py#L263) 

**NOTE:** Currently, the heatmaps place an "X" at the wildtype for the reference sequence _only_. Note that you quite easily look up the wildtypes for any homolog at a given site using the `MultiDmdData.site_map` property.

In [91]:
# data.site_map.loc[, :]


In [92]:
chart = a_model_object.mut_shift_plot()
chart

  for col_name, dtype in df.dtypes.iteritems():


In [93]:
chart.save("ref-delta-3-sigmoid.html")