In [1]:
import pandas as pd
from rdkit import Chem
import networkx as nx
import torch
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.error')

from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score
import numpy as np
from sklearn.metrics import recall_score
from sklearn.metrics import roc_auc_score
import torch.nn.functional as F

import torch.nn as nn
import torch.nn.functional as F

In [2]:
ddi_fp = 'drugbank.tab'
ddi = pd.read_csv(ddi_fp, sep='\t')
ddi.head()

Unnamed: 0,ID1,ID2,Y,Map,X1,X2
0,DB04571,DB00460,1,#Drug1 may increase the photosensitizing activ...,CC1=CC2=CC3=C(OC(=O)C=C3C)C(C)=C2O1,COC(=O)CCC1=C2NC(\C=C3/N=C(/C=C4\N\C(=C/C5=N/C...
1,DB00855,DB00460,1,#Drug1 may increase the photosensitizing activ...,NCC(=O)CCC(O)=O,COC(=O)CCC1=C2NC(\C=C3/N=C(/C=C4\N\C(=C/C5=N/C...
2,DB09536,DB00460,1,#Drug1 may increase the photosensitizing activ...,O=[Ti]=O,COC(=O)CCC1=C2NC(\C=C3/N=C(/C=C4\N\C(=C/C5=N/C...
3,DB01600,DB00460,1,#Drug1 may increase the photosensitizing activ...,CC(C(O)=O)C1=CC=C(S1)C(=O)C1=CC=CC=C1,COC(=O)CCC1=C2NC(\C=C3/N=C(/C=C4\N\C(=C/C5=N/C...
4,DB09000,DB00460,1,#Drug1 may increase the photosensitizing activ...,CC(CN(C)C)CN1C2=CC=CC=C2SC2=C1C=C(C=C2)C#N,COC(=O)CCC1=C2NC(\C=C3/N=C(/C=C4\N\C(=C/C5=N/C...


In [3]:
print(ddi['Y'].dtype)
print(ddi['Y'].unique())

int64
[ 1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
 73 74 75 76 77 78 79 80 81 82 83 84 85 86]


In [4]:
# filter incorrect smiles rows out 

def valid_smiles(smiles): 
    if not isinstance(smiles, str): 
        return False
    return Chem.MolFromSmiles(smiles) is not None

invalid_rows = ddi[~(ddi['X1'].apply(valid_smiles) & ddi['X2'].apply(valid_smiles))]
ddi_cleaned = ddi.drop(invalid_rows.index).reset_index(drop = True)

print(f"ddi size: {ddi.shape[0]}")
print(f"ddi_cleaned size: {ddi_cleaned.shape[0]}")
print(f"Rows removed: {len(ddi) - len(ddi_cleaned)}")

ddi size: 191808
ddi_cleaned size: 191798
Rows removed: 10


In [5]:
top5_labels = ddi_cleaned['Y'].value_counts().nlargest(5).index
ddi_filt = ddi_cleaned[ddi_cleaned['Y'].isin(top5_labels)].reset_index(drop = True)

label_mapping = {label: idx for idx, label in enumerate(top5_labels)}
print("Label Mapping:", label_mapping)
ddi_filt['Y'] = ddi_filt['Y'].map(label_mapping)

n = 10000  # Number of samples per label
#ddi_filt = ddi_filt.groupby('Y', group_keys=False).apply(lambda x: x.sample(n, replace=True)).reset_index(drop=True)
ddi_filt = ddi_filt.groupby('Y', group_keys=False).apply(lambda x: x.sample(min(len(x), n), replace=False)).reset_index(drop=True)

print(ddi_filt['Y'].value_counts())  # Check how many samples per label


Label Mapping: {49: 0, 47: 1, 73: 2, 75: 3, 60: 4}
Y
0    10000
1    10000
2    10000
3     9470
4     8397
Name: count, dtype: int64


  ddi_filt = ddi_filt.groupby('Y', group_keys=False).apply(lambda x: x.sample(min(len(x), n), replace=False)).reset_index(drop=True)


