# multidms

## Overview of model

The `multidms` model applies to a case where you have DMS datasets for two or more homologs and are interested in identifying shifts in mutational effects between homologs.
To do so, the model defines one homolog as a reference homolog.
For each mutation, the model fits one parameter that quantifies the effect of the mutation in the reference homolog.
For each non-reference homolog, it also fits a shift parameter that quantifies the shift in the mutation's effect in the non-reference homolog relative to the reference.
Shift parameters can be regularized, encouraging most of them to be close to zero.
This regularization step is a useful way to eliminate the effects of experimental noise, and is most useful in cases where you expect most mutations to have the same effects between homologs, such as for homologs that are close relatives.

The model uses a global-epistasis function to disentangle the effects of multiple mutations on the same variant.
To do so, it assumes that mutational effects additively influence a latent biophysical property the protein (e.g., $\Delta G$ of folding).
The mutational-effect parameters described above operate at this latent level.
The global-epistasis function then assumes a sigmoidal relationship between a protein's latent property and its functional score measured in the experiment (e.g., log enrichment score).
Ultimately, mutational parameters, as well as ones controlling the shape of the sigmoid, are all jointly fit to maximize agreement between predicted and observed functional scores acorss all variants of all homologs.

## Detailed description of the model

For each variant $v$ from homolog $h$, we use a global-epistasis function $g$ to convert a latent phenotype $\phi$ to a functional score $f$:

$$f(v,h) = g_{\alpha}(\phi(v,h))$$

where $g$ is a sigmoid and $\alpha$ is a set of parameters encoding the sigmoid.

The latent phenotype is computed in the following way:

$$\phi(v,h) = c + \sum_{m \in v} (x_m + s_{m,h})$$

where:
* $c$ is the wildtype latent phenotype for the reference homolog
* $x_m$ is the latent phenotypic effect of mutation $m$. See details below.
* $s_{m,h}$ is the shift of the effect of mutation $m$ in homolog $h$. These parameters are fixed to zero for the reference homolog. For non-reference homologs, they are defined in the same way as $x_m$ parameters.
* $v$ is the set of all mutations relative to the reference wildtype sequence (including all mutations that separate homolog $h$ from the reference homolog).

The $x_m$ variable is defined such that mutations are always relative to the reference homolog.
For example, if there reference at site 30 is an A, and the wildtype in homolog 2 is a G, then we encode the
homolog sequence as including a A30G mutation.
This way, each homolog informs the exact same parameters, even at sites that differ in wildtype amino acid.
These are encoded in a `BinaryMap` object, where all sites that are non-identical to the reference are 1's.

Ultimately, we fit parameters using a loss function with one term that scores differences between predicted and observed values and another that uses L1 regularization to penalize non-zero $s_{m,h}$ values:

$$ L_{\text{total}} = \sum_{h} \left[\sum_{v} L_{\text{fit}}(y_{v,h}, f(v,h)) + \lambda \sum_{m} |s_{m,h}|\right]$$

where:
* $L_{\text{total}}$ is the total loss function
* $L_{\text{fit}}$ is a loss function that penalizes differences in predicted vs. observed functional scores
* $y_{v,h}$ is the experimentally measured functional score of variant $v$ from homolog $h$

## Model using matriβ algebra

We compute a vector or predicted latent phenotypes $P_{h}$ as:

$$P_{h} = c + (W_h \cdot (β + S_h))$$

where:
* $β$ is a vector of all $β_m$ values
* $S_h$ is a matriβ of all $s_{m,h}$ values
* $W_h$ is a sparse matrix, where rows are variants, columns are mutations (all defined relative to the reference homolog), and values are weights of 0's and 1's. These weights are used to compute the phenotype of each variant given the mutations present.
* $c$ is the same as above.

In the matrix algebra, the sum of $β_m$ and $S_{m,h}$ gives a vector of mutational effects, with one entry per mutation.
Multiplying the matrix $W_h$ by this vector gives a new vector with one entry per variant, where values are the sum of mutational effects, weighted by the variant-specific weights in $W_h$.
Adding the $c$ value to this vector will give a vector of predicted latent phenotypes for each variant.

Next, the global-epistasis function can be used to convert a vector of predicted latent phenotypes to a vector of predicted functional scores.

$$F_{h,pred} = g_{\alpha}(P_h)$$

Finally, this vector could be fed into a loss function and compared with a vector of observed functional scores.

## Binarymap augmentation

Rather than re-invent a binary encoding for aa sequence data, we can simply augment the `BinaryClass`
Object after converting the mutations to be with respect to the reference sequence as stated above.

