In [1]:
import os, sys
import numpy as np
import json
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
import tqdm

## Dataset and one-hot encoding

In [2]:
BATCH_SIZE = 256
random.seed(42)

In [3]:
aptamer_dataset_file = "../data/aptamer_dataset.json"

def construct_dataset():
    with open(aptamer_dataset_file, 'r') as f:
        aptamer_data = json.load(f)
    full_dataset = []
    for aptamer in aptamer_data:
        peptides = aptamer_data[aptamer]
        for peptide, _ in peptides:
            peptide = peptide.replace("_", "")
            if len(aptamer) == 40 and len(peptide) == 8:
                full_dataset.append((aptamer, peptide))
    return list(set(full_dataset)) #removed duplicates

In [4]:
full_dataset = construct_dataset()
aptamers = [p[0] for p in full_dataset]
peptides = [p[1] for p in full_dataset]
training_set = full_dataset[:int(0.8*len(full_dataset))]
test_set = full_dataset[int(0.8*len(full_dataset)):]

In [5]:
class TrainDataset(torch.utils.data.Dataset):
    def __init__(self, training_set):
        super(TrainDataset, self).__init__() 
        self.training_set = training_set
        n = len(training_set)
        self.training_set = training_set[:n-n%BATCH_SIZE]
        
    def __len__(self):
        return len(self.training_set)

    def __getitem__(self, idx):
        aptamer, peptide = self.training_set[idx]
        return aptamer, peptide
    
class TestDataset(torch.utils.data.Dataset):
    def __init__(self, test_set):
        super(TestDataset, self).__init__() 
        self.test_set = test_set
        n = len(test_set)
        self.test_set = test_set[:n-n%BATCH_SIZE]
        
    def __len__(self):
        return len(self.test_set)

    def __getitem__(self, idx):
        aptamer, peptide = self.test_set[idx]
        return aptamer, peptide

In [6]:
train_dataset = TrainDataset(training_set)
test_dataset = TestDataset(test_set)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE)

In [7]:
na_list = ['A', 'C', 'G', 'T']
aa_list = ['R', 'L', 'S', 'A', 'G', 'P', 'T', 'V', 'N', 'D', 'C', 'Q', 'E', 'H', 'I', 'K', 'M', 'F', 'W', 'Y']
pvals = [0.089]*3 + [0.065]*5 + [0.034]*12
aa_dict = dict(zip(aa_list, pvals))

In [8]:
## Takes a peptide and aptamer sequence and converts to one-hot matrix
def one_hot(sequence_list, seq_type='peptide'):
    if seq_type == 'peptide':
        letters = aa_list
    else:
        letters = na_list
    
    one_hot = np.zeros((len(sequence_list), len(sequence_list[0]), len(letters)))
    
    for j in range(len(sequence_list)):
        sequence = sequence_list[j]
        for i in range(len(sequence)):
            element = sequence[i]
            idx = letters.index(element)
            one_hot[j][i][idx] = 1
    return one_hot

## NN Model

In [9]:
class TwoLayer(nn.Module):
    def __init__(self):
        super(TwoLayer, self).__init__()
        self.linear_apt_1 = nn.Linear(40, 40)
        self.linear_apt_2 = nn.Linear(40, 1)
        
        self.linear_pep_1 = nn.Linear(8, 8)
        self.linear_pep_2 = nn.Linear(8, 1)
        
        self.relu = nn.ReLU()
    
        self.sequential_pep = nn.Sequential(self.linear_pep_1,
                                            self.relu,
                                            self.linear_pep_2)
        
        self.sequential_apt = nn.Sequential(self.linear_apt_1,
                                            self.relu,
                                            self.linear_apt_2)
                
    def forward(self, apt, pep):
        apt = self.sequential_apt(apt)
        pep = self.sequential_pep(pep)
        apt = apt.view(-1, 1).T
        pep = pep.view(-1, 1).T
        x = torch.cat((apt, pep), 1)
        x = F.sigmoid(x)
        return x
    
    def loss(self, prediction, label):
        l = nn.MSELoss()
        label = torch.FloatTensor(label)
        label = label.reshape((1, 1))
        return l(torch.FloatTensor(prediction), label)

