In [None]:
%load_ext autoreload
%autoreload 2

# Journal Club: The simplicity of protein sequence-function relationships

I recently stumbled upon this paper by Yeonwoo Park that I wanted to explore in more depth. 

Here's why. 

This paper proposes a mathematical model that attempts to explain the effect of mutations on a protein's phenotype. This mathematical model is parsimonious in nature: it proposes that the effect of a phenotype is the result of the sum of the following terms:

- A zeroth order effect, which is the average phenotype value across all observed genotypes,
- A first-order effect, which is calculated by taking the average phenotype of all genotypes with a state  (i.e. a letter at a position) and calculating its difference from the global average;
- A second-order effect, which is calculated by taking the average phenotype of all genotypes containing two states, and calculating its difference from the first-order prediction. 
- Other $n$-th order effects, which is calculated using the same logic (average phenotype of all genotypes containing $n$ states, and calculating its difference from the $n-1$th order prediction).

This is mathematically elegant, and I want to see if I can re-implement the mathematical model here in Python, while also making it fully Bayesian in PyMC.

## Generative Model: Dissecting the Equations

Within Figure 1(B) of the paper, we see a phenotype function,
which begins by modeling the main phenotypic effects as being:

$$s(g_1,...,g_n)=e_0 + \sum_{i=1}^{n} e_i(g_i) + \sum_{i < k} e_{i, k}(g_i, g_k) + ...$$


Here, the $e$ symbols are:

- $e_0$: the zero-th order effect,
- $e_i$: the 1st order effect,
- $e_{i, k}$: the 2nd order effect,

Additionally,

- $i$ refers to the position out of $n$ position (in Python, this would be 0...n-1)
- $g_i$ is the genotype at position $i$; in a binary genotype setting, this would be modeled as a $(0, 1)$ scalar, while in a multinomial genotype setting, we might model this using a one-hot vector instead.
- $k$ is also a position out of $n$ positions; we restrict it to $i < k$ so that we do not double-count positions.

The 1st, 2nd, and, more generally $k$-th order effects have a special property: within this model, they must sum to 0. This is due to the _reference-free_ nature of this model: instead of estimating mutational effects relative to a baseline variant (commonly called the wild-type), mutational effects are measured relative to the mean phenotype value across all phenotypes. By this definition, 1st-order effects across all genotypes must sum to zero; a particular genotype, however, will by definition not contain all genotypes (the logical proof of this is that it is impossible for a protein to contain two differing amino acid sequences at the same position), and so its phenotype can be modeled as the average phenotype + the sum of effects of its particular genotype's single effects composition.

Then, $s$ is transformed by a sigmoid function to model non-specific epistasis in the final observed phenotype:

$$y = L + \frac{U - L}{1 + e^{-s}}$$

where:

- $L$ is the lower-bound quantitative phenotype of interest, 
- $U$ is the upper-bound quantitative phenotype of interest, and
- $y$ is the final observed phenotype.

The key thing to think about here is how each of the effects $e_i$, $e_{i, k}$, and so on for higher-order effects, are modeled. We can think of these as generative functions for the effects. If we think about the problem in a binary genotype setting, then $e_i(g_i)$ might be modeled as:

$$e_i(g_i) = \theta_{i} g_i$$

where $\theta_i$ is nothing more than a slope parameter that we have to estimate. When $g_i$ is 0, i.e. the genotype taken on at position $i$ is 0, then there is no effect for the genotype at that position; when $g_i$ is 1, i.e. the genotype at position $i$ is 1, then the effect for the genotype at that position is $\theta_i$.

This feels like a natural point to make things concrete by generating data. I will assume the 2-position, binary genotype situation for mental tractability, and then find ways to generalize to >2 positions and >2 genotypes per position.

To start, we will need simulated data. Following the paper, we are going to simulate a genotype-phenotype system where there are:

1. Two possible genotypes at each position, A and B.
2. Three positions, 0, 1, and 2
3. Linear coefficients for each of the positions
4. Interactions between positions 0 and 2, such that a logical XOR between positions 0 and 2 flips the signs of the coefficients.

This should serve as a minimally complex example for reference-free analysis.
On one hand, we should be able to enumerate every single phenotype possible -- there are only 8 in total.
On the other hand, the interaction term should give us a curveball that makes it difficult to use just a linear model to estimate parameters.

In [None]:
import jax.numpy as np
from jax import random

key = random.PRNGKey(1)

# Establish linear coefficients for each position - to start, they will just be hard-coded.
coefficients = random.uniform(key, shape=(3,))
coefficients

