In [1]:
import os, sys
import numpy as np
import json
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
import tqdm

## Dataset and one-hot encoding

In [2]:
BATCH_SIZE = 256
random.seed(42)

In [3]:
aptamer_dataset_file = "../data/aptamer_dataset.json"
clustered_dataset_file = "../data/clustered_aptamer_dataset.json"

def construct_dataset():
    with open(clustered_dataset_file, 'r') as f:
        aptamer_data = json.load(f)
    full_dataset = []
    for aptamer in aptamer_data:
        peptides = aptamer_data[aptamer]
        for peptide, _ in peptides:
            peptide = peptide.replace("_", "")
            if len(aptamer) == 40 and len(peptide) == 7:
                peptide = "M" + peptide
                full_dataset.append((aptamer, peptide))
    return list(set(full_dataset)) #removed duplicates

In [15]:
full_dataset = construct_dataset()
aptamers = [p[0] for p in full_dataset]
peptides = [p[1] for p in full_dataset]
training_set = full_dataset[:int(0.8*len(full_dataset))]
test_set = full_dataset[int(0.8*len(full_dataset)):]
print(str(len(test_set)))

40564


In [5]:
class TrainDataset(torch.utils.data.Dataset):
    def __init__(self, training_set):
        super(TrainDataset, self).__init__() 
        self.training_set = training_set
        n = len(training_set)
        self.training_set = training_set[:n-n%BATCH_SIZE]
        
    def __len__(self):
        return len(self.training_set)

    def __getitem__(self, idx):
        aptamer, peptide = self.training_set[idx]
        return aptamer, peptide
    
class TestDataset(torch.utils.data.Dataset):
    def __init__(self, test_set):
        super(TestDataset, self).__init__() 
        self.test_set = test_set
        n = len(test_set)
        self.test_set = test_set[:n-n%BATCH_SIZE]
        
    def __len__(self):
        return len(self.test_set)

    def __getitem__(self, idx):
        aptamer, peptide = self.test_set[idx]
        return aptamer, peptide

In [6]:
train_dataset = TrainDataset(training_set)
test_dataset = TestDataset(test_set)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE)

In [7]:
na_list = ['A', 'C', 'G', 'T']
aa_list = ['R', 'L', 'S', 'A', 'G', 'P', 'T', 'V', 'N', 'D', 'C', 'Q', 'E', 'H', 'I', 'K', 'M', 'F', 'W', 'Y']
pvals = [0.089]*3 + [0.065]*5 + [0.034]*12
aa_dict = dict(zip(aa_list, pvals))

In [8]:
## Takes a peptide and aptamer sequence and converts to one-hot matrix
def one_hot(sequence_list, seq_type='peptide'):
    if seq_type == 'peptide':
        letters = aa_list
    else:
        letters = na_list
    
    one_hot = np.zeros((len(sequence_list), len(sequence_list[0]), len(letters)))
    
    for j in range(len(sequence_list)):
        sequence = sequence_list[j]
        for i in range(len(sequence)):
            element = sequence[i]
            idx = letters.index(element)
            one_hot[j][i][idx] = 1
    return one_hot

## NN Model

In [9]:
class TwoLayer(nn.Module):
    def __init__(self):
        super(TwoLayer, self).__init__()
        self.linear_apt_1 = nn.Linear(40, 40)
        self.linear_apt_2 = nn.Linear(40, 1)
        
        self.linear_pep_1 = nn.Linear(8, 8)
        self.linear_pep_2 = nn.Linear(8, 1)
        
        self.relu = nn.ReLU()
    
        self.sequential_pep = nn.Sequential(self.linear_pep_1,
                                            self.relu,
                                            self.linear_pep_2)
        
        self.sequential_apt = nn.Sequential(self.linear_apt_1,
                                            self.relu,
                                            self.linear_apt_2)
                
    def forward(self, apt, pep):
        apt = self.sequential_apt(apt)
        pep = self.sequential_pep(pep)
        apt = apt.view(-1, 1).T
        pep = pep.view(-1, 1).T
        x = torch.cat((apt, pep), 1)
        x = F.sigmoid(x)
        return x
    
    def loss(self, prediction, label):
        l = nn.MSELoss()
        label = torch.FloatTensor(label)
        label = label.reshape((1, 1))
        return l(torch.FloatTensor(prediction), label)