In [10]:
class ConvNet(nn.Module):
    def __init__(self, batch_size):
        super(ConvNet, self).__init__()
        self.cnn_apt_1 = nn.Conv2d(BATCH_SIZE, 40, 1)
        self.cnn_apt_2 = nn.Conv2d(40, 10, 1)
        self.cnn_apt_3 = nn.Conv2d(10, 1, 1)
        self.fc_apt_1 = nn.Linear(160, 1)
        
        self.cnn_pep_1 = nn.Conv2d(BATCH_SIZE, 8, 1)
        self.cnn_pep_2 = nn.Conv2d(8, 1, 1)
        self.fc_pep_1 = nn.Linear(64, 1)
        
        self.pool = nn.MaxPool2d(2, 1)
        self.relu = nn.ReLU()
        
        #self.dropout = nn.Dropout(0.1)
        
        self.sequential_pep = nn.Sequential(self.cnn_pep_1,
                                            #self.dropout,
                                            self.relu, 
                                            self.pool, 
                                            self.cnn_pep_2)
        
        self.sequential_apt = nn.Sequential(self.cnn_apt_1, 
                                            #self.dropout,
                                            self.relu, 
                                            self.pool, 
                                            self.cnn_apt_2, 
                                            #self.dropout,
                                            self.relu, 
                                            self.pool, 
                                            self.cnn_apt_3)
        
        self.fc1 = nn.Linear(209, BATCH_SIZE)
        
    def forward(self, apt, pep):
        apt = self.sequential_apt(apt)
        pep = self.sequential_pep(pep)
        
        apt = apt.view(-1, 1).T
        pep = pep.view(-1, 1).T
        
        x = torch.cat((apt, pep), 1)
        x = self.fc1(x)
        x = torch.sigmoid(x)
        return x
    
    def loss(self, prediction, label):
        l = nn.MSELoss()
        label = torch.FloatTensor(label)
        return l(torch.FloatTensor(prediction), label)

In [11]:
model = ConvNet(batch_size=BATCH_SIZE)
def weights_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.xavier_uniform_(m.weight.data)
        nn.init.zeros_(m.bias.data)
    if isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight.data, nonlinearity='relu')
        nn.init.zeros_(m.bias.data)

#model_2 = TwoLayer()
# def weights_init_2(m):
#     if isinstance(m, nn.Linear):
#         nn.init.xavier_uniform_(m.weight.data)
#         nn.init.zeros_(m.bias.data)

model.apply(weights_init)
optimizer = Adam(model.parameters(), lr=1e-3, weight_decay=1e-2)

## Training

In [12]:
# Training Loop
for epoch in range(1):
    print("Epoch: ", epoch)
    model.train()
    running_loss = 0.0
    # Come up with a trainloader
    for batch_idx, (aptamer, peptide) in enumerate(tqdm.tqdm(train_loader)):
        # Peptide and aptamer, one-hot encode them 
        pep = one_hot(peptide, seq_type='peptide')
        apt = one_hot(aptamer, seq_type='aptamer')
        
        pep = torch.FloatTensor(np.reshape(pep, (1, pep.shape[0], pep.shape[1], pep.shape[2])))
        apt = torch.FloatTensor(np.reshape(apt, (1, apt.shape[0], apt.shape[1], apt.shape[2])))
        
        output = model(apt, pep)
        
        label = np.ones((1, BATCH_SIZE))
        loss = model.loss(output, label)
        optimizer.zero_grad()
        loss.backward()
        
        #hyperparameter
        clip = 5
        torch.nn.utils.clip_grad_norm(model.parameters(), clip)
        
        optimizer.step()
        running_loss += loss.item()
        if batch_idx % 10 == 9:
            print('[%d, %5d] loss: %.7f' %
                  (epoch + 1, batch_idx + 1, running_loss / 10*BATCH_SIZE))
            running_loss = 0.0
    
print('Finished Training')

  0%|          | 0/1859 [00:00<?, ?it/s]

Epoch:  0


  1%|          | 11/1859 [00:01<05:25,  5.67it/s]

[1,    10] loss: 63.8710293


  1%|          | 21/1859 [00:03<04:38,  6.60it/s]

[1,    20] loss: 41.1774403


  2%|▏         | 30/1859 [00:05<07:01,  4.34it/s]

[1,    30] loss: 11.5556701


  2%|▏         | 40/1859 [00:07<06:49,  4.44it/s]

[1,    40] loss: 1.2386307


  3%|▎         | 51/1859 [00:09<06:06,  4.93it/s]