In [None]:
# Enumerate all possible genotypes.
from itertools import product
genotypes = []
for i, j, k in product(range(2), range(2), range(2)):
    genotype = [i, j, k]
    genotypes.append(genotype)
genotypes = np.array(genotypes)
genotypes

In [None]:
# Simulate what the phenotypes will look like if we just had a linear combination of genotypes.
from jax import vmap 
def phenotype_without_interactions(genotype):
    return np.dot(genotype, coefficients)

vmap(phenotype_without_interactions)(genotypes)

In [None]:
# Simulate what would happen if we had the interaction terms.
from jax import vmap 
from jax import lax 
genotype = genotypes[4]

def phenotype_with_interactions(genotype):
    def true_fn(coefficients):
        return np.array([-1.0, 1.0, -1.0]) * coefficients
    
    def false_fn(coefficients):
        return coefficients
    
    interaction = np.logical_xor(genotype[0], genotype[2])
    
    coeff = lax.cond(interaction, true_fn, false_fn, coefficients)
    
    return np.dot(genotype, coeff)

phenotypes = vmap(phenotype_with_interactions)(genotypes)
# phenotypes = vmap(phenotype_without_interactions)(genotypes)
phenotypes

With this, we need to note that the coefficients that we are going to estimate should differ from the ground truth coefficients that were used to generate the data.

Cool, we've done it! Let's now get RFA going.

First off, we need to estimae $e_0$, which is the global average phenotype value.
This is trivial.

In [None]:
def zeroth_order_effects(genotypes, phenotypes):
    return np.mean(phenotypes)

e_0 = zeroth_order_effects(genotypes, phenotypes)
e_0

Next up, we need to estimate the first-order effects. To do so, we first need to calculate the average values for each position's state.

What does this mean? It means we need to calculate a `(num_states, num_positions)` array where each entry is the average phenotype value when a sequence contains a particular `state` at a particular `position`.

In [None]:
phenotypes

In [None]:
genotypes

In [None]:
def get_indices_with_genotype(sequences, genotype, site):
    idxs = []
    for i, sequence in enumerate(sequences):
        if sequence[site] == genotype:
            idxs.append(i)
    return np.array(idxs)


get_indices_with_genotype(genotypes, 1, 0)

In [None]:
def calculate_single_genotype_averages(genotypes, phenotypes):
    """
    Calculate the average phenotype for each genotype at each site.

    Args:
        genotypes (ndarray): Array of shape (num_sites,) representing the binary genotypes.
        phenotypes (ndarray): Array of shape (num_genotypes,) representing the phenotypes.

    Returns:
        ndarray: Array of shape (2, num_sites) containing the average phenotype for each state (0 or 1) at each site.
    """
    num_sites = len(genotypes[0])
    num_states = np.max(genotypes) + 1
    state_averages = np.zeros(shape=(num_states, num_sites))

    for site in range(num_sites):
        for state in range(num_states):
            # Calculate average phenotype for every genotype that has a particular genotype at `site`.
            idxs = get_indices_with_genotype(genotypes, state, site)
            phenotypes_of_interest = phenotypes[idxs]
            state_averages = state_averages.at[state, site].set(np.mean(phenotypes_of_interest))
    return state_averages

calculate_single_genotype_averages(genotypes, phenotypes)

In [None]:
def first_order_effects(genotypes, phenotypes):
    e_0 = zeroth_order_effects(genotypes, phenotypes)
    single_genotype_averages = calculate_single_genotype_averages(genotypes, phenotypes)
    return single_genotype_averages - e_0

e_1 = first_order_effects(genotypes, phenotypes)
e_1

The shape of this array is `(num_genotype_states, num_positions)`.

In [None]:
e_1.sum()

This property must hold that it equals to 0 for a comprehensive genotype-phenotype enumeration. We can check this by changing the random key above and re-running the program. It should also hold true whether or not we use genotype_with_interactions or genotype_without_interactions.

Now we're going to calculate the second-order effects.

Second-order effects are calculated as follows:

> For every _pair_ of sites, and for every possible _pair_ of genotypes at that site, we calculate the average phenotype value.

In [None]:
def get_indices_with_double_genotype(sequences, sites, genotypes):
    idxs = []
    for i, sequence in enumerate(sequences):
        if sequence[sites[0]] == genotypes[0] and sequence[sites[1]] == genotypes[1]:
            idxs.append(i)
    return np.array(idxs)


get_indices_with_double_genotype(genotypes, sites=[0, 1], genotypes=[1, 0])

In [None]:
# This means we have to have an array that is of shape (n_genotype_states, n_genotype_states, n_positions, n_positions).
from itertools import combinations

