The goal of this script will be to generate a function that, essentially, can take a coordinate tensor and a mapping between those coordiates and atom identifiers (names) and creates/writes a PDB file with that information.

In [1]:
import prody as pr
import numpy as np
import sys
import os
os.chdir("/home/jok120/protein-transformer/")
sys.path.append("/home/jok120/protein-transformer/scripts")
sys.path.append("/home/jok120/protein-transformer/scripts/utils")
import torch
from tqdm import tqdm
from prody import *
import numpy as np
from os.path import basename, splitext

import transformer.Models
import torch.utils.data
from dataset import ProteinDataset, paired_collate_fn, paired_collate_fn_with_len
from protein.Structure import generate_coords_with_tuples, generate_coords
from losses import inverse_trig_transform, copy_padding_from_gold, drmsd_loss_from_coords, mse_over_angles, combine_drmsd_mse
from protein.Sidechains import SC_DATA, ONE_TO_THREE_LETTER_MAP, THREE_TO_ONE_LETTER_MAP
from utils.structure_utils import onehot_to_seq

# 1. Load a model given a checkpoint (and maybe some args.)

In [2]:
def load_model(chkpt_path):
    """ Given a checkpoint path, loads and returns the specified transformer model. Assumes """
    chkpt = torch.load(chkpt_path)
    model_args = chkpt['settings']
    model_state = chkpt['model_state_dict']
    model_args.postnorm = False
    print(model_args)
    
    the_model = transformer.Models.Transformer(model_args,
                                               d_k=model_args.d_k,
                                               d_v=model_args.d_v,
                                               d_model=model_args.d_model,
                                               d_inner=model_args.d_inner_hid,
                                               n_layers=model_args.n_layers,
                                               n_head=model_args.n_head,
                                               dropout=model_args.dropout)
    the_model.load_state_dict(model_state)
    return the_model

In [3]:
model = load_model("data/checkpoints/casp12_30_ln_11_best.chkpt")

Namespace(batch_size=8, buffering_mode=1, chkpt_path='./data/checkpoints/casp12_30_ln_11', clip=1.0, cluster=False, combined_loss=True, cuda=True, d_inner_hid=32, d_k=12, d_model=64, d_v=12, d_word_vec=64, data='data/proteinnet/casp12_190809_30xsmall.pt', dropout=0, early_stopping=None, epochs=40, eval_train=False, learning_rate=1e-05, log=None, log_file='./data/logs/casp12_30_ln_11.train', lr_scheduling=False, max_token_seq_len=3303, n_head=8, n_layers=6, n_warmup_steps=1000, name='casp12_30_ln_11', no_cuda=False, optimizer='adam', postnorm=False, proteinnet=True, restart=False, rnn=False, save_mode='best', train_only=False, without_angle_means=False)


# 2. Load some data.

In [4]:
def get_data_loader(data_path, n=0, subset="test"):
    """ Given a subset of a dataset as a python dictionary file to make predictions from,
        this function selects n items at random from that dataset to predict. It then returns a DataLoader for those
        items, along with a list of ids.
        """
    data = torch.load(data_path)
    data_subset = data[subset]

    if n is 0:
        train_loader = torch.utils.data.DataLoader(
            ProteinDataset(
                seqs=data_subset['seq'],
                crds=data_subset['crd'],
                angs=data_subset['ang'],
                ),
            num_workers=2,
            batch_size=1,
            collate_fn=paired_collate_fn,
            shuffle=False)
        return train_loader, data_subset["ids"]

    # We just want to predict a few examples
    to_predict = set([s.upper() for s in np.random.choice(data_subset["ids"], n)])  # ["2NLP_D", "3ASK_Q", "1SZA_C"]
    will_predict = []
    ids = []
    seqs = []
    angs = []
    crds = []
    for i, prot in enumerate(data_subset["ids"]):
        if prot.upper() in to_predict and prot.upper() not in will_predict:
            seqs.append(data_subset["seq"][i])
            angs.append(data_subset["ang"][i])
            crds.append(data_subset["crd"][i])
            ids.append(prot)
            will_predict.append(prot.upper())
    assert len(seqs) == n and len(angs) == n or (len(seqs) == len(angs) and len(seqs) < n)

    data_loader = torch.utils.data.DataLoader(
        ProteinDataset(
            seqs=seqs,
            angs=angs,
            crds=crds),
        num_workers=2,
        batch_size=1,
        collate_fn=paired_collate_fn,
        shuffle=False)
    return data_loader, ids

