# RGN Modeling

The most recent possibly useful paper that I could find is: https://www.biorxiv.org/content/biorxiv/early/2018/02/14/265231.full.pdf

There are a lot of details missing so I expect that it will be difficult to implement, but the architecture is fairly simple. Feed sequence into an bi-LSTM and try to predict three bond characteristics (angle, extension and torsion). Pass the three predictions along with the current atoms for each residue into a "geometric unit", add each residue sequentially and deform the "nascent structure" appropriately. The last step is to calculate the loss, distance-based root mean square deviation (dRMSD), which accounts for global and local structural details and importantly does not require a specific orientation of the predicted structure since it only considers distance between one atom and all other atoms.

For training data the author uses targets from CASP 1-10 and tests results on CASP 11.

Task list:
<ul>
    <li>Create new bcolz array to attach sequence and structure together</li>
    <li>Pad structures to match length of sequences</li>
    <li>Handling of inaccurate PDB files</li>
</ul>

In [254]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [276]:
import numpy as np
import pandas as pd
import ipywidgets as ip
from matplotlib import pyplot as plt
import os
import matplotlib
import seaborn as sns
from tqdm import tqdm
import collections
from collections import Counter as cs
#import nglview as nv
import sys
import Bio.PDB as bio
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from keras.utils.np_utils import to_categorical
import torch.optim
import pdb
%matplotlib inline

In [277]:
import utils
from data import ProteinDataset, sequence_collate
from model import *
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

In [278]:
data_path = os.curdir + '/data/'
pdb_path = os.curdir + '/data/pdb/structures/pdb/'

In [279]:
#sp = utils.subset(data_path+'proteins_1.bc', 150, 50, save_path=data_path+'proteins_short.bc')

## Pytorch Dataloader

First construct the dataloader for training the model

Know PDB file errors and issues:
<ul>
    <li>38 of 1992 chain_1 proteins have no coordinates, caused by weird files like pdb5da6.ent</li>
    <li>some chain_1 proteins have hetatms in the main coordinate section because the residues are special transformations of the standard residue (i.e. selenomethionone in pdb1rfe.ent)</li>
    <li>in 634 of 1992 chain_1 proteins the index of the last residue is greater than the number of residues in the sequence, because atoms in many files do not start at one (neither does sequence)</li>
</ul>

In [296]:
dataset = ProteinDataset(data_path, 'short', encoding='onehot')
trn_data = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=sequence_collate)

In [297]:
for i_batch, sample_batched in enumerate(trn_data):
    vec = sample_batched['sequence']
    print(i_batch, sample_batched['sequence'].size(),
         sample_batched['coords'].size())
    if i_batch == 3:
        break

(0, torch.Size([149, 32, 20]), torch.Size([447, 32, 3]))
(1, torch.Size([142, 32, 20]), torch.Size([426, 32, 3]))
(2, torch.Size([149, 32, 20]), torch.Size([447, 32, 3]))
(3, torch.Size([148, 32, 20]), torch.Size([444, 32, 3]))


Potential todos with PDB data because of exceptions and errors:
<ul>
    <li>Atoms with multiple possible positions (A, B)</li>
    <li>PDB files with multiple chains</li>
    <li>Masking to use chains with atoms that don't have position 1</li>
    <li>HETATMs like water can play a substantial role in the final folds</li>
    <li>Consider adjusting loss function to reduce penalty for atoms with multiple occupancy</li>
</ul>

NOTE: Always make input tensor a float and wrap the input as an autograd variable!!!

## RGN Model

In [298]:
#aa2vec = bcolz.open(data_path + 'c3_embs.bc')

