In [22]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import torch.nn.functional as F

In [23]:
df = pd.read_csv("datasets/target_molecules_clean_2.csv")
df.head()

Unnamed: 0,canonical_smiles,pIC50,ECFP4_0,ECFP4_1,ECFP4_2,ECFP4_3,ECFP4_4,ECFP4_5,ECFP4_6,ECFP4_7,...,ECFP4_2038,ECFP4_2039,ECFP4_2040,ECFP4_2041,ECFP4_2042,ECFP4_2043,ECFP4_2044,ECFP4_2045,ECFP4_2046,ECFP4_2047
0,O=C(O)/C=C/c1ccc(OS(=O)(=O)O)cc1,-0.30103,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,CN(CCCNC(=O)c1ccc(O)cc1)CCCNC(=O)c1ccc(O)cc1,-0.30103,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,CCN(CCCN(CC)C(=O)c1ccc(O)cc1)C(=O)c1ccc(O)cc1,-0.30103,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,CN(CCCNC(=O)c1ccc(O)cc1)CCCNC(=O)c1ccc2cc(O)cc...,3.531653,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,CCN(CCOC(=O)/C=C/c1ccc(O)cc1)Cc1cc(Cl)ccc1O,4.337242,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [24]:
df.dtypes

canonical_smiles     object
pIC50               float64
ECFP4_0               int64
ECFP4_1               int64
ECFP4_2               int64
                     ...   
ECFP4_2043            int64
ECFP4_2044            int64
ECFP4_2045            int64
ECFP4_2046            int64
ECFP4_2047            int64
Length: 2050, dtype: object

In [25]:
# Create a binary target column based on the condition pIC50 > 0.5
df['target'] = (df['pIC50'] > 5.0).astype(int)
df.head()

Unnamed: 0,canonical_smiles,pIC50,ECFP4_0,ECFP4_1,ECFP4_2,ECFP4_3,ECFP4_4,ECFP4_5,ECFP4_6,ECFP4_7,...,ECFP4_2039,ECFP4_2040,ECFP4_2041,ECFP4_2042,ECFP4_2043,ECFP4_2044,ECFP4_2045,ECFP4_2046,ECFP4_2047,target
0,O=C(O)/C=C/c1ccc(OS(=O)(=O)O)cc1,-0.30103,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,CN(CCCNC(=O)c1ccc(O)cc1)CCCNC(=O)c1ccc(O)cc1,-0.30103,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,CCN(CCCN(CC)C(=O)c1ccc(O)cc1)C(=O)c1ccc(O)cc1,-0.30103,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,CN(CCCNC(=O)c1ccc(O)cc1)CCCNC(=O)c1ccc2cc(O)cc...,3.531653,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,CCN(CCOC(=O)/C=C/c1ccc(O)cc1)Cc1cc(Cl)ccc1O,4.337242,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [26]:

import torch
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

In [27]:
import rdkit
from rdkit import Chem
import torch
from torch_geometric.data import Data

def smiles_to_graph(smiles):
    mol = Chem.MolFromSmiles(smiles)
    
    # Node features
    atomic_nums = [atom.GetAtomicNum() for atom in mol.GetAtoms()]
    chiralities = [atom.GetChiralTag() for atom in mol.GetAtoms()]
    degrees = [atom.GetDegree() for atom in mol.GetAtoms()]
    formal_charges = [atom.GetFormalCharge() for atom in mol.GetAtoms()]
    num_hydrogens = [atom.GetTotalNumHs() for atom in mol.GetAtoms()]
    num_radical_electrons = [atom.GetNumRadicalElectrons() for atom in mol.GetAtoms()]
    hybridizations = [atom.GetHybridization() for atom in mol.GetAtoms()]
    aromatics = [atom.GetIsAromatic() for atom in mol.GetAtoms()]
    in_rings = [atom.IsInRing() for atom in mol.GetAtoms()]
    
    node_feats = list(zip(atomic_nums, chiralities, degrees, formal_charges, num_hydrogens, 
                          num_radical_electrons, hybridizations, aromatics, in_rings))
    node_feats = torch.tensor(node_feats, dtype=torch.float)
    
    # Edges
    edge_indices = []
    bond_types = []
    stereo_configs = []
    is_conjugateds = []
    for bond in mol.GetBonds():
        start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        edge_indices.append((start, end))
        edge_indices.append((end, start))  # Add both directions since the graph is undirected
        bond_types.extend([bond.GetBondTypeAsDouble()] * 2)
        stereo_configs.extend([bond.GetStereo()] * 2)
        is_conjugateds.extend([bond.GetIsConjugated()] * 2)
    
    edge_feats = list(zip(bond_types, stereo_configs, is_conjugateds))
    edge_feats = torch.tensor(edge_feats, dtype=torch.float)
    edge_indices = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
    
    # Create a PyG Data object
    graph = Data(x=node_feats.to(device), edge_index=edge_indices.to(device), edge_attr=edge_feats.to(device))
    
    return graph

