In [1]:
import os, sys
import numpy as np
import json
import random

## Prepare the data --> one hot encoding matrices

In [2]:
aptamer_dataset_file = "../data/aptamer_dataset.json"

def construct_dataset():
    with open(aptamer_dataset_file, 'r') as f:
        aptamer_data = json.load(f)
    full_dataset = []
    for aptamer in aptamer_data:
        peptides = aptamer_data[aptamer]
        for peptide, _ in peptides:
            if len(aptamer) == 40 and len(peptide) == 8:
                full_dataset.append((aptamer, peptide))
    return full_dataset

In [3]:
full_dataset = construct_dataset()
random.shuffle(full_dataset)
training_set = full_dataset[:int(0.8*len(full_dataset))]
test_set = full_dataset[int(0.8*len(full_dataset)):]

In [4]:
aa_list = ['R', 'L', 'S', 'A', 'G', 'P', 'T', 'V', 'N', 'D', 'C', 'Q', 'E', 'H', 'I', 'K', 'M', 'F', 'W', 'Y']
na_list = ['A', 'C', 'G', 'T']

## Takes a peptide and aptamer sequence and converts to one-hot matrix
def one_hot(sequence, seq_type='peptide'):
    if seq_type == 'peptide':
        letters = aa_list
    else:
        letters = na_list
    one_hot = np.zeros((len(sequence), len(letters)))
    for i in range(len(sequence)):
        element = sequence[i]
        idx = letters.index(element)
        one_hot[i][idx] = 1
    return one_hot

## Model --> CNN

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD

In [41]:
class TwoLayer(nn.Module):
    def __init__(self):
        super(TwoLayer, self).__init__()
        self.linear_apt_1 = nn.Linear(40, 40)
        self.linear_apt_2 = nn.Linear(40, 1)
        
        self.linear_pep_1 = nn.Linear(8, 8)
        self.linear_pep_2 = nn.Linear(8, 1)
        
        self.relu = nn.ReLU()
    
        self.sequential_pep = nn.Sequential(self.linear_pep_1,
                                            self.relu,
                                            self.linear_pep_2)
        
        self.sequential_apt = nn.Sequential(self.linear_apt_1,
                                            self.relu,
                                            self.linear_apt_2)
                
    def forward(self, apt, pep):
        apt = self.sequential_apt(apt)
        pep = self.sequential_pep(pep)
        print(apt.shape())
        apt = apt.view(-1, 1).T
        pep = pep.view(-1, 1).T
        x = torch.cat((apt, pep), 1)
        x = F.sigmoid(x)
        return x
    
    def loss(self, prediction, label):
        l = nn.MSELoss()
        label = torch.FloatTensor(label)
        label = label.reshape((1, 1))
        return l(torch.FloatTensor(prediction), label)

In [9]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.cnn_apt_1 = nn.Conv2d(1, 40, 1)
        self.cnn_apt_2 = nn.Conv2d(40, 10, 1)
        self.cnn_apt_3 = nn.Conv2d(10, 1, 1)
        self.fc_apt_1 = nn.Linear(160, 1)
        
        self.cnn_pep_1 = nn.Conv2d(1, 8, 1)
        self.cnn_pep_2 = nn.Conv2d(8, 1, 1)
        self.fc_pep_1 = nn.Linear(64, 1)
        
        self.pool = nn.MaxPool2d(2, 1)
        self.relu = nn.ReLU()
        
        self.sequential_pep = nn.Sequential(self.cnn_pep_1,
                                            self.relu, 
                                            self.pool, 
                                            self.cnn_pep_2)
        
        self.sequential_apt = nn.Sequential(self.cnn_apt_1, 
                                            self.relu, 
                                            self.pool, 
                                            self.cnn_apt_2, 
                                            self.relu, 
                                            self.pool, 
                                            self.cnn_apt_3)
        
        self.fc1 = nn.Linear(209, 1)
        
    def forward(self, apt, pep):
        apt = self.sequential_apt(apt)
        pep = self.sequential_pep(pep)
        
        apt = apt.view(-1, 1).T
        pep = pep.view(-1, 1).T
        
        x = torch.cat((apt, pep), 1)
        x = self.fc1(x)
        x = F.sigmoid(x)
        return x
    
    def loss(self, prediction, label):
        l = nn.MSELoss()
        label = torch.FloatTensor(label)
        label = label.reshape((1, 1))
        return l(torch.FloatTensor(prediction), label)

