# PriorityNet Training & Evaluation

**Fully trainable notebook** for PriorityNet - the intelligent signal ordering model in PosteriFlow.

## Quick Start

1. Run cells sequentially from top to bottom
2. Data will be generated automatically if needed
3. Model trains and saves checkpoints
4. Validation runs automatically each epoch

## 1. Setup & Configuration

In [None]:
import sys
import os
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import yaml
import logging
from types import SimpleNamespace
import time
from collections import defaultdict
import matplotlib.pyplot as plt

# Setup paths
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root / "src"))
sys.path.insert(0, str(project_root / "experiments"))

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('PriorityNetTraining')

logger.info(f"Project root: {project_root}")
logger.info(f"PyTorch version: {torch.__version__}")
logger.info(f"CUDA available: {torch.cuda.is_available()}")

# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
logger.info(f"Using device: {device}")

In [None]:
# Load configuration
config_path = project_root / "configs/enhanced_training.yaml"

with open(config_path) as f:
    config_dict = yaml.safe_load(f)

priority_net_config = config_dict.get('priority_net', {})

# Display key configuration parameters
print("\nüìã PriorityNet Configuration")
print("="*70)
print(f"Hidden dims: {priority_net_config.get('hidden_dims')}")
print(f"Batch size: {priority_net_config.get('batch_size')}")
print(f"Learning rate: {priority_net_config.get('learning_rate'):.2e}")
print(f"Epochs: {priority_net_config.get('epochs')}")
print(f"Warmup epochs: {priority_net_config.get('warmup_epochs')}")
print(f"Use Transformer encoder: {priority_net_config.get('use_transformer_encoder')}")
print(f"\nLoss weights:")
print(f"  Ranking: {priority_net_config.get('ranking_weight')}")
print(f"  MSE: {priority_net_config.get('mse_weight')}")
print(f"  Uncertainty: {priority_net_config.get('uncertainty_weight')}")
print(f"\nCalibration weights:")
print(f"  Mean: {priority_net_config.get('calib_mean_weight')}")
print(f"  Max: {priority_net_config.get('calib_max_weight')}")
print(f"  Range: {priority_net_config.get('calib_range_weight')}")
print("="*70)

## 2. Import Model & Training Classes

In [None]:
# Import core model classes
from ahsd.core.priority_net import (
    PriorityNet,
    PriorityLoss,
    PriorityNetTrainer,
    TemporalStrainEncoder,
    TransformerStrainEncoder
)

logger.info("‚úÖ Imported PriorityNet classes")

In [None]:
# Try to import data loading utilities
try:
    # Check if train_priority_net.py exists and import from it
    import importlib.util
    
    train_script = project_root / "experiments/train_priority_net.py"
    if train_script.exists():
        spec = importlib.util.spec_from_file_location(
            "train_priority_net",
            train_script
        )
        train_module = importlib.util.module_from_spec(spec)
        sys.modules['train_priority_net'] = train_module
        spec.loader.exec_module(train_module)
        
        # Extract needed classes
        PriorityNetDataset = train_module.PriorityNetDataset
        ChunkedGWDataLoader = train_module.ChunkedGWDataLoader
        collate_priority_batch = train_module.collate_priority_batch
        
        logger.info("‚úÖ Imported data loading utilities from train_priority_net.py")
        HAS_DATA_LOADERS = True
    else:
        logger.warning(f"train_priority_net.py not found at {train_script}")
        HAS_DATA_LOADERS = False
except Exception as e:
    logger.error(f"Could not import data loaders: {e}")
    HAS_DATA_LOADERS = False

if not HAS_DATA_LOADERS:
    logger.warning("Will use synthetic data batches for demonstration")

## 3. Initialize Model

In [None]:
# Create configuration namespace
cfg = SimpleNamespace(
    hidden_dims=priority_net_config.get('hidden_dims', [640, 512, 384, 256]),
    dropout=priority_net_config.get('dropout', 0.25),
    use_strain=True,
    use_edge_conditioning=True,
    n_edge_types=19,
    use_transformer_encoder=priority_net_config.get('use_transformer_encoder', False),
    overlap_importance_hidden=priority_net_config.get('importance_hidden_dim', 32)
)

