In [1]:
import numpy as np
import json
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD, Adam
import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

## Preliminary

In [2]:
torch.manual_seed(12345)
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

na_list = ['A', 'C', 'G', 'T'] #nucleic acids
aa_list = ['R', 'L', 'S', 'A', 'G', 'P', 'T', 'V', 'N', 'D', 'C', 'Q', 'E', 'H', 'I', 'K', 'M', 'F', 'W', 'Y'] #amino acids
NNK_freq = [0.09375]*3 + [0.0625]*5 + [0.03125]*13 #freq of 21 NNK codons including the stop codon
sum_20 = 0.0625*5 + 0.09375*3 + 0.03125*12 #sum of freq without the stop codon
pvals = [0.09375/sum_20]*3 + [0.0625/sum_20]*5 + [0.03125/sum_20]*12 #normalize freq for 20 codons
pvals = [0.09375/sum_20]*3 + [0.0625/sum_20]*5 + [0.03125/sum_20]*11 + \
        [1- sum([0.09375/sum_20]*3 + [0.0625/sum_20]*5 + [0.03125/sum_20]*11)] 
        #adjust sum to 1 due to numerical issue
aa_dict = dict(zip(aa_list, pvals))
k = 1000
BATCH_SIZE = 100

## Dataset

In [3]:
def construct_dataset():
    with open(aptamer_dataset_file, 'r') as f:
        aptamer_data = json.load(f)
    full_dataset = []
    for aptamer in aptamer_data:
        peptides = aptamer_data[aptamer]
        if aptamer == "CTTTGTAATTGGTTCTGAGTTCCGTTGTGGGAGGAACATG": #took out aptamer control
            continue
        for peptide, _ in peptides:
            peptide = peptide.replace("_", "") #removed stop codons
            if "RRRRRR" in peptide: #took out peptide control
                continue
            if len(aptamer) == 40 and len(peptide) == 8: #making sure right length
                full_dataset.append((aptamer, peptide, 1))
    full_dataset = list(set(full_dataset)) #removed duplicates
    return full_dataset

class Dataset(torch.utils.data.Dataset):
    def __init__(self, dataset, batch_size):
        super(Dataset, self).__init__() 
        num_batches = len(dataset)//batch_size
        dataset = dataset[:num_batches*batch_size]
        self.dataset = dataset 
        
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        aptamer, peptide, label = self.dataset[idx]
        return aptamer, peptide, label
    

aptamer_dataset_file = "../data/aptamer_dataset.json"
positive_dataset = construct_dataset()
m = len(positive_dataset)

In [4]:
# Sample x from P_X (assume apatamers follow uniform)
def get_x():
    x_idx = np.random.randint(0, 4, 40)
    x = ""
    for i in x_idx:
        x += na_list[i]
    return x

# Sample y from P_y (assume peptides follow NNK)
def get_y():
    y_idx = np.random.choice(20, 7, p=pvals)
    y = "M"
    for i in y_idx:
        y += aa_list[i]
    return y

# Return S_new (all unseen samples)
def get_S_new(k):
    S_new = []
    for _ in range(k):
        pair = (get_x(), get_y(), 0)
        S_new.append(pair)
    return list(set(S_new))

S_new = get_S_new(m)

## Construct Train/TestLoaders

In [5]:
full_dataset = positive_dataset + S_new
n = len(full_dataset)
np.random.shuffle(full_dataset)
train_dataset = Dataset(full_dataset[:int(0.8*n)], BATCH_SIZE)
validation_dataset = Dataset(full_dataset[int(0.8*n):int(0.9*n)], BATCH_SIZE)
test_dataset = Dataset(full_dataset[int(0.9*n):], BATCH_SIZE)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE)
val_loader = torch.utils.data.DataLoader(validation_dataset, batch_size=BATCH_SIZE)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE)

## NN Model

