In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import numpy as np
import pandas as pd
from utils import serialize, deserialize
from collapse import atom_info

In [14]:
from xml_parse import parse_xml

In [15]:
import random
np.random.seed(77)
random.seed(77)

In [9]:
sites = parse_xml('../data/flat_db_file.xml')

100%|██████████| 27128801/27128801 [02:35<00:00, 174009.93it/s]


In [23]:
data_list = []
for s in sites:
    if s.getSiteChain().getMolType() != 'protein':
        continue
    site_id = s.getSite()
    pdb_id = s.getPDBid()
    m = s.getMetal()
    sc = s.getSiteChain()
    
    chain = m.getChainid()
    metal_pdb_name = m.getAtomName()
    # metal_pdb_atom = m.getElement()
    metal_resname = m.getResidueName()
    metal_resid = m.getResidueid()
    metal_coord_num = m.getCoord()
    site_loc = s.getSiteLoc()
    name = sc.getMolName()
    interactions = []
    
    for lig in m.getLigands():
        if lig.getEndoExo() == 'exogenous':
            exo = True
        else:
            exo = False
        resname = lig.getResidueName()
        resid = lig.getResidueid()
        m_chain = lig.getChainid()
        if resname not in atom_info.aa:
            continue
        if m_chain != chain:
            continue
        interactions.append('_'.join((resname, resid)))
    data_list.append([site_id, pdb_id, chain, name, metal_pdb_name, metal_resname, metal_resid, metal_coord_num, site_loc, interactions])

In [24]:
metal_df = pd.DataFrame(data_list, columns=['site_id', 'pdb_id', 'chain', 'protein_name', 'metal_pdb_name', 'metal_resname', 'metal_resid', 'metal_coord_num', 'site_loc', 'interactions'])

In [25]:
metal_df[(metal_df['pdb_id'] == '1jb0') & (metal_df['metal_pdb_name'] == 'CA')]

Unnamed: 0,site_id,pdb_id,chain,protein_name,metal_pdb_name,metal_resname,metal_resid,metal_coord_num,site_loc,interactions
14334,1jb0_21,1jb0,L,Photosystem I reaction center subunit XI,CA,CA,1001,6,Within a Chain,"[ASP_70, PRO_67]"


In [26]:
metal_df = metal_df[metal_df.site_loc == 'Within a Chain']

In [30]:
import collections as col
col.Counter(metal_df.interactions.apply(len))

Counter({1: 36776,
         0: 24548,
         2: 22540,
         3: 25540,
         4: 28531,
         5: 7331,
         6: 1396,
         7: 16,
         8: 4})

In [31]:
metal_df = metal_df[metal_df.interactions.apply(len) > 1]

In [32]:
metal_df['pdb_chain'] = metal_df['pdb_id'] + metal_df['chain']

In [33]:
pdb_ec = pd.read_csv('../data/pdb_chain_enzyme.csv', skiprows=1)

In [34]:
pdb_ec['pdb_chain'] = pdb_ec['PDB'] + pdb_ec['CHAIN']

In [35]:
metal_df = pd.merge(metal_df, pdb_ec[['pdb_chain', 'EC_NUMBER']], how='left')

In [36]:
metal_enzymes = metal_df[~metal_df['EC_NUMBER'].isna()]
metal_enzymes = metal_enzymes[metal_enzymes['metal_pdb_name'] != '']

In [37]:
metal_enzymes.metal_pdb_name.value_counts()[:8]

metal_pdb_name
ZN    19678
CA    10283
MG     9930
NA     6179
FE     4046
MN     3176
K      2425
CU     1587
Name: count, dtype: int64

In [38]:
pdb_resol = {}
with open('../data/resolu.idx') as f:
    for line in f:
        l = [x.strip() for x in line.strip().split(';')]
        if (len(l) != 2) or (l[1] == ''):
            continue
        pdb_resol[l[0]] = float(l[1])

In [39]:
metal_enzymes['resolution'] = metal_enzymes['pdb_id'].apply(lambda x: pdb_resol.get(x.upper()))

In [41]:
ions = ['ZN', 'MG', 'CA', 'FE', 'NA', 'MN', 'K']

