In [1]:
import os, sys
import numpy as np
import json
import random

## Prepare the data --> one hot encoding matrices

In [2]:
aptamer_dataset_file = "../data/aptamer_dataset.json"

def construct_dataset():
    with open(aptamer_dataset_file, 'r') as f:
        aptamer_data = json.load(f)
    full_dataset = []
    for aptamer in aptamer_data:
        peptides = aptamer_data[aptamer]
        for peptide, _ in peptides:
            if len(aptamer) == 40 and len(peptide) == 8:
                full_dataset.append((aptamer, peptide))
    return full_dataset

In [3]:
full_dataset = construct_dataset()
random.shuffle(full_dataset)
training_set = full_dataset[:int(0.8*len(full_dataset))]
test_set = full_dataset[int(0.8*len(full_dataset)):]

In [4]:
aa_list = ['R', 'L', 'S', 'A', 'G', 'P', 'T', 'V', 'N', 'D', 'C', 'Q', 'E', 'H', 'I', 'K', 'M', 'F', 'W', 'Y']
na_list = ['A', 'C', 'G', 'T']

## Takes a peptide and aptamer sequence and converts to one-hot matrix
def one_hot(sequence, seq_type='peptide'):
    if seq_type == 'peptide':
        letters = aa_list
    else:
        letters = na_list
    one_hot = np.zeros((len(sequence), len(letters)))
    for i in range(len(sequence)):
        element = sequence[i]
        idx = letters.index(element)
        one_hot[i][idx] = 1
    return one_hot

## Model --> CNN

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD

In [41]:
class TwoLayer(nn.Module):
    def __init__(self):
        super(TwoLayer, self).__init__()
        self.linear_apt_1 = nn.Linear(40, 40)
        self.linear_apt_2 = nn.Linear(40, 1)
        
        self.linear_pep_1 = nn.Linear(8, 8)
        self.linear_pep_2 = nn.Linear(8, 1)
        
        self.relu = nn.ReLU()
    
        self.sequential_pep = nn.Sequential(self.linear_pep_1,
                                            self.relu,
                                            self.linear_pep_2)
        
        self.sequential_apt = nn.Sequential(self.linear_apt_1,
                                            self.relu,
                                            self.linear_apt_2)
                
    def forward(self, apt, pep):
        apt = self.sequential_apt(apt)
        pep = self.sequential_pep(pep)
        print(apt.shape())
        apt = apt.view(-1, 1).T
        pep = pep.view(-1, 1).T
        x = torch.cat((apt, pep), 1)
        x = F.sigmoid(x)
        return x
    
    def loss(self, prediction, label):
        l = nn.MSELoss()
        label = torch.FloatTensor(label)
        label = label.reshape((1, 1))
        return l(torch.FloatTensor(prediction), label)

In [19]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.cnn_apt_1 = nn.Conv2d(1, 40, 1)
        self.cnn_apt_2 = nn.Conv2d(40, 10, 1)
        self.cnn_apt_3 = nn.Conv2d(10, 1, 1)
        self.fc_apt_1 = nn.Linear(160, 1)
        
        self.cnn_pep_1 = nn.Conv2d(1, 8, 1)
        self.cnn_pep_2 = nn.Conv2d(8, 1, 1)
        self.fc_pep_1 = nn.Linear(64, 1)
        
        self.pool = nn.MaxPool2d(2, 1)
        self.relu = nn.ReLU()
        
        #self.dropout = nn.Dropout(0.1)
        
        self.sequential_pep = nn.Sequential(self.cnn_pep_1,
                                            #self.dropout,
                                            self.relu, 
                                            self.pool, 
                                            self.cnn_pep_2)
        
        self.sequential_apt = nn.Sequential(self.cnn_apt_1, 
                                            #self.dropout,
                                            self.relu, 
                                            self.pool, 
                                            self.cnn_apt_2, 
                                            #self.dropout,
                                            self.relu, 
                                            self.pool, 
                                            self.cnn_apt_3)
        
        self.fc1 = nn.Linear(209, 1)
        
    def forward(self, apt, pep):
        apt = self.sequential_apt(apt)
        pep = self.sequential_pep(pep)
        
        apt = apt.view(-1, 1).T
        pep = pep.view(-1, 1).T
        
        x = torch.cat((apt, pep), 1)
        x = self.fc1(x)
        x = F.sigmoid(x)
        return x
    
    def loss(self, prediction, label):
        l = nn.MSELoss()
        label = torch.FloatTensor(label)
        label = label.reshape((1, 1))
        return l(torch.FloatTensor(prediction), label)

In [26]:
model = ConvNet()
def weights_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.xavier_uniform_(m.weight.data)
        nn.init.zeros_(m.bias.data)
    if isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight.data, nonlinearity='relu')
        nn.init.zeros_(m.bias.data)

#model_2 = TwoLayer()
# def weights_init_2(m):
#     if isinstance(m, nn.Linear):
#         nn.init.xavier_uniform_(m.weight.data)
#         nn.init.zeros_(m.bias.data)

model.apply(weights_init)
optimizer = Adam(model.parameters(), lr=1e-7, weight_decay=1e-5)

