<a href="https://colab.research.google.com/github/amirshane/protein-embedding-retrieval/blob/master/cnn_protein_landscapes.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

CNN models are powerful tools for learning protein fitness landscapes. The code in this notebook applies them to three tasks: Fluorescence, Stability, and Variant Activity.

# Imports

In [None]:
# Install TORCH for TAPE and install JAX, JAXLIB, FLAX.
!pip install --upgrade -q torch==1.6.0 jax==0.1.75 jaxlib==0.1.52 flax==0.2.0

In [None]:
import requests
import os
if 'TPU_DRIVER_MODE' not in globals():
  url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20200320'
  resp = requests.post(url)
  TPU_DRIVER_MODE = 1

# Connect Jax XLA backend to TPU
from jax.config import config
config.enable_omnistaging()
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)

# Specifying TPU configuration
# Results are dependent on configuration since TPU computations are non-deterministic
# Reported results are from "tpu_driver_nightly" configuration
# See https://cloud.google.com/tpu/docs/version-switching and
# https://github.com/google/jax/issues/4408
from tensorflow.python.tpu.client.client import Client
c = Client()
c.configure_tpu_version('tpu_driver0.1-dev20200320', restart_type='ifNeeded')

In [3]:
import os
import sys

import jax.numpy as jnp

import flax.nn as nn

import numpy as np

import pandas as pd

import matplotlib.pyplot as plt

import scipy.stats

import sklearn
from sklearn.linear_model import Ridge
import sklearn.metrics as metrics
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.model_selection import train_test_split

import pprint

import collections

import copy

## clone and pip install tape

In [None]:
!git clone https://github.com/songlab-cal/tape
!pip install -q -r tape/requirements.txt
os.chdir('tape/')

In [5]:
import tape
from tape.datasets import LMDBDataset

In [6]:
os.chdir('../')

## clone protein-embedding-retrieval repo and import

In [None]:
!git clone --recurse-submodules https://github.com/googleinterns/protein-embedding-retrieval.git

In [8]:
os.chdir('protein-embedding-retrieval')

In [None]:
pip install -e .

In [10]:
sys.path.insert(1, 'google_research/')

In [11]:
from google_research.protein_lm import domains

# Utils

## Contextual lenses

In [12]:
from contextual_lenses.contextual_lenses import max_pool, linear_max_pool

In [13]:
def flatten(x, padding_mask=None):
  """Apply padding and flatten over sequence length axis."""

  if padding_mask is not None:
    x = x * padding_mask

  rep = x.reshape(x.shape[0], x.shape[1]*x.shape[2])

  return rep

## Loss functions

In [14]:
from contextual_lenses.loss_fns import mse_loss

## Encoders

In [15]:
from contextual_lenses.encoders import one_hot_encoder, cnn_one_hot_encoder

In [16]:
def flattened_one_hot_encoder(batch_inds, num_categories):
  """Flattens padded one-hot encoding from jax.nn."""
  
  padding_mask = jnp.expand_dims(jnp.where(batch_inds < num_categories-1, 1, 0), axis=2)
  
  one_hots = one_hot_encoder(batch_inds, num_categories)
  flattened_one_hots = flatten(one_hots, padding_mask)

  return flattened_one_hots

## Data batching

In [17]:
from contextual_lenses.train_utils import create_data_iterator

## Create optimizer



In [18]:
from contextual_lenses.train_utils import create_optimizer

## Model creation

In [19]:
from contextual_lenses.train_utils import create_representation_model

## Training methods

In [20]:
from contextual_lenses.train_utils import train

## Compute embeddings

In [21]:
from contextual_lenses.pfam_utils import compute_embeddings

## Compute number of parameters

In [22]:
def get_num_params(model):
  """Computes number of parameters in flax model."""

  # Code source: https://stackoverflow.com/questions/6027558/flatten-nested-dictionaries-compressing-keys
  def dict_flatten(d, parent_key='', sep='_'):
    items = []
    for k, v in d.items():
        new_key = parent_key + sep + k if parent_key else k
        if isinstance(v, collections.MutableMapping):
            items.extend(dict_flatten(v, new_key, sep=sep).items())
        else:
            items.append((new_key, v))
    return dict(items)

  params = model.params
  params = dict_flatten(params)

  num_params = 0
  for layer in params.keys():
    num_params += np.prod(params[layer].shape)
  
  return num_params

# Fluorescence