def calculate_double_genotype_averages(genotypes, phenotypes):
    """
    Calculate the average phenotype for each genotype at each site.

    Args:
        genotypes (ndarray): Array of shape (num_sites,) representing the binary genotypes.
        phenotypes (ndarray): Array of shape (num_genotypes,) representing the phenotypes.

    Returns:
        ndarray: Array of shape (num_states, num_sites, num_states, num_sites) containing the average phenotype for each state (0 or 1) at each site.
    """
    num_sites = len(genotypes[0])
    num_states = np.max(genotypes) + 1
    state_averages = np.zeros(shape=(num_states, num_sites, num_states, num_sites))

    for site1, site2 in combinations(range(num_sites), 2):
        for state1, state2 in product(range(num_states), range(num_states)):
            # Calculate average phenotype for every genotype that has a particular genotype at `site`.
            sites = np.array([site1, site2])
            states = np.array([state1, state2])
            idxs = get_indices_with_double_genotype(genotypes, sites, states)
            phenotypes_of_interest = phenotypes[idxs]
            state_averages = state_averages.at[state1, site1, state2, site2].set(np.mean(phenotypes_of_interest))
    return state_averages


double_genotype_averages = calculate_double_genotype_averages(genotypes, phenotypes)
double_genotype_averages

In [None]:
# Within the equation for 2nd order effects, we have to sum up e_0 + (e_1 for each position's particular genotype)
# This is how we do it:
site1 = 0
site2 = 2
state1 = 0
state2 = 0

e_1_s1 = e_1.at[state1, site1].get()
e_1_s2 = e_1.at[state2, site2].get()

e_1_s1, e_1_s2

e_0 + e_1_s1 + e_1_s2

In [None]:
len(genotypes[0])

In [None]:
# Made generalized:

def second_order_effects(genotypes, phenotypes):
    e_0 = zeroth_order_effects(genotypes, phenotypes)
    e_1 = first_order_effects(genotypes, phenotypes)
    double_genotype_averages = calculate_double_genotype_averages(genotypes, phenotypes)

    num_sites = len(genotypes[0])
    num_states = np.max(genotypes) + 1

    effects = np.zeros_like(double_genotype_averages)
    for site1, site2 in combinations(range(num_sites), 2):
        for state1, state2 in product(range(num_states), range(num_states)):
            phenotype_average = double_genotype_averages.at[state1, site1, state2, site2].get()
            effect = phenotype_average - (e_0 + e_1.at[state1, site1].get() + e_1.at[state2, site2].get())
            effects = effects.at[state1, site1, state2, site2].set(effect)
    return effects


e_2 = second_order_effects(genotypes, phenotypes)
e_2

In [None]:
# Also equals to 0! (or infinitesimally close)
assert np.allclose(e_2.sum(), 0)
e_2.sum()

In [None]:
# Its shape should be:
num_sites = len(genotypes[0])
num_states = np.max(genotypes) + 1
assert e_2.shape == (num_states, num_sites, num_states, num_sites)

Now, try to predict the phenotype value of a genotype given the zeroth, first, and second order effects.

We should be able to accurately calculate it all.

In [None]:
# Start with the genotype [0, 0, 0]
e_0 = zeroth_order_effects(genotypes, phenotypes)
e_1 = first_order_effects(genotypes, phenotypes)
e_2 = second_order_effects(genotypes, phenotypes)

i = 7

genotype = genotypes[i]

def get_first_order_effect(e_1, genotype):
    effects = []
    for site, state in enumerate(genotype):
        effects.append(e_1.at[state, site].get())
    return np.sum(np.array(effects))


def get_second_order_effect(e_2, genotype):
    effects = []
    num_sites = len(genotype)
    num_states = 2 # hard-coded
    for site1, site2 in combinations(range(num_sites), 2):
        state1 = genotype.at[site1].get()
        state2 = genotype.at[site2].get()
        effects.append(e_2.at[state1, site1, state2, site2].get())
    return np.sum(np.array(effects))

get_second_order_effect(e_2, genotype)

total = e_0 + get_first_order_effect(e_1, genotype) + get_second_order_effect(e_2, genotype)
total - phenotypes[i]

Things work!

We've test-driven the ideas on a simple binary genotype system with 3 positions.
We should test-drive the same with a 5-genotype system with 3 positions,
just to make sure the results are robust.
To do so, however, we need to refactor the code and ensure its correctness with software tests.
This is totally in line with the idea of "Software Engineering as Research Practice",
a topic I've written about on my [blog](https://ericmjl.github.io/blog//2020/8/21/software-engineering-as-a-research-practice/index.html).