In [1]:
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD, Adam
import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

## Datasets & setup

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
FILE_PATH = '/media/scratch/yuhaowan/data/'
NGRAM = False
m = 473047 #size of training set
n = 591309 #size of full dataset
if NGRAM:
    S_train = open(FILE_PATH + 's_tr_n.pkl', 'rb') # (apt, pep, apt_prime, pep_prime, pep_prime_pmf, indicator)
    S_test = open(FILE_PATH + 's_te_n.pkl', 'rb') # (apt, pep)
    S_new = open(FILE_PATH + 's_new_n.pkl', 'rb') # (apt, pep)
    train_loss_samples = open(FILE_PATH + 'tr_loss_n.pkl', 'rb') # (apt, pep)
    test_loss_samples = open(FILE_PATH + 'te_loss_n.pkl', 'rb') # (apt, pep)
    prime_train_loss_samples = open(FILE_PATH + 'ptr_loss_n.pkl', 'rb') # (apt, pep, pep_pmf, indicator)
    prime_test_loss_samples = open(FILE_PATH + 'pte_loss_n.pkl', 'rb') # (apt, pep, pep_pmf, indicator)
else:
    S_train = open(FILE_PATH + 's_tr.pkl', 'rb')
    S_test = open(FILE_PATH + 's_te.pkl', 'rb')
    S_new = open(FILE_PATH + 's_new.pkl', 'rb')
    train_loss_samples = open(FILE_PATH + 'tr_loss.pkl', 'rb')
    test_loss_samples = open(FILE_PATH + 'te_loss.pkl', 'rb')
    prime_train_loss_samples = open(FILE_PATH + 'ptr_loss.pkl', 'rb')
    prime_test_loss_samples = open(FILE_PATH + 'pte_loss.pkl', 'rb')

## NN Model

In [None]:
class ConvNetSimple(nn.Module):
    def __init__(self):
        super(Conv1dModelSimple, self).__init__()
        self.cnn_apt_1 = nn.Conv1d(40, 100, 3) 
        self.cnn_apt_2 = nn.Conv1d(100, 50, 1)
        
        self.cnn_pep_1 = nn.Conv1d(8, 50, 3)
        self.cnn_pep_2 = nn.Conv1d(50, 25, 1)
        self.cnn_pep_3 = nn.Conv1d(25, 10, 1)

        
        self.relu = nn.ReLU()
        self.name = "ConvNetSimple"
        self.maxpool = nn.MaxPool1d(2) 
        
        self.cnn_apt = nn.Sequential(self.cnn_apt_1, self.maxpool, self.relu, self.cnn_apt_2, self.relu)
        self.cnn_pep = nn.Sequential(self.cnn_pep_1, self.maxpool, self.relu, self.cnn_pep_2, self.relu)
        
        self.fc1 = nn.Linear(275, 1)
    
    def forward(self, apt, pep):
        apt = self.cnn_apt(apt)
        apt = self.relu(apt)
        pep = self.cnn_pep(pep)
        pep = self.relu(pep)
        apt = apt.view(-1, 1).T
        pep = pep.view(-1, 1).T
        
        x = torch.cat((apt, pep), 1)
        x = self.fc1(x)
        x = torch.sigmoid(x)
        return x

In [None]:
class LinearConv1d(nn.Module):
    def __init__(self):
        super(LinearConv1dModel, self).__init__()
        self.cnn_apt_1 = nn.Conv1d(40, 100, 3) 
        self.cnn_pep_1 = nn.Conv1d(8, 50, 3)

        self.relu = nn.ReLU()
        self.name = "LinearConv1d"
        self.maxpool = nn.MaxPool1d(2) 
        
        self.cnn_apt = nn.Sequential(self.cnn_apt_1, self.maxpool, self.relu)
        self.cnn_pep = nn.Sequential(self.cnn_pep_1, self.maxpool, self.relu)
        
        self.fc1 = nn.Linear(550, 1)
    
    def forward(self, apt, pep):
        apt = self.cnn_apt(apt)
        pep = self.cnn_pep(pep)
        apt = apt.view(-1, 1).T
        pep = pep.view(-1, 1).T
        
        x = torch.cat((apt, pep), 1)
        x = self.fc1(x)
        x = torch.sigmoid(x)
        return x