In [6]:
class SimpleConvNet(nn.Module):
    def __init__(self):
        super(SimpleConvNet, self).__init__()
        self.cnn_apt_1 = nn.Conv2d(1, 120, (4,4)) #similar to 4-gram
        self.cnn_pep_1 = nn.Conv2d(1, 50, (4,20))
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(4690, BATCH_SIZE*2)
        
    def forward(self, apt, pep):
        apt = self.cnn_apt_1(apt)
        apt = self.relu(apt)
        pep = self.cnn_pep_1(pep)
        pep = self.relu(pep)
        apt = apt.view(-1, 1).T
        pep = pep.view(-1, 1).T
        
        x = torch.cat((apt, pep), 1)
        x = self.fc1(x)
        x = F.log_softmax(x, dim=1)
        return x
    
    def loss(self, prediction, label):
        criterion = nn.CrossEntropyLoss()
        return criterion(prediction, label)
        
def weights_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.xavier_normal_(m.weight.data)
        nn.init.zeros_(m.bias.data)
    if isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight.data, nonlinearity='relu')
        nn.init.zeros_(m.bias.data)

In [7]:
class DoubleConvNet(nn.Module):
    def __init__(self):
        super(DoubleConvNet, self).__init__()
        self.cnn_apt_1 = nn.Conv2d(1, 1000, (5,4)) #similar to 5-gram
        self.cnn_apt_2 = nn.Conv2d(1000, 100, 1)
        self.cnn_pep_1 = nn.Conv2d(1, 500, (5,20))
        self.cnn_pep_2 = nn.Conv2d(500, 10, 1)
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(3640, 1)
        
    def forward(self, apt, pep):
        apt = self.cnn_apt_1(apt)
        apt = self.cnn_apt_2(apt)
        apt = self.relu(apt)
        pep = self.cnn_pep_1(pep)
        pep = self.cnn_pep_2(pep)
        pep = self.relu(pep)
        apt = apt.view(-1, 1).T
        pep = pep.view(-1, 1).T
        
        x = torch.cat((apt, pep), 1)
        x = self.fc1(x)
        x = torch.sigmoid(x)
        return x

In [8]:
class MoreComplexNet(nn.Module):
    def __init__(self):
        super(MoreComplexNet, self).__init__()
        self.cnn_apt_1 = nn.Conv2d(1, 500, 1)
        self.cnn_apt_2 = nn.Conv2d(500, 1000, 1)
        self.cnn_apt_3 = nn.Conv2d(1000, 500, 2)
        self.cnn_apt_4 = nn.Conv2d(500, 100, 2)
        self.cnn_apt_5 = nn.Conv2d(100, 10, 2)
        
        self.cnn_pep_1 = nn.Conv2d(1, 250, 1)
        self.cnn_pep_2 = nn.Conv2d(250, 500, 1)
        self.cnn_pep_3 = nn.Conv2d(500, 250, 3)
        self.cnn_pep_4 = nn.Conv2d(250, 100, 2)
        self.cnn_pep_5 = nn.Conv2d(100, 10, 2)
        
        self.relu = nn.ReLU()
        
        self.cnn_apt = nn.Sequential(self.cnn_apt_1, self.relu, self.cnn_apt_2, self.relu, self.cnn_apt_3, self.relu, self.cnn_apt_4, self.relu, self.cnn_apt_5)
        self.cnn_pep = nn.Sequential(self.cnn_pep_1, self.relu, self.cnn_pep_2, self.relu, self.cnn_pep_3, self.relu, self.cnn_pep_4, self.relu, self.cnn_pep_5)
        
        
        self.fc1 = nn.Linear(101000, BATCH_SIZE*2)
    
    def forward(self, apt, pep):
        apt = self.cnn_apt(apt)
        apt = self.relu(apt)
        pep = self.cnn_pep(pep)
        pep = self.relu(pep)
        apt = apt.view(-1, 1).T
        pep = pep.view(-1, 1).T
        
        x = torch.cat((apt, pep), 1)
        x = self.fc1(x)
        x = F.log_softmax(x, dim=1)
        return x
    
    def loss(self, prediction, label):
        criterion = nn.CrossEntropyLoss()
        return criterion(prediction, label)
        