[1,    50] loss: 0.7817502


  3%|▎         | 60/1859 [00:11<05:11,  5.78it/s]

[1,    60] loss: 1.6733526


  4%|▍         | 70/1859 [00:13<04:49,  6.18it/s]

[1,    70] loss: 1.4937803


  4%|▍         | 80/1859 [00:15<06:24,  4.62it/s]

[1,    80] loss: 1.0444667


  5%|▍         | 91/1859 [00:17<05:04,  5.80it/s]

[1,    90] loss: 1.0367521


  5%|▌         | 101/1859 [00:19<05:20,  5.48it/s]

[1,   100] loss: 0.9894654


  6%|▌         | 111/1859 [00:21<05:52,  4.96it/s]

[1,   110] loss: 0.9117784


  7%|▋         | 121/1859 [00:22<04:55,  5.89it/s]

[1,   120] loss: 0.8183960


  7%|▋         | 130/1859 [00:24<05:57,  4.83it/s]

[1,   130] loss: 0.7999410


  8%|▊         | 140/1859 [00:26<05:50,  4.90it/s]

[1,   140] loss: 0.7785866


  8%|▊         | 150/1859 [00:28<05:53,  4.83it/s]

[1,   150] loss: 0.7077753


  9%|▊         | 161/1859 [00:31<05:21,  5.28it/s]

[1,   160] loss: 0.6911426


  9%|▉         | 171/1859 [00:33<05:28,  5.14it/s]

[1,   170] loss: 0.6584550


 10%|▉         | 181/1859 [00:34<04:26,  6.30it/s]

[1,   180] loss: 0.6458887


 10%|█         | 190/1859 [00:36<05:35,  4.98it/s]

[1,   190] loss: 0.6274485


 11%|█         | 201/1859 [00:38<05:28,  5.05it/s]

[1,   200] loss: 0.5819058


 11%|█▏        | 211/1859 [00:40<04:19,  6.34it/s]

[1,   210] loss: 0.5993595


 12%|█▏        | 221/1859 [00:42<05:24,  5.05it/s]

[1,   220] loss: 0.5762372


 12%|█▏        | 230/1859 [00:43<05:49,  4.66it/s]

[1,   230] loss: 0.5465161


 13%|█▎        | 240/1859 [00:46<06:34,  4.11it/s]

[1,   240] loss: 0.5681571


 14%|█▎        | 251/1859 [00:48<05:39,  4.73it/s]

[1,   250] loss: 0.5244665


 14%|█▍        | 261/1859 [00:50<05:26,  4.90it/s]

[1,   260] loss: 0.5254282


 15%|█▍        | 271/1859 [00:52<04:25,  5.97it/s]

[1,   270] loss: 0.5574255


 15%|█▌        | 281/1859 [00:54<04:44,  5.54it/s]

[1,   280] loss: 0.5029975


 16%|█▌        | 291/1859 [00:56<04:43,  5.52it/s]

[1,   290] loss: 0.5174906


 16%|█▌        | 301/1859 [00:58<06:04,  4.28it/s]

[1,   300] loss: 0.5146613


 17%|█▋        | 310/1859 [01:00<05:54,  4.37it/s]

[1,   310] loss: 0.4897019


 17%|█▋        | 320/1859 [01:02<05:42,  4.50it/s]

[1,   320] loss: 0.5157504


 18%|█▊        | 330/1859 [01:05<05:21,  4.75it/s]

[1,   330] loss: 0.5015422


 18%|█▊        | 340/1859 [01:07<05:01,  5.04it/s]

[1,   340] loss: 0.4940511


 19%|█▉        | 351/1859 [01:09<04:52,  5.16it/s]

[1,   350] loss: 0.4908723


 19%|█▉        | 361/1859 [01:11<04:10,  5.98it/s]

[1,   360] loss: 0.4911854


 20%|█▉        | 371/1859 [01:13<04:45,  5.22it/s]

[1,   370] loss: 0.4986538


 20%|██        | 380/1859 [01:14<04:55,  5.01it/s]

[1,   380] loss: 0.4797839


 21%|██        | 391/1859 [01:16<04:14,  5.78it/s]

[1,   390] loss: 0.4816236


 22%|██▏       | 401/1859 [01:18<04:10,  5.82it/s]

[1,   400] loss: 0.4784183


 22%|██▏       | 411/1859 [01:20<04:12,  5.73it/s]