graphs = [smiles_to_graph(smiles) for smiles in df['canonical_smiles']]




In [28]:
from sklearn.model_selection import train_test_split

# Convert the graphs and targets into lists for splitting
graphs_list = list(graphs)
targets_list = df['target'].tolist()

# Split the data into train and temporary datasets (80% train, 20% temp)
train_graphs, temp_graphs, train_targets, temp_targets = train_test_split(
    graphs_list, targets_list, test_size=0.2, stratify=targets_list, random_state=42
)

# Split the temporary data into validation and test datasets (50% validation, 50% test)
val_graphs, test_graphs, val_targets, test_targets = train_test_split(
    temp_graphs, temp_targets, test_size=0.5, stratify=temp_targets, random_state=42
)

# Convert lists back to tensors for PyG compatibility
train_targets = torch.tensor(train_targets, dtype=torch.long).to(device)
val_targets = torch.tensor(val_targets, dtype=torch.long).to(device)
test_targets = torch.tensor(test_targets, dtype=torch.long).to(device)

# Add targets to graph data objects
for graph, target in zip(train_graphs, train_targets):
    graph.y = target

for graph, target in zip(val_graphs, val_targets):
    graph.y = target

for graph, target in zip(test_graphs, test_targets):
    graph.y = target


In [29]:
# Define the AddGaussianNoise transform
class AddGaussianNoise(object):
    def __init__(self, mean=0, std=0.1):
        self.mean = mean
        self.std = std

    def __call__(self, data):
        if data.x is not None:
            data.x = data.x + torch.randn_like(data.x) * self.std + self.mean
        return data

# Instantiate the transform
add_gaussian_noise = AddGaussianNoise(mean=0, std=0.1)

# Apply the transform to each graph in the training set
train_graphs_transformed = [add_gaussian_noise(graph) for graph in train_graphs]

# Check the first graph's features to see if noise is added
train_graphs_transformed[0].x[:5]  # Displaying the first 5 nodes' features


tensor([[ 5.9727e+00,  1.3139e-01,  9.9599e-01,  7.5892e-02,  2.9752e+00,
          1.4274e-02,  4.0147e+00,  4.2812e-02,  8.5679e-02],
        [ 8.3230e+00,  8.1465e-02,  1.9298e+00,  6.0213e-03, -6.1783e-02,
          5.7997e-02,  3.0968e+00,  1.8527e-02,  8.0315e-02],
        [ 6.1154e+00, -5.8172e-02,  2.9913e+00,  5.0748e-03,  1.5648e-01,
         -6.2758e-02,  2.9067e+00,  8.2520e-01,  1.0647e+00],
        [ 5.9746e+00,  4.6717e-02,  1.7104e+00,  1.9432e-01,  9.9301e-01,
          1.2092e-01,  3.0078e+00,  1.0253e+00,  1.0054e+00],
        [ 5.8330e+00, -4.3671e-02,  2.0335e+00,  9.8586e-02,  9.9091e-01,
         -1.5921e-02,  2.9736e+00,  9.8785e-01,  9.5788e-01]], device='cuda:1')

In [30]:
# Define the FeatureDropout transform
class FeatureDropout(object):
    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, data):
        # Check if input data has a `x` attribute
        if data.x is not None:
            # Randomly dropout features with probability `p`
            mask = torch.rand(data.x.shape[1]) < self.p
            data.x[:, mask] = 0
        
        return data

# Apply the FeatureDropout transform to the training graphs
feature_dropout_transform = FeatureDropout(p=0.05)
train_graphs_transformed = [feature_dropout_transform(graph) for graph in train_graphs]

# Let's check the effect of the transformation on the first graph's node features
train_graphs_transformed[0].x


