This is a simple example of a VAE for BCR sequences made using a bi-directional LSTM.
LSTMs are extremely slow to train.
Performance (based on out of sample cross-entropy) of this LSTM based implementation is inferior to a standard dense layer network so therefore this is just to provide an example of how to make such an architecture.

In [None]:
# Generic imports
from __future__ import print_function
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
#import pandas as pd
import math, random, re
import time
import pickle
from Bio import SeqIO

In [2]:
#Machine learning/Stats imports 
from scipy.stats import norm
from scipy.stats import spearmanr,pearsonr
from sklearn.preprocessing import normalize
from sklearn.model_selection import train_test_split
import tensorflow as tf
import keras
from keras.layers import Input, Dense, Bidirectional, RepeatVector, Reshape
from keras.models import Model
from keras import regularizers
from keras.layers import LSTM, RepeatVector
from keras.layers import Input, Dense, Lambda, Dropout,Activation, TimeDistributed
from keras import backend as K
from keras import objectives
from keras.callbacks import EarlyStopping
from keras.layers.normalization import BatchNormalization
from keras import regularizers

from IPython.display import SVG
from keras.utils.vis_utils import model_to_dot

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [3]:
# Amino acid alphabet:
AA_ORDER = 'ACDEFGHIKLMNPQRSTVWY-'
AA_LIST = list(AA_ORDER)
AA_DICT = {c:i for i, c in enumerate(AA_LIST)}
AA_DICT_REV = {i:c for i, c in enumerate(AA_LIST)}
AA_SET = set(AA_LIST)

In [4]:
def filter_seq(seq):
    '''Filter away ambiguous character containing sequences.'''
    if set(list(seq)) <= AA_SET:
        return(seq)
    else:
        return(None)

In [5]:
def seq2onehot(seq_list):
    '''
    Translate a list of amino acid sequences into a 3D tensor with onehot encodings.
    NB. all sequences must be of equal length.
    '''
    seqlen = len(seq_list[0])
    assert(not [True for s in seq_list if len(s) != seqlen])
    onehot_tensor = np.zeros((len(seq_list), seqlen, len(AA_SET)))
    for i, seq in enumerate(seq_list):
        for j, a in enumerate(seq):
            onehot_tensor[i][j][AA_DICT[a]] = 1
    return(onehot_tensor)

In [6]:
def onehot2seq(onehot_tensor):
    '''
    Translate a 3D tensor with onehot encodings to a list of amino acid sequences.
    '''
    seq_list = list()
    for i in range(onehot_tensor.shape[0]):
        seq = list()
        for j in range(onehot_tensor.shape[1]):
            seq.append(AA_DICT_REV[onehot_tensor[i][j].argmax()])
        seq_list.append(''.join(seq))
    return(seq_list)

In [9]:
# Read in some sequences:
MAX_SEQS = 10000
fnam = 'BCR_data/spurf_heavy_chain_AHo.fasta'
seq_list = list()
for i, record in enumerate(SeqIO.parse(fnam, 'fasta')):
    if i >= MAX_SEQS:
        break
    seq_list.append(str(record.seq))
print('Input data has {} sequences.'.format(len(seq_list)))

Input data has 10000 sequences.


In [10]:
# Transform to onehot:
onehot_tensor = seq2onehot(seq_list)
print('Onehot encoded tensor has this shape: {}'.format(onehot_tensor.shape))

Onehot encoded tensor has this shape: (10000, 149, 21)


In [11]:
# Various network definitions:
batch_size = 100
input_shape = onehot_tensor.shape[1:]
input_total_dim = np.array(input_shape).prod()

latent_dim = 10
lstm_nodes = 149
#lstm_nodes = 30

epsilon_std = 1.0
def sampling(args):
    '''This function draws a sample from the multinomial defined by the latent variables.'''
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=(batch_size, latent_dim), mean=0.0, stddev=epsilon_std)
    return(z_mean + K.exp(z_log_var / 2) * epsilon)

def vae_loss(io_encoder, io_decoder):
    '''The loss function is the sum of the cross-entropy and KL divergence.'''
    # Notice that "objectives.categorical_crossentropy(io_encoder, io_decoder)" is a vector so it is averaged:
    xent_loss = input_total_dim * K.mean(objectives.categorical_crossentropy(io_encoder, io_decoder))
    kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
    return(xent_loss + kl_loss)