The data originally comes from [Sarkisyan et. al.](https://www.nature.com/articles/nature17995) and was formatted by [TAPE](https://github.com/songlab-cal/tape).

## Open data

In [None]:
!wget http://s3.amazonaws.com/proteindata/data_pytorch/fluorescence.tar.gz

In [24]:
!tar xzf fluorescence.tar.gz

In [25]:
def gfp_dataset_to_df(in_name):
  dataset = LMDBDataset(in_name)
  df = pd.DataFrame(list(dataset)[:])
  df['log_fluorescence'] = df.log_fluorescence.apply(lambda x: x[0])
  return df

### Padding and one-hot encoding

In [26]:
GFP_SEQ_LEN = 237

In [27]:
GFP_AMINO_ACID_VOCABULARY = [
    'A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R',
    'S', 'T', 'V', 'W', 'Y', '-'
]

In [28]:
GFP_PROTEIN_DOMAIN = domains.VariableLengthDiscreteDomain(
    vocab=domains.ProteinVocab(include_anomalous_amino_acids=False,
                               include_eos=True,
                               include_pad=True),
    length=GFP_SEQ_LEN)

In [29]:
def gfp_seq_to_inds(seq):
  """Encode GFP amino acid sequence."""

  return GFP_PROTEIN_DOMAIN.encode([seq])[0]

### Add one-hots and batch data

In [30]:
def create_gfp_df(test=False):
  """Processes GFP data into a featurized dataframe."""
  
  if test:
    gfp_df = gfp_dataset_to_df('fluorescence/fluorescence_test.lmdb')
  else:
    gfp_df = gfp_dataset_to_df('fluorescence/fluorescence_train.lmdb')
  
  gfp_df['one_hot_inds'] = gfp_df.primary.apply(lambda x: gfp_seq_to_inds(x[:GFP_SEQ_LEN]))

  return gfp_df


def create_gfp_batches(batch_size, epochs=1, test=False, buffer_size=None, 
                       seed=0, drop_remainder=False):
  """Creates iterable object of GFP batches."""
  
  if test:
    buffer_size = 1
  
  gfp_df = create_gfp_df(test=test)
    
  fluorescences = gfp_df['log_fluorescence'].values

  gfp_batches = create_data_iterator(df=gfp_df, input_col='one_hot_inds', 
                                     output_col='log_fluorescence', 
                                     batch_size=batch_size, epochs=epochs, 
                                     buffer_size=buffer_size, seed=seed, 
                                     drop_remainder=drop_remainder)

  return gfp_batches, fluorescences

## Model evaluation

In [31]:
def gfp_evaluate(predict_fn, title, batch_size=256, 
                 test_data=None, pred_fluorescences=None,
                 clip_min=-999999, clip_max=999999):
  """Computes predicted fluorescences and measures performance in MSE and spearman correlation."""
  
  test_batches, test_fluorescences = create_gfp_batches(batch_size=batch_size, 
                                                        test=True, buffer_size=1)
  
  if test_data is not None:
    test_batches = test_data

  if pred_fluorescences is None:
    pred_fluorescences = []
    for batch in iter(test_batches):
      X, Y = batch
      preds = predict_fn(X)
      for pred in preds:
        pred_fluorescences.append(pred[0])
  
  pred_fluorescences = np.array(pred_fluorescences)
  pred_fluorescences = np.clip(pred_fluorescences, clip_min, clip_max)
    
  spearmanr = scipy.stats.spearmanr(test_fluorescences, pred_fluorescences).correlation
  mse = sklearn.metrics.mean_squared_error(test_fluorescences, pred_fluorescences)
  plt.scatter(test_fluorescences, pred_fluorescences, s=1, alpha=0.5)
  plt.xlabel('True LogFluorescence')
  plt.ylabel('Predicted LogFluorescence')
  plt.title(title)
  plt.show()

  bright_inds = np.where(test_fluorescences > 2.5)
  bright_test_fluorescences = test_fluorescences[bright_inds]
  bright_pred_fluorescences = pred_fluorescences[bright_inds]
  bright_spearmanr = scipy.stats.spearmanr(bright_test_fluorescences, bright_pred_fluorescences).correlation
  bright_mse = sklearn.metrics.mean_squared_error(bright_test_fluorescences, bright_pred_fluorescences)
  plt.scatter(bright_test_fluorescences, bright_pred_fluorescences, s=1, alpha=0.5)
  plt.xlabel('True LogFluorescence')
  plt.ylabel('Predicted LogFluorescence')
  bright_title = title + ' (Bright)'
  plt.title(bright_title)
  plt.show()

  dark_inds = np.where(test_fluorescences < 2.5)
  dark_test_fluorescences = test_fluorescences[dark_inds]
  dark_pred_fluorescences = pred_fluorescences[dark_inds]
  dark_spearmanr = scipy.stats.spearmanr(dark_test_fluorescences, dark_pred_fluorescences).correlation
  dark_mse = sklearn.metrics.mean_squared_error(dark_test_fluorescences, dark_pred_fluorescences)
  plt.scatter(dark_test_fluorescences, dark_pred_fluorescences, s=1, alpha=0.5)
  plt.xlabel('True LogFluorescence')
  plt.ylabel('Predicted LogFluorescence')
  dark_title = title + ' (Dark)'
  plt.title(dark_title)
  plt.show()

  results = {
      'title': title,
      'spearmanr': round(spearmanr, 3),
      'mse': round(mse, 3),
      'bright_spearmanr': round(bright_spearmanr, 3),
      'bright_mse': round(bright_mse, 3),
      'dark_spearmanr': round(dark_spearmanr, 3),
      'dark_mse': round(dark_mse, 3),
  }

  pprint.pprint(results)

  return results, pred_fluorescences

##  Experiments

In [32]:
gfp_train_df = create_gfp_df()
train_fluorescences = gfp_train_df['log_fluorescence']
gfp_train_one_hot_inds = np.array([x for x in gfp_train_df['one_hot_inds'].values])
gfp_train_one_hots = flattened_one_hot_encoder(gfp_train_one_hot_inds, num_categories=len(GFP_AMINO_ACID_VOCABULARY))

In [33]:
gfp_test_df = create_gfp_df(test=True)
gfp_test_one_hot_inds = np.array([x for x in gfp_test_df['one_hot_inds'].values])
gfp_test_one_hots = flattened_one_hot_encoder(gfp_test_one_hot_inds, num_categories=len(GFP_AMINO_ACID_VOCABULARY))

### Linear regression

In [34]:
gfp_linear_model = Ridge()

In [35]:
gfp_linear_model.fit(X=gfp_train_one_hots, y=train_fluorescences)

Ridge(alpha=1.0, copy_X=True, fit_intercept=True, max_iter=None,
      normalize=False, random_state=None, solver='auto', tol=0.001)

In [36]:
linear_model_pred_fluorescences = gfp_linear_model.predict(gfp_test_one_hots)

In [None]:
gfp_linear_model_results, linear_model_pred_fluorescences = \
gfp_evaluate(predict_fn=None, 
             title='Linear Regression',
             pred_fluorescences=linear_model_pred_fluorescences,
             clip_min=min(train_fluorescences),
             clip_max=max(train_fluorescences))

In [None]:
print('Number of Parameters for Fluorescence Linear Regression: ' + str(len(gfp_linear_model.coef_)))

### CNN + MaxPool

In [None]:
epochs = 50
gfp_train_batches, train_fluorescences = create_gfp_batches(batch_size=256, epochs=epochs)

layers = ['CNN_0', 'Dense_1']                                 
learning_rate = [1e-3, 5e-6]
weight_decay = [0.0, 0.05]

encoder_fn = cnn_one_hot_encoder
encoder_fn_kwargs = {
    'n_layers': 1,
    'n_features': [1024],
    'n_kernel_sizes': [5],
    'n_kernel_dilations': None
}
reduce_fn = max_pool
reduce_fn_kwargs = {

}
loss_fn_kwargs = {
    
}

gfp_model = create_representation_model(encoder_fn=encoder_fn,
                                        encoder_fn_kwargs=encoder_fn_kwargs,
                                        reduce_fn=reduce_fn,
                                        reduce_fn_kwargs=reduce_fn_kwargs,
                                        num_categories=len(GFP_AMINO_ACID_VOCABULARY),
                                        output_features=1)

gfp_optimizer = train(model=gfp_model,
                      train_data=gfp_train_batches, 
                      loss_fn=mse_loss,
                      loss_fn_kwargs=loss_fn_kwargs,
                      learning_rate=learning_rate, 
                      weight_decay=weight_decay,
                      layers=layers)

gfp_results, pred_fluorescences = gfp_evaluate(predict_fn=gfp_optimizer.target,
                                               title='CNN + MaxPool',
                                               batch_size=256,
                                               clip_min=min(train_fluorescences),
                                               clip_max=max(train_fluorescences))

In [None]:
print('Number of Parameters for Fluorescence CNN + MaxPool: ' + str(get_num_params(gfp_optimizer.target)))

### CNN + LinearMaxPool

In [None]:
epochs = 50
gfp_train_batches, train_fluorescences = create_gfp_batches(batch_size=256, epochs=epochs)

layers = ['CNN_0', 'Dense_1', 'Dense_2']                                 
learning_rate = [1e-3, 5e-5, 5e-6]
weight_decay = [0.0, 0.05, 0.05]

encoder_fn = cnn_one_hot_encoder
encoder_fn_kwargs = {
    'n_layers': 1,
    'n_features': [1024],
    'n_kernel_sizes': [5],
    'n_kernel_dilations': None
}
reduce_fn = linear_max_pool
reduce_fn_kwargs = {
    'rep_size': 2048
}
loss_fn_kwargs = {
    
}

gfp_model_l = create_representation_model(encoder_fn=encoder_fn,
                                          encoder_fn_kwargs=encoder_fn_kwargs,
                                          reduce_fn=reduce_fn,
                                          reduce_fn_kwargs=reduce_fn_kwargs,
                                          num_categories=len(GFP_AMINO_ACID_VOCABULARY),
                                          output_features=1)

gfp_optimizer_l = train(model=gfp_model_l,
                        train_data=gfp_train_batches, 
                        loss_fn=mse_loss,
                        loss_fn_kwargs=loss_fn_kwargs,
                        learning_rate=learning_rate, 
                        weight_decay=weight_decay,
                        layers=layers)

gfp_results_l, pred_fluorescences_l = gfp_evaluate(predict_fn=gfp_optimizer_l.target,
                                                   title='CNN + LinearMaxPool',
                                                   batch_size=256,
                                                   clip_min=min(train_fluorescences),
                                                   clip_max=max(train_fluorescences))

In [None]:
print('Number of Parameters for Fluorescence CNN + LinearMaxPool: ' + str(get_num_params(gfp_optimizer_l.target)))

### TAPE (Best Performance)

In [43]:
tape_fluorescence_results = {
    'title': 'TAPE',
    'spearmanr': 0.68,
    'mse': 0.19,
    'bright_spearmanr': 0.63,
    'bright_mse': 0.07,
    'dark_mse': 0.22,
    'dark_spearmanr': 0.05,
}

In [None]:
pprint.pprint(tape_fluorescence_results)

## Visualization

In [45]:
def gfp_embeddings(encoder_fn, encoder_fn_kwargs, reduce_fn, reduce_fn_kwargs, optimizer):
  """Computes GFP embeddings from given optimizer."""

  gfp_encoding_model = create_representation_model(encoder_fn=encoder_fn,
                                                   encoder_fn_kwargs=encoder_fn_kwargs,
                                                   reduce_fn=reduce_fn,
                                                   reduce_fn_kwargs=reduce_fn_kwargs,
                                                   num_categories=len(GFP_AMINO_ACID_VOCABULARY),
                                                   output='embedding',
                                                   output_features=1)
  
  trained_params = copy.deepcopy(optimizer.target.params)

  gfp_encoding_optimizer = create_optimizer(gfp_encoding_model,
                                            learning_rate=learning_rate,
                                            weight_decay=weight_decay,
                                            layers=layers)
  
  for layer in gfp_encoding_optimizer.target.params.keys():
    gfp_encoding_optimizer.target.params[layer] = trained_params[layer]

  gfp_train_batches, train_fluorescences = create_gfp_batches(batch_size=256, buffer_size=1)
  gfp_train_embeddings = compute_embeddings(gfp_encoding_optimizer.target, gfp_train_batches)
  
  gfp_test_batches, test_fluorescences = create_gfp_batches(batch_size=256, test=True)
  gfp_test_embeddings = compute_embeddings(gfp_encoding_optimizer.target, gfp_test_batches)

  return gfp_train_embeddings, gfp_test_embeddings

In [46]:
# CNN + MaxPool Optimizer Embeddings
encoder_fn = cnn_one_hot_encoder
encoder_fn_kwargs = {
    'n_layers': 1,
    'n_features': [1024],
    'n_kernel_sizes': [5],
    'n_kernel_dilations': None
}
reduce_fn = max_pool
reduce_fn_kwargs = {
    
}
gfp_train_embeddings, gfp_test_embeddings = gfp_embeddings(encoder_fn, 
                                                           encoder_fn_kwargs,
                                                           reduce_fn,
                                                           reduce_fn_kwargs, 
                                                           gfp_optimizer)

In [47]:
# Random CNN + MaxPool Optimizer Embeddings
layers = ['CNN_0', 'Dense_1']                                 
learning_rate = [1e-3, 5e-6]
weight_decay = [0.0, 0.05]

encoder_fn = cnn_one_hot_encoder
encoder_fn_kwargs = {
    'n_layers': 1,
    'n_features': [1024],
    'n_kernel_sizes': [5],
    'n_kernel_dilations': None
}
reduce_fn = max_pool
reduce_fn_kwargs = {

}
loss_fn_kwargs = {
    
}

gfp_model = create_representation_model(encoder_fn=encoder_fn,
                                        encoder_fn_kwargs=encoder_fn_kwargs,
                                        reduce_fn=reduce_fn,
                                        reduce_fn_kwargs=reduce_fn_kwargs,
                                        num_categories=len(GFP_AMINO_ACID_VOCABULARY),
                                        output_features=1)

gfp_optimizer_random = create_optimizer(gfp_model, learning_rate, weight_decay, layers)

gfp_train_embeddings_random, gfp_test_embeddings_random = gfp_embeddings(encoder_fn, 
                                                                         encoder_fn_kwargs,
                                                                         reduce_fn,
                                                                         reduce_fn_kwargs, 
                                                                         gfp_optimizer_random)

In [48]:
# CNN + LinearMaxPool Optimizer Embeddings
encoder_fn = cnn_one_hot_encoder
encoder_fn_kwargs = {
    'n_layers': 1,
    'n_features': [1024],
    'n_kernel_sizes': [5],
    'n_kernel_dilations': None
}
reduce_fn = linear_max_pool
reduce_fn_kwargs = {
    'rep_size': 2048
}
gfp_train_embeddings_l, gfp_test_embeddings_l = gfp_embeddings(encoder_fn,
                                                               encoder_fn_kwargs,
                                                               reduce_fn,
                                                               reduce_fn_kwargs,
                                                               gfp_optimizer_l)

### Train PCA

In [49]:
def gfp_train_pca_plot(train_embeddings, model_name):
  """Applies and plots PCA on GFP train embeddings."""

  gfp_train_embeddings_pca = PCA(n_components=2).fit_transform(train_embeddings)
  
  X_train = [g[0] for g in gfp_train_embeddings_pca]
  Y_train = [g[1] for g in gfp_train_embeddings_pca]

  gfp_train_df = create_gfp_df()

  plt.figure(figsize=(8, 6))
  plt.scatter(X_train, Y_train, c=gfp_train_df.log_fluorescence.values, s=1, alpha=0.5)
  plt.title('PCA of Train GFP Embeddings (' + model_name + ')')
  plt.xlabel('Principal Component 1')
  plt.ylabel('Principal Component 2')
  plt.colorbar().set_label('LogFluorescence')
  plt.show()

  plt.figure(figsize=(8, 6))
  plt.scatter(X_train, Y_train, c=gfp_train_df.num_mutations.values, s=1, alpha=0.5)
  plt.title('PCA of Train GFP Embeddings( ' + model_name + ')')
  plt.xlabel('Principal Component 1')
  plt.ylabel('Principal Component 2')
  cbar = plt.colorbar()
  cbar.set_label('Mutations (Edit Distance)')
  cbar.set_ticks(range(4))
  plt.show()

In [None]:
gfp_train_pca_plot(gfp_train_embeddings_random, model_name='Random CNN + MaxPool')

In [None]:
gfp_train_pca_plot(gfp_train_embeddings, model_name='CNN + MaxPool')

In [None]:
gfp_train_pca_plot(gfp_train_embeddings_l, model_name='CNN + LinearMaxPool')

### Test PCA

In [53]:
def gfp_test_pca_plot(test_embeddings, model_name, max_mut=8):
  """Applies and plots PCA on GFP test embeddings."""

  gfp_test_embeddings_pca = PCA(n_components=2).fit_transform(test_embeddings)
  
  X_test = [g[0] for g in gfp_test_embeddings_pca]
  Y_test = [g[1] for g in gfp_test_embeddings_pca]

  gfp_test_df = create_gfp_df(test=True)
  inds = [i for i in range(len(X_test)) if gfp_test_df.num_mutations.values[i]<=max_mut]

  X_test_mut_cap = [X_test[i] for i in inds]
  Y_test_mut_cap = [Y_test[i] for i in inds]
  mutations_mut_cap = [gfp_test_df.num_mutations.values[i] for i in inds]
  fluorescences_mut_cap = [gfp_test_df.log_fluorescence.values[i] for i in inds]

  plt.figure(figsize=(8, 6))
  plt.scatter(X_test_mut_cap, Y_test_mut_cap, c=fluorescences_mut_cap, s=1, alpha=0.25)
  plt.title('PCA of Test GFP Embeddings (' + model_name + ')')
  plt.xlabel('Principal Component 1')
  plt.ylabel('Principal Component 2')
  cbar = plt.colorbar()
  cbar.set_label('LogFluorescence')
  plt.show()

  plt.figure(figsize=(8, 6))
  plt.scatter(X_test_mut_cap, Y_test_mut_cap, c=mutations_mut_cap, s=1, alpha=0.25)
  plt.title('PCA of Test GFP Embeddings (' + model_name + ')')
  plt.xlabel('Principal Component 1')
  plt.ylabel('Principal Component 2')
  cbar = plt.colorbar()
  cbar.set_label('Mutations (Edit Distance)')
  cbar.set_ticks(range(4, max_mut+1))
  plt.show()

In [None]:
gfp_test_pca_plot(gfp_test_embeddings_random, model_name='Random CNN + MaxPool', max_mut=9)

In [None]:
gfp_test_pca_plot(gfp_test_embeddings, model_name='CNN + MaxPool', max_mut=9)

In [None]:
gfp_test_pca_plot(gfp_test_embeddings_l, model_name='CNN + LinearMaxPool', max_mut=9)

### Train t-SNE

In [57]:
def gfp_train_tsne_plot(train_embeddings, model_name, num_samples=2500):
  """Applies and plots t-SNE on GFP train embeddings."""

  gfp_train_df = create_gfp_df()

  np.random.seed(0)
  gfp_train_pairs = np.random.permutation(np.array([(train_embeddings[i], gfp_train_df.log_fluorescence.values[i], gfp_train_df.num_mutations.values[i]) for i in range(len(train_embeddings))]))
  sub_gfp_train_pairs = gfp_train_pairs[:num_samples]

  sub_gfp_train_embeddings = np.array([g[0] for g in sub_gfp_train_pairs])
  sub_gfp_train_fluorescences = np.array([g[1] for g in sub_gfp_train_pairs])
  sub_gfp_train_mutations = np.array([g[2] for g in sub_gfp_train_pairs])

  gfp_train_embeddings_tsne = TSNE(n_components=2).fit_transform(sub_gfp_train_embeddings)

  X_tsne_train = [g[0] for g in gfp_train_embeddings_tsne]
  Y_tsne_train = [g[1] for g in gfp_train_embeddings_tsne]

  plt.figure(figsize=(8, 6))
  plt.scatter(X_tsne_train, Y_tsne_train, c=sub_gfp_train_fluorescences, s=5, alpha=0.5)
  plt.title('t-SNE of Train GFP Embeddings (' + model_name + ')')
  cbar = plt.colorbar()
  cbar.set_label('LogFluorescence')
  plt.show()

  plt.figure(figsize=(8, 6))
  plt.scatter(X_tsne_train, Y_tsne_train, c=sub_gfp_train_mutations, s=5, alpha=0.5)
  plt.title('t-SNE of Train GFP Embeddings (' + model_name + ')')
  cbar = plt.colorbar()
  cbar.set_label('Mutations (Edit Distance)')
  cbar.set_ticks(range(4))
  plt.show()

In [None]:
gfp_train_tsne_plot(gfp_train_embeddings_random, model_name='Random CNN + MaxPool', num_samples=2500)

In [None]:
gfp_train_tsne_plot(gfp_train_embeddings, model_name='CNN + MaxPool', num_samples=2500)

In [None]:
gfp_train_tsne_plot(gfp_train_embeddings_l, model_name='CNN + LinearMaxPool', num_samples=2500)

### Test t-SNE

In [61]:
def gfp_test_tsne_plot(test_embeddings, model_name, max_mut=8, num_samples=2500):
  """Applies and plots t-SNE on GFP test embeddings."""

  gfp_test_df = create_gfp_df(test=True)

  np.random.seed(0)
  gfp_test_pairs_mut_cap = np.random.permutation(np.array([(test_embeddings[i], gfp_test_df.log_fluorescence.values[i], gfp_test_df.num_mutations.values[i]) for i in range(len(test_embeddings)) if gfp_test_df.num_mutations.values[i]<=max_mut]))
  sub_gfp_test_pairs_mut_cap = gfp_test_pairs_mut_cap[:num_samples]

  sub_gfp_test_embeddings_mut_cap = np.array([g[0] for g in sub_gfp_test_pairs_mut_cap])
  sub_gfp_test_fluorescences_mut_cap = np.array([g[1] for g in sub_gfp_test_pairs_mut_cap])
  sub_gfp_test_mutations_mut_cap = np.array([g[2] for g in sub_gfp_test_pairs_mut_cap])

  gfp_test_embeddings_tsne_mut_cap = TSNE(n_components=2).fit_transform(sub_gfp_test_embeddings_mut_cap)

  X_tsne_test_mut_cap = [g[0] for g in gfp_test_embeddings_tsne_mut_cap]
  Y_tsne_test_mut_cap = [g[1] for g in gfp_test_embeddings_tsne_mut_cap]

  plt.figure(figsize=(8, 6))
  plt.scatter(X_tsne_test_mut_cap, Y_tsne_test_mut_cap, c=sub_gfp_test_fluorescences_mut_cap, s=5, alpha=0.5)
  plt.title('t-SNE of Test GFP Embeddings (' + model_name + ')')
  cbar = plt.colorbar()
  cbar.set_label('LogFluorescence')
  plt.show()

  plt.figure(figsize=(8, 6))
  plt.scatter(X_tsne_test_mut_cap, Y_tsne_test_mut_cap, c=sub_gfp_test_mutations_mut_cap, s=5, alpha=0.5)
  plt.title('t-SNE of Test GFP Embeddings (' + model_name + ')')
  cbar = plt.colorbar()
  cbar.set_label('Mutations (Edit Distance)')
  cbar.set_ticks(range(4, max_mut+1))
  plt.show()

In [None]:
gfp_test_tsne_plot(gfp_test_embeddings_random, model_name='Random CNN + MaxPool', max_mut=9, num_samples=2500)

In [None]:
gfp_test_tsne_plot(gfp_test_embeddings, model_name='CNN + MaxPool', max_mut=9, num_samples=2500)

In [None]:
gfp_test_tsne_plot(gfp_test_embeddings_l, model_name='CNN + LinearMaxPool', max_mut=9, num_samples=2500)

# Stability

The data originally comes [from Rocklin et. al.](https://science.sciencemag.org/content/357/6347/168) and was formatted by [TAPE](https://github.com/songlab-cal/tape).

## Open data

In [None]:
!wget http://s3.amazonaws.com/proteindata/data_pytorch/stability.tar.gz

In [66]:
!tar xzf stability.tar.gz

In [67]:
def stability_dataset_to_df(in_name):
  dataset = LMDBDataset(in_name)
  df = pd.DataFrame(list(dataset)[:])
  df['stability'] = df.stability_score.apply(lambda x: x[0])
  df['id_str'] = df.id.apply(lambda x: x.decode('utf-8'))
  return df

### Padding and one-hot encoding

In [68]:
STABILITY_SEQ_LEN = 50

In [69]:
STABILITY_AMINO_ACID_VOCABULARY = [
    'A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R',
    'S', 'T', 'V', 'W', 'Y', '-'
]

In [70]:
STABILITY_PROTEIN_DOMAIN = domains.VariableLengthDiscreteDomain(
    vocab=domains.ProteinVocab(include_anomalous_amino_acids=False,
                               include_eos=True,
                               include_pad=True),
    length=STABILITY_SEQ_LEN)

In [71]:
def stability_seq_to_inds(seq):
  """Encode stability amino acid sequence."""

  return STABILITY_PROTEIN_DOMAIN.encode([seq])[0]

### Add one-hots and batch data

In [72]:
stability_train_df = stability_dataset_to_df('stability/stability_train.lmdb')
stability_test_df = stability_dataset_to_df('stability/stability_test.lmdb')

In [73]:
parent_to_parent_stability = {}

In [74]:
for parent in set(stability_train_df.parent.values): 
  stabilities = stability_train_df[stability_train_df['id_str']==parent.decode('utf-8') + '.pdb'].stability.values
  if len(stabilities) == 0:
    stabilities = stability_train_df[stability_train_df['id_str']==parent.decode('utf-8')].stability.values
    if len(stabilities) == 0:
      parent_to_parent_stability[parent] = None
    else:
      parent_to_parent_stability[parent] = stabilities[0]
  else:
    parent_to_parent_stability[parent] = stabilities[0]

In [75]:
for parent in set(stability_test_df.parent.values): 
  stabilities = stability_test_df[stability_test_df['id_str']==parent.decode('utf-8') + '.pdb'].stability.values
  if len(stabilities) == 0:
    stabilities = stability_test_df[stability_test_df['id_str']==parent.decode('utf-8')].stability.values
    if len(stabilities) == 0:
      parent_to_parent_stability[parent] = None
    else:
      parent_to_parent_stability[parent] = stabilities[0]
  else:
    parent_to_parent_stability[parent] = stabilities[0]

In [76]:
topology_to_ind = {'HHH': 0, 'HEEH': 1, 'EHEE': 3, 'EEHEE': 4}
def topology_to_index(top):
  """Returns categorical variable corresponding to a topology."""
  
  top = top.decode('utf-8')
  if top in topology_to_ind.keys():
    return topology_to_ind[top]
  else:
    return 2

In [77]:
def create_stability_df(test=False):
  """Processes stability data into a featurized dataframe."""
  
  if test:
    stability_df = stability_dataset_to_df('stability/stability_test.lmdb')
  else:
    stability_df = stability_dataset_to_df('stability/stability_train.lmdb')
  
  stability_df['one_hot_inds'] = stability_df.primary.apply(lambda x: stability_seq_to_inds(x[:STABILITY_SEQ_LEN]))

  stability_df['parent_stability'] = stability_df.parent.apply(lambda x: parent_to_parent_stability[x])

  stability_df['topology_ind'] = stability_df.topology.apply(lambda x: topology_to_index(x))

  return stability_df


def create_stability_batches(batch_size, epochs=1, test=False, buffer_size=None,
                             seed=0, drop_remainder=False):
  """Creates iterable object of Stability batches."""
  
  if test:
    buffer_size = 1
  
  stability_df = create_stability_df(test=test)
    
  stabilities = stability_df['stability'].values

  stability_batches = create_data_iterator(df=stability_df, input_col='one_hot_inds',
                                           output_col='stability', batch_size=batch_size,
                                           epochs=epochs, buffer_size=buffer_size, 
                                           seed=seed, drop_remainder=drop_remainder)

  return stability_batches, stabilities

## Model evaluation

In [78]:
def stability_evaluate(predict_fn, title, batch_size=256, 
                       test_data=None, pred_stabilities=None,
                       clip_min=-999999, clip_max=999999):
  """Computes predicted stabilities and measures performance in spearman correlation."""
  
  test_batches, test_stabilities = create_stability_batches(batch_size=batch_size, test=True, buffer_size=1)
  
  if test_data is not None:
    test_batches = test_data

  if pred_stabilities is None:
    pred_stabilities = []
    for batch in iter(test_batches):
      X, Y = batch
      preds = predict_fn(X)
      for pred in preds:
        pred_stabilities.append(pred[0])
  
  pred_stabilities = np.array(pred_stabilities)
  pred_stabilities = np.clip(pred_stabilities, clip_min, clip_max)
    
  spearmanr = scipy.stats.spearmanr(test_stabilities, pred_stabilities).correlation
  plt.scatter(test_stabilities, pred_stabilities, s=1, alpha=0.5)
  plt.xlabel('True Stability')
  plt.ylabel('Predicted Stability')
  plt.title(title)
  plt.show()
  
  stability_test_df = create_stability_df(test=True)
  stability_test_df['pred_stability'] = pred_stabilities
  stability_test_df['topology_str'] = stability_test_df.topology.apply(lambda x: x.decode('utf-8'))

  parent_to_pred_parent_stability = {}
  for parent in set(stability_test_df.parent.values): 
    pred_parent_stabilities = stability_test_df[stability_test_df['id_str']==parent.decode('utf-8') + '.pdb'].pred_stability.values
    if len(pred_parent_stabilities) == 0:
      pred_parent_stabilities = stability_test_df[stability_test_df['id_str']==parent.decode('utf-8')].pred_stability.values
      if len(pred_parent_stabilities) == 0:
        parent_to_pred_parent_stability[parent] = None
      else:
        parent_to_pred_parent_stability[parent] = pred_parent_stabilities[0]
    else:
      parent_to_pred_parent_stability[parent] = pred_parent_stabilities[0]
  
  stability_test_df['pred_parent_stability'] = stability_test_df.parent.apply(lambda x: parent_to_pred_parent_stability[x])
  
  parent_stabilities = stability_test_df['parent_stability'].values
  pred_parent_stabilities = stability_test_df['pred_parent_stability'].values

  correct_direction = 0
  for test_stability, pred_stability, parent_stability, pred_parent_stability in zip(test_stabilities, pred_stabilities, parent_stabilities, pred_parent_stabilities):
    if parent_stability is not None:
      if (test_stability >= parent_stability and pred_stability >= pred_parent_stability) or (test_stability <= parent_stability and pred_stability <= pred_parent_stability):
        correct_direction += 1
  accuracy = correct_direction / len(np.where(parent_stabilities!=None)[0])
  
  topologies = [('EEHEE', 'BBABB'), ('EHEE', 'BABB'), ('HEEH', 'ABBA'), ('HHH', 'AAA')]
  topology_results = {}
  
  for topology_pair in topologies:
    topology, topology_name = topology_pair
    topology_test_df = stability_test_df[stability_test_df['topology_str']==topology]

    topology_test_stabilities = topology_test_df['stability'].values
    topology_pred_stabilities = topology_test_df['pred_stability'].values
      
    topology_spearmanr = scipy.stats.spearmanr(topology_test_stabilities, topology_pred_stabilities).correlation
    plt.scatter(topology_test_stabilities, topology_pred_stabilities, s=1, alpha=0.5)
    plt.xlabel('True Stability')
    plt.ylabel('Predicted Stability')
    plt.title(title + ' (' + topology_name + ')')
    plt.show()
    
    topology_parent_stabilities = topology_test_df['parent_stability'].values
    topology_pred_parent_stabilities = topology_test_df['pred_parent_stability'].values

    topology_correct_direction = 0
    for test_stability, pred_stability, parent_stability, pred_parent_stability in zip(topology_test_stabilities, topology_pred_stabilities, topology_parent_stabilities, topology_pred_parent_stabilities):
      if parent_stability is not None:
        if (test_stability >= parent_stability and pred_stability >= pred_parent_stability) or (test_stability <= parent_stability and pred_stability <= pred_parent_stability):
          topology_correct_direction += 1
    topology_accuracy = topology_correct_direction / len(np.where(topology_parent_stabilities!=None)[0])
    
    topology_results[topology_name] = (topology_spearmanr, topology_accuracy)

  results = {
      'title': title,
      'spearmanr': round(spearmanr, 3),
      'accuracy': round(accuracy, 3)
  }
  for topology_name in topology_results.keys():
    results[topology_name + '_spearmanr'] = round(topology_results[topology_name][0], 3)
    results[topology_name + '_accuracy'] = round(topology_results[topology_name][1], 3)
  
  pprint.pprint(results)

  return results, pred_stabilities

## Experiments

In [79]:
stability_train_df = create_stability_df()
train_stabilities = stability_train_df['stability']
stability_train_one_hot_inds = np.array([x for x in stability_train_df['one_hot_inds'].values])
stability_train_one_hots = flattened_one_hot_encoder(stability_train_one_hot_inds, num_categories=len(STABILITY_AMINO_ACID_VOCABULARY))

In [80]:
stability_test_df = create_stability_df(test=True)
stability_test_one_hot_inds = np.array([x for x in stability_test_df['one_hot_inds'].values])
stability_test_one_hots = flattened_one_hot_encoder(stability_test_one_hot_inds, num_categories=len(STABILITY_AMINO_ACID_VOCABULARY))

### Linear regression

In [81]:
stability_linear_model = Ridge()

In [None]:
stability_linear_model.fit(X=stability_train_one_hots, y=train_stabilities)

In [83]:
linear_model_pred_stabilities = stability_linear_model.predict(stability_test_one_hots)

In [None]:
stability_linear_model_results, linear_model_pred_stabilities = \
stability_evaluate(predict_fn=None, 
                   title='Linear Regression',
                   pred_stabilities=linear_model_pred_stabilities,
                   clip_min=min(train_stabilities),
                   clip_max=max(train_stabilities))

In [None]:
print('Number of Parameters for Stability Linear Regression: ' + str(len(stability_linear_model.coef_)))

### CNN + MaxPool

In [None]:
epochs = 5
stability_train_batches, train_stabilities = create_stability_batches(batch_size=16, epochs=epochs)

layers = ['CNN_0', 'Dense_1']                                 
learning_rate = [5e-4, 5e-5]
weight_decay = [0.025, 0.025]

encoder_fn = cnn_one_hot_encoder
encoder_fn_kwargs = {
    'n_layers': 3,
    'n_features': [1024, 1024, 1024],
    'n_kernel_sizes': [5, 5, 5],
    'n_kernel_dilations': None
}
reduce_fn = max_pool
reduce_fn_kwargs = {
    
}
loss_fn_kwargs = {
    
}

stability_model = create_representation_model(encoder_fn=encoder_fn,
                                              encoder_fn_kwargs=encoder_fn_kwargs,
                                              reduce_fn=reduce_fn,
                                              reduce_fn_kwargs=reduce_fn_kwargs,
                                              num_categories=len(STABILITY_AMINO_ACID_VOCABULARY),
                                              output_features=1)

stability_optimizer = train(model=stability_model,
                            train_data=stability_train_batches, 
                            loss_fn=mse_loss,
                            loss_fn_kwargs=loss_fn_kwargs,
                            learning_rate=learning_rate, 
                            weight_decay=weight_decay,
                            layers=layers)

stability_results, pred_stabilities = stability_evaluate(predict_fn=stability_optimizer.target,
                                                         title='CNN + MaxPool',
                                                         batch_size=1024,
                                                         clip_min=min(train_stabilities),
                                                         clip_max=max(train_stabilities))

In [None]:
print('Number of Parameters for Stability CNN + MaxPool: ' + str(get_num_params(stability_optimizer.target)))

In [None]:
epochs = 5
stability_train_batches, train_stabilities = create_stability_batches(batch_size=16, epochs=epochs)

layers = ['CNN_0', 'Dense_1']                                 
learning_rate = [5e-4, 5e-5]
weight_decay = [0.025, 0.025]

encoder_fn = cnn_one_hot_encoder
encoder_fn_kwargs = {
    'n_layers': 3,
    'n_features': [1024, 1024, 1024],
    'n_kernel_sizes': [5, 5, 5],
    'n_kernel_dilations': [1, 2, 1]
}
reduce_fn = max_pool
reduce_fn_kwargs = {
    
}
loss_fn_kwargs = {
    
}

stability_model_d = create_representation_model(encoder_fn=encoder_fn,
                                                encoder_fn_kwargs=encoder_fn_kwargs,
                                                reduce_fn=reduce_fn,
                                                reduce_fn_kwargs=reduce_fn_kwargs,
                                                num_categories=len(STABILITY_AMINO_ACID_VOCABULARY),
                                                output_features=1)

stability_optimizer_d = train(model=stability_model_d,
                              train_data=stability_train_batches, 
                              loss_fn=mse_loss,
                              loss_fn_kwargs=loss_fn_kwargs,
                              learning_rate=learning_rate, 
                              weight_decay=weight_decay,
                              layers=layers)

stability_results_d, pred_stabilities_d = stability_evaluate(predict_fn=stability_optimizer_d.target,
                                                             title='Dilated CNN + MaxPool',
                                                             batch_size=1024,
                                                             clip_min=min(train_stabilities),
                                                             clip_max=max(train_stabilities))

In [None]:
print('Number of Parameters for Dilated Stability CNN + MaxPool: ' + str(get_num_params(stability_optimizer_d.target)))

### CNN + LinearMaxPool

In [None]:
epochs = 10
stability_train_batches, train_stabilities = create_stability_batches(batch_size=256, epochs=epochs)

layers = ['CNN_0', 'Dense_1', 'Dense_2']  
                              
learning_rate = [1e-5, 5e-5, 5e-6]
weight_decay = [0.0, 0.0, 0.0]

encoder_fn = cnn_one_hot_encoder
encoder_fn_kwargs = {
    'n_layers': 3,
    'n_features': [1024, 1024, 1024],
    'n_kernel_sizes': [5, 5, 5],
    'n_kernel_dilations': None
}
reduce_fn = linear_max_pool
reduce_fn_kwargs = {
    'rep_size': 2048
}
loss_fn_kwargs = {
    
}

stability_model_l = create_representation_model(encoder_fn=encoder_fn,
                                                encoder_fn_kwargs=encoder_fn_kwargs,
                                                reduce_fn=reduce_fn,
                                                reduce_fn_kwargs=reduce_fn_kwargs,
                                                num_categories=len(STABILITY_AMINO_ACID_VOCABULARY),
                                                output_features=1)

stability_optimizer_l = train(model=stability_model_l,
                              train_data=stability_train_batches, 
                              loss_fn=mse_loss,
                              loss_fn_kwargs=loss_fn_kwargs,
                              learning_rate=learning_rate, 
                              weight_decay=weight_decay,
                              layers=layers)

stability_results_l, pred_stabilities_l = stability_evaluate(predict_fn=stability_optimizer_l.target,
                                                             title='CNN + LinearMaxPool',
                                                             batch_size=1024,
                                                             clip_min=min(train_stabilities),
                                                             clip_max=max(train_stabilities))

In [None]:
print('Number of Parameters for Stability CNN + LinearMaxPool: ' + str(get_num_params(stability_optimizer_l.target)))

### Ensemble of CNNs

In [None]:
ensemble_pred_stabilities = [(pred_stabilities[i]+pred_stabilities_d[i]+pred_stabilities_l[i])/3 for i in range(len(pred_stabilities))]
ensemble_results, ensemble_pred_stabilities = stability_evaluate(predict_fn=None, 
                                                                 title='Ensemble of CNNs',
                                                                 pred_stabilities=ensemble_pred_stabilities,
                                                                 clip_min=min(train_stabilities),
                                                                 clip_max=max(train_stabilities))

In [None]:
print('Number of Parameters for Ensemble of CNNs: ' + str(get_num_params(stability_optimizer.target) + get_num_params(stability_optimizer_d.target) + get_num_params(stability_optimizer_l.target)))

### TAPE (Best Performance)

In [94]:
tape_stability_results = {
    'title': 'TAPE',
    'spearmanr': 0.73,
    'accuracy': 0.70,
    'AAA_spearmanr': 0.72,
    'AAA_accuracy': 0.70,
    'ABBA_spearmanr': 0.48,
    'ABBA_accuracy': 0.79,
    'BABB_spearmanr': 0.68,
    'BABB_accuracy': 0.71,
    'BBABB_spearmanr': 0.67,
    'BBABB_accuracy': 0.70
}

In [None]:
pprint.pprint(tape_stability_results)

## Visualization

In [96]:
def stability_embeddings(encoder_fn, encoder_fn_kwargs, reduce_fn, reduce_fn_kwargs, optimizer):
  """Computes stability embeddings from given optimizer."""

  stability_encoding_model = create_representation_model(encoder_fn=encoder_fn,
                                                         encoder_fn_kwargs=encoder_fn_kwargs,
                                                         reduce_fn=reduce_fn,
                                                         reduce_fn_kwargs=reduce_fn_kwargs,
                                                         num_categories=len(STABILITY_AMINO_ACID_VOCABULARY),
                                                         output='embedding',
                                                         output_features=1)
  
  trained_params = copy.deepcopy(optimizer.target.params)

  stability_encoding_optimizer = create_optimizer(stability_encoding_model,
                                                  learning_rate=learning_rate,
                                                  weight_decay=weight_decay,
                                                  layers=layers)
  
  for layer in stability_encoding_optimizer.target.params.keys():
    stability_encoding_optimizer.target.params[layer] = trained_params[layer]

  stability_train_batches, train_stabilities = create_stability_batches(batch_size=256, buffer_size=1)
  stability_train_embeddings = compute_embeddings(stability_encoding_optimizer.target, stability_train_batches)
  
  stability_test_batches, test_stabilities = create_stability_batches(batch_size=256, test=True)
  stability_test_embeddings = compute_embeddings(stability_encoding_optimizer.target, stability_test_batches)

  return stability_train_embeddings, stability_test_embeddings

In [97]:
# CNN + MaxPool Optimizer Embeddings
encoder_fn = cnn_one_hot_encoder
encoder_fn_kwargs = {
    'n_layers': 1,
    'n_features': [1024],
    'n_kernel_sizes': [5],
    'n_kernel_dilations': None
}
reduce_fn = max_pool
reduce_fn_kwargs = {

}
stability_train_embeddings, stability_test_embeddings = stability_embeddings(encoder_fn,
                                                                             encoder_fn_kwargs,
                                                                             reduce_fn,
                                                                             reduce_fn_kwargs,
                                                                             stability_optimizer)

In [98]:
# Random CNN + MaxPool Optimizer Embeddings
layers = ['CNN_0', 'Dense_1']                                 
learning_rate = [5e-4, 5e-5]
weight_decay = [0.025, 0.025]

encoder_fn = cnn_one_hot_encoder
encoder_fn_kwargs = {
    'n_layers': 3,
    'n_features': [1024, 1024, 1024],
    'n_kernel_sizes': [5, 5, 5],
    'n_kernel_dilations': None
}
reduce_fn = max_pool
reduce_fn_kwargs = {
    
}
loss_fn_kwargs = {
    
}

stability_model = create_representation_model(encoder_fn=encoder_fn,
                                              encoder_fn_kwargs=encoder_fn_kwargs,
                                              reduce_fn=reduce_fn,
                                              reduce_fn_kwargs=reduce_fn_kwargs,
                                              num_categories=len(STABILITY_AMINO_ACID_VOCABULARY),
                                              output_features=1)

stability_optimizer_random = create_optimizer(stability_model, learning_rate, weight_decay, layers)

stability_train_embeddings_random, stability_test_embeddings_random = stability_embeddings(encoder_fn,
                                                                                           encoder_fn_kwargs,
                                                                                           reduce_fn,
                                                                                           reduce_fn_kwargs,
                                                                                           stability_optimizer_random)

In [99]:
# CNN + LinearMaxPool Optimizer Embeddings
encoder_fn = cnn_one_hot_encoder
encoder_fn_kwargs = {
    'n_layers': 1,
    'n_features': [1024],
    'n_kernel_sizes': [5]
}
reduce_fn = linear_max_pool
reduce_fn_kwargs = {
    'rep_size': 2048
}
stability_train_embeddings_l, stability_test_embeddings_l = stability_embeddings(encoder_fn,
                                                                                 encoder_fn_kwargs,
                                                                                 reduce_fn,
                                                                                 reduce_fn_kwargs,
                                                                                 stability_optimizer_l)

In [100]:
# Ensemble Embeddings (Concatention of Above Embeddings)
stability_train_embeddings_e = np.concatenate((stability_train_embeddings, stability_train_embeddings_l), axis=1)
stability_test_embeddings_e = np.concatenate((stability_test_embeddings, stability_test_embeddings_l), axis=1)

### Train PCA

In [101]:
def stability_train_pca_plot(train_embeddings, model_name):
  """Applies and plots PCA on stability train embeddings."""

  stability_train_df = create_stability_df()

  stability_train_embeddings_pca = PCA(n_components=2).fit_transform(train_embeddings)
  
  X_train = [s[0] for s in stability_train_embeddings_pca]
  Y_train = [s[1] for s in stability_train_embeddings_pca]

  plt.figure(figsize=(8, 6))
  plt.scatter(X_train, Y_train, c=stability_train_df.stability.values, s=1, alpha=0.25)
  plt.title('PCA of Train Stability Embeddings (' + model_name + ')')
  plt.xlabel('Principal Component 1')
  plt.ylabel('Principal Component 2')
  plt.colorbar().set_label('Stability')
  plt.show()

  plt.figure(figsize=(8, 6))
  plt.scatter(X_train, Y_train, c=stability_train_df.parent_stability.values, s=1, alpha=0.5)
  plt.title('PCA of Train Stability Embeddings (' + model_name + ')')
  plt.xlabel('Principal Component 1')
  plt.ylabel('Principal Component 2')
  plt.colorbar().set_label('Parent Stability')
  plt.show()

  plt.figure(figsize=(8, 6))
  plt.scatter(X_train, Y_train, c=stability_train_df.topology_ind.values, s=1, alpha=0.5)
  plt.title('PCA of Train Stability Embeddings (' + model_name + ')')
  plt.xlabel('Principal Component 1')
  plt.ylabel('Principal Component 2')
  
  topology_names = ['AAA', 'ABBA', 'Other', 'BABB', 'BBABB']
  formatter = plt.FuncFormatter(lambda val, loc: topology_names[val])
  cbar = plt.colorbar(ticks=range(5), format=formatter)
  cbar.set_label('Topology')
  plt.show()

In [None]:
stability_train_pca_plot(stability_train_embeddings_random, model_name='Random CNN + MaxPool')

In [None]:
stability_train_pca_plot(stability_train_embeddings, model_name='CNN + MaxPool')

In [None]:
stability_train_pca_plot(stability_train_embeddings_l, model_name='CNN + LinearMaxPool')

In [None]:
stability_train_pca_plot(stability_train_embeddings_e, model_name='Ensemble of CNNs')

### Test PCA

In [106]:
def stability_test_pca_plot(test_embeddings, model_name):
  """Applies and plots PCA on stability test embeddings."""

  stability_test_df = create_stability_df(test=True)

  stability_test_embeddings_pca = PCA(n_components=2).fit_transform(test_embeddings)
  
  X_test = [s[0] for s in stability_test_embeddings_pca]
  Y_test = [s[1] for s in stability_test_embeddings_pca]

  plt.figure(figsize=(8, 6))
  plt.scatter(X_test, Y_test, c=stability_test_df.stability.values, s=1, alpha=0.25)
  plt.title('PCA of Test Stability Embeddings (' + model_name + ')')
  plt.xlabel('Principal Component 1')
  plt.ylabel('Principal Component 2')
  plt.colorbar().set_label('Stability')
  plt.show()

  plt.figure(figsize=(8, 6))
  plt.scatter(X_test, Y_test, c=stability_test_df.parent_stability.values, s=1, alpha=0.5)
  plt.title('PCA of Test Stability Embeddings (' + model_name + ')')
  plt.xlabel('Principal Component 1')
  plt.ylabel('Principal Component 2')
  plt.colorbar().set_label('Parent Stability')
  plt.show()

  plt.figure(figsize=(8, 6))
  plt.scatter(X_test, Y_test, c=stability_test_df.topology_ind.values, s=1, alpha=0.5)
  plt.title('PCA of Test Stability Embeddings (' + model_name + ')')
  plt.xlabel('Principal Component 1')
  plt.ylabel('Principal Component 2')
  topology_names = ['AAA', 'ABBA', 'Other', 'BABB', 'BBABB']
  formatter = plt.FuncFormatter(lambda val, loc: topology_names[val])
  cbar = plt.colorbar(ticks=range(5), format=formatter)
  cbar.set_label('Topology')
  plt.show()

In [None]:
stability_test_pca_plot(stability_test_embeddings_random, model_name='Random CNN + MaxPool')

In [None]:
stability_test_pca_plot(stability_test_embeddings, model_name='CNN + MaxPool')

In [None]:
stability_test_pca_plot(stability_test_embeddings_l, model_name='CNN + LinearMaxPool')

In [None]:
stability_test_pca_plot(stability_test_embeddings_e, model_name='Ensemble of CNNs')

### Train t-SNE

In [111]:
def stability_train_tsne_plot(train_embeddings, model_name, num_samples=5000):
  """Applies and plots t-SNE on stability train embeddings."""

  stability_train_df = create_stability_df()

  np.random.seed(0)
  stability_train_pairs = np.random.permutation(np.array([(train_embeddings[i], stability_train_df.stability.values[i], stability_train_df.parent_stability.values[i], stability_train_df.topology_ind.values[i]) for i in range(len(train_embeddings))]))
  sub_stability_train_pairs = stability_train_pairs[:num_samples]

  sub_stability_train_embeddings = np.array([s[0] for s in sub_stability_train_pairs])
  sub_stability_train_stabilities = np.array([s[1] for s in sub_stability_train_pairs])
  sub_stability_train_parent_stabilities = np.array([s[2] for s in sub_stability_train_pairs])
  sub_stability_train_topologies = np.array([s[3] for s in sub_stability_train_pairs])

  stability_train_embeddings_tsne = TSNE(n_components=2).fit_transform(sub_stability_train_embeddings)

  X_tsne_train = [s[0] for s in stability_train_embeddings_tsne]
  Y_tsne_train = [s[1] for s in stability_train_embeddings_tsne]

  plt.figure(figsize=(8, 6))
  plt.scatter(X_tsne_train, Y_tsne_train, c=sub_stability_train_stabilities, s=5, alpha=0.5)
  plt.title('t-SNE of Train Stability Embeddings (' + model_name + ')')
  cbar = plt.colorbar()
  cbar.set_label('Stability')
  plt.show()

  plt.figure(figsize=(8, 6))
  plt.scatter(X_tsne_train, Y_tsne_train, c=sub_stability_train_parent_stabilities, s=5, alpha=0.5)
  plt.title('t-SNE of Train Stability Embeddings (' + model_name + ')')
  cbar = plt.colorbar()
  cbar.set_label('Parent Stability')
  plt.show()

  plt.figure(figsize=(8, 6))
  plt.scatter(X_tsne_train, Y_tsne_train, c=sub_stability_train_topologies, s=5, alpha=0.5)
  plt.title('t-SNE of Train Stability Embeddings (' + model_name + ')')
  topology_names = ['AAA', 'ABBA', 'Other', 'BABB', 'BBABB']
  formatter = plt.FuncFormatter(lambda val, loc: topology_names[val])
  cbar = plt.colorbar(ticks=range(5), format=formatter)
  cbar.set_label('Topology')
  plt.show()

In [None]:
stability_train_tsne_plot(stability_train_embeddings_random, model_name='Random CNN + MaxPool', num_samples=5000)

In [None]:
stability_train_tsne_plot(stability_train_embeddings, model_name='CNN + MaxPool', num_samples=5000)

In [None]:
stability_train_tsne_plot(stability_train_embeddings_l, model_name='CNN + LinearMaxPool', num_samples=5000)

In [None]:
stability_train_tsne_plot(stability_train_embeddings_e, model_name='Ensemble of CNNs', num_samples=5000)

### Test t-SNE

In [116]:
def stability_test_tsne_plot(test_embeddings, model_name, num_samples=5000):
  """Applies and plots t-SNE on stability test embeddings."""

  stability_test_df = create_stability_df(test=True)

  np.random.seed(0)
  stability_test_pairs = np.random.permutation(np.array([(test_embeddings[i], stability_test_df.stability.values[i], stability_test_df.parent_stability.values[i], stability_test_df.topology_ind.values[i]) for i in range(len(test_embeddings))]))
  sub_stability_test_pairs = stability_test_pairs[:num_samples]

  sub_stability_test_embeddings = np.array([s[0] for s in sub_stability_test_pairs])
  sub_stability_test_stabilities = np.array([s[1] for s in sub_stability_test_pairs])
  sub_stability_test_parent_stabilities = np.array([s[2] for s in sub_stability_test_pairs])
  sub_stability_test_topologies = np.array([s[3] for s in sub_stability_test_pairs])

  stability_test_embeddings_tsne = TSNE(n_components=2).fit_transform(sub_stability_test_embeddings)

  X_tsne_test = [s[0] for s in stability_test_embeddings_tsne]
  Y_tsne_test = [s[1] for s in stability_test_embeddings_tsne]

  plt.figure(figsize=(8, 6))
  plt.scatter(X_tsne_test, Y_tsne_test, c=sub_stability_test_stabilities, s=5, alpha=0.5)
  plt.title('t-SNE of Test Stability Embeddings (' + model_name + ')')
  cbar = plt.colorbar()
  cbar.set_label('Stability')
  plt.show()

  plt.figure(figsize=(8, 6))
  plt.scatter(X_tsne_test, Y_tsne_test, c=sub_stability_test_parent_stabilities, s=5, alpha=0.5)
  plt.title('t-SNE of Test Stability Embeddings (' + model_name + ')')
  cbar = plt.colorbar()
  cbar.set_label('Parent Stability')
  plt.show()

  plt.figure(figsize=(8, 6))
  plt.scatter(X_tsne_test, Y_tsne_test, c=sub_stability_test_topologies, s=5, alpha=0.5)
  plt.title('t-SNE of Test Stability Embeddings (' + model_name + ')')
  topology_names = ['AAA', 'ABBA', 'Other', 'BABB', 'BBABB']
  formatter = plt.FuncFormatter(lambda val, loc: topology_names[val])
  cbar = plt.colorbar(ticks=range(5), format=formatter)
  cbar.set_label('Topology')
  plt.show()

In [None]:
stability_test_tsne_plot(stability_test_embeddings_random, model_name='Random CNN + MaxPool', num_samples=5000)

In [None]:
stability_test_tsne_plot(stability_test_embeddings, model_name='CNN + MaxPool', num_samples=5000)

In [None]:
stability_test_tsne_plot(stability_test_embeddings_l, model_name='CNN + LinearMaxPool', num_samples=5000)

In [None]:
stability_test_tsne_plot(stability_test_embeddings_e, model_name='Ensemble of CNNs', num_samples=5000)

# Variant prediction

The data originally comes from [Envision (Gray et. al.)](https://pubmed.ncbi.nlm.nih.gov/29226803/) and was formatted by [FAIR ESM](https://github.com/facebookresearch/esm).

## Open data



### Install FAIR esm

In [None]:
!pip install git+https://github.com/facebookresearch/esm.git
!curl -O https://dl.fbaipublicfiles.com/fair-esm/examples/P62593.fasta
!pwd
!ls

In [122]:
import esm

### Padding and one-hot encoding

In [123]:
VARIANT_SEQ_LEN = 286

In [124]:
VARIANT_AMINO_ACID_VOCABULARY = [
    'A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R',
    'S', 'T', 'V', 'W', 'Y', '-'
]

In [125]:
VARIANT_PROTEIN_DOMAIN = domains.VariableLengthDiscreteDomain(
    vocab=domains.ProteinVocab(include_anomalous_amino_acids=False,
                               include_eos=True,
                               include_pad=True),
    length=VARIANT_SEQ_LEN)

In [126]:
def variant_seq_to_inds(seq):
  """Encode variant amino acid sequence."""

  return VARIANT_PROTEIN_DOMAIN.encode([seq])[0]

### Add one-hots and batch data

In [127]:
def create_variant_df(test=False, train_size=0.8):
  """Processes variant data into a featurized dataframe."""
  
  FASTA_PATH = "./P62593.fasta"

  ys = []
  Xs = []
  for header, _seq in esm.data.read_fasta(FASTA_PATH):
    Xs.append(_seq)
    scaled_effect = header.split('|')[-1]
    ys.append(float(scaled_effect))
  
  Xs_train, Xs_test, ys_train, ys_test = \
    train_test_split(Xs, ys, train_size=train_size, random_state=42)

  if test:
    variant_df = pd.DataFrame(np.column_stack([Xs_test, ys_test]),
                              columns=['primary', 'activity'])    
  else:
    variant_df = pd.DataFrame(np.column_stack([Xs_train, ys_train]),
                              columns=['primary', 'activity'])

  variant_df['activity'] = variant_df.activity.apply(lambda x: float(x))
  variant_df['one_hot_inds'] = variant_df.primary.apply(lambda x: variant_seq_to_inds(x[:VARIANT_SEQ_LEN]))

  return variant_df


def create_variant_batches(batch_size, epochs=1, test=False, buffer_size=None, 
                           seed=0, drop_remainder=False, train_size=0.8):
  """Creates iterable object of variant batches."""
  
  if test:
    buffer_size = 1

  variant_df = create_variant_df(test=test, train_size=train_size)
    
  activities = variant_df['activity'].values

  variant_batches = create_data_iterator(df=variant_df, input_col='one_hot_inds', 
                                         output_col='activity', 
                                         batch_size=batch_size, epochs=epochs, 
                                         buffer_size=buffer_size, seed=seed, 
                                         drop_remainder=drop_remainder)

  return variant_batches, activities

## Model evaluation

In [128]:
def variant_evaluate(predict_fn, title, batch_size=256, test_data=None, 
                     pred_activities=None, clip_min=-999999, clip_max=999999,
                     train_size=0.8):
  """Computes variant predictions and measures performance in spearman correlation."""
  
  test_batches, test_activities = create_variant_batches(batch_size=batch_size, 
                                                         test=True, buffer_size=1,
                                                         train_size=train_size)
  
  if test_data is not None:
    test_batches = test_data

  if pred_activities is None:
    pred_activities = []
    for batch in iter(test_batches):
      X, Y = batch
      preds = predict_fn(X)
      for pred in preds:
        pred_activities.append(pred[0])
  
  pred_activities = np.array(pred_activities)
  pred_activities = np.clip(pred_activities, clip_min, clip_max)
    
  spearmanr = scipy.stats.spearmanr(test_activities, pred_activities).correlation
  mse = sklearn.metrics.mean_squared_error(test_activities, pred_activities)
  plt.scatter(test_activities, pred_activities, s=1, alpha=0.5)
  plt.xlabel('True Activity')
  plt.ylabel('Predicted Activity')
  plt.title(title)
  plt.show()

  # Source: https://stackoverflow.com/questions/2130016/splitting-a-list-into-n-parts-of-approximately-equal-length
  def list_split(a, n):
    k, m = divmod(len(a), n)
    return [a[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n)]

  all_pred_activities = copy.deepcopy(pred_activities)

  np.random.seed(0)
  test_activities = np.random.permutation(test_activities)
  np.random.seed(0)
  pred_activities = np.random.permutation(pred_activities)

  test_activities = list_split(test_activities, 5)
  pred_activities = list_split(pred_activities, 5)

  spearmanrs = []
  for i in range(5):
    spearmanrs.append(scipy.stats.spearmanr(test_activities[i], pred_activities[i]).correlation)
  
  mean_spearmanr = np.mean(spearmanrs)
  std_spearmanr = np.std(spearmanrs)

  results = {
      'title': title,
      'spearmanr': round(spearmanr, 3),
      'mse': round(mse, 3),
      'mean_spearmanr': round(mean_spearmanr, 3),
      'std_spearmanr': round(std_spearmanr, 3)
  }

  pprint.pprint(results)

  return results, all_pred_activities

## Experiments

In [129]:
variant_train_df = create_variant_df()
train_activities = variant_train_df['activity']
variant_train_one_hot_inds = np.array([x for x in variant_train_df['one_hot_inds'].values])
variant_train_one_hots = flattened_one_hot_encoder(variant_train_one_hot_inds, num_categories=len(VARIANT_AMINO_ACID_VOCABULARY))

In [130]:
variant_test_df = create_variant_df(test=True)
variant_test_one_hot_inds = np.array([x for x in variant_test_df['one_hot_inds'].values])
variant_test_one_hots = flattened_one_hot_encoder(variant_test_one_hot_inds, num_categories=len(VARIANT_AMINO_ACID_VOCABULARY))

### Linear regression

In [131]:
variant_linear_model = Ridge()

In [None]:
variant_linear_model.fit(X=variant_train_one_hots, y=train_activities)

In [133]:
linear_model_pred_activities = variant_linear_model.predict(variant_test_one_hots)

In [None]:
variant_linear_model_results, linear_model_pred_activities = \
variant_evaluate(predict_fn=None, 
                 title='Linear Regression',
                 pred_activities=linear_model_pred_activities,
                 clip_min=min(train_activities),
                 clip_max=max(train_activities))

In [None]:
print('Number of Parameters for Variant Linear Regression: ' + str(len(variant_linear_model.coef_)))

In [136]:
def fit_variant_linear_model(train_size=0.8):
  """Fits and evaluates Ridge regression model for given train_size."""

  variant_train_df = create_variant_df(train_size=train_size)
  train_activities = variant_train_df['activity']
  variant_train_one_hot_inds = np.array([x for x in variant_train_df['one_hot_inds'].values])
  variant_train_one_hots = flattened_one_hot_encoder(variant_train_one_hot_inds, num_categories=len(VARIANT_AMINO_ACID_VOCABULARY))

  variant_test_df = create_variant_df(test=True, train_size=train_size)
  variant_test_one_hot_inds = np.array([x for x in variant_test_df['one_hot_inds'].values])
  variant_test_one_hots = flattened_one_hot_encoder(variant_test_one_hot_inds, num_categories=len(VARIANT_AMINO_ACID_VOCABULARY))

  variant_linear_model = Ridge()

  variant_linear_model.fit(X=variant_train_one_hots, y=train_activities)

  linear_model_pred_activities = variant_linear_model.predict(variant_test_one_hots)

  variant_linear_model_results, linear_model_pred_activities = \
  variant_evaluate(predict_fn=None,
                   title='Linear Regression (train_size=' + str(train_size) + ')',
                   pred_activities=linear_model_pred_activities,
                   clip_min=min(train_activities),
                   clip_max=max(train_activities),
                   train_size=train_size)
    
  return variant_linear_model_results, linear_model_pred_activities

In [None]:
for train_size in [0.01, 0.1, 0.3, 0.5, 0.8]:
  variant_linear_model_results, linear_model_pred_activities = \
  fit_variant_linear_model(train_size)
  print()

### CNN + MaxPool

In [None]:
epochs = 500
variant_train_batches, train_activities = create_variant_batches(batch_size=32, epochs=epochs)

layers = ['CNN_0', 'Dense_1']                                 
learning_rate = [1e-3, 1e-3]
weight_decay = [0.1, 0.05]

encoder_fn = cnn_one_hot_encoder
encoder_fn_kwargs = {
    'n_layers': 2,
    'n_features': [1024, 512],
    'n_kernel_sizes': [9, 7],
    'n_kernel_dilations': None
}
reduce_fn = max_pool
reduce_fn_kwargs = {

}
loss_fn_kwargs = {
    
}

variant_model = create_representation_model(encoder_fn=encoder_fn,
                                            encoder_fn_kwargs=encoder_fn_kwargs,
                                            reduce_fn=reduce_fn,
                                            reduce_fn_kwargs=reduce_fn_kwargs,
                                            num_categories=len(VARIANT_AMINO_ACID_VOCABULARY),
                                            output_features=1)

variant_optimizer = train(model=variant_model,
                          train_data=variant_train_batches, 
                          loss_fn=mse_loss,
                          loss_fn_kwargs=loss_fn_kwargs,
                          learning_rate=learning_rate, 
                          weight_decay=weight_decay,
                          layers=layers)

variant_results, pred_activities = variant_evaluate(predict_fn=variant_optimizer.target,
                                                    title='CNN + MaxPool',
                                                    batch_size=256,
                                                    clip_min=min(train_activities),
                                                    clip_max=max(train_activities))

In [None]:
print('Number of Parameters for Variant CNN + MaxPool: ' + str(get_num_params(variant_optimizer.target)))

In [None]:
epochs = 500
variant_train_batches, train_activities = create_variant_batches(batch_size=32, epochs=epochs)

layers = ['CNN_0', 'Dense_1']                                 
learning_rate = [1e-3, 1e-3]
weight_decay = [0.1, 0.05]

encoder_fn = cnn_one_hot_encoder
encoder_fn_kwargs = {
    'n_layers': 2,
    'n_features': [1024, 512],
    'n_kernel_sizes': [9, 7],
    'n_kernel_dilations': [2, 1]
}
reduce_fn = max_pool
reduce_fn_kwargs = {

}
loss_fn_kwargs = {
    
}

variant_model_d = create_representation_model(encoder_fn=encoder_fn,
                                              encoder_fn_kwargs=encoder_fn_kwargs,
                                              reduce_fn=reduce_fn,
                                              reduce_fn_kwargs=reduce_fn_kwargs,
                                              num_categories=len(VARIANT_AMINO_ACID_VOCABULARY),
                                              output_features=1)

variant_optimizer_d = train(model=variant_model_d,
                            train_data=variant_train_batches, 
                            loss_fn=mse_loss,
                            loss_fn_kwargs=loss_fn_kwargs,
                            learning_rate=learning_rate, 
                            weight_decay=weight_decay,
                            layers=layers)

variant_results_d, pred_activities_d = variant_evaluate(predict_fn=variant_optimizer_d.target,
                                                        title='Dilated CNN + MaxPool',
                                                        batch_size=256,
                                                        clip_min=min(train_activities),
                                                        clip_max=max(train_activities))

In [None]:
print('Number of Parameters for Variant Dilated CNN + MaxPool: ' + str(get_num_params(variant_optimizer_d.target)))

### Ensemble of CNNs

In [None]:
ensemble_pred_activities = [(pred_activities[i]+pred_activities_d[i])/2 for i in range(len(pred_activities))]
ensemble_results, ensemble_pred_activities = variant_evaluate(predict_fn=None, 
                                                              title='Ensemble of CNNs',
                                                              pred_activities=ensemble_pred_activities,
                                                              clip_min=min(train_activities),
                                                              clip_max=max(train_activities))

In [None]:
print('Number of Parameters for Variant Ensemble of CNNs: ' + str(get_num_params(variant_optimizer.target) + get_num_params(variant_optimizer_d.target)))

In [144]:
def fit_variant_cnn_models(train_size=0.8):
  """Fits and evaluates CNN models for given train_size."""
  
  # CNN + MaxPool
  epochs = 500
  variant_train_batches, train_activities = create_variant_batches(batch_size=32, epochs=epochs, train_size=train_size)

  layers = ['CNN_0', 'Dense_1']                                 
  learning_rate = [1e-3, 1e-3]
  weight_decay = [0.1, 0.05]

  encoder_fn = cnn_one_hot_encoder
  encoder_fn_kwargs = {
      'n_layers': 2,
      'n_features': [1024, 512],
      'n_kernel_sizes': [9, 7],
      'n_kernel_dilations': None
  }
  reduce_fn = max_pool
  reduce_fn_kwargs = {

  }
  loss_fn_kwargs = {
      
  }

  variant_model = create_representation_model(encoder_fn=encoder_fn,
                                              encoder_fn_kwargs=encoder_fn_kwargs,
                                              reduce_fn=reduce_fn,
                                              reduce_fn_kwargs=reduce_fn_kwargs,
                                              num_categories=len(VARIANT_AMINO_ACID_VOCABULARY),
                                              output_features=1)

  variant_optimizer = train(model=variant_model,
                            train_data=variant_train_batches, 
                            loss_fn=mse_loss,
                            loss_fn_kwargs=loss_fn_kwargs,
                            learning_rate=learning_rate, 
                            weight_decay=weight_decay,
                            layers=layers)

  variant_results, pred_activities = variant_evaluate(predict_fn=variant_optimizer.target,
                                                      title='CNN + MaxPool (train_size=' + str(train_size) + ')',
                                                      batch_size=256,
                                                      clip_min=min(train_activities),
                                                      clip_max=max(train_activities),
                                                      train_size=train_size)

  # Dilated CNN + MaxPool
  epochs = 500
  variant_train_batches, train_activities = create_variant_batches(batch_size=32, epochs=epochs, train_size=train_size)

  layers = ['CNN_0', 'Dense_1']                                 
  learning_rate = [1e-3, 1e-3]
  weight_decay = [0.1, 0.05]

  encoder_fn = cnn_one_hot_encoder
  encoder_fn_kwargs = {
      'n_layers': 2,
      'n_features': [1024, 512],
      'n_kernel_sizes': [9, 7],
      'n_kernel_dilations': [2, 1]
  }
  reduce_fn = max_pool
  reduce_fn_kwargs = {

  }
  loss_fn_kwargs = {
      
  }

  variant_model_d = create_representation_model(encoder_fn=encoder_fn,
                                                encoder_fn_kwargs=encoder_fn_kwargs,
                                                reduce_fn=reduce_fn,
                                                reduce_fn_kwargs=reduce_fn_kwargs,
                                                num_categories=len(VARIANT_AMINO_ACID_VOCABULARY),
                                                output_features=1)

  variant_optimizer_d = train(model=variant_model_d,
                              train_data=variant_train_batches, 
                              loss_fn=mse_loss,
                              loss_fn_kwargs=loss_fn_kwargs,
                              learning_rate=learning_rate, 
                              weight_decay=weight_decay,
                              layers=layers)

  variant_results_d, pred_activities_d = variant_evaluate(predict_fn=variant_optimizer_d.target,
                                                          title='Dilated CNN + MaxPool (train_size=' + str(train_size) + ')',
                                                          batch_size=256,
                                                          clip_min=min(train_activities),
                                                          clip_max=max(train_activities),
                                                          train_size=train_size)
  
  # Ensemble of CNNs
  ensemble_pred_activities = [(pred_activities[i]+pred_activities_d[i])/2 for i in range(len(pred_activities))]
  ensemble_results, ensemble_pred_activities = variant_evaluate(predict_fn=None, 
                                                                title='Ensemble of CNNs (train_size=' + str(train_size) + ')',
                                                                pred_activities=ensemble_pred_activities,
                                                                clip_min=min(train_activities),
                                                                clip_max=max(train_activities),
                                                                train_size=train_size)
  
  return variant_results, pred_activities, variant_results_d, pred_activities_d, ensemble_results, ensemble_pred_activities

In [None]:
for train_size in [0.01, 0.1, 0.3, 0.5, 0.8]:
  fit_variant_cnn_models(train_size)
  print()

## Visualization

In [146]:
def variant_embeddings(encoder_fn, encoder_fn_kwargs, reduce_fn, reduce_fn_kwargs, optimizer, train_size=0.8):
  """Computes variant embeddings from given optimizer."""

  variant_encoding_model = create_representation_model(encoder_fn=encoder_fn,
                                                       encoder_fn_kwargs=encoder_fn_kwargs,
                                                       reduce_fn=reduce_fn,
                                                       reduce_fn_kwargs=reduce_fn_kwargs,
                                                       num_categories=len(VARIANT_AMINO_ACID_VOCABULARY),
                                                       output='embedding',
                                                       output_features=1)
  
  trained_params = copy.deepcopy(optimizer.target.params)

  variant_encoding_optimizer = create_optimizer(variant_encoding_model,
                                                learning_rate=learning_rate,
                                                weight_decay=weight_decay,
                                                layers=layers)
  
  for layer in variant_encoding_optimizer.target.params.keys():
    variant_encoding_optimizer.target.params[layer] = trained_params[layer]

  variant_train_batches, train_variant= create_variant_batches(batch_size=256, buffer_size=1, train_size=train_size)
  variant_train_embeddings = compute_embeddings(variant_encoding_optimizer.target, variant_train_batches)
  
  variant_test_batches, test_variant = create_variant_batches(batch_size=256, test=True, train_size=train_size)
  variant_test_embeddings = compute_embeddings(variant_encoding_optimizer.target, variant_test_batches)

  return variant_train_embeddings, variant_test_embeddings

In [147]:
# CNN + MaxPool Optimizer Embeddings
encoder_fn = cnn_one_hot_encoder
encoder_fn_kwargs = {
    'n_layers': 2,
    'n_features': [1024, 512],
    'n_kernel_sizes': [9, 7],
    'n_kernel_dilations': [1, 1]
}
reduce_fn = max_pool
reduce_fn_kwargs = {

}
variant_train_embeddings, variant_test_embeddings = variant_embeddings(encoder_fn,
                                                                       encoder_fn_kwargs,
                                                                       reduce_fn,
                                                                       reduce_fn_kwargs,
                                                                       variant_optimizer)

In [148]:
# Dilated CNN + MaxPool Optimizer Embeddings
encoder_fn = cnn_one_hot_encoder
encoder_fn_kwargs = {
    'n_layers': 2,
    'n_features': [1024, 512],
    'n_kernel_sizes': [9, 7],
    'n_kernel_dilations': [2, 1]
}
reduce_fn = max_pool
reduce_fn_kwargs = {

}
variant_train_embeddings_d, variant_test_embeddings_d = variant_embeddings(encoder_fn,
                                                                           encoder_fn_kwargs,
                                                                           reduce_fn,
                                                                           reduce_fn_kwargs,
                                                                           variant_optimizer_d)

In [149]:
# Ensemble Embeddings (Concatention of Above Embeddings)
variant_train_embeddings_e = np.concatenate((variant_train_embeddings, variant_train_embeddings_d), axis=1)
variant_test_embeddings_e = np.concatenate((variant_test_embeddings, variant_test_embeddings_d), axis=1)

In [150]:
# Random CNN + MaxPool Optimizer Embeddings
layers = ['CNN_0', 'Dense_1']                                 
learning_rate = [1e-3, 1e-3]
weight_decay = [0.1, 0.05]

encoder_fn = cnn_one_hot_encoder
encoder_fn_kwargs = {
    'n_layers': 2,
    'n_features': [1024, 512],
    'n_kernel_sizes': [9, 7],
    'n_kernel_dilations': [1, 1]
}
reduce_fn = max_pool
reduce_fn_kwargs = {
    
}
loss_fn_kwargs = {
    
}

variant_model = create_representation_model(encoder_fn=encoder_fn,
                                            encoder_fn_kwargs=encoder_fn_kwargs,
                                            reduce_fn=reduce_fn,
                                            reduce_fn_kwargs=reduce_fn_kwargs,
                                            num_categories=len(VARIANT_AMINO_ACID_VOCABULARY),
                                            output_features=1)

variant_optimizer_random = create_optimizer(variant_model, learning_rate, weight_decay, layers)

variant_train_embeddings_random, variant_test_embeddings_random = variant_embeddings(encoder_fn,
                                                                                     encoder_fn_kwargs,
                                                                                     reduce_fn,
                                                                                     reduce_fn_kwargs,
                                                                                     variant_optimizer_random)

In [151]:
# Random Dilated CNN + MaxPool Optimizer Embeddings
layers = ['CNN_0', 'Dense_1']                                 
learning_rate = [1e-3, 1e-3]
weight_decay = [0.1, 0.05]

encoder_fn = cnn_one_hot_encoder
encoder_fn_kwargs = {
    'n_layers': 2,
    'n_features': [1024, 512],
    'n_kernel_sizes': [9, 7],
    'n_kernel_dilations': [2, 1]
}
reduce_fn = max_pool
reduce_fn_kwargs = {
    
}
loss_fn_kwargs = {
    
}

variant_model_d = create_representation_model(encoder_fn=encoder_fn,
                                              encoder_fn_kwargs=encoder_fn_kwargs,
                                              reduce_fn=reduce_fn,
                                              reduce_fn_kwargs=reduce_fn_kwargs,
                                              num_categories=len(VARIANT_AMINO_ACID_VOCABULARY),
                                              output_features=1)

variant_optimizer_d_random = create_optimizer(variant_model_d, learning_rate, weight_decay, layers)

variant_train_embeddings_d_random, variant_test_embeddings_d_random = variant_embeddings(encoder_fn,
                                                                                         encoder_fn_kwargs,
                                                                                         reduce_fn,
                                                                                         reduce_fn_kwargs,
                                                                                         variant_optimizer_d_random)

In [152]:
# Random Ensemble Embeddings (Concatention of Above Embeddings)
variant_train_embeddings_e_random = np.concatenate((variant_train_embeddings_random, variant_train_embeddings_d_random), axis=1)
variant_test_embeddings_e_random = np.concatenate((variant_test_embeddings_random, variant_test_embeddings_d_random), axis=1)

### Train PCA

In [153]:
def variant_train_pca_plot(train_embeddings, model_name, train_size=0.8):
  """Applies and plots PCA on variant train embeddings."""

  variant_train_df = create_variant_df(train_size=train_size)

  variant_train_embeddings_pca = PCA(n_components=2).fit_transform(train_embeddings)
  
  X_train = [v[0] for v in variant_train_embeddings_pca]
  Y_train = [v[1] for v in variant_train_embeddings_pca]

  plt.figure(figsize=(8, 6))
  plt.scatter(X_train, Y_train, c=variant_train_df.activity.values, s=10, alpha=0.5)
  plt.title('PCA of Train Variant Embeddings (' + model_name + ')')
  plt.xlabel('Principal Component 1')
  plt.ylabel('Principal Component 2')
  plt.colorbar().set_label('Activity')
  plt.show()

In [None]:
variant_train_pca_plot(variant_train_embeddings_e, model_name='Ensemble of CNNs')

In [None]:
variant_train_pca_plot(variant_train_embeddings_e_random, model_name='Random Ensemble of CNNs')

### Test PCA

In [156]:
def variant_test_pca_plot(test_embeddings, model_name, train_size=0.8):
  """Applies and plots PCA on variant train embeddings."""

  variant_test_df = create_variant_df(test=True, train_size=train_size)

  variant_test_embeddings_pca = PCA(n_components=2).fit_transform(test_embeddings)
  
  X_test = [v[0] for v in variant_test_embeddings_pca]
  Y_test = [v[1] for v in variant_test_embeddings_pca]

  plt.figure(figsize=(8, 6))
  plt.scatter(X_test, Y_test, c=variant_test_df.activity.values, s=10, alpha=0.5)
  plt.title('PCA of Test Variant Embeddings (' + model_name + ')')
  plt.xlabel('Principal Component 1')
  plt.ylabel('Principal Component 2')
  plt.colorbar().set_label('Activity')
  plt.show()

In [None]:
variant_test_pca_plot(variant_test_embeddings_e, model_name='Ensemble of CNNs')

In [None]:
variant_test_pca_plot(variant_test_embeddings_e_random, model_name='Random Ensemble of CNNs')

### Train t-SNE

In [159]:
def variant_train_tsne_plot(train_embeddings, model_name, train_size=0.8):
  """Applies and plots t-SNE on GFP train embeddings."""

  variant_train_df = create_variant_df(train_size=train_size)

  variant_train_embeddings_tsne = TSNE(n_components=2).fit_transform(train_embeddings)

  X_tsne_train = [v[0] for v in variant_train_embeddings_tsne]
  Y_tsne_train = [v[1] for v in variant_train_embeddings_tsne]

  plt.figure(figsize=(8, 6))
  plt.scatter(X_tsne_train, Y_tsne_train, c=variant_train_df.activity.values, s=10, alpha=0.5)
  plt.title('t-SNE of Train GFP Embeddings (' + model_name + ')')
  cbar = plt.colorbar()
  cbar.set_label('Activity')
  plt.show()

In [None]:
variant_train_tsne_plot(variant_train_embeddings_e, model_name='Ensemble of CNNs')

In [None]:
variant_train_tsne_plot(variant_train_embeddings_e_random, model_name='Random Ensemble of CNNs')

### Test t-SNE

In [162]:
def variant_test_tsne_plot(test_embeddings, model_name, train_size=0.8):
  """Applies and plots t-SNE on GFP train embeddings."""

  variant_test_df = create_variant_df(test=True, train_size=train_size)

  variant_test_embeddings_tsne = TSNE(n_components=2).fit_transform(test_embeddings)

  X_tsne_test = [v[0] for v in variant_test_embeddings_tsne]
  Y_tsne_test = [v[1] for v in variant_test_embeddings_tsne]

  plt.figure(figsize=(8, 6))
  plt.scatter(X_tsne_test, Y_tsne_test, c=variant_test_df.activity.values, s=10, alpha=0.5)
  plt.title('t-SNE of Test GFP Embeddings (' + model_name + ')')
  cbar = plt.colorbar()
  cbar.set_label('Activity')
  plt.show()

In [None]:
variant_test_tsne_plot(variant_test_embeddings_e, model_name='Ensemble of CNNs')

In [None]:
variant_test_tsne_plot(variant_test_embeddings_e_random, model_name='Random Ensemble of CNNs')