In [23]:
data_loader, ids = get_data_loader('data/proteinnet/casp12_190809_30xsmall.pt')
data_iter = iter(data_loader)

# 3. Use the model to make a prediction

In [24]:
device = torch.device('cpu')
src_seq, src_pos_enc, tgt_ang, tgt_pos_enc, tgt_crds, tgt_crds_enc =  next(data_iter)
print(src_seq.shape)


torch.Size([1, 129, 20])


In [25]:
tgt_ang_no_nan = tgt_ang.clone().detach()
tgt_ang_no_nan[torch.isnan(tgt_ang_no_nan)] = 0
pred = model.predict(src_seq, src_pos_enc)
d_loss, d_loss_normalized, r_loss = drmsd_loss_from_coords(pred, tgt_crds, src_seq, device,
                                                                       return_rmsd=True)
m_loss = mse_over_angles(pred, tgt_ang).to('cpu')
c_loss = combine_drmsd_mse(d_loss, m_loss)
d_loss, r_loss, m_loss, c_loss

time step: 129

(tensor(20.7834), 25.607910700575648, tensor(0.4110), tensor(1039.8550))

In [26]:
tgt_ang_no_nan = tgt_ang.clone().detach()
tgt_ang_no_nan[torch.isnan(tgt_ang_no_nan)] = 0
pred2 = model.forward(src_seq, src_pos_enc, tgt_ang_no_nan, tgt_pos_enc)
d_loss, d_loss_normalized, r_loss = drmsd_loss_from_coords(pred, tgt_crds, src_seq, device,
                                                                       return_rmsd=True)
m_loss = mse_over_angles(pred, tgt_ang).to('cpu')
c_loss = combine_drmsd_mse(d_loss, m_loss)
d_loss, r_loss, m_loss, c_loss

time step: 128

(tensor(20.7834), 25.607910700575648, tensor(0.4110), tensor(1039.8550))

In [9]:
pred[0]

tensor([[ 0.7680, -0.5909, -0.8223,  ...,  0.6796,  0.0528,  0.0445],
        [ 0.7350, -0.5812, -0.4756,  ...,  0.5694, -0.9909,  0.0753],
        [ 0.6226, -0.8203, -0.6212,  ...,  0.3859,  0.9996, -0.0054],
        ...,
        [-0.3407, -0.7865,  0.1944,  ...,  0.2362,  0.9869,  0.0121],
        [-0.3214, -0.7676,  0.3263,  ...,  0.0204,  0.9145,  0.0151],
        [-0.3246, -0.6754,  0.3097,  ...,  0.0797,  0.6726, -0.0161]])

In [10]:
pred.shape

torch.Size([1, 129, 24])

In [11]:
tgt_ang_no_nan.shape

torch.Size([1, 129, 24])

In [12]:
pred = inverse_trig_transform(pred).squeeze()
src_seq = src_seq.squeeze()

In [13]:
coords = generate_coords(pred, pred.shape[0],src_seq, device)

In [14]:
coords.shape, tgt_crds.shape

(torch.Size([1677, 3]), torch.Size([1, 1677, 3]))

In [15]:
one_letter_seq = onehot_to_seq(src_seq.squeeze().detach().numpy())
one_letter_seq

'ANKPMQPITSTANKIVWSDPTRLSTTFSASLLRQRVKVGIAELNNVSGQYVSVYKRPAPKPEGGADAGVIMPNENQSIRTVISGSAENLATLKAEWETHKRNVDTLFASGNAGLGFLDPTAAIVSSDTT'

In [21]:
cur_map = get_13atom_mapping(one_letter_seq)

title = "0925a_pred.pdb"
ttitle = title.replace("pred", "true")
pdbc = PDB_Creator(coords.squeeze(), cur_map)
pdbc.save_pdb(title)
pdbc = PDB_Creator(tgt_crds.squeeze(), cur_map)
pdbc.save_pdb(ttitle)

PDB written to 0925a_pred.pdb.
PDB written to 0925a_true.pdb.


## 3b. Turn off teacher forcing for predicting

