In [1]:
import os, sys
import numpy as np
import json
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD, Adam, lr_scheduler
import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import datetime
import seaborn as sns


## Preliminary

In [2]:
torch.manual_seed(12345)
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

na_list = ['A', 'C', 'G', 'T'] #nucleic acids
aa_list = ['R', 'L', 'S', 'A', 'G', 'P', 'T', 'V', 'N', 'D', 'C', 'Q', 'E', 'H', 'I', 'K', 'M', 'F', 'W', 'Y'] #amino acids
NNK_freq = [0.09375]*3 + [0.0625]*5 + [0.03125]*13 #freq of 21 NNK codons including the stop codon
sum_20 = 0.0625*5 + 0.09375*3 + 0.03125*12 #sum of freq without the stop codon
pvals = [0.09375/sum_20]*3 + [0.0625/sum_20]*5 + [0.03125/sum_20]*12 #normalize freq for 20 codons
pvals = [0.09375/sum_20]*3 + [0.0625/sum_20]*5 + [0.03125/sum_20]*11 + \
        [1- sum([0.09375/sum_20]*3 + [0.0625/sum_20]*5 + [0.03125/sum_20]*11)] 
        #adjust sum to 1 due to numerical issue
aa_dict = dict(zip(aa_list, pvals))
k = 1000
BATCH_SIZE = 100

## Dataset

In [3]:
def construct_dataset():
    with open(aptamer_dataset_file, 'r') as f:
        aptamer_data = json.load(f)
    full_dataset = []
    for aptamer in aptamer_data:
        peptides = aptamer_data[aptamer]
        if aptamer == "CTTTGTAATTGGTTCTGAGTTCCGTTGTGGGAGGAACATG": #took out aptamer control
            continue
        for peptide, _ in peptides:
            peptide = peptide.replace("_", "") #removed stop codons
            if "RRRRRR" in peptide: #took out peptide control
                continue
            if len(aptamer) == 40 and len(peptide) == 8: #making sure right length
                full_dataset.append((aptamer.encode("utf-8"), peptide.encode("utf-8"), 1))
    full_dataset = list(set(full_dataset)) #removed duplicates
    return full_dataset

class Dataset(torch.utils.data.Dataset):
    def __init__(self, dataset, batch_size):
        super(Dataset, self).__init__() 
        num_batches = len(dataset)//batch_size
        dataset = dataset[:num_batches*batch_size]
        self.dataset = dataset 
        
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        aptamer, peptide, label = self.dataset[idx]
        return aptamer, peptide, label
    

aptamer_dataset_file = "../data/aptamer_dataset.json"
positive_dataset = construct_dataset()
n = len(positive_dataset)

In [4]:
class LossDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, batch_size):
        super(LossDataset, self).__init__()
        num_batches = len(dataset)//batch_size
        dataset = dataset[:num_batches*batch_size]
        self.dataset = dataset
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem(self, idx):
        aptamer, peptide, label = self.dataset[idx]
        return aptamer, peptide, torch.FloatTensor(label).cuda()

In [5]:
full_dataset = positive_dataset
random.shuffle(full_dataset)

In [6]:
training_set = full_dataset[:int(0.8*len(full_dataset))]
validation_set = full_dataset[int(0.8*len(full_dataset)):int(0.9*len(full_dataset))]
test_set = full_dataset[int(0.9*len(full_dataset)):]

## NN Model

In [7]:
class SimpleConvNet(nn.Module):
    def __init__(self):
        super(SimpleConvNet, self).__init__()
        self.cnn_apt_1 = nn.Conv2d(1, 120, (4,4)) #similar to 4-gram
        self.cnn_pep_1 = nn.Conv2d(1, 50, (4,20))
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(4690, 1)
        
    def forward(self, apt, pep):
        apt = self.cnn_apt_1(apt)
        apt = self.relu(apt)
        pep = self.cnn_pep_1(pep)
        pep = self.relu(pep)
        apt = apt.view(-1, 1).T
        pep = pep.view(-1, 1).T
        
        x = torch.cat((apt, pep), 1)
        x = self.fc1(x)
        x = torch.sigmoid(x)
        return x

def weights_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.xavier_uniform_(m.weight.data)
        nn.init.zeros_(m.bias.data)
    if isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight.data, nonlinearity='relu')
        nn.init.zeros_(m.bias.data)

