In [4]:
#!pip install tqdm
import pandas as pd
import json
from rdkit import Chem
from rdkit.Chem import rdchem
from tqdm import tqdm
import torch
import numpy as np
from ast import literal_eval
# --- Input File paths ---
classification_csv = "../1_data/processed/TRPM8_graph_classification.csv"
regression_csv = "../1_data/processed/TRPM8_graph_regression.csv"
charges_json = "../2_feature_extraction/mulliken_charges.json"  # extracted from ORCA program
# --- Load partial charges of each atom in each molecule---
with open(charges_json, "r") as f:
    charges_data = json.load(f)  # Format: { "CHEMBLxxxxxx": { "mulliken_charges": [...] }, ... }
# --- Bond mapping ---
BOND_TYPE_MAPPING = {
    "SINGLE": 0,
    "DOUBLE": 1,
    "TRIPLE": 2,
    "AROMATIC": 3
}
# --- Atom features with partial charges ---
def extract_atom_features(mol, charges):
    hybrid_dict = {2: [1,0,0], 3: [0,1,0], 4: [0,0,1]}
    features = []
    for i, atom in enumerate(mol.GetAtoms()):
        partial_charge = charges[i] if i < len(charges) else 0.0
        hybrid = hybrid_dict.get(int(atom.GetHybridization()), [0, 0, 0])
        features.append([
            atom.GetAtomicNum(),
            partial_charge,
            int(atom.IsInRing()),
            1-int(atom.IsInRing()),
            int(atom.GetIsAromatic()),
            1-int(atom.GetIsAromatic()),
            atom.GetFormalCharge(),
            atom.GetDegree()
        ] + hybrid)
    return features
# --- Edge features and indices ---
def extract_bond_features(mol):
    bond_features = []
    edge_indices = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        bond_type = str(bond.GetBondType()).upper()
        one_hot_type = [0] * 4  # For SINGLE, DOUBLE, TRIPLE, AROMATIC
        idx = BOND_TYPE_MAPPING.get(bond_type, -1)
        if idx >= 0:
            one_hot_type[idx] = 1
        other_flags = [
            int(bond.GetIsAromatic()),
            1-int(bond.GetIsAromatic()),
            int(bond.GetIsConjugated()),
            1-int(bond.GetIsConjugated()),
            int(bond.IsInRing()),
            1-int(bond.IsInRing())
        ]
        feat = one_hot_type + other_flags
        bond_features.append(feat)
        bond_features.append(feat)  # for reverse edge
        edge_indices.append((i, j))
        edge_indices.append((j, i))
    return bond_features, edge_indices
# --- Graph feature extraction ---
def process_dataset(df, label_column, output_path):
    records = []
    for _, row in tqdm(df.iterrows(), total=len(df)):
        mol_id = row["Molecule ChEMBL ID"]
        smiles = row["Smiles"]
        label = row[label_column]
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            print(f":x: Failed to parse: {mol_id}")
            continue
        charges = charges_data.get(mol_id, {}).get("mulliken_charges", [])
        if len(charges) < mol.GetNumAtoms():
            print(f":warning: Not enough charges for {mol_id}, defaulting to 0s")
            charges += [0.0] * (mol.GetNumAtoms() - len(charges))
        node_feats = extract_atom_features(mol, charges)
        edge_feats, edge_inds = extract_bond_features(mol)
        records.append({
            "mol_id": mol_id,
            "smiles": smiles,
            label_column: label,
            "node_features": node_feats,
            "edge_features": edge_feats,
            "edge_indices": edge_inds
        })
    output_df = pd.DataFrame(records)
    output_df.to_csv(output_path, index=False)
    print(f":white_check_mark: Saved: {output_path}")
# --- Generate outputs for both regression and classification tasks ---
df_class = pd.read_csv(classification_csv)
process_dataset(df_class, label_column="class_label",
                output_path="../2_feature_extraction/TRPM8_graph_features_class.csv")
df_reg = pd.read_csv(regression_csv)
process_dataset(df_reg, label_column="pChEMBL Value",
                output_path="../2_feature_extraction/TRPM8_graph_features_regression.csv")

