In [4]:
# -----------------------------
# 0. Dependências
# -----------------------------
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# -----------------------------
# 1. Carregar Dataset
# -----------------------------
os.makedirs("../datasets", exist_ok=True)
dataset = Planetoid(root="../datasets/Cora", name="Cora")
data = dataset[0]

# Para exemplo: vamos criar rótulos aleatórios para as arestas
num_edges = data.edge_index.size(1)
num_classes = 2  
edge_labels = torch.randint(0, num_classes, (num_edges,))

# -----------------------------
# 2. Definir Modelo
# -----------------------------
class GCN(nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return x

class EdgeClassifier(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(EdgeClassifier, self).__init__()
        self.gcn = GCN(in_channels, hidden_channels)
        self.mlp = nn.Sequential(
            nn.Linear(2 * hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, out_channels)
        )

    def forward(self, x, edge_index):
        # gera embeddings de nós
        node_emb = self.gcn(x, edge_index)
        # concatena embeddings dos nós das arestas
        edge_emb = torch.cat([
            node_emb[edge_index[0]],
            node_emb[edge_index[1]]
        ], dim=1)
        return self.mlp(edge_emb)

# -----------------------------
# 3. Funções de treino e teste
# -----------------------------
def train(model, optimizer, data, edge_labels):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.cross_entropy(out, edge_labels)
    loss.backward()
    optimizer.step()
    return loss

@torch.no_grad()
def test(model, data, edge_labels):
    model.eval()
    out = model(data.x, data.edge_index)
    preds = out.argmax(dim=1).cpu()
    labels = edge_labels.cpu()
    acc = accuracy_score(labels, preds)
    prec = precision_score(labels, preds, average="macro", zero_division=0)
    rec = recall_score(labels, preds, average="macro", zero_division=0)
    f1 = f1_score(labels, preds, average="macro", zero_division=0)
    return acc, prec, rec, f1

# -----------------------------
# 4. Treinar modelo
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gcn = EdgeClassifier(dataset.num_features, 32, num_classes).to(device)
data = data.to(device)
edge_labels = edge_labels.to(device)

optimizer = torch.optim.Adam(gcn.parameters(), lr=0.01, weight_decay=5e-4)

for epoch in range(1, 501):
    loss = train(gcn, optimizer, data, edge_labels)
    if epoch % 20 == 0:
        acc, prec, rec, f1 = test(gcn, data, edge_labels)
        print(f"Epoch {epoch:03d} | Loss: {loss:.4f} | "
              f"Acc: {acc:.4f} | Prec: {prec:.4f} | Rec: {rec:.4f} | F1: {f1:.4f}")

# -----------------------------
# 5. Salvar Modelo 
# -----------------------------
import os, json
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

os.makedirs("../models", exist_ok=True)
ckpt_gcn = "../models/gcn_state.pt"
ckpt_meta  = "../models/meta.pt"
ckpt_graph = "../models/graph_artifacts.pt"

# Salva pesos
torch.save(gcn.state_dict(), ckpt_gcn)

# Salva metadados corretos para ARESTAS
torch.save({
    "num_features": dataset.num_features,
    "hidden_channels": 32,   
    "num_classes": num_classes
}, ckpt_meta)

# Salva artefatos do grafo
torch.save({
    "x": data.x.cpu(),
    "edge_index": data.edge_index.cpu()
}, ckpt_graph)

print("\n")

# -----------------------------
# 6. Métricas adicionais  
# -----------------------------
gcn.eval()

with torch.no_grad():
    logits = gcn(data.x, data.edge_index)   
    pred = logits.argmax(dim=1).cpu()
    labels = edge_labels.cpu()              

    acc  = accuracy_score(labels, pred)
    prec = precision_score(labels, pred, average="macro", zero_division=0)
    rec  = recall_score(labels, pred, average="macro", zero_division=0)
    f1   = f1_score(labels, pred, average="macro", zero_division=0)

# Monta dicionário e salva em JSON
results = {
    "all": {
        "accuracy": float(acc),
        "precision": float(prec),
        "recall": float(rec),
        "f1": float(f1),
    }
}

os.makedirs("../results", exist_ok=True)
with open("../results/result.json", "w") as f:
    json.dump(results, f, indent=2)

print("\n")

Epoch 020 | Loss: 0.6928 | Acc: 0.5369 | Prec: 0.5912 | Rec: 0.5349 | F1: 0.4511
Epoch 040 | Loss: 0.6734 | Acc: 0.5871 | Prec: 0.5872 | Rec: 0.5871 | F1: 0.5870
Epoch 060 | Loss: 0.6493 | Acc: 0.6292 | Prec: 0.6338 | Rec: 0.6297 | F1: 0.6265
Epoch 080 | Loss: 0.6324 | Acc: 0.6654 | Prec: 0.6685 | Rec: 0.6651 | F1: 0.6635
Epoch 100 | Loss: 0.6233 | Acc: 0.6730 | Prec: 0.6789 | Rec: 0.6725 | F1: 0.6699
Epoch 120 | Loss: 0.6269 | Acc: 0.6815 | Prec: 0.6820 | Rec: 0.6816 | F1: 0.6814
Epoch 140 | Loss: 0.6179 | Acc: 0.6800 | Prec: 0.6802 | Rec: 0.6801 | F1: 0.6800
Epoch 160 | Loss: 0.6121 | Acc: 0.6782 | Prec: 0.6845 | Rec: 0.6786 | F1: 0.6757
Epoch 180 | Loss: 0.6165 | Acc: 0.6802 | Prec: 0.6830 | Rec: 0.6805 | F1: 0.6792
Epoch 200 | Loss: 0.6101 | Acc: 0.6884 | Prec: 0.6884 | Rec: 0.6884 | F1: 0.6884
Epoch 220 | Loss: 0.6102 | Acc: 0.6991 | Prec: 0.6992 | Rec: 0.6992 | F1: 0.6991
Epoch 240 | Loss: 0.6171 | Acc: 0.6883 | Prec: 0.6898 | Rec: 0.6881 | F1: 0.6876
Epoch 260 | Loss: 0.6062 | A