Extract node and edge features (atom number, bond type, partial charges); save as CSV file. one-hot encoder is used for the bond-type

In [8]:
!pip install tqdm
import pandas as pd
import json
from rdkit import Chem
from rdkit.Chem import rdchem
from tqdm import tqdm

# --- Input File paths ---
classification_csv = "../1_data/processed/TRPM8_graph_classification.csv"
regression_csv = "../1_data/processed/TRPM8_graph_regression.csv"
charges_json = "../2_feature_extraction/mulliken_charges.json"  # extracted from ORCA program

# --- Load partial charges of each atom in each molecule---
with open(charges_json, "r") as f:
    charges_data = json.load(f)  # Format: { "CHEMBLxxxxxx": { "mulliken_charges": [...] }, ... }

# --- Bond mapping ---
BOND_TYPE_MAPPING = {
    "SINGLE": 0,
    "DOUBLE": 1,
    "TRIPLE": 2,
    "AROMATIC": 3
}

# --- Atom features with partial charges ---
def extract_atom_features(mol, charges):
    features = []
    for i, atom in enumerate(mol.GetAtoms()):
        partial_charge = charges[i] if i < len(charges) else 0.0
        features.append([
            atom.GetAtomicNum(),
            partial_charge
        ])
    return features

# --- Edge features and indices ---
def extract_bond_features(mol):
    bond_features = []
    edge_indices = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()

        bond_type = str(bond.GetBondType()).upper()
        one_hot_type = [0] * 4  # For SINGLE, DOUBLE, TRIPLE, AROMATIC
        idx = BOND_TYPE_MAPPING.get(bond_type, -1)
        if idx >= 0:
            one_hot_type[idx] = 1

        other_flags = [
            int(bond.GetIsAromatic()),
            int(bond.GetIsConjugated()),
            int(bond.IsInRing())
        ]

        feat = one_hot_type + other_flags
        bond_features.append(feat)
        bond_features.append(feat)  # for reverse edge

        edge_indices.append((i, j))
        edge_indices.append((j, i))

    return bond_features, edge_indices


# --- Graph feature extraction ---
def process_dataset(df, label_column, output_path):
    records = []
    for _, row in tqdm(df.iterrows(), total=len(df)):
        mol_id = row["Molecule ChEMBL ID"]
        smiles = row["Smiles"]
        label = row[label_column]

        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            print(f"❌ Failed to parse: {mol_id}")
            continue

        charges = charges_data.get(mol_id, {}).get("mulliken_charges", [])
        if len(charges) < mol.GetNumAtoms():
            print(f"⚠️ Not enough charges for {mol_id}, defaulting to 0s")
            charges += [0.0] * (mol.GetNumAtoms() - len(charges))

        node_feats = extract_atom_features(mol, charges)
        edge_feats, edge_inds = extract_bond_features(mol)

        records.append({
            "mol_id": mol_id,
            "smiles": smiles,
            label_column: label,
            "node_features": node_feats,
            "edge_features": edge_feats,
            "edge_indices": edge_inds
        })

    output_df = pd.DataFrame(records)
    output_df.to_csv(output_path, index=False)
    print(f"✅ Saved: {output_path}")

# --- Generate outputs for both regression and classification tasks ---
df_class = pd.read_csv(classification_csv)
process_dataset(df_class, label_column="class_label",
                output_path="../2_feature_extraction/TRPM8_graph_features_class.csv")

df_reg = pd.read_csv(regression_csv)
process_dataset(df_reg, label_column="pChEMBL Value",
                output_path="../2_feature_extraction/TRPM8_graph_features_regression.csv")

df_reg.head()




100%|██████████| 529/529 [00:00<00:00, 2449.27it/s]


✅ Saved: ../2_feature_extraction/TRPM8_graph_features_class.csv


100%|██████████| 529/529 [00:00<00:00, 2155.53it/s]


✅ Saved: ../2_feature_extraction/TRPM8_graph_features_regression.csv


Unnamed: 0,Molecule ChEMBL ID,Smiles,pChEMBL Value
0,CHEMBL3235962,N#Cc1cccc(NC(=O)N2CCc3ccccc3[C@H]2c2ccc(C(F)(F...,7.08
1,CHEMBL3235983,C[C@H](NC(=O)N1CCc2ccccc2[C@H]1c1ccc(C(F)(F)F)...,8.0
2,CHEMBL1650511,FC(F)(F)c1ccccc1-c1cc(C(F)(F)F)c2[nH]c(C3=NOC4...,9.38
3,CHEMBL2443068,O=C1CC2(CCN(C(=O)Nc3ccc(C(F)(F)F)cc3)CC2)Oc2c(...,6.64
4,CHEMBL3959823,Cc1cccc(CN(C(=O)c2ccccc2)[C@@H](C(N)=O)c2ccccc...,6.06
