In [1]:
import os, sys
import numpy as np
import json
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
import tqdm

## Prepare the data --> one hot encoding matrices

In [2]:
BATCH_SIZE = 256
random.seed(42)

In [3]:
aptamer_dataset_file = "../data/aptamer_dataset.json"

def construct_dataset():
    with open(aptamer_dataset_file, 'r') as f:
        aptamer_data = json.load(f)
    full_dataset = []
    for aptamer in aptamer_data:
        peptides = aptamer_data[aptamer]
        for peptide, _ in peptides:
            peptide = peptide.replace("_", "")
            if len(aptamer) == 40 and len(peptide) == 8:
                full_dataset.append((aptamer, peptide))
    return full_dataset

In [4]:
full_dataset = construct_dataset()
random.shuffle(full_dataset)
training_set = full_dataset[:int(0.8*len(full_dataset))]
test_set = full_dataset[int(0.8*len(full_dataset)):]

In [5]:
class TrainDataset(torch.utils.data.Dataset):
    def __init__(self, training_set):
        super(TrainDataset, self).__init__() 
        self.training_set = training_set
        n = len(training_set)
        self.training_set = training_set[:n-n%BATCH_SIZE]
        
    def __len__(self):
        return len(self.training_set)

    def __getitem__(self, idx):
        aptamer, peptide = self.training_set[idx]
        return aptamer, peptide
    
class TestDataset(torch.utils.data.Dataset):
    def __init__(self, test_set):
        super(TestDataset, self).__init__() 
        self.test_set = test_set
        n = len(test_set)
        self.test_set = test_set[:n-n%BATCH_SIZE]
        
    def __len__(self):
        return len(self.test_set)

    def __getitem__(self, idx):
        aptamer, peptide = self.test_set[idx]
        return aptamer, peptide

In [6]:
train_dataset = TrainDataset(training_set)
test_dataset = TestDataset(test_set)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE)

In [7]:
aa_list = ['R', 'L', 'S', 'A', 'G', 'P', 'T', 'V', 'N', 'D', 'C', 'Q', 'E', 'H', 'I', 'K', 'M', 'F', 'W', 'Y']
na_list = ['A', 'C', 'G', 'T']

## Takes a peptide and aptamer sequence and converts to one-hot matrix
def one_hot(sequence_list, seq_type='peptide'):
    if seq_type == 'peptide':
        letters = aa_list
    else:
        letters = na_list
    
    one_hot = np.zeros((len(sequence_list), len(sequence_list[0]), len(letters)))
    for j in range(len(sequence_list)):
        sequence = sequence_list[j]
        for i in range(len(sequence)):
            element = sequence[i]
            idx = letters.index(element)
            one_hot[j][i][idx] = 1
    return one_hot

## Model --> CNN

In [8]:
class TwoLayer(nn.Module):
    def __init__(self):
        super(TwoLayer, self).__init__()
        self.linear_apt_1 = nn.Linear(40, 40)
        self.linear_apt_2 = nn.Linear(40, 1)
        
        self.linear_pep_1 = nn.Linear(8, 8)
        self.linear_pep_2 = nn.Linear(8, 1)
        
        self.relu = nn.ReLU()
    
        self.sequential_pep = nn.Sequential(self.linear_pep_1,
                                            self.relu,
                                            self.linear_pep_2)
        
        self.sequential_apt = nn.Sequential(self.linear_apt_1,
                                            self.relu,
                                            self.linear_apt_2)
                
    def forward(self, apt, pep):
        apt = self.sequential_apt(apt)
        pep = self.sequential_pep(pep)
        apt = apt.view(-1, 1).T
        pep = pep.view(-1, 1).T
        x = torch.cat((apt, pep), 1)
        x = F.sigmoid(x)
        return x
    
    def loss(self, prediction, label):
        l = nn.MSELoss()
        label = torch.FloatTensor(label)
        label = label.reshape((1, 1))
        return l(torch.FloatTensor(prediction), label)

In [9]:
class ConvNet(nn.Module):
    def __init__(self, batch_size):
        super(ConvNet, self).__init__()
        self.cnn_apt_1 = nn.Conv2d(BATCH_SIZE, 40, 1)
        self.cnn_apt_2 = nn.Conv2d(40, 10, 1)
        self.cnn_apt_3 = nn.Conv2d(10, 1, 1)
        self.fc_apt_1 = nn.Linear(160, 1)
        
        self.cnn_pep_1 = nn.Conv2d(BATCH_SIZE, 8, 1)
        self.cnn_pep_2 = nn.Conv2d(8, 1, 1)
        self.fc_pep_1 = nn.Linear(64, 1)
        
        self.pool = nn.MaxPool2d(2, 1)
        self.relu = nn.ReLU()
        
        #self.dropout = nn.Dropout(0.1)
        
        self.sequential_pep = nn.Sequential(self.cnn_pep_1,
                                            #self.dropout,
                                            self.relu, 
                                            self.pool, 
                                            self.cnn_pep_2)
        
        self.sequential_apt = nn.Sequential(self.cnn_apt_1, 
                                            #self.dropout,
                                            self.relu, 
                                            self.pool, 
                                            self.cnn_apt_2, 
                                            #self.dropout,
                                            self.relu, 
                                            self.pool, 
                                            self.cnn_apt_3)
        
        self.fc1 = nn.Linear(209, BATCH_SIZE)
        
    def forward(self, apt, pep):
        apt = self.sequential_apt(apt)
        pep = self.sequential_pep(pep)
        
        apt = apt.view(-1, 1).T
        pep = pep.view(-1, 1).T
        
        x = torch.cat((apt, pep), 1)
        x = self.fc1(x)
        x = torch.sigmoid(x)
        return x
    
    def loss(self, prediction, label):
        l = nn.MSELoss()
        label = torch.FloatTensor(label)
        return l(torch.FloatTensor(prediction), label)