In [299]:
class RGN(nn.Module):
    def __init__(self, hidden_size, num_layers, model_type='hardtanh', input_type='onehot', 
                 aa2vec=None, linear_units=None, input_size=21):
        super(RGN, self).__init__()

        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.input_type = input_type
        self.model_type = model_type
        self.grads = {}
        
        if self.input_type == 'onehot':
            self.input_size = input_size
        elif self.input_type == 'tokens':
            self.input_size = aa2vec.shape[1] + 1
            self.embeds, vocab_sz, embed_dim = create_emb_layer(data_path + 'c3_embs.bc')
        
        self.lstm = nn.LSTM(self.input_size, hidden_size, num_layers, bidirectional=True)
        
        if self.model_type == 'hardtanh':
            self.linear1 = nn.Linear(hidden_size*2, 3)
            self.linear2 = nn.Linear(hidden_size*2, 3)
            self.hardtanh = nn.Hardtanh()
        elif self.model_type == 'alphabet':
            u = torch.distributions.Uniform(-3.14, 3.14)
            self.alphabet = nn.Parameter(u.rsample(torch.Size([linear_units,3])))
            self.linear = nn.Linear(hidden_size*2, linear_units)
        
        #as per Mohammed, we simply use the identity matrix to define the first 3 residues
        #self.A = torch.tensor([0., 0., 1.])
        #self.B = torch.tensor([0., 1., 0.])
        #self.C = torch.tensor([1., 0., 0.])
        self.A = torch.tensor([0.,0.,0.])
        self.B = torch.tensor([1.384,-0.348,-0.463])
        self.C = torch.tensor([1.920,0.789,-1.319])

        #bond length vectors C-N, N-CA, CA-C
        self.avg_bond_lens = torch.tensor([1.329, 1.459, 1.525])
        #bond angle vector, in radians, CA-C-N, C-N-CA, N-CA-C
        self.avg_bond_angles = torch.tensor([2.034, 2.119, 1.937])

    
    def forward(self, sequences, lengths):
        max_len = sequences.size(0)
        batch_sz = sequences.size(1)
        lengths = torch.tensor(lengths, dtype=torch.long, requires_grad=False)
        order = [x for x,y in sorted(enumerate(lengths), key=lambda x: x[1], reverse=True)]
        
        abs_pos = torch.tensor(range(max_len), dtype=torch.float32).unsqueeze(1)
        abs_pos = (abs_pos * torch.ones((1, batch_sz))).unsqueeze(2)
        
        h0 = Variable(torch.zeros((self.num_layers*2, batch_sz, self.hidden_size)))
        c0 = Variable(torch.zeros((self.num_layers*2, batch_sz, self.hidden_size)))
        
        #set sequence input type
        if self.input_type == 'onehot':
            sequences = torch.tensor(sequences, dtype=torch.float32, requires_grad=True)
            pad_seq = torch.cat([sequences, abs_pos], 2)
        elif self.input_type == 'tokens':
            sequences = torch.tensor(sequences, dtype=torch.long, requires_grad=False)
            pad_seq = self.embeds(sequences)
            pad_seq = torch.cat([pad_seq, abs_pos], 2)
    
        packed = pack_padded_sequence(pad_seq[:, order], lengths[order], batch_first=False)
        
        lstm_out, _ = self.lstm(packed, (h0,c0))
        unpacked, _ = pad_packed_sequence(lstm_out, batch_first=False, padding_value=0.0)
        unpacked = unpacked[:, range(batch_sz)] #reorder to match target

        if self.model_type == 'hardtanh':
            sin_out = self.hardtanh(self.linear1(unpacked))
            cos_out = self.hardtanh(self.linear2(unpacked))
            out = torch.atan2(sin_out, cos_out)
            #out.register_hook(self.save_grad('out'))
        elif self.model_type == 'alphabet':
            softmax_out = F.softmax(self.linear(unpacked), dim=2)
            sine = torch.matmul(softmax_out, torch.sin(self.alphabet))
            cosine = torch.matmul(softmax_out, torch.cos(self.alphabet))
            out = torch.atan2(sine, cosine)
        
        #create as many copies of first residue as there are samples in the batch
        broadcast = torch.ones((batch_sz, 3))
        pred_coords = torch.stack([self.A*broadcast, self.B*broadcast, self.C*broadcast])
        
        for ix, triplet in enumerate(out[1:]):
            pred_coords = geometric_unit(pred_coords, triplet, 
                                         self.avg_bond_angles, 
                                         self.avg_bond_lens)
        #pred_coords.register_hook(self.save_grad('pc'))
        
            
        #pdb.set_trace()
        return pred_coords
    
    def save_grad(self, name):
        def hook(grad): self.grads[name] = grad
        return hook

In [300]:
#for i_batch, sampled_batch in enumerate(trn_data):
#    inp_seq = sampled_batch['sequence']
#    inp_lens = sampled_batch['length']
#    rgn = RGN(20, 1, 'hardtanh', 'onehot')
#    out = rgn(inp_seq, inp_lens)
#    print(i_batch, inp_seq.size(), sampled_batch['coords'].size(), out.size())
    