[1,   410] loss: 0.4913036


 23%|██▎       | 421/1859 [01:22<05:11,  4.61it/s]

[1,   420] loss: 0.4713367


 23%|██▎       | 431/1859 [01:24<03:35,  6.64it/s]

[1,   430] loss: 0.4899160


 24%|██▎       | 441/1859 [01:26<04:52,  4.85it/s]

[1,   440] loss: 0.4721671


 24%|██▍       | 450/1859 [01:28<04:49,  4.87it/s]

[1,   450] loss: 0.4770549


 25%|██▍       | 461/1859 [01:30<04:07,  5.65it/s]

[1,   460] loss: 0.4689699


 25%|██▌       | 470/1859 [01:31<04:25,  5.24it/s]

[1,   470] loss: 0.4821397


 26%|██▌       | 480/1859 [01:33<04:23,  5.23it/s]

[1,   480] loss: 0.4714336


 26%|██▋       | 490/1859 [01:35<05:22,  4.25it/s]

[1,   490] loss: 0.4880166


 27%|██▋       | 500/1859 [01:38<05:05,  4.45it/s]

[1,   500] loss: 0.4660508


 27%|██▋       | 511/1859 [01:40<03:40,  6.12it/s]

[1,   510] loss: 0.4841307


 28%|██▊       | 520/1859 [01:41<04:09,  5.36it/s]

[1,   520] loss: 0.4523395


 29%|██▊       | 531/1859 [01:43<03:46,  5.86it/s]

[1,   530] loss: 0.4830509


 29%|██▉       | 540/1859 [01:45<04:24,  4.99it/s]

[1,   540] loss: 0.4667489


 30%|██▉       | 551/1859 [01:47<04:24,  4.95it/s]

[1,   550] loss: 0.4717455


 30%|███       | 561/1859 [01:49<03:04,  7.02it/s]

[1,   560] loss: 0.4854517


 31%|███       | 570/1859 [01:51<04:33,  4.71it/s]

[1,   570] loss: 0.4500071


 31%|███▏      | 581/1859 [01:53<03:36,  5.90it/s]

[1,   580] loss: 0.4887236


 32%|███▏      | 591/1859 [01:54<04:04,  5.19it/s]

[1,   590] loss: 0.4581998


 32%|███▏      | 600/1859 [01:56<04:47,  4.38it/s]

[1,   600] loss: 0.4684577


 33%|███▎      | 611/1859 [01:59<04:03,  5.13it/s]

[1,   610] loss: 0.4778093


 33%|███▎      | 620/1859 [02:00<03:39,  5.64it/s]

[1,   620] loss: 0.4629440


 34%|███▍      | 631/1859 [02:02<03:10,  6.43it/s]

[1,   630] loss: 0.4707256


 34%|███▍      | 641/1859 [02:04<03:17,  6.16it/s]

[1,   640] loss: 0.4660046


 35%|███▌      | 651/1859 [02:06<03:20,  6.03it/s]

[1,   650] loss: 0.4662451


 36%|███▌      | 661/1859 [02:08<03:49,  5.23it/s]

[1,   660] loss: 0.4702012


 36%|███▌      | 671/1859 [02:10<03:48,  5.21it/s]

[1,   670] loss: 0.4681686


 37%|███▋      | 680/1859 [02:11<03:44,  5.24it/s]

[1,   680] loss: 0.4664205


 37%|███▋      | 691/1859 [02:13<02:48,  6.94it/s]

[1,   690] loss: 0.4628905


 38%|███▊      | 701/1859 [02:15<02:50,  6.78it/s]

[1,   700] loss: 0.4802170


 38%|███▊      | 711/1859 [02:16<02:20,  8.15it/s]

[1,   710] loss: 0.4495775


 39%|███▉      | 721/1859 [02:17<02:34,  7.38it/s]

[1,   720] loss: 0.4769483


 39%|███▉      | 730/1859 [02:19<03:22,  5.57it/s]

[1,   730] loss: 0.4657345


 40%|███▉      | 741/1859 [02:21<03:09,  5.90it/s]

[1,   740] loss: 0.4724808


 40%|████      | 751/1859 [02:23<03:15,  5.67it/s]

[1,   750] loss: 0.4551714


 41%|████      | 760/1859 [02:25<03:50,  4.77it/s]