In [10]:
model = ConvNet(batch_size=BATCH_SIZE)
def weights_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.xavier_uniform_(m.weight.data)
        nn.init.zeros_(m.bias.data)
    if isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight.data, nonlinearity='relu')
        nn.init.zeros_(m.bias.data)

#model_2 = TwoLayer()
# def weights_init_2(m):
#     if isinstance(m, nn.Linear):
#         nn.init.xavier_uniform_(m.weight.data)
#         nn.init.zeros_(m.bias.data)

model.apply(weights_init)
optimizer = Adam(model.parameters(), lr=1e-3, weight_decay=1e-2)

In [11]:
# Training Loop
for epoch in range(1):
    print("Epoch: ", epoch)
    model.train()
    running_loss = 0.0
    # Come up with a trainloader
    for batch_idx, (aptamer, peptide) in enumerate(tqdm.tqdm(train_loader)):
        # Peptide and aptamer, one-hot encode them 
        pep = one_hot(peptide, seq_type='peptide')
        apt = one_hot(aptamer, seq_type='aptamer')
        
        pep = torch.FloatTensor(np.reshape(pep, (1, pep.shape[0], pep.shape[1], pep.shape[2])))
        apt = torch.FloatTensor(np.reshape(apt, (1, apt.shape[0], apt.shape[1], apt.shape[2])))
        
        output = model(apt, pep)
        
        label = np.ones((1, BATCH_SIZE))
        loss = model.loss(output, label)
        optimizer.zero_grad()
        loss.backward()
        
        #hyperparameter
        clip = 5
        torch.nn.utils.clip_grad_norm(model.parameters(), clip)
        
        optimizer.step()
        running_loss += loss.item()
        if batch_idx % 10 == 9:
            print('[%d, %5d] loss: %.7f' %
                  (epoch + 1, batch_idx + 1, running_loss / 10*BATCH_SIZE))
            running_loss = 0.0
    
print('Finished Training')

  0%|          | 1/2422 [00:00<04:58,  8.10it/s]

Epoch:  0


  0%|          | 11/2422 [00:00<03:42, 10.84it/s]

[1,    10] loss: 69.7710876


  1%|          | 21/2422 [00:01<02:58, 13.46it/s]

[1,    20] loss: 61.8468021


  1%|▏         | 31/2422 [00:02<04:29,  8.88it/s]

[1,    30] loss: 46.0888577


  2%|▏         | 41/2422 [00:03<03:43, 10.67it/s]

[1,    40] loss: 16.1752067


  2%|▏         | 50/2422 [00:05<08:23,  4.71it/s]

[1,    50] loss: 2.4510290


  2%|▏         | 60/2422 [00:07<09:06,  4.32it/s]

[1,    60] loss: 0.7707083


  3%|▎         | 70/2422 [00:10<09:34,  4.10it/s]

[1,    70] loss: 1.0041341


  3%|▎         | 80/2422 [00:12<08:19,  4.69it/s]

[1,    80] loss: 1.5987549


  4%|▎         | 90/2422 [00:14<08:04,  4.82it/s]

[1,    90] loss: 1.3587666


  4%|▍         | 101/2422 [00:17<07:56,  4.87it/s]

[1,   100] loss: 1.0597275


  5%|▍         | 110/2422 [00:19<08:37,  4.46it/s]

[1,   110] loss: 0.9968802


  5%|▍         | 120/2422 [00:21<08:17,  4.63it/s]

[1,   120] loss: 0.9397389


  5%|▌         | 130/2422 [00:23<07:44,  4.94it/s]

[1,   130] loss: 0.8417522


  6%|▌         | 141/2422 [00:25<07:05,  5.37it/s]

[1,   140] loss: 0.7263591


  6%|▌         | 150/2422 [00:27<09:13,  4.10it/s]

[1,   150] loss: 0.7127776


  7%|▋         | 160/2422 [00:30<09:25,  4.00it/s]

[1,   160] loss: 0.6815220


  7%|▋         | 170/2422 [00:32<09:29,  3.96it/s]

[1,   170] loss: 0.6617583


  7%|▋         | 180/2422 [00:34<08:29,  4.40it/s]

[1,   180] loss: 0.6216890


  8%|▊         | 191/2422 [00:36<06:52,  5.41it/s]

[1,   190] loss: 0.5954580


  8%|▊         | 201/2422 [00:39<07:33,  4.90it/s]

[1,   200] loss: 0.5999322


  9%|▊         | 211/2422 [00:41<07:18,  5.04it/s]

[1,   210] loss: 0.5523004


  9%|▉         | 221/2422 [00:42<06:26,  5.70it/s]

