In [1]:
import os, sys
import numpy as np
import json
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD, Adam
import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
torch.manual_seed(12345)
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

## Preliminary

In [3]:
na_list = ['A', 'C', 'G', 'T'] #nucleic acids
aa_list = ['R', 'L', 'S', 'A', 'G', 'P', 'T', 'V', 'N', 'D', 'C', 'Q', 'E', 'H', 'I', 'K', 'M', 'F', 'W', 'Y'] #amino acids
NNK_freq = [0.09375]*3 + [0.0625]*5 + [0.03125]*13 #freq of 21 NNK codons including the stop codon
sum_20 = 0.0625*5 + 0.09375*3 + 0.03125*12 #sum of freq without the stop codon
pvals = [0.09375/sum_20]*3 + [0.0625/sum_20]*5 + [0.03125/sum_20]*12 #normalize freq for 20 codons
pvals = [0.09375/sum_20]*3 + [0.0625/sum_20]*5 + [0.03125/sum_20]*11 + \
        [1- sum([0.09375/sum_20]*3 + [0.0625/sum_20]*5 + [0.03125/sum_20]*11)] 
        #adjust sum to 1 due to numerical issue
aa_dict = dict(zip(aa_list, pvals))

## Dataset

In [4]:
aptamer_dataset_file = "../data/aptamer_dataset.json"

def construct_dataset():
    with open(aptamer_dataset_file, 'r') as f:
        aptamer_data = json.load(f)
    full_dataset = []
    for aptamer in aptamer_data:
        peptides = aptamer_data[aptamer]
        if aptamer == "CTTTGTAATTGGTTCTGAGTTCCGTTGTGGGAGGAACATG": #took out aptamer control
            continue
        for peptide, _ in peptides:
            peptide = peptide.replace("_", "") #removed stop codons
            if "RRRRRR" in peptide: #took out peptide control
                continue
            if len(aptamer) == 40 and len(peptide) == 8: #making sure right length
                full_dataset.append((aptamer, peptide))
    full_dataset = list(set(full_dataset)) #removed duplicates
    return full_dataset

In [5]:
class TrainDataset(torch.utils.data.Dataset):
    def __init__(self, training_set):
        super(TrainDataset, self).__init__() 
        self.training_set = training_set
        
    def __len__(self):
        return len(self.training_set)

    def __getitem__(self, idx):
        aptamer, peptide = self.training_set[idx]
        return aptamer, peptide
    
class TestDataset(torch.utils.data.Dataset):
    def __init__(self, test_set):
        super(TestDataset, self).__init__() 
        self.test_set = test_set
        
    def __len__(self):
        return len(self.test_set)

    def __getitem__(self, idx):
        aptamer, peptide = self.test_set[idx]
        return aptamer, peptide

In [6]:
full_dataset = construct_dataset()
n = len(full_dataset)
training_set = full_dataset[:int(0.8*n)]
test_set = full_dataset[int(0.8*n):]
train_dataset = TrainDataset(training_set)
test_dataset = TestDataset(test_set)
train_loader = torch.utils.data.DataLoader(train_dataset)
test_loader = torch.utils.data.DataLoader(test_dataset)

## One-hot encoding

In [7]:
## Takes a peptide and aptamer sequence and converts to one-hot matrix
def one_hot(sequence, seq_type='peptide'):
    if seq_type == 'peptide':
        letters = aa_list
    else:
        letters = na_list
    one_hot = np.zeros((len(sequence), len(letters)))
    for i in range(len(sequence)):
        char = sequence[i]
        for _ in range(len(letters)):
            idx = letters.index(char)
            one_hot[i][idx] = 1
    return one_hot

## NN Models

In [8]:
class SimpleConvNet(nn.Module):
    def __init__(self):
        super(SimpleConvNet, self).__init__()
        self.cnn_apt_1 = nn.Conv2d(1, 5, (3,4)) #similar to 3-gram
        self.cnn_pep_1 = nn.Conv2d(1, 5, (3,20))
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(220, 1)
        
    def forward(self, apt, pep):
        apt = self.cnn_apt_1(apt)
        apt = self.relu(apt)
        pep = self.cnn_pep_1(pep)
        pep = self.relu(pep)
        apt = apt.view(-1, 1).T
        pep = pep.view(-1, 1).T
        
        x = torch.cat((apt, pep), 1)
        x = self.fc1(x)
        x = torch.sigmoid(x)
        return x

In [9]:
def weights_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.xavier_uniform_(m.weight.data)
        nn.init.zeros_(m.bias.data)
    if isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight.data, nonlinearity='relu')
        nn.init.zeros_(m.bias.data)

## Sampling

