In [1]:
import IPython.core.display as di # Example: di.display_html('<h3>%s:</h3>' % str, raw=True)
di.display_html('<script>jQuery(function() {if (jQuery("body.notebook_app").length == 0) { jQuery(".input_area").toggle(); jQuery(".prompt").toggle();}});</script>', raw=True)

# Biophysical model

The goal of `multidms` is to jointly infer mutational effects across multiple deep mutational scanning (DMS) experiments, including how much each mutation's effect differs between experiments.
We refer to these differences as *shifts*.
If the experiments were performed with different homologs of the same protein, a shift would indicate epistasis between the shifted mutation and the amino-acid mutations that separate the homologs.
Or, if they were performed with the same wildtype protein under different selective conditions (e.g., selection for viral entry using different but related host receptors), shifts would indicate condition-specific effects.
Ultimately, this model was designed to identify those shifts relative to user’s chosen _reference condition_ with feature selection via a lasso “$L_1$” regularization of the shift parameters described in the additive latent phenotype section below.

`multidms` is compatible with DMS data that have the following characteristics.
First, the data must report a functional score (e.g., log enrichment ratio) for each variant from each DMS experiment, where a variant constitutes a unique gene sequence covering the entire mutagenized region.
Thus, the deep-sequencing data must resolve complete haplotypes.
Second, most mutations must be seen in multiple unique variants per experiment, as this number is the basis by which shifts are regularized (see below).
This second requirement is often met in DMS libraries with multiple mutations per variant, as long as each mutation occurs in multiple genetic backgrounds, or in libraries with one mutation per variant, as long as each variant is uniquely barcoded and individual mutations are found in the background of multiple barcodes. Given input data from one or more experiments, `multidms` fits to all the data and provides:
(1) mutational effects across all experiments, $\beta_m$
(2) how effects are shifted between experiments, $\Delta_{d, m}$, and 
(3) A model to predict the functional score of any given variant.