[1,   220] loss: 0.5611764


 10%|▉         | 231/2422 [00:44<06:29,  5.63it/s]

[1,   230] loss: 0.5480604


 10%|▉         | 241/2422 [00:46<05:30,  6.59it/s]

[1,   240] loss: 0.5177999


 10%|█         | 251/2422 [00:48<06:07,  5.90it/s]

[1,   250] loss: 0.5250139


 11%|█         | 260/2422 [00:50<08:53,  4.05it/s]

[1,   260] loss: 0.5262754


 11%|█         | 271/2422 [00:52<07:35,  4.73it/s]

[1,   270] loss: 0.4874240


 12%|█▏        | 280/2422 [00:54<06:50,  5.22it/s]

[1,   280] loss: 0.5067344


 12%|█▏        | 291/2422 [00:56<06:19,  5.61it/s]

[1,   290] loss: 0.4993124


 12%|█▏        | 300/2422 [00:58<06:23,  5.53it/s]

[1,   300] loss: 0.4736684


 13%|█▎        | 310/2422 [01:00<07:28,  4.71it/s]

[1,   310] loss: 0.4987807


 13%|█▎        | 321/2422 [01:02<06:11,  5.65it/s]

[1,   320] loss: 0.5022946


 14%|█▎        | 330/2422 [01:03<06:52,  5.07it/s]

[1,   330] loss: 0.4440739


 14%|█▍        | 340/2422 [01:05<06:39,  5.21it/s]

[1,   340] loss: 0.5142640


 14%|█▍        | 350/2422 [01:07<06:11,  5.58it/s]

[1,   350] loss: 0.4764321


 15%|█▍        | 361/2422 [01:09<05:10,  6.64it/s]

[1,   360] loss: 0.4644539


 15%|█▌        | 371/2422 [01:11<05:21,  6.38it/s]

[1,   370] loss: 0.4860504


 16%|█▌        | 381/2422 [01:12<05:50,  5.83it/s]

[1,   380] loss: 0.4695196


 16%|█▌        | 391/2422 [01:14<05:12,  6.50it/s]

[1,   390] loss: 0.4737341


 17%|█▋        | 400/2422 [01:16<06:27,  5.22it/s]

[1,   400] loss: 0.4891867


 17%|█▋        | 410/2422 [01:17<06:14,  5.37it/s]

[1,   410] loss: 0.4529393


 17%|█▋        | 421/2422 [01:19<05:49,  5.72it/s]

[1,   420] loss: 0.4778143


 18%|█▊        | 430/2422 [01:21<05:57,  5.57it/s]

[1,   430] loss: 0.4866282


 18%|█▊        | 441/2422 [01:23<06:16,  5.27it/s]

[1,   440] loss: 0.4522553


 19%|█▊        | 451/2422 [01:25<05:03,  6.50it/s]

[1,   450] loss: 0.4743787


 19%|█▉        | 461/2422 [01:26<03:15, 10.05it/s]

[1,   460] loss: 0.4926888


 19%|█▉        | 470/2422 [01:27<03:41,  8.83it/s]

[1,   470] loss: 0.4648285


 20%|█▉        | 483/2422 [01:28<02:45, 11.71it/s]

[1,   480] loss: 0.4748993


 20%|██        | 490/2422 [01:29<03:01, 10.65it/s]

[1,   490] loss: 0.4699406


 21%|██        | 501/2422 [01:31<05:19,  6.01it/s]

[1,   500] loss: 0.4662645


 21%|██        | 512/2422 [01:31<02:38, 12.07it/s]

[1,   510] loss: 0.4651860


 22%|██▏       | 522/2422 [01:32<01:40, 18.87it/s]

[1,   520] loss: 0.4764664


 22%|██▏       | 532/2422 [01:32<01:43, 18.26it/s]

[1,   530] loss: 0.4781662


 22%|██▏       | 544/2422 [01:33<01:25, 21.86it/s]

[1,   540] loss: 0.4663146


 23%|██▎       | 550/2422 [01:33<01:59, 15.71it/s]

[1,   550] loss: 0.4731297


 23%|██▎       | 564/2422 [01:35<02:17, 13.50it/s]

[1,   560] loss: 0.4615220


 24%|██▎       | 570/2422 [01:35<02:33, 12.03it/s]

[1,   570] loss: 0.4664586


 24%|██▍       | 579/2422 [01:36<03:31,  8.73it/s]

[1,   580] loss: 0.4745902


 24%|██▍       | 591/2422 [01:38<05:17,  5.78it/s]

[1,   590] loss: 0.4624879


 25%|██▍       | 601/2422 [01:39<03:18,  9.17it/s]

[1,   600] loss: 0.4716096


 25%|██▌       | 610/2422 [01:40<03:07,  9.68it/s]

[1,   610] loss: 0.4761197


 26%|██▌       | 620/2422 [01:41<03:03,  9.81it/s]

[1,   620] loss: 0.4595509


 26%|██▌       | 631/2422 [01:42<01:59, 14.95it/s]

[1,   630] loss: 0.4740473


 27%|██▋       | 643/2422 [01:42<01:26, 20.46it/s]

[1,   640] loss: 0.4501587


 27%|██▋       | 653/2422 [01:43<01:44, 16.89it/s]