In [42]:
database = {k:{} for k in ions}
for el in ions:
    subdf = metal_enzymes[metal_enzymes['metal_pdb_name'] == el]
    pos_pdb = subdf.sort_values('resolution').groupby('EC_NUMBER')['pdb_chain'].nth(0)
    pos_df = subdf[subdf['pdb_chain'].isin(pos_pdb)]
    database[el]['pos'] = pos_df
    
    neg_df = metal_enzymes[(metal_enzymes['metal_pdb_name'] != el) & (~metal_enzymes['pdb_chain'].isin(pos_pdb))]
    neg_pdb = neg_df.sort_values('resolution').groupby('EC_NUMBER')['pdb_chain'].nth(0)
    neg_pdb = neg_pdb.sample(len(pos_pdb))
    assert len(set(neg_pdb) & set(pos_pdb)) == 0
    neg_df = neg_df[neg_df['pdb_chain'].isin(neg_pdb)]
    # neg_df = neg_df.sample(len(pos_df))
    database[el]['neg'] = neg_df
    
    print(el, len(pos_df), len(neg_df))

ZN 964 1014
MG 813 1054
CA 799 894
FE 258 313
NA 741 993
MN 305 401
K 315 414


In [43]:
serialize(database, '../data/metal_database_balanced.pkl')

### Create embeddings database

In [5]:
from sklearn.model_selection import train_test_split
from atom3d.datasets import load_dataset
from collapse import process_pdb, embed_protein, initialize_model, atom_info
from torch_geometric.nn import knn_graph
import torch

In [6]:
database = deserialize('../data/metal_database_balanced.pkl')

In [13]:
db = database['ZN']
db_pos = db['pos']
db_pos[db_pos['pdb_id']=='2c5w']

Unnamed: 0,site_id,pdb_id,chain,protein_name,metal_pdb_name,metal_resname,metal_resid,metal_coord_num,site_loc,interactions,pdb_chain,EC_NUMBER,resolution
18628,2c5w_3,2c5w,B,Penicillin-binding protein 1A,ZN,ZN,702,3,Within a Chain,"[GLU_435, HIS_395]",2c5wB,2.4.1.129,2.55
18629,2c5w_3,2c5w,B,Penicillin-binding protein 1A,ZN,ZN,702,3,Within a Chain,"[GLU_435, HIS_395]",2c5wB,3.4.16.4,2.55


In [12]:
db_pos

Unnamed: 0,site_id,pdb_id,chain,protein_name,metal_pdb_name,metal_resname,metal_resid,metal_coord_num,site_loc,interactions,pdb_chain,EC_NUMBER,resolution
77,1a2p_3,1a2p,C,Ribonuclease,ZN,ZN,112,4,Within a Chain,"[LYS_62, GLU_60]",1a2pC,3.1.27.-,1.50
364,1ak0_2,1ak0,A,Nuclease P1,ZN,ZN,274,4,Within a Chain,"[GLU_181, HIS_15]",1ak0A,3.1.30.1,1.80
540,1atl_3,1atl,B,Zinc metalloproteinase atrolysin-D,ZN,ZN,402,5,Within a Chain,"[HIS_146, HIS_142, HIS_152]",1atlB,3.4.24.42,1.80
620,1aye_1,1aye,A,Carboxypeptidase A2,ZN,ZN,400,5,Within a Chain,"[GLU_72, HIS_196, HIS_69]",1ayeA,3.4.17.15,1.80
1046,1bqb_2,1bqb,A,Zinc metalloproteinase aureolysin,ZN,ZN,350,6,Within a Chain,"[HIS_148, GLU_168, HIS_144]",1bqbA,3.4.24.29,1.72
...,...,...,...,...,...,...,...,...,...,...,...,...,...
92751,5o6c_2,5o6c,A,E3 ubiquitin-protein ligase MYCBP2,ZN,ZN,4704,4,Within a Chain,"[CYS_4631, HIS_4552, CYS_4634, CYS_4549]",5o6cA,2.3.2.33,1.75
92752,5o6c_3,5o6c,A,E3 ubiquitin-protein ligase MYCBP2,ZN,ZN,4705,4,Within a Chain,"[CYS_4579, CYS_4564, CYS_4561, CYS_4582]",5o6cA,2.3.2.33,1.75
92753,5o6c_4,5o6c,A,E3 ubiquitin-protein ligase MYCBP2,ZN,ZN,4702,4,Within a Chain,"[CYS_4408, CYS_4437, CYS_4440, HIS_4410]",5o6cA,2.3.2.33,1.75
92754,5o6c_5,5o6c,A,E3 ubiquitin-protein ligase MYCBP2,ZN,ZN,4703,4,Within a Chain,"[CYS_4506, CYS_4540, CYS_4537, CYS_4509]",5o6cA,2.3.2.33,1.75


