In [16]:
!pip install tqdm
import pandas as pd
import json
from rdkit import Chem
from rdkit.Chem import rdchem
from tqdm import tqdm
import torch
import numpy as np
from ast import literal_eval
# --- Input File paths ---
classification_csv = "../1_data/processed/TRPM8_graph_classification.csv"
regression_csv = "../1_data/processed/TRPM8_graph_regression.csv"
charges_json = "../2_feature_extraction/mulliken_charges.json"  # extracted from ORCA program
# --- Load partial charges of each atom in each molecule---
with open(charges_json, "r") as f:
    charges_data = json.load(f)  # Format: { "CHEMBLxxxxxx": { "mulliken_charges": [...] }, ... }
# --- Bond mapping ---
BOND_TYPE_MAPPING = {
    "SINGLE": 0,
    "DOUBLE": 1,
    "TRIPLE": 2,
    "AROMATIC": 3
}
# --- Atom features with partial charges ---
def extract_atom_features(mol, charges):
    features = []
    for i, atom in enumerate(mol.GetAtoms()):
        partial_charge = charges[i] if i < len(charges) else 0.0
        features.append([
            atom.GetAtomicNum(),
            partial_charge,
            int(atom.IsInRing()),
            1-int(atom.IsInRing()),
            int(atom.GetIsAromatic()),
            1-int(atom.GetIsAromatic()),
            atom.GetFormalCharge(),
            atom.GetDegree()
        ])
    return features
# --- Edge features and indices ---
def extract_bond_features(mol):
    bond_features = []
    edge_indices = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        bond_type = str(bond.GetBondType()).upper()
        one_hot_type = [0] * 4  # For SINGLE, DOUBLE, TRIPLE, AROMATIC
        idx = BOND_TYPE_MAPPING.get(bond_type, -1)
        if idx >= 0:
            one_hot_type[idx] = 1
        other_flags = [
            int(bond.GetIsAromatic()),
            1-int(bond.GetIsAromatic()),
            int(bond.GetIsConjugated()),
            1-int(bond.GetIsConjugated()),
            int(bond.IsInRing()),
            1-int(bond.IsInRing())
        ]
        feat = one_hot_type + other_flags
        bond_features.append(feat)
        bond_features.append(feat)  # for reverse edge
        edge_indices.append((i, j))
        edge_indices.append((j, i))
    return bond_features, edge_indices
# --- Graph feature extraction ---
def process_dataset(df, label_column, output_path):
    records = []
    for _, row in tqdm(df.iterrows(), total=len(df)):
        mol_id = row["Molecule ChEMBL ID"]
        smiles = row["Smiles"]
        label = row[label_column]
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            print(f":x: Failed to parse: {mol_id}")
            continue
        charges = charges_data.get(mol_id, {}).get("mulliken_charges", [])
        if len(charges) < mol.GetNumAtoms():
            print(f":warning: Not enough charges for {mol_id}, defaulting to 0s")
            charges += [0.0] * (mol.GetNumAtoms() - len(charges))
        node_feats = extract_atom_features(mol, charges)
        edge_feats, edge_inds = extract_bond_features(mol)
        records.append({
            "mol_id": mol_id,
            "smiles": smiles,
            label_column: label,
            "node_features": node_feats,
            "edge_features": edge_feats,
            "edge_indices": edge_inds
        })
    output_df = pd.DataFrame(records)
    output_df.to_csv(output_path, index=False)
    print(f":white_check_mark: Saved: {output_path}")
# --- Generate outputs for both regression and classification tasks ---
df_class = pd.read_csv(classification_csv)
process_dataset(df_class, label_column="class_label",
                output_path="../2_feature_extraction/TRPM8_graph_features_class.csv")
df_reg = pd.read_csv(regression_csv)
process_dataset(df_reg, label_column="pChEMBL Value",
                output_path="../2_feature_extraction/TRPM8_graph_features_regression.csv")
df_reg.head()



