In [1]:
import os, sys
import numpy as np
import json
import random
import torch

## Prepare the data --> one hot encoding matrices

In [2]:
BATCH_SIZE = 1000
random.seed(42)

In [3]:
aptamer_dataset_file = "../data/aptamer_dataset.json"

def construct_dataset():
    with open(aptamer_dataset_file, 'r') as f:
        aptamer_data = json.load(f)
    full_dataset = []
    for aptamer in aptamer_data:
        peptides = aptamer_data[aptamer]
        for peptide, _ in peptides:
            if '_' in peptide:
                split = peptide.split('_')
                save = split[0]
                if len(save) < 8:
                    continue
            if len(aptamer) == 40 and len(peptide) == 8:
                full_dataset.append((aptamer, peptide))
    return full_dataset

In [4]:
full_dataset = construct_dataset()
random.shuffle(full_dataset)
training_set = full_dataset[:int(0.8*len(full_dataset))]
test_set = full_dataset[int(0.2*len(full_dataset)):]

In [5]:
class AptamerDataset(torch.utils.data.Dataset):
    def __init__(self, training_set):
        super(AptamerDataset, self).__init__()
        self.training_set = training_set
        num_batches = int(len(training_set)/BATCH_SIZE)
        self.training_set = training_set[:int(num_batches * BATCH_SIZE)]
        
    def __len__(self):
         return len(self.training_set)
         
    def __getitem__(self, idx):
        aptamer, peptide = self.training_set[idx]
        return aptamer, peptide

In [6]:
train_dataset = AptamerDataset(training_set)
test_dataset = AptamerDataset(test_set)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [7]:
aa_list = ['R', 'L', 'S', 'A', 'G', 'P', 'T', 'V', 'N', 'D', 'C', 'Q', 'E', 'H', 'I', 'K', 'M', 'F', 'W', 'Y']
na_list = ['A', 'C', 'G', 'T']

## Takes a peptide and aptamer sequence and converts to one-hot matrix
def one_hot(sequence_list, seq_type='peptide'):
    if seq_type == 'peptide':
        letters = aa_list
    else:
        letters = na_list
    
    one_hot = np.zeros((len(sequence_list), len(sequence_list[0]), len(letters)))
    for j in range(len(sequence_list)):
        sequence = sequence_list[j]
        for i in range(len(sequence)):
            element = sequence[i]
            idx = letters.index(element)
            one_hot[j][i][idx] = 1
    return one_hot

## Model --> CNN

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD

In [9]:
class TwoLayer(nn.Module):
    def __init__(self):
        super(TwoLayer, self).__init__()
        self.linear_apt_1 = nn.Linear(40, 40)
        self.linear_apt_2 = nn.Linear(40, 1)
        
        self.linear_pep_1 = nn.Linear(8, 8)
        self.linear_pep_2 = nn.Linear(8, 1)
        
        self.relu = nn.ReLU()
    
        self.sequential_pep = nn.Sequential(self.linear_pep_1,
                                            self.relu,
                                            self.linear_pep_2)
        
        self.sequential_apt = nn.Sequential(self.linear_apt_1,
                                            self.relu,
                                            self.linear_apt_2)
                
    def forward(self, apt, pep):
        apt = self.sequential_apt(apt)
        pep = self.sequential_pep(pep)
        print(apt.shape())
        apt = apt.view(-1, 1).T
        pep = pep.view(-1, 1).T
        x = torch.cat((apt, pep), 1)
        x = F.sigmoid(x)
        return x
    
    def loss(self, prediction, label):
        l = nn.MSELoss()
        label = torch.FloatTensor(label)
        label = label.reshape((1, 1))
        return l(torch.FloatTensor(prediction), label)

