In [1]:
#Trains the translator net
import os

import random
import string

import numpy as np
from Bio.Seq import Seq
from Bio.Data import CodonTable

import torch
from torch import nn
import torch.nn.functional as F
import torch_optimizer as optim

import matplotlib.pyplot as plt

os.chdir("/home/ubuntu/projects/olgdesign/")

from st import *
from translator import *

torch.set_num_threads(2)
gpu_id = 0
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
device = torch.device(f'cuda:0') if gpu_id >= 0 else torch.device('cpu')

In [2]:
nucleotides = ['A', 'T', 'G', 'C']
amino_acids = list("ARNDCQEGHILKMFPSTWYV*")

nucs = np.array([0]*100)
nucs[[65,84,71,67]] = [0,1,2,3] #ATGC

aas = np.array([0]*100)
aa_order = np.argsort(list("ARNDCQEGHILKMFPSTWYV"))
aas[[42,65,67,68,69,70,71,72,73,75,76,77,78,80,81,82,83,84,86,87,89]] = [20] + list(aa_order)

len_pro = 500
train_size = 10000
test_size = 1000

In [3]:
'''
#Generate random DNA
def gen_random_dna(len_seqs, number_to_gen):
    #returns list of uniform randomly generated nucleotide strings of a given length
    return [''.join(random.choice(nucleotides) for i in range(len_seqs)) for j in range(0, number_to_gen)]
all_dna  = gen_random_dna(len_pro * 3, train_size + test_size)
random.shuffle(all_dna)

#Generate protein sequences by translating DNA in all frames; forward strand
all_prot_f1 = [ str(Seq(s).translate()) for s in all_dna ]
all_prot_f2 = [ str(Seq(s[1:]).translate()) for s in all_dna ]
all_prot_f3 = [ str(Seq(s[2:]).translate()) for s in all_dna ]

#In reverse strand: reverse_complement, translate, then reverse
all_prot_r1 = [ str(Seq(s).reverse_complement().translate())[::-1] for s in all_dna ]
all_prot_r2 = [ str(Seq(s).reverse_complement()[2:].translate())[::-1] for s in all_dna ]
all_prot_r3 = [ str(Seq(s).reverse_complement()[1:].translate())[::-1] for s in all_dna ]
all_prot = [all_prot_f1, all_prot_f2, all_prot_f3, all_prot_r1, all_prot_r2, all_prot_r3]

#DNA seq to tensor
all_dna = np.array(all_dna)
all_dna = all_dna.view('S4').reshape((all_dna.size, -1)).view(np.uint32)
all_dna = torch.nn.functional.one_hot(torch.tensor(nucs[all_dna])).permute((0,2,1))
all_dna = all_dna.to(device) * 1.0

#Protein seq to tensor, with stop
for i in range(len(all_prot)):
    prot = np.array(all_prot[i])
    prot = prot.view('S4').reshape((prot.size, -1)).view(np.uint32)
    prot = F.one_hot(torch.tensor(aas[prot])).permute((0,2,1))
    prot = prot.to(device) * 1.0
    all_prot[i] = prot

#Split train-test sets
train_dna = all_dna[0:train_size]
test_dna = all_dna[train_size:len(all_dna)]
train_withstop = [ all_prot[i][0:train_size] for i in range(len(all_prot)) ]
test_withstop = [ all_prot[i][train_size:len(all_prot[i])] for i in range(len(all_prot)) ]
'''

#torch.save([train_dna, test_dna, train_withstop, test_withstop], "./translator_training_data.pth")
train_dna, test_dna, train_withstop, test_withstop = torch.load("./translator_training_data.pth")

In [5]:
#Loss function for translation net
sim_func = nn.CosineSimilarity(dim=1, eps=1e-16)
def loss_func(pred, target):
    loss = torch.tensor([0.0]).to(device)
    for i in range(len(target)):
        loss += (torch.mean(torch.sum(sim_func(pred[i], target[i]), dim=1)) * -1.0)
    return loss

In [12]:
#512 channel CNN translator
translator = Translator(512).to(device)

#Training translation net
batch_size = 8
n_epoch = 2
n_train = len(train_withstop[0])

opt_params = [ i for i in translator.parameters() ]
optimizer = torch.optim.SGD(opt_params, lr=1e-6, momentum=0.9)

losses = []

for epoch in range(n_epoch):
    running_loss = 0.
    last_loss = 0.
    training_iter = 0
    last_index = 0
    print('epoch ' + str(epoch))
    while last_index < n_train:
        last_index_end = min(last_index + batch_size, n_train)
        input_onehot = train_dna[last_index:last_index_end]
        withstop = [ train_withstop[i][last_index:last_index_end, :, :] for i in range(len(train_withstop)) ]
        stop = [ torch.hstack((train_withstop[i][last_index:last_index_end, 20:21, :], torch.abs(train_withstop[i][last_index:last_index_end, 20:21, :]-1))) for i in range(len(train_withstop)) ]
        
        optimizer.zero_grad()
        out_withstop, out_stop = translator(input_onehot, temperature=1.0)
        loss_withstop = loss_func(out_withstop, withstop)
        loss_stop = loss_func(out_stop, stop)
        loss = loss_withstop + loss_stop
        loss.backward()
        losses += [loss.detach().clone()]
        optimizer.step()
        
        running_loss += loss.item()
        last_index = last_index_end 
        training_iter += 1        
        
        if training_iter % 200 == 199:
            last_loss = running_loss / 200 # loss per batch
            print('  batch {} loss: {}'.format(training_iter + 1, last_loss))
            running_loss = 0.

epoch 0
  batch 200 loss: -4511.618125
  batch 400 loss: -5716.733125
  batch 600 loss: -5942.32625
  batch 800 loss: -5992.0
  batch 1000 loss: -5992.0
  batch 1200 loss: -5992.0
epoch 1
  batch 200 loss: -5962.04
  batch 400 loss: -5992.0
  batch 600 loss: -5992.0
  batch 800 loss: -5992.0
  batch 1000 loss: -5992.0
  batch 1200 loss: -5992.0


In [13]:
#Check error rate on test set
with torch.no_grad():
    test_dna_sub = test_dna[0:100]
    test_withstop_sub = [t[0:100] for t in test_withstop]
    test_stop_sub = [t[0:100][:,20:21,:] for t in test_withstop]
    pred_withstop, pred_stop = translator(test_dna_sub, temperature=1.0)
    
    error_withstop = torch.stack([ torch.mean(torch.sum(torch.argmax(pred_withstop[i], 1) != torch.argmax(test_withstop_sub[i], 1), dim=1)*1.0) / len_pro for i in range(len(pred_withstop)) ])
    error_stop = torch.stack([ torch.mean(torch.sum(pred_stop[i][:,0:1,:] != test_stop_sub[i], dim=1)*1.0) / len_pro for i in range(len(pred_stop)) ])

print(error_withstop)
print(error_stop)

tensor([0., 0., 0., 0., 0., 0.], device='cuda:0')
tensor([0., 0., 0., 0., 0., 0.], device='cuda:0')


In [7]:
torch.save(translator, "./weights/translator/translator_cnn_512ch.pth")

In [27]:
translator(test_dna_sub)[0][0][0]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0',
       grad_fn=<SelectBackward0>)