In [5]:
import torch
from torch_geometric.loader import DataLoader
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.nn import GCNConv, global_mean_pool
from sklearn.metrics import accuracy_score, mean_squared_error



In [12]:
import os

##############for github
#task = "classification"
#train_data = torch.load(f"../4_train_test_split/random_split/{task}/{task}_train.pt")
#val_data = torch.load(f"../4_train_test_split/random_split/{task}/{task}_val.pt")
#test_data = torch.load(f"../4_train_test_split/random_split/{task}/{task}_test.pt")
##############

task = "classification"
train_path = f"/content/drive/MyDrive/GNN_model_TRPM8_Drug_Potency_prediction/4_train_test_split/random_split/{task}/{task}_train.pt"
val_path   = f"/content/drive/MyDrive/GNN_model_TRPM8_Drug_Potency_prediction/4_train_test_split/random_split/{task}/{task}_val.pt"
test_path  = f"/content/drive/MyDrive/GNN_model_TRPM8_Drug_Potency_prediction/4_train_test_split/random_split/{task}/{task}_test.pt"

print("Train file exists:", os.path.exists(train_path))
print("Val file exists:", os.path.exists(val_path))
print("Test file exists:", os.path.exists(test_path))

#files with weights_only set to False
train_data = torch.load(train_path, weights_only=False)
val_data   = torch.load(val_path, weights_only=False)
test_data  = torch.load(test_path, weights_only=False)

print("Datasets loaded successfully.")




Train file exists: True
Val file exists: True
Test file exists: True
Datasets loaded successfully.


In [14]:
from torch_geometric.loader import DataLoader

train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
val_loader = DataLoader(val_data, batch_size=32)
test_loader = DataLoader(test_data, batch_size=32)


In [15]:
# GCN MODEL
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCN, self).__init__()
        # Three GCN layers
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        # Two linear layers for classification
        self.lin1 = Linear(hidden_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, out_channels)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        # Apply GCN layers with ReLU activation
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.conv3(x, edge_index)
        x = F.relu(x)
        # Global pooling to aggregate node features to a graph-level embedding
        x = global_mean_pool(x, batch)
        # Apply classifier layers
        x = self.lin1(x)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return x

# Loss and evaluation setup
if task == "classification":
    # Determine the number of classes from training targets
    num_classes = len(set([int(data.y.item()) for data in train_data]))
    # Instantiate model; ensure in_channels corresponds to the number of features per node
    model = GCN(in_channels=train_data[0].x.size(1), hidden_channels=64, out_channels=num_classes)
    criterion = torch.nn.CrossEntropyLoss()
else:
    model = GCN(in_channels=train_data[0].x.size(1), hidden_channels=64, out_channels=1)
    criterion = torch.nn.MSELoss()

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Training loop
def train():
    model.train()
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        out = model(batch)
        # For classification, convert targets to LongTensor
        loss = criterion(out.squeeze(), batch.y if task == "regression" else batch.y.long())
        loss.backward()
        optimizer.step()

# Evaluation function
def evaluate(loader):
    model.eval()
    preds, labels = [], []
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            out = model(batch)
            preds.append(out.squeeze().cpu())
            labels.append(batch.y.cpu())
    preds = torch.cat(preds)
    labels = torch.cat(labels)
    if task == "classification":
        # Get predicted classes
        pred_classes = preds.argmax(dim=1)
        return accuracy_score(labels, pred_classes)
    else:
        return mean_squared_error(labels, preds)

# Run training for 100 epochs
for epoch in range(1, 101):
    train()
    metric = evaluate(val_loader)
    print(f"Epoch {epoch:03d} - {'Accuracy' if task == 'classification' else 'MSE'}: {metric:.4f}")

# Final test evaluation
test_metric = evaluate(test_loader)
print(f"\nTest {'Accuracy' if task == 'classification' else 'MSE'}: {test_metric:.4f}")


Epoch 001 - Accuracy: 0.5192
Epoch 002 - Accuracy: 0.5192
Epoch 003 - Accuracy: 0.5192
Epoch 004 - Accuracy: 0.5192
Epoch 005 - Accuracy: 0.5192
Epoch 006 - Accuracy: 0.5192
Epoch 007 - Accuracy: 0.5192
Epoch 008 - Accuracy: 0.5192
Epoch 009 - Accuracy: 0.5192
Epoch 010 - Accuracy: 0.5192
Epoch 011 - Accuracy: 0.5192
Epoch 012 - Accuracy: 0.5192
Epoch 013 - Accuracy: 0.5192
Epoch 014 - Accuracy: 0.5192
Epoch 015 - Accuracy: 0.5192
Epoch 016 - Accuracy: 0.5192
Epoch 017 - Accuracy: 0.5192
Epoch 018 - Accuracy: 0.5192
Epoch 019 - Accuracy: 0.5192
Epoch 020 - Accuracy: 0.5192
Epoch 021 - Accuracy: 0.5192
Epoch 022 - Accuracy: 0.5192
Epoch 023 - Accuracy: 0.5192
Epoch 024 - Accuracy: 0.5192
Epoch 025 - Accuracy: 0.5192
Epoch 026 - Accuracy: 0.5192
Epoch 027 - Accuracy: 0.5192
Epoch 028 - Accuracy: 0.5192
Epoch 029 - Accuracy: 0.5192
Epoch 030 - Accuracy: 0.5192
Epoch 031 - Accuracy: 0.5192
Epoch 032 - Accuracy: 0.5192
Epoch 033 - Accuracy: 0.5192
Epoch 034 - Accuracy: 0.5192
Epoch 035 - Ac