# multidms

## Overview of model

The `multidms` model applies to a case where you have DMS datasets for two or more homologs and are interested in identifying shifts in mutational effects between homologs.
To do so, the model defines one homolog as a reference homolog.
For each mutation, the model fits one parameter that quantifies the effect of the mutation in the reference homolog.
For each non-reference homolog, it also fits a shift parameter that quantifies the shift in the mutation's effect in the non-reference homolog relative to the reference.
Shift parameters can be regularized, encouraging most of them to be close to zero.
This regularization step is a useful way to eliminate the effects of experimental noise, and is most useful in cases where you expect most mutations to have the same effects between homologs, such as for homologs that are close relatives.

The model uses a global-epistasis function to disentangle the effects of multiple mutations on the same variant.
To do so, it assumes that mutational effects additively influence a latent biophysical property the protein (e.g., $\Delta G$ of folding).
The mutational-effect parameters described above operate at this latent level.
The global-epistasis function then assumes a sigmoidal relationship between a protein's latent property and its functional score measured in the experiment (e.g., log enrichment score).
Ultimately, mutaitonal parameters, as well as ones controlling the shape of the sigmoid, are all jointly fit to maximize agreement between predicted and observed functional scores acorss all variants of all homologs.

## Detailed description of the model

For each variant $v$ from homolog $h$, we use a global-epistasis function $g$ to convert a latent phenotype $\phi$ to a functional score $f$:

$$f(v,h) = g_{\alpha}(\phi(v,h))$$

where $g$ is a sigmoid and $\alpha$ is a set of parameters encoding the sigmoid.

The latent phenotype is computed in the following way:

$$\phi(v,h) = c_h + \sum_{m \in v} (x_m + s_{m,h})$$

where:
* $c_h$ is the wildtype latent phenotype for homolog $h$
* $x_m$ is the latent phenotypic effect of mutation $m$. See details below.
* $s_{m,h}$ is the shift of the effect of mutation $m$ in homolog $h$. These parameters are fixed to zero for the reference homolog. For non-reference homologs, they are defined in the same way as $x_m$ parameters.

The $x_m$ variable is defined such that mutations are always relative to the reference homolog.
This way, each homolog informs the exact same parameters, even at sites that differ in wildtype amino acid.
For instance, if the wildtype amino acid at site 30 is an A in the reference homolog, but is a Y in a second homolog, then a Y30G mutation in the second homolog is defined as the sum of the following two mutations relative to A: A30Y (negated) and A30G.
This approach assumes that mutational effects can be negated, such that A30Y has the opposite effect as Y30A.
It also assumes that mutational effects are additive, such that the effect of Y30G is the sum of the effects of Y30A and A30G.

The below expression defines this approach more explicitly for an arbitrary site. It uses the notation $x_{\mathtt{X},n,\mathtt{Z}}$ where $\mathtt{X}$ and $\mathtt{Z}$ are amino acids.
For a site $n$, if $aa_{\text{wt}}$ is the site's wildtype amino acid in a non-reference homolog, $aa_{\text{mut}}$ is a mutant amino acid in a variant of that homolog, and $aa_{\text{ref}}$ is the site's wildtype amino acid in the reference homolog, then $x_m$ is:

$$x_{aa_{\text{wt}},n,aa_{\text{mut}}} = \begin{cases}
      x_{aa_{\text{ref}},n,aa_{\text{mut}}} & \text{if } aa_{\text{wt}} = aa_{\text{ref}}\\
      -x_{aa_{\text{ref}},n,aa_{\text{wt}}} & \text{if } aa_{\text{mut}} = aa_{\text{ref}}\\
      - x_{aa_{\text{ref}},n,aa_{\text{wt}}} + x_{aa_{\text{ref}},n,aa_{\text{mut}}} & \text{otherwise}\\
\end{cases}$$

In this way, we can express the whole model with reference only to coefficients that differ from the reference homolog.

Ultimately, we fit parameters using a loss function with one term that scores differences between predicted and observed values and another that uses L1 regularization to penalize non-zero $s_{m,h}$ values:

