# multidms

## Overview of model

The `multidms` model applies to a case where you have DMS datasets for two or more homologs and are interested in identifying shifts in mutational effects between homologs.
To do so, the model defines one homolog as a reference homolog.
For each mutation, the model fits one parameter that quantifies the effect of the mutation in the reference homolog.
For each non-reference homolog, it also fits a shift parameter that quantifies the shift in the mutation's effect in the non-reference homolog relative to the reference.
Shift parameters can be regularized, encouraging most of them to be close to zero.
This regularization step is a useful way to eliminate the effects of experimental noise, and is most useful in cases where you expect most mutations to have the same effects between homologs, such as for homologs that are close relatives.

The model uses a global-epistasis function to disentangle the effects of multiple mutations on the same variant.
To do so, it assumes that mutational effects additively influence a latent biophysical property the protein (e.g., $\Delta G$ of folding).
The mutational-effect parameters described above operate at this latent level.
The global-epistasis function then assumes a sigmoidal relationship between a protein's latent property and its functional score measured in the experiment (e.g., log enrichment score).
Ultimately, mutaitonal parameters, as well as ones controlling the shape of the sigmoid, are all jointly fit to maximize agreement between predicted and observed functional scores acorss all variants of all homologs.

## Detailed description of the model

For each variant $v$ from homolog $h$, we use a global-epistasis function $g$ to convert a latent phenotype $\phi$ to a functional score $f$:

$$f(v,h) = g_{\alpha}(\phi(v,h))$$

where $g$ is a sigmoid and $\alpha$ is a set of parameters encoding the sigmoid.

The latent phenotype is computed in the following way:

$$\phi(v,h) = c_h + \sum_{m \in v} (x_m + s_{m,h})$$

where:
* $c_h$ is the wildtype latent phenotype for homolog $h$
* $x_m$ is the latent phenotypic effect of mutation $m$. See details below.
* $s_{m,h}$ is the shift of the effect of mutation $m$ in homolog $h$. These parameters are fixed to zero for the reference homolog. For non-reference homologs, they are defined in the same way as $x_m$ parameters.

The $x_m$ variable is defined such that mutations are always relative to the reference homolog.
This way, each homolog informs the exact same parameters, even at sites that differ in wildtype amino acid.
For instance, if the wildtype amino acid at site 30 is an A in the reference homolog, but is a Y in a second homolog, then a Y30G mutation in the second homolog is defined as the sum of the following two mutations relative to A: A30Y (negated) and A30G.
This approach assumes that mutational effects can be negated, such that A30Y has the opposite effect as Y30A.
It also assumes that mutational effects are additive, such that the effect of Y30G is the sum of the effects of Y30A and A30G.

The below expression defines this approach more explicitly for an arbitrary site. It uses the notation $x_{\mathtt{X},n,\mathtt{Z}}$ where $\mathtt{X}$ and $\mathtt{Z}$ are amino acids.
For a site $n$, if $aa_{\text{wt}}$ is the site's wildtype amino acid in a non-reference homolog, $aa_{\text{mut}}$ is a mutant amino acid in a variant of that homolog, and $aa_{\text{ref}}$ is the site's wildtype amino acid in the reference homolog, then $x_m$ is:

$$x_{aa_{\text{wt}},n,aa_{\text{mut}}} = \begin{cases}
      x_{aa_{\text{ref}},n,aa_{\text{mut}}} & \text{if } aa_{\text{wt}} = aa_{\text{ref}}\\
      -x_{aa_{\text{ref}},n,aa_{\text{wt}}} & \text{if } aa_{\text{mut}} = aa_{\text{ref}}\\
      - x_{aa_{\text{ref}},n,aa_{\text{wt}}} + x_{aa_{\text{ref}},n,aa_{\text{mut}}} & \text{otherwise}\\
\end{cases}$$

In this way, we can express the whole model with reference only to coefficients that differ from the reference homolog.

Ultimately, we fit parameters using a loss function with one term that scores differences between predicted and observed values and another that uses L1 regularization to penalize non-zero $s_{m,h}$ values:

$$ L_{\text{total}} = \sum_{h} \left[\sum_{v} L_{\text{fit}}(y_{v,h}, f(v,h)) + \lambda \sum_{m} |s_{m,h}|\right]$$

where:
* $L_{\text{total}}$ is the total loss function
* $L_{\text{fit}}$ is a loss function that penalizes differences in predicted vs. observed functional scores
* $y_{v,h}$ is the experimentally measured functional score of variant $v$ from homolog $h$

## Model using matrix algebra

We compute a vector or predicted latent phenotypes $P_{h}$ as:

$$P_{h} = C_h + (W_h \cdot (X + S_h))$$