[1,   650] loss: 0.4843523


 27%|██▋       | 661/2422 [01:44<02:26, 12.01it/s]

[1,   660] loss: 0.4651423


 28%|██▊       | 671/2422 [01:45<04:10,  6.98it/s]

[1,   670] loss: 0.4651195


 28%|██▊       | 681/2422 [01:46<03:54,  7.43it/s]

[1,   680] loss: 0.4689447


 28%|██▊       | 690/2422 [01:48<05:48,  4.97it/s]

[1,   690] loss: 0.4599059


 29%|██▉       | 702/2422 [01:50<03:47,  7.56it/s]

[1,   700] loss: 0.4688624


 29%|██▉       | 711/2422 [01:51<02:27, 11.60it/s]

[1,   710] loss: 0.4715486


 30%|██▉       | 720/2422 [01:52<03:52,  7.32it/s]

[1,   720] loss: 0.4609969


 30%|███       | 731/2422 [01:54<05:13,  5.39it/s]

[1,   730] loss: 0.4765586


 31%|███       | 740/2422 [01:56<06:20,  4.42it/s]

[1,   740] loss: 0.4527685


 31%|███       | 750/2422 [01:58<06:20,  4.39it/s]

[1,   750] loss: 0.4840901


 31%|███▏      | 760/2422 [02:01<06:40,  4.15it/s]

[1,   760] loss: 0.4547821


 32%|███▏      | 770/2422 [02:03<06:20,  4.34it/s]

[1,   770] loss: 0.4681728


 32%|███▏      | 780/2422 [02:05<05:38,  4.85it/s]

[1,   780] loss: 0.4725622


 33%|███▎      | 790/2422 [02:07<06:15,  4.35it/s]

[1,   790] loss: 0.4605312


 33%|███▎      | 800/2422 [02:10<06:11,  4.37it/s]

[1,   800] loss: 0.4725147


 33%|███▎      | 810/2422 [02:12<06:28,  4.14it/s]

[1,   810] loss: 0.4702263


 34%|███▍      | 821/2422 [02:14<05:30,  4.84it/s]

[1,   820] loss: 0.4656587


 34%|███▍      | 830/2422 [02:16<05:41,  4.67it/s]

[1,   830] loss: 0.4743984


 35%|███▍      | 840/2422 [02:19<05:58,  4.42it/s]

[1,   840] loss: 0.4561654


 35%|███▌      | 851/2422 [02:21<04:54,  5.34it/s]

[1,   850] loss: 0.4693537


 36%|███▌      | 861/2422 [02:23<04:56,  5.26it/s]

[1,   860] loss: 0.4797601


 36%|███▌      | 871/2422 [02:25<04:56,  5.24it/s]

[1,   870] loss: 0.4571097


 36%|███▋      | 880/2422 [02:27<05:24,  4.75it/s]

[1,   880] loss: 0.4745048


 37%|███▋      | 890/2422 [02:29<05:48,  4.40it/s]

[1,   890] loss: 0.4567200


 37%|███▋      | 900/2422 [02:31<05:31,  4.59it/s]

[1,   900] loss: 0.4711542


 38%|███▊      | 911/2422 [02:33<05:07,  4.92it/s]

[1,   910] loss: 0.4673008


 38%|███▊      | 921/2422 [02:35<04:26,  5.63it/s]

[1,   920] loss: 0.4823452


 38%|███▊      | 931/2422 [02:37<04:33,  5.46it/s]

[1,   930] loss: 0.4473713


 39%|███▉      | 940/2422 [02:39<05:01,  4.92it/s]

[1,   940] loss: 0.4682946


 39%|███▉      | 950/2422 [02:41<05:54,  4.15it/s]

[1,   950] loss: 0.4676426


 40%|███▉      | 961/2422 [02:43<04:41,  5.19it/s]

[1,   960] loss: 0.4623486


 40%|████      | 970/2422 [02:45<05:11,  4.66it/s]

[1,   970] loss: 0.4706513


 40%|████      | 980/2422 [02:47<04:40,  5.13it/s]

[1,   980] loss: 0.4716268


 41%|████      | 990/2422 [02:49<04:44,  5.03it/s]

[1,   990] loss: 0.4631804


 41%|████▏     | 1000/2422 [02:51<05:08,  4.61it/s]

[1,  1000] loss: 0.4627995


 42%|████▏     | 1010/2422 [02:53<04:35,  5.12it/s]

[1,  1010] loss: 0.4721315


 42%|████▏     | 1020/2422 [02:55<04:42,  4.97it/s]

[1,  1020] loss: 0.4690325


 43%|████▎     | 1031/2422 [02:57<04:04,  5.70it/s]

[1,  1030] loss: 0.4620430


 43%|████▎     | 1041/2422 [02:59<03:46,  6.09it/s]

[1,  1040] loss: 0.4629469


 43%|████▎     | 1051/2422 [03:01<04:16,  5.35it/s]

[1,  1050] loss: 0.4697813


 44%|████▍     | 1061/2422 [03:03<04:49,  4.69it/s]

[1,  1060] loss: 0.4697418


 44%|████▍     | 1070/2422 [03:05<04:39,  4.84it/s]

[1,  1070] loss: 0.4669367


 45%|████▍     | 1080/2422 [03:07<04:51,  4.61it/s]

