## File to convert .pth models into the .h5 tensorflow format

In [1]:
import os, sys
import numpy as np
import json
import random
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD, Adam
from torch.optim.lr_scheduler import StepLR
import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import config

In [2]:
torch.manual_seed(12345)
na_list = ['A', 'C', 'G', 'T'] #nucleic acids
aa_list = ['R', 'L', 'S', 'A', 'G', 'P', 'T', 'V', 'N', 'D', 'C', 'Q', 'E', 'H', 'I', 'K', 'M', 'F', 'W', 'Y'] #amino acids
hydrophobicity = {'G': 0, 'A': 41, 'L':97, 'M': 74, 'F':100, 'W':97, 'K':-23, 'Q':-10, 'E':-31, 'S':-5, 'P':-46, 'V':76, 'I':99, 'C':49, 'Y':63, 'H':8, 'R':-14, 'N':-28, 'D':-55, 'T':13}
NNK_freq = [0.09375]*3 + [0.0625]*5 + [0.03125]*13 #freq of 21 NNK codons including the stop codon
sum_20 = 0.0625*5 + 0.09375*3 + 0.03125*12 #sum of freq without the stop codon
pvals = [0.09375/sum_20]*3 + [0.0625/sum_20]*5 + [0.03125/sum_20]*12 #normalize freq for 20 codons
pvals = [0.09375/sum_20]*3 + [0.0625/sum_20]*5 + [0.03125/sum_20]*11 + \
        [1- sum([0.09375/sum_20]*3 + [0.0625/sum_20]*5 + [0.03125/sum_20]*11)] 
        #adjust sum to 1 due to numerical issue
aa_dict = dict(zip(aa_list, pvals))

In [3]:
class LedidiNet(nn.Module):
    def __init__(self):
        super(LedidiNet, self).__init__()
        self.name = "LedidiNet"

        self.cnn_1 = nn.Conv1d(2, 100, 3) 
        self.cnn_2 = nn.Conv1d(100, 200, 3, padding=2) 
        self.cnn_3 = nn.Conv1d(200, 400, 3, padding=2) 
        self.cnn_4 = nn.Conv1d(400, 600, 3, padding=2) 
        self.cnn_5 = nn.Conv1d(600, 300, 3, padding=2)
        self.cnn_6 = nn.Conv1d(300, 100, 3, padding=2)
        self.cnn_7 = nn.Conv1d(100, 50, 3, padding=2) 


        self.softplus = nn.Softplus()
        self.maxpool = nn.MaxPool1d(2) 

        self.cnns = nn.Sequential(self.cnn_1, self.maxpool, self.softplus, 
                                     self.cnn_2, self.maxpool, self.softplus,
                                     self.cnn_3, self.maxpool, self.softplus,
                                     self.cnn_4, self.maxpool, self.softplus,
                                     self.cnn_5, self.maxpool, self.softplus,
                                     self.cnn_6, self.maxpool, self.softplus, 
                                     self.cnn_7, self.maxpool, self.softplus)

        self.fc1 = nn.Linear(100, 50)
        self.fc2 = nn.Linear(50, 1)
    
    def forward(self, pair):
        x = self.cnns(pair)
        x = x.view(-1, 1).T
        x = self.fc1(x)
        x = self.fc2(x)
        x = torch.sigmoid(x)
        return x

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
model = LedidiNet()
model_name = model.name
model_id = "06242020"
model.to(device)

checkpoint = '../model_checkpoints/binary/%s/%s.pth' % (model_name, model_id)
checkpointed_model = torch.load(checkpoint)
model.load_state_dict(checkpointed_model['model_state_dict'])
init_epoch = checkpointed_model['epoch'] +1
print("Reloading model: ", model.name, " at epoch: ", init_epoch)

Reloading model:  LedidiNet  at epoch:  1


## Create the dummy input

In [5]:
class GeneratedDataset(Dataset):
    def __init__(self, n):
        def construct_generated_dataset(k):
            S_new = []
            for _, i in enumerate(tqdm.tqdm(range(k))):
                pair = (get_x(), get_y())
                S_new.append(pair)
            np.random.shuffle(S_new)
            return S_new
        
        # Sample x from P_X (assume apatamers follow uniform)
        def get_x():
            x_idx = np.random.randint(0, 4, 40)
            x = ""
            for i in x_idx:
                x += na_list[i]
            return x

        # Sample y from P_y (assume peptides follow NNK)
        def get_y():
            y_idx = np.random.choice(20, 7, p=pvals)
            y = "M"
            for i in y_idx:
                y += aa_list[i]
            return y
        self.gen_ds = construct_generated_dataset(n)
    def __len__(self):
        return len(self.gen_ds)
    
    def __getitem__(self, idx):
        return (self.gen_ds[idx])

In [6]:
## Takes a peptide and aptamer sequence and converts to stacked translate sequence
def stacked_translate(sequence, seq_type='peptide', single_alphabet=True):
    if single_alphabet:
        apt = sequence[0]
        pep = sequence[1]
        
        encoding = np.zeros((2, len(apt)))
        
        # Encode the aptamer first
        for i in range(len(apt)):
            char = apt[i]
            idx = na_list.index(char)
            encoding[0][i] = idx
            
        # Encode the peptide second
        for i in range(len(pep)):
            char = pep[i]
            idx = aa_list.index(char)
            encoding[1][i] = idx
        return encoding     

# Convert a pair to one-hot tensor
def convert(apt, pep, label, single_alphabet=False): 
    if single_alphabet:
        pair = stacked_translate([apt, pep], single_alphabet=True) #(2, 40)
        pair = torch.FloatTensor(np.reshape(pair, (-1, pair.shape[0], pair.shape[1]))).to(device)
        label = torch.FloatTensor([[label]]).to(device)
        return pair, label
    else:
        apt = translate(apt, seq_type='aptamer') #(40, )
        pep = translate(pep, seq_type='peptide') #(8, )
        apt = torch.FloatTensor(np.reshape(apt, (-1, 1, apt.shape[0]))).to(device) #(1, 1, 40)
        pep = torch.FloatTensor(np.reshape(pep, (-1, 1, pep.shape[0]))).to(device) #(1, 1, 8)
        label = torch.FloatTensor([[label]]).to(device)
        return apt, pep, label

In [7]:
S_new = GeneratedDataset(2)
p_val, l_val = convert(S_new[0][0], S_new[0][1], 0, single_alphabet=True)

100%|██████████| 2/2 [00:00<00:00, 1852.61it/s]


In [None]:
dummy_input = torch.Tensor(p_val)
dummy_input.requires_grad=True

input_names = ["input"]
output_names = ["output"]
ONNX_MODEL_PATH = "ledidi_net.onnx"
torch.onnx.export(model, dummy_input, ONNX_MODEL_PATH, verbose=True, input_names=input_names, output_names=output_names )