where:
* $X$ is a vector of all $x_m$ values
* $S_h$ is a matrix of all $s_{m,h}$ values
* $W_h$ is a sparse matrix, where rows are variants, columns are mutations (all defined relative to the reference homolog), and values are weights of 0, 1, or -1. These weights are used to compute the phenotype of each variant given the mutations present.
* $C_h$ is a vector of the homolog's $c_h$ value, repeated $n$ times, where $n$ is the number of variants.

In the matrix algebra, the sum of $X_m$ and $S_{m,h}$ gives a vector of mutational effects, with one entry per mutation.
Multiplying the matrix $W_h$ by this vector gives a new vector with one entry per variant, where values are the sum of mutational effects, weighted by the variant-specific weights in $W_h$.
Adding $C_h$ values to this vector will give a vector of predicted latent phenotypes for each variant.

Next, the global-epistasis function can be used to convert a vector of predicted latent phenotypes to a vector of predicted functional scores.

$$F_{h,pred} = g_{\alpha}(P_h)$$

Finally, this vector could be fed into a loss function and compared with a vector of observed functional scores.

Quesitons
* how tell JAX to fix $s_m,h$ parameters at zero for reference homolog?

## Import Python modules

In [1]:
import pandas
import re
import jax
from jax.experimental import sparse

Make a test case

In [39]:
test_dict = {
    'homolog' : [1,1,1,1,2,2,2,2,2],
    'variant' : ['M1E','M1E G2A', 'G2R', 'G2P', 'M1E', 'P2R', 'P2G', 'M1E P2G', 'M1E P2R'],
    'log2E' : [2, -7, 0.6, -0.5, 2.3, -5, 0.4, 2.7, 0.5],
}
test_df = pandas.DataFrame(test_dict)
test_df

Unnamed: 0,homolog,variant,log2E
0,1,M1E,2.0
1,1,M1E G2A,-7.0
2,1,G2R,0.6
3,1,G2P,-0.5
4,2,M1E,2.3
5,2,P2R,-5.0
6,2,P2G,0.4
7,2,M1E P2G,2.7
8,2,M1E P2R,0.5


In [40]:
def make_weight_matrix(variants, ref_aas):
    """
    Make a sparse matrix of weights associated with particular
    variants (rows) and mutations (columns).
    
    Args:
    `variants`: a list of variants, with string values of `XnY`,
        where `X` and `Y` are amino acids and `n` is the site number
    `ref_aas`: a dictionary with sites as keys and the wildtype
        amino acid of the reference homolog as values
    
    Retruns: A tuple of: i) the sparse matrix and ii) a list of
        all unique mutations observed across all variants
    """
    
    # Loop over each variant and record mutations, as well as
    # weights and indices for making the sparse matrix
    all_mutations = []
    weights = []
    indices = []
    for (i, variant) in enumerate(variants):
        mutations = set(variant.split())
        for mutation in mutations:

            # Parse the wt_aa, site, and mut_aa
            pattern = r'(?P<wt_aa>\w)(?P<site>\d+)(?P<mut_aa>\w)'
            match = re.search(pattern, mutation)
            assert match != None, mutation
            wt_aa = match.group('wt_aa')
            site = match.group('site')
            mut_aa = match.group('mut_aa')
            
            # Get the wildtype amino acid of the reference homolog
            ref_aa = ref_aas[site]

            # If the wildtype amino acid matches the reference, then
            # assign a weight of 1 to the mutation and record the
            # mutation's indices in a sparse matrix
            if wt_aa == ref_aa:
                if mutation not in all_mutations:
                    all_mutations.append(mutation)
                j = all_mutations.index(mutation)
                weights.append(1)
                indices.append([i, j])

            # Otherwise, decompose the mutation into 1-2 mutations
            # relative to the reference sequence, depending on the
            # context
            else:
                mut_to_ref = ref_aa + site + wt_aa
                if mut_to_ref not in all_mutations:
                    all_mutations.append(mut_to_ref)
                j = all_mutations.index(mut_to_ref)
                weights.append(-1)
                indices.append([i, j])
                if ref_aa != mut_aa:
                    mut_from_ref = ref_aa + site + mut_aa
                    if mut_from_ref not in all_mutations:
                        all_mutations.append(mut_from_ref)
                    j = all_mutations.index(mut_from_ref)
                    weights.append(1)
                    indices.append([i, j])

    # Make a sparse matrix from the above weights and indices
    #assert len(indices) == len(set(indices))
    Wm = jax.experimental.sparse.BCOO(
        (weights, indices),
        shape=(len(variants), len(all_mutations))
    )
    
    return (Wm, all_mutations)

