In [1]:
# -----------------------------
# 0. Carregar Dependências
# -----------------------------
import os
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

# -----------------------------
# 1. Definir Modelos
# -----------------------------
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x


class LinkPredictor(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels=1):
        super(LinkPredictor, self).__init__()
        self.lin1 = torch.nn.Linear(in_channels * 2, hidden_channels)
        self.lin2 = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, x_i, x_j):
        z = torch.cat([x_i, x_j], dim=-1)
        z = F.relu(self.lin1(z))
        z = torch.sigmoid(self.lin2(z))
        return z.view(-1)

# -----------------------------
# 2. Carregar Pesos + Artefatos
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ckpt_gcn = "../models/gcn_state.pt"
ckpt_pred = "../models/predictor_state.pt"
ckpt_meta = "../models/meta.pt"
ckpt_graph = "../models/graph_artifacts.pt"

# Metadados
meta = torch.load(ckpt_meta, map_location=device)

# Reconstrói modelos
gcn_loaded = GCN(meta["num_features"], meta["gcn_hidden"], meta["gcn_out"]).to(device)
gcn_loaded.load_state_dict(torch.load(ckpt_gcn, map_location=device))
gcn_loaded.eval()

predictor_loaded = LinkPredictor(meta["gcn_out"], meta["pred_hidden"]).to(device)
predictor_loaded.load_state_dict(torch.load(ckpt_pred, map_location=device))
predictor_loaded.eval()

# Carrega artefatos do grafo (x e edge_index)
graph_arts = torch.load(ckpt_graph, map_location=device)
x = graph_arts["x"].to(device)
edge_index = graph_arts["train_pos_edge_index"].to(device)

# -----------------------------
# 3. Função de Inferência
# -----------------------------
@torch.no_grad()
def predict_pairs(node_pairs):
    # Embeddings dos nós com GCN treinado
    z = gcn_loaded(x, edge_index)

    i = torch.tensor([u for (u, _) in node_pairs], device=z.device)
    j = torch.tensor([v for (_, v) in node_pairs], device=z.device)

    probs = predictor_loaded(z[i], z[j])
    return probs

# -----------------------------
# 4. Exemplo de Uso
# -----------------------------
pares = [(0, 42), (10, 77), (123, 456)]
probs = predict_pairs(pares)

for (u, v), p in zip(pares, probs.tolist()):
    print(f"Probabilidade de existir aresta entre ({u}, {v}): {p:.4f}")


Probabilidade de existir aresta entre (0, 42): 0.0000
Probabilidade de existir aresta entre (10, 77): 0.0393
Probabilidade de existir aresta entre (123, 456): 0.0000
