# creating predictions for a position given a pdb structure

from https://arxiv.org/pdf/2012.04035.pdf

Environments are extracted from a non-redundant subset of high-resolution structures from the PDB.
Specifically, we use only X-ray structures with resolution <3.0 Å, and enforce a 60% sequence
identity threshold. We then split the dataset by structure based on domain-level CATH 4.2 topology
classes [Dawson et al., 2017], as described in [Anand et al., 2020]. This resulted in a total of
21147, 964, and 3319 PDB structures for the train, validation, and test sets, respectively. Rather
than train on every residue for each of these structures, we balance the classes in the train set by
downsampling to the frequency of the least-common amino acid (cysteine). The original class balance
is preserved in the test set. In total, the train, validation, and test sets comprised 3733710, 188530,
and 1261342 environments, respectively. We ignore all non-standard residues. We represent the
physico-chemical environment around each residue using all C, O, N, S, and P atoms in the protein
and any co-crystallized ligands or ions. All non-backbone atoms of the target residue are removed,
and each environment is centered around a “virtual” Cβ position of the target residue defined using
the average Cβ position over the training set.

In [1]:
import pandas as pd

In [2]:
import gvp
from atom3d.datasets import LMDBDataset
import torch_geometric
from functools import partial
import gvp.atom3d
import torch.nn as nn
import tqdm, torch, time, os
import numpy as np
from atom3d.util import metrics
import sklearn.metrics as sk_metrics
from collections import defaultdict
import scipy.stats as stats
print = partial(print, flush=True)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_id = float(time.time())
print(device)

cuda


# create lmdb data from pdbs


this transforms a pdb file into a lmdb dataset to use with gvp directly.

In [3]:
#see https://github.com/drorlab/atom3d/blob/master/examples/res/dataset/prepare_lmdb.py
import numpy as np
import pandas as pd
import collections as col
import logging
import os
import re
import sys
import scipy.spatial
import parallel as par
import click

sys.path.insert(0, '../../..')
import atom3d.datasets.datasets as da
import atom3d.splits.splits as spl
import atom3d.util.file as fi
import atom3d.util.formats as fo
#import util as res_util

label_res_dict={0:'HIS',1:'LYS',2:'ARG',3:'ASP',4:'GLU',5:'SER',6:'THR',7:'ASN',8:'GLN',9:'ALA',10:'VAL',11:'LEU',12:'ILE',13:'MET',14:'PHE',15:'TYR',16:'TRP',17:'PRO',18:'GLY',19:'CYS'}
res_label_dict={'HIS':0,'LYS':1,'ARG':2,'ASP':3,'GLU':4,'SER':5,'THR':6,'ASN':7,'GLN':8,'ALA':9,'VAL':10,'LEU':11,'ILE':12,'MET':13,'PHE':14,'TYR':15,'TRP':16,'PRO':17,'GLY':18,'CYS':19}
bb_atoms = ['N', 'CA', 'C', 'O']
allowed_atoms = ['C', 'O', 'N', 'S', 'P', 'SE']

# computed statistics from training set
res_wt_dict = {'HIS': 0.581391659111514, 'LYS': 0.266061611865989, 'ARG': 0.2796785729861747, 'ASP': 0.26563454667840314, 'GLU': 0.22814679094919596, 'SER': 0.2612916369563003, 'THR': 0.27832512315270935, 'ASN': 0.3477441570413752, 'GLN': 0.37781509139381086, 'ALA': 0.20421144813311043, 'VAL': 0.22354397064847012, 'LEU': 0.18395198072344454, 'ILE': 0.2631600545792168, 'MET': 0.6918305148744505, 'PHE': 0.3592224851905275, 'TYR': 0.4048964515721682, 'TRP': 0.9882874205355423, 'PRO': 0.32994186046511625, 'GLY': 0.2238561093317741, 'CYS': 1.0}

gly_CB_mu = np.array([-0.5311191 , -0.75842446,  1.2198311 ], dtype=np.float32)
gly_CB_sigma = np.array([[1.63731114e-03, 2.40018381e-04, 6.38361679e-04],
       [2.40018381e-04, 6.87853419e-05, 1.43898267e-04],
       [6.38361679e-04, 1.43898267e-04, 3.25022011e-04]], dtype=np.float32)


