Protein VAE

In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
assert(torch.cuda.is_available())
print("Torch version:", torch.__version__)

import numpy as np
import argparse
import os
import timeit
        
from sklearn.metrics import accuracy_score
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split as tts

Torch version: 1.3.1


change parameters

In [0]:
args_dict = {
    "lr": 5e-4,
    "batch_size_train": 201,
    "batch_size_test": 25,
    "num_epochs": 200,
    "latent_dim": 16,
    "device": 0,
    "dataset": "nostruc"
}

In [0]:
cuda=True       # for training with gpu, make it true. For inference with cpu make false
load=False        # load in the model (default provided is 16 dimensional for nostruc data)
train=True      # Make true to train the model presuming you have the dataset
new_metal=True   # Make true to produce 'batch_size' samples of a given protein
                     # see the docs on github for description of how to do this
  
# CHANGE THIS
DATA_PATH = '/content/drive/My Drive/protein VAE/assembled_data_mb.npy'

In [5]:
if cuda:
    if args_dict["dataset"]=="nostruc":    
        DATA = np.load(DATA_PATH)
    else:
        DATA = np.load('/scratch0/DeNovo/assembled_data_mbflip_fold.npy') #IGNORE 
    print("Full data set shape: {0}".format(DATA.shape))
    
    train_set, test_set = tts(DATA, test_size=0.15, shuffle=True)
    dev_train, dev_test = tts(DATA, train_size=0.15, test_size=0.1, shuffle=True)
    
    print("training set size: {0}".format(train_set.shape[0]))
    print("test set size: {0}".format(test_set.shape[0]))
    print("development training set size: {0}".format(dev_train.shape[0]))
    print("development test set size: {0}".format(dev_test.shape[0]))
    
    # CHANGE THIS, PICK DATA
    #data = dev_train
    #data_test = dev_test
    data = train_set
    data_test = test_set
    
    n=data.shape[0]
    X_dim = data.shape[1]
else:
    print("No DATA")
    if args_dict["dataset"]=="nostruc":
        X_dim=3088
    else:
        X_dim=4353


Full data set shape: (147842, 3088)
training set size: 125665
test set size: 22177
development training set size: 22176
development test set size: 14785


In [0]:
if cuda:
    os.environ["CUDA_VISIBLE_DEVICES"]=str(args_dict['device'])

#spec batch size
batch_size=args_dict['batch_size_train']
#learning rate
lr=args_dict['lr']
# layer sizes
hidden_size=[512,256,128,args_dict['latent_dim']]
conv_size = [1, 64, 128, 1024]

def convert(x):
    m = len(x)
    n = len(x[0])
    y = np.empty([1, n, m])
    y[0] = np.transpose(x)
    y = np.transpose(y, (2,0,1))
    return y

class Flatten(nn.Module):
    def forward(self, input):
        #print(input.size())
        return input.view(input.size(0), -1)

class Unflatten(nn.Module):
    def __init__(self, channel, height):
        super(Unflatten, self).__init__()
        self.channel = channel
        self.height = height

    def forward(self, input):
        return input.view(input.size(0), self.channel, self.height)