In [10]:
class ConvNet(nn.Module):
    def __init__(self, batch_size):
        super(ConvNet, self).__init__()
        self.cnn_apt_1 = nn.Conv2d(BATCH_SIZE, 40, 1)
        self.cnn_apt_2 = nn.Conv2d(40, 10, 1)
        self.cnn_apt_3 = nn.Conv2d(10, 1, 1)
        self.fc_apt_1 = nn.Linear(160, 1)
        
        self.cnn_pep_1 = nn.Conv2d(BATCH_SIZE, 8, 1)
        self.cnn_pep_2 = nn.Conv2d(8, 1, 1)
        self.fc_pep_1 = nn.Linear(64, 1)
        
        self.pool = nn.MaxPool2d(2, 1)
        self.relu = nn.ReLU()
        
        #self.dropout = nn.Dropout(0.1)
        
        self.sequential_pep = nn.Sequential(self.cnn_pep_1,
                                            #self.dropout,
                                            self.relu, 
                                            self.pool, 
                                            self.cnn_pep_2)
        
        self.sequential_apt = nn.Sequential(self.cnn_apt_1, 
                                            #self.dropout,
                                            self.relu, 
                                            self.pool, 
                                            self.cnn_apt_2, 
                                            #self.dropout,
                                            self.relu, 
                                            self.pool, 
                                            self.cnn_apt_3)
        
        self.fc1 = nn.Linear(209, BATCH_SIZE)
        
    def forward(self, apt, pep):
        apt = self.sequential_apt(apt)
        pep = self.sequential_pep(pep)
        
        apt = apt.view(-1, 1).T
        pep = pep.view(-1, 1).T
        
        x = torch.cat((apt, pep), 1)
        x = self.fc1(x)
        x = torch.sigmoid(x)
        return x
    
    def loss(self, prediction, label):
        l = nn.MSELoss()
        label = torch.FloatTensor(label)
        return l(torch.FloatTensor(prediction), label)

In [11]:
model = ConvNet(batch_size=BATCH_SIZE)
def weights_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.xavier_uniform_(m.weight.data)
        nn.init.zeros_(m.bias.data)
    if isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight.data, nonlinearity='relu')
        nn.init.zeros_(m.bias.data)

#model_2 = TwoLayer()
# def weights_init_2(m):
#     if isinstance(m, nn.Linear):
#         nn.init.xavier_uniform_(m.weight.data)
#         nn.init.zeros_(m.bias.data)

model.apply(weights_init)
optimizer = Adam(model.parameters(), lr=1e-3, weight_decay=1e-2)

## Training

In [12]:
# Training Loop
for epoch in range(1):
    print("Epoch: ", epoch)
    model.train()
    running_loss = 0.0
    # Come up with a trainloader
    for batch_idx, (aptamer, peptide) in enumerate(tqdm.tqdm(train_loader)):
        # Peptide and aptamer, one-hot encode them 
        pep = one_hot(peptide, seq_type='peptide')
        apt = one_hot(aptamer, seq_type='aptamer')
        
        pep = torch.FloatTensor(np.reshape(pep, (1, pep.shape[0], pep.shape[1], pep.shape[2])))
        apt = torch.FloatTensor(np.reshape(apt, (1, apt.shape[0], apt.shape[1], apt.shape[2])))
        
        output = model(apt, pep)
        
        label = np.ones((1, BATCH_SIZE))
        loss = model.loss(output, label)
        optimizer.zero_grad()
        loss.backward()
        
        #hyperparameter
        clip = 5
        torch.nn.utils.clip_grad_norm(model.parameters(), clip)
        
        optimizer.step()
        running_loss += loss.item()
        if batch_idx % 10 == 9:
            print('[%d, %5d] loss: %.7f' %
                  (epoch + 1, batch_idx + 1, running_loss / 10*BATCH_SIZE))
            running_loss = 0.0
    
print('Finished Training')

  0%|          | 1/633 [00:00<01:33,  6.79it/s]

('Epoch: ', 0)


  2%|▏         | 12/633 [00:01<01:10,  8.79it/s]

[1,    10] loss: 61.4689205


  3%|▎         | 22/633 [00:02<01:01,  9.96it/s]

[1,    20] loss: 47.9223087


  5%|▌         | 33/633 [00:03<00:58, 10.28it/s]

[1,    30] loss: 21.0020545


  6%|▋         | 41/633 [00:04<01:05,  8.98it/s]

[1,    40] loss: 3.3176869


  9%|▊         | 54/633 [00:05<00:48, 11.86it/s]

[1,    50] loss: 1.3538689


  9%|▉         | 60/633 [00:05<00:34, 16.64it/s]

[1,    60] loss: 1.8007408


 11%|█         | 71/633 [00:06<01:02,  9.01it/s]

[1,    70] loss: 1.3526288


 13%|█▎        | 82/633 [00:08<01:00,  9.15it/s]

[1,    80] loss: 0.6807662


 14%|█▍        | 91/633 [00:09<01:03,  8.56it/s]

[1,    90] loss: 0.5666671


 16%|█▌        | 101/633 [00:10<01:17,  6.88it/s]

[1,   100] loss: 0.5375870


 18%|█▊        | 111/633 [00:12<01:25,  6.10it/s]