In [12]:
model = ConvNet()
def weights_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.xavier_uniform_(m.weight.data)
        nn.init.zeros_(m.bias.data)

#model_2 = TwoLayer()
# def weights_init_2(m):
#     if isinstance(m, nn.Linear):
#         nn.init.xavier_uniform_(m.weight.data)
#         nn.init.zeros_(m.bias.data)

model.apply(weights_init)
optimizer = Adam(model.parameters(), lr=1e-7, weight_decay=1e-4)

In [13]:
# Training Loop
import tqdm
for epoch in range(1):
    print("Epoch: ", epoch)
    model.train()
    running_loss = 0.0
    # Come up with a trainloader
    for i, data in enumerate(tqdm.tqdm(training_set[:10000])):
        # Peptide and aptamer, one-hot encode them
        pep = training_set[i][1]
        apt = training_set[i][0]
        
        pep = one_hot(pep, seq_type='peptide')
        apt = one_hot(apt, seq_type='aptamer')
        
        pep = torch.FloatTensor(np.reshape(pep, (1, 1, pep.shape[0], pep.shape[1])))
        apt = torch.FloatTensor(np.reshape(apt, (1, 1, apt.shape[0], apt.shape[1])))
        
        output = model(apt, pep)
        loss = model.loss(output, 1)
        optimizer.zero_grad()
        loss.backward()
        
        clip = 5
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
        
        optimizer.step()
        running_loss += loss.item()
        if i % 200 == 199:
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 200))
            running_loss = 0.0
    
print('Finished Training')

  0%|          | 3/10000 [00:00<05:59, 27.83it/s]

Epoch:  0


  2%|▏         | 215/10000 [00:02<02:03, 79.30it/s]

[1,   200] loss: 0.429


  4%|▍         | 410/10000 [00:05<02:09, 74.25it/s]

[1,   400] loss: 0.474


  6%|▌         | 607/10000 [00:08<02:14, 69.61it/s]

[1,   600] loss: 0.430


  8%|▊         | 812/10000 [00:10<02:00, 75.94it/s]

[1,   800] loss: 0.392


 10%|█         | 1018/10000 [00:13<01:39, 90.38it/s]

[1,  1000] loss: 0.428


 12%|█▏        | 1211/10000 [00:15<01:39, 88.69it/s]

[1,  1200] loss: 0.431


 14%|█▍        | 1410/10000 [00:17<01:28, 96.84it/s]

[1,  1400] loss: 0.415


 16%|█▌        | 1612/10000 [00:19<01:25, 98.67it/s]

[1,  1600] loss: 0.386


 18%|█▊        | 1811/10000 [00:21<01:47, 76.23it/s] 

[1,  1800] loss: 0.418


 20%|██        | 2006/10000 [00:24<01:46, 75.31it/s]

[1,  2000] loss: 0.440


 22%|██▏       | 2212/10000 [00:26<01:33, 83.03it/s]

[1,  2200] loss: 0.417


 24%|██▍       | 2408/10000 [00:29<01:40, 75.87it/s]

[1,  2400] loss: 0.440


 26%|██▌       | 2609/10000 [00:31<01:30, 81.24it/s]

[1,  2600] loss: 0.442


 28%|██▊       | 2811/10000 [00:34<01:28, 81.07it/s]

[1,  2800] loss: 0.422


 30%|███       | 3010/10000 [00:36<01:13, 94.87it/s]

[1,  3000] loss: 0.389


 32%|███▏      | 3217/10000 [00:39<01:10, 95.61it/s]

[1,  3200] loss: 0.389


 34%|███▍      | 3409/10000 [00:41<01:03, 103.86it/s]

[1,  3400] loss: 0.405


 36%|███▌      | 3612/10000 [00:43<01:03, 99.84it/s] 

[1,  3600] loss: 0.409


 38%|███▊      | 3818/10000 [00:45<01:03, 97.48it/s]

[1,  3800] loss: 0.403


 40%|████      | 4012/10000 [00:47<01:07, 88.10it/s] 

[1,  4000] loss: 0.425


 42%|████▏     | 4217/10000 [00:49<00:59, 96.86it/s] 

[1,  4200] loss: 0.397


 44%|████▍     | 4418/10000 [00:51<00:58, 96.15it/s] 

[1,  4400] loss: 0.425


 46%|████▌     | 4613/10000 [00:53<00:54, 98.53it/s]

