In [1]:
from pathlib import Path
import numpy as np
import torch
from torch_geometric.data import Data
from Bio.PDB import PDBParser, is_aa

In [2]:
AA3 = [
    "ALA","ARG","ASN","ASP","CYS","GLU","GLN","GLY","HIS","ILE",
    "LEU","LYS","MET","PHE","PRO","SER","THR","TRP","TYR","VAL"
]
AA3_TO_IDX = {aa:i for i, aa in enumerate(AA3)}

In [4]:
def parse_pocket_residues(pdb_path: str):
    pdb_path = Path(pdb_path)
    if not pdb_path.exists():
        logger.error(f"pdb file not found: {pdb_path}")
        return []

    parser = PDBParser(QUIET=True)
    try:
        struct = parser.get_structure(pdb_path.stem, str(pdb_path))
    except Exception as e:
        return []

    residues = []
    for model in struct:
        for chain in model:
            chain_id = chain.get_id()
            for res in chain:
                # skip hetatms that are not amino acids (but include if flagged as amino acid)
                if not is_aa(res, standard=True):
                    continue
                resseq = res.get_id()[1]
                res_uid = f"{chain_id}_{res.get_resname().strip()}_{resseq}"
                residues.append((res_uid, res))
    print(f"Parsed {len(residues)} residues from {pdb_path.name}")
    return residues

In [5]:
pdb_path="../notebooks/data/raw/PDBbind_v2020_refined/refined-set/10gs/10gs_pocket.pdb"

res = parse_pocket_residues(pdb_path)

res[0]

Parsed 50 residues from 10gs_pocket.pdb


('A_TYR_7', <Residue TYR het=  resseq=7 icode= >)

In [6]:
res[0][1].get_resname().upper()
res[0][1].get_id()
list(res[0][1].get_atoms())

[<Atom N>,
 <Atom H>,
 <Atom CA>,
 <Atom C>,
 <Atom O>,
 <Atom CB>,
 <Atom CG>,
 <Atom CD1>,
 <Atom CD2>,
 <Atom CE1>,
 <Atom CE2>,
 <Atom CZ>,
 <Atom OH>,
 <Atom HH>]

In [7]:
def residue_to_feature(res):
    resname = res.get_resname().upper()
    onehot = np.zeros(len(AA3), dtype=np.float32)
    if resname in AA3_TO_IDX:
        onehot[AA3_TO_IDX[resname]] = 1.0
    try:
        resseq = res.get_id()[1]  
    except Exception:
        resseq = 0
        
    seq_feat = np.array([resseq], dtype=np.float32)

    b_factors = [atom.get_bfactor() for atom in res.get_atoms()]
    avg_b = np.mean(b_factors).astype(np.float32) if len(b_factors) > 0 else np.array([0.0], dtype=np.float32)
    feat = np.concatenate([onehot, seq_feat/100.0, np.array([avg_b/100.0], dtype=np.float32)])
    return feat

In [8]:
residue_to_feature(res[0][1])

array([0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 1.        , 0.        ,
       0.07      , 0.16237858], dtype=float32)

In [9]:
residue_to_feature(res[1][1])

array([0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 1.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.08      , 0.22460833], dtype=float32)

In [3]:
def get_ca_coord(res):
    for atom in res:
        name = atom.get_name()
        if name == "CA":
            coord = atom.get_coord()
            return np.array(coord, dtype=np.float32)
    # fallback: centroid of heavy atoms
    coords = [a.get_coord() for a in res if a.element != "H"]
    if len(coords) == 0:
        # last fallback: any atom
        coords = [a.get_coord() for a in res]
    centroid = np.mean(coords, axis=0) if len(coords) > 0 else np.zeros(3, dtype=np.float32)
    return np.array(centroid, dtype=np.float32)

In [12]:
residues = parse_pocket_residues(pdb_path)

Parsed 50 residues from 10gs_pocket.pdb


In [14]:
node_feats = []
positions = []
res_ids = []

for res_uid, res in residues:
    node_feats.append(residue_to_feature(res))
    positions.append(get_ca_coord(res))
    res_ids.append(res_uid)
    
print(node_feats[0], positions[0], res_ids[0])

[0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 1.         0.         0.07       0.16237858] [15.226  2.331 33.585] A_TYR_7