# Initialize PriorityNet
model = PriorityNet(cfg)
model.to(device)

total_params = sum(p.numel() for p in model.parameters())
logger.info(f"‚úÖ PriorityNet initialized with {total_params:,} parameters")

In [None]:
# Initialize trainer with loss function and optimizer
trainer = PriorityNetTrainer(model, priority_net_config)
logger.info(f"‚úÖ Trainer initialized")
print(f"\nüèãÔ∏è Trainer Configuration:")
print(f"  Learning rate: {priority_net_config.get('learning_rate'):.2e}")
print(f"  Optimizer: AdamW")
print(f"  Warmup epochs: {trainer.warmup_epochs}")
print(f"  Gradient clip norm: {trainer.gradient_clip_norm}")

## 4. Data Loading Setup

In [None]:
# Check for actual data
train_dir = project_root / "data/dataset/train"
val_dir = project_root / "data/dataset/val"

train_loader = None
val_loader = None
USE_SYNTHETIC_DATA = False

if HAS_DATA_LOADERS:
    if train_dir.exists():
        train_samples = list(train_dir.glob("*.h5")) + list(train_dir.glob("sample_*.pkl"))
        if train_samples:
            logger.info(f"‚úÖ Found {len(train_samples)} training samples")
            
            try:
                train_loader = ChunkedGWDataLoader(
                    data_dir=train_dir,
                    batch_size=priority_net_config.get('batch_size', 12),
                    shuffle=True,
                    collate_fn=collate_priority_batch,
                    num_workers=0  # Set to 0 in Jupyter to avoid multiprocessing issues
                )
                logger.info(f"‚úÖ Created train loader with {len(train_loader)} batches")
            except Exception as e:
                logger.warning(f"Could not create train loader: {e}")
                USE_SYNTHETIC_DATA = True
        else:
            logger.warning(f"No samples found in {train_dir}")
            USE_SYNTHETIC_DATA = True
    else:
        logger.warning(f"Train directory not found: {train_dir}")
        USE_SYNTHETIC_DATA = True
    
    if val_dir.exists() and train_loader is not None:
        val_samples = list(val_dir.glob("*.h5")) + list(val_dir.glob("sample_*.pkl"))
        if val_samples:
            logger.info(f"‚úÖ Found {len(val_samples)} validation samples")
            
            try:
                val_loader = ChunkedGWDataLoader(
                    data_dir=val_dir,
                    batch_size=priority_net_config.get('batch_size', 12),
                    shuffle=False,
                    collate_fn=collate_priority_batch,
                    num_workers=0
                )
                logger.info(f"‚úÖ Created val loader with {len(val_loader)} batches")
            except Exception as e:
                logger.warning(f"Could not create val loader: {e}")
else:
    logger.info("Data loader utilities not available, will use synthetic batches")
    USE_SYNTHETIC_DATA = True

if USE_SYNTHETIC_DATA:
    logger.info("‚ö†Ô∏è  Using synthetic data for demonstration")

## 5. Define Training Functions

In [None]:
def create_synthetic_batch(batch_size=8, n_signals_range=(2, 4)):
    """
    Create a synthetic batch for demonstration.
    
    Args:
        batch_size: Number of samples in batch
        n_signals_range: Range of number of signals per sample
    
    Returns:
        Batch dict with all required fields
    """
    batch = {}
    
    # Strain data: [batch, detectors, time]
    batch['strain_data'] = torch.randn(batch_size, 3, 2048).to(device)
    
    # Create variable-length parameters
    max_signals = np.random.randint(*n_signals_range)
    batch['parameters'] = torch.randn(batch_size, max_signals, 16).to(device)
    batch['edge_type_ids'] = torch.randint(0, 19, (batch_size, max_signals)).to(device)
    
    # Targets: random priorities
    batch['priorities'] = torch.rand(batch_size, max_signals).to(device)
    batch['snr'] = torch.ones(batch_size, max_signals).to(device) * 20
    
    return batch

logger.info("‚úÖ Defined synthetic batch generator")