$$ L_{\text{total}} = \sum_{h} \left[\sum_{v} L_{\text{fit}}(y_{v,h}, f(v,h)) + \lambda \sum_{m} |s_{m,h}|\right]$$

where:
* $L_{\text{total}}$ is the total loss function
* $L_{\text{fit}}$ is a loss function that penalizes differences in predicted vs. observed functional scores
* $y_{v,h}$ is the experimentally measured functional score of variant $v$ from homolog $h$

## Model using matrix algebra

We compute a vector or predicted latent phenotypes $P_{h}$ as:

$$P_{h} = C_h + (W_h \cdot (X + S_h))$$

where:
* $X$ is a vector of all $x_m$ values
* $S_h$ is a matrix of all $s_{m,h}$ values
* $W_h$ is a sparse matrix, where rows are variants, columns are mutations (all defined relative to the reference homolog), and values are weights of 0, 1, or -1. These weights are used to compute the phenotype of each variant given the mutations present.
* $C_h$ is a vector of the homolog's $c_h$ value, repeated $n$ times, where $n$ is the number of variants.

In the matrix algebra, the sum of $X_m$ and $S_{m,h}$ gives a vector of mutational effects, with one entry per mutation.
Multiplying the matrix $W_h$ by this vector gives a new vector with one entry per variant, where values are the sum of mutational effects, weighted by the variant-specific weights in $W_h$.
Adding $C_h$ values to this vector will give a vector of predicted latent phenotypes for each variant.

Next, the global-epistasis function can be used to convert a vector of predicted latent phenotypes to a vector of predicted functional scores.

$$F_{h,pred} = g_{\alpha}(P_h)$$

Finally, this vector could be fed into a loss function and compared with a vector of observed functional scores.

## Binarymap augmentation

Some toy code for making a "HomologBinaryMap" Class. 

No need to re-invent a binary encoding for aa sequence data, we can simply augment the `BinaryClass`
Object after converting the mutations to be with respect to the reference sequence as stated above defining $x_{aa_{\text{wt}},n,aa_{\text{mut}}}$. As seen above, some of the binary columns (substitutions) that are encoding the effect of a homolog substitution must be negated. These are sites that distinguish the wildtype sequences of the reference and homolog.

This could eventually be a proper implimentation of a `BinaryMap`-inhereted class that overides the `__init__()` and any other relevent functions. It would inherit all the same parameters, add the homolog wildtype sequence (`hom_wtseq`, below), and make the reference wildtype sequence (`wtseq`, in `BinaryMap`) manditory. 

Note: we're still not entirely sure how to handle gapped alignments

In [None]:
test_dict = {
    'homolog' : [1,1,1,2,2,2,2,2],
    'variant' : ['M1E', 'G2R', 'G2P', 'M1E', 'P2R', 'P2G', 'M1E P2G', 'M1E P2R'],
    'log2E' : [2, -7, -0.5, 2.3, -5, 0.4, 2.7, 0.5],
}
test_df = pd.DataFrame(test_dict)

In [10]:
import pandas as pd
import numpy as np
import re
import binarymap as bmap