#    if i_batch == 1:
#        break

In [301]:
def adaptive_lr(optimizer, step_size):
    #for now just linear scaling
    for param_group in optimizer.param_groups:
        param_group['lr'] += step_size
        new_lr = param_group['lr']
    
    return optimizer

In [305]:
rgn = RGN(800, 2, 'hardtanh', 'onehot', aa2vec=aa2vec)
drmsd = dRMSD()

In [306]:
#optimizer = torch.optim.SGD(rgn.parameters(), lr=1e-1, momentum=0.9)
#rgn.load_state_dict(torch.load(data_path+'models/rgn.pt'))
optimizer = torch.optim.Adam(rgn.parameters(), lr=1e-3)

Next steps, try debugging gradient using https://gist.github.com/apaszke/f93a377244be9bfcb96d3547b9bc424d.

In [307]:
loss_history=[]
running_loss = 0.0
last_batch = len(trn_data) - 1
c = 0
for epoch in range(50):
    #c = 0
    for i, data in tqdm(enumerate(trn_data)):
    #for i, data in enumerate(trn_data):
        #try:
        names = data['name']
        coords = data['coords']

        optimizer.zero_grad()
        outputs = rgn(data['sequence'], data['length'])

        loss = drmsd(outputs, coords)

        #print(i, loss.item(), rgn.embeds.state_dict()['weight'][0][0])
        loss.backward()
        nn.utils.clip_grad_norm_(rgn.parameters(), max_norm=1)
        optimizer.step()

        running_loss += loss.item()
        if (i != 0) and (i % last_batch == 0):
            print('Epoch {}, Loss {}'.format(epoch, running_loss/(i-c)))
            running_loss = 0.0
        #except KeyboardInterrupt:
        #    raise
        #except:
        #    c += 1
        #    pass

print('Finished Training')

18it [06:01, 20.08s/it]
0it [00:00, ?it/s]

Epoch 0, Loss 43.1176270878


18it [06:01, 20.10s/it]
0it [00:00, ?it/s]

Epoch 1, Loss 17.2254386229


13it [04:25, 20.42s/it]

KeyboardInterrupt: 

In [224]:
torch.save(rgn.state_dict(), data_path+'models/rgn.pt')
#rgn.load_state_dict(torch.load(data_path+'models/rgn.pt'))

In [None]:
#plt.plot(np.array(loss_history)[:, 0], np.array(loss_history)[:, 1])

In [None]:
torch.__version__()

## Validation

To actually reproduce the results from the RGN paper, I need to use the proteinnet dataset, https://github.com/aqlaboratory/proteinnet. In particular, Mohammed used the CASP 11 data to test his model. The full dataset may be too large for my memory without deleting all the hard work I did with the pdb files. However, if I delete all the PDB files are currently have, I at least still have the tools to reproduce the datasets if necessary.


## Geometric Units

Some basic information about bond angles and lengths can be found here: https://www.ruppweb.org/Xray/tutorial/protein_structure.htm

I'll use this as my primary source, but it may be somewhat inaccurate (I have since found a more reliable source, saved in my Dropbox).

To validate that my implementation of the NERF algorithm is correct, I want to get pdb file, use BioPDB to calculate the torsion angles, and then use the ground truth torsion angles to reconstruct the coordinates. The goal is for the dRMSD between the rendered structure and the gt structure to be zero. This would imply that if my LSTM model can correctly predict the torsion angles the calculated coordinates should match the gt PDB file.

In [162]:
#First find a pdb file with no missing coordinates
chain_1 = load_array(data_path+'proteins_1.bc')

In [163]:
for ix, chain in enumerate(chain_1[:20]):
    msk = chain[2].sum(1) == 0
    if np.any(msk) == False:
        print(ix)

2
15
19


In [164]:
chain_1[2][0]

['1zur']

Protein at index 2 in the proteins_1.bc dataset has no missing atoms, so we can use it for testing

In [165]:
t_angles, b_angles, b_len = utils.gt_dihedral_angles(pdb_path+'pdb1zur.ent')

Note that angles are in radians, whereas my implementation assumes degrees (can remove the 180 muliplication). Angles in omega are all roughly equal to $\pi$ in accordance with literatue I've read