# 4. Create a mapping from input seq to atom name list

In [17]:
atom_map_13 = {}
for one_letter in ONE_TO_THREE_LETTER_MAP.keys():
    atom_map_13[one_letter] = ["N", "CA", "C"] + list(SC_DATA[ONE_TO_THREE_LETTER_MAP[one_letter]]["predicted"])
    atom_map_13[one_letter].extend(["PAD"]*(13-len(atom_map_13[one_letter])))

In [18]:
def get_13atom_mapping(seq):
    mapping = []
    for residue in seq:
        mapping.append((ONE_TO_THREE_LETTER_MAP[residue], atom_map_13[residue]))
    return mapping
        

# 5. Given a coordinate tensor and an atom mapping, create a PDB file

In [20]:
class PDB_Creator(object):
    """
        A class for creating PDB files given an atom mapping.
        The Python format string was taken from http://cupnet.net/pdb-format/.
    """
    def __init__(self, coords, mapping, atoms_per_res=13):
        self.coords = coords.detach().numpy()
        self.mapping = mapping
        self.atoms_per_res = atoms_per_res
        self.format_str = "{:6s}{:5d} {:^4s}{:1s}{:3s} {:1s}{:4d}{:1s}   {:8.3f}{:8.3f}{:8.3f}{:6.2f}{:6.2f}          {:>2s}{:2s}"
        self.atom_nbr = 1
        self.res_nbr = 1
        self.defaults = {"alt_loc": "",
                         "chain_id": "",
                         "insertion_code": "",
                         "occupancy": 1,
                         "temp_factor": 0,
                         "element_sym": "",
                         "charge": ""}
        assert self.coords.shape[0] % self.atoms_per_res == 0, f"Coords is not divisible by {atoms_per_res}. {self.coords.shape}"
        self.peptide_bond_full =   np.asarray([[0.519,  -2.968,   1.340],  # CA
                                               [2.029,  -2.951,   1.374],  # C
                                               [2.654,  -2.667,   2.392],  # O 
                                               [2.682,  -3.244,   0.300]]) # next-N
        self.peptide_bond_mobile = np.asarray([[0.519,  -2.968,   1.340],  # CA
                                               [2.029,  -2.951,   1.374],  # C
                                               [2.682,  -3.244,   0.300]]) # next-N

    def get_oxy_coords(self, ca, c, n):
        target_coords = np.array([ca, c, n])
        t = calcTransformation(self.peptide_bond_mobile, target_coords)
        aligned_peptide_bond = t.apply(self.peptide_bond_full)
        return aligned_peptide_bond[2]
        
    
    def coord_generator(self):
        coord_idx = 0
        while coord_idx < self.coords.shape[0]:
            if coord_idx + self.atoms_per_res + 1 < self.coords.shape[0]: 
                next_n = self.coords[coord_idx + self.atoms_per_res + 1]
            else:
                # TODO: Fix oxygen placement for final residue
                next_n = self.coords[-1] +np.array([1.2, 0, 0])
            yield self.coords[coord_idx:coord_idx + self.atoms_per_res], next_n
            coord_idx += self.atoms_per_res
            
    def get_line_for_atom(self, res_name, atom_name, atom_coords, missing=False):
        if missing:
            occupancy = 0
        else:
            occupancy = self.defaults["occupancy"]
        return self.format_str.format("ATOM",
                                      self.atom_nbr,
                                      
                                      atom_name,
                                      self.defaults["alt_loc"],
                                      res_name,
                                      
                                      self.defaults["chain_id"],
                                      self.res_nbr,
                                      self.defaults["insertion_code"],
                                      
                                      atom_coords[0],
                                      atom_coords[1],
                                      atom_coords[2],
                                      occupancy,
                                      self.defaults["temp_factor"],
                                      
                                      atom_name[0],
                                      self.defaults["charge"])
    
    
    def get_lines_for_residue(self, res_name, atom_names, coords, next_n):
        residue_lines = []
        for atom_name, atom_coord in zip(atom_names, coords):
            if atom_name is "PAD" or np.isnan(atom_coord).sum() > 0:
                continue
