# `multidms` example

Here, we demonstrate the pipeline for fitting a `multidms` model on some data using the tools available in the package. Currently, we show how to pre-process data from various dms experiments targeting Delta and Omicron spike protein homologs.

In [1]:
# built-in libraries
import os
import sys
import pickle
from itertools import combinations
import importlib
import math
import re
from timeit import default_timer as timer
import json

# external dependencies
import pandas as pd
import jax
import jax.numpy as jnp
from jax.experimental import sparse
from jaxopt import ProximalGradient
import jaxopt
import numpy as onp
from scipy.stats import pearsonr
import binarymap as bmap
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
# import plotnine

# local dependencies
# sys.path.append("..")
# import multidms
import multidms
from multidms import Multidms, AAS_WITHSTOP

**First, we'll define a few global variables.**

## Initialize `Multidms` Object

In [2]:
func_score_dict = {
    'homolog' : ["1","1","1","1", "2","2","2","2","2","2"],
    'variant' : ['M1E', 'G3R', 'G3P', 'M1W', 'M1E', 'P3R', 'P3G', 'M1E P3G', 'M1E P3R', 'P2T'],
    'log2E' : [2, -7, -0.5, 2.3, 1, -5, 0.4, 2.7, -2.7, 0.3],
}
func_score_df = pd.DataFrame(func_score_dict)
func_score_df

NameError: name 'test_dict' is not defined

In [None]:
func_score_df.rename({"aa_substitutions_reference":"aa_substitutions"}, axis=1, inplace=True)
func_score_df.rename({"homolog_exp":"condition"}, axis=1, inplace=True)

## `Multidms` Object

In [None]:
func_score_df.head()

In [None]:
# def f(x,y,z):
#     return x + y + z
# ans = partial()

In [20]:
from jax.tree_util import Partial as jax_partial
import jaxlib
from functools import partial as functools_partial
from multidms.model import identity, scaled_shifted_softplus, ϕ, g

# JAX Engine
@functools_partial(jax.jit, static_argnums=(0, 1, 2,))
def model_predict(
    ϕ:jaxlib.xla_extension.CompiledFunction,
    g:jaxlib.xla_extension.CompiledFunction,
    t:jaxlib.xla_extension.CompiledFunction,
    h_params:dict, 
    X_h:jnp.array, 
    **kwargs
):
    """ Biophysical model - compiled for optimization 
    until model functions ϕ, g, and t are updated."""

    return t(g(h_params['α'], ϕ(h_params, X_h)), **kwargs)

In [21]:
compiled_predict = jax_partial(model_predict, ϕ, g, identity)
mdms = Multidms(
    func_score_df,
    compiled_predict,
    alphabet= AAS_WITHSTOP,
    reference="Delta-2-1"
)

Found 501 site(s) lacking data in at least one condition.
882 of the 2058 variants were removed because they had mutations at the above sites, leaving 1176 variants.


100%|███████████████████████████████████████████████████████| 783/783 [00:04<00:00, 179.69it/s]


There were 0 cache hits in total for condition Omicron_BA.1-2-1.


In [22]:
mdms.mut_df.head()

Unnamed: 0,mutation,β,wts,sites,muts,times_seen,S_Delta-2-1,F_Delta-2-1,S_Omicron_BA.1-2-1,F_Omicron_BA.1-2-1
0,V3A,-1.60456,V,3,A,1,0.0,-0.324383,0.0,-0.324383
1,V3I,0.434577,V,3,I,1,0.0,-0.043441,0.0,-0.043441
2,L5A,-1.267456,L,5,A,2,0.0,-0.233725,0.0,-0.233725
3,L5H,0.295835,L,5,H,1,0.0,-0.049874,0.0,-0.049874
4,L5I,-1.908145,L,5,I,1,0.0,-0.434445,0.0,-0.434445


In [23]:
mdms.data_to_fit.head()

Unnamed: 0,condition,aa_substitutions,weight,func_score,allowed_variant,var_wrt_ref,predicted_latent_phenotype,predicted_func_score,corrected_func_score
0,Delta-2-1,,80,-0.110441,True,,,-0.066929,-0.110441
1,Delta-2-1,A1020T,1,0.7215,True,A1020T,,-0.017982,0.7215
2,Delta-2-1,A1222V,1,-0.2743,True,A1222V,,-0.201145,-0.2743
5,Delta-2-1,A222E G614N P1079H K1191*,1,-0.3595,True,A222E G614N P1079H K1191*,,-0.069151,-0.3595
6,Delta-2-1,A222M,1,-0.6864,True,A222M,,-0.0515,-0.6864