In [52]:
for p in list(db['pos']['pdb_id'].unique()) + list(db['neg']['pdb_id'].unique()):
    if not os.path.exists('../data/pdb/' + p + '.cif'):
        os.system('wget https://files.rcsb.org/download/' + p + '.cif -O ../data/pdb/' + p + '.cif')
    break

--2023-09-17 19:03:56--  https://files.rcsb.org/download/1atl.cif
Resolving files.rcsb.org (files.rcsb.org)... 132.249.210.234
Connecting to files.rcsb.org (files.rcsb.org)|132.249.210.234|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified [application/octet-stream]
Saving to: ‘../data/pdb/1atl.cif’

     0K .......... .......... .......... .......... .......... 2.05M
    50K .......... .......... .......... .......... .......... 4.02M
   100K .......... .......... .......... .......... ..........  128M
   150K .......... .......... .......... .......... ..........  119M
   200K .......... .......... .......... .......... .......... 4.29M
   250K .......... .......... .......... .......... ..........  121M
   300K .......... .......... .......... .......... ..........  102M
   350K .......... .                                            139M=0.05s

2023-09-17 19:03:56 (7.17 MB/s) - ‘../data/pdb/1atl.cif’ saved [369667]



In [75]:
import atom3d.util.formats as fo
from atom3d.filters.filters import first_model_filter
def process_pdb(pdb_file, chain=None, include_hets=True):
    atoms = fo.bp_to_df(fo.read_mmcif(pdb_file, name=None, auth_residues=False))
    atoms = first_model_filter(atoms)
    if chain:
        atoms = atoms[atoms.chain == chain]
    atoms = atoms[atoms.resname != 'HOH']
    atoms = atoms[atoms.element != 'H'].reset_index(drop=True)
    if not include_hets:
        atoms = atoms[atoms.resname.isin(atom_info.aa)].reset_index(drop=True)
    return atoms

In [76]:
process_pdb('../data/pdb/2kpn.cif', chain='A', include_hets=False).head()#[['resname', 'residue']].iloc[:20]

TypeError: 'type' object cannot be interpreted as an integer

In [26]:
db = database['K']
db_pos = db['pos']
n_pos = len(db_pos)
num_pos_ec = db_pos['EC_NUMBER'].unique()
train_ec, test_ec = train_test_split(num_pos_ec, test_size=0.2, random_state=77)
train_df = db_pos[db_pos['EC_NUMBER'].isin(train_ec)]
test_df = db_pos[db_pos['EC_NUMBER'].isin(test_ec)]
train_keyres = dict(train_df.groupby('pdb_chain')['interactions'].apply(lambda x: [i for i in sum(x, []) if i.split('_')[0] in aa]))
test_keyres = dict(test_df.groupby('pdb_chain')['interactions'].apply(lambda x: [i for i in sum(x, []) if i.split('_')[0] in aa]))

In [42]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = initialize_model(device=device)

def compute_adjacency(df, resids, num_neighbors=8):
    df = df[df.name == 'CA']
    df['resid'] = df['resname'].apply(lambda x: atom_info.aa_to_letter(x)) + df['residue'].astype(str) + df['insertion_code'].astype(str).str.strip()
    df = df[df.resid.isin(resids)]
    edge_index = knn_graph(torch.tensor(df[['x', 'y', 'z']].to_numpy()), k=num_neighbors) # COO format
    return edge_index.numpy()

def create_graph_from_pdb(pdb_chain, pdb_dir, model, device):
    pdb, chain = pdb_chain[:4], pdb_chain[-1]
    atom_df = process_pdb(os.path.join(pdb_dir, pdb[1:3], f'pdb{pdb}.ent.gz'), chain=chain, include_hets=False)
    outdata = embed_protein(atom_df.copy(), model, device=device, include_hets=False, env_radius=10.0)
    if outdata is None:
        return
    outdata['adj'] = compute_adjacency(atom_df.copy(), outdata['resids'])
    return elem

In [43]:
pdb_dir = '/scratch/users/aderry/pdb'
for pdbc, res in train_keyres.items():
    emb_data = create_graph_from_pdb(pdbc, pdb_dir, model, device)
    print(emb_data)
    break

NameError: name 'self' is not defined