In [15]:
x = torch.tensor(np.vstack(node_feats), dtype=torch.float32)
pos = torch.tensor(np.vstack(positions), dtype=torch.float32)

In [16]:
x

tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0700, 0.1624],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0800, 0.2246],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0900, 0.2501],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.9700, 0.1941],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.9800, 0.1831],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 1.0200, 0.1807]])

In [17]:
cutoff = 8.0
coords = pos.numpy()
N = coords.shape[0]
edge_index = []
edge_attr = []

for i in range(N):
    for j in range(i + 1, N):
        d = np.linalg.norm(coords[i] - coords[j])
        print(d)
        if d <= cutoff:
            edge_index.append([i, j])
            edge_index.append([j, i])
            edge_attr.append([d])
            edge_attr.append([d])

3.8041472
6.576998
7.799639
9.016218
9.234269
8.968405
7.7516665
11.919353
5.6895065
8.404007
10.225339
12.883716
10.055275
13.621239
11.205105
11.430394
11.759535
14.067958
12.386234
9.233244
5.7374954
4.9615965
6.8144536
10.850437
11.564168
10.6579485
14.43536
15.129734
12.900584
14.899175
18.711285
17.934727
18.336866
16.631416
18.56884
17.704512
17.188105
12.890905
12.870509
12.132729
11.125574
14.87673
16.251474
16.972385
18.073217
18.04904
21.311855
18.950932
23.377779
3.8365161
5.576741
8.2524395
8.6334715
9.838013
9.389637
14.254345
4.947376
5.835318
7.3100467
10.444556
8.886307
11.967676
12.224593
12.153758
13.88403
15.784937
13.517427
10.403983
7.713347
7.940643
10.34366
13.990725
14.182832
13.16005
16.892506
18.132719
16.15987
17.751234
19.685328
17.965446
18.225838
15.550702
16.489616
15.527781
18.540583
10.307749
10.934884
9.737014
8.014354
11.626725
12.976323
19.913153
20.290442
20.075851
22.898094
20.12778
23.654882
3.8327382
7.0955806
8.956313
11.123535
10.107804
14.682

In [18]:
edge_index

