# Biophysical model

The `multidms` joint model applies primarily to a case where you have DMS datasets
for two or more experimental conditions and are interested in identifying shifts
in mutational effects between conditions. 
Here we describe the biophysical model of conditional shifts that motivates the approach in this package.

We suggest reading the [Otwinowski et al. 2018](https://www.pnas.org/doi/10.1073/pnas.1804015115) paper to understand the approach for modeling global epistasis before reading the rest of the documentation.

## Model overview

`multidms` extends the traditional global epistasis model by informing the parameters with multiple dms experiments under differing experimental conditions. For example, distinct experimental conditions (referred to as $h$ throughout the documentation) may include sets of experiments that are performed with differing homology of wildtype sequence (i.e. SARS-CoV-2 Delta Vs. Omicron spike). This approach could also be used for experimental conditions which share homology of the wildtype, but are run with different selection targets. This model makes the assumption that differing conditions should result in _mostly_ the same effect of mutations -- but some _shifts_ in mutational effect due to biological mechanisms (i.e. epistasis). Ultimately, this model was designed to identify those shifts using feature selection via $L1$ regularization of the shift parameters described in the additive latent phenotype section below.

[//]: # "We find that the qualitative results are robust to choice for lasso strength, and generally this lasso"
[//]: # "acts as a single dial to increase signal and noise in a linear fasion"


At a high level, the model is a composition of three
functions which describe the expected biophysical interactions underlying a given phenotype;
(1) an additive model, $\phi$, describing a variant's _latent_ phenotype under a given condition,
(2) a global epistasis model, $g$, to disentangle the effects of multiple mutations on the same variant, and
(3) a final output activation function, $t$, accounting for an expected _lower bound_ on the variant's phenotype,
where the observed functional score may sit below due to experimental sensitivity.

Concretely, the predicted phenotype for a given variant $v$ under condition $h$ is given by 

$$
\hat{y}_{v,h} = t_{\gamma}(g_{\alpha}(\phi_{\beta, S, C_{r}}(v,h))
$$

Where 
$\gamma$, $\alpha$, $\beta$, $S$, and $C_{r}$
are _free_ parameters inferred from experimental observations during the fitting process.

**Note** The motivation behind defining an abstract model in terms of its components offers us (1) modularity for method testing and development, and (2) The flexibility of multiple options for model components that encompass the needs of differing research goals and experimental techniques. While there is only a single option for the $\phi$ latent prediction, we offer a few options both for the global epistasis ($g$) and output activation ($t$) functions. Generally, the package defaults for these components described below should be sufficient for most purposes, and in this case feel free to ignore the `multidms.biophysical` module all-together as this functionality is generally hidden unless explicitly specified during the instantiation of a `MultiDmsModel` object. 

Below, we'll describe the individual components in more detail.

## Normalizing observed functional scores using $\gamma_{h}$ 

The first consideration is whether observed functional scores are directly comparable between conditions.
In context of the model, _Predicted_ functional scores are all directly comparable since all are generated from the same latent space via the same global-epistasis function.
However, observed scores may not be comparable in the same way.
For instance, a common way to compute functional scores is with log enrichment ratios, normalized so that the wildtype sequence has a value of zero.
If the conditions being compared are DMS experiments conducted in the background of different homologs, then each homolog will necessarily have a functional score of zero within its experiment, even if those same homologs would have different functional scores when measured in the same experiment.
Thus, log enrichment ratios are not always directly comparable between experiments.
Ideally one or more of the same sequences would be included in the experimental design of DMS libraries to be compared, so that all scores can be normalized to the same sequence.
However, if there are no such sequences, then it may be possible to computationally estimate how to renormalize scores.

To this end, the model includes an additional parameter $\gamma_h$ for each non-reference condition that allows functional scores, $y'_{v,h}$ from that condition to be renormalized as follows:

$$
y'_{v,h} = y_{v,h} + \gamma_h
$$

where $\gamma_h$ for the reference condition is locked at zero.
There is a theoretical basis for adding $\gamma_h$ to $y_{v,h}$ if functional scores are log enrichment ratios.
As mentioned above, log enrichment ratios are normalized so that the wildtype sequence from a given experiment has a value of zero, according to the formula:

$$
y_{v,h} = \log(E_{v,h}) - \log(E_{\text{wt},h})
$$

Thus, adding $\gamma_d$ to $y_{v,d}$ is akin to renormalizing the log enrichment ratios so that a different sequence has a functional score of zero.
In theory, for each non-reference condition, there is a $\gamma_d$ that renormalizes functional scores to be relative to the wildtype sequence of the reference condition.
If these values are not experimentally measured, the model allows $\gamma_d$ parameters to be fit during optimization, which assumes that the correct $\gamma_d$ values will give the best model fit.

## Additive latent phenotype, $\phi$

The model defines one condition as a _reference_ condition.
For each mutation $m$, the model fits a single mutation effect parameter, $\beta_{m}$.
Additionally, the model fits set of shift parameters, $S_{m,h}$, 
that quantifies the shift a given mutation's
effect. Each mutation is associated with an independent shift parameter 
for each non-reference condition. 
For example if there exists 3 total experimental conditions, $h_{1}$, $h_{2}$, & $h_{3}$, 
then each mutation, $m$, will be get assigned a
single $\beta_{m}$ parameter, 
and two non-reference condition _shift_ parameters $S_{m, h_{2}}$, $S_{m, h_{3}}$ for the latent prediction

Shift parameters can be regularized, encouraging most of them to be
close to zero. This regularization step is a useful way to eliminate
the effects of experimental noise, and is most useful in cases where
you expect most mutations to have the same effects between conditions,
such as for conditions that are close relatives. 

Concretely, the latent phenotype of any variant, $v$, from the experimental condition, $h$,
is computed like so:

$$
\phi(v,h) = c_{r} + \sum_{m \in v} (\beta_{m} + S_{m,h})
$$

where:

* $c_{r}$ is the wild type latent phenotype for the reference condition.
* $\beta_{m}$ is the latent phenotypic effect of mutation $m$. See the note below
* $s_{m,h}$ is the shift of the effect of mutation $m$ in condition $h$.
  These parameters are fixed to zero for the reference condition. For
  non-reference conditions, they are defined in the same way as $\beta_m$ parameters.
* $v$ is the set of all mutations relative to the reference wild type sequence
  (including all mutations that separate condition $h$ from the reference condition).

**Note** The $\beta_m$ variable is defined such that mutations are always relative to the
reference condition. For example, if the wild type amino acid at site 30 is an
A in the reference condition, and a G in a non-reference condition,
then a Y30G mutation in the non-reference condition is recorded as an A30G
mutation relative to the reference. This way, each condition informs
the exact same parameters, even at sites that differ in wild type amino acid.
These are encoded in a `BinaryMap` object, where all sites that are non-identical
to the reference are 1's.

## global epistasis, $g$

Latent phenotypes as described above give rise to functional scores according to a global-epistasis function.
If this function is non-linear, then the model allows mutations to non-additively effect functional scores, helping to account for global epistasis.
Here, we'll go into more detail about the options available in `multidms` 

### Sigmoidal:

By default, the global-epistasis function here assumes a sigmoidal relationship between
a protein's latent property and it's functional score measured in the experiment
(e.g., log enrichment score). Using free parameters, the sigmoid
can flexibly conform to an optimal shape informed by the data. 
Note that this function is independent from the
experimental condition from which a variant is observed.

The sigmoidal function that relates a given _latent phenotype_, $z$, to its functional score is given by:

$$
g(z) =  \frac{\alpha_{scale}}{1 + e^{-z}} + \alpha_{bias}
$$

where:
* $\alpha_{scale}$ is a free parameter defining the range of the sigmoid
* $\alpha_{bias}$ is a free parameter defining the lower bound of the sigmoid.

Below is an interactive plot showing the effect of the sigmoidal global epistasis as a function of an adjustable $\alpha_{scale}$, and $\alpha_{bias}$:

In [5]:
import altair as alt

import numpy

import pandas as pd


df = pd.DataFrame({"latent": numpy.linspace(-10, 10, 100)})

slider_s = alt.binding_range(min=0.1, max=10)
var_s = alt.param(bind=slider_s, value=1, name="alpha_scale")

slider_b = alt.binding_range(min=-10, max=5)
var_b = alt.param(bind=slider_b, value=0, name="alpha_bias")

(
    alt.Chart(df)
    .transform_calculate(
        phenotype=(1 / (1 + alt.expr.exp(-1*alt.datum['latent'])))
        * var_s
        + var_b
    )
    .encode(
        x=alt.X("latent", title="latent phenotype", scale=alt.Scale(domain=[-10, 10])),
        y=alt.Y("phenotype:Q", title="predicted phenotype", scale=alt.Scale(domain=[-10, 10]))
    )
    .mark_line()
    .add_params(var_s, var_b)
)


### Single layer neural network epistasis:

If you prefer a less constrained shape for global epistasis, we also offer the ability to learn the shape of global epistasis using a single layer neural network. For this option, the user defines a number of units in the singular hidden layer of the model. For each hidden unit, we introduce three free parameters (2 weights and a bias). All weights are clipped at zero to maintain assumptions of monotonicity in the resulting epistasis function shape.  The network applies a sigmoid activation to each internal unit before a final transformation and addition of a constant gives us our predicted functional score. 

Given a latent phenotype, z, the neural network function can then be defined:

$$
g(z) = b^{o}+ \sum_{i}^{n}(\frac{w^{o}_{i}}{1 + e^{w^{l}_{i}*z + b^{l}_{i}}})
$$

Where:
* $n$ is the number of units in the hidden layer
* $w^{l}_{i}$, $w^{o}_{i}$ are free parameters associated with the pre and post tranformations of unit $i$ in the hidden layer of the network
* $b^{l}_{i}$ is a vector of free parameters of length n
* $b^{o}$ is a constant, singular free parameter

**Note** In the `multidms` API, This is accomplished by setting the "output_activation" parameter in the constructor for `MultiDmsModel` to be a pointer to the function `multidms.biophysical.single_layer_nn`

### Identity (no epistasis):

In some scenerios, deep mutational scanning data only contains single-mutation variants to be observed. It then becomes hopeless to learn anything useful about epiststatic effects. In this scenerio, we provide the option to functionally disable any global epistsis modeling by setting $g(z) = z$. 

**Note** In the `multidms` API, This is accomplished by setting the "output_activation" parameter in the constructor for `MultiDmsModel` to be a pointer to the function definition `multidms.biophysical.identity_activation`

## Output activation, $t$

The commonly reported fold-change functional score metric falls victim to the limits of detection of a given experiment to provide an accurate ratio when the starting frequency of a barcoded variant is low. Common solutions to this sensitivity problem include the use of highly multiplexed assays with variant-level barcode replicates, as well simply filtering out low frequency variants. Unfortunatly, even modest filtering thresholds for starting frequency and number of barcode replicates may cut out important data. 

Thus, it's also common for researchers to simply _truncate_ (i.e. _clip_) the functional at some lower bound $l$, where observations below are assumed largely to be artifacts of experimental sensitivity. In practice, this truncation helps to reduce the impact of outliers on the model's performance.

This clipping leads to unwanted behavior in our global epistasis model. Put simply, the model, $g(\phi(v,h))$ _learns_ this lower bound and thus is encouraged to limit it's shape to conform to this lower bound. This behavior is particularly problematic when normalizing the observed functional scores with $\gamma_{h}$, described above. To provide intuition, consider a scenario where observed functional scores are truncated at a lower bound of -3 across all conditions. If $\gamma_d$ is fit to -1.0 for one of the non-reference conditions, then the new floor of functional scores will be -4 for that condition, while the floor for the reference condition would still be -3.
In this case, the global epistasis function could find itself in a pickle: if it allows predictions to go below -3, it could help model the floor of points in the non-reference condition, but hurt with modeling the floor of points in the reference condition.

### Softplus

By default, we avoid this unwanted behavior by applying a final activation on the output of the global epistasis model $g(\phi(v,h))$. In essence, this is a modified _softplus_ activation, ($\text{softplus}(x)=\log(1 + e^{x})$) with a _lower bound_ at $l + \gamma_{h}$, as well as a _ramping_ coefficient, $\lambda_{\text{sp}}$. 

Concretely, if we let $z' = g(\phi(v,h))$, then the predicted functional score of our model is given by:

$$
t(z') = \lambda_{sp}\log(1 + e^{\frac{z' - l}{\lambda_{sp}}}) + l
$$

Functionally speaking, this truncates scores below a lower bound, while leaving scores above (mostly) unaltered. There is a small range of input values where the function smoothly transitions between a flat regime (where data is truncated) and a linear regime (where data is not truncated). 

**Note** By default, we recommend leaving the $\lambda_{sp}$ parameter at it's default value of $0.1$. this ensures a sharp transition between regimes similar to a [ReLU](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)) function, but retain the differentible property for gradient based optimization.

Below is an interactive plot showing the effect of the modified softplus activation as a function of an adjustable $\lambda_{sp}$ scaling parameter, and lower bound, $l$:

In [6]:
import altair as alt

import numpy

import pandas as pd


df = pd.DataFrame({"latent": numpy.linspace(-10, 10, 100)})

slider_lsp = alt.binding_range(min=0.1, max=10)
var_lambda_sp = alt.param(bind=slider_lsp, value=1, name="lambda_sp")

slider_lb = alt.binding_range(min=-10, max=0)
var_lower_bound = alt.param(bind=slider_lb, value=-3.5, name="lower_bound")

(
    alt.Chart(df)
    .transform_calculate(
        phenotype=alt.expr.log(1 + alt.expr.exp((alt.datum['latent']-var_lower_bound)/var_lambda_sp))
        * var_lambda_sp
        + var_lower_bound
    )
    .encode(
        x=alt.X("latent", title="global epistasis prediction (z')", scale=alt.Scale(domain=[-10, 10])),
        y=alt.Y("phenotype:Q", title="predicted phenotype", scale=alt.Scale(domain=[-10, 10]))
    )
    .mark_line()
    .add_params(var_lambda_sp, var_lower_bound)
)

### Identity (no activation):

We offer the option to functionally _disable_ this feature incase there's no expected floor on the range of fitting data.

**Note** In the `multidms` API, This is accomplished by setting the "output_activation" parameter in the constructor for `MultiDmsModel` to be a pointer to the function definition `multidms.biophysical.identity_activation`

## Fitting procedure

`multidms` is implimented in _python_ using the [JAX](https://github.com/google/jax) library
allowing for the direct differentiation of the model functions.
The [jaxopt](https://jaxopt.github.io/stable/index.html) package is used for effecient optimization of all parameters with proximal gradient descent.
Given DMS training data, The objective is to minimize the difference between _normalized observed_ and _predected_ functional scores ($y'_{v, h} - \hat{y}$). We apply a lasso ($l^{1}$) penalty to all $s_{v,h}$ parameters for feature selection of high confidence values, and encouraging the rest to be $0$.

The total loss on a dataset is computed as the sum of the individual losses for each condition in the dataset. 

$$
L_{\text{total}} = \sum_{h} [L_{\text{fit},h} + L_{\text{reg},h}]
$$

conditional loss is then given by:

$$
\begin{align}
L_{\text{fit},h} &= \frac{\sum_{v} L_{\text{Huber}}(y'_{v,h}, t(v,h))}{n_h} \\
L_{\text{reg},h} &= \lambda \sum_{m} |s_{m,h}|
\end{align}
$$

where 
* $L_{\text{Huber}}$ is a Huber loss function and 
* $n_{n}$ is the number of variants in the condition.

Dividing the numerator $L_{\text{fit},h}$ by $n_d$ makes it so that $L_{\text{fit},d}$ returns the average loss across all variants. Ultimitely, this ensures that each condition contributes equally to $L_\text{total}$, regardless of the number of variants in that condition.

**Note** We find the qualitative results of model fits and their respective shift parameters are quite robust for a reasonable choice of $\lambda$ applied to the lasso. In other words, this parameter can be set to reflect the user's own preference for [accuracy-simplicity](https://en.wikipedia.org/wiki/Lasso_(statistics)) tradeoff.