In [1]:
import IPython.core.display as di # Example: di.display_html('<h3>%s:</h3>' % str, raw=True)
di.display_html('<script>jQuery(function() {if (jQuery("body.notebook_app").length == 0) { jQuery(".input_area").toggle(); jQuery(".prompt").toggle();}});</script>', raw=True)

# Biophysical model

The `multidms` joint model applies primarily to a case where you have deep mutational scanning (DMS) datasets
for two or more experimental conditions and are interested in identifying differences (i.e. _shifts_)
in mutational effects between conditions. 
Here we describe the biophysical model of conditional shifts that motivates the approach in this package.

We suggest reading the [Otwinowski et al. 2018](https://www.pnas.org/doi/10.1073/pnas.1804015115) paper to understand the approach for modeling global epistasis before reading the rest of the documentation.

## Model overview

`multidms` extends the traditional global epistasis model by informing the parameters with multiple DMS experiments under differing experimental _conditions_ (referred to as $h$). Distinct conditions may include non-homologous wildtype sequences between DMS experiments (i.e. SARS-CoV-2 Delta Vs. Omicron spike). However, this approach could also be used for experimental conditions which share homology of the wildtype, but are run with different selection steps. This model makes the assumption that differing conditions should result in _mostly_ the same effect of mutations -- but some _shifts_ in mutational effect due to biological mechanisms (i.e. epistasis). Ultimately, this model was designed to identify those shifts relative to user's chosen _reference_ condition with feature selection via a lasso "$L_{1}$" regularization of the shift parameters described in the additive latent phenotype section below.

At a high level, the model is a composition of three
functions which describe the expected biophysical interactions underlying a given phenotype;
(1) $\phi$, an additive model describing a variant's _latent_ phenotype under a given condition,
(2) $g$, a global epistasis model shared by all conditions to disentangle the effects of multiple mutations on the same variant, and
(3) $t$, a final output activation function accounting for the expected _lower bound_ on any given variant's phenotype.

Concretely, the predicted phenotype for a given variant $v$ under condition $h$ is given by 

$$
\hat{y}_{v,h} = t_{\gamma}(g_{\alpha}(\phi_{\beta, S, C_{r}}(v,h))
$$

Where 
$\gamma$, $\alpha$, $\beta$, $S$, and $C_{r}$
are _free_ parameters inferred from experimental observations during the fitting process. We describe the individual components and their associates parameters in more detail below.

**Note** The motivation behind defining an abstract model in terms of its components provides (1) modularity for method testing and development, and (2) multiple options for model components. While there is only a single option for the latent prediction, $\phi$, we offer a few options _post-latent_ modeling, $g$ and $t$, that encompass the needs of differing research goals and experimental techniques. Generally speaking, the package defaults for these components should be sufficient for most purposes, and in this case feel free to ignore the `multidms.biophysical` module all-together as this functionality is hidden unless explicitly specified during the instantiation of a `MultiDmsModel` object. 

## Normalizing observed functional scores using $\gamma_{h}$ 

The first consideration is whether observed functional scores are directly comparable between conditions.
In context of the model, _Predicted_ functional scores are all directly comparable since all are generated from the same latent space via the same global-epistasis function.
However, observed scores may not be comparable in the same way.
For instance, a common way to compute functional scores is with log enrichment ratios, normalized so that the wildtype sequence has a value of zero.
If the conditions being compared are DMS experiments conducted in the background of different homologs, then each homolog will necessarily have a functional score of zero within its experiment, even if those same homologs would have different functional scores when measured in the same experiment.
Thus, log enrichment ratios are not always directly comparable between experiments.
Ideally, one or more of the same sequences would be included in the experimental design of DMS libraries to be compared and normalized to the same sequence.
However, if there are no such sequences, then it may be possible to computationally estimate how to renormalize scores.

To this end, the model includes an additional parameter $\gamma_h$ for each non-reference condition that allows functional scores, $y'_{v,h}$ from that condition to be renormalized as follows:

$$
y'_{v,h} = y_{v,h} + \gamma_h
$$

where $\gamma_h$ for the reference condition is locked at zero.
There is a theoretical basis for adding $\gamma_h$ to $y_{v,h}$ if functional scores are log enrichment ratios.
As mentioned above, log enrichment ratios are normalized so that the wildtype sequence from a given experiment has a value of zero, according to the formula:

$$
y_{v,h} = \log(E_{v,h}) - \log(E_{\text{wt},h})
$$

Thus, adding $\gamma_d$ to $y_{v,d}$ is akin to renormalizing the log enrichment ratios so that a different sequence has a functional score of zero.
In theory, for each non-reference condition, there is a $\gamma_d$ that renormalizes functional scores to be relative to the wildtype sequence of the reference condition.
If these values are not experimentally measured, the model allows $\gamma_d$ parameters to be fit during optimization, which assumes that the correct $\gamma_d$ values will give the best model fit.

## Additive latent phenotype, $\phi$

For each mutation $m$, the model defines a single mutation effect parameter, $\beta_{m}$ shared by all conditions.
Additionally, the model defines the set of shift parameters, $s_{m,h}$, 
that quantifies the shift a given mutation's effect relative to some reference condition. 
Each mutation, for each non-reference condition, is associated with an independent shift parameter.
For example, if there exists three total experimental conditions, $h \in \{h_{1}, h_{2}$, $h_{3}\}$, 
and we define $h_{1}$ to be the _reference_ condition,
the model defines two sets of non-reference _shift_ parameters $s_{m, h_{2}}$, $s_{m, h_{3}}$ that may be fit to non-zero values. Note that while a $s_{m, h_{1}}$ does exist for computational and mathemathical coherency, it is locked to $0$ during the fitting procedure and is functionally ignored.

Concretely, the latent phenotype of any variant, $v$, from the experimental condition, $h$,
is computed like so:

$$
\phi(v,h) = c_{r} + \sum_{m \in v} (\beta_{m} + S_{m,h})
$$

Where:

* $c_{r}$ is the wildtype latent phenotype for the reference condition. 
* $\beta_{m}$ is the latent phenotypic effect of mutation $m$ (See the note below), 
* $s_{m,h}$ is the shift of the effect of mutation $m$ in condition $h$.
* $v$ is the set of all mutations relative to the reference wild type sequence including all potential non-identical wildtype mutations that separate condition $h$ from the reference condition.

As stated in the overview above, we expect the mutation effect among conditions to largely be the same. By applying a _lasso_ $L_{1}$ regularization term to shift parameters we encourage most $s_{m,h}$ to be zero, while identifying non-zero shifts with confidence. We find that the model is robust for most reasable choices of regularization strength. See the fitting procedure below for more on this.

**Note** The $\beta_m$ variable is defined such that mutations are always relative to the
reference condition. For example, if the wild type amino acid at site 30 is an
A in the reference condition, and a G in a non-reference condition,
then a Y30G mutation in the non-reference condition is recorded as an A30G
mutation relative to the reference. This way, each condition informs
the exact same parameters, even at sites that differ in wild type amino acid.
These are encoded in a `BinaryMap` object, where all sites that are non-identical
to the reference are 1's.

## global epistasis, $g$

Latent phenotypes as described above give rise to functional scores according to a global-epistasis function.
If this function is non-linear, then the model allows mutations to non-additively effect functional scores, helping to account for global epistasis.
Here, we'll go into more detail about the options available in `multidms` 

### Sigmoidal:

By default, the global-epistasis function here assumes a sigmoidal relationship between
a protein's latent property and it's functional score measured in the experiment
(e.g., log enrichment score). Using free parameters, the sigmoid
can flexibly conform to an optimal shape informed by the data. 
Note that this function is independent from the
experimental condition from which a variant is observed.

The sigmoidal function that relates a given _latent phenotype_, $z$, to its functional score is given by:

$$
g(z) =  \frac{\alpha_{scale}}{1 + e^{-z}} + \alpha_{bias}
$$

Where $\alpha_{scale}$ and $\alpha_{bias}$ are free parameters defining the range and lower bound of the sigmoid, respectively.

Below is an interactive plot showing the effect of the sigmoidal global epistasis as a function of an adjustable $\alpha_{scale}$, and $\alpha_{bias}$:

In [4]:
import altair as alt

import numpy

import pandas as pd


df = pd.DataFrame({"latent": numpy.linspace(-10, 10, 100)})

slider_s = alt.binding_range(min=0.1, max=10)
var_s = alt.param(bind=slider_s, value=5, name="alpha_scale")

slider_b = alt.binding_range(min=-10, max=5)
var_b = alt.param(bind=slider_b, value=0, name="alpha_bias")

(
    alt.Chart(df)
    .transform_calculate(
        phenotype=(1 / (1 + alt.expr.exp(-1*alt.datum['latent'])))
        * var_s
        + var_b
    )
    .encode(
        x=alt.X("latent", title="latent phenotype", scale=alt.Scale(domain=[-10, 10])),
        y=alt.Y("phenotype:Q", title="predicted phenotype", scale=alt.Scale(domain=[-10, 10]))
    )
    .mark_line()
    .add_params(var_s, var_b)
)


### Single-layer neural network epistasis:

If you prefer a less constrained shape for global epistasis, we also offer the ability to learn the shape of global epistasis using a single-layer neural network, sometimes referred to as a [multi-layer perceptron](https://en.wikipedia.org/wiki/Multilayer_perceptron). For this option, the user defines a number of units in the singular hidden layer of the model. For each hidden unit, we introduce three parameters (two weights and a bias) to be inferred. All weights are clipped at zero to maintain assumptions of monotonicity in the resulting epistasis function shape.  The network applies a sigmoid activation to each internal unit before a final transformation and addition of a constant gives us our predicted functional score. 

Given a latent phenotype, $z$, the neural network function can then be defined:

$$
g(z) = b^{o}+ \sum_{i}^{n}(\frac{w^{o}_{i}}{1 + e^{w^{l}_{i}*z + b^{l}_{i}}})
$$

Where: 

* $n$ is the number of units in the hidden layer.
* $w^{l}_{i}$ and $w^{o}_{i}$ are free parameters representing latent and output tranformations, respectively, associated with unit $i$ in the hidden layer of the network. 
* $b^{l}_{i}$ is a free parameter, as an added bias term to unit $i$.
* $b^{o}$ is a constant, singular free parameter.

**Note** In the `multidms` API, This behavior is accomplished by setting the "output_activation" parameter in the constructor for `MultiDmsModel` to be a pointer to the function `multidms.biophysical.single_layer_nn`

### Identity (no epistasis):

In some scenerios, deep mutational scanning data only contains single-mutation variants to be observed. It then becomes hopeless to learn anything useful about epiststatic effects. In this scenerio, we provide the option to functionally disable any global epistsis modeling by setting $g(z) = z$. 

**Note** In the `multidms` API, This is accomplished by setting the "output_activation" parameter in the constructor for `MultiDmsModel` to be a pointer to the function definition `multidms.biophysical.identity_activation`

## Output activation, $t$

The commonly reported fold-change metric for functional scores described above falls victim to reported outliers. This is particularly problematic when the starting frequency of a barcoded variant is low. Common solutions to this sensitivity problem include the use of highly multiplexed assays with variant-level barcode replicates, as well simply filtering out low frequency variants. Unfortunatly, even modest filtering thresholds for starting frequency and number of barcode replicates may cut out informative signal for the model. Thus, it's also common for researchers to simply _truncate_ (i.e. _clip_) the functional at some lower bound $l$, where observations below are assumed largely to be outlier.

This type of truncation in the data leads to unwanted behavior in our global epistasis model. Put simply, $g$ _learns_ this lower bound, and thus is encouraged to limit it's shape to conform to this lower bound during the fitting process. This behavior is particularly problematic when normalizing the observed functional scores with $\gamma_{h}$, described above. To provide intuition, consider a scenario where observed functional scores are truncated at a lower bound of -3 across all conditions. If $\gamma_d$ is fit to -1.0 for one of the non-reference conditions, then the new floor of functional scores will be -4 for that condition, while the floor for the reference condition would still be -3. In this case, the global epistasis function could find itself in a pickle: if it allows predictions to go below -3, it could help model the floor of points in the non-reference condition, but hurt with modeling the floor of points in the reference condition.

### Softplus

By default, we avoid this unwanted behavior by applying a final activation on the output of the global epistasis model. In essence, this is a modified _softplus_ activation, ($\text{softplus}(x)=\log(1 + e^{x})$) with a _lower bound_ at $l + \gamma_{h}$, as well as a _ramping_ coefficient, $\lambda_{\text{sp}}$. 

Concretely, if we let $z' = g(\phi(v,h))$, then the predicted functional score of our model is given by:

$$
t(z') = \lambda_{sp}\log(1 + e^{\frac{z' - l}{\lambda_{sp}}}) + l
$$

Functionally speaking, this truncates scores below a lower bound, while leaving scores above (mostly) unaltered. There is a small range of input values where the function smoothly transitions between a flat regime (where data is truncated) and a linear regime (where data is not truncated). 

**Note** By default, we recommend leaving the $\lambda_{sp}$ parameter at it's default value of $0.1$. this ensures a sharp transition between regimes similar to a [ReLU](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)) function, but retain the differentible property for gradient based optimization.

Below is an interactive plot showing the effect of the modified softplus activation as a function of an adjustable $\lambda_{sp}$ scaling parameter, and lower bound, $l$:

In [8]:
import altair as alt

import numpy

import pandas as pd


df = pd.DataFrame({"latent": numpy.linspace(-10, 5, 100)})

slider_lsp = alt.binding_range(min=0.1, max=2)
var_lambda_sp = alt.param(bind=slider_lsp, value=1, name="lambda_sp")

slider_lb = alt.binding_range(min=-10, max=0)
var_lower_bound = alt.param(bind=slider_lb, value=-3.5, name="lower_bound")

(
    alt.Chart(df)
    .transform_calculate(
        phenotype=alt.expr.log(1 + alt.expr.exp((alt.datum['latent']-var_lower_bound)/var_lambda_sp))
        * var_lambda_sp
        + var_lower_bound
    )
    .encode(
        x=alt.X("latent", title="global epistasis prediction (z')", scale=alt.Scale(domain=[-10, 5])),
        y=alt.Y("phenotype:Q", title="predicted phenotype", scale=alt.Scale(domain=[-10, 5]))
    )
    .mark_line()
    .add_params(var_lambda_sp, var_lower_bound)
)

### Identity (no activation):

We offer the option to functionally _disable_ this feature in the case where there's no expected floor on the range of fitting data.

**Note** In the `multidms` API, This is accomplished by setting the "output_activation" parameter in the constructor for `MultiDmsModel` to be a pointer to the function definition `multidms.biophysical.identity_activation`

## Fitting procedure

`multidms` is implimented in _python_ using the [JAX](https://github.com/google/jax) library.
This allows for the direct differentiation of the composed model.
The [jaxopt](https://jaxopt.github.io/stable/index.html) package is used for effecient optimization of all parameters with [proximal gradient descent](https://en.wikipedia.org/wiki/Proximal_gradient_method).
Given DMS training data, The objective is to minimize the difference between _normalized observed_ and _predicted_ functional scores ($y'_{v, h} - \hat{y}$). We apply a lasso $L_{1}$ penalty to all $s_{m,h}$ parameters for feature selection of high confidence values, and encouraging the rest to be $0$.

The total loss on a dataset is computed as the sum of the individual losses for each condition in the dataset. 

$$
L_{\text{total}} = \sum_{h} [L_{\text{fit},h} + L_{\text{reg},h}]
$$

conditional loss is then given by:

$$
\begin{align}
L_{\text{fit},h} &= \frac{\sum_{v} L_{\text{Huber}}(y'_{v,h}, t(v,h))}{n_h} \\
L_{\text{reg},h} &= \lambda \sum_{m} |s_{m,h}|
\end{align}
$$

Where; $L_{\text{Huber}}$ is a Huber loss function and $n_{h}$ is the number of variants in the condition.

Dividing the numerator $L_{\text{fit},h}$ by $n_h$ makes it so that $L_{\text{fit},h}$ returns the average loss across all variants. Ultimitely, this ensures that each condition contributes equally to $L_\text{total}$, regardless of the number of variants in that condition.

**Note** We find the qualitative results of model fits and their respective shift parameters are quite robust for a reasonable choice of $\lambda$ applied to the lasso. In other words, this parameter can be set to reflect the user's own preference for [accuracy-simplicity](https://en.wikipedia.org/wiki/Lasso_(statistics)) tradeoff.