[[0, 1],
 [1, 0],
 [0, 2],
 [2, 0],
 [0, 3],
 [3, 0],
 [0, 7],
 [7, 0],
 [0, 9],
 [9, 0],
 [0, 21],
 [21, 0],
 [0, 22],
 [22, 0],
 [0, 23],
 [23, 0],
 [1, 2],
 [2, 1],
 [1, 3],
 [3, 1],
 [1, 9],
 [9, 1],
 [1, 10],
 [10, 1],
 [1, 11],
 [11, 1],
 [1, 21],
 [21, 1],
 [1, 22],
 [22, 1],
 [2, 3],
 [3, 2],
 [2, 4],
 [4, 2],
 [2, 9],
 [9, 2],
 [2, 10],
 [10, 2],
 [2, 38],
 [38, 2],
 [2, 41],
 [41, 2],
 [3, 4],
 [4, 3],
 [3, 5],
 [5, 3],
 [3, 38],
 [38, 3],
 [3, 39],
 [39, 3],
 [3, 40],
 [40, 3],
 [3, 41],
 [41, 3],
 [4, 5],
 [5, 4],
 [4, 6],
 [6, 4],
 [4, 7],
 [7, 4],
 [4, 38],
 [38, 4],
 [4, 39],
 [39, 4],
 [4, 40],
 [40, 4],
 [5, 6],
 [6, 5],
 [5, 7],
 [7, 5],
 [5, 39],
 [39, 5],
 [5, 40],
 [40, 5],
 [5, 41],
 [41, 5],
 [6, 7],
 [7, 6],
 [6, 22],
 [22, 6],
 [6, 26],
 [26, 6],
 [7, 8],
 [8, 7],
 [7, 22],
 [22, 7],
 [7, 26],
 [26, 7],
 [8, 26],
 [26, 8],
 [8, 29],
 [29, 8],
 [8, 30],
 [30, 8],
 [8, 37],
 [37, 8],
 [9, 10],
 [10, 9],
 [9, 11],
 [11, 9],
 [9, 13],
 [13, 9],
 [10, 11],
 [11, 10]

In [19]:
edge_attr

[[3.8041472],
 [3.8041472],
 [6.576998],
 [6.576998],
 [7.799639],
 [7.799639],
 [7.7516665],
 [7.7516665],
 [5.6895065],
 [5.6895065],
 [5.7374954],
 [5.7374954],
 [4.9615965],
 [4.9615965],
 [6.8144536],
 [6.8144536],
 [3.8365161],
 [3.8365161],
 [5.576741],
 [5.576741],
 [4.947376],
 [4.947376],
 [5.835318],
 [5.835318],
 [7.3100467],
 [7.3100467],
 [7.713347],
 [7.713347],
 [7.940643],
 [7.940643],
 [3.8327382],
 [3.8327382],
 [7.0955806],
 [7.0955806],
 [6.9390197],
 [6.9390197],
 [6.8260283],
 [6.8260283],
 [7.628834],
 [7.628834],
 [7.998134],
 [7.998134],
 [3.8067436],
 [3.8067436],
 [5.6062617],
 [5.6062617],
 [5.1106343],
 [5.1106343],
 [5.6702123],
 [5.6702123],
 [5.8920627],
 [5.8920627],
 [6.18798],
 [6.18798],
 [3.8018625],
 [3.8018625],
 [6.3950753],
 [6.3950753],
 [5.70044],
 [5.70044],
 [6.2957726],
 [6.2957726],
 [4.902502],
 [4.902502],
 [6.6312633],
 [6.6312633],
 [3.7905297],
 [3.7905297],
 [5.5796027],
 [5.5796027],
 [5.626104],
 [5.626104],
 [5.5025525],
 [5.5025

In [29]:
torch.tensor(edge_index, dtype=torch.long).T

tensor([[ 0,  1,  0,  2,  0,  3,  0,  7,  0,  9,  0, 21,  0, 22,  0, 23,  1,  2,
          1,  3,  1,  9,  1, 10,  1, 11,  1, 21,  1, 22,  2,  3,  2,  4,  2,  9,
          2, 10,  2, 38,  2, 41,  3,  4,  3,  5,  3, 38,  3, 39,  3, 40,  3, 41,
          4,  5,  4,  6,  4,  7,  4, 38,  4, 39,  4, 40,  5,  6,  5,  7,  5, 39,
          5, 40,  5, 41,  6,  7,  6, 22,  6, 26,  7,  8,  7, 22,  7, 26,  8, 26,
          8, 29,  8, 30,  8, 37,  9, 10,  9, 11,  9, 13, 10, 11, 10, 12, 10, 13,
         11, 12, 11, 13, 11, 14, 12, 13, 12, 14, 13, 14, 13, 15, 13, 16, 15, 16,
         15, 17, 16, 17, 16, 19, 16, 20, 17, 18, 17, 19, 17, 20, 17, 21, 17, 23,
         17, 24, 18, 19, 18, 20, 18, 24, 18, 46, 19, 20, 19, 21, 20, 21, 20, 22,
         20, 23, 20, 24, 21, 22, 21, 23, 21, 24, 21, 25, 22, 23, 22, 24, 22, 25,
         22, 26, 23, 24, 23, 25, 24, 25, 24, 26, 24, 28, 24, 29, 24, 44, 25, 26,
         25, 27, 25, 28, 25, 29, 25, 44, 25, 45, 25, 46, 26, 27, 26, 28, 26, 29,
         26, 30, 27, 28, 27,

In [30]:
edge_index = torch.tensor(edge_index, dtype=torch.long).T
edge_attr = torch.tensor(edge_attr, dtype=torch.float32)

In [40]:
data = Data(x=x, pos=pos, edge_index=edge_index, edge_attr=edge_attr)

In [41]:
data

Data(x=[50, 22], edge_index=[2, 302], edge_attr=[302, 1], pos=[50, 3])

In [42]:
Path(pdb_path).parent.name

'10gs'

In [43]:
data.metadata = {
    "complex_id": Path(pdb_path).parent.name,
    "pdb_path": pdb_path,
    "residue_ids": res_ids,
    "node_type": "residue",
    "residue_cutoff": cutoff
    }

In [44]:
data

Data(
  x=[50, 22],
  edge_index=[2, 302],
  edge_attr=[302, 1],
  pos=[50, 3],
  metadata={
    complex_id='10gs',
    pdb_path='../notebooks/data/raw/PDBbind_v2020_refined/refined-set/10gs/10gs_pocket.pdb',
    residue_ids=[50],
    node_type='residue',
    residue_cutoff=8.0,
  }
)