[1,  4600] loss: 0.414


 48%|████▊     | 4809/10000 [00:56<01:05, 79.59it/s] 

[1,  4800] loss: 0.437


 50%|█████     | 5012/10000 [00:58<01:00, 82.44it/s]

[1,  5000] loss: inf


 52%|█████▏    | 5215/10000 [01:01<01:00, 78.62it/s]

[1,  5200] loss: 0.417


 54%|█████▍    | 5416/10000 [01:04<00:57, 79.38it/s]

[1,  5400] loss: inf


 56%|█████▌    | 5619/10000 [01:06<00:44, 99.30it/s]

[1,  5600] loss: 0.425


 58%|█████▊    | 5815/10000 [01:08<00:46, 89.81it/s]

[1,  5800] loss: 0.408


 60%|██████    | 6017/10000 [01:10<00:42, 92.73it/s]

[1,  6000] loss: 0.458


 62%|██████▏   | 6210/10000 [01:12<00:52, 72.78it/s] 

[1,  6200] loss: 0.479


 64%|██████▍   | 6418/10000 [01:14<00:37, 96.08it/s] 

[1,  6400] loss: 0.413


 66%|██████▌   | 6612/10000 [01:17<00:35, 94.62it/s] 

[1,  6600] loss: 0.431


 68%|██████▊   | 6814/10000 [01:19<00:33, 95.68it/s]

[1,  6800] loss: 27062053779819630428160.000


 70%|███████   | 7016/10000 [01:21<00:32, 91.93it/s]

[1,  7000] loss: 0.402


 72%|███████▏  | 7211/10000 [01:23<00:35, 78.06it/s]

[1,  7200] loss: 0.413


 74%|███████▍  | 7408/10000 [01:26<00:38, 67.32it/s]

[1,  7400] loss: 0.430


 76%|███████▌  | 7618/10000 [01:28<00:26, 90.74it/s]

[1,  7600] loss: 0.427


 78%|███████▊  | 7816/10000 [01:31<00:26, 82.27it/s]

[1,  7800] loss: 0.413


 80%|████████  | 8015/10000 [01:33<00:21, 90.63it/s]

[1,  8000] loss: 0.436


 82%|████████▏ | 8213/10000 [01:36<00:22, 79.38it/s]

[1,  8200] loss: 0.385


 84%|████████▍ | 8408/10000 [01:38<00:17, 90.95it/s]

[1,  8400] loss: 0.421


 86%|████████▌ | 8614/10000 [01:40<00:15, 91.84it/s]

[1,  8600] loss: 0.415


 88%|████████▊ | 8816/10000 [01:43<00:12, 93.77it/s]

[1,  8800] loss: 0.424


 90%|█████████ | 9010/10000 [01:45<00:12, 77.17it/s]

[1,  9000] loss: 0.419


 92%|█████████▏| 9213/10000 [01:48<00:08, 87.80it/s]

[1,  9200] loss: 0.449


 94%|█████████▍| 9418/10000 [01:50<00:06, 86.76it/s] 

[1,  9400] loss: 0.441


 96%|█████████▌| 9615/10000 [01:52<00:03, 98.28it/s]

[1,  9600] loss: 0.407


 98%|█████████▊| 9812/10000 [01:54<00:01, 97.43it/s] 

[1,  9800] loss: 0.441


100%|██████████| 10000/10000 [01:56<00:00, 85.88it/s]


[1, 10000] loss: 0.435
Finished Training


## Evaluation --> compare to random

In [15]:
correct = 0
incorrect = 0
for i, data in enumerate(tqdm.tqdm(test_set[:10000])):
    pep = test_set[i][1]
    apt = test_set[i][0]
    
    pep = one_hot(pep, seq_type='peptide')
    apt = one_hot(apt, seq_type='aptamer')

    pep = torch.FloatTensor(np.reshape(pep, (1, 1, pep.shape[0], pep.shape[1])))
    apt = torch.FloatTensor(np.reshape(apt, (1, 1, apt.shape[0], apt.shape[1])))

    output = model(apt, pep)
    #print("Output: ", output)
    
    if output > 0.5:
        correct += 1
    else:
        incorrect += 1

print('Accuracy of the network on the test samples: %d %%' % (100* correct/(correct + incorrect)))

100%|██████████| 10000/10000 [00:16<00:00, 609.13it/s]

Accuracy of the network on the test samples: 0 %



