# `multidms` fitting pipeline

Here, we demonstrate the pipeline for fitting a `multidms` model to data from [six deep mutational scanning experiments](https://github.com/dms-vep) across 3 homologs of the Spike protein.

In [1]:
import os
import sys
from collections import defaultdict
import time

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as onp
from tqdm.notebook import tqdm
import jax.numpy as jnp

import multidms
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [12]:
# import jax
# from jax import grad, jit
# def f(params, data):
#     """ret = transfors, ifs, loops"""
#     return (data @ params['weights']) + params['bias']

# # f_prime = df_{params}/d data
# f_prime = grad(f)

# # Jit-compiled
# f_compiled = jit(f)

# %timeit f(init_params, random_feature_array)
# %timeit f_compiled(init_params, random_feature_array)

21.8 µs ± 1.18 µs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
4.08 µs ± 459 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [25]:
# # import time
# seed = 0
# key = jax.random.PRNGKey(seed)
# n_features = 10000
# n_samples = 10

# key1, key2, key3 = jax.random.split(key, num=3)
# data = jax.random.normal(shape=(n_samples, n_features), key=key1)
# params=dict(
#     weights=jax.random.normal(shape=(n_features,), key=key2),
#     bias=jax.random.normal(shape=(1,), key=key3)
# )

# print("Benchmarking non jit-compiled")
# %timeit f(params, data).block_until_ready()

# print("Benchmarking jit-compiled")
# %timeit f_compiled(params, data).block_until_ready()

Benchmarking non jit-compiled
69.9 µs ± 16.8 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
Benchmarking jit-compiled
30 µs ± 548 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [8]:
# import numpy as np
# from jax.nn import selu
# from jax import numpy as jnp
# from jax import random as jran
# from jax import jit

# def feedforward_prediction(params, abscissa):

#     activations = abscissa
    
#     #  Loop over every dense layer except the last
#     for w, b in params[:-1]:
#         outputs = jnp.dot(w, activations) + b  # apply affine transformation
#         activations = selu(outputs)  #  apply nonlinear activation
        
#     #  Now for the final layer
#     w_final, b_final = params[-1]
#     final_outputs = jnp.dot(w_final, activations) + b_final 
#     return final_outputs  # Final layer is just w*x + b with no activation


# def get_random_layer_params(m, n, ran_key, scale=0.01):
#     """Helper function to randomly initialize 
#     weights and biases using the JAX-defined randoms."""
#     w_key, b_key = jran.split(ran_key)
#     ran_weights = scale * jran.normal(w_key, (n, m))
#     ran_biases = scale * jran.normal(b_key, (n,)) 
#     return ran_weights, ran_biases


# def get_init_network_params(sizes, ran_key):
#     """Initialize all layers for a fully-connected neural network."""
#     keys = jran.split(ran_key, len(sizes))
#     return [get_random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]


# def get_network_layer_sizes(n_features, n_targets, n_layers, n_neurons_per_layer):
#     dense_layer_sizes = [n_neurons_per_layer]*n_layers
#     layer_sizes = [n_features, *dense_layer_sizes, n_targets]
#     return layer_sizes


In [9]:
# SEED = 0
# ran_key = jran.PRNGKey(SEED)
# num_random_samples = 10
# num_features, num_targets = 1, 1
# num_layers, num_neurons_per_layer = 2, 8

# layer_sizes = get_network_layer_sizes(
#     num_features, num_targets, num_layers, num_neurons_per_layer)
# init_params = get_init_network_params(layer_sizes, ran_key)
# ran_key, func_key = jran.split(ran_key)
# random_feature_array = jran.uniform(func_key, minval=0, maxval=1, shape=(num_features, ))

In [13]:
# %timeit feedforward_prediction(init_params, random_feature_array)
# jit_feedforward = jit(feedforward_prediction)
# %timeit jit_feedforward(init_params, random_feature_array)

20.1 µs ± 2.15 µs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
3.48 µs ± 501 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [3]:
# @Partial(jax.jit,static_argnums=(0,1,),)
# def f(ϕ, g, d_params, v_d):
#     return g(d_params["α"], ϕ(d_params, v_d))

# @jax.jit
# def g(α_params, z_d):
#     activations = jax.nn.sigmoid(z_d[:, None])
#     return (
#         (α_params["ge_scale"] @ activations.T)
#         + α_params["ge_bias"]
#     )

# @jax.jit
# def ϕ(d_params, v_d):
#     return (
#         d_params["C_ref"]
#         + d_params["C_d"]
#         + (v_d @ (d_params["β_m"] + d_params["s_md"]))
#     )

In [35]:
import multidms.model as model
af = model.softplus_activation

In [36]:
af == model.softplus_activation

True

In [66]:
imodel = multidms.MultiDmsModel(data, epistatic_model=model.perceptron_global_epistasis)
imodel

<multidms.multidms.MultiDmsModel at 0x7f6a6881cdc0>

In [75]:
imodel.data.conditions

('Omicron_BA.2-1', 'Delta-3')

In [80]:
imodel.params

{'β': DeviceArray([-1.19485962, -1.46153493, -0.30373774, ...,  0.63124044,
              -0.40762012,  0.39962111], dtype=float64),
 'S_Omicron_BA.2-1': DeviceArray([0., 0., 0., ..., 0., 0., 0.], dtype=float64),
 'C_Omicron_BA.2-1': DeviceArray([0.], dtype=float64),
 'S_Delta-3': DeviceArray([0., 0., 0., ..., 0., 0., 0.], dtype=float64),
 'C_Delta-3': DeviceArray([0.], dtype=float64),
 'C_ref': DeviceArray([5.], dtype=float64),
 'α': {'p_weights_1': DeviceArray([ 1.33073845,  0.06860843, -0.93994404, -0.14075135,
                1.03805961], dtype=float64),
  'p_weights_2': DeviceArray([-0.8621453 ,  0.12101575,  0.38603345, -0.53469579,
                0.07089547], dtype=float64),
  'p_biases': DeviceArray([-0.80525932,  1.56270888,  0.14595164, -0.2757112 ,
                0.93999988], dtype=float64)},
 'γ_Omicron_BA.2-1': DeviceArray([0.], dtype=float64),
 'γ_Delta-3': DeviceArray([0.], dtype=float64)}

In [77]:
X_h = data.training_data['X']['Delta-3']
X_h.shape

(2224, 5807)

In [79]:
z_h = imodel.latent_model(imodel.get_condition_params('Delta-3'), X_h)
print(z_h.shape)
print(z_h[:, None].shape)

(2224,)
(2224, 1)


In [89]:
act = (imodel.params['α']['p_weights_1'] * z_h[:, None] + imodel.params['α']['p_biases'])
act.shape

(2224, 5)

In [91]:
imodel.params['α']['p_weights_2'] @ act.T

DeviceArray([-5.61115233, -8.48810302, -7.32794333, ..., -5.61907691,
             -2.90415721, -8.33511826], dtype=float64)

In [5]:
# func_score_df = pd.read_csv("Delta_BA1_BA2_func_score_df.csv")

# data = multidms.MultiDmsData(
#     func_score_df,
#     alphabet= multidms.AAS_WITHSTOP,
#     condition_colimport multidms.model as model
af = model.softplus_activation

af == model.softplus_activation

imodel = multidms.MultiDmsModel(data, epistatic_model=model.perceptron_global_epistasis)
imodel

imodel.data.conditions

imodel.params

X_h = data.training_data['X']['Delta-3']
X_h.shape

z_h = imodel.latent_model(imodel.get_condition_params('Delta-3'), X_h)
print(z_h.shape)
print(z_h[:, None].shape)

act = (imodel.params['α']['p_weights_1'] * z_h[:, None] + imodel.params['α']['p_biases'])
act.shape

imodel.params['α']['p_weights_2'] @ act.Tors = sns.color_palette("Paired"),
#     reference="Delta-3"
# )

# model = multidms.MultiDmsModel(
#         data,
#         epistatic_model="sigmoid",
#         output_activation="softplus"
# )

# model.fit(lasso_shift=1e-5, maxiter=5000, tol=1e-6)

In [16]:
import jax
import pickle

# @jax.jit
def foo(x):
    return x

class bar:
    def __init__(self, f):
        self.f = f
        
    def predict(self, x):
        return jax.jit(foo)(x)
        
b = bar(foo)

In [17]:
pickle.dump(b, open("b.pkl", "wb"))

In [18]:
b = pickle.load(open("b.pkl", "rb"))

In [19]:
b.predict(5)



DeviceArray(5, dtype=int32, weak_type=True)