# multidms

## Overview of model

The `multidms` model applies to a case where you have DMS datasets for two or more homologs and are interested in identifying shifts in mutational effects between homologs.
To do so, the model defines one homolog as a reference homolog.
For each mutation, the model fits one parameter that quantifies the effect of the mutation in the reference homolog.
For each non-reference homolog, it also fits a shift parameter that quantifies the shift in the mutation's effect in the non-reference homolog relative to the reference.
Shift parameters can be regularized, encouraging most of them to be close to zero.
This regularization step is a useful way to eliminate the effects of experimental noise, and is most useful in cases where you expect most mutations to have the same effects between homologs, such as for homologs that are close relatives.

The model uses a global-epistasis function to disentangle the effects of multiple mutations on the same variant.
To do so, it assumes that mutational effects additively influence a latent biophysical property the protein (e.g., $\Delta G$ of folding).
The mutational-effect parameters described above operate at this latent level.
The global-epistasis function then assumes a sigmoidal relationship between a protein's latent property and its functional score measured in the experiment (e.g., log enrichment score).
Ultimately, mutational parameters, as well as ones controlling the shape of the sigmoid, are all jointly fit to maximize agreement between predicted and observed functional scores acorss all variants of all homologs.

## Detailed description of the model

For each variant $v$ from homolog $h$, we use a global-epistasis function $g$ to convert a latent phenotype $\phi$ to a functional score $f$:

$$f(v,h) = g_{\alpha}(\phi(v,h))$$

where $g$ is a sigmoid and $\alpha$ is a set of parameters encoding the sigmoid.

The latent phenotype is computed in the following way:

$$\phi(v,h) = c + \sum_{m \in v} (x_m + s_{m,h})$$

where:
* $c$ is the wildtype latent phenotype for the reference homolog.
* $x_m$ is the latent phenotypic effect of mutation $m$. See details below.
* $s_{m,h}$ is the shift of the effect of mutation $m$ in homolog $h$. These parameters are fixed to zero for the reference homolog. For non-reference homologs, they are defined in the same way as $x_m$ parameters.
* $v$ is the set of all mutations relative to the reference wildtype sequence (including all mutations that separate homolog $h$ from the reference homolog).

The $x_m$ variable is defined such that mutations are always relative to the reference homolog.
For example, if the wildtype amino acid at site 30 is an A in the reference homolog, and a G in a non-reference homolog, then a Y30G mutation in the non-reference homolog is recorded as an A30G mutation relative to the reference.
This way, each homolog informs the exact same parameters, even at sites that differ in wildtype amino acid.
These are encoded in a `BinaryMap` object, where all sites that are non-identical to the reference are 1's.

Ultimately, we fit parameters using a loss function with one term that scores differences between predicted and observed values and another that uses L1 regularization to penalize non-zero $s_{m,h}$ values:

$$ L_{\text{total}} = \sum_{h} \left[\sum_{v} L_{\text{fit}}(y_{v,h}, f(v,h)) + \lambda \sum_{m} |s_{m,h}|\right]$$

where:
* $L_{\text{total}}$ is the total loss function.
* $L_{\text{fit}}$ is a loss function that penalizes differences in predicted vs. observed functional scores.
* $y_{v,h}$ is the experimentally measured functional score of variant $v$ from homolog $h$.

## Model using matrix algebra

We compute a vector or predicted latent phenotypes $P_{h}$ as:

$$P_{h} = c + (X_h \cdot (β + S_h))$$

where:
* $β$ is a vector of all $β_m$ values.
* $S_h$ is a matrix of all $s_{m,h}$ values.
* $X_h$ is a sparse matrix, where rows are variants, columns are mutations (all defined relative to the reference homolog), and values are weights of 0's and 1's. These weights are used to compute the phenotype of each variant given the mutations present.
* $c$ is the same as above.

In the matrix algebra, the sum of $β_m$ and $S_{m,h}$ gives a vector of mutational effects, with one entry per mutation.
Multiplying the matrix $X_h$ by this vector gives a new vector with one entry per variant, where values are the sum of mutational effects, weighted by the variant-specific weights in $X_h$.
Adding the $c$ value to this vector will give a vector of predicted latent phenotypes for each variant.