In [6]:
# convert smiles string to graph
def smiles_to_graph(smiles): 
    mol = Chem.MolFromSmiles(smiles)

    if mol is None: 
        raise ValueError(f"invalid SMILES string {smiles}")

    node_features = []
    for atom in mol.GetAtoms(): 
        atomic_num = atom.GetAtomicNum()
        degree = atom.GetDegree()
        hybridization = atom.GetHybridization()
        is_aromatic = atom.GetIsAromatic()
        
        features = [atomic_num, degree, int(hybridization), int(is_aromatic)]
        node_features.append(features)
        
    edges = []
    if mol.GetNumBonds() > 0 : 
        for bond in mol.GetBonds(): 
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            edges.append((i, j))
            edges.append((j, i))

    edge_index = torch.tensor(edges, dtype = torch.long).t().contiguous() if edges else torch.empty((2, 0), dtype=torch.long)

    x = torch.tensor(node_features, dtype = torch.float)

    return Data(x=x, edge_index=edge_index)

def convert_to_graphs(ddi_filt): 
    graph_data = []
    for _, row in ddi_filt.iterrows(): 
        graph_X1 = smiles_to_graph(row['X1'])
        graph_X2 = smiles_to_graph(row['X2'])

        graph_data.append((graph_X1, graph_X2, row['Y']))

    return graph_data

# convert codes to graphs
graph_data = convert_to_graphs(ddi_filt)

In [7]:
print(graph_data[0])

(Data(x=[19, 4], edge_index=[2, 40]), Data(x=[26, 4], edge_index=[2, 56]), 0)


In [8]:
from torch.utils.data import Dataset
from sklearn.metrics import classification_report

class DDI_GraphDataset(Dataset): 
    def __init__(self, graph_data): 
        self.graph_data = graph_data

    def __len__(self): 
        return len(self.graph_data)

    def __getitem__(self, idx): 
        graph_X1, graph_X2, label = self.graph_data[idx]
        return graph_X1, graph_X2, torch.tensor(label, dtype = torch.long)
    
dataset = DDI_GraphDataset(graph_data)

In [9]:
dataset[0]

(Data(x=[19, 4], edge_index=[2, 40]),
 Data(x=[26, 4], edge_index=[2, 56]),
 tensor(0))

In [10]:
from torch_geometric.data import Batch

def collate_fn(batch):
    X1_batch, X2_batch, labels = zip(*batch)
    
    X1_batch = Batch.from_data_list(X1_batch)
    X2_batch = Batch.from_data_list(X2_batch)
    
    labels = torch.stack(labels)
    
    return X1_batch, X2_batch, labels

from torch.utils.data import DataLoader

loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

for graph_X1_batch, graph_X2_batch, labels in loader:
    # Forward pass through the GNN
    z1 = gnn(graph_X1_batch)
    z2 = gnn(graph_X2_batch)
    
    # Combine embeddings
    z_combined = torch.cat([z1, z2, torch.abs(z1 - z2), z1 * z2], dim=1)
    
    # Forward pass through the classifier
    output = classifier(z_combined)
    
    # Calculate loss and backpropagate
    loss = criterion(output, labels)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [11]:
from torch_geometric.nn import global_mean_pool

class GNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_dim, out_dim):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, out_dim)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = torch.relu(x)
        
        x = self.conv2(x, edge_index)
        x = torch.relu(x)
        
        x = global_mean_pool(x, data.batch)
        
        x = self.fc(x)
        return x


In [14]:
from torch.utils.data import DataLoader, random_split, Subset
from sklearn.model_selection import train_test_split

dataset = DDI_GraphDataset(graph_data)

labels = [label for _, _, label in dataset]
train_indices, test_indices = train_test_split(
np.arange(len(dataset)), 
test_size = 0.2, 
stratify = labels, 
random_state = 42
)

train_dataset = Subset(dataset, train_indices)
test_dataset = Subset(dataset, test_indices)

train_loader = DataLoader(train_dataset, batch_size = 32, shuffle = True, collate_fn = collate_fn)
test_loader = DataLoader(test_dataset, batch_size = 32, shuffle = False, collate_fn = collate_fn)

#loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

In [15]:
import torch.optim as optim

