In [3]:
#!pip install rdkit
#pip install torch-geometric

import pandas as pd
from rdkit import Chem
import networkx as nx
import torch
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
from rdkit import RDLogger


In [5]:
ddi_fp = r"C:\Users\sreej\Desktop\drugbank.tab"
ddi = pd.read_csv(ddi_fp, sep='\t')
ddi.head()

Unnamed: 0,ID1,ID2,Y,Map,X1,X2
0,DB04571,DB00460,1,#Drug1 may increase the photosensitizing activ...,CC1=CC2=CC3=C(OC(=O)C=C3C)C(C)=C2O1,COC(=O)CCC1=C2NC(\C=C3/N=C(/C=C4\N\C(=C/C5=N/C...
1,DB00855,DB00460,1,#Drug1 may increase the photosensitizing activ...,NCC(=O)CCC(O)=O,COC(=O)CCC1=C2NC(\C=C3/N=C(/C=C4\N\C(=C/C5=N/C...
2,DB09536,DB00460,1,#Drug1 may increase the photosensitizing activ...,O=[Ti]=O,COC(=O)CCC1=C2NC(\C=C3/N=C(/C=C4\N\C(=C/C5=N/C...
3,DB01600,DB00460,1,#Drug1 may increase the photosensitizing activ...,CC(C(O)=O)C1=CC=C(S1)C(=O)C1=CC=CC=C1,COC(=O)CCC1=C2NC(\C=C3/N=C(/C=C4\N\C(=C/C5=N/C...
4,DB09000,DB00460,1,#Drug1 may increase the photosensitizing activ...,CC(CN(C)C)CN1C2=CC=CC=C2SC2=C1C=C(C=C2)C#N,COC(=O)CCC1=C2NC(\C=C3/N=C(/C=C4\N\C(=C/C5=N/C...


In [7]:
# filter incorrect smiles rows out 

RDLogger.DisableLog('rdApp.error')

def valid_smiles(smiles): 
    if not isinstance(smiles, str): 
        return False
    return Chem.MolFromSmiles(smiles) is not None

invalid_rows = ddi[~(ddi['X1'].apply(valid_smiles) & ddi['X2'].apply(valid_smiles))]
ddi_cleaned = ddi.drop(invalid_rows.index).reset_index(drop = True)

print(f"ddi size: {ddi.shape[0]}")
print(f"ddi_cleaned size: {ddi_cleaned.shape[0]}")
print(f"Rows removed: {len(ddi) - len(ddi_cleaned)}")

ddi size: 191808
ddi_cleaned size: 191798
Rows removed: 10


In [9]:
top20_labels = ddi_cleaned['Y'].value_counts().nlargest(20).index
ddi_filt = ddi_cleaned[ddi_cleaned['Y'].isin(top20_labels)].reset_index(drop = True)
print(top20_labels)
ddi_filt.head()
ddi_filt = ddi_filt.iloc[:1000].reset_index(drop=True)


Index([49, 47, 73, 75, 60, 70, 20, 16, 4, 6, 37, 9, 72, 54, 83, 58, 32, 27, 67,
       64],
      dtype='int64', name='Y')


In [11]:
# convert smiles string to graph
def smiles_to_graph(smiles): 
    mol = Chem.MolFromSmiles(smiles)

    if mol is None: 
        raise ValueError(f"invalid SMILES string {smiles}")

    node_features = [atom.GetAtomicNum() for atom in mol.GetAtoms()]

    edges = []
    if mol.GetNumBonds() > 0 : 
        for bond in mol.GetBonds(): 
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            edges.append((i, j))
            edges.append((j, i))
    else: 
        print(f"No bonds found for molecule: {smiles}")

    edge_index = torch.tensor(edges, dtype = torch.long).t().contiguous() if edges else torch.empty((2, 0), dtype=torch.long)

    x = torch.tensor(node_features, dtype = torch.float).view(-1, 1)

    return Data(x=x, edge_index=edge_index)

def convert_to_graphs(ddi_filt): 
    graph_data = []
    for _, row in ddi_filt.iterrows(): 
        graph_X1 = smiles_to_graph(row['X1'])
        graph_X2 = smiles_to_graph(row['X2'])

        graph_data.append((graph_X1, graph_X2, row['Y']))

    return graph_data

# convert codes to graphs
graph_data = convert_to_graphs(ddi_filt)

In [13]:
from torch.utils.data import Dataset

class GraphDataset(Dataset): 
    def __init__(self, graph_data): 
        self.graph_data = graph_data

    def __len__(self): 
        return len(self.graph_data)

    def __getitem__(self, idx): 
        graph_X1, graph_X2, label = self.graph_data[idx]
        return graph_X1, graph_X2, label

graph_dataset = GraphDataset(graph_data)

In [15]:
from sklearn.model_selection import train_test_split

# add stratify parameter
train_data, test_data = train_test_split(graph_dataset.graph_data, test_size = 0.2)

train_dataset = GraphDataset(train_data)
test_dataset = GraphDataset(test_data)

In [17]:
# create GNN 
import torch.nn as nn
import torch.nn.functional as F