class myResTransform(object):
    # pos_oi is for example: [('A', 60, 'TRP'),('A', 61, 'ASP'), ('A', 64, 'LYS'), ('A', 80, 'GLU')]
    # first entry is chain number, position in M1 indexing, and 3letter amino acid code

    def __init__(self, balance=False, pos_oi = []):
        self.balance = balance
        self.pos_oi = pos_oi

    def __call__(self, x):
        x['id'] = fi.get_pdb_code(x['id'])
        df = x['atoms']

        subunits = []
        # df = df.set_index(['chain', 'residue', 'resname'], drop=False)
        df = df.dropna(subset=['x','y','z'])
        #remove Hets and non-allowable atoms
        df = df[df['element'].isin(allowed_atoms)]
        df = df[df['hetero'].str.strip()=='']
        df = df.reset_index(drop=True)
        
        labels = []

        for chain_res, res_df in df.groupby(['chain', 'residue', 'resname']):
            # chain_res is something like ('B', 1, 'MET') 
            # i want to find D60, K63, E79 in chain A
            # ('A', 61, 'ASP')
            # ('A', 64, 'LYS')
            # ('A', 80, 'GLU')
            if chain_res not in self.pos_oi:
                continue
            print()
            print(chain_res)# s
            # chain_res = res_df.index.values[0]
            # names.append('_'.join([str(x) for x in name]))
            chain, res, res_name = chain_res
            # only train on canonical residues
            if res_name not in res_label_dict:
                continue
            # sample each residue based on its frequency in train data
            if self.balance:
                if not np.random.random() < res_wt_dict[res_name]:
                    continue

            if not np.all([b in res_df['name'].to_list() for b in bb_atoms]):
                # print('residue missing atoms...   skipping')
                continue
            CA_pos = res_df[res_df['name']=='CA'][['x', 'y', 'z']].astype(np.float32).to_numpy()[0]

            CB_pos = CA_pos + (np.ones_like(CA_pos) * gly_CB_mu)

            # remove current residue from structure
            subunit_df = df[(df.chain != chain) | (df.residue != res) | df['name'].isin(bb_atoms)]
            
            # environment = all atoms within 10*sqrt(3) angstroms (to enable a 20A cube)
            kd_tree = scipy.spatial.KDTree(subunit_df[['x','y','z']].to_numpy())
            subunit_pt_idx = kd_tree.query_ball_point(CB_pos, r=10.0*np.sqrt(3), p=2.0)
            
            subunits.append(subunit_df.index[sorted(subunit_pt_idx)].to_list())
    
            sub_name = '_'.join([str(x) for x in chain_res])
            label_row = [sub_name, res_label_dict[res_name], CB_pos[0], CB_pos[1], CB_pos[2]]
            labels.append(label_row)

        assert len(labels) == len(subunits)
        print(len(labels))
        x['atoms'] = df
        x['labels'] = pd.DataFrame(labels, columns=['subunit', 'label', 'x', 'y', 'z'])
        x['subunit_indices'] = subunits

        return x

In [4]:
from torch.utils.data import IterableDataset

import torch, random, scipy, math
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from atom3d.datasets import LMDBDataset
#import atom3d.datasets.ppi.neighbors as nb
from torch.utils.data import IterableDataset
from gvp import GVP, GVPConvLayer, LayerNorm
import torch_cluster, torch_geometric, torch_scatter
from gvp.data import _normalize, _rbf

# to go from 3 letter amino acid code to one letter amino acid code
AA3_TO_AA1 = {
    "CYS": "C",
    "ASP": "D",
    "SER": "S",
    "GLN": "Q",
    "LYS": "K",
    "ILE": "I",
    "PRO": "P",
    "THR": "T",
    "PHE": "F",
    "ASN": "N",
    "GLY": "G",
    "HIS": "H",
    "LEU": "L",
    "ARG": "R",
    "TRP": "W",
    "ALA": "A",
    "VAL": "V",
    "GLU": "E",
    "TYR": "Y",
    "MET": "M",
}

AA1_TO_AA3 = dict(zip(AA3_TO_AA1.values(), AA3_TO_AA1.keys()))

aa3_to_num = {
    'ALA': 0,
    'ARG': 1,
    'ASN': 2,
    'ASP': 3,
    'CYS': 4,
    'GLU': 5,
    'GLN': 6,
    'GLY': 7,
    'HIS': 8,
    'ILE': 9,
    'LEU': 10,
    'LYS': 11,
    'MET': 12,
    'PHE': 13,
    'PRO': 14,
    'SER': 15,
    'THR': 16,
    'TRP': 17,
    'TYR': 18,
    'VAL': 19
}

num_to_aa3 = dict(zip(aa3_to_num.values(), aa3_to_num.keys()))

_NUM_ATOM_TYPES = 9
_element_mapping = lambda x: {
    'H' : 0,
    'C' : 1,
    'N' : 2,
    'O' : 3,
    'F' : 4,
    'S' : 5,
    'Cl': 6, 'CL': 6,
    'P' : 7
}.get(x, 8)
_amino_acids = lambda x: {
    'ALA': 0,
    'ARG': 1,
    'ASN': 2,
    'ASP': 3,
    'CYS': 4,
    'GLU': 5,
    'GLN': 6,
    'GLY': 7,
    'HIS': 8,
    'ILE': 9,
    'LEU': 10,
    'LYS': 11,
    'MET': 12,
    'PHE': 13,
    'PRO': 14,
    'SER': 15,
    'THR': 16,
    'TRP': 17,
    'TYR': 18,
    'VAL': 19
}.get(x, 20)
_DEFAULT_V_DIM = (100, 16)
_DEFAULT_E_DIM = (32, 1)