In [10]:
class ConvNet(nn.Module):
    def __init__(self, batch_size):
        super(ConvNet, self).__init__()
        self.cnn_apt_1 = nn.Conv2d(BATCH_SIZE, 40, 1)
        self.cnn_apt_2 = nn.Conv2d(40, 10, 1)
        self.cnn_apt_3 = nn.Conv2d(10, 1, 1)
        self.fc_apt_1 = nn.Linear(160, 1)
        
        self.cnn_pep_1 = nn.Conv2d(BATCH_SIZE, 8, 1)
        self.cnn_pep_2 = nn.Conv2d(8, 1, 1)
        self.fc_pep_1 = nn.Linear(64, 1)
        
        self.pool = nn.MaxPool2d(2, 1)
        self.relu = nn.ReLU()
        
        #self.dropout = nn.Dropout(0.1)
        
        self.sequential_pep = nn.Sequential(self.cnn_pep_1,
                                            #self.dropout,
                                            self.relu, 
                                            self.pool, 
                                            self.cnn_pep_2)
        
        self.sequential_apt = nn.Sequential(self.cnn_apt_1, 
                                            #self.dropout,
                                            self.relu, 
                                            self.pool, 
                                            self.cnn_apt_2, 
                                            #self.dropout,
                                            self.relu, 
                                            self.pool, 
                                            self.cnn_apt_3)
        
        self.fc1 = nn.Linear(209, BATCH_SIZE)
        
    def forward(self, apt, pep):
        apt = self.sequential_apt(apt)
        pep = self.sequential_pep(pep)
        
        apt = apt.view(-1, 1).T
        pep = pep.view(-1, 1).T
        
        x = torch.cat((apt, pep), 1)
        x = self.fc1(x)
        x = F.sigmoid(x)
        return x
    
    def loss(self, prediction, label):
        l = nn.MSELoss()
        label = torch.FloatTensor(label)
        return l(torch.FloatTensor(prediction), label)

In [11]:
model = ConvNet(batch_size=BATCH_SIZE)
def weights_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.xavier_uniform_(m.weight.data)
        nn.init.zeros_(m.bias.data)
    if isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight.data, nonlinearity='relu')
        nn.init.zeros_(m.bias.data)

#model_2 = TwoLayer()
# def weights_init_2(m):
#     if isinstance(m, nn.Linear):
#         nn.init.xavier_uniform_(m.weight.data)
#         nn.init.zeros_(m.bias.data)

model.apply(weights_init)
optimizer = Adam(model.parameters(), lr=1e-3, weight_decay=1e-2)

In [12]:
# Training Loop
import tqdm
for epoch in range(1):
    print("Epoch: ", epoch)
    model.train()
    running_loss = 0.0
    # Come up with a trainloader
    for batch_idx, (aptamer, peptide) in enumerate(tqdm.tqdm(train_loader)):
        # Peptide and aptamer, one-hot encode them 
        pep = one_hot(peptide, seq_type='peptide')
        apt = one_hot(aptamer, seq_type='aptamer')
        
        pep = torch.FloatTensor(np.reshape(pep, (1, pep.shape[0], pep.shape[1], pep.shape[2])))
        apt = torch.FloatTensor(np.reshape(apt, (1, apt.shape[0], apt.shape[1], apt.shape[2])))
        
        output = model(apt, pep)
        
        label = np.ones((1, BATCH_SIZE))
        loss = model.loss(output, label)
        optimizer.zero_grad()
        loss.backward()
        
        #hyperparameter
        clip = 5
        torch.nn.utils.clip_grad_norm(model.parameters(), clip)
        
        optimizer.step()
        running_loss += loss.item()
        if batch_idx % 10 == 9:
            print('[%d, %5d] loss: %.7f' %
                  (epoch + 1, batch_idx + 1, running_loss / 10*BATCH_SIZE))
            running_loss = 0.0
    
print('Finished Training')

  0%|          | 0/620 [00:00<?, ?it/s]

Epoch:  0


  2%|▏         | 10/620 [00:02<01:59,  5.10it/s]

[1,    10] loss: 262.8596246


  3%|▎         | 20/620 [00:04<02:58,  3.36it/s]

[1,    20] loss: 249.5824128


  5%|▍         | 30/620 [00:07<02:54,  3.37it/s]

[1,    30] loss: 245.8046347


  6%|▋         | 40/620 [00:10<02:14,  4.30it/s]

[1,    40] loss: 238.0706340


  8%|▊         | 50/620 [00:12<02:18,  4.11it/s]

[1,    50] loss: 214.9693936


 10%|▉         | 61/620 [00:15<02:16,  4.11it/s]

[1,    60] loss: 137.7282754


 11%|█▏        | 70/620 [00:18<02:41,  3.40it/s]

[1,    70] loss: 30.8818967


 13%|█▎        | 80/620 [00:20<02:14,  4.00it/s]

[1,    80] loss: 4.4097976


 15%|█▍        | 90/620 [00:23<02:32,  3.47it/s]

