# RGN Modeling

The original paper is at: https://www.biorxiv.org/content/biorxiv/early/2018/02/14/265231.full.pdf

There are a lot of details missing, but the architecture is fairly simple. Feed sequences into an bi-LSTM and predict a set of three torsion angles. Pass the three predictions along with the current atoms for each residue into a "geometric unit", add each residue sequentially and deform the "nascent structure" appropriately. The last step is to calculate the loss, distance-based root mean square deviation (dRMSD), which accounts for global and local structural details and importantly does not require a specific orientation of the predicted structure since it only considers distance between pairs of atoms.

In [64]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [65]:
import numpy as np
import ipywidgets as ip
from matplotlib import pyplot as plt
import os
from tqdm import tqdm
from collections import Counter as cs
import sys
import Bio.PDB as bio
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim
import pdb
%matplotlib inline

In [88]:
import utils
from data import ProteinDataset, ProteinNet, sequence_collate
from model import *
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

In [67]:
data_path = os.curdir + '/data/'
pdb_path = os.curdir + '/data/pdb/structures/pdb/'
encoding = 'onehot'

In [68]:
#t = bcolz.carray(rootdir=data_path+'train30.bc')[0]

In [69]:
#for creating new data subsets based on sequence lengths
#sp = utils.subset(data_path+'proteins_1.bc', 150, 50, save_path=data_path+'proteins_short.bc')

## Pytorch Dataloader

First construct the dataloader for training the model

Know PDB file errors and issues:
<ul>
    <li>38 of 1992 chain_1 proteins have no coordinates, caused by weird files like pdb5da6.ent</li>
    <li>some chain_1 proteins have hetatms in the main coordinate section because the residues are special transformations of the standard residue (i.e. selenomethionone in pdb1rfe.ent)</li>
    <li>in 634 of 1992 chain_1 proteins the index of the last residue is greater than the number of residues in the sequence, because atoms in many files do not start at one (neither does sequence)</li>
</ul>

In [70]:
no_chains = '1' #1-8, number of chains in the protein
dataset_size = len(bcolz.open(data_path+'proteins_{}.bc'.format(no_chains)))
mask = np.random.random(dataset_size) < 0.8
trn_ixs = np.arange(dataset_size)[mask]
val_ixs = np.arange(dataset_size)[~mask]

In [71]:
trn_dataset = ProteinNet(data_path+'train30.bc')
val_dataset = ProteinNet(data_path+'validation.bc')
#trn_dataset = ProteinDataset(data_path+'proteins_1.bc', encoding=encoding, indices=trn_ixs)
#val_dataset = ProteinDataset(data_path+'proteins_1.bc', encoding=encoding, indices=val_ixs)

In [72]:
trn_data = DataLoader(trn_dataset, batch_size=32, shuffle=True, collate_fn=sequence_collate)
val_data = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=sequence_collate)
#trn_data = DataLoader(trn_dataset, batch_size=32, shuffle=True, collate_fn=sequence_collate)
#val_data = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=sequence_collate)

In [73]:
for i_batch, sample_batched in enumerate(trn_data):
    vec = sample_batched['sequence']
    print(i_batch, sample_batched['sequence'].size(),
         sample_batched['coords'].size())
    if i_batch == 3:
        break

(0, torch.Size([322, 32, 41]), torch.Size([966, 32, 3]))
(1, torch.Size([695, 32, 41]), torch.Size([2085, 32, 3]))
(2, torch.Size([731, 32, 41]), torch.Size([2193, 32, 3]))
(3, torch.Size([450, 32, 41]), torch.Size([1350, 32, 3]))


In [74]:
sample_batched['mask'].size()

torch.Size([395, 32])

## RGN Model

