In [7]:
import json
import random
import sys
import time
from pathlib import Path
import torch
import numpy as np
import pandas as pd

# --- Setup: Add project source to the Python path ---
# Ensure this path is correct for your project structure
try:
    project_root = Path(__file__).resolve().parent.parent
except NameError:
    project_root = Path.cwd().parent
    
src_path = project_root / 'src'
if str(src_path) not in sys.path:
    sys.path.insert(0, str(src_path))
    
# It's good practice to handle potential import errors
try:
    from utils import load_config
    from normalizer import DataNormalizer
except ImportError as e:
    print(f"Error importing project modules from {src_path}.")
    print("Please ensure the script is run from the 'scripts' directory or adjust the project_root path.")
    sys.exit(1)

def print_header(title: str):
    """Prints a formatted, centered header to the console."""
    width = 80
    padding = (width - len(title) - 2) // 2
    print("\n" + "=" * width)
    print(" " * padding, title, " " * padding)
    print("=" * width)

def prepare_batch_from_memory(batch_size: int, all_test_data: list, species_vars: list, global_vars: list, device: torch.device):
    """
    Creates a batch of input tensors by sampling from pre-loaded in-memory data.
    """
    batch_inputs_list = []
    selected_profiles = random.choices(all_test_data, k=batch_size)
    
    for norm_profile in selected_profiles:
        initial_species = [norm_profile[key][0] for key in species_vars]
        global_conds = [norm_profile[key] for key in global_vars]
        final_norm_time = norm_profile["t_time"][-1]
        input_list = initial_species + global_conds + [final_norm_time]
        batch_inputs_list.append(input_list)
        
    return torch.tensor(batch_inputs_list, dtype=torch.float32, device=device)