def mut_wrt_ref(wtseq:str, aa_subs:list[str]):
    """
    Convert a list of given variants for a non reference homolog
    to be with respect to the reference sequence
    such that a given mutation, {X_aawt,n,aamut} is equal to:

    1. X_{aawt,n,aamut}                             if aawt == aaref
    2. -1 * X_{aaref,n,aawt}                        if aamut == aaref
    3. (-1 * X_{aaref,n,aawt}) + X_{aawt,n,aamut}   else
    
    aawt is the wildtype aa of the homolog
    aamut is the substitution aa of the homolog
    aaref is the wildttype aa of the reference sequence
    
    Parameters
    ----------
    wtseq : str
        amino acid sequence for the reference homolog.
    
    aa_subs : list[str]
        a list a <wt><site><sub> formatted substitutions on a 
        non-reference homolog.
        
    Returns
    -------
    list:
        non reference homolog substitutions converted to be with respect 
        the reference homolog. The -1 coefficients from the above
        expressions are not included.
    """


    def convert_single_mut(mutation:str):
        """function for converting single mutations"""

        # Parse the aawt, site, and aamut
        pattern = r'(?P<aawt>\w)(?P<site>\d+)(?P<aamut>\w)'
        match = re.search(pattern, mutation)
        assert match != None, mutation
        aawt = match.group('aawt')
        site = match.group('site')
        aamut = match.group('aamut')

        # Get the wildtype amino acid of the reference homolog
        ref_aa = wtseq[int(site) - 1]

        ret = []
        
        # If the site wt amino acid is homologous, no conversion needed
        if ref_aa == aawt:
            ret.append(mutation)
            
        # If the substitution is a back-mutation to the ref,
        # return the reference substitution to the non-ref 
        # wildtype sequence, this will be negated
        elif aamut == ref_aa:
            ret.append(f"{ref_aa}{site}{aawt}")
            
        # Otherwise, the site is non homologous, and
        # there's a new mutation. The ref sub to non-red
        # wildtype will be negated and added to the effect
        # of the ref wt aa mutating to the homolog mutation
        else:
            ret.append(f"{ref_aa}{site}{aawt}")
            ret.append(f"{ref_aa}{site}{aamut}")
            
        return ret

    
    def convert_string_muts(mutations:str):
        """function for converting multiple mutations"""
        
        muts = [m for mut in mutations.split() for m in convert_single_mut(mut)]
        return " ".join(muts)

    
    return [convert_string_muts(muts) for muts in aa_subs]


def homolog_bmap(wtseq, hom_wtseq, func_score_df, *args, **kwargs):
    """
    Initialize a `HomologBinaryMap` which essentially converts the mutations
    of the homolog to be wrt the reference sequence, the negates the necessary
    columns on the binary encoding.
    
    extra arguments are passed to the constructor of a regular BinaryMap.
    For the model fitting, it is suggested to scan all the homolog converted
    mutations and pass those to the allowed_subs so that all homologs share
    the 1'th dimension.
    
    Note: We're not handeling gaps or insertions, yet. For the current code,
    we expect all homologs are alignable without gaps.
    
    Parameters
    ----------
    wtseq : str
        The wildtype sequence of the reference homolog.
    
    hom_wtseq : str
        The wildtype sequence of the non-reference homolog you're encoding
    
    func_score_df : pandas.DataFrame
        This should be in the same format as described in BinaryMap
        The aa substitution converstions are handled here so 
        substitution_col should be with respect the the 
        non-reference homolog
    
    Returns
    -------
    binarymap.BinaryMap
        The non-reference homolog binary map
    """

    sub_col = "aa_substitutions" 
    if "substitutions_col" in kwargs:
        sub_col = kwargs["substitutions_col"]

    h_func_scores = func_score_df.copy()
    h_func_scores[sub_col] = mut_wrt_ref(
        wtseq,
        func_score_df[sub_col]
    )

    hbmap = bmap.BinaryMap(h_func_scores, *args, **kwargs)

    # get all subs in variants wrt reference sequence computed above
    subs_in_variants = {
        s for subs in h_func_scores[sub_col]
        for s in subs.split()
    }

    # all possible mutations that would need to be negated
    all_non_homologous_muts = {
        f"{ref_aa}{i+1}{hom_aa}" 
        for i, (ref_aa, hom_aa) in enumerate(zip(wtseq, hom_wtseq))
        if ref_aa != hom_aa
    }

    # negate those that exist in our dataset and are non-homologous
    to_negate = subs_in_variants.intersection(all_non_homologous_muts)

    idx_to_negate = [hbmap.sub_to_i(s) for s in to_negate]
    hbmap.binary_variants[:, idx_to_negate] *= -1
    return hbmap

## Test Case

With some imaginary variants from two imaginary homologs, (`ref` and `hom`), we'll test the code above for a sanity check

