In [1]:
import pandas as pd
import networkx as nx
import logging
#import matplotlib.pyplot as plt
from Bio.PDB import PDBParser
from Bio.PDB.SASA import ShrakeRupley
import numpy as np
import csv
import os

import torch
from torch_geometric.data import Data, DataLoader
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool

In [2]:
# Define the GNN model
class GNN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GNN, self).__init__()
        # Define GCN layers
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        # Fully connected layer for classification
        self.fc = torch.nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x, edge_index, batch):
        # Apply GCN layers with ReLU activation
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        # Global pooling for graph-level embedding
        x = global_mean_pool(x, batch)  # Pool over nodes for each graph
        # Output layer
        x = self.fc(x)
        return x

# Function to train the model
def train(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    for data in loader:
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

# Function to test the model
def test(model, loader):
    model.eval()
    correct = 0
    for data in loader:
        out = model(data.x, data.edge_index, data.batch)
        pred = out.argmax(dim=1)
        correct += int((pred == data.y).sum())
    return correct / len(loader.dataset)

In [5]:
from torch_geometric.data import DataLoader

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def main():    
    logger.info("Start defining hyperparameters")
    # Define hyperparameters
    input_dim = 10  # Replace with the actual feature dimension of nodes
    hidden_dim = 64
    output_dim = 7  # Number of classes in your multi-class classification
    epochs = 50
    batch_size = 1
    learning_rate = 0.01
    
    logger.info("Start Initializing the model")
    # Initialize the model, optimizer, and loss function
    model = GNN(input_dim, hidden_dim, output_dim)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = torch.nn.CrossEntropyLoss()
    
    logger.info("Start loading into DataLoader")
    # Load the data_list from the saved file
    loaded_data_list = torch.load('data_list.pt')
    loader = DataLoader(loaded_data_list, batch_size=batch_size, shuffle=True)

    # Example: Iterating through DataLoader and printing batch info
    for batch in loader:
        print("Batch node features shape:", batch.x.shape)
        print("Batch edge indices shape:", batch.edge_index.shape)
        print("Batch labels:", batch.y)
        print("------------")
        
    logger.info("Start training loop")
    # Training loop
    for epoch in range(epochs):
        loss = train(model, loader, optimizer, criterion)
        accuracy = test(model, loader)
        print(f'Epoch {epoch+1}/{epochs}, Loss: {loss:.4f}, Accuracy: {accuracy:.4f}')

    print("Training complete.")

In [6]:
if __name__ == "__main__":
    main()

2025-01-21 14:41:51,893 - INFO - Start defining hyperparameters
2025-01-21 14:41:51,893 - INFO - Start Initializing the model
2025-01-21 14:41:51,897 - INFO - Start loading into DataLoader
2025-01-21 14:41:51,905 - INFO - Start training loop


Batch node features shape: torch.Size([53, 15])
Batch edge indices shape: torch.Size([2, 53])
Batch labels: tensor([2])
------------
Batch node features shape: torch.Size([44, 15])
Batch edge indices shape: torch.Size([2, 42])
Batch labels: tensor([0])
------------
Batch node features shape: torch.Size([112, 15])
Batch edge indices shape: torch.Size([2, 119])
Batch labels: tensor([1])
------------
Batch node features shape: torch.Size([146, 15])
Batch edge indices shape: torch.Size([2, 153])
Batch labels: tensor([1])
------------


RuntimeError: index 197 is out of bounds for dimension 0 with size 112