# RGN Modeling

The original paper is at: https://www.biorxiv.org/content/biorxiv/early/2018/02/14/265231.full.pdf

There are a lot of details missing, but the architecture is fairly simple. Feed sequences into an bi-LSTM and predict a set of three torsion angles. Pass the three predictions along with the current atoms for each residue into a "geometric unit", add each residue sequentially and deform the "nascent structure" appropriately. The last step is to calculate the loss, distance-based root mean square deviation (dRMSD), which accounts for global and local structural details and importantly does not require a specific orientation of the predicted structure since it only considers distance between pairs of atoms.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import ipywidgets as ip
from matplotlib import pyplot as plt
import os
from tqdm import tqdm
from collections import Counter as cs
import sys
import Bio.PDB as bio
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim
import pdb
%matplotlib inline

In [3]:
import utils
from data import ProteinNet, sequence_collate
from model import *
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

Using TensorFlow backend.


In [8]:
data_path = os.curdir + '/data/'
pdb_path = os.curdir + '/data/pdb/structures/pdb/'

## Pytorch Dataloader

In [5]:
#download data from github and run the proteinnet notebook first
trn_dataset = ProteinNet(data_path+'train30.bc')
val_dataset = ProteinNet(data_path+'validation.bc')

In [6]:
trn_data = DataLoader(trn_dataset, batch_size=32, shuffle=True, collate_fn=sequence_collate)
val_data = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=sequence_collate)

In [7]:
#there should be exactly 3 coordinates for each residue
for i_batch, sample_batched in enumerate(trn_data):
    vec = sample_batched['sequence']
    print(i_batch, sample_batched['sequence'].size(),
         sample_batched['coords'].size())
    if i_batch == 3:
        break

(0, torch.Size([622, 32, 41]), torch.Size([1866, 32, 3]))
(1, torch.Size([676, 32, 41]), torch.Size([2028, 32, 3]))
(2, torch.Size([695, 32, 41]), torch.Size([2085, 32, 3]))
(3, torch.Size([731, 32, 41]), torch.Size([2193, 32, 3]))


## RGN Model

In [9]:
class RGN(nn.Module):
    def __init__(self, hidden_size, num_layers, linear_units=20, input_size=42):
        super(RGN, self).__init__()

        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.input_size = input_size
        self.linear_units = linear_units
        self.grads = {}
        
        self.lstm = nn.LSTM(self.input_size, hidden_size, num_layers, bidirectional=True)
        
        #initialize alphabet to random values between -pi and pi
        u = torch.distributions.Uniform(-3.14, 3.14)
        self.alphabet = nn.Parameter(u.rsample(torch.Size([linear_units, 3])))
        self.linear = nn.Linear(hidden_size*2, linear_units)
        
        #set coordinates for first 3 atoms to identity matrix
        self.A = torch.tensor([0., 0., 1.])
        self.B = torch.tensor([0., 1., 0.])
        self.C = torch.tensor([1., 0., 0.])

        #bond length vectors C-N, N-CA, CA-C
        self.avg_bond_lens = torch.tensor([1.329, 1.459, 1.525])
        #bond angle vector, in radians, CA-C-N, C-N-CA, N-CA-C
        self.avg_bond_angles = torch.tensor([2.034, 2.119, 1.937])

    
    def forward(self, sequences, lengths):
        max_len = sequences.size(0)
        batch_sz = sequences.size(1)
        lengths = torch.tensor(lengths, dtype=torch.long, requires_grad=False)
        order = [x for x,y in sorted(enumerate(lengths), key=lambda x: x[1], reverse=True)]
        conv = zip(range(batch_sz), order) #for unorder after LSTM
        
        #add absolute position of residue to the input vector
        abs_pos = torch.tensor(range(max_len), dtype=torch.float32).unsqueeze(1)
        abs_pos = (abs_pos * torch.ones((1, batch_sz))).unsqueeze(2) #broadcasting
        
        h0 = Variable(torch.zeros((self.num_layers*2, batch_sz, self.hidden_size)))
        c0 = Variable(torch.zeros((self.num_layers*2, batch_sz, self.hidden_size)))
        
        #input needs to be float32 and require grad
        sequences = torch.tensor(sequences, dtype=torch.float32, requires_grad=True)
        pad_seq = torch.cat([sequences, abs_pos], 2)
    
        packed = pack_padded_sequence(pad_seq[:, order], lengths[order], batch_first=False)
        
        lstm_out, _ = self.lstm(packed, (h0,c0))
        unpacked, _ = pad_packed_sequence(lstm_out, batch_first=False, padding_value=0.0)
        
        #reorder back to original to match target
        reorder = [x for x,y in sorted(conv, key=lambda x: x[1], reverse=False)]
        unpacked = unpacked[:, reorder]

        #for example, see https://bit.ly/2lXJC4m
        softmax_out = F.softmax(self.linear(unpacked), dim=2)
        sine = torch.matmul(softmax_out, torch.sin(self.alphabet))
        cosine = torch.matmul(softmax_out, torch.cos(self.alphabet))
        out = torch.atan2(sine, cosine)
        
        #create as many copies of first 3 coords as there are samples in the batch
        broadcast = torch.ones((batch_sz, 3))
        pred_coords = torch.stack([self.A*broadcast, self.B*broadcast, self.C*broadcast])
        
        for ix, triplet in enumerate(out[1:]):
            pred_coords = geometric_unit(pred_coords, triplet, 
                                         self.avg_bond_angles, 
                                         self.avg_bond_lens)
        #pred_coords.register_hook(self.save_grad('pc'))
        
            
        #pdb.set_trace()
        return pred_coords
    
    def save_grad(self, name):
        def hook(grad): self.grads[name] = grad
        return hook