def _edge_features(coords, edge_index, D_max=4.5, num_rbf=16, device='cpu'):
    
    E_vectors = coords[edge_index[0]] - coords[edge_index[1]]
    rbf = _rbf(E_vectors.norm(dim=-1), 
               D_max=D_max, D_count=num_rbf, device=device)

    edge_s = rbf
    edge_v = _normalize(E_vectors).unsqueeze(-2)

    edge_s, edge_v = map(torch.nan_to_num,
            (edge_s, edge_v))

    return edge_s, edge_v

class BaseTransform:
    '''
    Implementation of an ATOM3D Transform which featurizes the atomic
    coordinates in an ATOM3D dataframes into `torch_geometric.data.Data`
    graphs. This class should not be used directly; instead, use the
    task-specific transforms, which all extend BaseTransform. Node
    and edge features are as described in the EGNN manuscript.
    
    Returned graphs have the following attributes:
    -x          atomic coordinates, shape [n_nodes, 3]
    -atoms      numeric encoding of atomic identity, shape [n_nodes]
    -edge_index edge indices, shape [2, n_edges]
    -edge_s     edge scalar features, shape [n_edges, 16]
    -edge_v     edge scalar features, shape [n_edges, 1, 3]
    
    Subclasses of BaseTransform will produce graphs with additional 
    attributes for the tasks-specific training labels, in addition 
    to the above.
    
    All subclasses of BaseTransform directly inherit the BaseTransform
    constructor.
    
    :param edge_cutoff: distance cutoff to use when drawing edges
    :param num_rbf: number of radial bases to encode the distance on each edge
    :device: if "cuda", will do preprocessing on the GPU
    '''
    def __init__(self, edge_cutoff=4.5, num_rbf=16, device='cpu'):
        self.edge_cutoff = edge_cutoff
        self.num_rbf = num_rbf
        self.device = device
            
    def __call__(self, df):
        '''
        :param df: `pandas.DataFrame` of atomic coordinates
                    in the ATOM3D format
        
        :return: `torch_geometric.data.Data` structure graph
        '''
        with torch.no_grad():
            coords = torch.as_tensor(df[['x', 'y', 'z']].to_numpy(),
                                     dtype=torch.float32, device=self.device)
            atoms = torch.as_tensor(list(map(_element_mapping, df.element)),
                                            dtype=torch.long, device=self.device)

            edge_index = torch_cluster.radius_graph(coords, r=self.edge_cutoff)

            edge_s, edge_v = _edge_features(coords, edge_index, 
                                D_max=self.edge_cutoff, num_rbf=self.num_rbf, device=self.device)

            return torch_geometric.data.Data(x=coords, atoms=atoms,
                        edge_index=edge_index, edge_s=edge_s, edge_v=edge_v)

class myRESDataset(IterableDataset):
    '''
    A `torch.utils.data.IterableDataset` wrapper around a
    ATOM3D RES dataset.
    
    On each iteration, returns a `torch_geometric.data.Data`
    graph with the attribute `label` encoding the masked residue
    identity, `ca_idx` for the node index of the alpha carbon, 
    and all structural attributes as described in BaseTransform.
    
    Excludes hydrogen atoms.
    
    :param lmdb_dataset: path to ATOM3D dataset
    :param split_path: path to the ATOM3D split file
    '''
    def __init__(self, lmdb_dataset, split_path=None):
        self.dataset = LMDBDataset(lmdb_dataset) #make lmdb dataset as above
        self.idx = [0]#list(map(int, open(split_path).read().split())) 
        # CHANGED to use the first structure? why is there only 2?
        self.transform = BaseTransform()
        
    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:
            gen = self._dataset_generator(list(range(len(self.idx))), 
                      shuffle=False)
        else:  
            per_worker = int(math.ceil(len(self.idx) / float(worker_info.num_workers)))
            worker_id = worker_info.id
            iter_start = worker_id * per_worker
            iter_end = min(iter_start + per_worker, len(self.idx))
            gen = self._dataset_generator(list(range(len(self.idx)))[iter_start:iter_end],
                      shuffle=False)
        return gen
    
    def _dataset_generator(self, indices, shuffle=False):
        if shuffle: random.shuffle(indices)
        with torch.no_grad():
            for idx in indices:
                print('idx',idx)
                data = self.dataset[self.idx[idx]]
                atoms = data['atoms']
                for sub in data['labels'].itertuples():
                    #print('sub',sub)
                    _, num, aa_num = sub.subunit.split('_')
                    num, aa = int(num), _amino_acids(aa_num)
                    if aa == 20: 
                        print('aais20')
                        continue
                    my_atoms = atoms.iloc[data['subunit_indices'][sub.Index]].reset_index(drop=True)
                    #display(my_atoms)
                    ca_idx = np.where((my_atoms.residue == num) & (my_atoms.name == 'CA') &(my_atoms.chain  =='A'))[0] # had to fix this
                    #print(ca_idx)
                    if len(ca_idx) != 1: 
                        print('len(ca_idx) is not 1')
                        continue
                        
                    with torch.no_grad():
                        graph = self.transform(my_atoms)
                        graph.label = aa
                        graph.ca_idx = int(ca_idx)
                        yield num, aa, graph

                        