**Note**: We suggest reading the [Otwinowski et al. 2018](https://www.pnas.org/doi/10.1073/pnas.1804015115) paper to understand the approach for modeling global epistasis before reading the rest of the documentation.

## Joint-model composition

At a high level, the model is a composition of three
functions which describe the expected biophysical interactions underlying a given phenotype;
(1) $\phi$, an additive model describing a variant's _latent_ phenotype under a given condition,
(2) $g$, a global epistasis model shared by all conditions to disentangle the effects of multiple mutations on the same variant, and
(3) $t$, an _optional_ final output activation function which can account for variant functional scores which have been clipped at some lower bound _prior_ to using `multidms`. 

Concretely, the predicted phenotype for a given variant $v$ under condition $d$ is given by 

$$
\hat{y}_{v, d} = t_{\gamma_{d}}(g_{\alpha}(\phi_{d, \beta, \Delta}(v))
$$

Where 
$\gamma$, $\alpha$, $\beta$, and $\Delta$
are _free_ parameters inferred from experimental observations during the fitting process. We describe the individual components and their associated parameters in more detail below.

**Note** The motivation behind defining an abstract model in terms of it's components are (1) modularity for method testing and development, and (2) the ability to provide multiple options for model components. While there is only a single option for the latent prediction, $\phi$, we offer a few options _post-latent_ modeling, $g$ and $t$, that encompass the needs of differing research goals and experimental techniques. Generally speaking, the package defaults for these components should be sufficient for most purposes, and in this case feel free to ignore the `multidms.biophysical` module all-together as this functionality is hidden unless explicitly specified during the instantiation of a `MultiDmsModel` object. 

## latent phenotype

For each mutation $m$, the model defines a single mutation effect parameter, $\beta_{m}$ shared by all conditions.
Additionally, the model defines the set of shift parameters, $\Delta_{d,m}$, 
that quantifies the shift a given mutation's effect relative to some reference condition. 
Each mutation, for each non-reference condition, is associated with an independent shift parameter.
For example, if there exists three total experimental conditions, $d \in \{d_{1}, d_{2}$, $d_{3}\}$, 
and we define $d_{1}$ to be the _reference_ condition,
the model defines two sets of (non-reference) _shift_ parameters $\Delta_{d_{2}, m}$, $\Delta_{d_{3}, m}$ that may be fit to non-zero values. Note that while a $s_{m, h_{1}}$ does exist for computational and mathemathical coherency, it is locked to $0$ during the fitting procedure and is functionally ignored.

Concretely, the latent phenotype of any variant, $v$, from the experimental condition, $d$,
is computed like so:

$$
\phi_d(v) = \beta_0 + \beta^{(d)}_0 + \sum_{m \in v} (\beta_{m} + \Delta_{d, m})
$$

Where:

* $\beta_0$ is a bias parameter applied to the latent prediction of all experiments.
* $\beta^d_0$ is a set of bias parameters, optionally applied to the latent phenotype of all _non-reference_ experiments.
* $\beta_{m}$ is the latent phenotypic effect of mutation $m$ shared by all experiments (See the note below), 
* $\Delta_{d, m}$ is the shift of the effect of mutation $m$ in condition $d$.
* $v$ is the set of all mutations relative to the reference wild type sequence including all potential non-identical wildtype mutations that separate condition $d$ from the reference condition.

Importantly, If a _non-reference_ experiment has a different wildtype sequence from the _reference_ (e.g., the wildtype sequences are homologs), then the `multidms.MultiDmsData` object will encode non-reference genotypes relative to the reference wild type. 
Thus, the summation term includes all mutations at non-identical sites that convert between the two sequences.
Likewise, If a mutation occurs at a non-identical site, then the mutation is encoded relative to the reference.
For example, consider a protein where site $30$ is a Y in the non-reference experiment's wildtype sequence and an A in the reference experiment's wildtype sequence.
If a variant from the non-reference experiment had a Y30G mutation, then the mutation would be defined as A30G in the summation term.
This does not assume that Y30G has the same effect as A30G.
It merely follows the strategy to define all sequences relative to the reference, which is practical because it ensures that each experiment informs the exact same set of $\beta_m$ and $\Delta_{d,m}$ parameters.
If a variant from a non-reference experiment had a reversion mutation at a non-identical site (e.g., Y30A), then the mutation would not be included in the summation term for that variant since the variant would have the reference amino-acid identity at that site.
In comparison, A30Y would be included in the summation term of all variants from the non-reference experiment that lack a mutation at site 30.
If mutations at non-identical sites are not sampled in the DMS libraries (e.g., Y30A is missing from the non-reference experiment and A30Y is missing from the reference experiment), then $C_d$ can be used to capture the combined effects of these missing mutations.
Otherwise, if all such mutations are sampled, $C_d$ can be locked at zero.

## global epistasis

Latent phenotypes as described above give rise to functional scores according to a global-epistasis function.
If this function is non-linear, then the model allows mutations to non-additively effect functional scores, helping to account for global epistasis. This type of model is useful for inferring effects of individual mutations in variants with more than one mutation. Below, we decribe the available options to model global-epistasis in the `multidms` infrastructure.

**Note**: when analyzing DMS libraries in which variants have a maximum of one mutation, then it will not be possible for the model to learn the shape of global-epistasis, in which case we provide an `identity` global-epistasis function described below, which assumes no global epistasis, but still allows the user to take advantage of the rest of the `multidms` approach.

### Sigmoidal model (default):

By default, the global-epistasis function here assumes a sigmoidal relationship between
a protein's latent property and its functional score measured in the experiment
(e.g., log enrichment score). Using free parameters, the sigmoid
can flexibly conform to an optimal shape informed by the data. 
Note that this function is independent from the
experimental condition from which a variant is observed.

Given _latent phenotype_, $\phi_d(v) = z$, let

$$
g(z) =  \frac{\alpha_{scale}}{1 + e^{-z}} + \alpha_{bias}
$$

Where $\alpha_{scale}$ and $\alpha_{bias}$ are free parameters defining the range and lower bound of the sigmoid, respectively.

Below is an interactive plot showing the effect of the sigmoidal global epistasis as a function of an adjustable $\alpha_{scale}$, and $\alpha_{bias}$:

In [2]:
import altair as alt

import numpy

import pandas as pd


df = pd.DataFrame({"latent": numpy.linspace(-10, 10, 100)})

slider_s = alt.binding_range(min=0.1, max=10)
var_s = alt.param(bind=slider_s, value=5, name="alpha_scale")

slider_b = alt.binding_range(min=-10, max=5)
var_b = alt.param(bind=slider_b, value=0, name="alpha_bias")

(
    alt.Chart(df)
    .transform_calculate(
        phenotype=(1 / (1 + alt.expr.exp(-1*alt.datum['latent'])))
        * var_s
        + var_b
    )
    .encode(
        x=alt.X("latent", title="latent phenotype", scale=alt.Scale(domain=[-10, 10])),
        y=alt.Y("phenotype:Q", title="predicted phenotype", scale=alt.Scale(domain=[-10, 10]))
    )
    .mark_line()
    .add_params(var_s, var_b)
)


### Softplus model

This function is a log-transformed version of a sigmoid function from above, with  $\alpha_{scale}$ and $\alpha_{bias}$ parameters serving similar roles.
The shape of this function mimics the Hill equation from the above example if the y-axis is instead the log of the fraction of protein molecules that are folded.
Such a function could be more appropriate for modeling functional scores in log space.

Given _latent phenotype_, $\phi_d(v) = z$, let

$$
g(z) =  -\alpha_\text{scale}\log\left(1+e^{-z}\right) + \alpha_\text{bias}
$$

**Note** In the `multidms` API, This behavior is accomplished by setting the "output_activation" parameter in the constructor for `MultiDmsModel` to be a pointer to the function `multidms.biophysical.softplus_activation`

### Single-layer neural network model:

If you prefer a less constrained shape for global epistasis, we also offer the ability to learn the shape of global epistasis using a single-layer neural network, sometimes referred to as a [multi-layer perceptron](https://en.wikipedia.org/wiki/Multilayer_perceptron). This is similar to what was presented by [Zhou et. al. 2022](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC9522415/)

For this option, the user defines a number of units in the singular hidden layer of the model. For each hidden unit, we introduce three parameters (two weights and a bias) to be inferred. All weights are clipped at zero to maintain assumptions of monotonicity in the resulting epistasis function shape.  The network applies a sigmoid activation to each internal unit before a final transformation and addition of a constant gives us our predicted functional score. 

Given _latent phenotype_, $\phi_d(v) = z$, let

$$
g(z) = b^{o}+ \sum_{i}^{n}(\frac{w^{o}_{i}}{1 + e^{w^{l}_{i}*z + b^{l}_{i}}})
$$

Where: 

* $n$ is the number of units in the hidden layer.
* $w^{l}_{i}$ and $w^{o}_{i}$ are free parameters representing latent and output tranformations, respectively, associated with unit $i$ in the hidden layer of the network. 
* $b^{l}_{i}$ is a free parameter, as an added bias term to unit $i$.
* $b^{o}$ is a constant, singular free parameter.

**Note** This is an advanced feature and we advise against it's use unless the other options are not sufficiently parameterized for particularly complex experimental conditions.

**Note** In the `multidms` API, This behavior is accomplished by setting the "output_activation" parameter in the constructor for `MultiDmsModel` to be a pointer to the function `multidms.biophysical.single_layer_nn`

### Identity (no epistasis):

Given _latent phenotype_, $\phi_d(v) = z$, let

$$
g(z) = z
$$

In this functional form, there is no global epistasis: latent phenotypes are identical to functional scores.
We recommend this functional form if none of the variants in the DMS experiment have more than one mutation, since multi-mutant variants are needed for the model to accurately infer a non-linear global-epistasis function.
It can also be used as a baseline to determine if a non-linear global-epistasis function leads to better model fit.

**Note** In the `multidms` API, This is accomplished by setting the "output_activation" parameter in the constructor for `MultiDmsModel` to be a pointer to the function definition `multidms.biophysical.identity_activation`

## Optional normalization of observed functional scores

A common way to compute functional scores is with log enrichment ratios, where all scores from a given condition are normalized so that the wildtype sequence from that condition has a value of zero.
If wildtype sequences differ between conditions, then log enrichment ratios may not be directly comparable between conditions, as they are normalized to different reference points.
This breaks the assumption of our joint-modeling scheme: that all functional scores are directly comparable.

Ideally one or more of the same sequences would be included in the DMS library of each condition, so that all scores could be normalized to the same sequence.
However, if there are no common sequences, then it may be possible to computationally estimate how to normalize scores so that they are more directly comparable.

To this end, the model includes an additional parameter $\gamma_d$ for each non-reference condition that allows functional scores from that condition to be normalized as follows:

$$
y_{v,d}^{\text{norm}} = y_{v,d} + \gamma_d
$$

where $\gamma_d$ for the reference condition is locked at zero.
There is a theoretical basis for adding $\gamma_d$ to $y_{v,d}$ if functional scores are log enrichment ratios.
As mentioned above, log enrichment ratios are normalized so that the wildtype sequence from a given experiment has a value of zero, according to the formula:

$$
y_{v,d} = \log(E_{v,d}) - \log(E_{\text{wt},d})
$$

Thus, adding $\gamma_d$ to $y_{v,d}$ is akin to renormalizing the log enrichment ratios so that a different sequence has a functional score of zero.
In theory, for each non-reference condition, there is a $\gamma_d$ that normalizes functional scores to be relative to the wildtype sequence of the reference condition.
If these values are not experimentally measured, the model allows $\gamma_d$ parameters to be fit during optimization, which assumes that the correct $\gamma_d$ values will give the best model fit.
Alternatively, $\gamma_d$ values can be locked at zero if the user does not wish to implement this optional feature.

## Optional truncation of predicted functional scores

By default, the output of the model is equal to the output of the global epistasis model chosen $\hat{y}_d(v) = g(\phi_d(v))$. However, we provide infrastructure to clip the predicted functional scores at some lower bound for users who may clip their input data. Here, we'll describe the motivation and approach for this.

The commonly reported fold-change metric for functional scores described above falls victim to 
a limit of detection problem for deleterious mutations: once a mutation is sufficiently deleterious, further differences (eg, between -3 and -5) become largely meaningless because it is already basically nonfunctional.
Thus, it's common for researchers to simply _truncate_ (i.e. _clip_) the functional at some lower bound $l$, where observations below are assumed largely to be outlier.

When fitting the model to these data, it may be desirable to truncate \emph{predicted} functional scores in a similar way.
This is especially relevant when allowing $\gamma_d$ values to be non-zero.
For instance, say the user has truncated observed functional scores at a lower bound of $-3$ across all conditions.
If $\gamma_d$ is fit to $-1.0$ for one of the non-reference conditions, then the new floor of normalized functional scores for that condition would be $-4$, while the floor for the reference condition would still be $-3$.
In this case, the global-epistasis function could find itself in a pickle: if it allowed predictions to go below $-3$, it could help model the floor of points in the non-reference condition, but hurt with modeling the floor of points in the reference condition.
However, if it was able to truncate predicted scores for a given condition at the floor for that condition, then this tension would be relieved: predicted scores for the reference condition could be truncated at $-3$, while (non-truncated) predicted scores for the non-reference condition could go as low as $-4$.

To this end, the software package provides an option to truncate predicted functional scores which should be used if (and only if) the user has clipped the functional scores in their data to some lower bound.
To enable this, the mathematical model passes predicted functional scores through an activation function, $t$.

In essence, this is a modified _softplus_ activation, ($\text{softplus}(x)=\log(1 + e^{x})$) with a _lower bound_ at $l + \gamma_{h}$, as well as a _ramping_ coefficient, $\lambda_{\text{sp}}$. 

Concretely, if we let $z' = g(\phi_d(v))$, then the predicted functional score of our model is given by:

$$
t(z') = \lambda_{sp}\log(1 + e^{\frac{z' - l}{\lambda_{sp}}}) + l
$$

Functionally speaking, this truncates scores below a lower bound, while leaving scores above (mostly) unaltered. There is a small range of input values where the function smoothly transitions between a flat regime (where data is truncated) and a linear regime (where data is not truncated). 

Below is an interactive plot showing the effect of the modified softplus activation as a function of an adjustable $\lambda_{sp}$ scaling parameter, and lower bound, $l$:

**Note** We recommend leaving the $\lambda_{sp}$ parameter at it's default value of $0.1$. this ensures a sharp transition between regimes similar to a [ReLU](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)) function, but retain the differentible property for gradient based optimization.

In [3]:
import altair as alt

import numpy

import pandas as pd


df = pd.DataFrame({"latent": numpy.linspace(-10, 5, 100)})

slider_lsp = alt.binding_range(min=0.1, max=2)
var_lambda_sp = alt.param(bind=slider_lsp, value=1, name="lambda_sp")

slider_lb = alt.binding_range(min=-10, max=0)
var_lower_bound = alt.param(bind=slider_lb, value=-3.5, name="lower_bound")

(
    alt.Chart(df)
    .transform_calculate(
        phenotype=alt.expr.log(1 + alt.expr.exp((alt.datum['latent']-var_lower_bound)/var_lambda_sp))
        * var_lambda_sp
        + var_lower_bound
    )
    .encode(
        x=alt.X("latent", title="global epistasis prediction (z')", scale=alt.Scale(domain=[-10, 5])),
        y=alt.Y("phenotype:Q", title="predicted phenotype", scale=alt.Scale(domain=[-10, 5]))
    )
    .mark_line()
    .add_params(var_lambda_sp, var_lower_bound)
)

## Fitting procedure

The `multidms` software package has a framework for fitting free parameters in the model given input DMS data.
The mathematical model is coded in [Python](https://www.python.org/) using the [JAX](https://github.com/google/jax) , allowing for autograd and XLA compilation for high-performance optimization.
Non-smooth optimization problem via proximal gradient descent in order to satisfy several simultaneous constrains such as L1 regularization, optionally locking parameters, and clipping parameters at pre-specified lower bounds.
Specifically, we use [JAXopt](https://jaxopt.github.io/stable/index.html) to optimize parameters via full-batch proximal gradient descent using the [JAXopt.ProximalGradient](https://jaxopt.github.io/stable/_autosummary/jaxopt.ProximalGradient.html) function.

To get ready for future more general non-smooth penalties, we use a proximal gradient method for the constrain. The objective splits into a smooth piece with a gradient, and a non-smooth piece with a proximity operator:

The full objective function we use to optimize parameters involves summing two terms for each experiment $d$:

$$
L_{\text{total}} = \sum_{d} [L_{\text{fit},d} + L_{\text{reg},d}]
$$

The first term, $L_{\text{fit},d}$, computes the loss between observed and predicted functional scores for a given condition, after applying the two optional modifications described above for renormalizing observed scores with $\gamma_d$ and truncating predicted scores with $t$:

$$
    L_{\text{fit},d} = \frac{\sum_{v} L_{h_{\delta}}[(y_{v,d}+\gamma_d) - t(\hat{y}_{v,d})]}{n_d}
$$

where $L_{h_{\delta}}$ is a Huber loss function with $\delta=1$ by default, and $n_d$ is the number of variants in the condition.
Dividing the numerator by $n_d$ makes it so that $L_{\text{fit},d}$ returns the average loss across all variants.
This ensures that each condition contributes equally to $L_\text{total}$, regardless of the number of variants in that condition.
The second term, $L_{\text{reg},d}$, uses L1 regularization to penalize non-zero $\Delta_{d,m}$ values:

$$
L_{\text{reg},d} = \lambda \sum_{m} |\Delta_{d,m}|
$$

where $\lambda$ controls the strength of regularization.

The goal of the regularization term is to drive $\Delta_{d,m}$ parameters to zero unless they are strongly supported by the data.
In the absence of regularization, $\Delta_{d,m}$ parameters will be fit to optimize $L_{\text{fit},d}$.
Whether something is included depends on the data set size and the effect size, but you only highlight the former here.
In the presence of regularization, the sensitivity of $\Delta_{d,m}$ parameters to regularization will depend on both the magnitude of the shift and how many variants have the mutation $m$ in each experiment.