In [None]:
def train_one_epoch(model, train_loader, trainer, device='cpu', use_synthetic=False):
    """
    Train for one epoch.
    
    Args:
        model: PriorityNet model
        train_loader: Training data loader (or None if synthetic)
        trainer: PriorityNetTrainer instance
        device: 'cpu' or 'cuda'
        use_synthetic: Use synthetic data if True
    
    Returns:
        avg_loss: Average loss over epoch
        grad_stats: (min_grad, max_grad, mean_grad)
    """
    model.train()
    total_loss = 0.0
    num_batches = 0
    grad_stats = {'min': float('inf'), 'max': 0.0, 'mean': 0.0, 'count': 0}
    
    # Use synthetic batches if no real data
    if use_synthetic:
        batches = [create_synthetic_batch(batch_size=8) for _ in range(5)]
    else:
        batches = train_loader
    
    for batch_idx, batch in enumerate(batches):
        try:
            # Handle dict vs tensor batches
            if isinstance(batch, dict):
                strain_data = batch.get('strain_data', batch.get('strain', None))
                parameters = batch.get('parameters', None)
                edge_ids = batch.get('edge_type_ids', None)
                targets = batch.get('priorities', batch.get('targets', None))
                snr_values = batch.get('snr', None)
            else:
                # Fallback for tuple returns
                strain_data, parameters, edge_ids, targets, snr_values = batch[:5]
            
            if strain_data is None or parameters is None or targets is None:
                continue
            
            # Move to device if needed
            strain_data = strain_data.to(device) if not strain_data.is_cuda else strain_data
            parameters = parameters.to(device) if not parameters.is_cuda else parameters
            targets = targets.to(device) if not targets.is_cuda else targets
            if edge_ids is not None and not edge_ids.is_cuda:
                edge_ids = edge_ids.to(device)
            if snr_values is not None and not snr_values.is_cuda:
                snr_values = snr_values.to(device)
            
            # Forward pass
            predictions, uncertainties = model(
                strain_data=strain_data,
                parameters=parameters,
                edge_type_ids=edge_ids
            )
            
            # Compute loss
            loss = trainer.loss_fn(
                predictions=predictions,
                targets=targets,
                uncertainties=uncertainties,
                snr_values=snr_values
            )
            
            # Backward pass
            trainer.optimizer.zero_grad()
            loss.backward()
            
            # Gradient clipping and stats
            grad_norm = torch.nn.utils.clip_grad_norm_(
                model.parameters(),
                trainer.gradient_clip_norm
            )
            
            # Track gradient stats
            grads = [p.grad.abs().max().item() for p in model.parameters() if p.grad is not None]
            if grads:
                grad_stats['min'] = min(grad_stats['min'], min(grads))
                grad_stats['max'] = max(grad_stats['max'], max(grads))
                grad_stats['mean'] += np.mean(grads)
                grad_stats['count'] += 1
            
            # Step optimizer
            trainer.optimizer.step()
            
            total_loss += loss.item()
            num_batches += 1
            
            if (batch_idx + 1) % 5 == 0 or batch_idx == 0:
                logger.info(
                    f"  Batch {batch_idx+1} | Loss: {loss.item():.6f} | Grad norm: {grad_norm:.4f}"
                )
        
        except Exception as e:
            logger.error(f"Error in batch {batch_idx}: {e}")
            continue
    
    if grad_stats['count'] > 0:
        grad_stats['mean'] /= grad_stats['count']
    
    avg_loss = total_loss / max(num_batches, 1)
    return avg_loss, grad_stats

logger.info("‚úÖ Defined train_one_epoch function")