Next, the global-epistasis function can be used to convert a vector of predicted latent phenotypes to a vector of predicted functional scores.

$$F_{h,pred} = g_{\alpha}(P_h)$$

Finally, this vector could be fed into a loss function and compared with a vector of observed functional scores.

## Import `Python` modules

In [10]:
import pandas as pd
import numpy as np
import re
import binarymap as bmap

import matplotlib.pyplot as plt
import seaborn as sns
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax.experimental import sparse
import jaxopt
import numpy as onp
from scipy.stats import pearsonr

## Strategy for converting mutations to be relative to the reference homolog

As described above, our strategy involves converting all mutations in all homologs to be relative to the amino-acid sequence of the reference homolog. The below function performs this step.

Note 1: we will probably need some special-purpose code to handle gaps. This isn't done yet.

In [11]:
def variant_mutations_wrt_ref(
    func_score_df:pd.DataFrame, 
    homologs:dict,
    homolog_name_col,
    substitution_col
):
    """
    Takes a dataframe for making a `BinaryMap` object, and adds
    a column where each entry is a list of mutations in a variant
    relative to the amino-acid sequence of the reference homolog.
    
    Parameters
    ----------

    func_score_df : pandas.DataFrame
        This should be in the same format as described in BinaryMap.
        
    homologs : dict
        A dictionary containing all possible target homolog 
        names (keys) and sequences (values).
    
    homolog_name_col : str
        The name of the column in func_score_df that identifies the
        homolog for a given variant.
    
    substitution_col: The name of the column in func_score_df that
        lists mutations in each variant relative to the wildtype
        amino-acid sequence of the homolog in which they occur.
    
    Returns
    -------
        
    pd.DataFrame :
        A dataframe with a new column 'var_wrt_ref' that contains
        all variant mutations with respect to the wildtype.
    
    """
    
    def mutations_wrt_ref(mutations, hom_wtseq):
        """
        Takes a list of mutations for a given variant relative
        to its background homolog and returns a list of all
        mutations that separate the variant from the reference
        homolog.
        """
        
        # Compute the full amino-acid sequence of the
        # given variant
        mutated_homolog = list(hom_wtseq)
        for mutation in mutations.split():

            # TODO: Do we need to change the regex to allow
            # for gap '-' and stop '*' characters?
            pattern = r'(?P<aawt>\w)(?P<site>\d+)(?P<aamut>[\w\*])'
            match = re.search(pattern, mutation)
            assert match != None, mutation
            aawt = match.group('aawt')
            site = match.group('site')
            aamut = match.group('aamut')
            mutated_homolog[int(site)-1] = aamut
            
        hom_var_seq = ''.join(mutated_homolog)
        
        # Make a list of all mutations that separate the variant
        # from the reference homolog
        ref_muts = [
            f"{aaref}{i+1}{aavar}" 
            for i, (aaref, aavar) in enumerate(zip(homologs["reference"], hom_var_seq))
            if aaref != aavar
        ]
        
        return " ".join(ref_muts)

    # Duplicate the substitutions_col, then loop over homologs
    # and modify entries as needed
    func_score_df = func_score_df.assign(var_wrt_ref = func_score_df[substitution_col].values)
    for hom_name, hom_seq in homologs.items():
        if hom_name == "reference": continue
        hom_df = func_score_df.query(f"{homolog_name_col} == '{hom_name}'")
        hom_var_wrt_ref = [
            mutations_wrt_ref(muts, homologs[hom_name]) 
            for muts in hom_df[substitution_col]
        ]
        func_score_df.loc[hom_df.index.values, "var_wrt_ref"] = hom_var_wrt_ref
        
    return func_score_df

Next, we will test the above function with a small test case. Below, we define variants from two imaginary homologs: "reference" and "2".

In [12]:
# TODO: right now, the code requires that one of the homologs
# is called "reference". We need to add code to somehow make
# this work with arbitrary input.
homologs = {
    "reference" : "MG",
    "2" : "MP"
}

