# Protein Structure Machine Translation

Based on the succes of Neural Machine Translation, I'm going to try a seq2seq type of model. The encoder with be the same as the current RGN model; instead of predicting torsion angles directly after the encoder, I will implement a decoder network with attention to assist in the folding process. First and foremost I hope that the shorter distance between the trainable parameters and the evaluation of loss will improve the flow of gradients (this argument doesn't make much sense though). More importantly, the network allows information from the folding process to be directly incorporated back into the trainable parameters. There is also the possibility that by using teacher forcing (feeding in the real output to the next decoder timestep) the network will be able to learn about the folding process in addition to the features of the input sequence.

The main inspiration for this approach comes from https://github.com/spro/practical-pytorch/blob/master/seq2seq-translation/seq2seq-translation-batched.ipynb.

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import numpy as np
import pandas as pd
import ipywidgets as ip
from matplotlib import pyplot as plt
import os
#import utils
from fastai import *
import matplotlib
import seaborn as sns
from tqdm import tqdm
import collections
from collections import Counter as cs
#import nglview as nv
import sys
import Bio.PDB as bio
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from keras.utils.np_utils import to_categorical
import torch.optim
import pdb
%matplotlib inline

Using TensorFlow backend.


In [4]:
import utils
from data import ProteinDataset, sequence_collate
from model import geometric_unit, dRMSD
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

In [5]:
plt.rcParams['figure.figsize'] = [16,6]

In [6]:
data_path = os.curdir + '/data/'
pdb_path = os.curdir + '/data/pdb/structures/pdb/'

## Dataloader and Model

In [7]:
dataset = ProteinDataset(data_path, 'short', encoding='tokens')
trn_data = DataLoader(dataset, batch_size=14, shuffle=True, collate_fn=sequence_collate)

In [8]:
aa2vec = bcolz.open(data_path + 'c3_embs.bc')

def create_emb_layer(aa2vec, requires_grad=True):
    aa2vec = torch.tensor(aa2vec, requires_grad=requires_grad)
    #blank_vec = torch.zeros(aa2vec.shape[1]).view(1,-1)
    #aa2vec = torch.cat([aa2vec, blank_vec], dim=0)
    
    vocab_sz, embed_dim = aa2vec.size()
    emb_layer = nn.Embedding(vocab_sz, embed_dim)
    emb_layer.load_state_dict({'weight': aa2vec})
    if requires_grad == False:
        emb_layer.weight.requires_grad = False

    return emb_layer, vocab_sz, embed_dim

In [9]:
class ProteinEncoder(nn.Module):
    def __init__(self, aa2vec, hidden_size, num_layers):
        super(ProteinEncoder, self).__init__()
        
        input_size = aa2vec.shape[1]
        self.aa2vec = aa2vec
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.embeds, vocab_sz, embed_dim = create_emb_layer(aa2vec, requires_grad=True)
        self.gru = nn.GRU(input_size, hidden_size, num_layers, bidirectional=True)
    
    def forward(self, x):
        inp_seq = x[0]
        inp_lens = x[1]
        order = [x for x,y in sorted(enumerate(inp_lens), key=lambda x: x[1], reverse=True)]
        
        #forward propagate lstm
        emb = self.embeds(inp_seq)
        packed = pack_padded_sequence(emb[:, order], inp_lens[order], batch_first=False)
        
        gru_out, hid_out = self.gru(packed, None)
        unpacked, _ = pad_packed_sequence(gru_out, batch_first=False, padding_value=0.0)
        output = unpacked[:, range(inp_seq.size(1))] #reorder to match target

        #sum the context vectors
        outputs = unpacked[:, :, :self.hidden_size] + unpacked[:, :, self.hidden_size:]
        
        return outputs, hid_out

In [10]:
for i_batch, sampled_batch in enumerate(trn_data):
    inp_seq = torch.tensor(sampled_batch['sequence'], dtype=torch.long, requires_grad=True)
    inp_lens = torch.tensor(sampled_batch['length'], dtype=torch.long, requires_grad=False)
    enc = ProteinEncoder(aa2vec,20,1)
    enc_out, enc_hi = enc([inp_seq, inp_lens])
    print(i_batch, inp_seq.size(), sampled_batch['coords'].size(), enc_out.size(), enc_hi.size())
    
    if i_batch == 2:
        break

(0, torch.Size([29, 14]), torch.Size([87, 14, 3]), torch.Size([29, 14, 20]), torch.Size([2, 14, 20]))
(1, torch.Size([29, 14]), torch.Size([87, 14, 3]), torch.Size([29, 14, 20]), torch.Size([2, 14, 20]))
(2, torch.Size([29, 14]), torch.Size([87, 14, 3]), torch.Size([29, 14, 20]), torch.Size([2, 14, 20]))


In [11]:
class Attn(nn.Module):
    def __init__(self, hid_sz):
        super(Attn, self).__init__()
        #self.method = method
        self.hid_sz = hid_sz
        
        #if self.method == 'general':
        self.attn = nn.Linear(hid_sz, hid_sz)
        #elif self.method == 'concat':
        #    self.attn = nn.Linear(hidden_size*2, hidden_size)
        #    self.v = nn.Parameter(torch.FloatTensor(1, hidden_size))
            
    def forward(self, prev_hi, enc_out):
        max_len = enc_out.size(0)
        bs = enc_out.size(1)
        
        attn_energies = torch.zeros((bs, max_len))
        
        for b in range(bs):
            for i in range(max_len):
                attn_energies[b, i] = self.score(prev_hi[0, b], enc_out[i, b])
        
        return F.softmax(attn_energies, dim=-1)
    
    def score(self, hidden, enc_out):
        
        #if self.method == 'dot':
        #    energy = hidden.mm(enc_out)
        #    return energy
        
        #elif self.method == 'general':
        energy = self.attn(enc_out)
        energy = hidden.dot(energy)
        return energy
        
        #TODO: make this work?
        #elif self.method == 'concat':
        #    energy = self.attn(torch.cat((hidden, enc_out), 1))
        #    energy = self.v.mm(energy)
        #    return energy