In [None]:
def validate(model, val_loader, trainer, device='cpu', use_synthetic=False):
    """
    Validate model on validation set.
    
    Args:
        model: PriorityNet model
        val_loader: Validation data loader (or None if synthetic)
        trainer: PriorityNetTrainer instance
        device: 'cpu' or 'cuda'
        use_synthetic: Use synthetic data if True
    
    Returns:
        avg_loss: Average validation loss
        metrics: Dict of validation metrics
    """
    model.eval()
    total_loss = 0.0
    num_batches = 0
    all_predictions = []
    all_targets = []
    all_uncertainties = []
    all_errors = []
    
    # Use synthetic batches if no real data
    if use_synthetic:
        batches = [create_synthetic_batch(batch_size=8) for _ in range(3)]
    else:
        batches = val_loader
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(batches):
            try:
                # Handle dict vs tensor batches
                if isinstance(batch, dict):
                    strain_data = batch.get('strain_data', batch.get('strain', None))
                    parameters = batch.get('parameters', None)
                    edge_ids = batch.get('edge_type_ids', None)
                    targets = batch.get('priorities', batch.get('targets', None))
                    snr_values = batch.get('snr', None)
                else:
                    strain_data, parameters, edge_ids, targets, snr_values = batch[:5]
                
                if strain_data is None or parameters is None or targets is None:
                    continue
                
                # Move to device
                strain_data = strain_data.to(device) if not strain_data.is_cuda else strain_data
                parameters = parameters.to(device) if not parameters.is_cuda else parameters
                targets = targets.to(device) if not targets.is_cuda else targets
                if edge_ids is not None and not edge_ids.is_cuda:
                    edge_ids = edge_ids.to(device)
                if snr_values is not None and not snr_values.is_cuda:
                    snr_values = snr_values.to(device)
                
                # Forward pass
                predictions, uncertainties = model(
                    strain_data=strain_data,
                    parameters=parameters,
                    edge_type_ids=edge_ids
                )
                
                # Compute loss
                loss = trainer.loss_fn(
                    predictions=predictions,
                    targets=targets,
                    uncertainties=uncertainties,
                    snr_values=snr_values
                )
                
                total_loss += loss.item()
                num_batches += 1
                
                # Collect metrics
                all_predictions.append(predictions.cpu())
                all_targets.append(targets.cpu())
                all_uncertainties.append(uncertainties.cpu())
                all_errors.append(torch.abs(predictions - targets).cpu())
            
            except Exception as e:
                logger.error(f"Error in val batch {batch_idx}: {e}")
                continue
    
    avg_loss = total_loss / max(num_batches, 1)
    
    # Compute metrics
    metrics = {}
    if all_predictions:
        preds = torch.cat(all_predictions).numpy().flatten()
        targets_np = torch.cat(all_targets).numpy().flatten()
        errors = torch.cat(all_errors).numpy().flatten()
        uncs = torch.cat(all_uncertainties).numpy().flatten()
        
        metrics['mae'] = float(np.mean(errors))
        metrics['pred_min'] = float(preds.min())
        metrics['pred_max'] = float(preds.max())
        metrics['pred_range'] = float(preds.max() - preds.min())
        metrics['target_range'] = float(targets_np.max() - targets_np.min())
        metrics['compression_ratio'] = metrics['pred_range'] / (metrics['target_range'] + 1e-8)
        
        # Uncertainty correlation
        if len(errors) > 1 and uncs.std() > 0:
            corr = np.corrcoef(errors, uncs)[0, 1]
            metrics['uncertainty_corr'] = float(corr) if not np.isnan(corr) else 0.0
    
    return avg_loss, metrics

logger.info("‚úÖ Defined validate function")

## 6. Run Training Loop

In [None]:
# Training configuration
num_epochs = min(priority_net_config.get('epochs', 50), 5)  # Limit to 5 for demo
patience = priority_net_config.get('patience', 15)
warmup_epochs = trainer.warmup_epochs
batch_size = priority_net_config.get('batch_size', 12)

print(f"\n" + "="*70)
print(f"üöÄ STARTING TRAINING")
print(f"="*70)
print(f"Epochs: {num_epochs}")
print(f"Warmup epochs: {warmup_epochs}")
print(f"Patience: {patience}")
print(f"Batch size: {batch_size}")
print(f"Device: {device}")
print(f"Using synthetic data: {USE_SYNTHETIC_DATA}")
print("="*70 + "\n")

# Create results directory
checkpoint_dir = project_root / "models/priority_net"
checkpoint_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Checkpoint directory: {checkpoint_dir}")

In [None]:
# Main training loop
history = defaultdict(list)
best_val_loss = float('inf')
patience_counter = 0
start_time = time.time()