[1,   760] loss: 0.4795878


 41%|████▏     | 770/1859 [02:27<03:48,  4.76it/s]

[1,   770] loss: 0.4485057


 42%|████▏     | 781/1859 [02:29<03:14,  5.54it/s]

[1,   780] loss: 0.4843896


 42%|████▏     | 790/1859 [02:31<03:36,  4.95it/s]

[1,   790] loss: 0.4545255


 43%|████▎     | 800/1859 [02:33<03:30,  5.03it/s]

[1,   800] loss: 0.4778819


 44%|████▎     | 811/1859 [02:35<03:13,  5.41it/s]

[1,   810] loss: 0.4569006


 44%|████▍     | 821/1859 [02:36<03:06,  5.55it/s]

[1,   820] loss: 0.4692132


 45%|████▍     | 830/1859 [02:38<03:09,  5.43it/s]

[1,   830] loss: 0.4540211


 45%|████▌     | 841/1859 [02:40<02:59,  5.67it/s]

[1,   840] loss: 0.4756480


 46%|████▌     | 851/1859 [02:42<02:20,  7.18it/s]

[1,   850] loss: 0.4738371


 46%|████▋     | 861/1859 [02:43<02:20,  7.10it/s]

[1,   860] loss: 0.4537503


 47%|████▋     | 871/1859 [02:45<02:45,  5.97it/s]

[1,   870] loss: 0.4748430


 47%|████▋     | 880/1859 [02:46<02:59,  5.44it/s]

[1,   880] loss: 0.4497405


 48%|████▊     | 891/1859 [02:48<02:41,  6.00it/s]

[1,   890] loss: 0.4818449


 48%|████▊     | 901/1859 [02:50<02:48,  5.70it/s]

[1,   900] loss: 0.4569391


 49%|████▉     | 910/1859 [02:52<02:45,  5.72it/s]

[1,   910] loss: 0.4692311


 50%|████▉     | 921/1859 [02:54<02:42,  5.76it/s]

[1,   920] loss: 0.4747794


 50%|█████     | 930/1859 [02:55<02:53,  5.35it/s]

[1,   930] loss: 0.4532784


 51%|█████     | 941/1859 [02:58<02:56,  5.20it/s]

[1,   940] loss: 0.4677229


 51%|█████     | 951/1859 [02:59<02:29,  6.07it/s]

[1,   950] loss: 0.4643159


 52%|█████▏    | 960/1859 [03:01<02:46,  5.41it/s]

[1,   960] loss: 0.4567976


 52%|█████▏    | 971/1859 [03:03<02:31,  5.87it/s]

[1,   970] loss: 0.4733768


 53%|█████▎    | 980/1859 [03:04<02:34,  5.69it/s]

[1,   980] loss: 0.4681798


 53%|█████▎    | 991/1859 [03:07<02:42,  5.33it/s]

[1,   990] loss: 0.4576215


 54%|█████▍    | 1001/1859 [03:09<02:28,  5.78it/s]

[1,  1000] loss: 0.4730136


 54%|█████▍    | 1010/1859 [03:11<03:16,  4.31it/s]

[1,  1010] loss: 0.4560671


 55%|█████▍    | 1021/1859 [03:13<02:26,  5.73it/s]

[1,  1020] loss: 0.4674398


 55%|█████▌    | 1030/1859 [03:15<02:52,  4.80it/s]

[1,  1030] loss: 0.4758401


 56%|█████▌    | 1041/1859 [03:17<02:07,  6.42it/s]

[1,  1040] loss: 0.4607800


 57%|█████▋    | 1051/1859 [03:18<02:08,  6.31it/s]

[1,  1050] loss: 0.4758346


 57%|█████▋    | 1060/1859 [03:20<01:58,  6.73it/s]

[1,  1060] loss: 0.4530327


 58%|█████▊    | 1071/1859 [03:22<02:01,  6.51it/s]

[1,  1070] loss: 0.4823646


 58%|█████▊    | 1081/1859 [03:23<02:12,  5.86it/s]

[1,  1080] loss: 0.4499319


 59%|█████▊    | 1091/1859 [03:25<02:49,  4.53it/s]

[1,  1090] loss: 0.4798401


 59%|█████▉    | 1100/1859 [03:27<02:37,  4.81it/s]

[1,  1100] loss: 0.4603769


 60%|█████▉    | 1110/1859 [03:30<03:05,  4.05it/s]

