In [2]:
import sys
sys.path.insert(0, "./deepsequence_model")

In [3]:
import numpy as np

import pickle
import copy

In [4]:
%load_ext autoreload

In [5]:
%autoreload 2

### MSA array 3d

In [6]:
with open("data_pickle/msa_arr.pickle", 'rb') as f:
    msa_arr = pickle.load(f)

In [7]:
msa_arr.shape

(10813, 286)

In [8]:
msa_arr_3d = np.zeros(shape=(msa_arr.shape[0], msa_arr.shape[1], 20))

In [9]:
for idx_sequence, sequence in enumerate(msa_arr):
    for idx_amino, amino_value in enumerate(sequence):
        if amino_value < 20:
            msa_arr_3d[idx_sequence, idx_amino, amino_value] = 1

In [10]:
msa_arr_3d

array([[[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 1., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 1., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 1., 0.]],

       ...,

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0.

In [11]:
import pickle
import os

filepath = "msa_arr_3d.pickle"
output_dir = "./data_pickle"
msa_output_path = os.path.join(output_dir, filepath)
with open(msa_output_path, 'wb') as f:
    pickle.dump(msa_arr_3d, f, protocol=pickle.HIGHEST_PROTOCOL)

### DataHelper

In [8]:
from helper import DataHelper

class DataHelperAugmented(DataHelper):

    def __init__(self,
                 data_pickle_path,
                 calc_weights=True,
                 theta=0.2):

        # Pass to the parents the only arguments that can be changed
        super().__init__(calc_weights=calc_weights, theta=theta) 

        # x_train = msa_3d for consistency with parent class
        self.x_train = self._unpickle_data(data_pickle_path)
        self.seq_len = self.x_train.shape[1]
        self.alphabet_size = 20

        # Compute weights and Neff
        if calc_weights:
            self._compute_weights()
        else:
            self._isotropic_weights()
        self.n_eff = np.sum(self.weights)

    
    def _unpickle_data(self, data_pickle_path):
        with open(data_pickle_path, 'rb') as pickleFile:
            return pickle.load(pickleFile)

    def _compute_weights(self):
        '''X = T.tensor3("x")
        cutoff = T.scalar("theta")
        X_flat = X.reshape((X.shape[0], X.shape[1]*X.shape[2]))
        N_list, updates = theano.map(lambda x: 1.0 / T.sum(T.dot(X_flat, x) / T.dot(x, x) > 1 - cutoff), X_flat)
        weightfun = theano.function(inputs=[X, cutoff], outputs=[N_list],allow_input_downcast=True)
        self.weights = weightfun(self.x_train, self.theta)[0]'''
        pass

    def _isotropic_weights(self):
        self.weights = np.ones(self.x_train.shape[0])

## Baseline

### Baseline DataHelper

In [9]:
path_msa_3d = 'data_pickle/msa_arr_3d.pickle'
baseline_data_helper = DataHelperAugmented(data_pickle_path=path_msa_3d, calc_weights=False)

In [10]:
baseline_data_helper.x_train.shape

(10813, 286, 20)

### Baseline Model VAE

In [11]:
from model import VAE

model = VAE(baseline_data_helper, bayesian=False)

In [12]:
model

VAE(
  (fc1): Linear(in_features=5720, out_features=1500, bias=True)
  (fc2): Linear(in_features=1500, out_features=1500, bias=True)
  (fc3_mu): Linear(in_features=1500, out_features=40, bias=True)
  (fc3_logvar): Linear(in_features=1500, out_features=40, bias=True)
  (bfc4): Linear(in_features=40, out_features=100, bias=False)
  (bfc5): Linear(in_features=100, out_features=2000, bias=True)
  (bfc6): Linear(in_features=2000, out_features=5720, bias=True)
)

### Baseline Training

In [63]:
from train import train


model_trained = train(baseline_data_helper, 
                      copy.deepcopy(model), 
                      kl_latent_scale=0.0001,
                      kl_weights_scale=0,  
                      lr=0.001,
                      num_updates=3000,
                      batch_size=100)

Using device: cuda:0
Epoch: 0
	MSE loss: 0.25085827708244324
	KL latent loss: 0.01957143284380436
	KL weights loss: 0
	Total loss:0.2508602440357208
Epoch: 100
	MSE loss: 0.023568730801343918
	KL latent loss: 2.823861598968506
	KL weights loss: 0
	Total loss:0.023851117119193077
Epoch: 200
	MSE loss: 0.02263300120830536
	KL latent loss: 5.465763568878174
	KL weights loss: 0
	Total loss:0.02317957766354084
Epoch: 300
	MSE loss: 0.020691299811005592
	KL latent loss: 8.810708045959473
	KL weights loss: 0
	Total loss:0.02157237008213997
Epoch: 400
	MSE loss: 0.01925746724009514
	KL latent loss: 10.130838394165039
	KL weights loss: 0
	Total loss:0.02027055062353611
Epoch: 500
	MSE loss: 0.01954415999352932
	KL latent loss: 11.097723960876465
	KL weights loss: 0
	Total loss:0.02065393328666687
Epoch: 600
	MSE loss: 0.018548965454101562
	KL latent loss: 11.82666015625
	KL weights loss: 0
	Total loss:0.019731631502509117
Epoch: 700
	MSE loss: 0.018983354791998863
	KL latent loss: 12.3188018798

## Fine-tuning with MSA representation

OpenFold is trained by giving as an input the MSA of the protein family (which is the input of the baseline model as well). <br>
The dimension of the MSA representation is smaller that the original MSA because only 508 sequences are sampled. <br>
The indices of the original MSA that have been sampled to re-produce the MSA representation are contained in sel_seq.pickle

In [16]:
# Open output from OpenFold and extract the MSA representation

with open("data_pickle/prediction_result/prediction_result.pickle", "rb") as file:
    af_repr1 = pickle.load(file)['msa'].cpu().numpy()

with open("data_pickle/prediction_result/prediction_result_2.pickle", "rb") as file:
    af_repr2 = pickle.load(file)['msa'].cpu().numpy()

with open("data_pickle/prediction_result/prediction_result_3.pickle", "rb") as file:
    af_repr3 = pickle.load(file)['msa'].cpu().numpy()

with open("data_pickle/prediction_result/prediction_result_4.pickle", "rb") as file:
    af_repr4 = pickle.load(file)['msa'].cpu().numpy()

with open("data_pickle/prediction_result/prediction_result_5.pickle", "rb") as file:
    af_repr5 = pickle.load(file)['msa'].cpu().numpy()

with open("data_pickle/prediction_result/prediction_result_6.pickle", "rb") as file:
    af_repr6 = pickle.load(file)['msa'].cpu().numpy()

with open("data_pickle/prediction_result/prediction_result_7.pickle", "rb") as file:
    af_repr7 = pickle.load(file)['msa'].cpu().numpy()

with open("data_pickle/prediction_result/prediction_result_8.pickle", "rb") as file:
    af_repr8 = pickle.load(file)['msa'].cpu().numpy()

with open("data_pickle/prediction_result/prediction_result_9.pickle", "rb") as file:
    af_repr9 = pickle.load(file)['msa'].cpu().numpy()

with open("data_pickle/prediction_result/prediction_result_10.pickle", "rb") as file:
    af_repr10 = pickle.load(file)['msa'].cpu().numpy()

In [18]:
af_repr = np.stack((af_repr1, af_repr2, af_repr3, af_repr4, af_repr5, af_repr6, af_repr7, af_repr8, af_repr9, af_repr10), axis=0).reshape(10*512, 286, 256)
af_repr.shape

(5120, 286, 256)

In [19]:
# Check reshape correctness

(af_repr[512:512*2] == af_repr2).all()

True

In [20]:
# Open sampled indices from the MSA to make the MSA representation

with open("data_pickle/sel_seq/sel_seq.pickle", "rb") as file:
    sel_seq_1 = pickle.load(file)

with open("data_pickle/sel_seq/sel_seq_2.pickle", "rb") as file:
    sel_seq_2 = pickle.load(file)

with open("data_pickle/sel_seq/sel_seq_3.pickle", "rb") as file:
    sel_seq_3 = pickle.load(file)

with open("data_pickle/sel_seq/sel_seq_4.pickle", "rb") as file:
    sel_seq_4 = pickle.load(file)

with open("data_pickle/sel_seq/sel_seq_5.pickle", "rb") as file:
    sel_seq_5 = pickle.load(file)

with open("data_pickle/sel_seq/sel_seq_6.pickle", "rb") as file:
    sel_seq_6 = pickle.load(file)

with open("data_pickle/sel_seq/sel_seq_7.pickle", "rb") as file:
    sel_seq_7 = pickle.load(file)

with open("data_pickle/sel_seq/sel_seq_8.pickle", "rb") as file:
    sel_seq_8 = pickle.load(file)

with open("data_pickle/sel_seq/sel_seq_9.pickle", "rb") as file:
    sel_seq_9 = pickle.load(file)

with open("data_pickle/sel_seq/sel_seq_10.pickle", "rb") as file:
    sel_seq_10 = pickle.load(file)

In [21]:
sel_seq = np.stack((sel_seq_1, sel_seq_2, sel_seq_3, sel_seq_4, sel_seq_5, sel_seq_6, sel_seq_7, sel_seq_8, sel_seq_9, sel_seq_10)).reshape(512*10)
sel_seq.shape

(5120,)

In [33]:
with open("data_pickle/not_sel_seq.pickle", "rb") as file:
    not_sel_seq = pickle.load(file)

In [36]:
# Just to check
# The index of the selected and non selected amino acids sum up to the total leght of the MSA

len(sel_seq_1) + len(not_sel_seq)

10813

### PCA 

The MSA representation has 256 channels. In order to reproduce the same shape as the normal MSA, we perform PCA on the array

In [22]:
# Reshape the MSA representation as a array 2d, with the challens as second dimension

af_repr_2d = af_repr.reshape(-1, 256)

In [23]:
from sklearn.decomposition import PCA

In [24]:
pca = PCA(n_components=20)

In [25]:
res = pca.fit_transform(af_repr_2d)

In [26]:
# Reshape the MSA representation as a 3d array

res_3d = res.reshape(512*10, 286, 20)

In [27]:
res_3d[0][0]

array([ 82.41776  ,  -1.7027469,  30.132208 ,   2.3998735,  17.69152  ,
         8.770821 ,  -2.0832384,  18.866558 ,  15.341218 , -14.330871 ,
       -22.691286 ,  -9.659728 , -15.043725 ,  -5.754189 ,  27.090532 ,
       -16.156355 , -13.270078 ,  19.071564 , -29.013563 ,  42.09233  ],
      dtype=float32)

### Min-Max scaling

Scale the distribution between 0 and 1 like the MSA one-hot encoding

In [28]:
res_3d_scaled = np.zeros(res_3d.shape)

In [29]:
for i, matr_2d in enumerate(res_3d):
    for j, row in enumerate(matr_2d):

        vect = row
        vect = vect - min(vect)
        vect = vect/max(vect)

        res_3d_scaled[i, j] = vect

In [30]:
# All the number are between 0 and 1
# The index with the highest value now is 1

res_3d_scaled[0][0]

array([1.        , 0.24509101, 0.53078222, 0.28190848, 0.41913784,
       0.33908224, 0.24167642, 0.42968276, 0.3980459 , 0.1317645 ,
       0.05673698, 0.17368397, 0.12536724, 0.20873281, 0.50348586,
       0.11538235, 0.1412842 , 0.43152252, 0.        , 0.63811404])

### Finetuning DataHelper

In [31]:
finetune_data_helper = DataHelperAugmented(data_pickle_path=path_msa_3d, calc_weights=False)

# x_train_finetune holds the MSA representation as the input for the finetune of the model
finetune_data_helper.x_train_finetune = res_3d_scaled

# x_train holds the output of the VAE, 
# therefore during finetuning the model learns to reconstruct the MSA one-hot from the MSA representation
# IMPORTANT: the MSA one-hot is sliced with sel_seq, which contains the indices sampled to obtain the MSA representation
finetune_data_helper.x_train = finetune_data_helper.x_train[sel_seq]

In [65]:
from train import train

model_finetune = train(finetune_data_helper, 
                       copy.deepcopy(model_trained),
                       finetune=True, 
                       kl_latent_scale=0.0001,
                       kl_weights_scale=0,  
                       lr=0.001,
                       num_updates=1000,
                       batch_size=100)

Using device: cuda:0
Epoch: 0
	MSE loss: 0.04015533626079559
	KL latent loss: 132.97535705566406
	KL weights loss: 0
	Total loss:0.05345287173986435
Epoch: 100
	MSE loss: 0.018478505313396454
	KL latent loss: 11.365095138549805
	KL weights loss: 0
	Total loss:0.019615015015006065
Epoch: 200
	MSE loss: 0.014805679209530354
	KL latent loss: 15.02550983428955
	KL weights loss: 0
	Total loss:0.01630822941660881
Epoch: 300
	MSE loss: 0.01240499783307314
	KL latent loss: 18.77210235595703
	KL weights loss: 0
	Total loss:0.014282207936048508
Epoch: 400
	MSE loss: 0.010501372627913952
	KL latent loss: 20.828603744506836
	KL weights loss: 0
	Total loss:0.012584232725203037
Epoch: 500
	MSE loss: 0.008410325273871422
	KL latent loss: 22.045391082763672
	KL weights loss: 0
	Total loss:0.010614864528179169
Epoch: 600
	MSE loss: 0.007220648694783449
	KL latent loss: 23.404272079467773
	KL weights loss: 0
	Total loss:0.009561075828969479
Epoch: 700
	MSE loss: 0.004951528739184141
	KL latent loss: 23.

In [66]:
# Lower learning rate

model_finetune2 = train(finetune_data_helper, 
                       copy.deepcopy(model_finetune),
                       finetune=True, 
                       kl_latent_scale=0.0001,
                       kl_weights_scale=0,  
                       lr=0.0001,
                       batch_size=100,
                       num_updates=1000)

Using device: cuda:0
Epoch: 0
	MSE loss: 0.0031853094696998596
	KL latent loss: 23.8128604888916
	KL weights loss: 0
	Total loss:0.005566595122218132
Epoch: 100
	MSE loss: 0.002699552569538355
	KL latent loss: 22.800432205200195
	KL weights loss: 0
	Total loss:0.0049795955419540405
Epoch: 200
	MSE loss: 0.002934566466137767
	KL latent loss: 22.27725601196289
	KL weights loss: 0
	Total loss:0.005162292160093784
Epoch: 300
	MSE loss: 0.0024571844842284918
	KL latent loss: 22.631139755249023
	KL weights loss: 0
	Total loss:0.004720298573374748
Epoch: 400
	MSE loss: 0.00217860727570951
	KL latent loss: 22.12660026550293
	KL weights loss: 0
	Total loss:0.004391266964375973
Epoch: 500
	MSE loss: 0.002409042092040181
	KL latent loss: 22.482479095458984
	KL weights loss: 0
	Total loss:0.004657289944589138
Epoch: 600
	MSE loss: 0.0026889117434620857
	KL latent loss: 22.1121768951416
	KL weights loss: 0
	Total loss:0.0049001295119524
Epoch: 700
	MSE loss: 0.0023127186577767134
	KL latent loss: 2

## Trainig from scratch with AF representation

In [41]:
model_finetune_scratch = train(finetune_data_helper, 
                              copy.deepcopy(model),
                              finetune=True, 
                              kl_latent_scale=0.0001,
                              kl_weights_scale=0,  
                              lr=0.001,
                              batch_size=100,
                              num_updates=3000)

Using device: cuda:0
Epoch: 0
	MSE loss: 0.25083133578300476
	KL latent loss: 0.09518483281135559
	KL weights loss: 0
	Total loss:0.25084084272384644
Epoch: 100
	MSE loss: 0.02280583791434765
	KL latent loss: 2.9082722663879395
	KL weights loss: 0
	Total loss:0.023096665740013123
Epoch: 200
	MSE loss: 0.021986011415719986
	KL latent loss: 5.091205596923828
	KL weights loss: 0
	Total loss:0.022495131939649582
Epoch: 300
	MSE loss: 0.021461665630340576
	KL latent loss: 7.8939361572265625
	KL weights loss: 0
	Total loss:0.022251058369874954
Epoch: 400
	MSE loss: 0.0214969664812088
	KL latent loss: 6.69047737121582
	KL weights loss: 0
	Total loss:0.022166013717651367
Epoch: 500
	MSE loss: 0.01927795447409153
	KL latent loss: 10.101778030395508
	KL weights loss: 0
	Total loss:0.0202881321310997
Epoch: 600
	MSE loss: 0.018945792689919472
	KL latent loss: 9.728129386901855
	KL weights loss: 0
	Total loss:0.01991860568523407
Epoch: 700
	MSE loss: 0.017971545457839966
	KL latent loss: 11.455566

In [55]:
# Lower learning rate

model_finetune_scratch2 = train(finetune_data_helper, 
                              copy.deepcopy(model_finetune_scratch),
                              finetune=True, 
                              kl_latent_scale=0.0001,
                              kl_weights_scale=0,  
                              lr=0.0001,
                              batch_size=100,
                              num_updates=1000)

Using device: cuda:0
Epoch: 0
	MSE loss: 0.003354511922225356
	KL latent loss: 20.56233787536621
	KL weights loss: 0
	Total loss:0.005410745739936829
Epoch: 100
	MSE loss: 0.002729611238464713
	KL latent loss: 18.48223304748535
	KL weights loss: 0
	Total loss:0.004577834624797106
Epoch: 200
	MSE loss: 0.0028627924621105194
	KL latent loss: 18.56342887878418
	KL weights loss: 0
	Total loss:0.0047191353514790535
Epoch: 300
	MSE loss: 0.0030346328858286142
	KL latent loss: 18.31036949157715
	KL weights loss: 0
	Total loss:0.004865669645369053
Epoch: 400
	MSE loss: 0.00245704990811646
	KL latent loss: 18.10809326171875
	KL weights loss: 0
	Total loss:0.004267859272658825
Epoch: 500
	MSE loss: 0.0027011772617697716
	KL latent loss: 18.410919189453125
	KL weights loss: 0
	Total loss:0.00454226927831769
Epoch: 600
	MSE loss: 0.002803846262395382
	KL latent loss: 17.823972702026367
	KL weights loss: 0
	Total loss:0.004586243536323309
Epoch: 700
	MSE loss: 0.002847309224307537
	KL latent loss: 

In [56]:
# Even lower learning rate

model_finetune_scratch3 = train(finetune_data_helper, 
                              copy.deepcopy(model_finetune_scratch2),
                              finetune=True, 
                              kl_latent_scale=0.0001,
                              kl_weights_scale=0,  
                              lr=0.00001,
                              batch_size=100,
                              num_updates=1000)

Using device: cuda:0
Epoch: 0
	MSE loss: 0.0027777275536209345
	KL latent loss: 18.18413734436035
	KL weights loss: 0
	Total loss:0.004596141166985035
Epoch: 100
	MSE loss: 0.0027726776897907257
	KL latent loss: 18.381969451904297
	KL weights loss: 0
	Total loss:0.004610874690115452
Epoch: 200
	MSE loss: 0.002794278785586357
	KL latent loss: 17.877988815307617
	KL weights loss: 0
	Total loss:0.004582077730447054
Epoch: 300
	MSE loss: 0.0026489070151001215
	KL latent loss: 17.560789108276367
	KL weights loss: 0
	Total loss:0.004404985811561346
Epoch: 400
	MSE loss: 0.0025657638907432556
	KL latent loss: 17.6796817779541
	KL weights loss: 0
	Total loss:0.004333732184022665
Epoch: 500
	MSE loss: 0.0027674755547195673
	KL latent loss: 17.665796279907227
	KL weights loss: 0
	Total loss:0.004534055013209581
Epoch: 600
	MSE loss: 0.0024628043174743652
	KL latent loss: 17.439754486083984
	KL weights loss: 0
	Total loss:0.004206779878586531
Epoch: 700
	MSE loss: 0.0029619047418236732
	KL latent

### Testing the model

The actual accuracy is higher that the one below because many sequences have missing amino acids (-) and in this simple evaluation we predict at each position with the argmax over all 20 amino acids. Therefore each prediction at missing position is always wrong unless 0 by chance.

In [67]:
from model_performance_tests import reconstruction_accuracy_per_aminoacid

In [68]:
# Performance randomly initialized model

reconstruction_accuracy_per_aminoacid(model, data_helper=baseline_data_helper, n_samples=10)

4.965034965034965

In [69]:
# Performance baseline model

reconstruction_accuracy_per_aminoacid(model_trained, data_helper=baseline_data_helper, n_samples=100)

59.65734265734269

In [70]:
# Performance finetuned model - with MSA representation

reconstruction_accuracy_per_aminoacid(model_finetune, data_helper=finetune_data_helper, n_samples=100, finetune=True)

70.00349650349649

In [71]:
# Performance finetuned model with lower learning rate - with MSA representation

reconstruction_accuracy_per_aminoacid(model_finetune2, data_helper=finetune_data_helper, n_samples=100, finetune=True)

70.38111888111887

In [72]:
# Performance model learning from scratch - with MSA representation

reconstruction_accuracy_per_aminoacid(model_finetune_scratch, data_helper=finetune_data_helper, n_samples=100, finetune=True)

66.7867132867133

In [62]:
# Performance model learning from scratch - with MSA representation

reconstruction_accuracy_per_aminoacid(model_finetune_scratch3, data_helper=finetune_data_helper, n_samples=100, finetune=True)

70.05244755244755

### Stability experiments

In [11]:
# TO BE DONE
#baseline_data_helper.delta_elbo(baseline_data_helper,[(175,"A","C")], N_pred_iterations=500)