# Validate Group Splitter

In [None]:
import json
from pathlib import Path
from datetime import datetime

import torch

from pioneerml.zenml import utils as zenml_utils
from pioneerml.training.lightning import GraphLightningModule

from pioneerml.models.classifiers import GroupSplitter
from pioneerml.data.loaders import load_splitter_groups
from pioneerml.training.datamodules import SplitterDataModule
from pioneerml.evaluation import (
    plot_multilabel_confusion_matrix,
    plot_roc_curves,
    plot_precision_recall_curves,
    plot_probability_distributions,
    plot_confidence_analysis,
)
from pioneerml.data import NODE_LABEL_TO_NAME, NUM_NODE_CLASSES

PROJECT_ROOT = zenml_utils.find_project_root()
checkpoints_dir = Path(PROJECT_ROOT) / 'trained_models' / 'group_splitter'
checkpoints_dir.mkdir(parents=True, exist_ok=True)
print(f'Project root: {PROJECT_ROOT}')
print(f'Checkpoints directory: {checkpoints_dir}')


In [None]:
# Load latest checkpoint
checkpoint_files = sorted(checkpoints_dir.glob('group_splitter_*.pt'), reverse=True)
if not checkpoint_files:
    raise ValueError('No group_splitter checkpoints found')
selected_checkpoint = checkpoint_files[0]
state_dict = torch.load(selected_checkpoint, map_location='cpu')
model = GroupSplitter()
model.load_state_dict(state_dict)
model.eval()
print(f'Loaded checkpoint: {selected_checkpoint.name}')


In [None]:
# Load validation data
file_pattern = str(Path(PROJECT_ROOT) / 'data' / 'splitterGroups_*.npy')
records = load_splitter_groups(file_pattern, max_files=None, verbose=True)
datamodule = SplitterDataModule(records=records, batch_size=128, num_workers=0, val_split=0.0)
datamodule.setup('fit')
val_dataset = datamodule.val_dataset or datamodule.train_dataset
print(f'Validation size: {len(val_dataset)}')


In [None]:
# Predict
from torch_geometric.loader import DataLoader
loader = DataLoader(val_dataset, batch_size=128, shuffle=False)
all_preds, all_targets = [], []
with torch.no_grad():
    for batch in loader:
        preds = model(batch)
        all_preds.append(preds.cpu())
        all_targets.append(batch.y.cpu())
predictions = torch.cat(all_preds, dim=0)
targets = torch.cat(all_targets, dim=0)
print(predictions.shape, targets.shape)


In [None]:
# Plots
class_names = list(NODE_LABEL_TO_NAME.values())
plot_multilabel_confusion_matrix(predictions=predictions, targets=targets, class_names=class_names, threshold=0.5, normalize=True, show=True)
plot_roc_curves(predictions=predictions, targets=targets, class_names=class_names, show=True)
plot_precision_recall_curves(predictions=predictions, targets=targets, class_names=class_names, show=True)
plot_probability_distributions(predictions=predictions, targets=targets, class_names=class_names, show=True)
plot_confidence_analysis(predictions=predictions, targets=targets, class_names=class_names, show=True)