In [8]:
class DoubleConvNet(nn.Module):
    def __init__(self):
        super(DoubleConvNet, self).__init__()
        self.cnn_apt_1 = nn.Conv2d(1, 1000, (5,4)) #similar to 5-gram
        self.cnn_apt_2 = nn.Conv2d(1000, 100, 1)
        self.cnn_pep_1 = nn.Conv2d(1, 500, (5,20))
        self.cnn_pep_2 = nn.Conv2d(500, 10, 1)
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(3640, 1)
        
    def forward(self, apt, pep):
        apt = self.cnn_apt_1(apt)
        apt = self.cnn_apt_2(apt)
        apt = self.relu(apt)
        pep = self.cnn_pep_1(pep)
        pep = self.cnn_pep_2(pep)
        pep = self.relu(pep)
        apt = apt.view(-1, 1).T
        pep = pep.view(-1, 1).T
        
        x = torch.cat((apt, pep), 1)
        x = self.fc1(x)
        x = torch.sigmoid(x)
        return x

In [9]:
class MoreComplexNet(nn.Module):
    def __init__(self):
        super(MoreComplexNet, self).__init__()
        self.cnn_apt_1 = nn.Conv2d(1, 500, 1) #similar to 5-gram
        self.cnn_apt_2 = nn.Conv2d(500, 1000, 1)
        self.cnn_apt_3 = nn.Conv2d(1000, 500, 2)
        self.cnn_apt_4 = nn.Conv2d(500, 100, 2)
        self.cnn_apt_5 = nn.Conv2d(100, 10, 2)
        
        self.cnn_pep_1 = nn.Conv2d(1, 250, 1)
        self.cnn_pep_2 = nn.Conv2d(250, 500, 1)
        self.cnn_pep_3 = nn.Conv2d(500, 250, 3)
        self.cnn_pep_4 = nn.Conv2d(250, 100, 2)
        self.cnn_pep_5 = nn.Conv2d(100, 10, 2)
        
        self.relu = nn.ReLU()
        
        self.cnn_apt = nn.Sequential(self.cnn_apt_1, self.relu, self.cnn_apt_2, self.relu, self.cnn_apt_3, self.relu, self.cnn_apt_4, self.relu, self.cnn_apt_5)
        self.cnn_pep = nn.Sequential(self.cnn_pep_1, self.relu, self.cnn_pep_2, self.relu, self.cnn_pep_3, self.relu, self.cnn_pep_4, self.relu, self.cnn_pep_5)
        
        
        self.fc1 = nn.Linear(101000, BATCH_SIZE*2)
    
    def forward(self, apt, pep):
        apt = self.cnn_apt(apt)
        apt = self.relu(apt)
        pep = self.cnn_pep(pep)
        pep = self.relu(pep)
        apt = apt.view(-1, 1).T
        pep = pep.view(-1, 1).T
        
        x = torch.cat((apt, pep), 1)
        x = self.fc1(x)
        return x
    
    def loss(self, prediction, label):
        criterion = nn.CrossEntropyLoss()
        return criterion(prediction, label)

## Sampling methods

In [10]:
# Sample x from P_X (assume apatamers follow uniform)
def get_x():
    x_idx = np.random.randint(0, 4, 40)
    x = ""
    for i in x_idx:
        x += na_list[i]
    return x

# Sample y from P_y (assume peptides follow NNK)
def get_y():
    y_idx = np.random.choice(20, 7, p=pvals)
    y = "M"
    for i in y_idx:
        y += aa_list[i]
    return y

# S' contains S with double the size of S (domain for Importance Sampling)
# Return S_prime, and S_new (all unseen samples)
def get_S_new(k):
    S_new = []
    for _ in range(k):
        pair = (get_x(), get_y(), 0)
        S_new.append(pair)
    return S_new

# Returns pmf of an aptamer
def get_x_pmf():
    return 0.25**40

# Returns pmf of a peptide
def get_y_pmf(y):
    pmf = 1
    for char in y[1:]: #skips first char "M"
        pmf *= aa_dict[char]
    return pmf


In [11]:
S_new = get_S_new(n) #use for sgd and eval
print("Length of S_new: ", len(S_new))

('Length of S_new: ', 591309)


## Helper methods

In [12]:
## Takes a peptide and aptamer sequence and converts to one-hot matrix
def one_hot(sequence, seq_type='peptide'):
    if seq_type == 'peptide':
        letters = aa_list
    else:
        letters = na_list
    one_hot = np.zeros((len(sequence), len(letters)))
    for i in range(len(sequence)):
        char = sequence[i]
        for _ in range(len(letters)):
            idx = letters.index(char)
            one_hot[i][idx] = 1
    return one_hot