[1,  1110] loss: 0.4694461


 60%|██████    | 1120/1859 [03:32<03:04,  4.00it/s]

[1,  1120] loss: 0.4630286


 61%|██████    | 1130/1859 [03:35<02:57,  4.10it/s]

[1,  1130] loss: 0.4760260


 61%|██████▏   | 1140/1859 [03:37<02:34,  4.66it/s]

[1,  1140] loss: 0.4531632


 62%|██████▏   | 1150/1859 [03:39<02:29,  4.76it/s]

[1,  1150] loss: 0.4691198


 62%|██████▏   | 1160/1859 [03:41<02:05,  5.58it/s]

[1,  1160] loss: 0.4716786


 63%|██████▎   | 1170/1859 [03:43<02:34,  4.45it/s]

[1,  1170] loss: 0.4631518


 63%|██████▎   | 1180/1859 [03:46<02:44,  4.13it/s]

[1,  1180] loss: 0.4684028


 64%|██████▍   | 1190/1859 [03:48<02:20,  4.77it/s]

[1,  1190] loss: 0.4579507


 65%|██████▍   | 1200/1859 [03:50<02:25,  4.51it/s]

[1,  1200] loss: 0.4786004


 65%|██████▌   | 1210/1859 [03:53<02:44,  3.95it/s]

[1,  1210] loss: 0.4551209


 66%|██████▌   | 1220/1859 [03:55<02:47,  3.82it/s]

[1,  1220] loss: 0.4673359


 66%|██████▌   | 1230/1859 [03:57<02:03,  5.08it/s]

[1,  1230] loss: 0.4716299


 67%|██████▋   | 1240/1859 [03:59<02:11,  4.71it/s]

[1,  1240] loss: 0.4618933


 67%|██████▋   | 1251/1859 [04:02<01:51,  5.48it/s]

[1,  1250] loss: 0.4693204


 68%|██████▊   | 1261/1859 [04:03<01:41,  5.88it/s]

[1,  1260] loss: 0.4584726


 68%|██████▊   | 1270/1859 [04:05<02:03,  4.77it/s]

[1,  1270] loss: 0.4732781


 69%|██████▉   | 1281/1859 [04:07<01:52,  5.14it/s]

[1,  1280] loss: 0.4552325


 69%|██████▉   | 1291/1859 [04:09<01:50,  5.14it/s]

[1,  1290] loss: 0.4812680


 70%|██████▉   | 1300/1859 [04:11<02:02,  4.56it/s]

[1,  1300] loss: 0.4547305


 70%|███████   | 1310/1859 [04:14<02:00,  4.55it/s]

[1,  1310] loss: 0.4716391


 71%|███████   | 1321/1859 [04:16<02:00,  4.48it/s]

[1,  1320] loss: 0.4683678


 72%|███████▏  | 1330/1859 [04:18<01:57,  4.51it/s]

[1,  1330] loss: 0.4645567


 72%|███████▏  | 1340/1859 [04:20<02:01,  4.26it/s]

[1,  1340] loss: 0.4717351


 73%|███████▎  | 1351/1859 [04:23<01:36,  5.24it/s]

[1,  1350] loss: 0.4555947


 73%|███████▎  | 1360/1859 [04:24<01:22,  6.06it/s]

[1,  1360] loss: 0.4783658


 74%|███████▎  | 1371/1859 [04:27<01:35,  5.10it/s]

[1,  1370] loss: 0.4588055


 74%|███████▍  | 1380/1859 [04:29<01:49,  4.39it/s]

[1,  1380] loss: 0.4686226


 75%|███████▍  | 1390/1859 [04:31<01:29,  5.24it/s]

[1,  1390] loss: 0.4621230


 75%|███████▌  | 1401/1859 [04:33<01:32,  4.93it/s]

[1,  1400] loss: 0.4680777


 76%|███████▌  | 1410/1859 [04:35<01:41,  4.42it/s]

[1,  1410] loss: 0.4670255


 76%|███████▋  | 1420/1859 [04:38<01:49,  4.01it/s]

[1,  1420] loss: 0.4580178


 77%|███████▋  | 1431/1859 [04:40<01:27,  4.87it/s]

[1,  1430] loss: 0.4762945


 77%|███████▋  | 1440/1859 [04:42<01:28,  4.72it/s]

