In [None]:
# -----------------------------
# 0. Carregar Dependências
# -----------------------------
import os
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

# -----------------------------
# 1. Definir Modelo
# -----------------------------
class NodeClassifier(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(NodeClassifier, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return x   # logits por classe para cada nó

# -----------------------------
# 2. Carregar Pesos + Metadados
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ckpt_model = "../models/gcn_state.pt"
ckpt_meta  = "../models/meta.pt"
ckpt_graph = "../models/graph_artifacts.pt"

# Metadados
meta = torch.load(ckpt_meta, map_location=device)

# Reconstrói modelo
gcn_loaded = NodeClassifier(meta["num_features"], meta["hidden_channels"], meta["num_classes"]).to(device)
gcn_loaded.load_state_dict(torch.load(ckpt_model, map_location=device))
gcn_loaded.eval()

# Carrega artefatos do grafo (x e edge_index)
graph_arts = torch.load(ckpt_graph, map_location=device)
x = graph_arts["x"].to(device)
edge_index = graph_arts["edge_index"].to(device)

# -----------------------------
# 3. Função de Inferência
# -----------------------------
@torch.no_grad()
def predict_nodes(node_indices):
    # Embeddings dos nós com GCN treinado
    out = gcn_loaded(x, edge_index)          # [num_nodes, num_classes]
    probs_all = F.softmax(out, dim=1)        # probabilidades para todos os nós
    preds_all = probs_all.argmax(dim=1)      # classes previstas

    results = []
    for idx in node_indices:
        results.append({
            "node": idx,
            "pred_class": int(preds_all[idx].cpu()),
            "probs": probs_all[idx].cpu().numpy().tolist()
        })
    return results


# -----------------------------
# 4. Exemplo de Uso
# -----------------------------
nos = [0, 10, 123]
probs = predict_nodes(nos)

for r in probs :
    print(f"Nó {r['node']} -> Classe prevista: {r['pred_class']} | Probs: {r['probs']}")


Nó 0 -> Classe prevista: 3 | Probs: [0.00018875040404964238, 0.0002575435792095959, 0.00027790165040642023, 0.9984663724899292, 0.00022564477694686502, 0.0001672050857450813, 0.00041659874841570854]
Nó 10 -> Classe prevista: 0 | Probs: [0.9975009560585022, 0.0020546468440443277, 3.4819483971659793e-06, 2.5743382138898596e-05, 8.3106046076864e-05, 2.8420610760804266e-05, 0.0003037290589418262]
Nó 123 -> Classe prevista: 6 | Probs: [1.3812036741001066e-05, 1.3780419976683334e-05, 1.5782463378855027e-05, 8.077052370936144e-06, 1.9456465452094562e-05, 9.343179044662975e-06, 0.9999197721481323]
