In [1]:
import os, sys
import numpy as np
import json
import random
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD, Adam
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR
import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from numpy import linalg as LA

In [2]:
torch.manual_seed(12345)
k = 10000
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
na_list = ['A', 'C', 'G', 'T'] #nucleic acids
aa_list = ['R', 'L', 'S', 'A', 'G', 'P', 'T', 'V', 'N', 'D', 'C', 'Q', 'E', 'H', 'I', 'K', 'M', 'F', 'W', 'Y'] #amino acids
NNK_freq = [0.09375]*3 + [0.0625]*5 + [0.03125]*13 #freq of 21 NNK codons including the stop codon
sum_20 = 0.0625*5 + 0.09375*3 + 0.03125*12 #sum of freq without the stop codon
pvals = [0.09375/sum_20]*3 + [0.0625/sum_20]*5 + [0.03125/sum_20]*12 #normalize freq for 20 codons
pvals = [0.09375/sum_20]*3 + [0.0625/sum_20]*5 + [0.03125/sum_20]*11 + \
        [1- sum([0.09375/sum_20]*3 + [0.0625/sum_20]*5 + [0.03125/sum_20]*11)] 
        #adjust sum to 1 due to numerical issue
aa_dict = dict(zip(aa_list, pvals))
encoding_style = 'clipped'

In [3]:
# Original BLOSUM62 matrix
original_blosum62 = {}
with open('../blosum62.txt', 'r') as f:
    for line in f:
        split_line = line.strip().split()
        aa = split_line[0]
        encoding = [int(x) for x in split_line[1:-3]]
        original_blosum62[aa] = encoding
blosum_matrix = np.zeros((20, 20))
for i, aa in enumerate(original_blosum62.keys()):
    sims = original_blosum62[aa]
    for j, s in enumerate(sims):
        blosum_matrix[i][j] = s   
u, V = LA.eig(blosum_matrix)
clipped_u = u
clipped_u[clipped_u < 0] = 0
lamb = np.diag(clipped_u)
T = V
clip_blosum62 = {}
for i, aa in enumerate(original_blosum62.keys()):
    clip_blosum62[aa] = np.dot(np.sqrt(lamb), V[i])

In [4]:
peptide = "MTATRLST"
aptamer_0 = np.full((40, 4), 0.25)

## Model

In [5]:
# Expects peptides to be encoding according to BLOSUM62 matrix
# Expects aptamers to be one hot encoded
class BlosumLinearNet(nn.Module):
    def __init__(self):
        super(BlosumLinearNet, self).__init__()
        self.name = "BlosumLinearNet"
        self.single_alphabet = False
        
        self.fc_apt_1 = nn.Linear(160, 200) 
        self.fc_apt_2 = nn.Linear(200, 250)
        self.fc_apt_3 = nn.Linear(250, 300)
        
        self.fc_pep_1 = nn.Linear(160, 200)
        self.fc_pep_2 = nn.Linear(200, 250)
        
        self.relu = nn.ReLU()
        
        self.fc_apt = nn.Sequential(self.fc_apt_1, self.fc_apt_2, self.fc_apt_3)
        self.fc_pep = nn.Sequential(self.fc_pep_1, self.fc_pep_2)
        
        self.fc1 = nn.Linear(550, 600)
        self.fc2 = nn.Linear(600, 1)
        
    def forward(self, apt, pep):
        apt = apt.view(-1, 1).T
        pep = pep.view(-1, 1).T
        
        apt = self.fc_apt(apt)
        pep = self.fc_pep(pep)
        x = torch.cat((apt, pep), 1)
        x = self.fc2(self.fc1(x))
        x = torch.sigmoid(x)
        return x

In [6]:
# Reinstantiate the model with the proper weights
model = BlosumLinearNet()
model_name = model.name
model_id = "07132020"
model.to(device)
checkpointed_model = '../model_checkpoints/binary/%s/%s.pth' % (model_name, "07102020")
checkpoint = torch.load(checkpointed_model)
optimizer = SGD(model.parameters(), lr=1e-2)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
init_epoch = checkpoint['epoch'] +1
print("Reloading model: ", model.name, " at epoch: ", init_epoch)
model.eval()

Reloading model:  BlosumLinearNet  at epoch:  34


BlosumLinearNet(
  (fc_apt_1): Linear(in_features=160, out_features=200, bias=True)
  (fc_apt_2): Linear(in_features=200, out_features=250, bias=True)
  (fc_apt_3): Linear(in_features=250, out_features=300, bias=True)
  (fc_pep_1): Linear(in_features=160, out_features=200, bias=True)
  (fc_pep_2): Linear(in_features=200, out_features=250, bias=True)
  (relu): ReLU()
  (fc_apt): Sequential(
    (0): Linear(in_features=160, out_features=200, bias=True)
    (1): Linear(in_features=200, out_features=250, bias=True)
    (2): Linear(in_features=250, out_features=300, bias=True)
  )
  (fc_pep): Sequential(
    (0): Linear(in_features=160, out_features=200, bias=True)
    (1): Linear(in_features=200, out_features=250, bias=True)
  )
  (fc1): Linear(in_features=550, out_features=600, bias=True)
  (fc2): Linear(in_features=600, out_features=1, bias=True)
)

## SGD based search