In [10]:
#make sure output size and target sizes are the same
for i_batch, sampled_batch in enumerate(trn_data):
    inp_seq = sampled_batch['sequence']
    inp_lens = sampled_batch['length']
    rgn = RGN(20, 1, 20)
    out = rgn(inp_seq, inp_lens)
    print(i_batch, inp_seq.size(), sampled_batch['coords'].size(), out.size())
    
    if i_batch == 3:
        break

(0, torch.Size([676, 32, 41]), torch.Size([2028, 32, 3]), torch.Size([2028, 32, 3]))
(1, torch.Size([731, 32, 41]), torch.Size([2193, 32, 3]), torch.Size([2193, 32, 3]))
(2, torch.Size([414, 32, 41]), torch.Size([1242, 32, 3]), torch.Size([1242, 32, 3]))
(3, torch.Size([563, 32, 41]), torch.Size([1689, 32, 3]), torch.Size([1689, 32, 3]))


## Training

From the author:

The biggest issue with training these models, especially if you’re using ProteinNet with the full length proteins, is that they’re very seed sensitive, extremely so. Often you won’t find a good seed for hundreds of trials. What I do to get around this problem is set up a milestone scheme where if the validation error hasn’t dropped below a certain threshold by a certain iteration, I kill the model and start over. For example for ProteinNet12, here are my milestones using validation dRMSD (angstroms) (showing iterations not epochs):

 
<ul>
    <li>1k: 13.5</li>
    <li>5k: 12.6</li>
    <li>20k: 12.2</li>
    <li>50k: 11.4</li>
    <li>100k: 10.6</li>
</ul>

In [11]:
#hyperparameters are directly from the paper's author
rgn = RGN(800, 2, linear_units=60)
#rgn.load_state_dict(torch.load(data_path+'models/rgn.pt')) #load pretrained model
optimizer = torch.optim.Adam(rgn.parameters(), lr=1e-3)
drmsd = dRMSD()

In [None]:
running_loss = 0.0

for epoch in range(30):
    last_batch = len(trn_data) - 1
    for i, data in tqdm(enumerate(trn_data)):
        names = data['name']
        coords = data['coords']
        mask = data['mask']

        optimizer.zero_grad()
        outputs = rgn(data['sequence'], data['length'])

        loss = drmsd(outputs, coords, mask)

        loss.backward()
        nn.utils.clip_grad_norm_(rgn.parameters(), max_norm=50)
        optimizer.step()

        running_loss += loss.item()
        if (i != 0) and (i % last_batch == 0):
            print('Epoch {}, Train Loss {}'.format(epoch, running_loss/i))
            running_loss = 0.0
            break
            
    last_batch = len(val_data) - 1
    for i, data in tqdm(enumerate(val_data)):
        names = data['name']
        coords = data['coords']
        mask = data['mask']
        
        outputs = rgn(data['sequence'], data['length'])
        loss = drmsd(outputs, coords, mask)

        running_loss += loss.item()
        if (i != 0) and (i % last_batch == 0):
            print('Epoch {}, Val Loss {}'.format(epoch, running_loss/i))
            running_loss = 0.0
            
print('Finished Training')

In [154]:
torch.save(rgn.state_dict(), data_path+'models/rgn.pt')