def get_model(task):
    return {
        'RES' : gvp.atom3d.RESModel,
    }[task]()

def forward(model, batch, device):
    if type(batch) in [list, tuple]:
        batch = batch[0].to(device), batch[1].to(device)
    else:
        batch = batch.to(device)
    return model(batch)

# specific for AT

In [4]:
# Load dataset from directory of PDB files and filter for the positions of interest
pos_oi = [('A', 60, 'TRP'),('A', 61, 'ASP'), ('A', 64, 'LYS'), ('A', 80, 'GLU')]
dataset = da.load_dataset('./pdbs/at/', 'pdb', transform = myResTransform(balance=False, pos_oi =pos_oi))
# need to check that balance false leads to all residues being converted

# Create LMDB dataset from PDB dataset, and write to file
da.make_lmdb_dataset(dataset, './pdbs/lmdb_spec/')


  0%|          | 0/4 [00:00<?, ?it/s]

0


  0%|          | 0/4 [00:00<?, ?it/s]


RuntimeError: LMDB entry 0 in ./pdbs/lmdb_spec/ already exists

## to inspect the LMDBdataset


In [None]:
lds = LMDBDataset('./pdbs/lmdb_spec/') # this file has an amino acid label. make sure this label fits with the training data labels
lds.ids()
lds.ids_to_indices(['bio_'])
lds_get=lds.get('bio_')

In [None]:
lds_get.keys()

In [None]:
lds_get['atoms']#these are the xyz of all the atoms

In [None]:
len(lds_get['subunit_indices']) #around 748
#len(lds_get['subunit_indices'][100])# around 556,587, 895, 729
for i in lds_get['subunit_indices']:
    print(len(i))

# create all amino acid lmdb datasets for antitoxin


In [4]:
wt_at = 'MANVEKMSVAVTPQQAAVMREAVEAGEYATASEIVREAVRDWLAKRELRHDDIRRLRQLWDEGKASGRPEPVDFDALRKEARQKLTEVPPNGR'
len(wt_at)

pos_oi_all_at = list(zip(['A']* len(wt_at),
                         range(1,len(wt_at)+1),
                         [AA1_TO_AA3[aa] for aa in wt_at]
                        )
                    )
'''
pos_oi_all_at = [('A', 1, 'MET'),
 ('A', 2, 'ALA'),
 ('A', 3, 'ASN'),
 ('A', 4, 'VAL'),
 ('A', 5, 'GLU'),
 ]
 '''
#pos_oi_all_at

"\npos_oi_all_at = [('A', 1, 'MET'),\n ('A', 2, 'ALA'),\n ('A', 3, 'ASN'),\n ('A', 4, 'VAL'),\n ('A', 5, 'GLU'),\n ]\n "

In [5]:
# Load dataset from directory of PDB files 
# this is recursive, all pdb files in subdirectories will also be used
dataset = da.load_dataset('./pdbs/ta/', 'pdb', 
                          transform = myResTransform(balance=False, pos_oi =pos_oi_all_at)) 

# Create LMDB dataset from PDB dataset, and write to file
da.make_lmdb_dataset(dataset, './pdbs/lmdb_at_all/')

NameError: name 'da' is not defined

In [7]:
# this dataset is fine
lds = LMDBDataset('./pdbs/lmdb_at_all/') # this file has an amino acid label. make sure this label fits with the training data labels

lds.ids()
lds.ids_to_indices(['bio_'])
lds_get=lds.get('bio_')

In [8]:
lds_get['atoms']#these are the xyz of all the atoms

Unnamed: 0,ensemble,subunit,structure,model,chain,hetero,insertion_code,residue,segid,resname,altloc,occupancy,bfactor,x,y,z,element,name,fullname,serial_number
0,bio_all.pdb,0,bio_all.pdb,0,A,,,3,,ASN,,1.00,53.23,-37.250000,-65.727997,-55.965000,N,N,N,1
1,bio_all.pdb,0,bio_all.pdb,0,A,,,3,,ASN,,1.00,48.71,-36.237000,-66.579002,-56.580002,C,CA,CA,2
2,bio_all.pdb,0,bio_all.pdb,0,A,,,3,,ASN,,1.00,38.44,-34.804001,-66.112999,-56.310001,C,C,C,3
3,bio_all.pdb,0,bio_all.pdb,0,A,,,3,,ASN,,1.00,37.30,-34.316002,-66.199997,-55.183998,O,O,O,4
4,bio_all.pdb,0,bio_all.pdb,0,A,,,3,,ASN,,1.00,50.30,-36.479000,-66.672997,-58.089001,C,CB,CB,5
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
6023,bio_all.pdb,0,bio_all.pdb,0,H,,,103,,PHE,A,0.56,22.81,-25.395000,-76.075996,-34.758999,C,CD1,CD1,6827
6024,bio_all.pdb,0,bio_all.pdb,0,H,,,103,,PHE,A,0.56,22.29,-23.186001,-76.666000,-35.421001,C,CD2,CD2,6829
6025,bio_all.pdb,0,bio_all.pdb,0,H,,,103,,PHE,A,0.56,23.10,-25.566000,-77.329002,-34.188999,C,CE1,CE1,6831
6026,bio_all.pdb,0,bio_all.pdb,0,H,,,103,,PHE,A,0.56,26.33,-23.350000,-77.918999,-34.856998,C,CE2,CE2,6833