In [13]:
test_dict = {
    'homolog' : ["reference","reference","reference","2","2","2","2","2"],
    'variant' : ['M1E', 'G2R', 'G2P', 'M1E', 'P2R', 'P2G', 'M1E P2G', 'M1E P2R'],
    'log2E' : [2, -7, -0.5, 2.3, -5, 0.4, 2.7, -2.7],
}
test_df = pd.DataFrame(test_dict)
test_df

Unnamed: 0,homolog,variant,log2E
0,reference,M1E,2.0
1,reference,G2R,-7.0
2,reference,G2P,-0.5
3,2,M1E,2.3
4,2,P2R,-5.0
5,2,P2G,0.4
6,2,M1E P2G,2.7
7,2,M1E P2R,-2.7


In [14]:
test_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 8 entries, 0 to 7
Data columns (total 3 columns):
 #   Column   Non-Null Count  Dtype  
---  ------   --------------  -----  
 0   homolog  8 non-null      object 
 1   variant  8 non-null      object 
 2   log2E    8 non-null      float64
dtypes: float64(1), object(2)
memory usage: 320.0+ bytes


In [15]:
func_score_df = variant_mutations_wrt_ref(test_df, homologs, "homolog", "variant")
func_score_df

Unnamed: 0,homolog,variant,log2E,var_wrt_ref
0,reference,M1E,2.0,M1E
1,reference,G2R,-7.0,G2R
2,reference,G2P,-0.5,G2P
3,2,M1E,2.3,M1E G2P
4,2,P2R,-5.0,G2R
5,2,P2G,0.4,
6,2,M1E P2G,2.7,M1E
7,2,M1E P2R,-2.7,M1E G2R


Next, we will use the modified dataframe to create a `BinaryMap` object for each homolog. Each object needs to have the exact same columns in the same order. To achieve this, we will first make a list of all unique mutations observed in the above dataframe.

In [16]:
allowed_subs = {
    s for subs in func_score_df.var_wrt_ref
    for s in subs.split()
}
allowed_subs

{'G2P', 'G2R', 'M1E'}

Then, we will make a `BinaryMap` for each homolog, feeding the above list to the function that makes the maps, which ensures the maps will have identical columns.

In [17]:
X = {}
for homolog, homolog_func_score_df in func_score_df.groupby("homolog"):
    ref_bmap = bmap.BinaryMap(
        homolog_func_score_df,
        substitutions_col="var_wrt_ref",
        allowed_subs=allowed_subs
    )
    
    # TODO use sparse arrays
#     X[homolog] = jnp.array(ref_bmap.binary_variants.toarray())
    X[homolog] = sparse.BCOO.fromdense(ref_bmap.binary_variants.toarray())
    
    print(homolog)
    print(ref_bmap.binary_variants.toarray())
    print({mut:ref_bmap.sub_to_i(mut) for mut in allowed_subs})



2
[[1 1 0]
 [0 0 1]
 [0 0 0]
 [1 0 0]
 [1 0 1]]
{'M1E': 0, 'G2P': 1, 'G2R': 2}
reference
[[1 0 0]
 [0 0 1]
 [0 1 0]]
{'M1E': 0, 'G2P': 1, 'G2R': 2}


## Implementing the full model with JAX

Below, we implement the $\phi$ and $g$ functions from above.

Next, we initialize values of free parameters for downstream optimization by JAX.

In [18]:
# all model params to be tuned will be stored in a dictionary
params = {}
seed = 0
key = jax.random.PRNGKey(seed)

# initialize beta parameters from normal distribution.
all_mutations = list(allowed_subs)
params["β"] = jax.random.normal(shape=(len(all_mutations),), key=key)

# initialize shift parameters
# TODO no need for subkeys
key, *subkeys = jax.random.split(key, len(homologs)+1)
for homolog, subkey in zip(homologs, subkeys):
    
    # Do not create shift parameters for the reference
#     if homolog == "reference": continue
    
    # We expect most shift parameters to be close to zero
    params[f"S_{homolog}"] = jnp.zeros(shape=(len(all_mutations),))

params["C_ref"] = jnp.zeros(shape=(1, ))

# Number of free parameters for a linear transformation on the latent phenotype
# that results in the shape of the sigmoid.
n_units = 1