In [7]:
# Encode the peptide appropriately
def blosum62_encoding(sequence, seq_type='peptide', single_alphabet=False, style=encoding_style):
    if single_alphabet:
        pass
    else:
        if seq_type == 'peptide':
            encoding = []
            for i in range(len(sequence)):
                if style == "clipped":
                    encoding.append(clip_blosum62[sequence[i]])
                else:
                    encoding.append(original_blosum62[sequence[i]])
            encoding = np.asarray(encoding)
        else:
            #Translation
            letters = na_list
            encoding = np.zeros(len(sequence))
            for i in range(len(sequence)):
                char = sequence[i]
                idx = letters.index(char)
                encoding[i] = idx
        return encoding 
# Convert a pair to one-hot tensor
def convert(apt, pep, label, single_alphabet=False): 
    if single_alphabet:
        pair = translate([apt, pep], single_alphabet=True) #(2, 40)
        pair = torch.FloatTensor(np.reshape(pair, (-1, pair.shape[0], pair.shape[1]))).to(device)
        label = torch.FloatTensor([[label]]).to(device)
        return pair, label
    else:
        pep = blosum62_encoding(pep, seq_type='peptide') 
        apt = torch.FloatTensor(np.reshape(apt, (-1, apt.shape[1], apt.shape[0]))).to(device) #(1, 4, 40)
        pep = torch.FloatTensor(np.reshape(pep, (-1, pep.shape[1], pep.shape[0]))).to(device) #(1, 20, 8)
        
        label = torch.FloatTensor([[label]]).to(device)
        return apt, pep, label

# Getting the output of the model for a pair (aptamer, peptide)
def update(x, y, p, single_alphabet=False):
    if single_alphabet:
        p.requires_grad=True
        p = p.to(device)
        out = model(p)
        return out
    else:
        x.requires_grad=True
        y.requires_grad=False
        x = x.to(device)
        y = y.to(device)
        out = model(x, y)
        return out

In [8]:
# Un one-hot the aptamer
def stringify(oh):
    # oh.shape = (1, 4, 40)
    aptamer_string = ""
    na_list = ['A', 'C', 'G', 'T']
    for i in range(40):
        column = oh[0, :, i]
        ind = np.argmax(column)
        aptamer_string += na_list[ind]
    return aptamer_string

In [9]:
# Round the resulting aptamer
def round_aptamer(apt):
    rounded_aptamer = np.zeros((1, 4, 40))
    for i in range(40):
        ind = np.argmax(curr_aptamer[i, :, :])
        rounded_aptamer[0, ind, i] = 1
    return rounded_aptamer

## Use SGD to find an aptamer

In [12]:
curr_aptamer = aptamer_0
for k in range(100):
    a, p, l = convert(curr_aptamer, peptide, 1, single_alphabet=False)
    train_score = update(a, p, None, single_alphabet=False)
    train_score.backward()
    new_aptamer = np.zeros((40, 4, 1))
    alpha_k = 1/(2*(k + 1))
    for i in range(40):
        ind = np.argmax(a.grad[:, :, i].cpu().numpy())
        for j in range(4):
            # new_aptamer.shape = 40, 4, 1
            # curr_aptamer.shape = 1, 4, 40
            new_aptamer[i, j, 0] = (1 - alpha_k)*a[0, j, i] + alpha_k*(j == ind)
    
    curr_aptamer = new_aptamer
    # Round the aptamer and find the resulting string
    rounded_aptamer = round_aptamer(curr_aptamer)
    aptamer_string = stringify(rounded_aptamer)
    print(str(aptamer_string))    

GTTTGGATCGCCCCCTGGGGGAGTGAAACCCGGATTGGAC
GTATGCTACCGATCCAGCGTGAATGACAGCCATACTGGAC
GTGTGATAACACTCTGGTGTGACTGGCAACTTTACTGGAC
CTGTGCTAAGACTCCAGTGAAACTGGCAACTCCACTGGAC
CTTTGCACTGGCCCCATTAACACCGGCAAATCGACCAGAA
CACTGTATTGTACCCAGTTACAGCAGCCACTCGGCAGGAA
CGCCGTATATTTGCCAGTAACAGAAGTTACTATGCCGGAA
CGCCGTACATTTCCCAACACCAGAAGTTTCTACGCCGGAA
TGCTGTACTTCTCCCAGCAACAGAAGTGCCTGCTTGTGAA
CCTTGAACTTCTCCCAGGAGCAGTAGCTCAGGCGCAGGAA
TCTTGAACTCGACCTAGGAGCAGCAGCATATGCGCAGACA
GTCTCATACCCACCTAGGAGCAGGAGTTTATGCTGGTATA
GGCTCAAACAGACCGAGATGAAGGAGCTTAGGCAGGTACA
GGCTCATACCGGCCCAGATGCAGCAGCTGTGGGTGGAAAA
GGCTAATAAAGACCCTGATGCTGCAGCACCGGGAGGTAAG
GGCCCCAGAAGATCCTGATGCTGCAGCAGCGGAGGGAGTA
CGCCCGCGTAGATCTCGATCGAGCAGCAGCTGAGGCAGCA
TACTCCCGTAGACCCAGATGGTGTTACAGCAGATGAACAA
TGGTCGAGTAGACCCAGCTGGAGTTACACCCGTGGCAGCA
CTGAGTAGTCATCCTAGCTGTTGCTGCACCTGCGGCAACG
CGCTGCAGGAATACTTGAAGGTGCTGCGCCTGCTGCAGCG
CACAACAGCAGTCCATCGTCGTGCTGCGTATGCTACACCA
TACTACAGCTGTACTTCCTCCTGCTGCATTACTTGGACCA
TGCAATAGCAGTTCCTGCTGCGTTTCTTTTACTTGGACCA
TGCAACACCAGGACCT