Note 1: we will probably need some special-purpose code to handle gaps. This isn't done yet.

In [7]:
import pandas as pd 

test_dict = {
    'homolog' : [1,1,1,2,2,2,2,2],
    'variant' : ['M1E', 'G2R', 'G2P', 'M1E', 'P2R', 'P2G', 'M1E P2G', 'M1E P2R'],
    'log2E' : [2, -7, -0.5, 2.3, -5, 0.4, 2.7, 0.5],
}
test_df = pd.DataFrame(test_dict)

In [8]:
import pandas as pd
import numpy as np
import re
import binarymap as bmap

def variant_mutations_wrt_ref(
    func_score_df:pd.DataFrame, 
    homologs:dict,
    reference,
    homolog_name_col
):
    """Convert a list of given variants for a non reference homolog
    to be with respect to the reference sequence
    
    Parameters
    ----------

    func_score_df : pandas.DataFrame
        This should be in the same format as described in BinaryMap
        The aa substitution converstions are handled here so 
        substitution_col should be with respect the the 
        non-reference homolog
        
    homologs : dict
        A dictionary containing all possible target homolog 
        names (keys) and sequences (values).
        
    homolog_name_col : str
        The name of the column in func_score_df that identifies the target
        homolog for any given variant.
        
    reference_homolog_name :
        The factor level of homolog_name_col that is the reference.
    
    Returns
    -------
        
    pd.DataFrame :
        A dataframe with a new column 'var_wrt_wt' that contains
        all variant mutations with respect to the wildtype.
    
    """
    
    def mutations_wrt_ref(mutations, hom_wtseq):
        
        mutated_homolog = list(hom_wtseq)
        for mutation in mutations.split():

            pattern = r'(?P<aawt>\w)(?P<site>\d+)(?P<aamut>\w)'
            match = re.search(pattern, mutation)
            assert match != None, mutation
            aawt = match.group('aawt')
            site = match.group('site')
            aamut = match.group('aamut')
            mutated_homolog[int(site)-1] = aamut
            
        hom_var_seq = ''.join(mutated_homolog)
        ref_muts = [
            f"{aaref}{i+1}{aavar}" 
            for i, (aaref, aavar) in enumerate(zip(homologs[reference], hom_var_seq))
            if aaref != aavar
        ]
        
        return " ".join(ref_muts)
    
    func_score_df = func_score_df.assign(var_wrt_ref = func_score_df.variant.values)
    for hom_name, hom_seq in homologs.items():
        if hom_name == reference: continue
        print(f"{homolog_name_col} == '{hom_name}'")
        hom_df = func_score_df.query(f"{homolog_name_col} == '{hom_name}'")
        hom_var_wrt_ref = [
            mutations_wrt_ref(muts, homologs[hom_name]) 
            for muts in hom_df.variant
        ]
        func_score_df.loc[hom_df.index.values, "var_wrt_ref"] = hom_var_wrt_ref
        
    return func_score_df

## Test Case

With some imaginary variants from two imaginary homologs, (`ref` and `hom`), we'll test the code above for a sanity check.

In [9]:
homologs = {
    "1" : "MG",
    "2" : "MP"
}

In [10]:
test_dict = {
    'homolog' : ["1","1","1","2","2","2","2","2"],
    'variant' : ['M1E', 'G2R', 'G2P', 'M1E', 'P2R', 'P2G', 'M1E P2G', 'M1E P2R'],
    'log2E' : [2, -7, -0.5, 2.3, -5, 0.4, 2.7, 0.5],
}
test_df = pd.DataFrame(test_dict)
test_df

Unnamed: 0,homolog,variant,log2E
0,1,M1E,2.0
1,1,G2R,-7.0
2,1,G2P,-0.5
3,2,M1E,2.3
4,2,P2R,-5.0
5,2,P2G,0.4
6,2,M1E P2G,2.7
7,2,M1E P2R,0.5


In [11]:
test_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 8 entries, 0 to 7
Data columns (total 3 columns):
 #   Column   Non-Null Count  Dtype  
---  ------   --------------  -----  
 0   homolog  8 non-null      object 
 1   variant  8 non-null      object 
 2   log2E    8 non-null      float64
dtypes: float64(1), object(2)
memory usage: 320.0+ bytes


In [12]:
func_score_df = variant_mutations_wrt_ref(test_df, homologs, "1", "homolog")
func_score_df

homolog == '2'


