In [1]:
import os, sys
import numpy as np
import json
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD, Adam
import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import itertools

## Preliminary

In [3]:
EPOCHS = 5
N_GRAM = 3
EMBED_DIM = 32
k = 10000 # number of samples used to calculate loss

In [4]:
torch.manual_seed(12345)
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
na_list = ['A', 'C', 'G', 'T'] #nucleic acids for aptamer
aa_list = ['R', 'L', 'S', 'A', 'G', 'P', 'T', 'V', 'N', 'D', 'C', 'Q', 'E', 'H', 'I', 'K', 'M', 'F', 'W', 'Y'] #amino acids for peptide
hydrophobicity = {'G': 0, 'A': 41, 'L':97, 'M': 74, 'F':100, 'W':97, 'K':-23, 'Q':-10, 'E':-31, 'S':-5, 'P':-46, 'V':76, 'I':99, 'C':49, 'Y':63, 'H':8, 'R':-14, 'N':-28, 'D':-55, 'T':13}
NNK_freq = [0.09375]*3 + [0.0625]*5 + [0.03125]*13 #freq of 21 NNK codons including the stop codon
sum_20 = 0.0625*5 + 0.09375*3 + 0.03125*12 #sum of freq without the stop codon
pvals = [0.09375/sum_20]*3 + [0.0625/sum_20]*5 + [0.03125/sum_20]*12 #normalize freq for 20 codons
pvals = [0.09375/sum_20]*3 + [0.0625/sum_20]*5 + [0.03125/sum_20]*11 + \
        [1- sum([0.09375/sum_20]*3 + [0.0625/sum_20]*5 + [0.03125/sum_20]*11)] 
        #adjust sum to 1 due to numerical issue
aa_dict = dict(zip(aa_list, pvals))

## Dataset, Sampling & Vocab

In [21]:
def construct_dataset():
    with open(aptamer_dataset_file, 'r') as f:
        aptamer_data = json.load(f)
    ds = []
    for aptamer in aptamer_data:
        peptides = aptamer_data[aptamer]
        if aptamer == "CTTTGTAATTGGTTCTGAGTTCCGTTGTGGGAGGAACATG": #took out aptamer control
            continue
        for peptide, _ in peptides:
            peptide = peptide.replace("_", "") #removed stop codons
            if "RRRRRR" in peptide: #took out peptide control
                continue
            if len(aptamer) == 40 and len(peptide) == 8: #making sure right length
                ds.append((aptamer, peptide))
    ds = list(set(ds)) #removed duplicates
    return ds


# Get dictionary of {token: index}
def get_vocab():
    def generator(letters):
        letters = "".join(letters)
        for n in range(1, N_GRAM+1):
            for item in itertools.product(letters, repeat=n):
                yield "".join(item)
    a = [i for i in generator(na_list)]
    p = [i for i in generator(aa_list)]
    vocab_apt = {a[i]: i+1 for i in range(len(a))}
    vocab_pep = {p[i]: i+1 for i in range(len(p))}
    return vocab_apt, vocab_pep


# Sample x from P_X (assume apatamers follow uniform)
def get_x():
    x_idx = np.random.randint(0, 4, 40)
    x = ""
    for i in x_idx:
        x += na_list[i]
    return x


# Sample y from P_y (assume peptides follow NNK)
def get_y():
    y_idx = np.random.choice(20, 7, p=pvals)
    y = "M"
    for i in y_idx:
        y += aa_list[i]
    return y


# S'(train/test) contains S_train/S_test with double the size of S_train/S_test
def get_S_prime(kind="train"):
    if kind == "train":
        dset = S_train
    else:
        dset = S_test
    k = len(dset)
    S_prime_dict = dict.fromkeys(dset, 0) #indicator 0 means in S
    for _ in range(k):
        pair = (get_x(), get_y())
        S_prime_dict[pair] = 1 #indicator 1 means not in S
    S_prime = [[k,int(v)] for k,v in S_prime_dict.items()] 
    np.random.shuffle(S_prime)
    return S_prime


# S new contains unseen new examples
def get_S_new(k):
    S_new = []
    for i in range(k):
        pair = (get_x(), get_y())
        S_new.append(pair)
    np.random.shuffle(S_new)
    return S_new
    
    
# Returns pmf of an aptamer
def get_x_pmf():
    return 0.25**40