In [None]:
class LinearNet(nn.Module):
    def __init__(self):
        super(TrueLinearNet, self).__init__()
        self.lin_apt_1 = nn.Linear(160, 100) 
        self.lin_apt_2 = nn.Linear(100, 50)
        self.lin_apt_3 = nn.Linear(50, 10)
        
        self.lin_pep_1 = nn.Linear(160, 50)
        self.lin_pep_2 = nn.Linear(50, 10)

        self.relu = nn.ReLU()
        
        self.name = "LinearNet"
        
        self.lin_apt = nn.Sequential(self.lin_apt_1, self.lin_apt_2, self.lin_apt_3)
        self.lin_pep = nn.Sequential(self.lin_pep_1, self.lin_pep_2)
        
        self.fc1 = nn.Linear(20, 1)
        
    def forward(self, apt, pep):
        apt = apt.view(-1, 1).T
        pep = pep.view(-1, 1).T
        apt = self.lin_apt(apt)
        pep = self.lin_pep(pep)
        apt = apt.view(-1, 1).T
        pep = pep.view(-1, 1).T
        
        x = torch.cat((apt, pep), 1)
        x = self.fc1(x)
        x = torch.sigmoid(x)
        return x

In [None]:
# This model is too complex for our input sequence size
class ConvNetComplex(nn.Module):
    def __init__(self):
        super(Conv1dModel, self).__init__()
        self.cnn_apt_1 = nn.Conv1d(40, 500, 3) 
        self.cnn_apt_2 = nn.Conv1d(500, 300, 1)
        self.cnn_apt_3 = nn.Conv1d(300, 150, 1)
        self.cnn_apt_4 = nn.Conv1d(150, 75, 1)
        self.cnn_apt_5 = nn.Conv1d(25, 10, 1)
        
        self.cnn_pep_1 = nn.Conv1d(8, 250, 3)
        self.cnn_pep_2 = nn.Conv1d(250, 500, 1)
        self.cnn_pep_3 = nn.Conv1d(500, 250, 1)
        self.cnn_pep_4 = nn.Conv1d(250, 100, 1)
        self.cnn_pep_5 = nn.Conv1d(100, 10, 1)
        
        self.relu = nn.ReLU()
        self.name = "ConvNetComplex"
        self.maxpool = nn.MaxPool1d(2) 
        
        self.cnn_apt = nn.Sequential(self.cnn_apt_1, self.maxpool, self.relu, self.cnn_apt_2, self.maxpool, self.relu, self.cnn_apt_3, self.maxpool, self.relu, self.cnn_apt_4, self.maxpool, self.relu, self.cnn_apt_5, self.relu)
        self.cnn_pep = nn.Sequential(self.cnn_pep_1, self.maxpool, self.relu, self.cnn_pep_2, self.maxpool, self.relu, self.cnn_pep_3, self.maxpool, self.relu, self.cnn_pep_4, self.maxpool, self.relu, self.cnn_pep_5, self.relu)
        
        self.fc1 = nn.Linear(180, 1)
        
    def forward(self, apt, pep):
        apt = self.cnn_apt(apt)
        apt = self.relu(apt)
        pep = self.cnn_pep(pep)
        pep = self.relu(pep)
        apt = apt.view(-1, 1).T
        pep = pep.view(-1, 1).T
        
        x = torch.cat((apt, pep), 1)
        x = self.fc1(x)
        x = torch.sigmoid(x)
        return x

## Helper methods

In [None]:
# Returns pmf of an aptamer
def get_x_pmf():
    return 0.25**40

# Getting the output of the model for a pair (aptamer, peptide)
def update(x, y):
    x.requires_grad=True
    y.requires_grad=True
    x = x.to(device)
    y = y.to(device)
    out = model(x, y)
    return out

# First term of the loss
def get_log_out(dataset='train'):
    outs = []
    if dataset == 'train':
        dset = train_loss_samples
    else:
        dset = test_loss_samples
    for _ in range(10000):
        apt, pep, _ = pickle.load()
        out = update(apt, pep)
        outs.append(torch.log(out).cpu().detach().numpy().flatten()[0])
    return np.average(outs)