[1,  1440] loss: 0.4671191


 78%|███████▊  | 1451/1859 [04:44<01:21,  5.02it/s]

[1,  1450] loss: 0.4558078


 79%|███████▊  | 1460/1859 [04:46<01:31,  4.35it/s]

[1,  1460] loss: 0.4740142


 79%|███████▉  | 1470/1859 [04:49<01:18,  4.96it/s]

[1,  1470] loss: 0.4705621


 80%|███████▉  | 1480/1859 [04:51<01:28,  4.31it/s]

[1,  1480] loss: 0.4647959


 80%|████████  | 1490/1859 [04:53<01:20,  4.60it/s]

[1,  1490] loss: 0.4586843


 81%|████████  | 1500/1859 [04:56<01:22,  4.35it/s]

[1,  1500] loss: 0.4757656


 81%|████████▏ | 1511/1859 [04:58<01:06,  5.24it/s]

[1,  1510] loss: 0.4621587


 82%|████████▏ | 1520/1859 [05:00<01:07,  4.99it/s]

[1,  1520] loss: 0.4584776


 82%|████████▏ | 1530/1859 [05:02<01:06,  4.95it/s]

[1,  1530] loss: 0.4784347


 83%|████████▎ | 1541/1859 [05:04<00:58,  5.45it/s]

[1,  1540] loss: 0.4650954


 83%|████████▎ | 1550/1859 [05:05<00:57,  5.42it/s]

[1,  1550] loss: 0.4587187


 84%|████████▍ | 1560/1859 [05:07<00:58,  5.09it/s]

[1,  1560] loss: 0.4739643


 84%|████████▍ | 1570/1859 [05:09<00:58,  4.94it/s]

[1,  1570] loss: 0.4630546


 85%|████████▍ | 1580/1859 [05:12<01:10,  3.94it/s]

[1,  1580] loss: 0.4596885


 86%|████████▌ | 1590/1859 [05:14<01:10,  3.79it/s]

[1,  1590] loss: 0.4756278


 86%|████████▌ | 1601/1859 [05:17<00:56,  4.59it/s]

[1,  1600] loss: 0.4600969


 87%|████████▋ | 1610/1859 [05:19<01:06,  3.75it/s]

[1,  1610] loss: 0.4670688


 87%|████████▋ | 1620/1859 [05:22<00:59,  4.00it/s]

[1,  1620] loss: 0.4647116


 88%|████████▊ | 1630/1859 [05:24<00:59,  3.82it/s]

[1,  1630] loss: 0.4717068


 88%|████████▊ | 1640/1859 [05:27<00:56,  3.90it/s]

[1,  1640] loss: 0.4705615


 89%|████████▉ | 1651/1859 [05:29<00:42,  4.87it/s]

[1,  1650] loss: 0.4520240


 89%|████████▉ | 1660/1859 [05:31<00:47,  4.18it/s]

[1,  1660] loss: 0.4721767


 90%|████████▉ | 1671/1859 [05:33<00:36,  5.17it/s]

[1,  1670] loss: 0.4614564


 90%|█████████ | 1680/1859 [05:35<00:40,  4.37it/s]

[1,  1680] loss: 0.4716232


 91%|█████████ | 1691/1859 [05:38<00:35,  4.76it/s]

[1,  1690] loss: 0.4642341


 91%|█████████▏| 1700/1859 [05:40<00:33,  4.75it/s]

[1,  1700] loss: 0.4716170


 92%|█████████▏| 1710/1859 [05:42<00:31,  4.70it/s]

[1,  1710] loss: 0.4672316


 93%|█████████▎| 1720/1859 [05:44<00:31,  4.35it/s]

[1,  1720] loss: 0.4617426


 93%|█████████▎| 1730/1859 [05:47<00:31,  4.07it/s]

[1,  1730] loss: 0.4764496


 94%|█████████▎| 1740/1859 [05:49<00:27,  4.26it/s]

[1,  1740] loss: 0.4620858


 94%|█████████▍| 1751/1859 [05:51<00:22,  4.84it/s]

[1,  1750] loss: 0.4631374


 95%|█████████▍| 1761/1859 [05:53<00:19,  4.97it/s]

[1,  1760] loss: 0.4616867


 95%|█████████▌| 1770/1859 [05:55<00:19,  4.58it/s]

[1,  1770] loss: 0.4647052


 96%|█████████▌| 1780/1859 [05:58<00:17,  4.52it/s]