In [24]:
mdms.conditions

('Omicron_BA.1-2-1', 'Delta-2-1')

In [25]:
mdms.reference

'Delta-2-1'

In [26]:
mdms.binarymaps['X']

{'Delta-2-1': BCOO(int8[393, 1582], nse=856),
 'Omicron_BA.1-2-1': BCOO(int8[783, 1582], nse=16155)}

## JAX Engine: Objective Function

In [27]:
@jax.jit
def prox(
    params, 
    hyperparams_prox=dict(
        lasso_params=None, 
        lock_params=None,
    ), 
    scaling=1.0
):
    
    # Monotonic non-linearity, if non-linear model
    if "α" in params:
        params["α"]["ge_scale"] = params["α"]["ge_scale"].clip(0)
    
    if hyperparams_prox["lasso_params"] is not None:
        for key, value in hyperparams_prox["lasso_params"].items():
            params[key] = jaxopt.prox.prox_lasso(params[key], value, scaling)

    # Any params to constrain during fit
    if hyperparams_prox["lock_params"] is not None:
        for key, value in hyperparams_prox["lock_params"].items():
            params[key] = value

    return params


@functools_partial(jax.jit, static_argnums=(0,))
def cost_smooth(f, params, data, δ=1, λ_ridge=0, **kwargs):
    """Cost (Objective) function summed across all homologs"""

    X, y = data
    loss = 0   
    
    # Sum the huber loss across all homologs
    for homolog, X_h in X.items():   
        
        # Subset the params for homolog-specific prediction
        h_params = {
            "α":params["α"],
            "β":params["β"], 
            "C_ref":params["C_ref"],
            "S":params[f"S_{homolog}"], 
            "C":params[f"C_{homolog}"],
        }
       
        # compute predictions 
        y_h_predicted = f(h_params, X_h, **kwargs)
        
        # compute the Huber loss between observed and predicted
        # functional scores
        loss += jaxopt.loss.huber_loss(
            y[homolog] + params[f"γ_{homolog}"], y_h_predicted, δ
        ).mean()
        
        # compute a regularization term that penalizes non-zero
        # shift parameters and add it to the loss function
        ridge_penalty = λ_ridge * (params[f"S_{homolog}"] ** 2).sum()
        loss += ridge_penalty

    return loss

In [28]:
compiled_cost = jax_partial(cost_smooth, compiled_predict)
data = (mdms.binarymaps['X'], mdms.binarymaps['y'])
print(compiled_cost(mdms.params, data))

1.4350489447079537


In [29]:
mdms.fit(compiled_cost, prox)
print(compiled_cost(mdms.params, data))

0.09388914704164272


In [30]:
mdms.mut_df

Unnamed: 0,mutation,β,wts,sites,muts,times_seen,S_Delta-2-1,F_Delta-2-1,S_Omicron_BA.1-2-1,F_Omicron_BA.1-2-1
0,V3A,-0.645865,V,3,A,1,0.0,-0.933695,0.000000,-0.933695
1,V3I,-0.218521,V,3,I,1,0.0,-0.539380,-0.440457,-0.946310
2,L5A,0.493328,L,5,A,2,0.0,-0.003065,0.140090,0.080669
3,L5H,1.488686,L,5,H,1,0.0,0.446821,1.229238,0.673836
4,L5I,-0.521430,L,5,I,1,0.0,-0.815158,1.166527,0.087321
...,...,...,...,...,...,...,...,...,...,...
1577,S1252Q,-2.050120,S,1252,Q,1,0.0,-2.192824,-1.558008,-2.885626
1578,S1252R,-0.171221,S,1252,R,1,0.0,-0.498438,-0.189141,-0.665837
1579,S1252T,-2.029232,S,1252,T,2,0.0,-2.177880,0.000000,-2.177880
1580,S1252V,-0.498258,S,1252,V,1,0.0,-0.793358,0.759045,-0.158280