gnn = GNN(in_channels = 4, hidden_dim = 8, out_dim = 5)
classifier = nn.Linear(20, 5)  # Adjust the input size based on your GNN output
optimizer = optim.Adam(list(gnn.parameters()) + list(classifier.parameters()), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [16]:
from sklearn.metrics import classification_report

# training loop 
num_epochs = 16  

for epoch in range(num_epochs):
    gnn.train() 
    
    true_labels = []
    predicted_labels = []
    
    for graph_X1_batch, graph_X2_batch, labels in train_loader:
        # forward pass through the GNN
        z1 = gnn(graph_X1_batch)
        z2 = gnn(graph_X2_batch)

        # combine the graph embeddings
        z_combined = torch.cat([z1, z2, torch.abs(z1 - z2), z1 * z2], dim=1)

        # forward pass through the classifier
        output = classifier(z_combined)

        # calculate loss and backpropagate
        loss = criterion(output, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        _, predicted = torch.max(output, 1)

        true_labels.extend(labels.cpu().numpy())
        predicted_labels.extend(predicted.cpu().numpy())

    print(f"Epoch {epoch + 1}, Loss: {loss.item()}")
    
print(classification_report(true_labels, predicted_labels))

Epoch 1, Loss: 1.562505841255188
Epoch 2, Loss: 1.580553650856018
Epoch 3, Loss: 1.5074833631515503
Epoch 4, Loss: 1.5265101194381714
Epoch 5, Loss: 1.5521981716156006
Epoch 6, Loss: 1.5246347188949585
Epoch 7, Loss: 1.7109894752502441
Epoch 8, Loss: 1.512468695640564
Epoch 9, Loss: 1.4684982299804688
Epoch 10, Loss: 1.5067403316497803
Epoch 11, Loss: 1.5810078382492065
Epoch 12, Loss: 1.4927995204925537
Epoch 13, Loss: 1.4266656637191772
Epoch 14, Loss: 1.3921960592269897
Epoch 15, Loss: 1.427808165550232
Epoch 16, Loss: 1.613053798675537
              precision    recall  f1-score   support

           0       0.33      0.13      0.18      8000
           1       0.30      0.47      0.37      8000
           2       0.34      0.27      0.30      8000
           3       0.41      0.41      0.41      7576
           4       0.36      0.47      0.41      6717

    accuracy                           0.35     38293
   macro avg       0.35      0.35      0.33     38293
weighted avg       0

In [18]:
from torch.utils.data import DataLoader

num_epochs = 16  

# test loop 
for epoch in range(num_epochs):
    gnn.eval() 
    
    true_labels = []
    predicted_labels = []
    
    for graph_X1_batch, graph_X2_batch, labels in test_loader:
        # Forward pass through the GNN
        z1 = gnn(graph_X1_batch)
        z2 = gnn(graph_X2_batch)

        # Combine the graph embeddings
        z_combined = torch.cat([z1, z2, torch.abs(z1 - z2), z1 * z2], dim=1)

        # Forward pass through the classifier
        output = classifier(z_combined)

        # Calculate loss and backpropagate
        loss = criterion(output, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        _, predicted = torch.max(output, 1)

        true_labels.extend(labels.cpu().numpy())
        predicted_labels.extend(predicted.cpu().numpy())

    print(f"Epoch {epoch + 1}, Loss: {loss.item()}")
    
print(classification_report(true_labels, predicted_labels))

Epoch 1, Loss: 1.8476295471191406
Epoch 2, Loss: 1.8380569219589233
Epoch 3, Loss: 1.8323599100112915
Epoch 4, Loss: 1.8289803266525269
Epoch 5, Loss: 1.8243392705917358
Epoch 6, Loss: 1.820809006690979
Epoch 7, Loss: 1.8155337572097778
Epoch 8, Loss: 1.8137401342391968
Epoch 9, Loss: 1.812421441078186
Epoch 10, Loss: 1.810024380683899
Epoch 11, Loss: 1.808811068534851
Epoch 12, Loss: 1.8067926168441772
Epoch 13, Loss: 1.8044925928115845
Epoch 14, Loss: 1.8039122819900513
Epoch 15, Loss: 1.8017951250076294
Epoch 16, Loss: 1.800451397895813
              precision    recall  f1-score   support

           0       0.30      0.14      0.19      2000
           1       0.31      0.47      0.38      2000
           2       0.33      0.27      0.30      2000
           3       0.43      0.44      0.44      1894
           4       0.38      0.48      0.42      1680

    accuracy                           0.35      9574
   macro avg       0.35      0.36      0.34      9574
weighted avg       0