[1,  1780] loss: 0.4697491


 96%|█████████▋| 1791/1859 [06:00<00:15,  4.34it/s]

[1,  1790] loss: 0.4753877


 97%|█████████▋| 1801/1859 [06:02<00:11,  4.84it/s]

[1,  1800] loss: 0.4498781


 97%|█████████▋| 1811/1859 [06:04<00:08,  5.55it/s]

[1,  1810] loss: 0.4746187


 98%|█████████▊| 1820/1859 [06:06<00:09,  4.14it/s]

[1,  1820] loss: 0.4696766


 98%|█████████▊| 1831/1859 [06:09<00:05,  4.81it/s]

[1,  1830] loss: 0.4692844


 99%|█████████▉| 1840/1859 [06:11<00:04,  4.19it/s]

[1,  1840] loss: 0.4628020


100%|█████████▉| 1851/1859 [06:13<00:01,  5.56it/s]

[1,  1850] loss: 0.4700799


100%|██████████| 1859/1859 [06:14<00:00,  5.77it/s]

Finished Training





## Recall test

In [13]:
correct = 0
incorrect = 0
for batch_idx, (aptamer, peptide) in enumerate(tqdm.tqdm(test_loader)):
    pep = one_hot(peptide, seq_type='peptide')
    apt = one_hot(aptamer, seq_type='aptamer')
    
    pep = torch.FloatTensor(np.reshape(pep, (1, pep.shape[0], pep.shape[1], pep.shape[2])))
    apt = torch.FloatTensor(np.reshape(apt, (1, apt.shape[0], apt.shape[1], apt.shape[2])))

    output = model(apt, pep).detach().numpy().flatten()
    for i in range(output.shape[0]):
        o = output[i]
        if o > 0.9:
            correct += 1
        else:
            incorrect += 1
    break

print('Recall of the network on the test samples: %d %%' % (100* correct/(correct + incorrect)))

  0%|          | 0/464 [00:00<?, ?it/s]

Recall of the network on the test samples: 100 %





## Sampling methods

In [14]:
# Sample x from P_X (assume peptides follow NNK)
def get_x():
    x_idx = np.random.choice(20, 7, p=pvals)
    x = "M"
    for i in x_idx:
        x += aa_list[i]
    return x

# Sample y from P_Y (assume apatamers follow uniform)
def get_y():
    y_idx = np.random.randint(0, 4, 40)
    y = ""
    for i in y_idx:
        y += na_list[i]
    return y

# Generate uniformly from S without replacement
def get_xy(k):
    samples = [full_dataset[i] for i in np.random.choice(len(full_dataset), k, replace=False)]
    return samples

# S' contains S with double the size of S (domain for Importance Sampling)
def get_S_prime(k):
    S_prime = full_dataset.copy()
    for _ in range(k):
        S_prime.append((get_y(), get_x()))
    S_prime_new = S_prime.copy()[-n:]
    return list(set(S_prime)), S_prime_new

# Sample from S' without replacement
def get_xy_prime(k):
    samples = [S_prime[i] for i in np.random.choice(len(S_prime), k, replace=False)]
    return samples

## Test whether naive sampling would generate pairs in S -- nope

In [15]:
n = len(full_dataset)
S_prime, S_prime_new = get_S_prime(n)
print("Size of S: ", n)
print("Size of naively sampled dataset: ", len(S_prime_new))
diff = set(full_dataset) - set(S_prime_new)
print("Size of set difference: ", len(diff))

Size of S:  594900
Size of naively sampled dataset:  594900
Size of set difference:  594900


## Motivates importance sampling in SGD

In [16]:
# Returns pmf of a peptide
def get_x_pmf(x):
    pmf = 1
    for char in x[1:]: #skips first char "M"
        pmf *= aa_dict[char]
    return pmf

# Returns pmf of an aptamer
def get_y_pmf():
    return 0.25**40

In [18]:
def sgd(t=int(1e5), #num of iter
        gamma=1e-2, #step size
        batch=1): #batch size
    for _ in range(t):
        xy = get_xy(1) #sample pair from S
        xy_prime = get_xy_prime(1) #sample pair from S'
        const = 1 #indicator
        x_prime_pmf = get_x_pmf(xy_prime[0][1])
        y_prime_pmf = get_y_pmf()
        if xy_prime in full_dataset:
            const = 0

## Evaluation