[1,    90] loss: 6.4695180


 16%|█▌        | 100/620 [00:26<02:17,  3.77it/s]

[1,   100] loss: 8.9608268


 18%|█▊        | 110/620 [00:29<02:34,  3.30it/s]

[1,   110] loss: 6.0989915


 19%|█▉        | 120/620 [00:32<02:26,  3.41it/s]

[1,   120] loss: 5.4338183


 21%|██        | 130/620 [00:35<02:17,  3.56it/s]

[1,   130] loss: 5.9257256


 23%|██▎       | 140/620 [00:38<02:23,  3.35it/s]

[1,   140] loss: 5.5864691


 24%|██▍       | 150/620 [00:41<02:16,  3.43it/s]

[1,   150] loss: 4.6322359


 26%|██▌       | 160/620 [00:43<02:05,  3.67it/s]

[1,   160] loss: 4.3466561


 27%|██▋       | 170/620 [00:46<02:13,  3.37it/s]

[1,   170] loss: 4.1953410


 29%|██▉       | 180/620 [00:49<01:53,  3.86it/s]

[1,   180] loss: 4.1553673


 31%|███       | 190/620 [00:52<01:57,  3.66it/s]

[1,   190] loss: 4.0849059


 32%|███▏      | 200/620 [00:54<01:47,  3.93it/s]

[1,   200] loss: 3.9929293


 34%|███▍      | 210/620 [00:57<01:58,  3.46it/s]

[1,   210] loss: 3.9279811


 35%|███▌      | 220/620 [01:00<01:58,  3.38it/s]

[1,   220] loss: 3.8642235


 37%|███▋      | 230/620 [01:03<02:11,  2.96it/s]

[1,   230] loss: 3.8803976


 39%|███▊      | 240/620 [01:06<01:45,  3.59it/s]

[1,   240] loss: 3.7527795


 40%|████      | 250/620 [01:09<01:47,  3.45it/s]

[1,   250] loss: 3.7302403


 42%|████▏     | 260/620 [01:12<01:42,  3.52it/s]

[1,   260] loss: 3.6239669


 44%|████▎     | 270/620 [01:15<01:40,  3.48it/s]

[1,   270] loss: 3.6076224


 45%|████▌     | 280/620 [01:18<01:36,  3.51it/s]

[1,   280] loss: 3.4772878


 47%|████▋     | 290/620 [01:20<01:27,  3.77it/s]

[1,   290] loss: 3.4132683


 48%|████▊     | 300/620 [01:23<01:14,  4.29it/s]

[1,   300] loss: 3.3436632


 50%|█████     | 310/620 [01:26<01:26,  3.57it/s]

[1,   310] loss: 3.2133554


 52%|█████▏    | 321/620 [01:28<01:14,  4.01it/s]

[1,   320] loss: 3.1794399


 53%|█████▎    | 330/620 [01:31<01:19,  3.64it/s]

[1,   330] loss: 3.0711760


 55%|█████▍    | 340/620 [01:34<01:14,  3.76it/s]

[1,   340] loss: 3.0032045


 56%|█████▋    | 350/620 [01:36<01:02,  4.29it/s]

[1,   350] loss: 2.9872715


 58%|█████▊    | 360/620 [01:39<01:03,  4.12it/s]

[1,   360] loss: 2.9130010


 60%|█████▉    | 370/620 [01:41<01:07,  3.72it/s]

[1,   370] loss: 2.8814459


 61%|██████▏   | 380/620 [01:44<01:07,  3.57it/s]

[1,   380] loss: 2.8417462


 63%|██████▎   | 390/620 [01:47<01:14,  3.07it/s]

[1,   390] loss: 2.8530647


 65%|██████▍   | 400/620 [01:50<00:55,  3.96it/s]

[1,   400] loss: 2.7967954


 66%|██████▌   | 410/620 [01:53<01:00,  3.47it/s]

[1,   410] loss: 2.8023027


 68%|██████▊   | 420/620 [01:55<00:50,  3.99it/s]

[1,   420] loss: 2.7710998


 69%|██████▉   | 430/620 [01:58<00:58,  3.24it/s]

[1,   430] loss: 2.7291590


 71%|███████   | 440/620 [02:02<00:54,  3.30it/s]

