<a href="https://colab.research.google.com/github/meklithab/graph-neural-network-task/blob/main/gnn_task.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

!pip install torch torchvision torchaudio
!pip install torch-geometric scikit-learn




In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv, SAGEConv, GATv2Conv
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score


dataset = Planetoid(root="/tmp/Cora", name="Cora")
data = dataset[0]





In [None]:
# --- Models ---
class MLP(torch.nn.Module):
    def __init__(self, hidden=64, layers=2, dropout=0.5):
        super().__init__()
        self.layers = torch.nn.ModuleList()
        self.layers.append(torch.nn.Linear(dataset.num_features, hidden))
        for _ in range(layers-2):
            self.layers.append(torch.nn.Linear(hidden, hidden))
        self.layers.append(torch.nn.Linear(hidden, dataset.num_classes))
        self.dropout = dropout
    def forward(self, x):
        for layer in self.layers[:-1]:
            x = F.relu(layer(x))
            x = F.dropout(x, p=self.dropout, training=self.training)
        return F.log_softmax(self.layers[-1](x), dim=1)

class GCN(torch.nn.Module):
    def __init__(self, hidden=64, layers=2, dropout=0.5):
        super().__init__()
        self.convs = torch.nn.ModuleList()
        self.convs.append(GCNConv(dataset.num_features, hidden))
        for _ in range(layers-2):
            self.convs.append(GCNConv(hidden, hidden))
        self.convs.append(GCNConv(hidden, dataset.num_classes))
        self.dropout = dropout
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        for conv in self.convs[:-1]:
            x = F.relu(conv(x, edge_index))
            x = F.dropout(x, p=self.dropout, training=self.training)
        return F.log_softmax(self.convs[-1](x, edge_index), dim=1)

class GraphSAGE(torch.nn.Module):
    def __init__(self, hidden=64, layers=2, dropout=0.5):
        super().__init__()
        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(dataset.num_features, hidden))
        for _ in range(layers-2):
            self.convs.append(SAGEConv(hidden, hidden))
        self.convs.append(SAGEConv(hidden, dataset.num_classes))
        self.dropout = dropout
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        for conv in self.convs[:-1]:
            x = F.relu(conv(x, edge_index))
            x = F.dropout(x, p=self.dropout, training=self.training)
        return F.log_softmax(self.convs[-1](x, edge_index), dim=1)

class GATv2(torch.nn.Module):
    def __init__(self, hidden=64, layers=2, heads=4, dropout=0.5):
        super().__init__()
        self.convs = torch.nn.ModuleList()
        self.convs.append(GATv2Conv(dataset.num_features, hidden, heads=heads, dropout=dropout))
        for _ in range(layers-2):
            self.convs.append(GATv2Conv(hidden*heads, hidden, heads=heads, dropout=dropout))
        self.convs.append(GATv2Conv(hidden*heads, dataset.num_classes, heads=1, concat=False, dropout=dropout))
        self.dropout = dropout
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        for conv in self.convs[:-1]:
            x = F.elu(conv(x, edge_index))
            x = F.dropout(x, p=self.dropout, training=self.training)
        return F.log_softmax(self.convs[-1](x, edge_index), dim=1)


In [None]:
from sklearn.metrics import accuracy_score, f1_score, log_loss

def train_eval(model, data, lr=0.01, weight_decay=5e-4, epochs=100):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    best_val_acc, best_state = 0, None
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        out = model(data) if not isinstance(model, MLP) else model(data.x)
        loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()

        model.eval()
        out = model(data) if not isinstance(model, MLP) else model(data.x)
        pred = out.argmax(dim=1)
        val_acc = (pred[data.val_mask] == data.y[data.val_mask]).sum().item() / data.val_mask.sum().item()
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_state = model.state_dict()

    model.load_state_dict(best_state)

# --- Evaluation metrics ---
def evaluate_metrics(model, data):
    model.eval()
    out = model(data) if not isinstance(model, MLP) else model(data.x)
    pred = out.argmax(dim=1).cpu().numpy()
    y_true = data.y.cpu().numpy()
    test_mask = data.test_mask.cpu().numpy()

    acc = accuracy_score(y_true[test_mask], pred[test_mask])
    f1 = f1_score(y_true[test_mask], pred[test_mask], average="macro")

    # Convert log-softmax outputs back to probabilities
    probs = torch.exp(out).detach().cpu().numpy()
    ll = log_loss(y_true[test_mask], probs[test_mask])

    return acc, f1, ll

# --- Minimal hyperparameter experiments ---
hidden_dims = [64, 128]
layers = [2, 3]

results = []
for hd in hidden_dims:
    for lay in layers:
        for model_class in [MLP, GCN, GraphSAGE, GATv2]:
            model = model_class(hidden=hd, layers=lay, dropout=0.5)
            train_eval(model, data, lr=0.01, weight_decay=5e-4, epochs=100)
            acc, f1, ll = evaluate_metrics(model, data)
            results.append({
                "Model": model_class.__name__,
                "Hidden": hd,
                "Layers": lay,
                "Accuracy": acc,
                "F1": f1,
                "LogLoss": ll
            })

# --- Print results table ---
print("{:<10} {:<6} {:<6} {:<8} {:<8} {:<8}".format("Model","Hidden","Layers","Acc","F1","LogLoss"))
for r in results:
    print("{:<10} {:<6} {:<6} {:<8.3f} {:<8.3f} {:<8.3f}".format(
        r["Model"], r["Hidden"], r["Layers"], r["Accuracy"], r["F1"], r["LogLoss"]))



Model      Hidden Layers Acc      F1       LogLoss 
MLP        64     2      0.578    0.564    1.323   
GCN        64     2      0.813    0.804    0.601   
GraphSAGE  64     2      0.799    0.790    0.643   
GATv2      64     2      0.786    0.775    0.804   
MLP        64     3      0.555    0.546    1.900   
GCN        64     3      0.784    0.779    0.876   
GraphSAGE  64     3      0.793    0.787    1.019   
GATv2      64     3      0.802    0.792    1.169   
MLP        128    2      0.568    0.558    1.351   
GCN        128    2      0.806    0.801    0.615   
GraphSAGE  128    2      0.804    0.798    0.635   
GATv2      128    2      0.769    0.771    0.960   
MLP        128    3      0.558    0.550    1.850   
GCN        128    3      0.800    0.795    0.819   
GraphSAGE  128    3      0.784    0.774    1.070   
GATv2      128    3      0.749    0.753    1.834   
