In [None]:
# -----------------------------
# 0. Carregar Dependências
# -----------------------------
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

# -----------------------------
# 1. Definir Modelo
# -----------------------------
class GCN(nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return x   # embeddings de nós

class EdgeClassifier(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(EdgeClassifier, self).__init__()
        self.gcn = GCN(in_channels, hidden_channels)
        self.mlp = nn.Sequential(
            nn.Linear(2 * hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, out_channels)
        )

    def forward(self, x, edge_index):
        node_emb = self.gcn(x, edge_index)
        edge_emb = torch.cat([
            node_emb[edge_index[0]],
            node_emb[edge_index[1]]
        ], dim=1)
        return self.mlp(edge_emb)   # logits por classe para cada aresta

# -----------------------------
# 2. Carregar Pesos + Metadados
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ckpt_model = "../models/gcn_state.pt"     
ckpt_meta  = "../models/meta.pt"
ckpt_graph = "../models/graph_artifacts.pt"

# Metadados
meta = torch.load(ckpt_meta, map_location=device)

# Reconstrói modelo
edge_model = EdgeClassifier(meta["num_features"], meta["hidden_channels"], meta["num_classes"]).to(device)
edge_model.load_state_dict(torch.load(ckpt_model, map_location=device))
edge_model.eval()

# Carrega artefatos do grafo (x e edge_index)
graph_arts = torch.load(ckpt_graph, map_location=device)
x = graph_arts["x"].to(device)
edge_index = graph_arts["edge_index"].to(device)

# -----------------------------
# 3. Função de Inferência
# -----------------------------
@torch.no_grad()
def predict_edges(edge_list):
    """
    edge_list: lista de arestas [(u,v), (x,y), ...]
    """
    logits_all = edge_model(x, edge_index)      # [num_edges, num_classes]
    probs_all = F.softmax(logits_all, dim=1)    # probabilidades
    preds_all = probs_all.argmax(dim=1)         # classes previstas

    results = []
    for (u, v) in edge_list:
        mask = (((edge_index[0] == u) & (edge_index[1] == v)) |
                ((edge_index[0] == v) & (edge_index[1] == u))).nonzero(as_tuple=True)[0]

        if mask.numel() == 0:
            results.append({
                "edge": (u, v),
                "error": "Aresta não encontrada no grafo"
            })
            continue

        idx = mask[0].item()
        results.append({
            "edge": (u, v),
            "pred_class": int(preds_all[idx].cpu()),
            "probs": probs_all[idx].cpu().numpy().tolist()
        })
    return results

# -----------------------------
# 4. Exemplo de Uso
# -----------------------------
arestas = [(2582, 0), (2, 1), (652, 1)]
preds = predict_edges(arestas)

for r in preds:
    print(f"Aresta {r['edge']} -> Classe prevista: {r['pred_class']} | Probs: {r['probs']}")

Aresta (2582, 0) -> Classe prevista: 1 | Probs: [0.24018479883670807, 0.7598152160644531]
Aresta (2, 1) -> Classe prevista: 0 | Probs: [0.9403064250946045, 0.05969354882836342]
Aresta (652, 1) -> Classe prevista: 0 | Probs: [0.8687646985054016, 0.13123531639575958]