# Convert a pair to one-hot tensor
def convert(apt, pep, ls, size=BATCH_SIZE):
    aptamers = np.zeros((size, 40, 4))
    peptides = np.zeros((size, 8, 20))
    
    for i in range(size):
        aptamers[i] = one_hot(apt[i], seq_type='aptamer')
        peptides[i] = one_hot(pep[i], seq_type='peptide')
    aptamers = torch.FloatTensor(np.reshape(aptamers, (size, -1, aptamers.shape[1], aptamers.shape[2]))).cuda() #(1, 1, 40, 4)
    peptides = torch.FloatTensor(np.reshape(peptides, (size, -1, peptides.shape[1], peptides.shape[2]))).cuda() #(1, 1, 8, 20)
    labels = torch.LongTensor([])
    if ls is not None:
        labels = []
        for l in ls:
            labels.append(l.item())
        labels = torch.LongTensor(np.asarray(labels))
    return aptamers, peptides, labels.cuda()

def update(x, y):
    # Pmf list
    pmf_list = []
    for pep in y:
        pmf_list.append(get_y_pmf(pep))
    
    x.requires_grad=True
    y.requires_grad=True
    
    x = x.cuda()
    y = y.cuda()
    
    # Model only takes in things of batch size 
    out = model(x, y)
    
    return pmf_list, out

def generate_loss_samples(k, dataset='train'):
    if dataset == 'train':
        dset = training_set
    elif dataset == 'val':
        dset = validation_set
    else:
        dset = S_new
    pairs = []
    for (apt, pep, label) in dset[:k]:
        x, y, _ = convert(apt, pep, None, size=1)
        pairs.append((x, y, label))
    return pairs
    
def get_log_out(dataset='train'):
    outs = []
    if dataset == 'train':
        loader = train_loss_loader
    elif dataset == 'val':
        loader = val_loss_loader
    else:
        loader = sprime_loss_loader
        
    for i, batch in enumerate(loader):
        aptamers = list(batch[0])
        peptides = list(batch[1])
        labels = list(batch[2])
        _, out = update(aptamers, peptides)
        
        # Out = [batch_size*2, 1]
        out = out.view((BATCH_SIZE, 2))
        out = F.normalize(out, dim=1)
        out = out.cpu().detach().numpy()
        
        for j in range(out.shape[0]):
            label = labels[j]
            outs.append(torch.log(out[j][label.item()]))
        
    return np.average(outs)

## Plotting Functions

In [13]:
def cdf(scores1, scores2): # i is the index
    _, ax = plt.subplots()
    ax.hist(scores1, 100, histtype='step', density=True, cumulative=True, color='red', label='train cdf')
    ax.hist(scores2, 100, histtype='step', density=True, cumulative=True, color='black', label='test cdf')
    ax.legend()
    plt.show()

def histogram(eval_scores, train_scores, test_scores):
    f, axes = plt.subplots(2, 2, figsize=(7, 7), sharex=True)
    plt.xlim(0, 1.1)
    sns.distplot(eval_scores , color="skyblue", label='Eval: not in dataset', ax=axes[0, 0])
    sns.distplot(train_scores , color="gold", label='Train: in dataset', ax=axes[1, 0])
    sns.distplot(test_scores, color='red', label='Test: in the dataset', ax=axes[0, 1])
    axes[0,0].set_title("Eval: not in dataset")
    axes[1,0].set_title("Train: in dataset")
    axes[0, 1].set_title("Test: in dataset")
    plt.show()

def plot_loss(train_loss, test_loss, i, lamb, gamma):
    _, ax = plt.subplots()
    ax.plot(train_loss, 'g', label='Train loss')
    ax.plot(test_loss, 'p', label='Test loss')
    ax.set_title('Loss after ' + str(i) + " iterations, " + 'lambda =%.5f' % lamb  + ' gamma =%.5f' % gamma)
    ax.legend()
    plt.show()

def plot_recall(train_recall, test_recall, i, lamb, gamma):
    _, ax = plt.subplots()
    ax.plot(train_recall, 'b', label='Train recall')
    ax.plot(test_recall, 'y', label='Test recall')
    ax.legend()
    ax.set_title('Recall after ' + str(i) + " iterations, " + 'lambda =%.5f' % lamb  + ' gamma =%.5f' % gamma)
    plt.show()

## Construct Train/TestLoaders