In [11]:
ref_wt_seq = "MG"
hom_wt_seq = "MP"

In [12]:
test_dict = {
    'homolog' : [1,1,1,2,2,2,2,2],
    'variant' : ['M1E', 'G2R', 'G2P', 'M1E', 'P2R', 'P2G', 'M1E P2G', 'M1E P2R'],
    'log2E' : [2, -7, -0.5, 2.3, -5, 0.4, 2.7, 0.5],
}
test_df = pd.DataFrame(test_dict)
test_df

Unnamed: 0,homolog,variant,log2E
0,1,M1E,2.0
1,1,G2R,-7.0
2,1,G2P,-0.5
3,2,M1E,2.3
4,2,P2R,-5.0
5,2,P2G,0.4
6,2,M1E P2G,2.7
7,2,M1E P2R,0.5


Let's test it out on the df created earlier

The first thing our `HomologBinarymap` will do after constructing a regular binarymap, is to conver the mutations in the homolog to be wrt to the wildtype. This step is carried out in the `homolog_bmap` function, but the below cell shows an example of the results.

In [5]:
test_df_wrt_ref = test_df.assign(var_wrt_ref=mut_wrt_ref(ref_wt_seq, test_df.variant.values))
test_df_wrt_ref

Unnamed: 0,homolog,variant,log2E,var_wrt_ref
0,1,M1E,2.0,M1E
1,1,G2R,-7.0,G2R
2,1,G2P,-0.5,G2P
3,2,M1E,2.3,M1E
4,2,P2R,-5.0,G2P G2R
5,2,P2G,0.4,G2P
6,2,M1E P2G,2.7,M1E G2P
7,2,M1E P2R,0.5,M1E G2P G2R


Each of the homologs will get their own binary encoding, but they need to all include the same possible substitutions. This way they will have the same shape when informing the $X$ parameters in out model. So, we extract all the allowed subs from all homologs that will be included in the model

In [6]:
allowed_subs = {
    s for subs in test_df_wrt_ref.var_wrt_ref
    for s in subs.split()
}
allowed_subs

{'G2P', 'G2R', 'M1E'}

In [7]:
test_df

Unnamed: 0,homolog,variant,log2E
0,1,M1E,2.0
1,1,G2R,-7.0
2,1,G2P,-0.5
3,2,M1E,2.3
4,2,P2R,-5.0
5,2,P2G,0.4
6,2,M1E P2G,2.7
7,2,M1E P2R,0.5


Using the set of common substitutions from above, the below cell makes a regular `BinaryMap` object for the _reference_ homolog.

In [8]:
ref_bmap = bmap.BinaryMap(
    test_df.query('homolog == 1'),
    substitutions_col="variant",
    allowed_subs=allowed_subs
)
print(ref_bmap.binary_variants.toarray())
print({mut:ref_bmap.sub_to_i(mut) for mut in allowed_subs})

[[1 0 0]
 [0 0 1]
 [0 1 0]]
{'G2R': 2, 'G2P': 1, 'M1E': 0}


Using the same list of substitutions, the next cell makes a homolog-`BinaryMap` for the _non-reference_ homolog, which includes which has done the conversions/negations

In [9]:
hom_bmap = homolog_bmap(
    ref_wt_seq,
    hom_wt_seq,
    test_df.query('homolog == 2'),
    substitutions_col="variant",
    allowed_subs=allowed_subs
)
print(hom_bmap.binary_variants.toarray())
print({mut:hom_bmap.sub_to_i(mut) for mut in allowed_subs})

[[ 1  0  0]
 [ 0 -1  1]
 [ 0 -1  0]
 [ 1 -1  0]
 [ 1 -1  1]]
{'G2R': 2, 'G2P': 1, 'M1E': 0}


Okay, so basically in order to allow for multiple mutations at the same site in the same variant, I had to remove a check in binarymap source code that raised and error for such a case. Currently, hese are lines 525 and 526 in the binarymap.py file. I could submit a PR that makes this check optional maybe?

## JAX Model

## Simulation Fit

## Empirical Data?