In [33]:
# Sample x from P_X (assume apatamers follow uniform)
def get_x():
    x_idx = np.random.randint(0, 4, 40)
    x = ""
    for i in x_idx:
        x += na_list[i]
    return x

# Sample y from P_y (assume peptides follow NNK)
def get_y():
    y_idx = np.random.choice(20, 7, p=pvals)
    y = "M"
    for i in y_idx:
        y += aa_list[i]
    return y

# Generate uniformly from S without replacement
def get_xy(k):
    samples = [full_dataset[i] for i in np.random.choice(len(full_dataset), k, replace=False)]
    return samples

# S' contains S with double the size of S (domain for Importance Sampling)
def get_S_prime(k):
    S_prime_dict = dict.fromkeys(full_dataset, 0) #indicator 0 means in the original dataset
    S_new = []
    for _ in range(k):
        pair = (get_x(), get_y())
        S_prime_dict[pair] = 1 #indicator 1 means not in the original dataset
        S_new.append(pair)
    S_prime = [[k,int(v)] for k,v in S_prime_dict.items()]
    random.shuffle(S_prime)
    return S_prime, S_new

# Returns pmf of an aptamer
def get_x_pmf():
    return 0.25**40

# Returns pmf of a peptide
def get_y_pmf(y):
    pmf = 1
    for char in y[1:]: #skips first char "M"
        pmf *= aa_dict[char]
    return pmf

S_prime, S_new = get_S_prime(n)

## SGD

In [31]:
# Convert a pair to one-hot tensor
def convert(apt, pep): 
    apt = one_hot(apt, seq_type='aptamer') #(40, 4)
    pep = one_hot(pep, seq_type='peptide') #(8, 20)
    apt = torch.FloatTensor(np.reshape(apt, (1, 1, apt.shape[0], apt.shape[1]))).cuda() #(1, 1, 40, 4)
    pep = torch.FloatTensor(np.reshape(pep, (1, 1, pep.shape[0], pep.shape[1]))).cuda() #(1, 1, 8, 20)
    return apt, pep

def update(x, y):
    pmf = get_y_pmf(y)
    x.requires_grad=True
    y.requires_grad=True
    x = x.cuda()
    y = y.cuda()
    out = model(x, y)
    return pmf, out

def apply_param_grad(grads1, grads2, fn):
    gs = []
    for grad1, grad2 in zip(grads1, grads2):
        gs.append(fn(grad1, grad2))
    return gs

def sgd(t=1, #num of iter over the training set
        lamb=1e-1, #hyperparam
        gamma=1e-2): #step size
    optim = SGD(model.parameters(), lr=gamma)
    model.train()
    for i, (apt, pep) in enumerate(tqdm.tqdm(train_loader)):
        optim.zero_grad()
        x, y = convert(apt[0], pep[0])
        _, out = update(x, y)
        log_out = torch.log(out)
        log_out.backward(retain_graph=True)
        g1 = []
        for param in model.parameters():
            g1.append(param.grad)
        
        optim.zero_grad()
        x_prime, y_prime = convert(S_prime[i][0][0], S_prime[i][0][1])
        y_pmf, out_prime = update(x_prime, y_prime)
        out_prime = out_prime * y_pmf * get_x_pmf() * 2 * n
        out_prime.backward()
        g2 = []
        for param in model.parameters():
            g2.append(param.grad)
        
        const = S_prime[i][1] #indicator
        gs = apply_param_grad(g1, g2, lambda g1, g2: lamb*const*g2 - g1)
        for param, g in zip(model.parameters(), gs): #update params
            param.grad = g
        optim.step()