[1,   110] loss: 0.5382289


 19%|█▉        | 121/633 [00:13<01:13,  6.92it/s]

[1,   120] loss: 0.4397897


 21%|██        | 131/633 [00:15<01:14,  6.78it/s]

[1,   130] loss: 0.4018961


 22%|██▏       | 141/633 [00:16<01:20,  6.10it/s]

[1,   140] loss: 0.3866037


 24%|██▍       | 151/633 [00:18<01:18,  6.17it/s]

[1,   150] loss: 0.4471696


 25%|██▌       | 161/633 [00:19<01:05,  7.17it/s]

[1,   160] loss: 0.3783663


 27%|██▋       | 171/633 [00:21<01:11,  6.42it/s]

[1,   170] loss: 0.3560505


 29%|██▊       | 181/633 [00:22<01:08,  6.61it/s]

[1,   180] loss: 0.3806872


 30%|███       | 191/633 [00:23<00:43, 10.12it/s]

[1,   190] loss: 0.3704663


 32%|███▏      | 204/633 [00:24<00:27, 15.81it/s]

[1,   200] loss: 0.3622590


 33%|███▎      | 212/633 [00:24<00:22, 19.08it/s]

[1,   210] loss: 0.3609684


 35%|███▍      | 221/633 [00:25<00:39, 10.48it/s]

[1,   220] loss: 0.3357642


 36%|███▋      | 231/633 [00:26<00:31, 12.65it/s]

[1,   230] loss: 0.3464744


 38%|███▊      | 243/633 [00:27<00:26, 14.54it/s]

[1,   240] loss: 0.3539533


 40%|███▉      | 251/633 [00:27<00:31, 12.04it/s]

[1,   250] loss: 0.3538663


 42%|████▏     | 263/633 [00:29<00:30, 12.29it/s]

[1,   260] loss: 0.3542832


 43%|████▎     | 273/633 [00:29<00:23, 15.07it/s]

[1,   270] loss: 0.3359359


 45%|████▍     | 283/633 [00:30<00:22, 15.59it/s]

[1,   280] loss: 0.3566108


 46%|████▋     | 293/633 [00:30<00:20, 16.62it/s]

[1,   290] loss: 0.3530959


 48%|████▊     | 303/633 [00:31<00:20, 16.38it/s]

[1,   300] loss: 0.3533762


 49%|████▉     | 313/633 [00:32<00:20, 15.47it/s]

[1,   310] loss: 0.3367805


 51%|█████     | 321/633 [00:33<00:32,  9.52it/s]

[1,   320] loss: 0.3155464


 52%|█████▏    | 331/633 [00:34<00:42,  7.10it/s]

[1,   330] loss: 0.4013364


 54%|█████▍    | 342/633 [00:35<00:36,  7.88it/s]

[1,   340] loss: 0.3346506


 56%|█████▌    | 352/633 [00:36<00:25, 10.97it/s]

[1,   350] loss: 0.3184173


 57%|█████▋    | 362/633 [00:37<00:20, 13.19it/s]

[1,   360] loss: 0.3654594


 59%|█████▉    | 372/633 [00:38<00:18, 13.79it/s]

[1,   370] loss: 0.3621911


 60%|██████    | 381/633 [00:38<00:19, 13.08it/s]

[1,   380] loss: 0.3320122


 62%|██████▏   | 392/633 [00:39<00:17, 13.77it/s]

[1,   390] loss: 0.3899092


 64%|██████▎   | 402/633 [00:40<00:18, 12.75it/s]

[1,   400] loss: 0.3275609


 65%|██████▍   | 410/633 [00:41<00:22,  9.96it/s]

[1,   410] loss: 0.3556760


 67%|██████▋   | 422/633 [00:42<00:22,  9.57it/s]

[1,   420] loss: 0.3336691


 68%|██████▊   | 430/633 [00:43<00:20,  9.90it/s]

[1,   430] loss: 0.3690396


 70%|██████▉   | 441/633 [00:44<00:22,  8.35it/s]

[1,   440] loss: 0.3625881


 71%|███████   | 451/633 [00:46<00:27,  6.73it/s]

[1,   450] loss: 0.3259897


 73%|███████▎  | 461/633 [00:47<00:29,  5.93it/s]

[1,   460] loss: 0.3679208


 74%|███████▍  | 471/633 [00:49<00:24,  6.74it/s]

[1,   470] loss: 0.3515721


 76%|███████▌  | 480/633 [00:50<00:19,  7.79it/s]

[1,   480] loss: 0.3347246


 78%|███████▊  | 491/633 [00:51<00:15,  9.30it/s]

[1,   490] loss: 0.3577720


 79%|███████▉  | 501/633 [00:53<00:16,  8.03it/s]

[1,   500] loss: 0.3906060


 81%|████████  | 511/633 [00:54<00:15,  7.94it/s]