tensor([[ 5.9727e+00,  1.3139e-01,  9.9599e-01,  0.0000e+00,  2.9752e+00,
          1.4274e-02,  4.0147e+00,  4.2812e-02,  8.5679e-02],
        [ 8.3230e+00,  8.1465e-02,  1.9298e+00,  0.0000e+00, -6.1783e-02,
          5.7997e-02,  3.0968e+00,  1.8527e-02,  8.0315e-02],
        [ 6.1154e+00, -5.8172e-02,  2.9913e+00,  0.0000e+00,  1.5648e-01,
         -6.2758e-02,  2.9067e+00,  8.2520e-01,  1.0647e+00],
        [ 5.9746e+00,  4.6717e-02,  1.7104e+00,  0.0000e+00,  9.9301e-01,
          1.2092e-01,  3.0078e+00,  1.0253e+00,  1.0054e+00],
        [ 5.8330e+00, -4.3671e-02,  2.0335e+00,  0.0000e+00,  9.9091e-01,
         -1.5921e-02,  2.9736e+00,  9.8785e-01,  9.5788e-01],
        [ 5.8137e+00,  1.7382e-01,  2.9669e+00,  0.0000e+00,  4.8954e-02,
          1.0035e-01,  2.9630e+00,  1.0520e+00,  1.0472e+00],
        [ 6.1809e+00, -1.6178e-01,  3.0157e+00,  0.0000e+00, -4.5318e-02,
          9.2799e-02,  2.9223e+00,  1.2094e+00,  9.1527e-01],
        [ 5.9831e+00, -4.4775e-02,  1.9796e+00, 

In [31]:
from torch_geometric.loader import DataLoader
# Create PyTorch Geometric DataLoaders
train_loader = DataLoader(train_graphs, batch_size=32, shuffle=True)
val_loader = DataLoader(val_graphs, batch_size=32, shuffle=False)
test_loader = DataLoader(test_graphs, batch_size=32, shuffle=False)


In [32]:
graph.y.shape

torch.Size([])

In [33]:
from torch.nn import Linear
from torch_geometric.nn import GCNConv, global_mean_pool

class GNN(torch.nn.Module):
    def __init__(self, input_node_features, input_edge_features, hidden_channels):
        super(GNN, self).__init__()
        torch.manual_seed(42)
        
        # Define the convolution layers
        self.conv1 = GCNConv(input_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        
        # Define the final classifier
        self.lin = Linear(hidden_channels, 2)
    
    def forward(self, x, edge_index, edge_attr, batch=None):
        # 1. Node embedding 
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        
        # 2. Readout layer
        batch = torch.zeros(x.shape[0], dtype=int).to(device) if batch is None else batch
        x = global_mean_pool(x, batch)
        
        # 3. Final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        
        return x

# Create an instance of the model
model = GNN(input_node_features=9, input_edge_features=3, hidden_channels=64).to(device)
model = model.to(device)


In [34]:
# Calculate class weights
num_pos = sum(targets_list)  
num_neg = len(targets_list) - num_pos
total_samples = len(targets_list)

weight_class1 = total_samples / (2 * num_pos)
weight_class0 = total_samples / (2 * num_neg)

weights = torch.tensor([weight_class0, weight_class1]).to(device)
weights

tensor([0.6961, 1.7750], device='cuda:1')

In [35]:
# model = GNN(num_node_features=1, num_classes=2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_func = torch.nn.CrossEntropyLoss(weight=weights)

def train():
    model.train()
    total_loss = 0
    correct = 0
    total = 0  # Track total number of samples processed
    
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.edge_attr, data.batch)  # Pass the batch attribute
        
        loss = loss_func(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
        # Compute training accuracy
        pred = out.argmax(dim=1)  # Get the class with the maximum probability
        correct += pred.eq(data.y).sum().item()
        total += data.y.size(0)  # Add the size of the current batch to the total

    train_acc = correct / total
    return total_loss / len(train_loader), train_acc

def validate():
    model.eval()
    correct = 0
    total = 0  # Track total number of samples processed
    
    for data in val_loader:
        data = data.to(device)
        
        with torch.no_grad():
            out = model(data.x, data.edge_index, data.edge_attr, data.batch)
            pred = out.argmax(dim=1)  # Get the class with the maximum probability

            correct += pred.eq(data.y).sum().item()
            total += data.y.size(0)  # Add the size of the current batch to the total

    return correct / total



epochs = 512
for epoch in range(epochs):
    train_loss, train_acc = train() 
    val_acc = validate()
    
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.4f} | Train Accuracy: {train_acc:.4f} | Validation Accuracy: {val_acc:.4f}")
    
    if val_acc >= 0.85:
        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.4f} | Train Accuracy: {train_acc:.4f} | Validation Accuracy: {val_acc:.4f}")
        break


Epoch 10/512 | Train Loss: 0.6591 | Train Accuracy: 0.6669 | Validation Accuracy: 0.7372
Epoch 20/512 | Train Loss: 0.6365 | Train Accuracy: 0.7326 | Validation Accuracy: 0.7436
Epoch 30/512 | Train Loss: 0.6367 | Train Accuracy: 0.7382 | Validation Accuracy: 0.7628
Epoch 40/512 | Train Loss: 0.6148 | Train Accuracy: 0.7574 | Validation Accuracy: 0.7692
Epoch 50/512 | Train Loss: 0.6153 | Train Accuracy: 0.7262 | Validation Accuracy: 0.7628
Epoch 60/512 | Train Loss: 0.6108 | Train Accuracy: 0.7566 | Validation Accuracy: 0.7949
Epoch 70/512 | Train Loss: 0.6038 | Train Accuracy: 0.7454 | Validation Accuracy: 0.7500
Epoch 80/512 | Train Loss: 0.5722 | Train Accuracy: 0.7654 | Validation Accuracy: 0.7628
Epoch 90/512 | Train Loss: 0.6295 | Train Accuracy: 0.7558 | Validation Accuracy: 0.7756
Epoch 100/512 | Train Loss: 0.5754 | Train Accuracy: 0.7646 | Validation Accuracy: 0.7500
Epoch 110/512 | Train Loss: 0.5698 | Train Accuracy: 0.7598 | Validation Accuracy: 0.7500
Epoch 120/512 | Tra

In [36]:
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score

def test_metrics():
    model.eval()
    y_true = []
    y_pred = []
    
    for data in test_loader:
        data = data.to(device)
        
        with torch.no_grad():
            out = model(data.x, data.edge_index, data.edge_attr, data.batch)
            preds = out.argmax(dim=1).tolist()
            y_pred.extend(preds)
            y_true.extend(data.y.tolist())
    
    acc = accuracy_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    
    return acc, recall, precision, f1

# At the end of training or wherever you want to print the metrics:
acc, recall, precision, f1 = test_metrics()
print(f"Test Accuracy: {acc:.4f} | Recall: {recall:.4f} | Precision: {precision:.4f} | F1 Score: {f1:.4f}")


Test Accuracy: 0.7771 | Recall: 0.8636 | Precision: 0.5672 | F1 Score: 0.6847


In [37]:
from torch_geometric.nn import NNConv, global_mean_pool
# https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.NNConv.html
class EdgeNetwork(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(EdgeNetwork, self).__init__()
        self.lin = torch.nn.Linear(in_channels, out_channels)
        
    def forward(self, x):
        x = F.relu(self.lin(x))
        return x

class GNN(torch.nn.Module):
    def __init__(self, input_node_features, input_edge_features, hidden_channels):
        super(GNN, self).__init__()
        torch.manual_seed(42)
        
        # Define the edge update network for NNConv
        edge_nn1 = EdgeNetwork(input_edge_features, input_node_features * hidden_channels)
        self.conv1 = NNConv(input_node_features, hidden_channels, edge_nn1)
        
        edge_nn2 = EdgeNetwork(input_edge_features, hidden_channels * hidden_channels)
        self.conv2 = NNConv(hidden_channels, hidden_channels, edge_nn2)
        
        # Define the final classifier
        self.lin = Linear(hidden_channels, 2)
    
    def forward(self, x, edge_index, edge_attr, batch=None):
        # 1. Node embedding with edge features
        x = self.conv1(x, edge_index, edge_attr)
        x = F.relu(x)
        
        x = self.conv2(x, edge_index, edge_attr)
        x = F.relu(x)
        
        # 2. Readout layer
        batch = torch.zeros(x.shape[0], dtype=int).to(device) if batch is None else batch
        x = global_mean_pool(x, batch)
        
        # 3. Final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        
        return x

# Create an instance of the model
model = GNN(input_node_features=9, input_edge_features=3, hidden_channels=64).to(device)
model = model.to(device)


In [38]:
# model = GNN(num_node_features=1, num_classes=2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_func = torch.nn.CrossEntropyLoss(weight=weights)

def train():
    model.train()
    total_loss = 0
    correct = 0
    total = 0  # Track total number of samples processed
    
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.edge_attr, data.batch)  # Pass the batch attribute
        
        loss = loss_func(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
        # Compute training accuracy
        pred = out.argmax(dim=1)  # Get the class with the maximum probability
        correct += pred.eq(data.y).sum().item()
        total += data.y.size(0)  # Add the size of the current batch to the total

    train_acc = correct / total
    return total_loss / len(train_loader), train_acc

def validate():
    model.eval()
    correct = 0
    total = 0  # Track total number of samples processed
    
    for data in val_loader:
        data = data.to(device)
        
        with torch.no_grad():
            out = model(data.x, data.edge_index, data.edge_attr, data.batch)
            pred = out.argmax(dim=1)  # Get the class with the maximum probability

            correct += pred.eq(data.y).sum().item()
            total += data.y.size(0)  # Add the size of the current batch to the total

    return correct / total



epochs = 512
for epoch in range(epochs):
    train_loss, train_acc = train() 
    val_acc = validate()
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.4f} | Train Accuracy: {train_acc:.4f} | Validation Accuracy: {val_acc:.4f}")


Epoch 10/512 | Train Loss: 0.6166 | Train Accuracy: 0.7182 | Validation Accuracy: 0.7436
Epoch 20/512 | Train Loss: 0.6025 | Train Accuracy: 0.7078 | Validation Accuracy: 0.7949
Epoch 30/512 | Train Loss: 0.5839 | Train Accuracy: 0.7246 | Validation Accuracy: 0.6410
Epoch 40/512 | Train Loss: 0.5914 | Train Accuracy: 0.6910 | Validation Accuracy: 0.7051
Epoch 50/512 | Train Loss: 0.5664 | Train Accuracy: 0.7294 | Validation Accuracy: 0.7564
Epoch 60/512 | Train Loss: 0.5476 | Train Accuracy: 0.7494 | Validation Accuracy: 0.7628
Epoch 70/512 | Train Loss: 0.5430 | Train Accuracy: 0.7406 | Validation Accuracy: 0.7115
Epoch 80/512 | Train Loss: 0.5100 | Train Accuracy: 0.7646 | Validation Accuracy: 0.6346
Epoch 90/512 | Train Loss: 0.5143 | Train Accuracy: 0.7606 | Validation Accuracy: 0.7756
Epoch 100/512 | Train Loss: 0.5077 | Train Accuracy: 0.7430 | Validation Accuracy: 0.7244
Epoch 110/512 | Train Loss: 0.5055 | Train Accuracy: 0.7710 | Validation Accuracy: 0.7564
Epoch 120/512 | Tra

In [39]:
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score

def test_metrics():
    model.eval()
    y_true = []
    y_pred = []
    
    for data in test_loader:
        data = data.to(device)
        
        with torch.no_grad():
            out = model(data.x, data.edge_index, data.edge_attr, data.batch)
            preds = out.argmax(dim=1).tolist()
            y_pred.extend(preds)
            y_true.extend(data.y.tolist())
    
    acc = accuracy_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    
    return acc, recall, precision, f1

# At the end of training or wherever you want to print the metrics:
acc, recall, precision, f1 = test_metrics()
print(f"Test Accuracy: {acc:.4f} | Recall: {recall:.4f} | Precision: {precision:.4f} | F1 Score: {f1:.4f}")


Test Accuracy: 0.7197 | Recall: 0.8409 | Precision: 0.5000 | F1 Score: 0.6271


In [40]:
from torch_geometric.nn import MessagePassing, global_mean_pool

class CustomGCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels, edge_features):
        super(CustomGCNConv, self).__init__(aggr='add')  # "Add" aggregation
        self.lin = torch.nn.Linear(in_channels, out_channels)
        self.edge_nn = torch.nn.Linear(edge_features, out_channels)

    def forward(self, x, edge_index, edge_attr):
        # Transform node inputs
        x = self.lin(x)

        # Transform edge inputs
        edge_attr = self.edge_nn(edge_attr)

        # Start message passing
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x, edge_attr=edge_attr)

    def message(self, x_j, edge_attr):
        # Combine node and edge features in message function
        return x_j + edge_attr

    def update(self, aggr_out):
        # Update node features with aggregated values
        return aggr_out