[1,  1080] loss: 0.4607599


 45%|████▌     | 1090/2422 [03:09<04:51,  4.57it/s]

[1,  1090] loss: 0.4764998


 45%|████▌     | 1101/2422 [03:11<03:58,  5.54it/s]

[1,  1100] loss: 0.4553115


 46%|████▌     | 1111/2422 [03:13<03:18,  6.62it/s]

[1,  1110] loss: 0.4723447


 46%|████▋     | 1121/2422 [03:15<03:56,  5.51it/s]

[1,  1120] loss: 0.4597180


 47%|████▋     | 1131/2422 [03:16<03:47,  5.68it/s]

[1,  1130] loss: 0.4786626


 47%|████▋     | 1140/2422 [03:18<03:39,  5.85it/s]

[1,  1140] loss: 0.4546703


 48%|████▊     | 1151/2422 [03:20<03:57,  5.36it/s]

[1,  1150] loss: 0.4684452


 48%|████▊     | 1162/2422 [03:21<02:13,  9.43it/s]

[1,  1160] loss: 0.4686710


 48%|████▊     | 1172/2422 [03:22<01:47, 11.60it/s]

[1,  1170] loss: 0.4707995


 49%|████▉     | 1183/2422 [03:23<01:19, 15.63it/s]

[1,  1180] loss: 0.4573336


 49%|████▉     | 1192/2422 [03:23<01:22, 14.89it/s]

[1,  1190] loss: 0.4788434


 50%|████▉     | 1201/2422 [03:24<02:51,  7.11it/s]

[1,  1200] loss: 0.4525548


 50%|████▉     | 1210/2422 [03:26<03:42,  5.45it/s]

[1,  1210] loss: 0.4830940


 50%|█████     | 1221/2422 [03:28<03:40,  5.44it/s]

[1,  1220] loss: 0.4498659


 51%|█████     | 1230/2422 [03:30<04:25,  4.49it/s]

[1,  1230] loss: 0.4690899


 51%|█████     | 1240/2422 [03:32<03:57,  4.98it/s]

[1,  1240] loss: 0.4752061


 52%|█████▏    | 1251/2422 [03:35<03:35,  5.44it/s]

[1,  1250] loss: 0.4566525


 52%|█████▏    | 1261/2422 [03:36<03:34,  5.41it/s]

[1,  1260] loss: 0.4736565


 52%|█████▏    | 1271/2422 [03:38<03:10,  6.03it/s]

[1,  1270] loss: 0.4492797


 53%|█████▎    | 1284/2422 [03:39<01:38, 11.53it/s]

[1,  1280] loss: 0.4781303


 53%|█████▎    | 1290/2422 [03:40<02:12,  8.52it/s]

[1,  1290] loss: 0.4614719


 54%|█████▍    | 1302/2422 [03:41<01:24, 13.33it/s]

[1,  1300] loss: 0.4712911


 54%|█████▍    | 1311/2422 [03:41<01:04, 17.25it/s]

[1,  1310] loss: 0.4619586


 55%|█████▍    | 1323/2422 [03:42<00:57, 19.09it/s]

[1,  1320] loss: 0.4755083


 55%|█████▍    | 1332/2422 [03:42<01:23, 13.00it/s]

[1,  1330] loss: 0.4608300


 55%|█████▌    | 1341/2422 [03:44<02:55,  6.15it/s]

[1,  1340] loss: 0.4641741


 56%|█████▌    | 1351/2422 [03:46<03:07,  5.70it/s]

[1,  1350] loss: 0.4666834


 56%|█████▌    | 1361/2422 [03:47<02:37,  6.73it/s]

[1,  1360] loss: 0.4791486


 57%|█████▋    | 1371/2422 [03:49<02:46,  6.32it/s]

[1,  1370] loss: 0.4534833


 57%|█████▋    | 1380/2422 [03:50<03:52,  4.49it/s]

[1,  1380] loss: 0.4650784


 57%|█████▋    | 1391/2422 [03:52<03:14,  5.29it/s]

[1,  1390] loss: 0.4808349


 58%|█████▊    | 1401/2422 [03:54<02:38,  6.44it/s]

[1,  1400] loss: 0.4619990


 58%|█████▊    | 1410/2422 [03:55<02:14,  7.52it/s]

[1,  1410] loss: 0.4679276


 59%|█████▊    | 1420/2422 [03:58<03:51,  4.32it/s]

[1,  1420] loss: 0.4563385


 59%|█████▉    | 1430/2422 [04:00<03:30,  4.70it/s]

[1,  1430] loss: 0.4738917


 59%|█████▉    | 1440/2422 [04:02<03:23,  4.83it/s]

[1,  1440] loss: 0.4672041


 60%|█████▉    | 1450/2422 [04:04<03:18,  4.89it/s]

[1,  1450] loss: 0.4796405


 60%|██████    | 1461/2422 [04:06<02:51,  5.60it/s]

[1,  1460] loss: 0.4560249


 61%|██████    | 1470/2422 [04:08<03:12,  4.95it/s]

[1,  1470] loss: 0.4653647


 61%|██████    | 1480/2422 [04:10<03:08,  4.99it/s]

[1,  1480] loss: 0.4687989


 62%|██████▏   | 1490/2422 [04:12<03:04,  5.06it/s]