100%|█████████████████████████████████████████████████████████████████████████████████████████████| 529/529 [00:00<00:00, 1419.65it/s]


:white_check_mark: Saved: ../2_feature_extraction/TRPM8_graph_features_class.csv


100%|█████████████████████████████████████████████████████████████████████████████████████████████| 529/529 [00:00<00:00, 1198.37it/s]

:white_check_mark: Saved: ../2_feature_extraction/TRPM8_graph_features_regression.csv





Unnamed: 0,Molecule ChEMBL ID,Smiles,pChEMBL Value
0,CHEMBL3235962,N#Cc1cccc(NC(=O)N2CCc3ccccc3[C@H]2c2ccc(C(F)(F...,7.08
1,CHEMBL3235983,C[C@H](NC(=O)N1CCc2ccccc2[C@H]1c1ccc(C(F)(F)F)...,8.0
2,CHEMBL1650511,FC(F)(F)c1ccccc1-c1cc(C(F)(F)F)c2[nH]c(C3=NOC4...,9.38
3,CHEMBL2443068,O=C1CC2(CCN(C(=O)Nc3ccc(C(F)(F)F)cc3)CC2)Oc2c(...,6.64
4,CHEMBL3959823,Cc1cccc(CN(C(=O)c2ccccc2)[C@@H](C(N)=O)c2ccccc...,6.06


In [43]:
hybrid_dict = {2: [1,0,0], 3: [0,1,0], 4: [0,0,1]}
def get_hybridization(smiles): # function to get the hybridization of each atom and return a series
    mol = Chem.MolFromSmiles(smiles) # create molecule from smiles string
    if mol:
        hybrid_list = []
        for atom in mol.GetAtoms(): # iterate over each atom in the molecule
            hybridization = atom.GetHybridization()
            hybrid_list.append(hybrid_dict[hybridization])
        return pd.DataFrame(hybrid_list)
    else:
        return pd.Series(None)

In [44]:
df = pd.read_csv("TRPM8_graph_features_regression.csv")

In [45]:
df['Hybridization'] = df['smiles'].apply(get_hybridization)

In [46]:
# torch.tensor(np.array(df["Hybridization"][0]), dtype = torch.float)
# i = 0
# node_features = torch.tensor(literal_eval(df["node_features"][i]), dtype = torch.float)
# node_feat_hybridization = torch.tensor(np.array(df["Hybridization"][i]))
# # node_features2 = torch.cat([torch.tensor(np.array(df[feature][i])).unsqueeze(1) for feature in Features], dim = 1)
# node_features_final = torch.cat([node_features, node_feat_hybridization], dim = 1)
# torch.save(node_features_final, "node_features")
# # node_features_final
# edge_features_final = torch.tensor(literal_eval(df["edge_features"][i]), dtype = torch.float)
# torch.save(edge_features_final, "edge_features")

In [48]:
# df["Hybridization"][1]

Unnamed: 0,0,1,2
0,0,0,1
1,0,0,1
2,0,1,0
3,0,1,0
4,0,1,0
5,0,1,0
6,0,0,1
7,0,0,1
8,0,1,0
9,0,1,0


In [50]:
from torch_geometric.data import Data, Dataset, DataLoader
import torch
import numpy as np
from ast import literal_eval

graphs = []
for i in df.index:
    node_features = torch.tensor(literal_eval(df["node_features"][i]), dtype = torch.float)
    node_feat_hybridization = torch.tensor(np.array(df["Hybridization"][i]))
    node_features_final = torch.cat([node_features, node_feat_hybridization], dim = 1)
    edge_features_final = torch.tensor(literal_eval(df["edge_features"][i]), dtype = torch.float)

    
    node_features = node_features_final
    edge_features = edge_features_final
    edges = torch.tensor(literal_eval(df["edge_indices"][i]), dtype = torch.int64).transpose(0,1)
    target = torch.tensor([df["pChEMBL Value"][i]], dtype = torch.float)
    # print(node_features.shape)
    graphs.append(Data(x = node_features, edge_index = edges, y = target, edge_attr = edge_features))

torch.save(graphs, "data_collection")

# Create a DataLoader to batch graphs together
loader = DataLoader(graphs, batch_size=1, shuffle=True)

graphs[0]

Data(x=[31, 11], edge_index=[2, 68], edge_attr=[68, 10], y=[1])

In [52]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
from torch_geometric.data import Data, Batch
from torch_scatter import scatter

class GNN(torch.nn.Module):
    def __init__(self, node_feat_dim, edge_feat_dim, hidden_dim):
        super(GNN, self).__init__()
        self.node_mlp = torch.nn.Linear(node_feat_dim, hidden_dim)
        self.edge_mlp = torch.nn.Linear(edge_feat_dim, hidden_dim)
        self.conv1 = SAGEConv(hidden_dim, hidden_dim)
        self.conv2 = SAGEConv(hidden_dim, hidden_dim)
        self.out_layer = torch.nn.Linear(hidden_dim, 1)  # Outputs a single number

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        x = self.node_mlp(x)  # Transform node features
        edge_attr = self.edge_mlp(edge_attr)  # Transform edge features

        # Aggregate edge features using scatter
        edge_aggr = scatter(edge_attr, edge_index[1], dim=0, reduce="mean")
        x = self.conv1(x, edge_index) + edge_aggr
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        
        # Global mean pooling (aggregate over nodes)
        x = scatter(x, batch, dim=0, reduce="mean")  
        
        return self.out_layer(x)  # Outputs a single number per graph

node_feat_dim = graphs[0].x.shape[1]
edge_feat_dim = graphs[0].edge_attr.shape[1]
hidden_dim = 32

model = GNN(node_feat_dim, edge_feat_dim, hidden_dim)
optimizer = torch.optim.Adam(model.parameters(), lr=0.000001)

# Training loop
for epoch in range(100):
    for batch in loader:
        optimizer.zero_grad()
        out = model(batch)
        loss = F.mse_loss(out, batch.y.float().unsqueeze(1))  # Adjust loss function based on your task
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        loss.backward()
        optimizer.step()

    # optimizer.zero_grad()
    # out = model(graphs[0])
    # loss = F.mse_loss(out, batch.y.float().unsqueeze(1))  # Adjust loss function based on your task
    # loss.backward()
    # optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item()}")

Epoch 1, Loss: 43.336273193359375
Epoch 2, Loss: 34.35185241699219
Epoch 3, Loss: 46.498023986816406
Epoch 4, Loss: 26.20098304748535
Epoch 5, Loss: 48.809288024902344
Epoch 6, Loss: 32.88529968261719
Epoch 7, Loss: 63.1019287109375
Epoch 8, Loss: 29.090251922607422
Epoch 9, Loss: 36.49740219116211
Epoch 10, Loss: 61.78485107421875
Epoch 11, Loss: 16.550643920898438
Epoch 12, Loss: 55.342559814453125
Epoch 13, Loss: 47.48004150390625
Epoch 14, Loss: 38.87635040283203
Epoch 15, Loss: 49.93194580078125
Epoch 16, Loss: 15.019915580749512
Epoch 17, Loss: 47.276161193847656
Epoch 18, Loss: 11.701358795166016
Epoch 19, Loss: 23.909486770629883
Epoch 20, Loss: 26.0580997467041
Epoch 21, Loss: 18.14934539794922
Epoch 22, Loss: 14.76583194732666
Epoch 23, Loss: 22.430469512939453
Epoch 24, Loss: 6.729621887207031
Epoch 25, Loss: 19.156023025512695
Epoch 26, Loss: 14.05169677734375
Epoch 27, Loss: 12.876519203186035
Epoch 28, Loss: 10.366278648376465
Epoch 29, Loss: 19.307571411132812
Epoch 30, 