In [158]:
class RGN(nn.Module):
    def __init__(self, hidden_size, num_layers, model_type='hardtanh', input_type='onehot', 
                 aa2vec=None, linear_units=None, input_size=42):
        super(RGN, self).__init__()

        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.input_type = input_type
        self.model_type = model_type
        self.grads = {}
        
        if self.input_type == 'onehot':
            self.input_size = input_size
        elif self.input_type == 'tokens':
            self.embeds, vocab_sz, embed_dim = create_emb_layer(data_path + 'c3_embs.bc')
            self.input_size = embed_dim + 1
        
        self.lstm = nn.LSTM(self.input_size, hidden_size, num_layers, bidirectional=True)
        
        if self.model_type == 'hardtanh':
            self.linear1 = nn.Linear(hidden_size*2, 3)
            self.bn1 = nn.BatchNorm1d(3)
            self.linear2 = nn.Linear(hidden_size*2, 3)
            self.bn2 = nn.BatchNorm1d(3)
            self.hardtanh = nn.Hardtanh()
        elif self.model_type == 'alphabet':
            u = torch.distributions.Uniform(-3.14, 3.14)
            self.alphabet = nn.Parameter(u.rsample(torch.Size([linear_units,3])))
            self.linear = nn.Linear(hidden_size*2, linear_units)
        
        #set first coordinates to approximate values
        self.A = torch.tensor([0.,0.,0.])
        self.B = torch.tensor([1.384,-0.348,-0.463])
        self.C = torch.tensor([1.920,0.789,-1.319])

        #bond length vectors C-N, N-CA, CA-C
        self.avg_bond_lens = torch.tensor([1.329, 1.459, 1.525])
        #bond angle vector, in radians, CA-C-N, C-N-CA, N-CA-C
        self.avg_bond_angles = torch.tensor([2.034, 2.119, 1.937])

    
    def forward(self, sequences, lengths):
        max_len = sequences.size(0)
        batch_sz = sequences.size(1)
        lengths = torch.tensor(lengths, dtype=torch.long, requires_grad=False)
        order = [x for x,y in sorted(enumerate(lengths), key=lambda x: x[1], reverse=True)]
        
        abs_pos = torch.tensor(range(max_len), dtype=torch.float32).unsqueeze(1)
        abs_pos = (abs_pos * torch.ones((1, batch_sz))).unsqueeze(2)
        
        h0 = Variable(torch.zeros((self.num_layers*2, batch_sz, self.hidden_size)))
        c0 = Variable(torch.zeros((self.num_layers*2, batch_sz, self.hidden_size)))
        
        #set sequence input type
        if self.input_type == 'onehot':
            sequences = torch.tensor(sequences, dtype=torch.float32, requires_grad=True)
            pad_seq = torch.cat([sequences, abs_pos], 2)
        elif self.input_type == 'tokens':
            sequences = torch.tensor(sequences, dtype=torch.long, requires_grad=False)
            pad_seq = self.embeds(sequences)
            pad_seq = torch.cat([pad_seq, abs_pos], 2)
    
        packed = pack_padded_sequence(pad_seq[:, order], lengths[order], batch_first=False)
        
        lstm_out, _ = self.lstm(packed, (h0,c0))
        unpacked, _ = pad_packed_sequence(lstm_out, batch_first=False, padding_value=0.0)
        unpacked = unpacked[:, range(batch_sz)] #reorder to match target

        if self.model_type == 'hardtanh':
            flat = unpacked.view(-1, unpacked.size(2))
            sin_out = self.hardtanh(self.bn1(self.linear1(flat))).view(max_len, batch_sz, 3)
            cos_out = self.hardtanh(self.bn2(self.linear2(flat))).view(max_len, batch_sz, 3)
            out = torch.atan2(sin_out, cos_out)
            out.register_hook(self.save_grad('out'))
        elif self.model_type == 'alphabet':
            softmax_out = F.softmax(self.linear(unpacked), dim=2)
            sine = torch.matmul(softmax_out, torch.sin(self.alphabet))
            cosine = torch.matmul(softmax_out, torch.cos(self.alphabet))
            out = torch.atan2(sine, cosine)
        
        #create as many copies of first residue as there are samples in the batch
        broadcast = torch.ones((batch_sz, 3))
        pred_coords = torch.stack([self.A*broadcast, self.B*broadcast, self.C*broadcast])
        
        for ix, triplet in enumerate(out[1:]):
            pred_coords = geometric_unit(pred_coords, triplet, 
                                         self.avg_bond_angles, 
                                         self.avg_bond_lens)
        #pred_coords.register_hook(self.save_grad('pc'))
        
            
        #pdb.set_trace()
        return pred_coords
    
    def save_grad(self, name):
        def hook(grad): self.grads[name] = grad
        return hook

In [149]:
#make sure output size and target sizes are the same
for i_batch, sampled_batch in enumerate(trn_data):
    inp_seq = sampled_batch['sequence']
    inp_lens = sampled_batch['length']
    rgn = RGN(20, 1, 'hardtanh', encoding)
    out = rgn(inp_seq, inp_lens)
    print(i_batch, inp_seq.size(), sampled_batch['coords'].size(), out.size())
    
    if i_batch == 3:
        break

(0, torch.Size([382, 32, 41]), torch.Size([1146, 32, 3]), torch.Size([1146, 32, 3]))
(1, torch.Size([450, 32, 41]), torch.Size([1350, 32, 3]), torch.Size([1350, 32, 3]))
(2, torch.Size([731, 32, 41]), torch.Size([2193, 32, 3]), torch.Size([2193, 32, 3]))
(3, torch.Size([622, 32, 41]), torch.Size([1866, 32, 3]), torch.Size([1866, 32, 3]))


## Training

In [159]:
rgn = RGN(500, 1, 'hardtanh', encoding)
#rgn.load_state_dict(torch.load(data_path+'models/rgn_no_bn.pt')) #load pretrained model
optimizer = torch.optim.Adam(rgn.parameters(), lr=1e-3)
drmsd = dRMSD()

Next steps, try debugging gradient using https://gist.github.com/apaszke/f93a377244be9bfcb96d3547b9bc424d.

In [160]:
running_loss = 0.0

for epoch in range(1):
    last_batch = len(trn_data) - 1
    for i, data in tqdm(enumerate(trn_data)):
        names = data['name']
        coords = data['coords']
        mask = data['mask']

        optimizer.zero_grad()
        outputs = rgn(data['sequence'], data['length'])

        loss = drmsd(outputs, coords, mask)

        loss.backward()
        nn.utils.clip_grad_norm_(rgn.parameters(), max_norm=50)
        optimizer.step()

        running_loss += loss.item()
        if (i != 0) and (i % last_batch == 0):
            print('Epoch {}, Train Loss {}'.format(epoch, running_loss/i))
            running_loss = 0.0
            break
    """
    last_batch = len(val_data) - 1
    for i, data in tqdm(enumerate(val_data)):
        names = data['name']
        coords = data['coords']
        mask = data['mask']
        
        outputs = rgn(data['sequence'], data['length'])
        loss = drmsd(outputs, coords, mask)

        running_loss += loss.item()
        if (i != 0) and (i % last_batch == 0):
            print('Epoch {}, Val Loss {}'.format(epoch, running_loss/i))
            running_loss = 0.0
    """
print('Finished Training')

4it [00:36,  9.10s/it]

Epoch 0, Train Loss 39.637737751
Finished Training





In [154]:
#torch.save(rgn.state_dict(), data_path+'models/rgn_no_bn.pt')