# Second term of loss
def get_out_prime(ds="train"):
    outs = []
    if ds == "train":
        dset = prime_train_loss_samples 
        leng = m
    else:
        dset = prime_test_loss_samples
        leng = n-m
    for _ in range(10000):
        apt, pep, pmf, ind = pickle.load(dset)
        x = apt.to(device)
        y = pep.to(device)
        out = model(x, y)
        if ind == 0:
            factor = (2*leng*get_x_pmf()*pmf)/(1+leng*get_x_pmf()*pmf)
        else:
            factor = 2
        out_is = out.cpu().detach().numpy().flatten()[0] * factor
        outs.append(out_is)
    return np.average(outs)

## Plotting functions

def plot_loss(train_loss, test_loss, i, j, lamb, gamma):
    plt.plot(train_loss, 'b', label='Train loss')
    plt.plot(test_loss, 'y', label='Test loss')
    plt.ylabel("Loss")
    plt.xlabel("Number of iterations")
    plt.title('Loss after ' +  str(i) + " iterations, " + str(j) + " epochs, " + 'lambda =%.5f' % lamb  + ' gamma =%.5f' % gamma)
    plt.legend()
    plt.show()

def plot_recall(train_recall, test_recall, new_recall, i, j, lamb, gamma):
    plt.plot(train_recall, 'b', label='Train recall')
    plt.plot(test_recall, 'y', label='Test recall')
    plt.plot(new_recall, 'r', label='New recall')
    plt.ylabel("Recall (%)")
    plt.xlabel("Number of iterations")
    plt.title('Recall after ' + str(i) + " iterations, " + str(j) + " epochs, " + 'lambda =%.5f' % lamb  + ' gamma =%.5f' % gamma)
    plt.legend()
    plt.show()

def plot_ecdf_test(test_score, i, j, lamb, gamma):
    test_idx = np.argsort(test_score)
    test_id = test_idx >= 10000
    test = np.sort(test_score)
    test_c = ""
    for m in test_id:
        if m:
            test_c += "y"
        else:
            test_c += "g"
    n = test_score.size
    y = np.arange(1, n+1) / n
    plt.scatter(y, test, c=test_c, label='Test CDF')
    plt.xlabel("CDF")
    plt.ylabel("Most recent 10,000 samples")
    plt.title('CDF after ' + str(i) + " iterations, " + str(j) + " epochs, " + 'lambda =%.5f' % lamb  + ' gamma =%.5f' % gamma)
    plt.legend()
    plt.show()

def plot_ecdf_train(train_score, i, j, lamb, gamma):
    #train_score consisits of [10000 scores generated] + [1000 scores from training set]
    train_idx = np.argsort(train_score)
    train_id = train_idx >= 10000
    train = np.sort(train_score)
    train_c = "" #colors
    for l in train_id:
        if l:
            train_c += "r"
        else:
            train_c += "b"
    n = train_score.size
    y = np.arange(1, n+1) / n
    plt.scatter(y, train, c=train_c, label='Train CDF')
    plt.xlabel("CDF")
    plt.ylabel("Most recent 10,000 samples")
    plt.title('CDF after ' + str(i) + " iterations, " + str(j) + " epochs, " + 'lambda =%.5f' % lamb  + ' gamma =%.5f' % gamma)
    plt.legend()
    plt.show()

def histogram(eval_scores, train_scores, test_scores):
    fig, ax = plt.subplots()
    plt.xlim(0, 1.1)
    sns.distplot(eval_scores , color="skyblue", label='New: not in dataset', ax=ax)
    sns.distplot(train_scores , color="gold", label='Train: in dataset', ax=ax)
    sns.distplot(test_scores, color='red', label='Test: in the dataset', ax=ax)
    ax.set_title("Distribution of Scores")
    ax.figure.set_size_inches(7, 4)
    ax.legend()
    plt.show()

## SGD