In [41]:
variants = list(test_df['variant'])
ref_aas = {'1':'M', '2':'G'}
#all_mutations = ['M1E', 'G2R', 'G2P']
(Wm, all_mutations) = make_weight_matrix(variants, ref_aas)
print(all_mutations)
Wm.todense()

['M1E', 'G2A', 'G2R', 'G2P']


DeviceArray([[ 1,  0,  0,  0],
             [ 1,  1,  0,  0],
             [ 0,  0,  1,  0],
             [ 0,  0,  0,  1],
             [ 1,  0,  0,  0],
             [ 0,  0,  1, -1],
             [ 0,  0,  0, -1],
             [ 1,  0,  0, -1],
             [ 1,  0,  1, -1]], dtype=int32)

## Binarymap

Some toy code for making a "HomologBinaryMap" Object. This should informa child class of BinaryMap that will have pretty much these functions as methods, and the `__init__()`

In [42]:
import pandas as pd
import numpy as np
import re
import binarymap as bmap


def mut_wrt_ref(ref_wt_seq:str, aa_subs:list[str]):
    """
    convert a set of given mutations a homolog, return the
    list of aa_substitutions to be with respect to the reference sequence
    such that a given mutation, {X_aawt,n,aamut} is equal to:

    1. X_{aawt,n,aamut}                             if aa_wt == aa_ref
    2. -1 * X_{aaref,n,aawt}                        if aamut == aaref
    3. (-1 * X_{aaref,n,aawt}) + X_{aawt,n,aamut}   else
    """

    # function for converting single mutations
    def convert_single_mut(mutation:str):

        # Parse the wt_aa, site, and mut_aa
        pattern = r'(?P<wt_aa>\w)(?P<site>\d+)(?P<mut_aa>\w)'
        match = re.search(pattern, mutation)
        assert match != None, mutation
        wt_aa = match.group('wt_aa')
        site = match.group('site')
        mut_aa = match.group('mut_aa')

        # Get the wildtype amino acid of the reference homolog
        # TODO, we're sure that the sites are 1-indexed, yes?
        ref_aa = ref_wt_seq[int(site) - 1]

        ret = []
        if ref_aa == wt_aa:
            ret.append(mutation)
        elif mut_aa == ref_aa:
            ret.append(f"{ref_aa}{site}{wt_aa}")
        else:
            ret.append(f"{ref_aa}{site}{wt_aa}")
            ret.append(f"{ref_aa}{site}{mut_aa}")
        return ret

    def convert_string_muts(mutations:str):
        muts = [m for mut in mutations.split() for m in convert_single_mut(mut)]
        return " ".join(muts)


    return [convert_string_muts(muts) for muts in aa_subs]


def homolog_bmap(ref_wt_seq, hom_wt_seq, func_score_df, *args, **kwargs):
    """
    initialize a `HomologBinaryMap` which essentially convervt the mutations
    of the homolog to be wrt the reference sequence, the negates the necessary
    columns on the binary encoding. 
    
    extra arguments are passed to the
    constructor of a regular BinaryMap

    This is essentially the stub for what could be in the __init__ of a child
    subclass of binarymap.
    """

    sub_col = "aa_substitutions" 
    if "substitutions_col" in kwargs:
        sub_col = kwargs["substitutions_col"]

    h_func_scores = func_score_df.copy()
    h_func_scores[sub_col] = mut_wrt_ref(
                ref_wt_seq,
                func_score_df[sub_col]
            )

    hbmap = bmap.BinaryMap(h_func_scores, *args, **kwargs)

    # get all subs in variants wrt reference sequence computed above
    subs_in_variants = {
            s for subs in h_func_scores[sub_col]
            for s in subs.split()
    }

    # all possible mutations that would need to be negated
    all_non_homologous_muts = {
        f"{ref_aa}{i+1}{hom_aa}" 
        for i, (ref_aa, hom_aa) in enumerate(zip(ref_wt_seq, hom_wt_seq))
        if ref_aa != hom_aa
    }

    # negate those that exist in our dataset and are non-homologous
    to_negate = subs_in_variants.intersection(all_non_homologous_muts)

    idx_to_negate = [hbmap.sub_to_i(s) for s in to_negate]
    hbmap.binary_variants[:, idx_to_negate] *= -1
    return hbmap

Let's test it out on the df created earlier

In [43]:
ref_wt_seq = "MG"
hom_wt_seq = "MP"

In [44]:
test_df

Unnamed: 0,homolog,variant,log2E
0,1,M1E,2.0
1,1,M1E G2A,-7.0
2,1,G2R,0.6
3,1,G2P,-0.5
4,2,M1E,2.3
5,2,P2R,-5.0
6,2,P2G,0.4
7,2,M1E P2G,2.7
8,2,M1E P2R,0.5