In [12]:
hid_sz = 20
for i_batch, sampled_batch in enumerate(trn_data):
    inp_seq = torch.tensor(sampled_batch['sequence'], dtype=torch.long, requires_grad=True)
    inp_lens = torch.tensor(sampled_batch['length'], dtype=torch.long, requires_grad=False)
    enc = ProteinEncoder(aa2vec,hid_sz,1)
    enc_out, enc_hi = enc([inp_seq, inp_lens])
    dec_hi = enc_hi[:1]
    
    attn = Attn(hid_sz)
    attn_out = attn(dec_hi, enc_out)
    print(attn_out.size())
    #softmax in correct dimension because each sample in batch adds to 1
    
    if i_batch == 2:
        break

torch.Size([14, 29])
torch.Size([14, 29])
torch.Size([14, 29])


In [13]:
class ProteinDecoder(nn.Module):
    def __init__(self, aa2vec, hid_sz, num_layers):
        super(ProteinDecoder, self).__init__()
        
        input_size = aa2vec.shape[1]
        self.aa2vec = aa2vec
        self.hid_sz = hid_sz
        self.num_layers = num_layers
        
        self.embeds, vocab_sz, embed_dim = create_emb_layer(aa2vec, requires_grad=True)
        #self.attn = Attn(hid_sz)
        self.gru = nn.GRU(hid_sz+embed_dim, hid_sz, num_layers)
        
        self.linear1 = nn.Linear(hidden_size, 3)
        self.linear2 = nn.Linear(hidden_size, 3)
        self.hardtanh = nn.Hardtanh()
        
        #as per Mohammed, we simply use the identity matrix to define the first 3 residues
        self.A = torch.tensor([0., 0., 1.])
        self.B = torch.tensor([0., 1., 0.])
        self.C = torch.tensor([1., 0., 0.])

        #bond length vectors C-N, N-CA, CA-C
        self.avg_bond_lens = torch.tensor([1.329, 1.459, 1.525])
        #bond angle vector, in radians, CA-C-N, C-N-CA, N-CA-C
        self.avg_bond_angles = torch.tensor([2.034, 2.119, 1.937])
    
    def forward(self, prev_aa, last_hidden, enc_out, prev_pred_coords):
        
        #forward propagate lstm
        emb = self.embeds(prev_aa)
        attn_weights = self.attn(last_hidden, enc_out).unsqueeze(1)
        context = attn_weights.bmm(enc_out.transpose(0, 1))
        context = context.transpose(0, 1)
        
        rnn_input = torch.cat((emb, context), 2)
        gru_out, hid_out = self.gru(rnn_input, last_hidden)
        
        sin_out = self.hardtanh(self.linear1(gru_out))
        cos_out = self.hardtanh(self.linear2(gru_out))
        out = torch.atan2(sin_out, cos_out).squeeze(0)

        new_pred_coords = geometric_unit(prev_pred_coords, out, 
                                         self.avg_bond_angles, self.avg_bond_lens)

        return new_pred_coords, hid_out, attn_weights

In [58]:
hidden_size = 100
num_layers = 3

encoder = ProteinEncoder(aa2vec, hidden_size, num_layers)
decoder = ProteinDecoder(aa2vec, hidden_size, num_layers)

enc_opt = torch.optim.Adam(encoder.parameters(), lr=1e-3)
dec_opt = torch.optim.Adam(decoder.parameters(), lr=1e-3)

In [None]:
A = torch.tensor([0., 0., 1.])
B = torch.tensor([0., 1., 0.])
C = torch.tensor([1., 0., 0.])

drmsd = dRMSD()
running_loss=0.0

for epoch in range(50):
    pa=3
    for i, data in enumerate(trn_data):
        try:
            inp_seq = torch.tensor(data['sequence'], dtype=torch.long, requires_grad=False)
            inp_lens = torch.tensor(data['length'], dtype=torch.long, requires_grad=False)
            gt_coords = data['coords']

            enc_out, enc_hi = encoder([inp_seq, inp_lens])

            # Prepare decoder input and outputs
            prev_aa = inp_seq[0].unsqueeze(0)
            last_hidden = enc_hi[:decoder.num_layers]

            broadcast = torch.ones((inp_seq.size(1), 3))
            new_pred_coords = torch.stack([A*broadcast, B*broadcast, C*broadcast])

            # Run through decoder one time step at a time
            for t in range(1, max(inp_lens)):
                new_pred_coords, last_hidden, decoder_attn = decoder(
                    prev_aa, last_hidden, enc_out, new_pred_coords
                )
                prev_aa = inp_seq[t].unsqueeze(0) # Next input is current target

            enc_opt.zero_grad()
            dec_opt.zero_grad()

            loss = drmsd(new_pred_coords, gt_coords)
            loss.backward()

            nn.utils.clip_grad_norm_(encoder.parameters(), max_norm=1)
            nn.utils.clip_grad_norm_(decoder.parameters(), max_norm=1)

            enc_opt.step()
            dec_opt.step()

            running_loss += loss.item()
            if (i != 0) and (i % 2 == 0):
                print('Epoch {}, Loss {}'.format(epoch, running_loss/pa))
                running_loss = 0.0
            
        except KeyboardInterrupt:
            raise
            
        except:
            pa-=1
            pass

Epoch 0, Loss 11.0974680583
Epoch 1, Loss 10.6576738358
Epoch 2, Loss 10.8651529948
Epoch 3, Loss 10.022840182