In [None]:
'''
lamb = hyperparameter
gamma = step size
run_from_checkpoint = path to a checkpointed model
save_checkpoings = file name
'''
def sgd(epochs=[1, 2, 3], 
        lamb=[10, 10, 10], 
        gamma=[1e-3, 1e-4, 1e-5], 
        run_from_checkpoint=None, 
        save_checkpoints=None): 
    
    if run_from_checkpoint is not None:
        checkpointed_model = run_from_checkpoint
        checkpoint = torch.load(checkpointed_model)
        
        optim = SGD(model.parameters(), lr=gamma[0])

        model.load_state_dict(checkpoint['model_state_dict'])
        optim.load_state_dict(checkpoint['optimizer_state_dict'])
        reloaded_epoch = checkpoint['epoch']
        print("Reloading model: " + str(model.name) + " at epoch: " + str(reloaded_epoch))
        epoch = reloaded_epoch
    else:
        model.apply(weights_init)
    
    model.to(device)
    
    train_losses = []
    test_losses = []
    total_epochs = 0
    for i in range(len(epochs)):
        g = gamma[i]
        l = lamb[i]
        eps = epochs[i]
        epoch = 0
        optim = SGD(model.parameters(), lr=g)
        while epoch < eps:
            train_recalls = []
            train_recall_outputs = [] 
            test_recalls = []
            test_recall_outputs = []
            new_outputs = []
            new_recalls = []
            train_correct = 0
            test_correct = 0
            new_correct = 0
            print("Training Epoch: ", total_epochs)
            for i in range(1, m+1)
                apt, pep, apt_prime, pep_prime, pep_prime_pmf, indicator = pickle.load(S_train)
                print(apt, pep, apt_prime, pep_prime, pep_prime_pmf, indicator)
                model.train()
                optim.zero_grad() 
                out = update(apt, pep) #get S_train output/score
                log_out = torch.log(out) 

                train_score = out.cpu().detach().numpy().flatten()[0] 
                if train_score > 0.6:
                    train_correct += 1 
                train_recall_outputs.append(train_score) 

                optim.zero_grad() 
                out_prime = update(apt_prime, pep_prime) #get score from S_prime_train
                if indicator == 0:
                    factor = (2*m*get_x_pmf()*pep_prime_pmf)/(1+m*get_x_pmf()*pep_prime_pmf)
                else:
                    factor = 2
                out_prime = out_prime*factor #adjust for IS
                #print("Obj first part: ", out_prime.cpu().detach().numpy().flatten()[0]*lamb*indicator)
                #print("Obj second part: ", log_out.cpu().detach().numpy().flatten()[0])
                # Retain graph retains the graph for further operations
                (l*indicator*out_prime - log_out).backward(retain_graph=True) 
                optim.step()

                with torch.no_grad():
                    model.eval()
                try:
                    apt_test, pep_test = pickle.load(S_test)
                    print("apt_test, pep_test", apt_test, pep_test)
                except EOFError:
                    S_test.close()
                    S_test = open(FILE_PATH + 's_te.pkl', 'rb')
                    apt_test, pep_test = pickle.load(S_test)
                
                test_score = model(apt_test, pep_test).cpu().detach().numpy().flatten()[0]
                test_recall_outputs.append(test_score) 
                if test_score > 0.6:
                    test_correct += 1 

                #generate 10 unseen examples from S_new as compared 1 example from S_train/S_test for cdfs
                for _ in range(10):
                    apt_new, pep_new = pickle.load(S_new)
                    new_score = model(apt_new, pep_new).cpu().detach().numpy().flatten()[0] #get unknown score
                    new_outputs.append(new_score)
                    if new_score < 0.3:
                        new_correct += 1

                if i % 10 == 0:
                    train_loss = l*get_out_prime("train") - get_log_out('train') #training loss
                    #print("Train loss first part: ", lamb*get_out_prime("train"))
                    #print("Train loss second part: ", get_log_out('train'))
                    test_loss = (m/(n-m))*l*get_out_prime("test") - get_log_out('test') #test loss
                    #print("Test loss first part: ", lamb*get_out_prime("test"))
                    #print("Test loss second part: ", get_log_out('test'))
                    train_losses.append(train_loss)
                    test_losses.append(test_loss)

                    train_recall = 100*train_correct/(total_epochs*m + i) #training recall
                    train_recalls.append(train_recall) 
                    test_recall = 100*test_correct/(total_epochs*m + i) #test recall
                    test_recalls.append(test_recall)
                    new_recall = 100*new_correct/(i*10) #generated dataset recall
                    new_recalls.append(new_recall)
                    if i > 1000:
                        train_score = np.asarray(new_outputs[-10000:] + train_recall_outputs[-1000:]) 
                        test_score = np.asarray(new_outputs[-10000:] + test_recall_outputs[-1000:])
                    else:
                        train_score = np.asarray(new_outputs + train_recall_outputs) #combine train and unknown scores
                        test_score = np.asarray(new_outputs + test_recall_outputs) #combibne test and unknown scores


                if i % 200 == 0:
                    plot_recall(train_recalls, test_recalls, new_recalls, i, total_epochs, l, g)
                    plot_loss(train_losses, test_losses, i, total_epochs, l, g)
                    plot_ecdf_train(train_score, i, total_epochs, l, g)
                    plot_ecdf_test(test_score, i, total_epochs, l, g)
                    histogram(new_outputs[-1000:], train_recall_outputs[-1000:], test_recall_outputs[-1000:])
                    print("New score: ", np.average(new_outputs[-100:]))
                    print("Train score: ", np.average(train_score[-100:]))
                    print("Test score: ", np.average(test_score[-100:]))
        
            # Save after every epoch
            total_epochs += 1
            epoch += 1
            if save_checkpoints is not None:
                print("Saving to: ", save_checkpoints)
                checkpoint_name = save_checkpoints
                torch.save({'epoch': epoch,'model_state_dict': model.state_dict(), 'optimizer_state_dict': optim.state_dict()}, checkpoint_name)