# TODO ??: Are we parameterizing out GE function correctly?
key, *subkeys = jax.random.split(key, num=5)
params["α"]=dict(sig_stretch_x = jax.random.normal(shape=(n_units,), key=subkeys[0]), # 'stretch' in x direction
                 sig_shift_x = jax.random.normal(shape=(1,), key=subkeys[1]),         # 'shift' in the x direction
                 sig_stretch_y = jax.random.normal(shape=(n_units,), key=subkeys[2]),       # 'stretch' in the y direction
                 sig_shift_y = jax.random.normal(shape=(1,), key=subkeys[3]),         # 'shift' in the y directions
)

params

{'β': DeviceArray([ 0.18784401, -1.28334229,  0.6494182 ], dtype=float64),
 'S_reference': DeviceArray([0., 0., 0.], dtype=float64),
 'S_2': DeviceArray([0., 0., 0.], dtype=float64),
 'C_ref': DeviceArray([0.], dtype=float64),
 'α': {'sig_stretch_x': DeviceArray([0.55893636], dtype=float64),
  'sig_shift_x': DeviceArray([1.70016778], dtype=float64),
  'sig_stretch_y': DeviceArray([1.31402346], dtype=float64),
  'sig_shift_y': DeviceArray([-0.45297973], dtype=float64)}}

Format the data to be fed into the model

In [19]:
y = {h:jnp.array(hfdf.log2E.values) for h, hfdf in func_score_df.groupby("homolog")}
data = (X, y)
data

({'2': BCOO(int8[5, 3], nse=6), 'reference': BCOO(int8[3, 3], nse=3)},
 {'2': DeviceArray([ 2.3, -5. ,  0.4,  2.7, -2.7], dtype=float64),
  'reference': DeviceArray([ 2. , -7. , -0.5], dtype=float64)})

Define an objective function to optimize

In [20]:
# TODO centering wildtype -
# TODO static args for reference -
# TODO regularization -

@jax.jit
def ϕ(params:dict, X_h:jnp.array):
    """Model for predicting latent space"""
    
    return (X_h @ (params["β"] + params[f"S"])) + params["C_ref"]


@jax.jit
def g(α:dict, z_h:jnp.array):
    """Model for global epistasis as 'flexible' sigmoid."""
     
    activations = jax.nn.sigmoid(α["sig_stretch_x"] * z_h[:, None] + α["sig_shift_x"])
#     wt_activations = jax.nn.sigmoid(α["sig_stretch_x"] * α["C_ref"] + α["sig_shift_x"])
#     activations -= wt_activations
    return (α["sig_stretch_y"] * activations) + α["sig_shift_y"] # sig shift y needs to be 0 for the reference


@jax.jit
def cost_smooth(params, data, δ=1):
    """Cost (Objective) function as a sum of huber loss across all homologs"""

    
#     print("Calling Cost Function", flush=True)
    X, y = data
    loss = 0
    
    
    # Sum the huber loss across all homologs
    for homolog, X_h in X.items():
        
        # Fix the shift parameters for reference to 0
        # Static arguments for reference params?
        S_h = jnp.where(
            homolog=="reference", 
            jnp.zeros(len(params['β'])), 
            params[f"S_{homolog}"]
        )
        
        # Subset the params being passed into latent prediction, ϕ
        h_params = {"β":params["β"], "S":S_h, "C_ref":params["C_ref"]}
        
#         print(f"Calling Phi Function {homolog}", flush=True)
        z_h = ϕ(h_params, X_h)
        
        # Pass the latent predictions through GE model prediction
        # all GE specific parameters are stored in α
#         print(f"Calling G Function {homolog}", flush=True)
        y_h_predicted = g(params["α"], z_h)
        
        # compute loss at current parameter state.
        loss += jaxopt.loss.huber_loss(y[homolog], y_h_predicted, δ).mean()

#     print(f"{loss}", flush=True)
    return loss

Calculate initial loss with random parameters

In [21]:
cost_smooth(params, data)

DeviceArray(5.02411651, dtype=float64)

In [22]:
tol = 1e-6
maxiter = 1000
solver = jaxopt.GradientDescent(cost_smooth, tol=tol, maxiter=maxiter)

In [23]:
params, state = solver.run(params, data=data)

Optimize the entire model at once with gradient decent