def run_benchmark_and_validate():
    """
    Main function to load artifacts, run a performance benchmark on Apple Silicon (MPS)
    if available, and then validate a single prediction.
    """
    # 1. LOAD ARTIFACTS
    # --------------------------------------------------------------------------
    print_header("Initialization")
    print("Loading model and configuration artifacts...")
    CONFIG_FILE = project_root / "inputs/model_input_params.jsonc"
    DATA_ROOT = project_root / "data"
    
    config = load_config(CONFIG_FILE)
    if not config:
        print("Error: Could not load configuration. Exiting.")
        return
        
    model_folder = DATA_ROOT / config["output_paths_config"]["fixed_model_foldername"]
    normalized_data_folder = DATA_ROOT / config["data_paths_config"]["normalized_profiles_foldername"]
    
    # ============================ MPS CHANGE 1: START ============================
    # Updated device selection logic for Apple Silicon
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")
    # ============================ MPS CHANGE 1: END ==============================
    
    model_path = model_folder / "best_model_jit.pt"
    if not model_path.exists():
        print(f"Error: Model file not found at {model_path}. Exiting.")
        return

    # Load the JIT model and move it to the selected device (CPU, CUDA, or MPS)
    model = torch.jit.load(model_path, map_location=device)
    model.eval()
    
    norm_meta_path = normalized_data_folder / "normalization_metadata.json"
    if not norm_meta_path.exists():
        print(f"Error: Normalization metadata not found at {norm_meta_path}. Exiting.")
        return
    with norm_meta_path.open("r") as f:
        norm_metadata = json.load(f)

    test_info_path = model_folder / "test_set_info.json"
    if not test_info_path.exists():
        print(f"Error: Test set info file not found at {test_info_path}. Exiting.")
        return
    with test_info_path.open("r") as f:
        test_filenames = json.load(f)["test_filenames"]
        
    print(f"Pre-loading {len(test_filenames)} test profiles into memory...")
    all_test_data_in_memory = []
    for filename in test_filenames:
        with (normalized_data_folder / filename).open("r") as f:
            all_test_data_in_memory.append(json.load(f))
    print("Pre-loading complete.")

    species_vars = sorted(config["species_variables"])
    global_vars = sorted(config["global_variables"])
    print(f"Setup complete. Model loaded on device: {device.type.upper()}")
    
    # 2. RUN PERFORMANCE BENCHMARK
    # --------------------------------------------------------------------------
    print_header("Performance Benchmark")
    BATCH_SIZE = 4096
    NUM_WARMUP_RUNS = 50
    NUM_TIMING_RUNS = 100

    batch_tensor = prepare_batch_from_memory(BATCH_SIZE, all_test_data_in_memory, species_vars, global_vars, device)
    
    with torch.no_grad():
        for _ in range(NUM_WARMUP_RUNS):
            _ = model(batch_tensor)
    
    timings = []
    with torch.no_grad():
        for _ in range(NUM_TIMING_RUNS):
            # For MPS, perf_counter is sufficient as PyTorch handles synchronization
            start_time = time.perf_counter()
            _ = model(batch_tensor)
            end_time = time.perf_counter()
            timings.append(end_time - start_time)

    # 3. REPORT BENCHMARK RESULTS
    # --------------------------------------------------------------------------
    total_time = sum(timings)
    avg_batch_time_ms = (total_time / NUM_TIMING_RUNS) * 1000
    avg_prediction_time_us = (avg_batch_time_ms / BATCH_SIZE) * 1000

    print(f"{'Batch Size:':<28} {BATCH_SIZE}")
    print(f"{'Device:':<28} {device.type.upper()}")
    print(f"{'Average time per batch:':<28} {avg_batch_time_ms:.4f} ms")
    print(f"{'Average time per single pred:':<28} {avg_prediction_time_us:.4f} µs (microseconds)")

    # 4. RUN SINGLE PREDICTION VALIDATION
    # --------------------------------------------------------------------------
    print_header("Single Prediction Validation")
    norm_profile = random.choice(all_test_data_in_memory)
    
    query_time_idx = random.randint(1, len(norm_profile["t_time"]) - 1)
    query_time_idx = len(norm_profile["t_time"]) - 1


    norm_time_to_predict = norm_profile["t_time"][query_time_idx]
    
    initial_species = [norm_profile[key][0] for key in species_vars]
    global_conds = [norm_profile[key] for key in global_vars]
    
    input_vector = torch.tensor(
        initial_species + global_conds + [norm_time_to_predict], dtype=torch.float32, device=device
    ).unsqueeze(0)
    
    with torch.no_grad():
        norm_prediction_tensor = model(input_vector).squeeze(0)

    # Move the prediction back to the CPU for numpy/pandas operations
    norm_prediction_tensor_cpu = norm_prediction_tensor.cpu()
    
    norm_true_values_tensor = torch.tensor([norm_profile[key][query_time_idx] for key in species_vars])
    
    # 5. DENORMALIZE AND REPORT VALIDATION RESULTS
    # --------------------------------------------------------------------------
    results = []
    real_time_to_predict = DataNormalizer.denormalize(norm_time_to_predict, norm_metadata, "t_time")
    print(f"  - Predicting at time index {query_time_idx} (t ≈ {real_time_to_predict:.4e} s)")

    for i, key in enumerate(species_vars):
        predicted_val = DataNormalizer.denormalize(norm_prediction_tensor_cpu[i], norm_metadata, key).item()
        true_val = DataNormalizer.denormalize(norm_true_values_tensor[i], norm_metadata, key).item()
        
        results.append({
            "Species": key.replace('_evolution', ''),
            "Predicted Value": predicted_val,
            "True Value": true_val,
            "Abs. Error": abs(predicted_val - true_val),
            "Rel. Error (%)": abs(predicted_val - true_val) / (true_val + 1e-20) * 100 if true_val != 0 else float('inf')
        })

    df = pd.DataFrame(results)
    pd.options.display.float_format = '{:,.4e}'.format
    print("\n" + df.to_string(index=False))
    print("-" * 80 + "\n")


if __name__ == "__main__":
    run_benchmark_and_validate()


                                 Initialization                                 
Loading model and configuration artifacts...
Pre-loading 14983 test profiles into memory...
Pre-loading complete.
Setup complete. Model loaded on device: MPS

                             Performance Benchmark                             
Batch Size:                  4096
Device:                      MPS
Average time per batch:      0.7808 ms
Average time per single pred: 0.1906 µs (microseconds)

                          Single Prediction Validation                          
  - Predicting at time index 39 (t ≈ 1.0000e+04 s)

Species  Predicted Value  True Value  Abs. Error  Rel. Error (%)
   C2H2       1.1393e-01  1.5416e-01  4.0223e-02      2.6092e+01
    CH4       7.2136e-02  8.1398e-02  9.2614e-03      1.1378e+01
    CO2       1.3531e-12  1.9115e-12  5.5839e-13      2.9213e+01
     CO       5.1327e-04  5.0093e-04  1.2335e-05      2.4624e+00
    H2O       8.1341e-09  9.4394e-09  1.3053e-09      1.382