In [1]:
%load_ext autoreload
%autoreload 2

In [34]:
import os
import numpy as np
import pandas as pd
from utils import serialize, deserialize

In [7]:
from xml_parse import parse_xml

In [8]:
import random
np.random.seed(77)
random.seed(77)

In [9]:
sites = parse_xml('../data/flat_db_file.xml')

100%|██████████| 27128801/27128801 [02:35<00:00, 174658.89it/s]


In [10]:
data_list = []
for s in sites:
    if s.getSiteChain().getMolType() != 'protein':
        continue
    site_id = s.getSite()
    pdb_id = s.getPDBid()
    m = s.getMetal()
    sc = s.getSiteChain()
    
    chain = m.getChainid()
    metal_pdb_name = m.getAtomName()
    # metal_pdb_atom = m.getElement()
    metal_resname = m.getResidueName()
    metal_resid = m.getResidueid()
    metal_coord_num = m.getCoord()
    site_loc = s.getSiteLoc()
    name = sc.getMolName()
    interactions = []
    for lig in m.getLigands():
        if lig.getEndoExo() == 'exogenous':
            exo = True
        else:
            exo = False
        resname = lig.getResidueName()
        resid = lig.getResidueid()
        chain = lig.getChainid()
        interactions.append('_'.join((resname, resid)))
    data_list.append([site_id, pdb_id, chain, name, metal_pdb_name, metal_resname, metal_resid, metal_coord_num, site_loc, interactions])

In [11]:
metal_df = pd.DataFrame(data_list, columns=['site_id', 'pdb_id', 'chain', 'protein_name', 'metal_pdb_name', 'metal_resname', 'metal_resid', 'metal_coord_num', 'site_loc', 'interactions'])

In [12]:
metal_df = metal_df[metal_df.site_loc == 'Within a Chain']

In [13]:
metal_df['pdb_chain'] = metal_df['pdb_id'] + metal_df['chain']

In [14]:
pdb_ec = pd.read_csv('../data/pdb_chain_enzyme.csv', skiprows=1)

In [15]:
pdb_ec['pdb_chain'] = pdb_ec['PDB'] + pdb_ec['CHAIN']

In [16]:
metal_df = pd.merge(metal_df, pdb_ec[['pdb_chain', 'EC_NUMBER']], how='left')

In [17]:
metal_enzymes = metal_df[~metal_df['EC_NUMBER'].isna()]
metal_enzymes = metal_enzymes[metal_enzymes['metal_pdb_name'] != '']

In [18]:
metal_enzymes.metal_pdb_name.value_counts()[:8]

metal_pdb_name
ZN    21408
MG    15817
CA    12081
FE     9338
NA     8447
MN     3717
K      2843
NI     1851
Name: count, dtype: int64

In [19]:
metal_enzymes.groupby('EC_NUMBER')['pdb_chain'].sample(1)

32217     3l4oB
11017     3ip1C
141025    5vn1B
14340     1uzlA
134043    5kiaA
          ...  
11864     2onjB
50098     3g5uA
50084     3g5uA
33773     2cbzA
86758     4fi3C
Name: pdb_chain, Length: 2116, dtype: object

In [20]:
pdb_resol = {}
with open('../data/resolu.idx') as f:
    for line in f:
        l = [x.strip() for x in line.strip().split(';')]
        if (len(l) != 2) or (l[1] == ''):
            continue
        pdb_resol[l[0]] = float(l[1])

In [21]:
metal_enzymes['resolution'] = metal_enzymes['pdb_id'].apply(lambda x: pdb_resol.get(x.upper()))

In [22]:
ions = ['ZN', 'MG', 'CA', 'FE', 'NA', 'MN', 'K']

In [23]:
database = {k:{} for k in ions}
for el in ions:
    subdf = metal_enzymes[metal_enzymes['metal_pdb_name'] == el]
    pos_pdb = subdf.sort_values('resolution').groupby('EC_NUMBER')['pdb_chain'].nth(0)
    pos_df = subdf[subdf['pdb_chain'].isin(pos_pdb)]
    database[el]['pos'] = pos_df
    
    neg_df = metal_enzymes[(metal_enzymes['metal_pdb_name'] != el) & (~metal_enzymes['pdb_chain'].isin(pos_pdb))]
    neg_pdb = neg_df.sort_values('resolution').groupby('EC_NUMBER')['pdb_chain'].nth(0)
    neg_pdb = neg_pdb.sample(len(pos_pdb))
    assert len(set(neg_pdb) & set(pos_pdb)) == 0
    neg_df = neg_df[neg_df['pdb_chain'].isin(neg_pdb)]
    # neg_df = neg_df.sample(len(pos_df))
    database[el]['neg'] = neg_df
    
    print(el, len(pos_df), len(neg_df))

ZN 1230 1257
MG 1177 1628
CA 1039 1248
FE 357 507
NA 1086 1390
MN 378 579
K 398 633


In [24]:
serialize(database, '../data/metal_database_balanced.pkl')

### Create embeddings database

In [41]:
from sklearn.model_selection import train_test_split
from atom3d.datasets import load_dataset
from collapse import process_pdb, embed_protein, initialize_model, atom_info
from torch_geometric.nn import knn_graph
import torch

In [25]:
database = deserialize('../data/metal_database_balanced.pkl')

In [26]:
db = database['K']
db_pos = db['pos']
n_pos = len(db_pos)
num_pos_ec = db_pos['EC_NUMBER'].unique()
train_ec, test_ec = train_test_split(num_pos_ec, test_size=0.2, random_state=77)
train_df = db_pos[db_pos['EC_NUMBER'].isin(train_ec)]
test_df = db_pos[db_pos['EC_NUMBER'].isin(test_ec)]
train_keyres = dict(train_df.groupby('pdb_chain')['interactions'].apply(lambda x: [i for i in sum(x, []) if i.split('_')[0] in aa]))
test_keyres = dict(test_df.groupby('pdb_chain')['interactions'].apply(lambda x: [i for i in sum(x, []) if i.split('_')[0] in aa]))

In [42]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = initialize_model(device=device)

def compute_adjacency(df, resids, num_neighbors=8):
    df = df[df.name == 'CA']
    df['resid'] = df['resname'].apply(lambda x: atom_info.aa_to_letter(x)) + df['residue'].astype(str) + df['insertion_code'].astype(str).str.strip()
    df = df[df.resid.isin(resids)]
    edge_index = knn_graph(torch.tensor(df[['x', 'y', 'z']].to_numpy()), k=num_neighbors) # COO format
    return edge_index.numpy()

def create_graph_from_pdb(pdb_chain, pdb_dir, model, device):
    pdb, chain = pdb_chain[:4], pdb_chain[-1]
    atom_df = process_pdb(os.path.join(pdb_dir, pdb[1:3], f'pdb{pdb}.ent.gz'), chain=chain, include_hets=False)
    outdata = embed_protein(atom_df.copy(), model, device=device, include_hets=False, env_radius=10.0)
    if outdata is None:
        return
    outdata['adj'] = compute_adjacency(atom_df.copy(), outdata['resids'])
    return elem

In [43]:
pdb_dir = '/scratch/users/aderry/pdb'
for pdbc, res in train_keyres.items():
    emb_data = create_graph_from_pdb(pdbc, pdb_dir, model, device)
    print(emb_data)
    break

NameError: name 'self' is not defined