## Hyperparameter tuning

In [None]:
# Hyperparameter search
gamma = [1e-2, 1e-3, 1e-4, 1e-5]
lamb = [10, 10, 10, 10]
EPOCHS = [2, 3, 3, 3]

In [None]:
model = ConvNetComplex()
checkpoint = None
save_path = 'model_checkpoints/ConvNetComplex/04202020.pth'

sgd(epochs=EPOCHS, lamb=lamb, gamma=gamma, run_from_checkpoint=checkpoint, save_checkpoints=None)

## Relevance of learned motifs

In [None]:
# checkpointed_model = '../models/model_checkpoints/mle_model.pth'
# checkpoint = torch.load(checkpointed_model)
# model = Conv1dModelSimple()
# optim = SGD(model.parameters(), lr=1e-3)
# model.load_state_dict(checkpoint['model_state_dict'])
# optim.load_state_dict(checkpoint['optimizer_state_dict'])
# epoch = checkpoint['epoch']
# model.to(device)
# print(str(epoch))

In [None]:
# print(str(len(S_prime_test)))
# print(str(len(S_test)))

In [None]:
# Validation set is S_prime_test and S_test
# validation_set = []
# for (apt, pep), label in S_prime_test[:118262]:
#     validation_set.append((apt, pep, label))

# for (apt, pep) in S_test[:4000]:
#     validation_set.append((apt, pep, 0))

# np.random.shuffle(validation_set)

In [None]:
# validation_set[0]

In [None]:
# correct = 0
# hydrophobicity_binding = []
# hydrophobicity_free = []
# arginine_binding = []
# arginine_free = []

# for (apt, pep, label) in validation_set:
#     if 'Conv1' in model.name:
#         conv_type='1d'
#     else:
#         conv_type='2d'
#     x, y = convert(apt, pep, conv_type=conv_type)
#     score = model(x, y).cpu().detach().numpy().flatten()[0]
#     hp = 0
#     for aa in pep:
#         hp += hydrophobicity[aa]
    
#     if score < 0.3:
#         hydrophobicity_free.append(hp)
#         arginine_free.append(pep.count('R'))
#     elif score > 0.6:
#         hydrophobicity_binding.append(hp)
#         arginine_binding.append(pep.count('R'))


In [None]:
# print("Average Hydrophobicity of binding peptides: ", np.mean(np.asarray(hydrophobicity_binding)))
# print("Average Hydrophobicity of non-binding peptides: ", np.mean(np.asarray(hydrophobicity_free)))
# print("Average Number of Arginines in binding peptides: ", np.mean(np.asarray(arginine_binding)))
# print("Average Number of Arginines in non-binding peptides: ", np.mean(np.asarray(arginine_free)))

In [None]:
# plt.hist(hydrophobicity_binding, bins=10, label='Hydrophobicity of Binding Peptides')
# plt.hist(hydrophobicity_free, bins=10 , label='Hydrophobicity of Non-Binding Peptides')
# plt.ylabel("Density")
# plt.xlabel("Hydrophobicity Score")
# plt.title('Hydrophobicity of Test Set Outputs')
# plt.legend()
# plt.show()

In [None]:
# Arginine content
# plt.hist(arginine_binding, bins=8, label='Number of arginines in binding peptides')
# plt.hist(arginine_free, bins=8 , label='Number of arginines in non-binding peptides')
# plt.ylabel("Density")
# plt.xlabel("Number of Arginines")
# plt.title('Arginine Count in Peptides')
# plt.legend()
# plt.show()