for epoch in range(num_epochs):
    epoch_start = time.time()
    
    # Warmup phase
    if epoch < warmup_epochs:
        trainer.warmup_scheduler.step()
        current_lr = trainer.optimizer.param_groups[0]['lr']
        logger.info(f"\n{'='*70}")
        logger.info(f"Epoch {epoch+1}/{num_epochs} [WARMUP] - LR: {current_lr:.2e}")
        logger.info(f"{'='*70}")
    else:
        logger.info(f"\n{'='*70}")
        logger.info(f"Epoch {epoch+1}/{num_epochs}")
        logger.info(f"{'='*70}")
    
    # Training
    print(f"\nüìà Training...")
    train_loss, grad_stats = train_one_epoch(
        model, train_loader, trainer, device,
        use_synthetic=USE_SYNTHETIC_DATA or train_loader is None
    )
    
    # Validation
    print(f"‚úÖ Validating...")
    val_loss, metrics = validate(
        model, val_loader, trainer, device,
        use_synthetic=USE_SYNTHETIC_DATA or val_loader is None
    )
    
    # Learning rate scheduling (after warmup)
    if epoch >= warmup_epochs:
        trainer.scheduler.step(val_loss)
    
    # Track history
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    for key, val in metrics.items():
        history[f'val_{key}'].append(val)
    
    # Log epoch
    epoch_time = time.time() - epoch_start
    logger.info(
        f"\nResults:")
    logger.info(
        f"  Train loss: {train_loss:.6f}")
    logger.info(
        f"  Val loss: {val_loss:.6f}")
    logger.info(
        f"  Time: {epoch_time:.1f}s")
    
    # Log gradient stats
    if grad_stats['count'] > 0:
        logger.info(
            f"  Grad stats: min={grad_stats['min']:.2e}, "
            f"max={grad_stats['max']:.2e}, mean={grad_stats['mean']:.2e}"
        )
    
    # Log metrics if available
    if metrics:
        logger.info(
            f"  MAE: {metrics.get('mae', 0):.4f} | "
            f"Range: [{metrics.get('pred_min', 0):.3f}, {metrics.get('pred_max', 0):.3f}] | "
            f"Compression: {metrics.get('compression_ratio', 0):.1%} | "
            f"Unc Corr: {metrics.get('uncertainty_corr', 0):.3f}"
        )
    
    # Save checkpoint if validation improves
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        
        checkpoint = {
            'epoch': epoch,
            'loss': val_loss,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': trainer.optimizer.state_dict(),
            'scheduler_state_dict': trainer.scheduler.state_dict(),
            'model_config': vars(cfg),
            'history': dict(history)
        }
        
        checkpoint_path = checkpoint_dir / "priority_net_best.pth"
        torch.save(checkpoint, checkpoint_path)
        logger.info(f"\n‚úÖ CHECKPOINT SAVED: {checkpoint_path}")
    
    else:
        patience_counter += 1
        logger.info(f"\nNo improvement ({patience_counter}/{patience})")
        if patience_counter >= patience:
            logger.info(f"\nüõë Early stopping: no improvement for {patience} epochs")
            break

# Summary
total_time = time.time() - start_time
print(f"\n" + "="*70)
print(f"‚úÖ TRAINING COMPLETE")
print(f"="*70)
print(f"Total time: {total_time/60:.1f} minutes")
print(f"Best validation loss: {best_val_loss:.6f}")
print(f"Checkpoint saved to: {checkpoint_dir / 'priority_net_best.pth'}")
print(f"="*70)

## 7. Plot Training Results

