# Validate Pion Stop Regressor

In [None]:
import json
from pathlib import Path
from datetime import datetime

import torch

from pioneerml.zenml import utils as zenml_utils
from pioneerml.training.lightning import GraphLightningModule

from pioneerml.models.regressors import PionStopRegressor
from pioneerml.data.loaders import load_pion_stop_groups
from pioneerml.training.datamodules import PionStopDataModule
from pioneerml.evaluation import plot_regression_diagnostics

PROJECT_ROOT = zenml_utils.find_project_root()
checkpoints_dir = Path(PROJECT_ROOT) / 'trained_models' / 'pion_stop'
checkpoints_dir.mkdir(parents=True, exist_ok=True)
print(f'Project root: {PROJECT_ROOT}')
print(f'Checkpoints directory: {checkpoints_dir}')


In [None]:
# Load latest checkpoint
checkpoint_files = sorted(checkpoints_dir.glob('pion_stop_*.pt'), reverse=True)
if not checkpoint_files:
    raise ValueError('No pion_stop checkpoints found')
selected_checkpoint = checkpoint_files[0]
state_dict = torch.load(selected_checkpoint, map_location='cpu')
model = PionStopRegressor()
model.load_state_dict(state_dict)
model.eval()
print(f'Loaded checkpoint: {selected_checkpoint.name}')


In [None]:
# Load validation data
file_pattern = str(Path(PROJECT_ROOT) / 'data' / 'pionStopGroups_*.npy')
records = load_pion_stop_groups(file_pattern, max_files=None, verbose=True)
datamodule = PionStopDataModule(records=records, batch_size=128, num_workers=0, val_split=0.0)
datamodule.setup('fit')
val_dataset = datamodule.val_dataset or datamodule.train_dataset
print(f'Validation size: {len(val_dataset)}')


In [None]:
# Predict
from torch_geometric.loader import DataLoader
loader = DataLoader(val_dataset, batch_size=128, shuffle=False)
all_preds, all_targets = [], []
with torch.no_grad():
    for batch in loader:
        preds = model(batch)
        all_preds.append(preds.cpu())
        all_targets.append(batch.y.cpu())
predictions = torch.cat(all_preds, dim=0)
targets = torch.cat(all_targets, dim=0)
print(predictions.shape, targets.shape)


In [None]:
# Plots
plot_regression_diagnostics(predictions=predictions, targets=targets, show=True)