The first thing our `HomologBinarymap` will do after constructing a regular binarymap, is to conver the mutations in the homolog to be wrt to the wildtype

In [45]:
test_df_wrt_ref = test_df.assign(var_wrt_ref=mut_wrt_ref(ref_wt_seq, test_df.variant.values))
test_df_wrt_ref

Unnamed: 0,homolog,variant,log2E,var_wrt_ref
0,1,M1E,2.0,M1E
1,1,M1E G2A,-7.0,M1E G2A
2,1,G2R,0.6,G2R
3,1,G2P,-0.5,G2P
4,2,M1E,2.3,M1E
5,2,P2R,-5.0,G2P G2R
6,2,P2G,0.4,G2P
7,2,M1E P2G,2.7,M1E G2P
8,2,M1E P2R,0.5,M1E G2P G2R


Each of the homologs will get their own binary encoding, but they need to all include the same possible substitutions. This way they will have the same shape when informing the $X$ parameters in out model

In [46]:
allowed_subs = {
    s for subs in test_df_wrt_ref.var_wrt_ref
    for s in subs.split()
}
allowed_subs

{'G2A', 'G2P', 'G2R', 'M1E'}

In [47]:
test_df

Unnamed: 0,homolog,variant,log2E
0,1,M1E,2.0
1,1,M1E G2A,-7.0
2,1,G2R,0.6
3,1,G2P,-0.5
4,2,M1E,2.3
5,2,P2R,-5.0
6,2,P2G,0.4
7,2,M1E P2G,2.7
8,2,M1E P2R,0.5


In [50]:
ref_bmap = bmap.BinaryMap(
    test_df.query('homolog == 1'),
    substitutions_col="variant",
    allowed_subs=allowed_subs
)
print(ref_bmap.binary_variants.toarray())
print({mut:ref_bmap.sub_to_i(mut) for mut in allowed_subs})

[[1 0 0 0]
 [1 1 0 0]
 [0 0 0 1]
 [0 0 1 0]]
{'M1E': 0, 'G2R': 3, 'G2P': 2, 'G2A': 1}
[0 0 1 3 2]


In [51]:
hom_bmap = homolog_bmap(
    ref_wt_seq,
    hom_wt_seq,
    test_df.query('homolog == 2'),
    substitutions_col="variant",
    allowed_subs=allowed_subs
)
print(hom_bmap.binary_variants.toarray())
print({mut:hom_bmap.sub_to_i(mut) for mut in allowed_subs})

[[ 1  0  0  0]
 [ 0  0 -1  1]
 [ 0  0 -1  0]
 [ 1  0 -1  0]
 [ 1  0 -1  1]]
{'M1E': 0, 'G2R': 3, 'G2P': 2, 'G2A': 1}


Okay, so basically in order to allow for multiple mutations at the same site in the same variant, I had to remove a check in binarymap source code that raised and error for such a case. Now, it gives the correct binary encoding but certainly is not lining up the variants with the index into the array. But I think that's okay?

## JAX Model Fitting

Below cell is work in progress and may be outdated

Implementing the above equations in code

In [None]:
# Make a dataframe where rows are variants and one of the
# columns lists the mutations in each variant

# Make a list of all possible amino-acid mutations relative
# to the wildtype amino acid of the reference homolog
assert len(all_mutations) == len(set(all_mutations))

# Make weight matrix
Wm = make_wt_matrix(df, ref_aas, all_mutations)

# `Ch`: float of wildtype latent phenotype of homolog
# `Wh`: sparse matrix, where rows are variants, columns are mutations, and values are weights
# `Xm`: vector of x_m coefficients for all mutations across all homologs (one vector for all homologs)
# `Smh`: vector of s_m,h coefficients for homolog h and mutations m (one vector per homolog)

# sum (Xh + Shm) for all mutations m in each variant, using the Wh matrix to identify which
# values to add for each variant
# then add Ch to each value in the resulting vector
# this gives back a vector of the predicted latent phenotype of each variant
Lh_pred = Ch + (Wh @ (Xm + Smh))

# Use the global-epistasis function to compute a predicted functional score for each variant
# alpha contains parameters related to optimizing the sigmoid
Fh_pred = g(Lh, alpha)

# Compute the Huber loss for homolog h
huber_loss_h = jaxopt.loss.huber_loss(Fh_exp, Fh_pred, δ).mean()

# Compute the L1 loss for smh params for homolog h
L1_loss_h = L1(Smh)

# Compute total loss
total_loss_h = huber_loss_h + L1_loss_h
total_loss += total_loss_h 

# Optimize params Ch, Xm, Smh, and alpha using the above loss function


# To do
# look up how to optimize parameters of a sigmoid in JAX