In [None]:
# Plot training history
if history['train_loss']:
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # Loss plot
    axes[0, 0].plot(history['train_loss'], label='Train', marker='o')
    axes[0, 0].plot(history['val_loss'], label='Val', marker='s')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Training Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # MAE plot
    if 'val_mae' in history:
        axes[0, 1].plot(history['val_mae'], label='Val MAE', marker='o')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('MAE')
        axes[0, 1].set_title('Mean Absolute Error')
        axes[0, 1].grid(True, alpha=0.3)
    
    # Output range plot
    if 'val_pred_min' in history and 'val_pred_max' in history:
        axes[1, 0].plot(history['val_pred_min'], label='Min', marker='o')
        axes[1, 0].plot(history['val_pred_max'], label='Max', marker='s')
        axes[1, 0].fill_between(range(len(history['val_pred_min'])),
                                   history['val_pred_min'],
                                   history['val_pred_max'], alpha=0.2)
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('Value')
        axes[1, 0].set_title('Output Range Expansion')
        axes[1, 0].legend()
        axes[1, 0].grid(True, alpha=0.3)
    
    # Uncertainty correlation plot
    if 'val_uncertainty_corr' in history:
        axes[1, 1].plot(history['val_uncertainty_corr'], label='Unc Corr', marker='o')
        axes[1, 1].axhline(y=0.15, color='r', linestyle='--', label='Target (0.15)')
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('Correlation')
        axes[1, 1].set_title('Uncertainty-Error Correlation')
        axes[1, 1].legend()
        axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(checkpoint_dir / 'training_history.png', dpi=150, bbox_inches='tight')
    print(f"\nüìä Training plots saved to {checkpoint_dir / 'training_history.png'}")
    plt.show()
else:
    print("No training history to plot")

## 8. Load and Test Checkpoint

In [None]:
# Load best checkpoint
checkpoint_path = project_root / "models/priority_net/priority_net_best.pth"

if checkpoint_path.exists():
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    print("üîß Checkpoint Details")
    print("="*70)
    print(f"Path: {checkpoint_path}")
    print(f"Epoch: {checkpoint.get('epoch', 'unknown')}")
    print(f"Loss: {checkpoint.get('loss', 'unknown'):.6f}")
    
    # Count state dict keys
    state_dict = checkpoint.get('model_state_dict', {})
    print(f"Model parameters: {len(state_dict)} keys")
    print(f"Total model size: {sum(v.numel() for v in state_dict.values()):,} params")
    
    # Load into model
    model.load_state_dict(state_dict, strict=True)
    model.eval()
    print(f"\n‚úÖ Model loaded successfully")
    
    # Show training history
    hist = checkpoint.get('history', {})
    if hist:
        print(f"\nTraining history:")
        print(f"  Epochs trained: {len(hist.get('train_loss', []))}")
        if hist.get('train_loss'):
            print(f"  Final train loss: {hist['train_loss'][-1]:.6f}")
        if hist.get('val_loss'):
            print(f"  Final val loss: {hist['val_loss'][-1]:.6f}")
else:
    print(f"‚ö†Ô∏è  Checkpoint not found at {checkpoint_path}")

## 9. Inference Test

In [None]:
# Test inference
print("\nüîÆ Testing Inference")
print("="*70)

with torch.no_grad():
    # Create test batch
    test_batch = create_synthetic_batch(batch_size=2, n_signals_range=(2, 3))
    
    # Forward pass
    predictions, uncertainties = model(
        strain_data=test_batch['strain_data'],
        parameters=test_batch['parameters'],
        edge_type_ids=test_batch['edge_type_ids']
    )
    
    print(f"\nSample 1:")
    print(f"  Predictions: {predictions[0].cpu().numpy()}")
    print(f"  Uncertainties: {uncertainties[0].cpu().numpy()}")
    print(f"  Target priorities: {test_batch['priorities'][0].cpu().numpy()}")
    print(f"\nSample 2:")
    print(f"  Predictions: {predictions[1].cpu().numpy()}")
    print(f"  Uncertainties: {uncertainties[1].cpu().numpy()}")
    print(f"  Target priorities: {test_batch['priorities'][1].cpu().numpy()}")
    
    # Summary stats
    all_preds = predictions.cpu().numpy().flatten()
    print(f"\nPrediction statistics:")
    print(f"  Min: {all_preds.min():.4f}")
    print(f"  Max: {all_preds.max():.4f}")
    print(f"  Mean: {all_preds.mean():.4f}")
    print(f"  Std: {all_preds.std():.4f}")
    print(f"  Range: {all_preds.max() - all_preds.min():.4f}")

print("="*70 + "\n‚úÖ Inference test complete")