## Helper methods

In [9]:
## Takes a peptide and aptamer sequence and converts to one-hot matrix
def one_hot(sequence, seq_type='peptide'):
    if seq_type == 'peptide':
        letters = aa_list
    else:
        letters = na_list
    one_hot = np.zeros((len(sequence), len(letters)))
    for i in range(len(sequence)):
        char = sequence[i]
        for _ in range(len(letters)):
            idx = letters.index(char)
            one_hot[i][idx] = 1
    return one_hot


# Convert a pair to one-hot tensor
def convert(apt, pep, ls):
    aptamers = np.zeros((BATCH_SIZE, 40, 4))
    peptides = np.zeros((BATCH_SIZE, 8, 20))
    
    for i in range(BATCH_SIZE):
        aptamers[i] = one_hot(apt[i], seq_type='aptamer')
        peptides[i] = one_hot(pep[i], seq_type='peptide')

    aptamers = torch.FloatTensor(np.reshape(aptamers, (BATCH_SIZE, 1, aptamers.shape[1], aptamers.shape[2]))).cuda() #(1, 1, 40, 4)
    peptides = torch.FloatTensor(np.reshape(peptides, (BATCH_SIZE, 1, peptides.shape[1], peptides.shape[2]))).cuda() #(1, 1, 8, 20)
    labels = []
    for l in ls:
        labels.append(l.item())
    labels = torch.LongTensor(np.asarray(labels))
    return aptamers, peptides, labels.cuda()

## Binary Classification

In [10]:
def binary_classification(t=5, lr=1e-5):
    model = SimpleConvNet()
    model.apply(weights_init)
    model.cuda()
        
    optimizer = SGD(model.parameters(), lr=lr)
    train_losses = []
    val_losses = []
    
    # Training loop
    for epoch in range(t):
        model.train()
        for i, batch in enumerate(tqdm.tqdm(train_loader)):
            aptamers = list(batch[0]) #list of BATCH_SIZE num of aptamers
            peptides = list(batch[1])
            labels = list(batch[2])
            
            x, y, labels = convert(aptamers, peptides, labels)

            optimizer.zero_grad()
            output = model(x, y)
            print(output)
            
            output = output.view((BATCH_SIZE, 2))
            print(output)
            
            loss = model.loss(output, labels)            
            print(loss)
            train_losses.append(loss)
            loss.backward()
            optimizer.step()

            # Testing loop
            if i % 1000 == 0:
                model.eval()
                correct = 0
                total = 0
                with torch.no_grad():
                    for j, batch in enumerate(val_loader):
                        aptamers = list(batch[0])
                        peptides = list(batch[1])
                        labels = list(batch[2])
                        x, y, labels = convert(aptamers, peptides, labels)
                        output = model(x, y)
                        output = output.view((BATCH_SIZE, 2))
                        val_losses.append(model.loss(output, labels))
                        pred = torch.argmax(output, dim=1)
                        pred = pred.cpu().detach().numpy()
                        labels = labels.cpu().detach().numpy()
                        for i in range(pred.shape[0]):
                            if pred[i] == labels[i]:
                                correct += 1

                        total += BATCH_SIZE

                print('Accuracy of the network after ' + str(epoch) + ' epochs on the test samples: %d %%' % (100* correct/total))
                _, ax = plt.subplots()
                ax.plot(train_losses, 'g', label='Train loss')
                ax.plot(val_losses, 'b', label='Val loss')
                ax.set_title('Loss after ' + str(i) + " iterations ")
                ax.legend()
                plt.show()

In [11]:
binary_classification(t=10, lr=1e-3)

  0%|          | 0/9460 [00:00<?, ?it/s]


RuntimeError: size mismatch, m1: [1 x 469000], m2: [4690 x 200] at /opt/conda/conda-bld/pytorch_1579022071458/work/aten/src/THC/generic/THCTensorMathBlas.cu:290