In [1]:
# -----------------------------
# 0. Carregar Dependências
# -----------------------------
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

# -----------------------------
# 1. Definir modelos
# -----------------------------
class SimpleGCN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.fc1 = nn.Linear(in_channels, hidden_channels)
        self.fc2 = nn.Linear(hidden_channels, out_channels)
        
    def forward(self, x, adj):
        h = torch.matmul(adj, x)
        h = F.relu(self.fc1(h))
        h = torch.matmul(adj, h)
        h = self.fc2(h)
        return h

class LinkPredictor(nn.Module):
    def __init__(self, in_channels, hidden_channels, num_classes=3):
        super().__init__()
        self.lin1 = nn.Linear(in_channels * 2, hidden_channels)
        self.lin2 = nn.Linear(hidden_channels, num_classes)
        
    def forward(self, x_i, x_j):
        z = torch.cat([x_i, x_j], dim=-1)
        z = F.relu(self.lin1(z))
        return self.lin2(z)

# -----------------------------
# 2. Carregar Pesos + Artefatos
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ckpt_gcn = "../models/gcn_state.pt"
ckpt_pred = "../models/predictor_state.pt"
ckpt_meta = "../models/meta.pt"
ckpt_graph = "../models/graph_artifacts.pt"

meta = torch.load(ckpt_meta, map_location=device)

gcn_loaded = SimpleGCN(meta["num_features"], meta["gcn_hidden"], meta["gcn_out"]).to(device)
gcn_loaded.load_state_dict(torch.load(ckpt_gcn, map_location=device))
gcn_loaded.eval()

predictor_loaded = LinkPredictor(meta["gcn_out"], meta["pred_hidden"], meta["num_classes"]).to(device)
predictor_loaded.load_state_dict(torch.load(ckpt_pred, map_location=device))
predictor_loaded.eval()

graph_arts = torch.load(ckpt_graph, map_location=device)
x = graph_arts["x"].to(device)
edge_index = graph_arts["adj_norm"].to(device)

# -----------------------------
# 3. Função de Inferência
# -----------------------------
idx2label = {0: "baixa", 1: "media", 2: "alta"}

@torch.no_grad()
def classify_edges(edge_list):
    z = gcn_loaded(x, edge_index)

    i = torch.tensor([u for (u, _) in edge_list], device=z.device)
    j = torch.tensor([v for (_, v) in edge_list], device=z.device)

    logits = predictor_loaded(z[i], z[j])
    probs = torch.softmax(logits, dim=-1)

    results = []
    for (u, v), p in zip(edge_list, probs):
        pred_idx = p.argmax().item()
        results.append({
            "src": u,
            "dst": v,
            "pred_label": idx2label[pred_idx],
            "probs": {idx2label[k]: float(p[k]) for k in range(len(p))}
        })
    return results

```
pares = [(7,26), (31,15), (43,4)]
outputs = classify_edges(pares)

for o in outputs:
    print(f"Aresta ({o['src']}, {o['dst']}): Classe prevista = {o['pred_label']}, Probs = {o['probs']}")
```