In [14]:
train_dataset = Dataset(training_set, BATCH_SIZE) # 80%
validation_dataset = Dataset(validation_set, BATCH_SIZE) # 10%
test_dataset = Dataset(test_set, BATCH_SIZE) # 10%

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE)
val_loader = torch.utils.data.DataLoader(validation_dataset, batch_size=BATCH_SIZE)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE)
print("Length of train dataset: ", len(train_dataset))
print("Length of val dataset: ", len(validation_dataset))
print("Length of test dataset: ", len(test_dataset))

('Length of train dataset: ', 473000)
('Length of val dataset: ', 59100)
('Length of test dataset: ', 59100)


In [15]:
train_loss_samples = LossDataset(generate_loss_samples(k, 'train'), batch_size=BATCH_SIZE)
val_loss_samples = LossDataset(generate_loss_samples(k, 'val'), batch_size=BATCH_SIZE)
sprime_loss_samples = LossDataset(generate_loss_samples(k, 's_prime'), batch_size=BATCH_SIZE)

In [16]:
train_loss_loader = torch.utils.data.DataLoader(train_loss_samples, batch_size=BATCH_SIZE)
val_loss_loader = torch.utils.data.DataLoader(val_loss_samples, batch_size=BATCH_SIZE)
sprime_loss_loader = torch.utils.data.DataLoader(sprime_loss_samples, batch_size=BATCH_SIZE)

## Binary Classification

In [None]:
def binary_classification(t=5, lr=1e-5, lamb=1, decay=0.5):
    model = MoreComplexNet()
    model.apply(weights_init)
    model.cuda()
        
    optimizer = SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=decay)
    
    train_losses = []
    train_recalls = []
    train_recall_outputs = [] 
    train_cdfs = []
    
    val_losses = []
    val_recalls = []
    val_recall_outputs = []
    val_cdfs = []
    
    # Training loop
    for epoch in range(t):
        model.train()
        print("Epoch: ", epoch)
        for i, batch in enumerate(tqdm.tqdm(train_loader)):
            aptamers = list(batch[0])
            peptides = list(batch[1])
            labels = list(batch[2])
            
            x, y, labels = convert(aptamers, peptides, labels)

            optimizer.zero_grad()
            output = model(x, y)
            
            output = output.view((BATCH_SIZE, 2))
            output = F.normalize(output, dim=1) # [[a, b], [c, d], [e, f], ...]
        
            # Calculate train recall
            output = output.cpu().detach().numpy()
            labels = labels.cpu().detach().numpy()
            train_correct = 0
            for k in range(output.shape[0]):
                scores = output[k] # (class1score, class2score)
                label = labels[k]
                
                if label == 1 and scores[1] > scores[0]:
                    train_correct += 1
                elif label == 0 and scores[0] > scores[1]:
                    train_correct += 1
                
                train_recall_outputs.append(scores[label])
            
             train_recalls.append(float(train_correct/output.shape[0]))   
                
            if i % 500 == 0:
                # Calculate train loss
                out_prime = get_log_out('sprime')
                train_loss = lamb*out_prime - get_log_out('train') 
                train_losses.append(train_loss)

            
            loss.backward()
            optimizer.step()

            # Testing loop
            if i % 1000 == 0:
                model.eval()
                correct = 0
                total = 0
                with torch.no_grad():
                    for j, batch in enumerate(val_loader):
                        aptamers = list(batch[0])
                        peptides = list(batch[1])
                        labels = list(batch[2])

                        x, y, labels = convert(aptamers, peptides, labels)

                        output = model(x, y)
                        output = output.view((BATCH_SIZE, 2))
                        output = F.normalize(output, dim=1)
                        
                        # Calculate validation recall
                        
                        # Calculate validation loss
                        val_loss = model.loss(output, labels)
                        val_losses.append(val_loss)
                        
                        # Calcualte 
                        
                        pred = torch.argmax(output, dim=1)
                        pred = pred.cpu().detach().numpy()
                        labels = labels.cpu().detach().numpy()
                        for j in range(pred.shape[0]):
                            if pred[j] == labels[j]:
                                correct += 1

                        total += BATCH_SIZE

                print('Accuracy of the network after ' + str(i) + ' iterations on the validation samples: %d %%' % (100* correct/total))
                _, ax = plt.subplots()
                ax.plot(train_losses, 'g', label='Train loss')
                ax.plot(val_losses, 'b', label='Val loss')
                ax.set_title('Loss after ' + str(i) + " iterations ")
                ax.legend()
                plt.show()
        


In [None]:
lrs = [1e-3]
for lr in lrs:
    print("Initial learning rate: ", lr)
    binary_classification(t=4, lr=lr, decay=0)