class GNNModel(nn.Module): 
    def __init__(self, in_channels, hidden_channels, out_channels): 
        super(GNNModel, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)

        # concatenating X1 and X2
        self.fc = nn.Linear(2 * hidden_channels, out_channels)

    def forward(self, data1, data2): 
        x1, edge_index1 = data1.x, data1.edge_index
        x2, edge_index2 = data2.x, data2.edge_index

        # apply graph convolution on graph 1 (X1)
        x1 = F.relu(self.conv1(x1, edge_index1))
        x1 = self.conv2(x1, edge_index1)
    
        # apply graph convolution on graph 2 (X2)
        x2 = F.relu(self.conv1(x2, edge_index2))
        x2 = self.conv2(x2, edge_index2)
    
        x = torch.cat([x1.mean(dim=0), x2.mean(dim=0)], dim = -1)
        # output layer predicts one of 20 classes
        out = self.fc(x)
        return out
    

In [35]:
from torch.optim import Adam

# define model, loss func, and optimizer
model = GNNModel(in_channels=1, hidden_channels=32, out_channels=20)
optimizer = Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()  

#train_loader = DataLoader(train_dataset, batch_size = 32, shuffle = True)
# training (10 epochs)
for epoch in range(10):  
    model.train()
    total_loss = 0
    
    #for batch_idx, (data1, data2, label) in enumerate(train_loader): 
    for data1, data2, label in train_dataset:
        label = torch.tensor(label, dtype = torch.long)
        label = label -1
        label = label % 20
        
        optimizer.zero_grad()

        output = model(data1, data2)

        # Debugging print statements
        print(f"Epoch {epoch+1}, Output: {output}")
        print(f"Epoch {epoch+1}, Label: {label}")
        
        loss = criterion(output, label)
        
        loss.backward()
        
        optimizer.step()

        total_loss += loss.item()
        
        #print(f'Epoch {epoch + 1}, Batch {batch_idx + 1}/{len(train_loader)}, Loss: {loss.item()}')


    # Print the loss for the current epoch
    print(f'Epoch {epoch+1}, Loss: {total_loss/len(train_dataset)}')


Epoch 1, Output: tensor([ 0.2668, -1.0109,  1.0230, -0.6005, -0.0393,  1.6127,  0.1700,  0.4327,
        -1.0175, -1.2545, -0.0463,  0.2700, -0.6872, -0.0638,  0.4931, -0.8514,
         0.7464, -0.1852, -0.4803, -0.6304], grad_fn=<ViewBackward0>)
Epoch 1, Label: 3
Epoch 1, Output: tensor([-0.2105, -1.5564,  0.3371,  0.4525, -0.4177,  0.9704, -0.4650, -0.2036,
        -1.3308, -1.8020, -0.6730, -0.4678, -1.2543, -0.6040, -0.0770, -1.4122,
         0.1030, -0.8808, -1.1752, -1.0872], grad_fn=<ViewBackward0>)
Epoch 1, Label: 3
Epoch 1, Output: tensor([-0.6417, -2.0572, -0.2767,  1.4472, -0.8375,  0.4015, -1.0717, -0.8551,
        -1.6401, -2.3696, -1.3366, -1.1647, -1.7851, -1.1175, -0.6224, -1.9477,
        -0.5011, -1.5293, -1.8145, -1.5624], grad_fn=<ViewBackward0>)
Epoch 1, Label: 3
Epoch 1, Output: tensor([-1.1315, -2.6324, -0.9001,  2.4368, -1.3772, -0.1607, -1.7072, -1.5702,
        -2.0755, -3.0246, -2.0627, -1.8948, -2.4147, -1.6674, -1.2135, -2.5546,
        -1.1303, -2.1911, -2

In [33]:
# Evaluation loop
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for data1, data2, label in test_dataset:
        output = model(data1, data2)

        print(f"Output: {output}")
        
        predicted = output.argmax(dim = -1)

        print(f"Predicted: {predicted}")
        print(f"Actual Label: {label}")
        
        correct += (predicted == label).sum().item()
        if isinstance(label, torch.Tensor): 
            total += label.size(0)
        else: 
            total += 1

print(f'Accuracy: {correct/total * 100}%')


Output: tensor([-14.6799, -13.5582, -12.6944,  19.7873, -16.0750, -15.6747, -14.8757,
        -15.5374, -15.6314, -12.8781, -15.4957, -16.5518, -15.8624, -12.5998,
        -15.7222, -14.6583, -14.6204, -13.2666, -11.3226, -14.8237])
Predicted: 3
Label: 4
Output: tensor([-14.8547, -13.7534, -12.8926,  20.0076, -16.3284, -15.9250, -15.0816,
        -15.6896, -15.9443, -13.1132, -15.6923, -16.8543, -16.0125, -12.8159,
        -15.9416, -14.9455, -14.8237, -13.4248, -11.5460, -15.0477])
Predicted: 3
Label: 4
Output: tensor([-16.2684, -15.1958, -14.3033,  21.8502, -18.1240, -17.6795, -16.6266,
        -17.0578, -17.9486, -14.6768, -17.2313, -18.8579, -17.3853, -14.2893,
        -17.5813, -16.7964, -16.3499, -14.7100, -12.9895, -16.6637])
Predicted: 3
Label: 4
Output: tensor([-14.2602, -13.1468, -12.2997,  19.2328, -15.5737, -15.1876, -14.4320,
        -15.1141, -15.1019, -12.4558, -15.0452, -16.0122, -15.4349, -12.1967,
        -15.2521, -14.1678, -14.1820, -12.8843, -10.9392, -14.3684])
Pr