Peek at tuned parameters

In [24]:
params

{'C_ref': DeviceArray([-0.72320785], dtype=float64),
 'S_2': DeviceArray([2.63420862, 2.63523184, 2.59795679], dtype=float64),
 'S_reference': DeviceArray([0., 0., 0.], dtype=float64),
 'α': {'sig_shift_x': DeviceArray([0.69421427], dtype=float64),
  'sig_shift_y': DeviceArray([-0.5059075], dtype=float64),
  'sig_stretch_x': DeviceArray([2.06062493], dtype=float64),
  'sig_stretch_y': DeviceArray([2.91412947], dtype=float64)},
 'β': DeviceArray([-2.63421577, -2.63523528, -2.59796787], dtype=float64)}

## Simulation Fit

* TODO J: Cleaup prep
* TODO J: Try removing for loop in cost function (only compute cost on homolog)

* TODO H: Get will's notebook running in our environment (positive control for env)

* TODO J: jnp where culprit:
    * Try removing the shift parameters & update Phi accordingly
    * static args for reference

* TODO H: Run the Ab-CGGnaive_DMS with our model
* TODO H: Smaller simulated sequence

In [60]:
simulated_dataset = pd.read_csv("results/simulated_dataset_v1.csv")

In [61]:
import json
homologs = json.load(open("results/homolog_aa_seqs.json", "r"))

In [62]:
homologs["reference"] = homologs['1']
homologs["H2"] = homologs['2']
del homologs['1']
del homologs['2']
homologs

{'reference': 'RSVILRAYTNSRVKRVILCNNDLPIRNIRLMMILHNSDASFSTPVGLRSG',
 'H2': 'RVVILRAYTNSRVKRIKLCNNDRPIRNIRTMMIEHNSDAKFHTPYGLDSG'}

In [75]:
simulated_dataset_lib1 = simulated_dataset.query("library == 'lib_1'").copy()
simulated_dataset_lib1.aa_substitutions.fillna("", inplace=True)
# simulated_dataset_lib1 = simulated_dataset_lib1.sample(n=2000)

In [76]:
func_score_df = variant_mutations_wrt_ref(simulated_dataset_lib1, homologs, "homolog", "aa_substitutions")
func_score_df

Unnamed: 0,library,barcode,variant_call_support,codon_substitutions,aa_substitutions,n_codon_substitutions,n_aa_substitutions,latent_phenotype,observed_phenotype,observed_enrichment,homolog,var_wrt_ref
0,lib_1,AAAAAAATTTACGCGA,1,TTA18GTC TTA23TCA TTT41AAC,L18V L23S F41N,3,3,-15.608947,-9.965543,0.001000,reference,L18V L23S F41N
1,lib_1,AAAAAACATAGGAGTA,3,TGC19AAG CGG29TCC,C19K R29S,2,2,2.516825,-0.102192,0.931616,reference,C19K R29S
2,lib_1,AAAAAAGAGGTTAAAC,1,ATG32TTC,M32F,1,1,3.980341,-0.016995,0.988289,reference,M32F
3,lib_1,AAAAAAGGCTTATACT,1,TCA11TCG CGG12GGT GGT46AAA,R12G G46K,3,2,-21.146126,-9.965783,0.001000,reference,R12G G46K
4,lib_1,AAAAAATCACTAATAT,3,AGA1ACA CGT15CCG TCC37GCT AGT40TAA TCG42AGC,R1T R15P S37A S40*,5,4,-21.623836,-9.965784,0.001000,reference,R1T R15P S37A S40*
...,...,...,...,...,...,...,...,...,...,...,...,...
74995,lib_1,TTTTTGTAAGGCCTCT,1,AGA1CTT CTG5TCT ATC25GAT GGA50ACA,R1L L5S I25D G50T,4,4,-17.215931,-9.965736,0.001000,H2,R1L S2V L5S V16I I17K L23R I25D L30T L34E S40K...
74996,lib_1,TTTTTTAATCAGTTAG,3,CCA44ACC,P44T,1,1,0.242479,-0.829286,0.562808,H2,S2V V16I I17K L23R L30T L34E S40K S42H P44T V4...
74997,lib_1,TTTTTTCCTAGGAGAT,2,AAC10ACA ATC28TTC,N10T I28F,2,2,-6.999078,-9.028964,0.001914,H2,S2V N10T V16I I17K L23R I28F L30T L34E S40K S4...
74998,lib_1,TTTTTTCTACAGAGGT,2,CGC23GGT,R23G,1,1,-0.720320,-1.603013,0.329189,H2,S2V V16I I17K L23G L30T L34E S40K S42H V45Y R48D