In [9]:
print(len(lds_get['subunit_indices'])) #how many positions
for i in lds_get['subunit_indices']:
    print(len(i))

85
556
587
528
533
572
562
698
695
743
697
551
572
678
572
460
530
616
485
429
563
564
386
445
561
682
732
660
669
724
801
855
862
875
862
895
913
885
850
873
857
767
763
832
748
650
688
716
596
584
661
623
489
600
654
530
462
576
592
439
415
527
466
335
343
315
322
465
455
494
531
446
497
388
368
474
440
343
366
430
349
284
320
335
227
190


# make resdataset

going from lmdb dataset to res dataset

## make resdataset for 4 pos only

In [79]:
# making a resdataset from the lmdb dataset
ds = myRESDataset('./pdbs/lmdb_spec/')
dl = torch_geometric.data.DataLoader(ds, num_workers=4, batch_size=1)

In [80]:

#ds.dataset[0]

In [104]:
for num, aa, b in dl:
    print(num.numpy()[0], num_to_aa3[aa.numpy()[0]])
    #print(num, num_to_aa3[aa.values])

idx 0
60 TRP
61 ASP
64 LYS
80 GLU


## make resds for all at positions

In [12]:
ds_all = myRESDataset('./pdbs/lmdb_at_all/')
dl_all = torch_geometric.data.DataLoader(ds_all, num_workers=4, batch_size=1)



In [141]:
#ds_all.dataset.get('bio_') # so ds_all does have the full dataset