In [180]:
A = torch.tensor(chain_1[2][2][0], dtype=torch.float)
B = torch.tensor(chain_1[2][2][1], dtype=torch.float)
C = torch.tensor(chain_1[2][2][2], dtype=torch.float)

#A = torch.tensor([0., 0., 1.], dtype=torch.float)
#B = torch.tensor([0., 1., 0.], dtype=torch.float)
#C = torch.tensor([1., 0., 0.], dtype=torch.float)

#avg_bond_lens = torch.tensor([1.329, 1.459, 1.525])
#avg_bond_angles = torch.tensor([2.034, 2.119, 1.937])

pred_coords = torch.stack([A, B, C])

for ix,triplet in enumerate(t_angles):
    for i in range(3):
        T = b_angles[ix][i] #avg_bond_angles[i] #angle_BCD
        R = b_len[ix][i] #avg_bond_lens[i] #bond_CD
        P = triplet[i] #torsionBC
        
        D2 = torch.stack([-R*torch.cos(T),
                          R*torch.cos(P)*torch.sin(T),
                          R*torch.sin(P)*torch.sin(T)])

        BC = C - B
        bc = BC/torch.norm(BC, 2)

        AB = B - A

        N = torch.cross(AB, bc)
        n = N/torch.norm(torch.cross(AB, bc), 2)
        
        if ix==0:
            print(n)

        M = torch.stack([bc, torch.cross(n, bc), n], dim=1)

        D = torch.mm(M, D2.view(-1,1)).squeeze() + C
        
        pred_coords = torch.cat([pred_coords, D.view(1,3)])
        
        A = pred_coords[-3]
        B = pred_coords[-2]
        C = pred_coords[-1]

tensor([ 0.3138,  0.2345,  0.9201])
tensor([-0.3717,  0.2770, -0.8861])
tensor([ 0.3293, -0.3031,  0.8942])


In [171]:
pair_dist(pred_coords)[:7, :7]

tensor([[ 0.0000,  1.4874,  2.5037,  3.6471,  4.9417,  6.0591,  7.2858],
        [ 1.4874,  0.0000,  1.5289,  2.4205,  3.8290,  4.7594,  6.0301],
        [ 2.5037,  1.5289,  0.0000,  1.3304,  2.4853,  3.6839,  4.8436],
        [ 3.6471,  2.4205,  1.3304,  0.0000,  1.4620,  2.4165,  3.6473],
        [ 4.9417,  3.8290,  2.4853,  1.4620,  0.0000,  1.5286,  2.4434],
        [ 6.0591,  4.7594,  3.6839,  2.4165,  1.5286,  0.0000,  1.3305],
        [ 7.2858,  6.0301,  4.8436,  3.6473,  2.4434,  1.3305,  0.0000]])

In [170]:
gt_coords = torch.tensor(chain_1[2][2])
pair_dist(gt_coords)[:7, :7]

tensor([[ 0.0000,  1.4875,  2.5037,  3.6471,  4.9417,  6.0591,  7.2858],
        [ 1.4875,  0.0000,  1.5288,  2.4205,  3.8289,  4.7595,  6.0301],
        [ 2.5037,  1.5288,  0.0000,  1.3304,  2.4853,  3.6839,  4.8436],
        [ 3.6471,  2.4205,  1.3304,  0.0000,  1.4620,  2.4165,  3.6473],
        [ 4.9417,  3.8289,  2.4853,  1.4620,  0.0000,  1.5286,  2.4434],
        [ 6.0591,  4.7595,  3.6839,  2.4165,  1.5286,  0.0000,  1.3305],
        [ 7.2858,  6.0301,  4.8436,  3.6473,  2.4434,  1.3305,  0.0000]], dtype=torch.float64)

Notice that there is a considerable amount of error injected into the geometric units when just using average bond lengths and angles. In particular, since bond lengths are fixed, it is actually impossible to train any model that can achieve zero loss (dRMSD is directly affected by the bond lengths). Using the identity matrix as Mohammed suggested also leads to larger errors even when using the gt torsions. I dislike the idea of lazily allowing these sources of loss to remain in the model, but I want to see if it is possible to reproduce the paper's results before jiggering with the architecture. At the very least, it seems like these parameters should be learnable.