100%|██████████| 529/529 [00:00<00:00, 998.47it/s] 


:white_check_mark: Saved: ../2_feature_extraction/TRPM8_graph_features_class.csv


100%|██████████| 529/529 [00:00<00:00, 1935.22it/s]


:white_check_mark: Saved: ../2_feature_extraction/TRPM8_graph_features_regression.csv


Inspect One Molecule's Features from CSV

In [None]:
import pandas as pd
from ast import literal_eval

# 📁 Path to your graph feature CSV
csv_path = "../2_feature_extraction/TRPM8_graph_features_class.csv"  # or use regression version
df = pd.read_csv(csv_path)

# 🧪 select a molecule to inspect
mol_id_to_check = "CHEMBL3687406"  # Change to any Molecule ChEMBL ID you're interested in

# 🔍 Locate the molecule
row = df[df["mol_id"] == mol_id_to_check]

if row.empty:
    print(f"❌ Molecule {mol_id_to_check} not found.")
else:
    row = row.iloc[0]  
    print(f"\n✅ Molecule: {row['mol_id']}")
    print(f"🔬 SMILES: {row['smiles']}")
    print(f"🏷️ Label: {row.get('class_label') or row.get('pChEMBL Value')}")

    #  print out features
    node_feats = literal_eval(row["node_features"])
    edge_feats = literal_eval(row["edge_features"])
    edge_inds = literal_eval(row["edge_indices"])

    print(f"\n🧠 Node Features (Total {len(node_feats)} atoms):")
    for i, f in enumerate(node_feats):
        print(f"  Atom {i}: {f}")

    print(f"\n🔗 Edge Features (Total {len(edge_feats)} bonds):")
    for i, (e, idx) in enumerate(zip(edge_feats, edge_inds)):
        print(f"  Edge {idx}: {e}")

    print("\n📌 Sample Summary:")
    print(f"- Atoms: {len(node_feats)}")
    print(f"- Edges: {len(edge_feats)}")



✅ Molecule: CHEMBL3687406
🔬 SMILES: O=C([O-])c1ccc(S(=O)(=O)N(Cc2ccc(C(F)(F)C3CC3)c(F)c2)c2ncc3ccccc3c2C(F)(F)F)cc1
🏷️ Label: 2

🧠 Node Features (Total 41 atoms):
  Atom 0: [8, -0.547867, 0, 1, 0, 1, 0, 1, 0, 1, 0]
  Atom 1: [6, 0.413805, 0, 1, 0, 1, 0, 3, 0, 1, 0]
  Atom 2: [8, -0.655835, 0, 1, 0, 1, -1, 1, 0, 1, 0]
  Atom 3: [6, 0.106281, 1, 0, 1, 0, 0, 3, 0, 1, 0]
  Atom 4: [6, -0.157795, 1, 0, 1, 0, 0, 2, 0, 1, 0]
  Atom 5: [6, -0.140982, 1, 0, 1, 0, 0, 2, 0, 1, 0]
  Atom 6: [6, -0.158261, 1, 0, 1, 0, 0, 3, 0, 1, 0]
  Atom 7: [16, 1.188058, 0, 1, 0, 1, 0, 4, 0, 0, 1]
  Atom 8: [8, -0.571514, 0, 1, 0, 1, 0, 1, 0, 1, 0]
  Atom 9: [8, -0.545595, 0, 1, 0, 1, 0, 1, 0, 1, 0]
  Atom 10: [7, -0.591051, 0, 1, 0, 1, 0, 3, 0, 1, 0]
  Atom 11: [6, -0.255981, 0, 1, 0, 1, 0, 2, 0, 0, 1]
  Atom 12: [6, 0.199253, 1, 0, 1, 0, 0, 3, 0, 1, 0]
  Atom 13: [6, -0.18892, 1, 0, 1, 0, 0, 2, 0, 1, 0]
  Atom 14: [6, -0.153793, 1, 0, 1, 0, 0, 2, 0, 1, 0]
  Atom 15: [6, -0.065583, 1, 0, 1, 0, 0, 3, 0, 1, 0]
 