[1,   440] loss: 2.7600611


 73%|███████▎  | 450/620 [02:04<00:40,  4.21it/s]

[1,   450] loss: 2.7367390


 74%|███████▍  | 460/620 [02:06<00:36,  4.44it/s]

[1,   460] loss: 2.7061131


 76%|███████▌  | 470/620 [02:09<00:37,  3.97it/s]

[1,   470] loss: 2.7345390


 77%|███████▋  | 480/620 [02:11<00:35,  3.93it/s]

[1,   480] loss: 2.7237380


 79%|███████▉  | 490/620 [02:14<00:36,  3.57it/s]

[1,   490] loss: 2.6670255


 81%|████████  | 500/620 [02:16<00:32,  3.65it/s]

[1,   500] loss: 2.6962666


 82%|████████▏ | 510/620 [02:19<00:27,  4.04it/s]

[1,   510] loss: 2.6984680


 84%|████████▍ | 520/620 [02:22<00:27,  3.59it/s]

[1,   520] loss: 2.6608465


 86%|████████▌ | 531/620 [02:25<00:19,  4.48it/s]

[1,   530] loss: 2.6824406


 87%|████████▋ | 540/620 [02:27<00:19,  4.15it/s]

[1,   540] loss: 2.6579493


 89%|████████▊ | 550/620 [02:29<00:17,  4.07it/s]

[1,   550] loss: 2.6604593


 90%|█████████ | 560/620 [02:32<00:15,  3.92it/s]

[1,   560] loss: 2.6654822


 92%|█████████▏| 570/620 [02:35<00:14,  3.39it/s]

[1,   570] loss: 2.6410210


 94%|█████████▎| 580/620 [02:38<00:12,  3.16it/s]

[1,   580] loss: 2.6595160


 95%|█████████▌| 590/620 [02:41<00:07,  3.76it/s]

[1,   590] loss: 2.6412007


 97%|█████████▋| 600/620 [02:43<00:04,  4.16it/s]

[1,   600] loss: 2.6444253


 98%|█████████▊| 610/620 [02:46<00:02,  3.85it/s]

[1,   610] loss: 2.6540520


100%|██████████| 620/620 [02:48<00:00,  3.94it/s]

[1,   620] loss: 2.6256444
Finished Training





## Evaluation --> compare to random

In [13]:
correct = 0
incorrect = 0
for batch_idx, (aptamer, peptide) in enumerate(tqdm.tqdm(test_loader)):
    pep = one_hot(peptide, seq_type='peptide')
    apt = one_hot(aptamer, seq_type='aptamer')
    
    pep = torch.FloatTensor(np.reshape(pep, (1, pep.shape[0], pep.shape[1], pep.shape[2])))
    apt = torch.FloatTensor(np.reshape(apt, (1, apt.shape[0], apt.shape[1], apt.shape[2])))

    output = model(apt, pep).detach().numpy().flatten()
    for i in range(output.shape[0]):
        o = output[i]
        if o > 0.5:
            correct += 1
        else:
            incorrect += 1
    break

print('Recall of the network on the test samples: %d %%' % (100* correct/(correct + incorrect)))

  0%|          | 0/620 [00:00<?, ?it/s]

Recall of the network on the test samples: 100 %





# sampling --> need to change

In [14]:
k = 1e5

# Generate uniformly without replacement
def get_samples(kind="pep",num=k):
    if kind == "apt":
        samples = [all_aptamers[i] for i in np.random.choice(len(all_aptamers), num_samples, replace=False)]
    else:
        samples = [all_peptides[i] for i in np.random.choice(len(all_peptides), num_samples, replace=False)]
    return samples

# Sample x' from P_X (assume peptides follow NNK)
def get_x_prime(k):
    x_primes = []
    for _ in range(k):
        pvals = [0.089]*3 + [0.065]*5 + [0.034]*12
        x_idx = np.random.choice(20, 7, p=pvals)
        x_prime = "M"
        for i in x_idx:
            x_prime += aa_list[i]
        x_primes.append(x_prime)
    return x_primes

# Sample y' from P_Y (assume apatamers follow uniform)
def get_y_prime(k):
    y_primes = []
    for _ in range(k):
        y_idx = np.random.randint(0, 4, 40)
        y_prime = ""
        for i in y_idx:
            y_prime += na_list[i]
        y_primes.append(y_prime)
    return y_primes