Next, we will use the modified dataframe to create a `BinaryMap` object for each homolog. Each object needs to have the exact same columns in the same order. To achieve this, we will first make a list of all unique mutations observed in the above dataframe.

In [77]:
allowed_subs = {
    s for subs in func_score_df.var_wrt_ref
    for s in subs.split()
}
# allowed_subs

Then, we will make a `BinaryMap` for each homolog, feeding the above list to the function that makes the maps, which ensures the maps will have identical columns.

In [78]:
X = {}
for homolog, homolog_func_score_df in func_score_df.groupby("homolog"):
    ref_bmap = bmap.BinaryMap(
        homolog_func_score_df,
        substitutions_col="var_wrt_ref",
        allowed_subs=allowed_subs
    )
    
    # TODO use sparse arrays
#     X[homolog] = jnp.array(ref_bmap.binary_variants.toarray())
    X[homolog] = sparse.BCOO.fromdense(ref_bmap.binary_variants.toarray())
    
    
    print(homolog)
    print(ref_bmap.binary_variants.toarray())
#     print({mut:ref_bmap.sub_to_i(mut) for mut in allowed_subs})

H2
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]
reference
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]


In [79]:
y = {h:jnp.array(hfdf.observed_enrichment.values) for h, hfdf in func_score_df.groupby("homolog")}
data = (X, y)
data

({'H2': BCOO(int8[25000, 1000], nse=287350),
  'reference': BCOO(int8[25000, 1000], nse=47572)},
 {'H2': DeviceArray([0.00100008, 0.64589581, 0.99360193, ..., 0.0019143 ,
               0.32918888, 1.        ], dtype=float64),
  'reference': DeviceArray([0.00100017, 0.93161612, 0.98828937, ..., 0.00149706,
               0.50507318, 1.        ], dtype=float64)})

In [80]:
# all model params to be tuned will be stored in a dictionary
params = {}
seed = 0
key = jax.random.PRNGKey(seed)

# initialize beta parameters from normal distribution.
all_mutations = list(allowed_subs)
params["β"] = jax.random.normal(shape=(len(all_mutations),), key=key)

# initialize shift parameters
# TODO no need for subkeys
key, *subkeys = jax.random.split(key, len(homologs)+1)
for homolog, subkey in zip(homologs, subkeys):
    
    # Do not create shift parameters for the reference
#     if homolog == "reference": continue
    
    # We expect most shift parameters to be close to zero
    params[f"S_{homolog}"] = jnp.zeros(shape=(len(all_mutations),))

params["C_ref"] = jnp.zeros(shape=(1, ))

# Number of free parameters for a linear transformation on the latent phenotype
# that results in the shape of the sigmoid.
n_units = 1

# TODO ??: Are we parameterizing out GE function correctly?
key, *subkeys = jax.random.split(key, num=5)
params["α"]=dict(sig_stretch_x = jax.random.normal(shape=(n_units,), key=subkeys[0]), # 'stretch' in x direction
                 sig_shift_x = jax.random.normal(shape=(1,), key=subkeys[1]),         # 'shift' in the x direction
                 sig_stretch_y = jax.random.normal(shape=(n_units,), key=subkeys[2]),       # 'stretch' in the y direction
                 sig_shift_y = jax.random.normal(shape=(1,), key=subkeys[3]),         # 'shift' in the y directions
)

params.keys()

dict_keys(['β', 'S_reference', 'S_H2', 'C_ref', 'α'])

In [69]:
params['β'].shape

(936,)

In [70]:
cost_smooth(params, data)

DeviceArray(0.27564962, dtype=float64)

In [71]:
# cost_deriv = jax.grad(cost_smooth)