In [27]:
# Training Loop
import tqdm
for epoch in range(1):
    print("Epoch: ", epoch)
    model.train()
    running_loss = 0.0
    # Come up with a trainloader
    for i, data in enumerate(tqdm.tqdm(training_set[:10000])):
        # Peptide and aptamer, one-hot encode them
        pep = training_set[i][1]
        apt = training_set[i][0]
        
        pep = one_hot(pep, seq_type='peptide')
        apt = one_hot(apt, seq_type='aptamer')
        
        pep = torch.FloatTensor(np.reshape(pep, (1, 1, pep.shape[0], pep.shape[1])))
        apt = torch.FloatTensor(np.reshape(apt, (1, 1, apt.shape[0], apt.shape[1])))
        
        output = model(apt, pep)
        loss = model.loss(output, 1)
        optimizer.zero_grad()
        loss.backward()
        
        #hyperparameter
        clip = 5
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()
        running_loss += loss.item()
        if i % 500 == 499:
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 500))
            running_loss = 0.0
    
print('Finished Training')



  0%|          | 0/10000 [00:00<?, ?it/s][A[A

  0%|          | 1/10000 [00:00<21:58,  7.58it/s][A[A

Epoch:  0




  0%|          | 8/10000 [00:00<16:05, 10.35it/s][A[A

  0%|          | 16/10000 [00:00<11:55, 13.96it/s][A[A

  0%|          | 24/10000 [00:00<09:00, 18.46it/s][A[A

  0%|          | 33/10000 [00:00<06:51, 24.23it/s][A[A

  0%|          | 43/10000 [00:00<05:20, 31.05it/s][A[A

  1%|          | 53/10000 [00:00<04:17, 38.60it/s][A[A

  1%|          | 61/10000 [00:00<03:49, 43.36it/s][A[A

  1%|          | 70/10000 [00:01<03:15, 50.84it/s][A[A

  1%|          | 78/10000 [00:01<02:53, 57.07it/s][A[A

  1%|          | 86/10000 [00:01<02:44, 60.34it/s][A[A

  1%|          | 95/10000 [00:01<02:28, 66.60it/s][A[A

  1%|          | 104/10000 [00:01<02:18, 71.26it/s][A[A

  1%|          | 114/10000 [00:01<02:09, 76.50it/s][A[A

  1%|          | 124/10000 [00:01<02:05, 78.84it/s][A[A

  1%|▏         | 134/10000 [00:01<01:58, 83.00it/s][A[A

  1%|▏         | 143/10000 [00:01<01:56, 84.62it/s][A[A

  2%|▏         | 152/10000 [00:01<01:55, 85.61it/s][A[A

  2%|▏ 

[1,   500] loss: 0.541




  5%|▌         | 521/10000 [00:06<01:55, 82.06it/s][A[A

  5%|▌         | 532/10000 [00:06<01:48, 87.26it/s][A[A

  5%|▌         | 542/10000 [00:06<01:46, 88.99it/s][A[A

  6%|▌         | 551/10000 [00:06<01:48, 86.95it/s][A[A

  6%|▌         | 561/10000 [00:06<01:46, 88.56it/s][A[A

  6%|▌         | 571/10000 [00:06<01:44, 89.90it/s][A[A

  6%|▌         | 581/10000 [00:07<01:43, 90.65it/s][A[A

  6%|▌         | 591/10000 [00:07<01:48, 86.43it/s][A[A

  6%|▌         | 600/10000 [00:07<01:56, 80.49it/s][A[A

  6%|▌         | 609/10000 [00:07<02:00, 77.95it/s][A[A

  6%|▌         | 617/10000 [00:07<02:03, 76.15it/s][A[A

  6%|▋         | 625/10000 [00:07<02:05, 74.88it/s][A[A

  6%|▋         | 633/10000 [00:07<02:07, 73.57it/s][A[A

  6%|▋         | 641/10000 [00:07<02:08, 72.89it/s][A[A

  6%|▋         | 649/10000 [00:08<02:07, 73.58it/s][A[A

  7%|▋         | 658/10000 [00:08<02:04, 75.31it/s][A[A

  7%|▋         | 666/10000 [00:08<02:03, 75.62it/s][A

[1,  1000] loss: 0.518




 10%|█         | 1017/10000 [00:13<02:03, 72.73it/s][A[A

 10%|█         | 1025/10000 [00:13<02:04, 72.34it/s][A[A

 10%|█         | 1034/10000 [00:13<01:58, 75.66it/s][A[A

 10%|█         | 1042/10000 [00:13<01:58, 75.79it/s][A[A

 11%|█         | 1051/10000 [00:13<01:54, 77.85it/s][A[A

 11%|█         | 1061/10000 [00:13<01:50, 80.88it/s][A[A

 11%|█         | 1070/10000 [00:13<01:50, 81.00it/s][A[A

 11%|█         | 1079/10000 [00:13<01:48, 82.17it/s][A[A

 11%|█         | 1088/10000 [00:13<01:51, 79.71it/s][A[A

 11%|█         | 1097/10000 [00:14<01:54, 77.48it/s][A[A

 11%|█         | 1105/10000 [00:14<01:54, 77.35it/s][A[A

 11%|█         | 1113/10000 [00:14<01:54, 77.32it/s][A[A

 11%|█         | 1121/10000 [00:14<01:58, 74.71it/s][A[A

 11%|█▏        | 1129/10000 [00:14<01:57, 75.37it/s][A[A

 11%|█▏        | 1137/10000 [00:14<01:56, 76.15it/s][A[A

 11%|█▏        | 1146/10000 [00:14<01:53, 77.69it/s][A[A

 12%|█▏        | 1155/10000 [00:14<01:

[1,  1500] loss: 0.530




 15%|█▌        | 1520/10000 [00:19<02:01, 69.93it/s][A[A

 15%|█▌        | 1528/10000 [00:19<02:00, 70.47it/s][A[A

 15%|█▌        | 1536/10000 [00:19<01:58, 71.49it/s][A[A

 15%|█▌        | 1545/10000 [00:19<01:53, 74.19it/s][A[A

 16%|█▌        | 1554/10000 [00:20<01:52, 75.20it/s][A[A

 16%|█▌        | 1562/10000 [00:20<01:59, 70.52it/s][A[A

 16%|█▌        | 1570/10000 [00:20<02:22, 59.05it/s][A[A

 16%|█▌        | 1579/10000 [00:20<02:08, 65.63it/s][A[A

 16%|█▌        | 1588/10000 [00:20<02:00, 69.80it/s][A[A

 16%|█▌        | 1596/10000 [00:20<02:07, 65.87it/s][A[A

 16%|█▌        | 1604/10000 [00:20<02:02, 68.64it/s][A[A

 16%|█▌        | 1612/10000 [00:20<01:57, 71.27it/s][A[A

 16%|█▌        | 1620/10000 [00:21<01:55, 72.53it/s][A[A

 16%|█▋        | 1628/10000 [00:21<01:53, 74.05it/s][A[A

 16%|█▋        | 1637/10000 [00:21<01:49, 76.38it/s][A[A

 16%|█▋        | 1645/10000 [00:21<01:54, 73.17it/s][A[A

 17%|█▋        | 1653/10000 [00:21<01:

[1,  2000] loss: 0.519




 20%|██        | 2017/10000 [00:26<01:57, 67.87it/s][A[A

 20%|██        | 2026/10000 [00:26<01:50, 72.22it/s][A[A

 20%|██        | 2034/10000 [00:26<01:49, 72.69it/s][A[A

 20%|██        | 2042/10000 [00:26<01:48, 73.59it/s][A[A

 21%|██        | 2051/10000 [00:26<01:42, 77.84it/s][A[A

 21%|██        | 2061/10000 [00:27<01:37, 81.02it/s][A[A

 21%|██        | 2070/10000 [00:27<01:42, 77.63it/s][A[A

 21%|██        | 2078/10000 [00:27<01:44, 76.02it/s][A[A

 21%|██        | 2086/10000 [00:27<01:46, 74.10it/s][A[A

 21%|██        | 2094/10000 [00:27<01:48, 72.80it/s][A[A

 21%|██        | 2102/10000 [00:27<01:48, 72.55it/s][A[A

 21%|██        | 2110/10000 [00:27<01:48, 72.91it/s][A[A

 21%|██        | 2119/10000 [00:27<01:43, 76.33it/s][A[A

 21%|██▏       | 2128/10000 [00:27<01:39, 79.28it/s][A[A

 21%|██▏       | 2137/10000 [00:28<01:37, 80.92it/s][A[A

 21%|██▏       | 2147/10000 [00:28<01:34, 83.29it/s][A[A

 22%|██▏       | 2156/10000 [00:28<01:

[1,  2500] loss: 0.534




 25%|██▌       | 2517/10000 [00:32<01:29, 83.23it/s][A[A

 25%|██▌       | 2526/10000 [00:32<01:32, 80.78it/s][A[A

 25%|██▌       | 2536/10000 [00:32<01:29, 83.58it/s][A[A

 25%|██▌       | 2546/10000 [00:32<01:26, 85.81it/s][A[A

 26%|██▌       | 2555/10000 [00:32<01:26, 86.56it/s][A[A

 26%|██▌       | 2564/10000 [00:33<01:27, 85.10it/s][A[A

 26%|██▌       | 2574/10000 [00:33<01:24, 87.50it/s][A[A

 26%|██▌       | 2583/10000 [00:33<01:29, 82.60it/s][A[A

 26%|██▌       | 2592/10000 [00:33<01:28, 84.12it/s][A[A

 26%|██▌       | 2602/10000 [00:33<01:25, 86.82it/s][A[A

 26%|██▌       | 2611/10000 [00:33<01:24, 87.23it/s][A[A

 26%|██▌       | 2621/10000 [00:33<01:21, 90.64it/s][A[A

 26%|██▋       | 2631/10000 [00:33<01:21, 90.05it/s][A[A

 26%|██▋       | 2641/10000 [00:33<01:20, 90.94it/s][A[A

 27%|██▋       | 2651/10000 [00:34<01:21, 90.18it/s][A[A

 27%|██▋       | 2661/10000 [00:34<01:19, 92.67it/s][A[A

 27%|██▋       | 2671/10000 [00:34<01:

[1,  3000] loss: 0.509




 30%|███       | 3018/10000 [00:38<01:23, 84.12it/s][A[A

 30%|███       | 3028/10000 [00:38<01:21, 85.94it/s][A[A

 30%|███       | 3038/10000 [00:38<01:19, 87.12it/s][A[A

 30%|███       | 3047/10000 [00:38<01:24, 82.11it/s][A[A

 31%|███       | 3056/10000 [00:38<01:32, 75.21it/s][A[A

 31%|███       | 3064/10000 [00:39<01:35, 72.90it/s][A[A

 31%|███       | 3072/10000 [00:39<01:36, 72.12it/s][A[A

 31%|███       | 3080/10000 [00:39<01:35, 72.55it/s][A[A

 31%|███       | 3088/10000 [00:39<01:35, 72.66it/s][A[A

 31%|███       | 3097/10000 [00:39<01:32, 74.30it/s][A[A

 31%|███       | 3107/10000 [00:39<01:27, 79.17it/s][A[A

 31%|███       | 3116/10000 [00:39<01:29, 76.95it/s][A[A

 31%|███       | 3124/10000 [00:39<01:29, 77.17it/s][A[A

 31%|███▏      | 3133/10000 [00:39<01:26, 79.06it/s][A[A

 31%|███▏      | 3142/10000 [00:40<01:24, 80.88it/s][A[A

 32%|███▏      | 3151/10000 [00:40<01:22, 83.31it/s][A[A

 32%|███▏      | 3160/10000 [00:40<01:

[1,  3500] loss: 0.528




 35%|███▌      | 3527/10000 [00:44<01:12, 88.96it/s][A[A

 35%|███▌      | 3536/10000 [00:44<01:13, 87.73it/s][A[A

 35%|███▌      | 3545/10000 [00:44<01:27, 73.74it/s][A[A

 36%|███▌      | 3553/10000 [00:44<01:30, 71.31it/s][A[A

 36%|███▌      | 3561/10000 [00:45<01:28, 72.77it/s][A[A

 36%|███▌      | 3569/10000 [00:45<01:27, 73.51it/s][A[A

 36%|███▌      | 3577/10000 [00:45<01:27, 73.82it/s][A[A

 36%|███▌      | 3585/10000 [00:45<01:25, 75.33it/s][A[A

 36%|███▌      | 3594/10000 [00:45<01:24, 76.19it/s][A[A

 36%|███▌      | 3603/10000 [00:45<01:21, 78.62it/s][A[A

 36%|███▌      | 3612/10000 [00:45<01:20, 79.26it/s][A[A

 36%|███▌      | 3621/10000 [00:45<01:18, 81.04it/s][A[A

 36%|███▋      | 3630/10000 [00:45<01:19, 80.59it/s][A[A

 36%|███▋      | 3639/10000 [00:45<01:19, 80.26it/s][A[A

 36%|███▋      | 3648/10000 [00:46<01:19, 80.04it/s][A[A

 37%|███▋      | 3657/10000 [00:46<01:18, 80.52it/s][A[A

 37%|███▋      | 3666/10000 [00:46<01:

[1,  4000] loss: 0.536




 40%|████      | 4015/10000 [00:51<01:32, 65.05it/s][A[A

 40%|████      | 4023/10000 [00:51<01:28, 67.89it/s][A[A

 40%|████      | 4031/10000 [00:51<01:24, 70.46it/s][A[A

 40%|████      | 4040/10000 [00:51<01:20, 74.37it/s][A[A

 40%|████      | 4048/10000 [00:51<01:21, 72.95it/s][A[A

 41%|████      | 4057/10000 [00:51<01:18, 75.78it/s][A[A

 41%|████      | 4066/10000 [00:51<01:16, 77.66it/s][A[A

 41%|████      | 4074/10000 [00:51<01:20, 73.96it/s][A[A

 41%|████      | 4083/10000 [00:51<01:16, 77.50it/s][A[A

 41%|████      | 4091/10000 [00:52<01:25, 69.39it/s][A[A

 41%|████      | 4099/10000 [00:52<01:21, 71.99it/s][A[A

 41%|████      | 4107/10000 [00:52<01:20, 72.89it/s][A[A

 41%|████      | 4116/10000 [00:52<01:18, 75.24it/s][A[A

 41%|████      | 4124/10000 [00:52<01:19, 74.09it/s][A[A

 41%|████▏     | 4132/10000 [00:52<01:18, 74.78it/s][A[A

 41%|████▏     | 4140/10000 [00:52<01:20, 73.10it/s][A[A

 41%|████▏     | 4148/10000 [00:52<01:

[1,  4500] loss: 0.522


 45%|████▌     | 4516/10000 [00:57<01:13, 74.74it/s][A[A

 45%|████▌     | 4525/10000 [00:57<01:11, 76.52it/s][A[A

 45%|████▌     | 4533/10000 [00:58<01:10, 77.42it/s][A[A

 45%|████▌     | 4543/10000 [00:58<01:07, 81.28it/s][A[A

 46%|████▌     | 4553/10000 [00:58<01:04, 84.73it/s][A[A

 46%|████▌     | 4563/10000 [00:58<01:01, 88.60it/s][A[A

 46%|████▌     | 4573/10000 [00:58<01:00, 89.75it/s][A[A

 46%|████▌     | 4583/10000 [00:58<00:59, 91.16it/s][A[A

 46%|████▌     | 4593/10000 [00:58<01:00, 90.10it/s][A[A

 46%|████▌     | 4603/10000 [00:58<01:01, 87.86it/s][A[A

 46%|████▌     | 4613/10000 [00:58<00:59, 90.10it/s][A[A

 46%|████▌     | 4623/10000 [00:59<01:00, 88.91it/s][A[A

 46%|████▋     | 4632/10000 [00:59<01:01, 87.46it/s][A[A

 46%|████▋     | 4641/10000 [00:59<01:01, 86.85it/s][A[A

 46%|████▋     | 4650/10000 [00:59<01:02, 85.79it/s][A[A

 47%|████▋     | 4660/10000 [00:59<01:00, 88.72it/s][A[A

 47%|████▋     | 4669/10000 [00:59<01:02

[1,  5000] loss: 0.509




 50%|█████     | 5021/10000 [01:03<00:55, 90.43it/s][A[A

 50%|█████     | 5031/10000 [01:03<00:54, 91.95it/s][A[A

 50%|█████     | 5041/10000 [01:04<00:58, 84.80it/s][A[A

 50%|█████     | 5050/10000 [01:04<00:58, 84.03it/s][A[A

 51%|█████     | 5059/10000 [01:04<00:58, 84.11it/s][A[A

 51%|█████     | 5068/10000 [01:04<00:57, 85.67it/s][A[A

 51%|█████     | 5077/10000 [01:04<00:57, 85.26it/s][A[A

 51%|█████     | 5086/10000 [01:04<00:58, 83.33it/s][A[A

 51%|█████     | 5095/10000 [01:04<00:58, 84.23it/s][A[A

 51%|█████     | 5105/10000 [01:04<00:56, 85.97it/s][A[A

 51%|█████     | 5114/10000 [01:04<00:57, 84.95it/s][A[A

 51%|█████     | 5124/10000 [01:04<00:55, 87.76it/s][A[A

 51%|█████▏    | 5134/10000 [01:05<00:54, 89.97it/s][A[A

 51%|█████▏    | 5144/10000 [01:05<00:53, 91.62it/s][A[A

 52%|█████▏    | 5154/10000 [01:05<00:53, 90.14it/s][A[A

 52%|█████▏    | 5164/10000 [01:05<00:55, 87.82it/s][A[A

 52%|█████▏    | 5174/10000 [01:05<00:

[1,  5500] loss: 0.515




 55%|█████▌    | 5523/10000 [01:09<00:48, 92.70it/s][A[A

 55%|█████▌    | 5533/10000 [01:09<00:51, 86.75it/s][A[A

 55%|█████▌    | 5543/10000 [01:09<00:50, 88.08it/s][A[A

 56%|█████▌    | 5553/10000 [01:09<00:49, 90.55it/s][A[A

 56%|█████▌    | 5563/10000 [01:09<00:48, 91.10it/s][A[A

 56%|█████▌    | 5573/10000 [01:09<00:47, 92.68it/s][A[A

 56%|█████▌    | 5583/10000 [01:10<00:46, 93.98it/s][A[A

 56%|█████▌    | 5593/10000 [01:10<00:47, 92.81it/s][A[A

 56%|█████▌    | 5603/10000 [01:10<00:46, 94.53it/s][A[A

 56%|█████▌    | 5613/10000 [01:10<00:47, 92.93it/s][A[A

 56%|█████▌    | 5623/10000 [01:10<00:47, 92.52it/s][A[A

 56%|█████▋    | 5633/10000 [01:10<00:48, 90.53it/s][A[A

 56%|█████▋    | 5643/10000 [01:10<00:47, 91.03it/s][A[A

 57%|█████▋    | 5653/10000 [01:10<00:48, 89.61it/s][A[A

 57%|█████▋    | 5662/10000 [01:10<00:48, 89.46it/s][A[A

 57%|█████▋    | 5672/10000 [01:11<00:47, 91.58it/s][A[A

 57%|█████▋    | 5682/10000 [01:11<00:

[1,  6000] loss: 0.519




 60%|██████    | 6021/10000 [01:15<00:45, 88.09it/s][A[A

 60%|██████    | 6030/10000 [01:15<00:45, 87.58it/s][A[A

 60%|██████    | 6040/10000 [01:15<00:44, 89.96it/s][A[A

 60%|██████    | 6050/10000 [01:15<00:43, 91.52it/s][A[A

 61%|██████    | 6060/10000 [01:15<00:42, 91.87it/s][A[A

 61%|██████    | 6070/10000 [01:15<00:43, 90.78it/s][A[A

 61%|██████    | 6082/10000 [01:15<00:40, 95.81it/s][A[A

 61%|██████    | 6092/10000 [01:15<00:41, 95.06it/s][A[A

 61%|██████    | 6103/10000 [01:15<00:39, 97.46it/s][A[A

 61%|██████    | 6113/10000 [01:16<00:40, 96.56it/s][A[A

 61%|██████    | 6123/10000 [01:16<00:39, 97.08it/s][A[A

 61%|██████▏   | 6133/10000 [01:16<00:44, 85.98it/s][A[A

 61%|██████▏   | 6143/10000 [01:16<00:43, 89.22it/s][A[A

 62%|██████▏   | 6153/10000 [01:16<00:42, 90.32it/s][A[A

 62%|██████▏   | 6163/10000 [01:16<00:42, 90.53it/s][A[A

 62%|██████▏   | 6173/10000 [01:16<00:43, 88.56it/s][A[A

 62%|██████▏   | 6183/10000 [01:16<00:

[1,  6500] loss: 0.533




 65%|██████▌   | 6519/10000 [01:21<00:40, 86.40it/s][A[A

 65%|██████▌   | 6529/10000 [01:21<00:39, 88.40it/s][A[A

 65%|██████▌   | 6539/10000 [01:21<00:38, 89.90it/s][A[A

 65%|██████▌   | 6549/10000 [01:21<00:37, 91.68it/s][A[A

 66%|██████▌   | 6559/10000 [01:21<00:37, 91.23it/s][A[A

 66%|██████▌   | 6569/10000 [01:21<00:37, 90.40it/s][A[A

 66%|██████▌   | 6579/10000 [01:21<00:37, 92.13it/s][A[A

 66%|██████▌   | 6589/10000 [01:21<00:37, 89.79it/s][A[A

 66%|██████▌   | 6599/10000 [01:21<00:37, 91.76it/s][A[A

 66%|██████▌   | 6610/10000 [01:22<00:35, 95.06it/s][A[A

 66%|██████▌   | 6620/10000 [01:22<00:36, 92.73it/s][A[A

 66%|██████▋   | 6630/10000 [01:22<00:37, 90.94it/s][A[A

 66%|██████▋   | 6640/10000 [01:22<00:39, 85.72it/s][A[A

 66%|██████▋   | 6649/10000 [01:22<00:40, 83.47it/s][A[A

 67%|██████▋   | 6658/10000 [01:22<00:41, 80.42it/s][A[A

 67%|██████▋   | 6667/10000 [01:22<00:41, 80.60it/s][A[A

 67%|██████▋   | 6676/10000 [01:22<00:

[1,  7000] loss: 0.508




 70%|███████   | 7019/10000 [01:27<00:35, 83.82it/s][A[A

 70%|███████   | 7028/10000 [01:27<00:35, 84.37it/s][A[A

 70%|███████   | 7038/10000 [01:27<00:34, 86.74it/s][A[A

 70%|███████   | 7047/10000 [01:27<00:35, 83.90it/s][A[A

 71%|███████   | 7056/10000 [01:27<00:35, 82.66it/s][A[A

 71%|███████   | 7065/10000 [01:27<00:35, 83.09it/s][A[A

 71%|███████   | 7074/10000 [01:27<00:34, 83.74it/s][A[A

 71%|███████   | 7083/10000 [01:27<00:34, 84.37it/s][A[A

 71%|███████   | 7092/10000 [01:27<00:34, 83.12it/s][A[A

 71%|███████   | 7101/10000 [01:28<00:34, 83.35it/s][A[A

 71%|███████   | 7110/10000 [01:28<00:36, 79.07it/s][A[A

 71%|███████   | 7120/10000 [01:28<00:34, 82.98it/s][A[A

 71%|███████▏  | 7129/10000 [01:28<00:34, 84.11it/s][A[A

 71%|███████▏  | 7139/10000 [01:28<00:32, 86.87it/s][A[A

 71%|███████▏  | 7148/10000 [01:28<00:34, 81.90it/s][A[A

 72%|███████▏  | 7157/10000 [01:28<00:37, 76.10it/s][A[A

 72%|███████▏  | 7165/10000 [01:28<00:

[1,  7500] loss: 0.508




 75%|███████▌  | 7516/10000 [01:33<00:36, 68.17it/s][A[A

 75%|███████▌  | 7524/10000 [01:33<00:35, 69.59it/s][A[A

 75%|███████▌  | 7531/10000 [01:34<00:36, 67.71it/s][A[A

 75%|███████▌  | 7539/10000 [01:34<00:35, 68.49it/s][A[A

 75%|███████▌  | 7548/10000 [01:34<00:33, 73.06it/s][A[A

 76%|███████▌  | 7556/10000 [01:34<00:33, 73.21it/s][A[A

 76%|███████▌  | 7564/10000 [01:34<00:32, 74.25it/s][A[A

 76%|███████▌  | 7573/10000 [01:34<00:31, 76.41it/s][A[A

 76%|███████▌  | 7581/10000 [01:34<00:34, 70.59it/s][A[A

 76%|███████▌  | 7590/10000 [01:34<00:32, 74.27it/s][A[A

 76%|███████▌  | 7598/10000 [01:34<00:32, 74.87it/s][A[A

 76%|███████▌  | 7606/10000 [01:35<00:31, 75.76it/s][A[A

 76%|███████▌  | 7614/10000 [01:35<00:31, 75.93it/s][A[A

 76%|███████▌  | 7622/10000 [01:35<00:34, 68.70it/s][A[A

 76%|███████▋  | 7630/10000 [01:35<00:33, 70.55it/s][A[A

 76%|███████▋  | 7638/10000 [01:35<00:32, 72.03it/s][A[A

 76%|███████▋  | 7646/10000 [01:35<00:

[1,  8000] loss: 0.530




 80%|████████  | 8022/10000 [01:40<00:23, 84.71it/s][A[A

 80%|████████  | 8031/10000 [01:40<00:23, 84.83it/s][A[A

 80%|████████  | 8040/10000 [01:40<00:24, 81.28it/s][A[A

 80%|████████  | 8049/10000 [01:40<00:23, 82.43it/s][A[A

 81%|████████  | 8059/10000 [01:40<00:22, 85.08it/s][A[A

 81%|████████  | 8069/10000 [01:40<00:21, 87.79it/s][A[A

 81%|████████  | 8078/10000 [01:41<00:22, 86.76it/s][A[A

 81%|████████  | 8087/10000 [01:41<00:21, 87.65it/s][A[A

 81%|████████  | 8097/10000 [01:41<00:21, 90.09it/s][A[A

 81%|████████  | 8107/10000 [01:41<00:21, 89.23it/s][A[A

 81%|████████  | 8116/10000 [01:41<00:21, 88.07it/s][A[A

 81%|████████▏ | 8125/10000 [01:41<00:21, 87.15it/s][A[A

 81%|████████▏ | 8134/10000 [01:41<00:21, 87.06it/s][A[A

 81%|████████▏ | 8143/10000 [01:41<00:22, 84.31it/s][A[A

 82%|████████▏ | 8152/10000 [01:41<00:21, 85.69it/s][A[A

 82%|████████▏ | 8161/10000 [01:42<00:21, 84.24it/s][A[A

 82%|████████▏ | 8171/10000 [01:42<00:

[1,  8500] loss: 0.527




 85%|████████▌ | 8523/10000 [01:46<00:16, 89.75it/s][A[A

 85%|████████▌ | 8533/10000 [01:46<00:16, 91.06it/s][A[A

 85%|████████▌ | 8543/10000 [01:46<00:15, 91.20it/s][A[A

 86%|████████▌ | 8553/10000 [01:46<00:17, 81.80it/s][A[A

 86%|████████▌ | 8562/10000 [01:46<00:17, 80.26it/s][A[A

 86%|████████▌ | 8571/10000 [01:46<00:18, 78.26it/s][A[A

 86%|████████▌ | 8579/10000 [01:46<00:18, 75.23it/s][A[A

 86%|████████▌ | 8587/10000 [01:47<00:19, 71.66it/s][A[A

 86%|████████▌ | 8595/10000 [01:47<00:19, 73.47it/s][A[A

 86%|████████▌ | 8603/10000 [01:47<00:19, 72.54it/s][A[A

 86%|████████▌ | 8611/10000 [01:47<00:19, 71.43it/s][A[A

 86%|████████▌ | 8619/10000 [01:47<00:19, 72.04it/s][A[A

 86%|████████▋ | 8627/10000 [01:47<00:18, 73.31it/s][A[A

 86%|████████▋ | 8635/10000 [01:47<00:18, 74.13it/s][A[A

 86%|████████▋ | 8643/10000 [01:47<00:18, 72.81it/s][A[A

 87%|████████▋ | 8651/10000 [01:47<00:19, 69.48it/s][A[A

 87%|████████▋ | 8658/10000 [01:48<00:

[1,  9000] loss: 0.495




 90%|█████████ | 9020/10000 [01:52<00:13, 70.59it/s][A[A

 90%|█████████ | 9028/10000 [01:53<00:13, 70.49it/s][A[A

 90%|█████████ | 9037/10000 [01:53<00:13, 73.75it/s][A[A

 90%|█████████ | 9046/10000 [01:53<00:12, 77.32it/s][A[A

 91%|█████████ | 9055/10000 [01:53<00:11, 79.32it/s][A[A

 91%|█████████ | 9064/10000 [01:53<00:11, 81.01it/s][A[A

 91%|█████████ | 9073/10000 [01:53<00:11, 81.84it/s][A[A

 91%|█████████ | 9082/10000 [01:53<00:11, 77.88it/s][A[A

 91%|█████████ | 9090/10000 [01:53<00:12, 75.40it/s][A[A

 91%|█████████ | 9099/10000 [01:53<00:11, 78.42it/s][A[A

 91%|█████████ | 9107/10000 [01:54<00:11, 78.62it/s][A[A

 91%|█████████ | 9115/10000 [01:54<00:11, 78.24it/s][A[A

 91%|█████████ | 9124/10000 [01:54<00:10, 79.86it/s][A[A

 91%|█████████▏| 9134/10000 [01:54<00:10, 84.55it/s][A[A

 91%|█████████▏| 9143/10000 [01:54<00:10, 84.91it/s][A[A

 92%|█████████▏| 9153/10000 [01:54<00:09, 88.25it/s][A[A

 92%|█████████▏| 9162/10000 [01:54<00:

[1,  9500] loss: 0.486




 95%|█████████▌| 9519/10000 [01:58<00:05, 93.47it/s][A[A

 95%|█████████▌| 9529/10000 [01:58<00:05, 90.52it/s][A[A

 95%|█████████▌| 9539/10000 [01:59<00:05, 85.05it/s][A[A

 95%|█████████▌| 9548/10000 [01:59<00:05, 84.47it/s][A[A

 96%|█████████▌| 9557/10000 [01:59<00:05, 86.00it/s][A[A

 96%|█████████▌| 9567/10000 [01:59<00:04, 87.65it/s][A[A

 96%|█████████▌| 9576/10000 [01:59<00:04, 87.61it/s][A[A

 96%|█████████▌| 9586/10000 [01:59<00:04, 89.95it/s][A[A

 96%|█████████▌| 9596/10000 [01:59<00:04, 91.21it/s][A[A

 96%|█████████▌| 9606/10000 [01:59<00:04, 91.96it/s][A[A

 96%|█████████▌| 9616/10000 [01:59<00:04, 91.51it/s][A[A

 96%|█████████▋| 9626/10000 [01:59<00:04, 90.13it/s][A[A

 96%|█████████▋| 9637/10000 [02:00<00:03, 94.21it/s][A[A

 96%|█████████▋| 9647/10000 [02:00<00:03, 90.45it/s][A[A

 97%|█████████▋| 9657/10000 [02:00<00:03, 86.43it/s][A[A

 97%|█████████▋| 9666/10000 [02:00<00:03, 83.95it/s][A[A

 97%|█████████▋| 9675/10000 [02:00<00:

[1, 10000] loss: 0.513
Finished Training


## Evaluation --> compare to random

In [28]:
correct = 0
incorrect = 0
for i, data in enumerate(tqdm.tqdm(test_set[:10000])):
    pep = test_set[i][1]
    apt = test_set[i][0]
    
    pep = one_hot(pep, seq_type='peptide')
    apt = one_hot(apt, seq_type='aptamer')

    pep = torch.FloatTensor(np.reshape(pep, (1, 1, pep.shape[0], pep.shape[1])))
    apt = torch.FloatTensor(np.reshape(apt, (1, 1, apt.shape[0], apt.shape[1])))

    output = model(apt, pep)
    #print("Output: ", output)
    
    if output > 0.5:
        correct += 1
    else:
        incorrect += 1

print('Accuracy of the network on the test samples: %d %%' % (100* correct/(correct + incorrect)))



  0%|          | 0/10000 [00:00<?, ?it/s][A[A

  0%|          | 45/10000 [00:00<00:22, 448.78it/s][A[A

  1%|          | 94/10000 [00:00<00:21, 458.87it/s][A[A

  1%|▏         | 143/10000 [00:00<00:21, 467.20it/s][A[A

  2%|▏         | 192/10000 [00:00<00:20, 473.48it/s][A[A

  2%|▏         | 241/10000 [00:00<00:20, 478.31it/s][A[A

  3%|▎         | 290/10000 [00:00<00:20, 481.48it/s][A[A

  3%|▎         | 339/10000 [00:00<00:20, 481.56it/s][A[A

  4%|▍         | 388/10000 [00:00<00:19, 483.10it/s][A[A

  4%|▍         | 437/10000 [00:00<00:19, 483.83it/s][A[A

  5%|▍         | 486/10000 [00:01<00:19, 484.72it/s][A[A

  5%|▌         | 535/10000 [00:01<00:19, 485.31it/s][A[A

  6%|▌         | 584/10000 [00:01<00:19, 485.30it/s][A[A

  6%|▋         | 634/10000 [00:01<00:19, 487.05it/s][A[A

  7%|▋         | 684/10000 [00:01<00:19, 489.30it/s][A[A

  7%|▋         | 734/10000 [00:01<00:18, 491.42it/s][A[A

  8%|▊         | 784/10000 [00:01<00:18, 493.59it/s

 90%|█████████ | 9039/10000 [00:13<00:01, 499.33it/s][A[A

 91%|█████████ | 9089/10000 [00:13<00:01, 499.43it/s][A[A

 91%|█████████▏| 9139/10000 [00:13<00:01, 499.47it/s][A[A

 92%|█████████▏| 9189/10000 [00:14<00:01, 499.37it/s][A[A

 92%|█████████▏| 9239/10000 [00:14<00:01, 499.43it/s][A[A

 93%|█████████▎| 9289/10000 [00:14<00:01, 499.21it/s][A[A

 93%|█████████▎| 9339/10000 [00:14<00:01, 499.17it/s][A[A

 94%|█████████▍| 9389/10000 [00:14<00:01, 499.00it/s][A[A

 94%|█████████▍| 9439/10000 [00:14<00:01, 499.03it/s][A[A

 95%|█████████▍| 9490/10000 [00:14<00:01, 499.40it/s][A[A

 95%|█████████▌| 9540/10000 [00:14<00:00, 499.23it/s][A[A

 96%|█████████▌| 9590/10000 [00:14<00:00, 499.26it/s][A[A

 96%|█████████▋| 9640/10000 [00:14<00:00, 499.18it/s][A[A

 97%|█████████▋| 9690/10000 [00:15<00:00, 499.19it/s][A[A

 97%|█████████▋| 9740/10000 [00:15<00:00, 499.39it/s][A[A

 98%|█████████▊| 9790/10000 [00:15<00:00, 499.22it/s][A[A

 98%|█████████▊| 9840/10

Accuracy of the network on the test samples: 65 %


# sampling --> need to change

In [29]:
k = 1e5

# Generate uniformly without replacement
def get_samples(kind="pep",num=k):
    if kind == "apt":
        samples = [all_aptamers[i] for i in np.random.choice(len(all_aptamers), num_samples, replace=False)]
    else:
        samples = [all_peptides[i] for i in np.random.choice(len(all_peptides), num_samples, replace=False)]
    return samples

# Sample x' from P_X (assume peptides follow NNK)
def get_x_prime(k):
    x_primes = []
    for _ in range(k):
        pvals = [0.089]*3 + [0.065]*5 + [0.034]*12
        x_idx = np.random.choice(20, 7, p=pvals)
        x_prime = "M"
        for i in x_idx:
            x_prime += aa_list[i]
        x_primes.append(x_prime)
    return x_primes

# Sample y' from P_Y (assume apatamers follow uniform)
def get_y_prime(k):
    y_primes = []
    for _ in range(k):
        y_idx = np.random.randint(0, 4, 40)
        y_prime = ""
        for i in y_idx:
            y_prime += na_list[i]
        y_primes.append(y_prime)
    return y_primes