[1,  1490] loss: 0.4724515


 62%|██████▏   | 1500/2422 [04:14<03:13,  4.77it/s]

[1,  1500] loss: 0.4684815


 62%|██████▏   | 1510/2422 [04:16<03:00,  5.04it/s]

[1,  1510] loss: 0.4674724


 63%|██████▎   | 1521/2422 [04:18<02:41,  5.57it/s]

[1,  1520] loss: 0.4596797


 63%|██████▎   | 1531/2422 [04:20<02:23,  6.20it/s]

[1,  1530] loss: 0.4691990


 64%|██████▎   | 1541/2422 [04:22<02:35,  5.67it/s]

[1,  1540] loss: 0.4689468


 64%|██████▍   | 1551/2422 [04:24<02:34,  5.65it/s]

[1,  1550] loss: 0.4718216


 64%|██████▍   | 1561/2422 [04:25<01:48,  7.94it/s]

[1,  1560] loss: 0.4618802


 65%|██████▍   | 1570/2422 [04:27<02:35,  5.48it/s]

[1,  1570] loss: 0.4585785


 65%|██████▌   | 1580/2422 [04:28<02:16,  6.15it/s]

[1,  1580] loss: 0.4749393


 66%|██████▌   | 1590/2422 [04:30<02:40,  5.19it/s]

[1,  1590] loss: 0.4745710


 66%|██████▌   | 1600/2422 [04:32<02:21,  5.79it/s]

[1,  1600] loss: 0.4587445


 67%|██████▋   | 1611/2422 [04:34<02:22,  5.70it/s]

[1,  1610] loss: 0.4625873


 67%|██████▋   | 1620/2422 [04:36<02:20,  5.69it/s]

[1,  1620] loss: 0.4673975


 67%|██████▋   | 1630/2422 [04:38<02:24,  5.48it/s]

[1,  1630] loss: 0.4644774


 68%|██████▊   | 1641/2422 [04:40<02:37,  4.97it/s]

[1,  1640] loss: 0.4746134


 68%|██████▊   | 1651/2422 [04:41<02:10,  5.89it/s]

[1,  1650] loss: 0.4552269


 69%|██████▊   | 1660/2422 [04:43<02:21,  5.39it/s]

[1,  1660] loss: 0.4863706


 69%|██████▉   | 1671/2422 [04:45<02:26,  5.13it/s]

[1,  1670] loss: 0.4505146


 69%|██████▉   | 1681/2422 [04:47<02:14,  5.52it/s]

[1,  1680] loss: 0.4694206


 70%|██████▉   | 1691/2422 [04:49<02:03,  5.91it/s]

[1,  1690] loss: 0.4685817


 70%|███████   | 1700/2422 [04:51<02:24,  5.01it/s]

[1,  1700] loss: 0.4589506


 71%|███████   | 1710/2422 [04:53<02:58,  3.99it/s]

[1,  1710] loss: 0.4675216


 71%|███████   | 1720/2422 [04:55<02:38,  4.43it/s]

[1,  1720] loss: 0.4731868


 71%|███████▏  | 1730/2422 [04:58<02:56,  3.92it/s]

[1,  1730] loss: 0.4658564


 72%|███████▏  | 1741/2422 [05:00<02:22,  4.78it/s]

[1,  1740] loss: 0.4636966


 72%|███████▏  | 1751/2422 [05:02<01:58,  5.65it/s]

[1,  1750] loss: 0.4703145


 73%|███████▎  | 1760/2422 [05:03<01:51,  5.94it/s]

[1,  1760] loss: 0.4764308


 73%|███████▎  | 1771/2422 [05:05<01:41,  6.43it/s]

[1,  1770] loss: 0.4526245


 74%|███████▎  | 1781/2422 [05:06<00:53, 12.09it/s]

[1,  1780] loss: 0.4710017


 74%|███████▍  | 1792/2422 [05:07<00:54, 11.49it/s]

[1,  1790] loss: 0.4575782


 74%|███████▍  | 1801/2422 [05:08<01:04,  9.67it/s]

[1,  1800] loss: 0.4839160


 75%|███████▍  | 1811/2422 [05:09<01:07,  9.03it/s]

[1,  1810] loss: 0.4596851


 75%|███████▌  | 1821/2422 [05:10<01:40,  5.97it/s]

[1,  1820] loss: 0.4640886


 76%|███████▌  | 1831/2422 [05:12<01:24,  6.96it/s]

[1,  1830] loss: 0.4649456


 76%|███████▌  | 1841/2422 [05:13<01:32,  6.28it/s]

[1,  1840] loss: 0.4834308


 76%|███████▋  | 1851/2422 [05:15<01:28,  6.46it/s]

[1,  1850] loss: 0.4544279


 77%|███████▋  | 1861/2422 [05:17<01:47,  5.23it/s]

[1,  1860] loss: 0.4642893


 77%|███████▋  | 1871/2422 [05:19<01:45,  5.25it/s]

[1,  1870] loss: 0.4818742


 78%|███████▊  | 1881/2422 [05:21<01:39,  5.41it/s]

[1,  1880] loss: 0.4485263


 78%|███████▊  | 1890/2422 [05:23<01:52,  4.74it/s]