class GNN(torch.nn.Module):
    def __init__(self, input_node_features, input_edge_features, hidden_channels):
        super(GNN, self).__init__()
        torch.manual_seed(42)
        
        # Define the convolution layers
        self.conv1 = CustomGCNConv(input_node_features, hidden_channels, input_edge_features)
        self.conv2 = CustomGCNConv(hidden_channels, hidden_channels, input_edge_features)
        
        # Define the final classifier
        self.lin = Linear(hidden_channels, 2)

    def forward(self, x, edge_index, edge_attr, batch=None):
        # 1. Node embedding 
        x = self.conv1(x, edge_index, edge_attr)
        x = F.relu(x)
        x = self.conv2(x, edge_index, edge_attr)
        x = F.relu(x)
        
        # 2. Readout layer
        batch = torch.zeros(x.shape[0], dtype=int).to(device) if batch is None else batch
        x = global_mean_pool(x, batch)
        
        # 3. Final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        
        return x

# Create an instance of the model
model = GNN(input_node_features=9, input_edge_features=3, hidden_channels=64).to(device)


In [41]:
# model = GNN(num_node_features=1, num_classes=2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_func = torch.nn.CrossEntropyLoss(weight=weights)

def train():
    model.train()
    total_loss = 0
    correct = 0
    total = 0  # Track total number of samples processed
    
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.edge_attr, data.batch)  # Pass the batch attribute
        
        loss = loss_func(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
        # Compute training accuracy
        pred = out.argmax(dim=1)  # Get the class with the maximum probability
        correct += pred.eq(data.y).sum().item()
        total += data.y.size(0)  # Add the size of the current batch to the total

    train_acc = correct / total
    return total_loss / len(train_loader), train_acc

def validate():
    model.eval()
    correct = 0
    total = 0  # Track total number of samples processed
    
    for data in val_loader:
        data = data.to(device)
        
        with torch.no_grad():
            out = model(data.x, data.edge_index, data.edge_attr, data.batch)
            pred = out.argmax(dim=1)  # Get the class with the maximum probability

            correct += pred.eq(data.y).sum().item()
            total += data.y.size(0)  # Add the size of the current batch to the total

    return correct / total



epochs = 512
for epoch in range(epochs):
    train_loss, train_acc = train() 
    val_acc = validate()
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.4f} | Train Accuracy: {train_acc:.4f} | Validation Accuracy: {val_acc:.4f}")