#             if np.isnan(atom_coord).sum() > 0:
#                 residue_lines.append(self.get_line_for_atom(res_name, atom_name, atom_coord, missing=True))
#                 self.atom_nbr += 1
#                 continue
            residue_lines.append(self.get_line_for_atom(res_name, atom_name, atom_coord))
            self.atom_nbr += 1
        try:
            oxy_coords = self.get_oxy_coords(coords[1], coords[2], next_n)
            residue_lines.append(self.get_line_for_atom(res_name, "O", oxy_coords))
            self.atom_nbr += 1
        except ValueError:
            pass
        return residue_lines
        
    def get_lines_for_protein(self):
        self.lines = []
        self.res_nbr = 1
        self.atom_nbr = 1
        mapping_coords = zip(self.mapping, self.coord_generator())
        prev_n = torch.tensor([0,0,-1])
        for (res_name, atom_names), (res_coords, next_n) in mapping_coords:
            self.lines.extend(self.get_lines_for_residue(res_name, atom_names, res_coords, next_n))
            prev_n = res_coords[0]
            self.res_nbr += 1
        return self.lines
    
    def make_header(self, title):
        return f"REMARK  {title}"
    
    def make_footer(self):
        return "TER\nEND          \n"
    
    def save_pdb(self, path, title="test"):
        self.get_lines_for_protein()
        self.lines = [self.make_header(title)] + self.lines + [self.make_footer()]
        with open(path, "w") as outfile:
            outfile.write("\n".join(self.lines))
        print(f"PDB written to {path}.")
        
    def get_seq(self):
        return "".join([THREE_TO_ONE_LETTER_MAP[m[0]] for m in self.mapping])
                  

In [None]:
cur_map = get_13atom_mapping(one_letter_seq)

In [None]:
title = "0924f_pred.pdb"
ttitle = title.replace("pred", "true")
pdbc = PDB_Creator(coords.squeeze(), cur_map)
pdbc.save_pdb(title)
pdbc = PDB_Creator(tgt_crds.squeeze(), cur_map)
pdbc.save_pdb(ttitle)

In [None]:
# Align
# p = parsePDB(title)
# t = parsePDB(ttitle)
# print(t.getCoords().shape, p.getCoords().shape)
# tr = calcTransformation(t.getCoords(), p.getCoords())
# t.setCoords(tr.apply(t.getCoords()))

# writePDB(ttitle, t)

In [None]:
import prody

In [None]:
def do_a_prediction(title, data_iter):
    src_seq, src_pos_enc, tgt_ang, tgt_pos_enc, tgt_crds, tgt_crds_enc =  next(data_iter)
    tgt_ang_no_nan = tgt_ang.clone().detach()
    tgt_ang_no_nan[torch.isnan(tgt_ang_no_nan)] = 0
    pred = model(src_seq, src_pos_enc, tgt_ang_no_nan, tgt_pos_enc)
    
    # Calculate loss
    d_loss, d_loss_normalized, r_loss = drmsd_loss_from_coords(pred, tgt_crds, src_seq, device,
                                                                       return_rmsd=True)
    m_loss = mse_over_angles(pred, tgt_ang).to('cpu')
    
    # Generate coords
    pred = inverse_trig_transform(pred).squeeze()
    src_seq = src_seq.squeeze()
    coords = generate_coords(pred, pred.shape[0],src_seq, device)
    
    # Generate coord, atom_name mapping
    one_letter_seq = onehot_to_seq(src_seq.squeeze().detach().numpy())
    cur_map = get_13atom_mapping(one_letter_seq)
    
    # Make PDB Creator objects
    pdb_pred = PDB_Creator(coords.squeeze(), cur_map)
    pdb_true = PDB_Creator(tgt_crds.squeeze(), cur_map)
    
    # Save PDB files
    pdb_pred.save_pdb(f"{title}_pred.pdb")
    pdb_true.save_pdb(f"{title}_true.pdb")
    
    # Align PDB files
    p = parsePDB(f"{title}_pred.pdb")
    t = parsePDB(f"{title}_true.pdb")
    tr = calcTransformation(p.getCoords()[:-1], t.getCoords())
    p.setCoords(tr.apply(p.getCoords()))
    
    writePDB(f"{title}_pred.pdb", p)
    
    print("Constructed PDB files for", title, ".")
    
    
    

In [None]:
do_a_prediction(8, data_iter)

In [None]:
prody.apps.prody_apps.prody_align