[1,  1890] loss: 0.4680040


 78%|███████▊  | 1900/2422 [05:25<01:54,  4.56it/s]

[1,  1900] loss: 0.4714532


 79%|███████▉  | 1911/2422 [05:27<01:27,  5.86it/s]

[1,  1910] loss: 0.4674715


 79%|███████▉  | 1921/2422 [05:29<01:43,  4.85it/s]

[1,  1920] loss: 0.4710798


 80%|███████▉  | 1930/2422 [05:30<01:29,  5.49it/s]

[1,  1930] loss: 0.4624509


 80%|████████  | 1940/2422 [05:33<01:42,  4.71it/s]

[1,  1940] loss: 0.4665506


 81%|████████  | 1950/2422 [05:35<02:11,  3.59it/s]

[1,  1950] loss: 0.4695052


 81%|████████  | 1960/2422 [05:38<01:54,  4.03it/s]

[1,  1960] loss: 0.4677861


 81%|████████▏ | 1970/2422 [05:40<01:40,  4.51it/s]

[1,  1970] loss: 0.4681294


 82%|████████▏ | 1980/2422 [05:42<01:47,  4.10it/s]

[1,  1980] loss: 0.4673650


 82%|████████▏ | 1990/2422 [05:45<01:50,  3.90it/s]

[1,  1990] loss: 0.4733500


 83%|████████▎ | 2000/2422 [05:47<01:46,  3.95it/s]

[1,  2000] loss: 0.4710179


 83%|████████▎ | 2010/2422 [05:50<01:34,  4.37it/s]

[1,  2010] loss: 0.4593552


 83%|████████▎ | 2020/2422 [05:52<01:32,  4.36it/s]

[1,  2020] loss: 0.4635159


 84%|████████▍ | 2030/2422 [05:54<01:29,  4.40it/s]

[1,  2030] loss: 0.4587702


 84%|████████▍ | 2040/2422 [05:56<01:32,  4.15it/s]

[1,  2040] loss: 0.4726067


 85%|████████▍ | 2050/2422 [05:59<01:23,  4.47it/s]

[1,  2050] loss: 0.4771091


 85%|████████▌ | 2061/2422 [06:01<01:17,  4.64it/s]

[1,  2060] loss: 0.4641555


 86%|████████▌ | 2071/2422 [06:03<01:05,  5.39it/s]

[1,  2070] loss: 0.4586671


 86%|████████▌ | 2080/2422 [06:05<01:07,  5.10it/s]

[1,  2080] loss: 0.4781881


 86%|████████▋ | 2090/2422 [06:07<01:09,  4.79it/s]

[1,  2090] loss: 0.4677442


 87%|████████▋ | 2100/2422 [06:09<01:13,  4.40it/s]

[1,  2100] loss: 0.4722863


 87%|████████▋ | 2110/2422 [06:11<01:10,  4.40it/s]

[1,  2110] loss: 0.4512054


 88%|████████▊ | 2120/2422 [06:13<00:57,  5.29it/s]

[1,  2120] loss: 0.4751598


 88%|████████▊ | 2131/2422 [06:16<01:00,  4.81it/s]

[1,  2130] loss: 0.4745426


 88%|████████▊ | 2141/2422 [06:17<00:44,  6.31it/s]

[1,  2140] loss: 0.4632393


 89%|████████▉ | 2151/2422 [06:19<00:59,  4.52it/s]

[1,  2150] loss: 0.4688376


 89%|████████▉ | 2160/2422 [06:22<01:06,  3.92it/s]

[1,  2160] loss: 0.4644844


 90%|████████▉ | 2170/2422 [06:24<01:05,  3.83it/s]

[1,  2170] loss: 0.4597722


 90%|█████████ | 2180/2422 [06:27<01:02,  3.88it/s]

[1,  2180] loss: 0.4788057


 90%|█████████ | 2190/2422 [06:30<00:55,  4.19it/s]

[1,  2190] loss: 0.4620663


 91%|█████████ | 2200/2422 [06:32<00:51,  4.35it/s]

[1,  2200] loss: 0.4722945


 91%|█████████ | 2210/2422 [06:34<00:52,  4.02it/s]

[1,  2210] loss: 0.4633785


 92%|█████████▏| 2220/2422 [06:37<00:47,  4.22it/s]

[1,  2220] loss: 0.4639327


 92%|█████████▏| 2230/2422 [06:39<00:40,  4.73it/s]

[1,  2230] loss: 0.4745985


 92%|█████████▏| 2240/2422 [06:41<00:39,  4.58it/s]

[1,  2240] loss: 0.4550344


 93%|█████████▎| 2251/2422 [06:43<00:31,  5.39it/s]

[1,  2250] loss: 0.4822009


 93%|█████████▎| 2260/2422 [06:45<00:31,  5.22it/s]

[1,  2260] loss: 0.4534733


 94%|█████████▎| 2270/2422 [06:47<00:32,  4.64it/s]

[1,  2270] loss: 0.4739352


 94%|█████████▍| 2280/2422 [06:49<00:30,  4.66it/s]

[1,  2280] loss: 0.4663671


 95%|█████████▍| 2291/2422 [06:52<00:25,  5.13it/s]

[1,  2290] loss: 0.4593419


 95%|█████████▍| 2300/2422 [06:53<00:25,  4.74it/s]