In [32]:
model = SimpleConvNet()
model.apply(weights_init)
model.cuda()
sgd()



  0%|          | 0/473047 [00:00<?, ?it/s][A[A

  0%|          | 17/473047 [00:00<47:19, 166.57it/s][A[A

  0%|          | 34/473047 [00:00<47:20, 166.51it/s][A[A

  0%|          | 51/473047 [00:00<47:10, 167.08it/s][A[A

  0%|          | 69/473047 [00:00<46:42, 168.77it/s][A[A

  0%|          | 87/473047 [00:00<46:25, 169.80it/s][A[A

  0%|          | 104/473047 [00:00<46:30, 169.48it/s][A[A

  0%|          | 121/473047 [00:00<46:48, 168.41it/s][A[A

  0%|          | 138/473047 [00:00<46:59, 167.73it/s][A[A

  0%|          | 154/473047 [00:00<47:57, 164.37it/s][A[A

  0%|          | 171/473047 [00:01<47:42, 165.21it/s][A[A

  0%|          | 189/473047 [00:01<47:30, 165.91it/s][A[A

  0%|          | 207/473047 [00:01<46:52, 168.14it/s][A[A

  0%|          | 225/473047 [00:01<46:28, 169.56it/s][A[A

  0%|          | 243/473047 [00:01<45:55, 171.57it/s][A[A

  0%|          | 261/473047 [00:01<46:02, 171.12it/s][A[A

  0%|          | 279/473047 [00:01<45:

  1%|          | 2512/473047 [00:13<39:57, 196.23it/s][A[A

  1%|          | 2532/473047 [00:13<39:53, 196.55it/s][A[A

  1%|          | 2552/473047 [00:14<40:11, 195.10it/s][A[A

  1%|          | 2573/473047 [00:14<39:54, 196.48it/s][A[A

  1%|          | 2594/473047 [00:14<39:39, 197.67it/s][A[A

  1%|          | 2614/473047 [00:14<41:03, 190.99it/s][A[A

  1%|          | 2634/473047 [00:14<43:12, 181.46it/s][A[A

  1%|          | 2653/473047 [00:14<44:27, 176.36it/s][A[A

  1%|          | 2671/473047 [00:14<45:44, 171.38it/s][A[A

  1%|          | 2689/473047 [00:14<46:00, 170.40it/s][A[A

  1%|          | 2707/473047 [00:14<46:24, 168.94it/s][A[A

  1%|          | 2724/473047 [00:15<46:27, 168.72it/s][A[A

  1%|          | 2741/473047 [00:15<46:29, 168.57it/s][A[A

  1%|          | 2758/473047 [00:15<47:03, 166.56it/s][A[A

  1%|          | 2775/473047 [00:15<47:32, 164.87it/s][A[A

  1%|          | 2792/473047 [00:15<48:23, 161.98it/s][A[A

  1%|   

  1%|          | 5030/473047 [00:27<40:55, 190.59it/s][A[A

  1%|          | 5050/473047 [00:27<40:40, 191.76it/s][A[A

  1%|          | 5070/473047 [00:27<40:59, 190.26it/s][A[A

  1%|          | 5090/473047 [00:27<41:12, 189.23it/s][A[A

  1%|          | 5109/473047 [00:28<41:16, 188.97it/s][A[A

  1%|          | 5128/473047 [00:28<41:41, 187.07it/s][A[A

  1%|          | 5148/473047 [00:28<41:26, 188.17it/s][A[A

  1%|          | 5168/473047 [00:28<40:57, 190.35it/s][A[A

  1%|          | 5188/473047 [00:28<40:33, 192.23it/s][A[A

  1%|          | 5208/473047 [00:28<40:53, 190.72it/s][A[A

  1%|          | 5228/473047 [00:28<41:14, 189.02it/s][A[A

  1%|          | 5247/473047 [00:28<41:16, 188.93it/s][A[A

  1%|          | 5266/473047 [00:28<41:15, 188.96it/s][A[A

  1%|          | 5286/473047 [00:29<40:51, 190.83it/s][A[A

  1%|          | 5306/473047 [00:29<40:48, 191.05it/s][A[A

  1%|          | 5326/473047 [00:29<41:07, 189.53it/s][A[A

  1%|   

KeyboardInterrupt: 

In [None]:
# Recall on train set of size k to test for overfitting
def recall_train(k):
    model.eval()
    correct = 0
    train_recall_outputs = []
    for (apt, pep) in recall_train_samples:
        apt, pep = convert(apt, pep)
        out = model(apt, pep).cpu().detach().numpy().flatten()[0]
        train_recall_outputs.append(out)
        if out > 0.75:
            correct += 1
    train_recall = 100*correct/k #recall rate of k samples from training set
    return train_recall, train_recall_outputs #list of k outputs


# Recall on test set of size k
def recall_test(k):
    model.eval()
    correct = 0
    count = 0
    test_recall_outputs = []
    for _, (aptamer, peptide) in enumerate(tqdm.tqdm(test_loader)):
        if count > k:
            break
        apt, pep = convert(aptamer[0], peptide[0])
        output = model(apt, pep).cpu().detach().numpy().flatten()[0]
        test_recall_outputs.append(output)
        if output > 0.75:
            correct += 1
        count += 1
    test_recall = 100*correct/k #recall rate of k samples from test set
    return test_recal, test_recall_outputs #list of k outputs


# Eval on m new unseen pairs in S_new (not in our dataset)
def eval_unknown(m):
    model.eval()
    eval_unknown_outputs = []
    for (x, y) in S_new[:m]:
        apt, pep = convert(x, y)
        output = model(apt, pep).cpu().detach().numpy().flatten()[0]
        eval_unknown_outputs.append(output)
    return eval_unknown_outputs #list of m outputs

In [None]:
gammas = [1e-2]
lambdas = [1e-1, 1e-3, 1e-5]
train_recalls = []
test_recalls = []
train_scores = []
test_scores = []
train_cdfs = []
test_cdfs = []

m = int(1e6) # number of unknown samples
k = m//10 # number of binding samples (test set size is 118262, k is just some limit we set)
recall_train_samples = get_xy(k)
x, y = get_x(), get_y()

for g in range(len(gammas)):
    for l in range(len(lambdas)):
        model = SimpleConvNet()
        model.apply(weights_init)
        model.cuda()
        
        print("=============Training=======================")
        sgd(t=1, gamma=gammas[g], lamb=lambdas[l])
        
        print("=============Evaluating train===============")
        train_recall, train_recall_outputs = recall_train(k)
        print("Gamma: ", "%.5f" % gammas[g], "Lambda: ", "%.5f" % lambdas[l], \
              "Train recall: ", "%.2f" % train_recall)
        
        print("=============Evaluating test================")
        test_recall, test_recall_outputs = recall_test(k)
        print("Gamma: ", "%.5f" % gammas[g], "Lambda: ", "%.5f" % lambdas[l], \
              "Test recall: ", "%.2f" % test_recall)
        
        print("=============Evaluating unknown=============")
        precision, eval_unknown_outputs = eval_unknown(m)
        
        train_score = np.asarray(eval_unknown_outputs + train_recall_outputs)
        train_scores.append(train_score)
        
        test_score = np.asarray(eval_unknown_outputs + test_recall_outputs)
        test_scores.append(test_score)
        
        train_cdf = np.sum(np.cumsum(train_score), dtype=float)/(np.sum(train_score)*len(train_score))
        test_cdf = np.sum(np.cumsum(test_score), dtype=float)/(np.sum(test_score)*len(test_score))
        print("G: ", "%.5f" % gammas[g], "L: ", "%.5f" % lambdas[l], \
              "Train CDF: ", "%.3f" % train_cdf, "Test CDF: ", "%.3f" % test_cdf)
        
        train_recalls.append(train_recall)
        test_recalls.append(test_recall)
        train_cdfs.append((gammas[g], lambdas[l], train_cdf))
        test_cdfs.append((gammas[g], lambdas[l], test_cdf))

print("Train CDFs: ", train_cdfs)
print("Test CDFs: ", test_cdfs)

In [None]:
# AUC Plot
def cdf(scores, i): # i is the index
    plt.hist(scores, 100, histtype='step', density=True, cumulative=True)
    g = gammas[i//len(lambdas)]
    l = lambdas[i%len(lambdas)]
    label = 'lambda =%.5f' % l  + ' gamma =%.5f' % g
    plt.legend([label])
    plt.show()

In [None]:
### Test for new config of the NN model
model = SimpleConvNet()
model.apply(weights_init)
model.cuda()
print("Training...")

sgd(t=100, gamma=0.1, lamb=0.01)
m = 100
k = 10

train_recall, train_recall_outputs = recall_train(k)
print(train_recall)

test_recall, test_recall_outputs = recall_test(k)
print(test_recall)

train_scores = []
test_scores = []
train_cdfs = []
test_cdfs = []

precision, eval_unknown_outputs = eval_unknown(m)
        
train_score = np.asarray(eval_unknown_outputs + train_recall_outputs)
train_scores.append(train_score)

test_score = np.asarray(eval_unknown_outputs + test_recall_outputs)
test_scores.append(test_score)

train_cdf = np.sum(np.cumsum(train_score), dtype=float)/(np.sum(train_score)*len(train_score))
test_cdf = np.sum(np.cumsum(test_score), dtype=float)/(np.sum(test_score)*len(test_score))

train_cdfs.append((gammas[g], lambdas[l], train_cdf))
test_cdfs.append((gammas[g], lambdas[l], test_cdf))

print("G: ", "%.5f" % gammas[g], "L: ", "%.5f" % lambdas[l], \
      "Train CDF: ", "%.3f" % train_cdf, "Test CDF: ", "%.3f" % test_cdf)

In [None]:
### Write to file when output limit exceeded
# with open("./scores.txt", "a") as f:
#     f.write(str(scores[-1]))
#     f.close()

### Table of recalls with different params
# idx = sorted(range(len(recalls)), key=lambda k: recalls[k])
# for i in idx:
#     g = gammas[i//len(lambdas)]
#     l = lambdas[i%len(lambdas)]
#     print("Gamma: ", "%.5f" % g, "Lambda: ", "%.5f" % l, \
#           "Recall: ", "%.2f" % recalls[i], "Precision: ", "%.2f" % precisions[i])

### Heatmap of recalls
# mat = sns.heatmap(M, vmin=0, vmax=100)
# plt.show()