# Returns pmf of a peptide
def get_y_pmf(y):
    pmf = 1
    for char in y[1:]: #skips first char "M"
        pmf *= aa_dict[char]
    return pmf    

In [22]:
vocab_apt, vocab_pep = get_vocab()
VOCAB_SIZE_APT = len(vocab_apt)
VOCAB_SIZE_PEP = len(vocab_pep)

In [7]:
aptamer_dataset_file = "../data/aptamer_dataset.json"
S = construct_dataset()
n = len(S)
m = int(0.8*n) #length of S_train
S_train = S[:m]
S_test = S[m:]
S_prime_train = get_S_prime("train") #use for sgd 
S_prime_test = get_S_prime("test") #use for sgd 
S_new = get_S_new(10*n) #use for eval
train_ds = np.hstack((S_train, S_prime_train[:len(S_prime_train)//2]))

## N-gram model

In [122]:
class NGram(nn.Module):
    def __init__(self, vocab_size_apt, vocab_size_pep, embed_dim):
        super().__init__()
        self.embedding_apt = nn.Embedding(vocab_size_apt, embed_dim, sparse=True)
        self.embedding_pep = nn.Embedding(vocab_size_pep, embed_dim, sparse=True)
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(embed_dim, 128)
        self.fc2 = nn.Linear(256, 1)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding_apt.weight.data.uniform_(-initrange, initrange)
        self.embedding_pep.weight.data.uniform_(-initrange, initrange)
        self.fc1.weight.data.uniform_(-initrange, initrange)
        self.fc2.weight.data.uniform_(-initrange, initrange)
        self.fc1.bias.data.zero_()
        self.fc2.bias.data.zero_()

    def forward(self, apt, pep):
        embed_apt = self.embedding_apt(apt)
        fc_apt = self.fc1(embed_apt)
        embed_pep = self.embedding_pep(pep)
        fc_pep = self.fc1(embed_pep)
        apt = fc_apt.view(-1, 1).T
        pep = fc_pep.view(-1, 1).T
        x = torch.cat((apt, pep), 1)
        x = self.fc2(x)
        return torch.sigmoid(x)

## Helper methods

In [116]:
# Iterates the broken-down tokens of the given sequence with N_GRAM
def ngrams_iterator(seq):
    for char in seq:
        yield char
    for n in range(2, N_GRAM + 1):
        for char in zip(*[seq[i:] for i in range(n)]):
            yield ''.join(char)
            

# Convert a list of ngrams to a float tensor
def convert_ngram(apt, pep):
    apt = torch.Tensor([vocab_apt[i] for i in ngrams_iterator(apt)])
    pep = torch.Tensor([vocab_pep[i] for i in ngrams_iterator(pep)])
    return apt, pep


# Getting the output of the ngram model for a pair (aptamer, peptide)
def get_ngram_out(x, y):
    x.requires_grad=True
    y.requires_grad=True
    x = x.cuda()
    y = y.cuda()
    out = model(x, y)
    return out


# Generates the samples used to calculate loss
def loss_samples(k, ds='train'): # S_train/S_test
    if ds == 'train':
        dset = S_train
    else:
        dset = S_test
    pairs = []
    for (apt, pep) in dset[:k]:
        x, y = convert_ngram(apt, pep)
        pairs.append((x, y))
    return pairs


# Generates the samples used to calculate loss from S_prime_train/S_prime_test
def prime_loss_samples(k, ds='train'):
    if ds == "train":
        dset = S_prime_train[len(S_prime_train)//2:]    
    else:
        dset = S_prime_test[len(S_prime_test)//2:]
    pairs = []
    for (apt, pep), ind in dset[:k]:
        pmf = get_y_pmf(pep)
        x, y = convert_ngram(apt, pep)
        pairs.append((x, y, ind, pmf))
    return pairs


# First term of the loss
def get_log_out(dataset='train'):
    outs = []
    if dataset == 'train':
        dset = train_loss_samples
    else:
        dset = test_loss_samples
    for (apt, pep) in dset:
        out = get_ngram_out(apt, pep)
        outs.append(torch.log(out).cpu().detach().numpy().flatten()[0])
    return np.average(outs)


# Second term of loss
def get_out_prime(ds="train"):
    outs = []
    if ds == "train":
        dset = prime_train_loss_samples
        leng = m
    else:
        dset = prime_test_loss_samples
        leng = n-m
    for (apt, pep, ind, pmf) in dset:
        x = apt.cuda()
        y = pep.cuda()
        out = model(x, y)
        if ind == 0:
            factor = (2*leng*get_x_pmf()*pmf)/(1+leng*get_x_pmf()*pmf)
        else:
            factor = 2
        out_is = out.cpu().detach().numpy().flatten()[0] * factor
        outs.append(out_is)
    return np.average(outs)


## Plotting functions
def plot_loss(train_loss, test_loss, i, j, lamb, gamma):
    plt.plot(train_loss, 'b', label='Train loss')
    plt.plot(test_loss, 'y', label='Test loss')
    plt.ylabel("Loss")
    plt.xlabel("Number of iterations")
    plt.title('Loss after ' +  str(i) + " iterations, " + str(j) + " epochs, " + 'lambda =%.5f' % lamb  + ' gamma =%.5f' % gamma)
    plt.legend()
    plt.show()

def plot_recall(train_recall, test_recall, new_specificity, i, j, lamb, gamma):
    plt.plot(train_recall, 'b', label='Train recall')
    plt.plot(test_recall, 'y', label='Test recall')
    plt.plot(new_specificity, 'r', label='New specificity')
    plt.ylabel("Recall (%)")
    plt.xlabel("Number of iterations")
    plt.title('Recall/specificity after ' + str(i) + " iterations, " + str(j) + " epochs, " + 'lambda =%.5f' % lamb  + ' gamma =%.5f' % gamma)
    plt.legend()
    plt.show()
    
def plot_cdf(train_cdf, test_cdf, i, j, lamb, gamma):
    plt.plot(train_cdf, 'b', label='Train CDF')
    plt.plot(test_cdf, 'y', label='Test CDF')
    plt.ylabel("CDF")
    plt.xlabel("Most recent 10,000 samples")
    plt.title('CDF after ' + str(i) + " iterations, " + str(j) + " epochs, " + 'lambda =%.5f' % lamb  + ' gamma =%.5f' % gamma)
    plt.legend()
    plt.show()

def histogram(eval_scores, train_scores, test_scores):
    f, axes = plt.subplots(2, 2, figsize=(7, 7), sharex=True)
    plt.xlim(0, 1.1)
    plt.ylim(0,)
    sns.distplot(eval_scores , color="skyblue", label='New: not in dataset', ax=axes[0, 0])
    sns.distplot(train_scores , color="gold", label='Train: in dataset', ax=axes[1, 0])
    sns.distplot(test_scores, color='red', label='Test: in the dataset', ax=axes[0, 1])
    axes[0,0].set_title("New: not in dataset")
    axes[1,0].set_title("Train: in dataset")
    axes[0,1].set_title("Test: in dataset")
    plt.show()

In [117]:
train_loss_samples = loss_samples(k, 'train')
test_loss_samples = loss_samples(k, 'test')
prime_train_loss_samples = prime_loss_samples(k, 'train')
prime_test_loss_samples = prime_loss_samples(k, 'test')

## SGD

In [123]:
def sgd(model_name,
        lamb=10, #hyperparam
        gamma=1e-3, #step size
        save_checkpoints=False): #save checkpoints
    
    optim = SGD(model.parameters(), lr=gamma)
    #offsets = torch.LongTensor([0])
    for epoch in range(EPOCHS):
        train_losses = []
        train_recalls = []
        train_recall_outputs = [] 

        test_losses = []
        test_recalls = []
        test_recall_outputs = []

        new_outputs = []
        new_specificity = []

        train_correct = 0
        test_correct = 0
        new_correct = 0
        
        for i, (aptamer, peptide, (aptamer_prime, peptide_prime), indicator) in enumerate(tqdm.tqdm(train_ds)):
            if i == 0:
                continue
            model.train()
            optim.zero_grad() 
            apt, pep = convert_ngram(aptamer, peptide)
            out = get_ngram_out(apt, pep) #get S_train output/score
            log_out = torch.log(out) 
            
            train_score = out.cpu().detach().numpy().flatten()[0] 
            if train_score > 0.6:
                train_correct += 1 
            train_recall_outputs.append(train_score) 

            optim.zero_grad() 
            y_pmf = get_y_pmf(peptide_prime)
            apt_prime, pep_prime = convert_ngram(aptamer_prime, peptide_prime)
            out_prime = get_ngram_out(apt_prime, pep_prime) #get score from S_prime_train
            if indicator == 0:
                factor = (2*m*get_x_pmf()*y_pmf)/(1+m*get_x_pmf()*y_pmf)
            else:
                factor = 2
            out_prime = out_prime*factor #adjust for IS
            print("Obj first part: ", out_prime.cpu().detach().numpy().flatten()[0]*lamb*indicator)
            print("Obj second part: ", log_out.cpu().detach().numpy().flatten()[0])
            # Retain graph retains the graph for further operations
            (lamb*indicator*out_prime - log_out).backward(retain_graph=True) 
            optim.step()

            with torch.no_grad():
                model.eval()
            
            test_score = model(S_test[i%(n-m)][0], S_test[i%(n-m)][1]).cpu().detach().numpy().flatten()[0]
            test_recall_outputs.append(test_score) 
            if test_score > 0.6:
                test_correct += 1 

            #generate 10 unseen examples from S_new as compared 1 example from S_train/S_test for cdfs
            for x, y in S_new[10*i:10*(i+1)]:
                new_score = model(x, y).cpu().detach().numpy().flatten()[0] #get unknown score
                new_outputs.append(new_score)
                if new_score < 0.3:
                    new_correct += 1

            if i % 50 == 0:
                train_loss = lamb*get_out_prime("train") - get_log_out('train') #training loss
                print("Train loss first part: ", lamb*get_out_prime("train"))
                print("Train loss second part: ", get_log_out('train'))
                test_loss = (m/(n-m))*lamb*get_out_prime("test") - get_log_out('test') #test loss
                print("Test loss first part: ", lamb*get_out_prime("test"))
                print("Test loss second part: ", get_log_out('test'))
                train_losses.append(train_loss)
                test_losses.append(test_loss)

                train_recall = 100*train_correct/i #training recall
                train_recalls.append(train_recall) 
                test_recall = 100*test_correct/i #test recall
                test_recalls.append(test_recall)
                new_specificity = 100*new_correct/(i*10) #generated dataset specificity
                new_specificity.append(new_specificity)
                if i > 1000:
                    train_score = np.asarray(new_outputs[-10000:] + train_recall_outputs[-1000:]) 
                    test_score = np.asarray(new_outputs[-10000:] + test_recall_outputs[-1000:])
                else:
                    train_score = np.asarray(new_outputs + train_recall_outputs) #combine train and unknown scores
                    test_score = np.asarray(new_outputs + test_recall_outputs) #combibne test and unknown scores
                train_cdf = np.cumsum(train_score)/np.sum(train_score) #train cdf
                test_cdf = np.cumsum(test_score)/np.sum(test_score) #test cdf


            if i % 1000 == 0:
                plot_recall(train_recalls, test_recalls, new_specificity, i, epoch, lamb, gamma)
                plot_loss(train_losses, test_losses, i, epoch, lamb, gamma)
                plot_cdf(train_cdf, test_cdf, i, epoch, lamb, gamma)
                histogram(new_outputs[-1000:], train_recall_outputs[-1000:], test_recall_outputs[-1000:])
                print("New score: ", np.average(new_outputs[-1000:]))
                print("Train score: ", np.average(train_score[-1000:]))
                print("Test score: ", np.average(test_score[-1000:]))
        # Save after every epoch
        if save_checkpoints:
            checkpoint_name = '../models/model_checkpoints/' + str(model_name) + '_lambda=' + str(lamb) + '_gamma=' + str(gamma) + '.pth'
            torch.save({'epoch': epoch,'model_state_dict': model.state_dict(), 'optimizer_state_dict': optim.state_dict()}, checkpoint_name)   

## Hyperparameter tuning

In [124]:
# Hyperparameter search
gammas = [1e-3]
lambdas = [10, 5]

In [125]:
for g in gammas:
    for l in lambdas:
        model = NGram(VOCAB_SIZE_APT, VOCAB_SIZE_PEP, EMBED_DIM).to(device)
        sgd(model_name="ngram", gamma=g, lamb=l, save_checkpoints=True)

  0%|          | 0/473047 [00:00<?, ?it/s]


RuntimeError: Expected tensor for argument #1 'indices' to have scalar type Long; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)