In [0]:
class feed_forward(torch.nn.Module):
    def __init__(self, input_size, hidden_sizes, batch_size):
        super().__init__()
        
        self.input_size = input_size
        self.hidden_sizes = hidden_sizes
        self.batch_size = batch_size


        self.encoder = nn.Sequential(
            nn.Conv1d(conv_size[0], conv_size[1], kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv1d(conv_size[1], conv_size[2], kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            Flatten(),
            nn.Linear(98560, conv_size[3]),
            nn.ReLU()
        )   
 
        self.fc_mu = torch.nn.Linear(conv_size[3], hidden_size[3])
        self.fc_var = torch.nn.Linear(conv_size[3], hidden_size[3])

        
        self.decoder = nn.Sequential(
            nn.Linear(hidden_sizes[3]+8, conv_size[3]),
            nn.ReLU(),
            nn.Linear(1024, 98560),
            nn.ReLU(),
            Unflatten(128, 770),
            nn.ReLU(),
            nn.ConvTranspose1d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(64, 1, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        mu, logvar = self.fc_mu(h), self.fc_var(h)
        return mu, logvar

    def decode(self, z):
        z = self.decoder(z)
        return z

    def sample_z(self, mu, log_var):
        # Using reparameterization trick to sample from a gaussian
        eps = torch.randn(self.batch_size, self.hidden_sizes[-1]).cuda()
        #print(mu.size(), log_var.size(), eps.size())
        #print((mu + torch.exp(log_var / 2) * eps).size())
        return mu + torch.exp(log_var / 2) * eps
    
    def forward(self, x, code):

        mu, log_var = self.encode(x)
        z = self.sample_z(mu, log_var)
        #print(z.size(), code.size())
        z = torch.cat((z, code), 1)

        return self.decode(z), mu, log_var



In [10]:

# init the networks
if cuda:
    ff = feed_forward(X_dim, hidden_size, batch_size).cuda()
else:
    ff = feed_forward(X_dim, hidden_size, batch_size)

# change the loading bit here
if load: 
    ff.load_state_dict(torch.load("models/metal16_nostruc", map_location=lambda storage, loc: storage))


# Loss and Optimizer
solver = optim.Adam(ff.parameters(), lr=lr)
burn_in_counter = 0
tick = 0


# number of epochs
num_epochs=args_dict['num_epochs']

if train:
    
    patience = 100 # early stopping
    patience_counter = patience
    best_val_acc = -np.inf
    checkpoint_filename = 'checkpoint.pt' # save best model
    
    for its in range(num_epochs):
        
        #############################
        # TRAINING 
        #############################
        
        ff.train()
        scores=[]
        data=shuffle(data)
        
        if its%10 == 0:
          print("Epoch: {0}/{1}  Latent: {2}".format(its,num_epochs,hidden_size[-1]))
        
        start_time = timeit.default_timer()
        
        for it in range(n // batch_size):
        
            if args_dict["dataset"]=="nostruc":
                
                x_batch=data[it * batch_size: (it + 1) * batch_size]
                code = x_batch[:,-8:]
                x_batch = x_batch[:,:3080]

                x_batch = convert(x_batch)
                #code = convert(code)

                if cuda:
                    X = torch.from_numpy(x_batch).cuda().type(torch.cuda.FloatTensor)
                    C = torch.from_numpy(code).cuda().type(torch.cuda.FloatTensor)
                else:
                    X = torch.from_numpy(x_batch).type(torch.FloatTensor)
                    C = torch.from_numpy(code).type(torch.FloatTensor)

                
            else:
                x_batch=data[it * batch_size: (it + 1) * batch_size]
                code = x_batch[:,-8:]
                structure = x_batch[:,3080:-8]
                x_batch = x_batch[:,:3080]

                if cuda:
                    X = torch.from_numpy(x_batch).cuda().type(torch.cuda.FloatTensor)
                    C = torch.from_numpy(code).cuda().type(torch.cuda.FloatTensor)
                    S = torch.from_numpy(structure).cuda().type(torch.cuda.FloatTensor) 
                else:
                    X = torch.from_numpy(x_batch).type(torch.FloatTensor)
                    C = torch.from_numpy(code).type(torch.FloatTensor)
                    S = torch.from_numpy(structure).type(torch.FloatTensor)  
    

            
            #turf last gradients
            solver.zero_grad()
            
            
            if args_dict["dataset"]=="struc":
            # Forward
                x_sample, z_mu, z_var = ff(X, C, S)
            else:
                x_sample, z_mu, z_var = ff(X, C)
            
    
                
            # Loss
            recon_loss = nn.functional.binary_cross_entropy(x_sample, X, size_average=False) # by setting to false it sums instead of avg.
            kl_loss = 0.5 * torch.sum(torch.exp(z_var) + z_mu**2 - 1. - z_var)
            #kl_loss=KL_Div(z_mu,z_var,unit_gauss=True,cuda=True)
            kl_loss = kl_loss*burn_in_counter
            loss = recon_loss + kl_loss
            
            
            # Backward
            loss.backward()
        
            # Update
            solver.step()
            
            
            
            len_aa=140*22
            y_label=np.argmax(x_batch[:,:len_aa].reshape(batch_size,-1,22), axis=2)
            y_pred =np.argmax(x_sample[:,:len_aa].cpu().data.numpy().reshape(batch_size,-1,22), axis=2)
            
            
            # can use argmax again for clipping as it uses the first instance of 21
            # loop with 256 examples is only about 3 milliseconds                      
            for idx, row in enumerate(y_label):
                scores.append(accuracy_score(row[:np.argmax(row)],y_pred[idx][:np.argmax(row)]))
                
        elapsed = float(timeit.default_timer() - start_time)
        
        # if its%10 == 0:
        #   print("Time: %.2fs" % (elapsed*10))
        #   print("Patience: {0}".format(patience))
        #   print("Tra Acc: {0}".format(np.mean(scores)))

        print("Time: %.2fs" % (elapsed))
        print("Patience: {0}".format(patience_counter))
        print("Tra Acc: {0}".format(np.mean(scores)))
                
        if its==(num_epochs-1):
            with open('latent_results_'+str(args_dict["dataset"])+'.txt', 'a') as f:
                f.write(str(args_dict['latent_dim'])+' train '+str(np.mean(scores)))


        if its>300 and burn_in_counter<1.0:
            burn_in_counter+=0.003
        
        #############################
        # Validation 
        #############################
        
        scores=[]

        ff.eval()
        for it in range(data_test.shape[0] // batch_size):
            x_batch=data_test[it * batch_size: (it + 1) * batch_size]

            if args_dict["dataset"]=="nostruc":

                x_batch=data_test[it * batch_size: (it + 1) * batch_size]
                code = x_batch[:,-8:]
                x_batch = convert(x_batch[:,:3080])

                if cuda:
                    X = torch.from_numpy(x_batch).cuda().type(torch.cuda.FloatTensor)
                    C = torch.from_numpy(code).cuda().type(torch.cuda.FloatTensor)
                else:
                    X = torch.from_numpy(x_batch).type(torch.FloatTensor)
                    C = torch.from_numpy(code).type(torch.FloatTensor)


            else:
                
                x_batch=data_test[it * batch_size: (it + 1) * batch_size]
                code = x_batch[:,-8:]
                structure = x_batch[:,3080:-8]
                x_batch = convert(x_batch[:,:3080])


                if cuda:
                    X = torch.from_numpy(x_batch).cuda().type(torch.cuda.FloatTensor)
                    C = torch.from_numpy(code).cuda().type(torch.cuda.FloatTensor)
                    S = torch.from_numpy(structure).cuda().type(torch.cuda.FloatTensor)
                else:
                    X = torch.from_numpy(x_batch).type(torch.FloatTensor)
                    C = torch.from_numpy(code).type(torch.FloatTensor)
                    S = torch.from_numpy(structure).type(torch.FloatTensor)


            if args_dict["dataset"]=="struc":
            # Forward
                x_sample, z_mu, z_var = ff(X, C, S)
            else:
                x_sample, z_mu, z_var = ff(X, C)

                            

        
            len_aa=140*22
            y_label=np.argmax(x_batch[:,:len_aa].reshape(batch_size,-1,22), axis=2)
            y_pred =np.argmax(x_sample[:,:len_aa].cpu().data.numpy().reshape(batch_size,-1,22), axis=2)
            #print(y_pred.shape)
            #print(y_pred)

            
            for idx, row in enumerate(y_label):
                #zero_scores.append(accuracy_score(row[:np.argmax(row)],y_zeros[idx][:np.argmax(row)]))
                scores.append(accuracy_score(row[:np.argmax(row)],y_pred[idx][:np.argmax(row)]))
        
        
        # if its%10 == 0:
        #   acc = np.mean(scores)
        #   print("Val Acc: {0}".format(acc))
        #   print()

        acc = np.mean(scores)
        print('Val Acc: {0}'.format(acc))
        #print('Val Acc w/ zeors: {0}'.format(np.mean(zero_scores)))
        print()
        
        if acc > best_val_acc:           
          torch.save(ff.state_dict(), checkpoint_filename)
          best_val_acc = acc
          patience_counter = patience               
        else:
          patience_counter -= 1
          if patience_counter <= 0:
                ff.load_state_dict(torch.load(checkpoint_filename))
                break
          
        if its==(num_epochs-1):
            with open('latent_results_'+str(args_dict["dataset"])+'.txt', 'a') as f:
                f.write(str(args_dict['latent_dim'])+' test '+str(np.mean(scores)))



# # saves if its running on gpu          
# if cuda:
#     torch.save(ff.state_dict(), 'metal'+str(args_dict['latent_dim'])+"_"+str(args_dict['dataset']))




Epoch: 0/200  Latent: 16




Time: 83.82s
Patience: 100
Tra Acc: 0.17268636151794636
(201, 140)
[[ 2  8  8 ... 21 21 21]
 [ 1 19  1 ... 21 21 21]
 [18  7  9 ... 21 21 21]
 ...
 [ 3  1  2 ... 21 21 21]
 [ 3  6 11 ... 21 21 21]
 [ 0 18 11 ... 21 21 21]]
(201, 140)
[[ 1 14 14 ... 21 21 21]
 [ 3  9 19 ... 21 21 21]
 [ 1  1  1 ... 21 21 21]
 ...
 [ 9  9  6 ... 21 21 21]
 [11  6 11 ... 21 21 21]
 [ 3 16  6 ... 21 21 21]]
(201, 140)
[[18 10 11 ... 21 21 21]
 [ 0 18  1 ... 21 21 21]
 [ 3  9  6 ... 21 21 21]
 ...
 [ 3 16 11 ... 21 21 21]
 [ 3  6  8 ... 21 21 21]
 [12 10 11 ... 21 21 21]]
(201, 140)
[[17  2  8 ... 21 21 21]
 [17  9  8 ... 21 21 21]
 [ 3  6  6 ... 21 21 21]
 ...
 [ 3  9  8 ... 21 21 21]
 [12 19 13 ... 21 21 21]
 [11  7  1 ... 21 21 21]]
(201, 140)
[[ 3  9  8 ... 21 21 21]
 [12  8 11 ... 21 21 21]
 [17 16  9 ... 21 21 21]
 ...
 [ 9  9  6 ... 21 21 21]
 [ 3 19  4 ... 21 21 21]
 [16 12  1 ... 21 21 21]]
(201, 140)
[[18  1  1 ... 21 21 21]
 [ 3  1  8 ... 21 21 21]
 [ 1  0  1 ... 21 21 21]
 ...
 [ 3 16 18 ... 21 

KeyboardInterrupt: ignored

In [0]:

# # saves if its running on gpu          
# if cuda:
#     torch.save(ff.state_dict(), 'metal'+str(args_dict['latent_dim'])+"_"+str(args_dict['dataset']))


RESULT: (dev set 15% train, 10% validation)

200 epochs (~50 min)

Tra Acc: 0.9990800993620171

Val Acc: 0.999...........