[1,  2300] loss: 0.4732232


 95%|█████████▌| 2310/2422 [06:56<00:25,  4.47it/s]

[1,  2310] loss: 0.4700605


 96%|█████████▌| 2320/2422 [06:58<00:21,  4.82it/s]

[1,  2320] loss: 0.4671829


 96%|█████████▌| 2330/2422 [07:00<00:19,  4.66it/s]

[1,  2330] loss: 0.4608948


 97%|█████████▋| 2340/2422 [07:02<00:17,  4.76it/s]

[1,  2340] loss: 0.4762989


 97%|█████████▋| 2350/2422 [07:04<00:14,  4.97it/s]

[1,  2350] loss: 0.4612647


 97%|█████████▋| 2360/2422 [07:06<00:13,  4.67it/s]

[1,  2360] loss: 0.4699706


 98%|█████████▊| 2371/2422 [07:08<00:09,  5.20it/s]

[1,  2370] loss: 0.4628771


 98%|█████████▊| 2380/2422 [07:10<00:10,  3.90it/s]

[1,  2380] loss: 0.4720470


 99%|█████████▊| 2390/2422 [07:13<00:07,  4.48it/s]

[1,  2390] loss: 0.4762239


 99%|█████████▉| 2400/2422 [07:15<00:04,  4.68it/s]

[1,  2400] loss: 0.4578701


100%|█████████▉| 2410/2422 [07:18<00:02,  4.13it/s]

[1,  2410] loss: 0.4722994


100%|█████████▉| 2420/2422 [07:20<00:00,  4.58it/s]

[1,  2420] loss: 0.4621632


100%|██████████| 2422/2422 [07:20<00:00,  4.34it/s]

Finished Training





## Evaluation --> compare to random

In [14]:
correct = 0
incorrect = 0
for batch_idx, (aptamer, peptide) in enumerate(tqdm.tqdm(test_loader)):
    pep = one_hot(peptide, seq_type='peptide')
    apt = one_hot(aptamer, seq_type='aptamer')
    
    pep = torch.FloatTensor(np.reshape(pep, (1, pep.shape[0], pep.shape[1], pep.shape[2])))
    apt = torch.FloatTensor(np.reshape(apt, (1, apt.shape[0], apt.shape[1], apt.shape[2])))

    output = model(apt, pep).detach().numpy().flatten()
    for i in range(output.shape[0]):
        o = output[i]
        print(o)
        if o > 0.9:
            correct += 1
        else:
            incorrect += 1
    break

print('Recall of the network on the test samples: %d %%' % (100* correct/(correct + incorrect)))

  0%|          | 0/605 [00:00<?, ?it/s]

0.9567549
0.957085
0.9568714
0.957408
0.95679474
0.95700353
0.95702887
0.95687366
0.9571973
0.9572274
0.9568822
0.9568617
0.9568713
0.95701253
0.9574633
0.9573711
0.9573384
0.9572117
0.9570139
0.9570279
0.9568444
0.9571089
0.9572363
0.95738965
0.9571937
0.95730424
0.95736265
0.9568217
0.9568214
0.9570708
0.9569886
0.95685804
0.95681494
0.95696676
0.957396
0.95748353
0.9572729
0.95731765
0.9570516
0.95701504
0.95679146
0.9572955
0.9571367
0.9570932
0.95738256
0.95700616
0.9568899
0.9572897
0.9570907
0.9569316
0.95729786
0.95738834
0.9571122
0.95737904
0.9573532
0.9568454
0.95685476
0.9573713
0.95733887
0.9568479
0.9572363
0.9574307
0.95733786
0.9571064
0.9569708
0.95716006
0.9572149
0.95698816
0.95695084
0.95740706
0.9569526
0.956785
0.9569647
0.956881
0.9569896
0.95706743
0.9568834
0.95719403
0.9568906
0.9569418
0.9568436
0.95743215
0.9575058
0.95702386
0.9568679
0.9568499
0.9568144
0.9572034
0.95740277
0.9575205
0.95686996
0.9567855
0.95737535
0.9572734
0.9571198
0.9574195
0.95708543





# sampling --> need to change

In [13]:
k = 1e5

# Generate uniformly without replacement
def get_samples(kind="pep",num=k):
    if kind == "apt":
        samples = [all_aptamers[i] for i in np.random.choice(len(all_aptamers), num_samples, replace=False)]
    else:
        samples = [all_peptides[i] for i in np.random.choice(len(all_peptides), num_samples, replace=False)]
    return samples

# Sample x' from P_X (assume peptides follow NNK)
def get_x_prime(k):
    x_primes = []
    for _ in range(k):
        pvals = [0.089]*3 + [0.065]*5 + [0.034]*12
        x_idx = np.random.choice(20, 7, p=pvals)
        x_prime = "M"
        for i in x_idx:
            x_prime += aa_list[i]
        x_primes.append(x_prime)
    return x_primes

# Sample y' from P_Y (assume apatamers follow uniform)
def get_y_prime(k):
    y_primes = []
    for _ in range(k):
        y_idx = np.random.randint(0, 4, 40)
        y_prime = ""
        for i in y_idx:
            y_prime += na_list[i]
        y_primes.append(y_prime)
    return y_primes