In [1]:
# ==========================================================
# 04_GIN_embeddings.ipynb
# Pr√©-treinamento contrastivo com GIN (GCPAL-like)
# ==========================================================

# 0) Project setup (path fix for notebooks)
import sys
from pathlib import Path

# Detect project root (works in notebooks and scripts)
if "__file__" in globals():
    ROOT = Path(__file__).resolve().parents[1]
else:
    ROOT = Path.cwd().parents[0]  # assumes this notebook lives in /notebooks

# Add project root to sys.path so `src` can be imported (if needed)
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))

print(f"üìÅ Project root detected as: {ROOT.resolve()}")


üìÅ Project root detected as: /Users/leonardoribeiro/Documents/DataScience/MBA_USP/TCC


In [2]:
# ==========================================================
# Imports + config + device
# ==========================================================
import pandas as pd
from src.utils import ConfigLoader, EnvironmentSetup
from src.gcpal import (
    PretrainGraphBuilder,
    build_knn_edge_index,
    GINEncoder,
    ProjectionHead,
    build_positive_lists,
    GINPretrainer,
)

# config & device
cfg = ConfigLoader.load("base.yaml")
env = EnvironmentSetup(seed=cfg.get("general", {}).get("seed", 42))
device = env.device
print("device:", device)

data_proc = (ROOT / cfg["paths"]["data_processed"]).resolve()
print("data_processed:", data_proc)


‚úÖ Active device: cpu
GPU detected: None
Torch version: 2.3.1
device: cpu
data_processed: /Users/leonardoribeiro/Documents/DataScience/MBA_USP/TCC/data/processed


In [3]:
# ==========================================================
# Carrega CSVs processados do 01/03
# ==========================================================
nodes_with_class_path = data_proc / "elliptic_nodes_with_class.csv"
edges_path            = data_proc / "elliptic_edges.csv"

df_nodes_with_class = pd.read_csv(nodes_with_class_path)
df_edges            = pd.read_csv(edges_path)

print("nodes_with_class:", df_nodes_with_class.shape)
print("edges:", df_edges.shape)


nodes_with_class: (203769, 168)
edges: (234355, 2)


In [4]:
# ==========================================================
# Grafo √öNICO de pr√©-treino (time_step <= 34) com PretrainGraphBuilder
# ==========================================================
feature_cols = [c for c in df_nodes_with_class.columns if c.startswith("feature_")]

pt_builder = PretrainGraphBuilder(
    df_nodes_with_class=df_nodes_with_class,
    df_edges=df_edges,
    feature_cols=feature_cols,
    max_train_step=cfg.get("pretrain", {}).get("max_train_step", 34),
    device=device,
)
data_train_global = pt_builder.build()


Data(x=[136265, 165], edge_index=[2, 313686], y=[136265], mask_labeled=[136265])
Total de n√≥s (‚â§34): 136265
Total de arestas (bidirecionais): 313686
N√≥s rotulados: 29894
Propor√ß√£o de il√≠citos nos rotulados: 0.11580919474363327


In [5]:
# ==========================================================
# KNN view (k=15 por padr√£o, pode vir do YAML)
# ==========================================================
k = cfg.get("pretrain", {}).get("k", 15)
knn_bs = cfg.get("pretrain", {}).get("knn_batch_size", 4096)

edge_index_knn = build_knn_edge_index(
    x=data_train_global.x.to(device),
    k=k,
    batch_size=knn_bs,
    device=device,
)

print("edge_index_knn:", tuple(edge_index_knn.size()))


edge_index_knn: (2, 4087992)


In [6]:
# ==========================================================
# Modelo: GINEncoder + ProjectionHead
# ==========================================================
in_dim      = data_train_global.x.size(1)
hidden_dim  = cfg.get("model", {}).get("gin", {}).get("hidden_dim", 128)
layers      = cfg.get("model", {}).get("gin", {}).get("layers", 2)
proj_dim    = cfg.get("model", {}).get("gin", {}).get("proj_dim", 128)

encoder  = GINEncoder(in_channels=in_dim, hidden_channels=hidden_dim, num_layers=layers).to(device)
proj_head = ProjectionHead(in_dim=hidden_dim, proj_dim=proj_dim).to(device)

print("encoder/proj prontos:", in_dim, hidden_dim, layers, proj_dim)


encoder/proj prontos: 165 128 2 128


In [7]:
# ==========================================================
# Listas de positivos (A estrutural e A_KNN)
# ==========================================================
pos_lists_struct = build_positive_lists(
    edge_index=data_train_global.edge_index.to("cpu"),
    num_nodes=data_train_global.num_nodes,
    add_self=True,
)
pos_lists_knn = build_positive_lists(
    edge_index=edge_index_knn.to("cpu"),
    num_nodes=data_train_global.num_nodes,
    add_self=True,
)

len(pos_lists_struct), len(pos_lists_knn)


(136265, 136265)

In [8]:
# ==========================================================
# Treino contrastivo com GINPretrainer (mesma l√≥gica do seu loop)
# ==========================================================
pre = cfg.get("pretrain", {})

trainer = GINPretrainer(
    device=device,
    lambda_mix=pre.get("lambda_mix", 0.5),
    tau=pre.get("tau", 0.5),
    lr=pre.get("lr", 1e-3),
    max_epochs=pre.get("epochs", 20),
    anchor_bs=pre.get("anchor_bs", 2048),
    target_bs=pre.get("target_bs", 32768),
    patience=pre.get("patience", 999),
    drop_p_edge=pre.get("drop_p_edge", 0.3),
    drop_p_feat=pre.get("drop_p_feat", 0.3),
)

metrics = trainer.fit(
    data_train_global=data_train_global,
    edge_index_knn=edge_index_knn,
    encoder=encoder,
    proj_head=proj_head,
    pos_lists_struct=pos_lists_struct,
    pos_lists_knn=pos_lists_knn,
)

metrics


: 