Epoch 10/512 | Train Loss: 0.6957 | Train Accuracy: 0.4043 | Validation Accuracy: 0.2821
Epoch 20/512 | Train Loss: 0.6617 | Train Accuracy: 0.4988 | Validation Accuracy: 0.4872
Epoch 30/512 | Train Loss: 0.6815 | Train Accuracy: 0.5140 | Validation Accuracy: 0.6603
Epoch 40/512 | Train Loss: 0.6198 | Train Accuracy: 0.6125 | Validation Accuracy: 0.5962
Epoch 50/512 | Train Loss: 0.6036 | Train Accuracy: 0.7574 | Validation Accuracy: 0.7564
Epoch 60/512 | Train Loss: 0.5994 | Train Accuracy: 0.7718 | Validation Accuracy: 0.7821
Epoch 70/512 | Train Loss: 0.5810 | Train Accuracy: 0.7966 | Validation Accuracy: 0.7949
Epoch 80/512 | Train Loss: 0.5760 | Train Accuracy: 0.7902 | Validation Accuracy: 0.8013
Epoch 90/512 | Train Loss: 0.5747 | Train Accuracy: 0.7878 | Validation Accuracy: 0.7885
Epoch 100/512 | Train Loss: 0.5687 | Train Accuracy: 0.8038 | Validation Accuracy: 0.8077
Epoch 110/512 | Train Loss: 0.5784 | Train Accuracy: 0.7878 | Validation Accuracy: 0.7885
Epoch 120/512 | Tra

In [42]:
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score

def test_metrics():
    model.eval()
    y_true = []
    y_pred = []
    
    for data in test_loader:
        data = data.to(device)
        
        with torch.no_grad():
            out = model(data.x, data.edge_index, data.edge_attr, data.batch)
            preds = out.argmax(dim=1).tolist()
            y_pred.extend(preds)
            y_true.extend(data.y.tolist())
    
    acc = accuracy_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    
    return acc, recall, precision, f1

# At the end of training or wherever you want to print the metrics:
acc, recall, precision, f1 = test_metrics()
print(f"Test Accuracy: {acc:.4f} | Recall: {recall:.4f} | Precision: {precision:.4f} | F1 Score: {f1:.4f}")


Test Accuracy: 0.7580 | Recall: 0.4545 | Precision: 0.5882 | F1 Score: 0.5128
