In [1]:
import sys
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau

from torch_geometric.data import Data

sys.path.append('..')
from src.sheaf import CSNN
from src.utils import accuracy
from src.load_data import load_data


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)

<torch._C.Generator at 0x112751790>

## Load dataset

In [8]:
DATASET_NAME = 'cornell'  # ['texas', 'cornell', 'wisconsin']
SPLIT = [0.7, 0.1, 0.2]  # Train, Val, Test split ratios

In [9]:
data: Data
data, train_mask, val_mask, test_mask, num_classes, class_weights = load_data(DATASET_NAME, SPLIT)

Cornell | num_nodes=183 | num_classes=5
Train=126, val=16, test=41
class_counts: [26.0, 11.0, 21.0, 57.0, 11.0]
class_weights: [0.6737174987792969, 1.5924232006072998, 0.8341264724731445, 0.30730974674224854, 1.5924232006072998]


## Sheaf-NN

In [5]:
torch.manual_seed(42)
# Modelo CSNN
model = CSNN(
    in_dim=data.x.size(-1),
    hidden_dim=64,           # prueba también 64
    out_dim=num_classes,
    num_nodes=data.num_nodes,
    edge_index=data.edge_index,
    num_layers=2,            # el paper suele usar bastantes capas
    dropout=0.5,
).to(device)

### Training

In [6]:
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=0.01,
    weight_decay=5e-4,
)
criterion = nn.CrossEntropyLoss(weight=class_weights)
lr_scheduler = ReduceLROnPlateau(
    optimizer,
    mode='max',
    factor=0.5,
    patience=100,
    min_lr=1e-5
)

In [None]:
# Entrenamiento con early stopping
best_val_acc = 0.0
best_state = None
max_epochs = 500

for epoch in range(1, max_epochs + 1):
    model.train()
    optimizer.zero_grad()
    out = model(data.x)
    loss = criterion(out[train_mask], data.y[train_mask])
    loss.backward()
    optimizer.step()

    # Eval
    model.eval()
    with torch.no_grad():
        out = model(data.x)
        train_acc = accuracy(out[train_mask], data.y[train_mask])
        val_acc = accuracy(out[val_mask], data.y[val_mask])
        test_acc = accuracy(out[test_mask], data.y[test_mask])

    # Early stopping basado en validación
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}

    if epoch % 20 == 0 or epoch == 1:
        print(
            f"Epoch {epoch:04d} | "
            f"loss {loss.item():.4f} | "
            f"train {train_acc:.4f} | "
            f"val {val_acc:.4f} | "
            f"test {test_acc:.4f}"
        )

    lr_scheduler.step(val_acc)

# Cargar mejor modelo
if best_state is not None:
    model.load_state_dict({k: v.to(device) for k, v in best_state.items()})

model.eval()
with torch.no_grad():
    out = model(data.x)
    final_train = accuracy(out[train_mask], data.y[train_mask])
    final_val = accuracy(out[val_mask], data.y[val_mask])
    final_test = accuracy(out[test_mask], data.y[test_mask])

print(
    f"train {final_train:.4f} | val {final_val:.4f} | test {final_test:.4f}"
)

# TODO: implementar metrica de f1
# TODO: Imprimir metricas contra el baseline
# TODO: Hacer gráfica de accuracy, f1 vs epochs


Epoch 0001 | loss 1.5942 | train 0.5556 | val 0.5882 | test 0.5250
Epoch 0020 | loss 1.1840 | train 0.6111 | val 0.6471 | test 0.5500
Epoch 0040 | loss 0.5995 | train 0.8016 | val 0.8235 | test 0.6250
Epoch 0060 | loss 0.4336 | train 0.8889 | val 0.7059 | test 0.6750
Epoch 0080 | loss 0.4298 | train 0.9127 | val 0.7647 | test 0.6750
Epoch 0100 | loss 0.3738 | train 0.9444 | val 0.7647 | test 0.7500
Epoch 0120 | loss 0.3113 | train 0.9762 | val 0.8235 | test 0.6750
Epoch 0140 | loss 0.0736 | train 1.0000 | val 0.8824 | test 0.8500
Epoch 0160 | loss 0.0924 | train 1.0000 | val 0.8824 | test 0.8500
Epoch 0180 | loss 0.0389 | train 1.0000 | val 0.9412 | test 0.8500
Epoch 0200 | loss 0.0273 | train 1.0000 | val 0.9412 | test 0.8500
Epoch 0220 | loss 0.0110 | train 1.0000 | val 0.9412 | test 0.8750
Epoch 0240 | loss 0.0125 | train 1.0000 | val 0.9412 | test 0.8750
Epoch 0260 | loss 0.0053 | train 1.0000 | val 0.9412 | test 0.8500
Epoch 0280 | loss 0.0340 | train 1.0000 | val 0.9412 | test 0.