In [12]:
# Encoding layers:
io_encoder = Input(shape=input_shape)
lstm_encoder1 = Bidirectional(LSTM(lstm_nodes, return_sequences=True, recurrent_dropout=0.2), merge_mode='concat')(io_encoder)
lstm_encoder2 = Bidirectional(LSTM(lstm_nodes, return_sequences=False, recurrent_dropout=0.2), merge_mode='concat')(lstm_encoder1)


# Latent layers:
z_mean = Dense(latent_dim)(lstm_encoder2)
z_log_var = Dense(latent_dim)(lstm_encoder2)
z = Lambda(sampling, output_shape=(latent_dim, ))([z_mean, z_log_var])

encoder = Model(io_encoder, z_mean)
encoder.summary()
#SVG(model_to_dot(encoder, show_shapes=True).create(prog='dot', format='svg'))

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 149, 21)           0         
_________________________________________________________________
bidirectional_1 (Bidirection (None, 149, 298)          203832    
_________________________________________________________________
bidirectional_2 (Bidirection (None, 298)               534016    
_________________________________________________________________
dense_1 (Dense)              (None, 10)                2990      
Total params: 740,838
Trainable params: 740,838
Non-trainable params: 0
_________________________________________________________________


In [14]:
# Decoding layers:
rep_decoder = RepeatVector(input_shape[0])
lstm_decoder1 = Bidirectional(LSTM(lstm_nodes, return_sequences=True, recurrent_dropout=0.2), merge_mode='concat')
lstm_decoder2 = Bidirectional(LSTM(lstm_nodes, return_sequences=False, recurrent_dropout=0.2), merge_mode='concat')
decoder_out = Dense(input_total_dim, activation='sigmoid')
reshape2input = Reshape(input_shape)#, input_shape=(None, decoder_architecture[1]))
io_decoder = reshape2input(decoder_out(lstm_decoder2(lstm_decoder1(rep_decoder(z)))))


io_z = Input(shape=(latent_dim,))
io_decoder_means = reshape2input(decoder_out(lstm_decoder2(lstm_decoder1(rep_decoder(io_z)))))
decoder = Model(io_z, io_decoder_means)
decoder.summary()
#SVG(model_to_dot(decoder, show_shapes=True).create(prog='dot', format='svg'))

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         (None, 10)                0         
_________________________________________________________________
repeat_vector_2 (RepeatVecto (None, 149, 10)           0         
_________________________________________________________________
bidirectional_5 (Bidirection (None, 149, 298)          190720    
_________________________________________________________________
bidirectional_6 (Bidirection (None, 298)               534016    
_________________________________________________________________
dense_4 (Dense)              (None, 3129)              935571    
_________________________________________________________________
reshape_1 (Reshape)          (None, 149, 21)           0         
Total params: 1,660,307
Trainable params: 1,660,307
Non-trainable params: 0
_________________________________________________________________


In [15]:
vae = Model(io_encoder, io_decoder)
vae.compile(optimizer="adam", loss=vae_loss)
vae.summary()
#SVG(model_to_dot(vae, show_shapes=True).create(prog='dot', format='svg'))

____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
input_1 (InputLayer)             (None, 149, 21)       0                                            
____________________________________________________________________________________________________
bidirectional_1 (Bidirectional)  (None, 149, 298)      203832      input_1[0][0]                    
____________________________________________________________________________________________________
bidirectional_2 (Bidirectional)  (None, 298)           534016      bidirectional_1[0][0]            
____________________________________________________________________________________________________
dense_1 (Dense)                  (None, 10)            2990        bidirectional_2[0][0]            
___________________________________________________________________________________________

In [18]:
# Split dataset into train/test:
x_train, x_test = train_test_split(onehot_tensor, test_size=0.1, shuffle=True)
sl = len(x_train) // (batch_size*10)
x_train = x_train[:(sl*batch_size*10)]

In [19]:
len(x_train)

9000

In [20]:
nb_epoch = 2
early_stopping = EarlyStopping(monitor='val_loss', patience=5)
vae_log = vae.fit(x_train,
                  x_train,  # VAE is unsupervised so y=X
                  shuffle=True,
                  epochs=nb_epoch,
                  batch_size=batch_size,
                  validation_split=0.1,
                  callbacks=[early_stopping])

Train on 8100 samples, validate on 900 samples
Epoch 1/2
Epoch 2/2
 300/8100 [>.............................] - ETA: 474s - loss: 2719.9393

KeyboardInterrupt: 