In [72]:
tol = 1e-6
maxiter = 10
solver = jaxopt.GradientDescent(cost_smooth, tol=tol, maxiter=maxiter)

In [73]:
params, state = solver.run(params, data=data)

Optimize the entire model at once with gradient decent

Peek at tuned parameters

In [74]:
params

{'C_ref': DeviceArray([0.02381425], dtype=float64),
 'S_H2': DeviceArray([ 0.00000000e+00,  6.35462365e-05,  1.14176287e-04,
               0.00000000e+00,  7.47357580e-05,  1.58670127e-04,
              -1.17119055e-05,  2.78465527e-04,  1.02569191e-04,
               1.25591527e-04,  0.00000000e+00,  2.70561708e-04,
               2.45273046e-04,  6.48985540e-04,  1.40973150e-04,
               5.58800383e-04,  4.39097867e-04, -3.89111087e-05,
               2.80021226e-04,  1.16366579e-04,  2.46968562e-04,
               1.08271070e-04,  5.17715662e-04,  0.00000000e+00,
               0.00000000e+00,  0.00000000e+00,  2.51754165e-04,
              -2.19313510e-05,  0.00000000e+00, -4.69259122e-05,
               0.00000000e+00,  9.43620866e-04,  0.00000000e+00,
               4.04430183e-02,  2.38679588e-04,  1.80612477e-04,
               2.63488437e-04,  5.12737459e-05, -3.24199604e-05,
              -2.47057248e-05,  0.00000000e+00,  0.00000000e+00,
               3.18577468e-05,

## Empirical Data (TODO)

In [2]:
df = pd.read_csv("https://media.githubusercontent.com/media/jbloomlab/SARS-CoV-2-RBD_DMS_Omicron/main/results/binding_Kd/bc_binding.csv")

In [3]:
df

Unnamed: 0,library,barcode,target,variant_class,aa_substitutions,n_aa_substitutions,TiteSeq_avgcount,log10Ka
0,pool1A,AAAAAAAAAAACGCGA,BA2,1 nonsynonymous,I88V,1,3.317373,
1,pool1A,AAAAAAAAAAAGGAGA,Wuhan_Hu_1,1 nonsynonymous,G166M,1,59.823342,5.963056
2,pool1A,AAAAAAAAAAATTTAA,Wuhan_Hu_1,wildtype,,0,66.484153,8.913236
3,pool1A,AAAAAAAAAACAAAAA,BA1,1 nonsynonymous,A42N,1,17.062794,9.253828
4,pool1A,AAAAAAAAAACGCGTA,Wuhan_Hu_1,1 nonsynonymous,E154T,1,24.911126,9.152625
...,...,...,...,...,...,...,...,...
598389,pool2A,TTTTTTTGATATTGGA,Wuhan_Hu_1,1 nonsynonymous,C158G,1,47.587851,5.000000
598390,pool2A,TTTTTTTGCTCTTACC,BA1,1 nonsynonymous,S129P,1,63.491459,8.837559
598391,pool2A,TTTTTTTGTATAACAA,BA1,1 nonsynonymous,P191W,1,22.825324,8.992898
598392,pool2A,TTTTTTTTAGCCGATA,BA2,1 nonsynonymous,F17T,1,93.332943,7.592250


## Old code Attic

In [None]:
# @jax.jit
# def ϕ(params:dict, X_tuple:tuple): #X_h:jnp.array): #, homolog:str):#is_ref:bool):
#     """Phi function for predicting latent phenotype."""
    
    
    
# #     if is_ref:
#     homolog = X_tuple[0]
#     X_h = X_tuple[1]
#     if homolog == "reference":
        
#         return (X_h @ params["β"]) + params["C_ref"]

#     return (X_h @ (params["β"] + params[f"S_{homolog}"])) + params["C_ref"]

# @jax.jit
# def ϕ(params:dict, X_h:jnp.array):
#     return (X_h @ (params["β"] + params[f"S"])) + params["C_ref"]

# @jax.jit
# def g(α:dict, z_h:jnp.array):
#     """Global epistasis function as flexible sigmoid."""
#     activations = jax.nn.sigmoid(α["rate"] * z_h[:, None] + α["intercept"])
#     return (α["a"] * activations) + α["bias"] 