In [None]:
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score
import numpy as np
from tqdm import tqdm
from torch_geometric.nn import GCNConv, global_mean_pool, GINConv, BatchNorm, global_add_pool

with open('/Users/nasibhuseynzade/Downloads/zinc_dataset.pkl','rb') as f:
    dataset = pickle.load(f)

In [None]:
def train_test_model(model, dataset, num_epochs=4, batch_size=32, learning_rate=0.0005):

    train_dataset, test_dataset = train_test_split(dataset, test_size=0.2, random_state=42)

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    r2_values=[]

    for epoch in tqdm(range(num_epochs)):
 
        model.train()
        total_loss = 0

        for batch in train_loader:
            batch = batch.to(device)
            batch.x = batch.x.float()
            batch.y = batch.y.float().view(-1, 1)

            optimizer.zero_grad()
            
            # Pass the entire batch (PyG handles batch processing internally)
            out = model(batch)  # Instead of individual inputs
            
            loss = F.mse_loss(out, batch.y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * batch.num_graphs

        model.eval()
        y_true = []
        y_pred = []
    
        with torch.no_grad():
            for batch in test_loader:
                batch = batch.to(device)
                batch.x = batch.x.float()
                batch.y = batch.y.float().view(-1, 1)
            
                out = model(batch)
                y_true.append(batch.y.cpu().numpy())
                y_pred.append(out.cpu().numpy())
    
        y_true = np.concatenate(y_true)
        y_pred = np.concatenate(y_pred)

        r2_value = r2_score(y_true, y_pred)
        r2_values.append(r2_value)

        print(f'Epoch {epoch+1}/{num_epochs}, R2 Value: {r2_value:.4f}')

    
    return r2_values



In [None]:
class GINModel(torch.nn.Module):
    def __init__(self, num_features, num_classes=1, hidden_dim=64, depth=3):
        super(GINModel, self).__init__()

        # Define GIN layers
        self.convs = torch.nn.ModuleList()
        self.convs.append(GINConv(
            torch.nn.Sequential(
                torch.nn.Linear(num_features, hidden_dim),
                torch.nn.ReLU(),
                torch.nn.Linear(hidden_dim, hidden_dim),
                torch.nn.ReLU()
            )
        ))

        # Additional GIN layers
        for _ in range(depth - 1):
            self.convs.append(GINConv(
                torch.nn.Sequential(
                    torch.nn.Linear(hidden_dim, hidden_dim),
                    torch.nn.ReLU(),
                    torch.nn.Linear(hidden_dim, hidden_dim),
                    torch.nn.ReLU()
                )
            ))

        # Batch normalization layers
        self.batch_norms = torch.nn.ModuleList([BatchNorm(hidden_dim) for _ in range(depth)])

        # Final regression layer
        self.final_lin = torch.nn.Linear(hidden_dim, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        for conv, batch_norm in zip(self.convs, self.batch_norms):
            x = conv(x, edge_index)
            x = batch_norm(x)
            x = F.relu(x)
        x = global_add_pool(x, data.batch)  # Pool to get a graph-level representation
        x = self.final_lin(x)  # Final regression output
        return x

In [None]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model=GINModel(num_features=dataset[0].x.shape[1])
#model = GNN(num_node_features=dataset[0].x.shape[1], hidden_dim=64).to(device)
r2_values = train_test_model(model, dataset, num_epochs=4)