[1,   510] loss: 0.2983237


 82%|████████▏ | 521/633 [00:55<00:14,  7.96it/s]

[1,   520] loss: 0.3809748


 84%|████████▍ | 531/633 [00:56<00:13,  7.62it/s]

[1,   530] loss: 0.3464099


 85%|████████▌ | 541/633 [00:58<00:11,  8.02it/s]

[1,   540] loss: 0.3708178


 87%|████████▋ | 551/633 [00:59<00:11,  7.04it/s]

[1,   550] loss: 0.3290696


 89%|████████▊ | 561/633 [01:00<00:11,  6.53it/s]

[1,   560] loss: 0.3778466


 90%|█████████ | 571/633 [01:02<00:10,  5.91it/s]

[1,   570] loss: 0.3535264


 92%|█████████▏| 581/633 [01:04<00:08,  6.31it/s]

[1,   580] loss: 0.3388720


 93%|█████████▎| 591/633 [01:05<00:06,  6.18it/s]

[1,   590] loss: 0.3836388


 95%|█████████▍| 601/633 [01:07<00:05,  5.98it/s]

[1,   600] loss: 0.3157116


 97%|█████████▋| 611/633 [01:08<00:03,  6.73it/s]

[1,   610] loss: 0.3897336


 98%|█████████▊| 621/633 [01:10<00:01,  6.32it/s]

[1,   620] loss: 0.3457360


100%|█████████▉| 631/633 [01:11<00:00,  7.13it/s]

[1,   630] loss: 0.3450323


100%|██████████| 633/633 [01:12<00:00,  8.76it/s]

Finished Training





## Recall test

In [14]:
correct = 0
incorrect = 0
for batch_idx, (aptamer, peptide) in enumerate(tqdm.tqdm(test_loader)):
    pep = one_hot(peptide, seq_type='peptide')
    apt = one_hot(aptamer, seq_type='aptamer')
    
    pep = torch.FloatTensor(np.reshape(pep, (1, pep.shape[0], pep.shape[1], pep.shape[2])))
    apt = torch.FloatTensor(np.reshape(apt, (1, apt.shape[0], apt.shape[1], apt.shape[2])))

    output = model(apt, pep).detach().numpy().flatten()
    for i in range(output.shape[0]):
        o = output[i]
        if o > 0.9:
            correct += 1
        else:
            incorrect += 1
    break

print('Recall of the network on the test samples: %d %%' % (100* correct/(correct + incorrect)))

  0%|          | 0/158 [00:00<?, ?it/s]

Recall of the network on the test samples: 100 %





## Sampling methods

In [None]:
# Sample x from P_X (assume peptides follow NNK)
def get_x():
    x_idx = np.random.choice(20, 7, p=pvals)
    x = "M"
    for i in x_idx:
        x += aa_list[i]
    return x

# Sample y from P_Y (assume apatamers follow uniform)
def get_y():
    y_idx = np.random.randint(0, 4, 40)
    y = ""
    for i in y_idx:
        y += na_list[i]
    return y

# Generate uniformly from S without replacement
def get_xy(k):
    samples = [full_dataset[i] for i in np.random.choice(len(full_dataset), k, replace=False)]
    return samples

# S' contains S with double the size of S (domain for Importance Sampling)
def get_S_prime(k):
    S_prime = full_dataset.copy()
    for _ in range(k):
        S_prime.append((get_y(), get_x()))
    S_prime_new = S_prime.copy()[-n:]
    return list(set(S_prime)), S_prime_new

# Sample from S' without replacement
def get_xy_prime(k):
    samples = [S_prime[i] for i in np.random.choice(len(S_prime), k, replace=False)]
    return samples

## Test whether naive sampling would generate pairs in S -- nope

In [None]:
n = len(full_dataset)
S_prime, S_prime_new = get_S_prime(n)
print("Size of S: ", n)
print("Size of naively sampled dataset: ", len(S_prime_new))
diff = set(full_dataset) - set(S_prime_new)
print("Size of set difference: ", len(diff))

## Motivates importance sampling in SGD

In [None]:
# Returns pmf of a peptide
def get_x_pmf(x):
    pmf = 1
    for char in x[1:]: #skips first char "M"
        pmf *= aa_dict[char]
    return pmf

# Returns pmf of an aptamer
def get_y_pmf():
    return 0.25**40

In [None]:
def sgd(t=int(1e5), #num of iter
        gamma=1e-2, #step size
        batch=1): #batch size
    for _ in range(t):
        xy = get_xy(1) #sample pair from S
        xy_prime = get_xy_prime(1) #sample pair from S'
        const = 1 #indicator
        x_prime_pmf = get_x_pmf(xy_prime[0][1])
        y_prime_pmf = get_y_pmf()
        if xy_prime in full_dataset:
            const = 0

## Evaluation