Unnamed: 0,homolog,variant,log2E,var_wrt_ref
0,1,M1E,2.0,M1E
1,1,G2R,-7.0,G2R
2,1,G2P,-0.5,G2P
3,2,M1E,2.3,M1E G2P
4,2,P2R,-5.0,G2R
5,2,P2G,0.4,
6,2,M1E P2G,2.7,M1E
7,2,M1E P2R,0.5,M1E G2R


Let's test it out on the df created earlier.

Each of the homologs will get their own binary encoding, but they need to all include the same possible substitutions. This way they will have the same shape when informing the $X$ parameters in out model. So, we extract all the allowed subs from all homologs that will be included in the model.

In [13]:
allowed_subs = {
    s for subs in func_score_df.var_wrt_ref
    for s in subs.split()
}
allowed_subs

{'G2P', 'G2R', 'M1E'}

create binary maps

In [14]:
for homolog, homolog_func_score_df in func_score_df.groupby("homolog"):
    ref_bmap = bmap.BinaryMap(
        homolog_func_score_df,
        substitutions_col="var_wrt_ref",
        allowed_subs=allowed_subs
    )
    print(homolog)
    print(ref_bmap.binary_variants.toarray())
    print({mut:ref_bmap.sub_to_i(mut) for mut in allowed_subs})

1
[[1 0 0]
 [0 0 1]
 [0 1 0]]
{'M1E': 0, 'G2P': 1, 'G2R': 2}
2
[[1 1 0]
 [0 0 1]
 [0 0 0]
 [1 0 0]
 [1 0 1]]
{'M1E': 0, 'G2P': 1, 'G2R': 2}


## JAX Model

Here, we'll build a global epistasis model for homologs

In [19]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax.experimental import sparse
import jaxopt
import numpy as onp
from scipy.stats import pearsonr

In [20]:
@jax.jit
def ϕ(X, β, S):
    return  X @ (β + S) 

@jax.jit
def g(z, α):
    activations = jax.nn.sigmoid(α["weights"] * z[:, None] + α["biases"])
    wt_activations = jax.nn.sigmoid(α["biases"])
    return α["a"] @ (activations - wt_activations).T

@jax.jit
def f(X, α, β):
    return g(ϕ(X, β), α)

In [21]:
tol = 1e-6
maxiter = 1000

In [22]:
all_mutations = list(allowed_subs)

In [43]:
seed = 0
key = jax.random.PRNGKey(seed)
β0 = jax.random.normal(shape=(len(all_mutations),), key=key)

key, *subkeys = jax.random.split(key, len(homologs)+1)
# print(subkeys)
S_h = {}
for homolog, subkey in zip(homologs, subkeys):
    # np.zeros
    S_h[homolog] = jax.random.normal(subkey, shape=(len(all_mutations),))
# print(S_h)
    
β0 = jaxopt.linear_solve.solve_normal_cg(lambda β: ϕ(X, β), y, init=β0, tol=tol, maxiter=maxiter)

[DeviceArray([3186719485, 3840466878], dtype=uint32), DeviceArray([2562233961, 1946702221], dtype=uint32)]
{'1': DeviceArray([-0.64391523,  0.07092413,  0.3602591 ], dtype=float64), '2': DeviceArray([ 1.55754172,  1.57821833, -0.63811247], dtype=float64)}


In [None]:
@jax.jit
def cost_smooth(params, data, λ_ridge=0.0, λ_spline=0.0, δ=1):
    α = params["α"]
    β = params["β"]
    S = params["S"]

    X, y, sample_weights = data
    z = ϕ(X, β)
    y_predicted = g(z, α)

    loss = jaxopt.loss.huber_loss(y, y_predicted, δ).mean()

    ridge_penalty = λ_ridge * ((β ** 2).sum() + (α["weights"] ** 2).sum())

    a = -10
    b = 10
    # assert(a < z.min() and b > z.max())

    z_grid, dz = jnp.linspace(a, b, num=100, retstep=True)
    spline_penalty = λ_spline * (dz * jnp.diff(g(z_grid, α), n=2) ** 2).sum()

    return loss + ridge_penalty + spline_penalty

@jax.jit
def cost_nonsmooth(params, data, λ_lasso=0.0):
    return λ_lasso * jnp.linalg.norm(params["β"], 1)

@jax.jit
def cost(params, data, λ_ridge=0.0, λ_lasso=0.0, λ_spline=0.0, δ=1):
    return cost_smooth(params, data, λ_ridge=λ_ridge, λ_spline=λ_spline, δ=δ) + cost_nonsmooth(params, data, λ_lasso=λ_lasso)

## Simulation Fit

## Empirical Data