{'atoms':          ensemble  subunit    structure  model chain hetero insertion_code  \
 0     bio_all.pdb        0  bio_all.pdb      0     A                         
 1     bio_all.pdb        0  bio_all.pdb      0     A                         
 2     bio_all.pdb        0  bio_all.pdb      0     A                         
 3     bio_all.pdb        0  bio_all.pdb      0     A                         
 4     bio_all.pdb        0  bio_all.pdb      0     A                         
 ...           ...      ...          ...    ...   ...    ...            ...   
 6023  bio_all.pdb        0  bio_all.pdb      0     H                         
 6024  bio_all.pdb        0  bio_all.pdb      0     H                         
 6025  bio_all.pdb        0  bio_all.pdb      0     H                         
 6026  bio_all.pdb        0  bio_all.pdb      0     H                         
 6027  bio_all.pdb        0  bio_all.pdb      0     H                         
 
       residue segid resname altloc  occu

In [142]:
for b in dl_all:
    print(b)

idx 0
[tensor([3]), tensor([2]), DataBatch(x=[556, 3], edge_index=[2, 8380], atoms=[556], edge_s=[8380, 16], edge_v=[8380, 1, 3], label=[1], ca_idx=[1], batch=[556], ptr=[2])]
[tensor([4]), tensor([19]), DataBatch(x=[587, 3], edge_index=[2, 8906], atoms=[587], edge_s=[8906, 16], edge_v=[8906, 1, 3], label=[1], ca_idx=[1], batch=[587], ptr=[2])]
[tensor([5]), tensor([5]), DataBatch(x=[528, 3], edge_index=[2, 7898], atoms=[528], edge_s=[7898, 16], edge_v=[7898, 1, 3], label=[1], ca_idx=[1], batch=[528], ptr=[2])]
[tensor([6]), tensor([11]), DataBatch(x=[533, 3], edge_index=[2, 7970], atoms=[533], edge_s=[7970, 16], edge_v=[7970, 1, 3], label=[1], ca_idx=[1], batch=[533], ptr=[2])]
[tensor([7]), tensor([12]), DataBatch(x=[572, 3], edge_index=[2, 8620], atoms=[572], edge_s=[8620, 16], edge_v=[8620, 1, 3], label=[1], ca_idx=[1], batch=[572], ptr=[2])]
[tensor([8]), tensor([15]), DataBatch(x=[562, 3], edge_index=[2, 8714], atoms=[562], edge_s=[8714, 16], edge_v=[8714, 1, 3], label=[1], ca_id

# evaluate the pdb with the model

In [13]:

device = 'cuda' if torch.cuda.is_available() else 'cpu'




# push the model to cuda 
model = get_model('RES').to(device)

#load model
#model_path = '/n/groups/marks/users/david/res/model_save/model2/RES_1639761348.356219_25.pt'
#model_path = '/n/groups/marks/users/david/res/model_save/model2/RES_1640009835.334532_30.pt'
model_path = '/n/groups/marks/users/david/res/model_save/model2/RES_1646945484.3030427_8.pt'
'''
EPOCH 8 TRAIN loss: 1.40427650

EPOCH 8 VAL loss: 1.47019486
BEST /n/groups/marks/users/david/res/model_save/model2/RES_1646945484.3030427_8.pt VAL loss: 1.47019486
'''
model.load_state_dict(torch.load(model_path))
model = model.eval()

In [52]:
dl

<torch_geometric.loader.dataloader.DataLoader at 0x768551b90d90>

In [138]:
# predicting for 4 positions

n_ave = 10
pos_to_do = 4

df_result = pd.DataFrame()

c=0
for num, aa, b in dl:
    if c<pos_to_do:
        pos = num.numpy()[0]
        aa3 = num_to_aa3[aa.numpy()[0]]
        print(pos, aa3)
        print('b', b.x)
        x= np.zeros([n_ave, 20])
        for i in range(n_ave):
            out = forward(model, b, device)
            print(out.shape)
            m_out= out.cpu().detach().numpy().reshape(-1)# there is some stochasticity in this output

            x[i,:] = m_out
        
        
        mean_x = x.mean(axis=0)
        std_x = x.std(axis=0)
        c+=1
    
        aa1 = AA3_TO_AA1[aa3]
        wt_pos = aa1+str(pos)

        muts = [wt_pos+AA3_TO_AA1[k] for k in aa3_to_num.keys()]

        zipped = list(zip(muts, mean_x, std_x))
        df_pos = pd.DataFrame(zipped, columns=['mut', 'mean_x', 'std_x'])
        #print(df_pos)
        df_result = pd.concat([df_result,df_pos], axis=0)
df_result = df_result.reset_index()
print(df_result)

        
        

idx 0


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f9e3035ba70>
Traceback (most recent call last):
  File "/n/groups/marks/software/anaconda_o2/envs/dd_torch/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/n/groups/marks/software/anaconda_o2/envs/dd_torch/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():
  File "/n/groups/marks/software/anaconda_o2/envs/dd_torch/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process


60 TRP
b tensor([[-28.4430, -51.1320, -34.5120],
        [-28.5160, -50.3650, -33.5590],
        [-27.5300, -49.3960, -35.9640],
        ...,
        [-26.5470, -52.1610, -23.2800],
        [-25.8270, -51.4110, -24.4100],
        [-26.9570, -52.3740, -26.8200]])
torch.Size([1, 20])
torch.Size([1, 20])
torch.Size([1, 20])
torch.Size([1, 20])
torch.Size([1, 20])
torch.Size([1, 20])
torch.Size([1, 20])
torch.Size([1, 20])
torch.Size([1, 20])
torch.Size([1, 20])
61 ASP
b tensor([[-28.7340, -48.4560, -35.9130],
        [-28.6680, -47.3690, -35.3320],
        [-31.0610, -48.0870, -36.4590],
        ...,
        [-24.6550, -45.5570, -28.1340],
        [-23.5220, -44.6060, -28.5110],
        [-25.8560, -45.3550, -29.0650]])
torch.Size([1, 20])
torch.Size([1, 20])
torch.Size([1, 20])
torch.Size([1, 20])
torch.Size([1, 20])
torch.Size([1, 20])
torch.Size([1, 20])
torch.Size([1, 20])
torch.Size([1, 20])
torch.Size([1, 20])
64 LYS
b tensor([[-29.2040, -45.3800, -32.2620],
        [-29.2250, -44.46

# making dataset for all positions

In [14]:
# predicting for all at positions

n_ave = 100
max_pos_to_do = 200
df_result = pd.DataFrame()

with torch.no_grad():
    c=0
    for num, aa, b in dl_all:
        if c<max_pos_to_do:
            pos = num.numpy()[0]
            aa3 = num_to_aa3[aa.numpy()[0]]
            #print(pos, aa3)
            #print('b', b.x)
            x= np.zeros([n_ave, 20])
            for i in range(n_ave):
                out = forward(model, b, device)
                #print(out.shape)
                m_out= out.cpu().detach().numpy().reshape(-1)# there is some stochasticity in this output

                x[i,:] = m_out


            mean_x = x.mean(axis=0)
            std_x = x.std(axis=0)

            aa1 = AA3_TO_AA1[aa3]
            wt_pos = aa1+str(pos)

            muts = [wt_pos+AA3_TO_AA1[k] for k in aa3_to_num.keys()]

            zipped = list(zip(muts, mean_x, std_x))
            df_pos = pd.DataFrame(zipped, columns=['mut', 'mean_x', 'std_x'])

            df_result = pd.concat([df_result,df_pos], axis=0)
            c+=1
            print(c)
            #print(df_pos)
df_result = df_result.reset_index()
print(df_result)
df_result.to_csv('./out/gvp_{}_m_{}_220711.csv'.format(n_ave, model_path.split('/')[-1][:-3]))

idx 0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
      index   mut    mean_x         std_x
0         0   N3A -5.922711  9.081839e-07
1         1   N3R -4.185834  7.297946e-07
2         2   N3N -4.946471  9.083842e-07
3         3   N3D -6.752131  1.042912e-06
4         4   N3C -6.214829  7.957754e-07
...     ...   ...       ...           ...
1695     15  E87S -2.564033  2.234530e-07
1696     16  E87T -2.989485  3.064769e-07
1697     17  E87W -6.175044  4.774090e-07
1698     18  E87Y -5.364314  4.598204e-07
1699     19  E87V -4.475881  3.873843e-07

[1700 rows x 4 columns]


# make res predictions for gfp

In [5]:
wt_gfp = 'MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTLSYGVQCFSRYPDHMKRHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKTRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYN'
len(wt_gfp)

pos_oi_all_at = list(zip(['A']* len(wt_gfp),
                         range(1,len(wt_gfp)+1),
                         [AA1_TO_AA3[aa] for aa in wt_gfp]
                        )
                    )

In [6]:
# Load dataset from directory of PDB files 
# this is recursive, all pdb files in subdirectories will also be used
dataset = da.load_dataset('./pdbs/gfp_sark/', 'pdb', 
                          transform = myResTransform(balance=False, pos_oi =pos_oi_all_at)) 

# Create LMDB dataset from PDB dataset, and write to file
da.make_lmdb_dataset(dataset, './pdbs/lmdb_gfp_all/')

  0%|          | 0/1 [00:00<?, ?it/s]


('A', 3, 'LYS')

('A', 4, 'GLY')

('A', 5, 'GLU')

('A', 6, 'GLU')

('A', 7, 'LEU')

('A', 8, 'PHE')

('A', 9, 'THR')

('A', 10, 'GLY')

('A', 11, 'VAL')

('A', 12, 'VAL')

('A', 13, 'PRO')

('A', 14, 'ILE')

('A', 15, 'LEU')

('A', 16, 'VAL')

('A', 17, 'GLU')

('A', 18, 'LEU')

('A', 19, 'ASP')

('A', 20, 'GLY')

('A', 21, 'ASP')

('A', 22, 'VAL')

('A', 23, 'ASN')

('A', 24, 'GLY')

('A', 25, 'HIS')

('A', 26, 'LYS')

('A', 27, 'PHE')

('A', 28, 'SER')

('A', 29, 'VAL')

('A', 30, 'SER')

('A', 31, 'GLY')

('A', 32, 'GLU')

('A', 33, 'GLY')

('A', 34, 'GLU')

('A', 35, 'GLY')

('A', 36, 'ASP')

('A', 37, 'ALA')

('A', 38, 'THR')

('A', 39, 'TYR')

('A', 40, 'GLY')

('A', 41, 'LYS')

('A', 42, 'LEU')

('A', 43, 'THR')

('A', 44, 'LEU')

('A', 45, 'LYS')

('A', 46, 'PHE')

('A', 47, 'ILE')

('A', 48, 'CYS')

('A', 49, 'THR')

('A', 50, 'THR')

('A', 51, 'GLY')

('A', 52, 'LYS')

('A', 53, 'LEU')

('A', 54, 'PRO')

('A', 55, 'VAL')

('A', 56, 'PRO')

('A', 57, 'TRP')

('A', 58, 'PRO')

100%|██████████| 1/1 [00:02<00:00,  2.19s/it]


In [9]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# push the model to cuda 
model = get_model('RES').to(device)

#load model
model_path = '/n/groups/marks/users/david/res/model_save/model2/RES_1646945484.3030427_8.pt'
'''
EPOCH 8 TRAIN loss: 1.40427650

EPOCH 8 VAL loss: 1.47019486
BEST /n/groups/marks/users/david/res/model_save/model2/RES_1646945484.3030427_8.pt VAL loss: 1.47019486
'''
model.load_state_dict(torch.load(model_path))
model = model.eval()

In [7]:
ds_all = myRESDataset('./pdbs/lmdb_gfp_all/')
dl_all = torch_geometric.data.DataLoader(ds_all, num_workers=4, batch_size=1)



In [10]:
# predicting for all at positions

n_ave = 100
max_pos_to_do = 300
df_result = pd.DataFrame()

with torch.no_grad():
    c=0
    for num, aa, b in dl_all:
        if c<max_pos_to_do:
            pos = num.numpy()[0]
            aa3 = num_to_aa3[aa.numpy()[0]]
            #print(pos, aa3)
            #print('b', b.x)
            x= np.zeros([n_ave, 20])
            for i in range(n_ave):
                out = forward(model, b, device)
                #print(out.shape)
                m_out= out.cpu().detach().numpy().reshape(-1)# there is some stochasticity in this output

                x[i,:] = m_out


            mean_x = x.mean(axis=0)
            std_x = x.std(axis=0)

            aa1 = AA3_TO_AA1[aa3]
            wt_pos = aa1+str(pos)

            muts = [wt_pos+AA3_TO_AA1[k] for k in aa3_to_num.keys()]

            zipped = list(zip(muts, mean_x, std_x))
            df_pos = pd.DataFrame(zipped, columns=['mut', 'mean_x', 'std_x'])

            df_result = pd.concat([df_result,df_pos], axis=0)
            c+=1
            print(c)
            #print(df_pos)
df_result = df_result.reset_index()
print(df_result)
df_result.to_csv('./out/gvp_{}_m_{}_230519_gfp.csv'.format(n_ave, model_path.split('/')[-1][:-3]))

idx 0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
      index    mut    mean_x         std_x
0         0    K3A -2.658675  5.085642e-07
1         1    K3R -3.110558  5.441953e-07
2         2    K3N -3.254103  4.832022e-07
3         3    K3D -2.6776

# misc: plyaing with resdataset

In [10]:

from torch.utils.data import IterableDataset

class RESDataset(IterableDataset):
    '''
    A `torch.utils.data.IterableDataset` wrapper around a
    ATOM3D RES dataset.
    
    On each iteration, returns a `torch_geometric.data.Data`
    graph with the attribute `label` encoding the masked residue
    identity, `ca_idx` for the node index of the alpha carbon, 
    and all structural attributes as described in BaseTransform.
    
    Excludes hydrogen atoms.
    
    :param lmdb_dataset: path to ATOM3D dataset
    :param split_path: path to the ATOM3D split file
    '''
    def __init__(self, lmdb_dataset, split_path):
        self.dataset = LMDBDataset(lmdb_dataset)
        self.idx = list(map(int, open(split_path).read().split()))
        self.transform = BaseTransform()
        
    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:
            gen = self._dataset_generator(list(range(len(self.idx))), 
                      shuffle=True)
        else:  
            per_worker = int(math.ceil(len(self.idx) / float(worker_info.num_workers)))
            worker_id = worker_info.id
            iter_start = worker_id * per_worker
            iter_end = min(iter_start + per_worker, len(self.idx))
            gen = self._dataset_generator(list(range(len(self.idx)))[iter_start:iter_end],
                      shuffle=True)
        return gen
    
    def _dataset_generator(self, indices, shuffle=True):
        if shuffle: random.shuffle(indices)
        with torch.no_grad():
            for idx in indices:
                data = self.dataset[self.idx[idx]]
                atoms = data['atoms']
                for sub in data['labels'].itertuples():
                    _, num, aa = sub.subunit.split('_')
                    num, aa = int(num), _amino_acids(aa)
                    if aa == 20: continue
                    my_atoms = atoms.iloc[data['subunit_indices'][sub.Index]].reset_index(drop=True)
                    ca_idx = np.where((my_atoms.residue == num) & (my_atoms.name == 'CA'))[0]
                    if len(ca_idx) != 1: continue
                        
                    with torch.no_grad():
                        graph = self.transform(my_atoms)
                        graph.label = aa
                        graph.ca_idx = int(ca_idx)
                        yield graph

def get_datasets(task, lba_split=30):
    data_path = {
        'RES' : '/n/groups/marks/users/david/res/atom3d_data/raw/RES/data/',
    }[task]

    if task == 'RES':
        split_path = '/n/groups/marks/users/david/res/atom3d_data/split-by-cath-topology/indices/'
        dataset = partial(gvp.atom3d.RESDataset, data_path)        
        trainset = dataset(split_path=split_path+'train_indices.txt')
        valset = dataset(split_path=split_path+'val_indices.txt')
        testset = dataset(split_path=split_path+'test_indices.txt')


    return trainset, valset, testset

lba_split = 30
datasets = get_datasets('RES', lba_split=lba_split)

In [17]:
batch_size = 8 # control memory of model as long as you can fit a batch size of 1, you can do gradient accumultion to simulate batch size.
num_workers = 4
dataloader = partial(torch_geometric.data.DataLoader, 
                    num_workers=num_workers, batch_size=batch_size)

In [18]:
trainset, valset, testset = map(dataloader, datasets)   



In [20]:
c=0
for b in testset:
    if c<2:
        print(b)
        print(b.label)
        c+=1

DataBatch(x=[4455, 3], edge_index=[2, 74774], atoms=[4455], edge_s=[74774, 16], edge_v=[74774, 1, 3], label=[8], ca_idx=[8], batch=[4455], ptr=[9])
tensor([15, 11, 13, 13, 11, 17, 15,  0])
DataBatch(x=[4220, 3], edge_index=[2, 64128], atoms=[4220], edge_s=[64128, 16], edge_v=[64128, 1, 3], label=[8], ca_idx=[8], batch=[4220], ptr=[9])
tensor([ 8,  7, 16, 12,  8, 11, 14, 11])
