In [None]:
import sys
import yaml
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import mean_absolute_error, r2_score

In [None]:
project_root = ""
sys.path.append(project_root)

In [None]:
from src.utils import create_model
from src.data_pipeline import  HDF5Dataset
sns.set_theme(style="whitegrid")

In [None]:
CONFIG_PATH = "" 
MODEL_PATH = ""
HDF5_PATH = ""

with open(CONFIG_PATH, 'r') as f:
    config = yaml.safe_load(f)

PARAM_NAMES = ["dist", "poni1", "poni2", "rot1", "rot2", "rot3"]
PARAM_UNITS = ["m", "m", "m", "rad", "rad", "rad"]

In [None]:
# Load model in inference mode

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = create_model(config)

print(f"Loading model weights from {MODEL_PATH}")
state_dict = torch.load(MODEL_PATH, map_location=device)

if any(k.startswith('_orig_mod.') for k in state_dict.keys()):
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[10:] if k.startswith('_orig_mod.') else k
        new_state_dict[name] = v
    state_dict = new_state_dict

model.load_state_dict(state_dict)
model.to(device)
model.eval();

print("Model loaded for inference.")

In [None]:
# Run the inference

image_size = config['model'].get('image_size', 224)
test_dataset = HDF5Dataset(HDF5_PATH, 'test', image_size=image_size)
test_loader = DataLoader(test_dataset, batch_size=config['training']['batch_size'], shuffle=False)

all_preds = []
all_labels = []

print("Running inference on the test set...")
with torch.no_grad():
    for images, labels in tqdm(test_loader):
        images = images.to(device)
        predictions = model(images)
        
        all_preds.append(predictions.cpu().numpy())
        all_labels.append(labels.cpu().numpy())

preds_np = np.concatenate(all_preds)
labels_np = np.concatenate(all_labels)

print("Inference complete.")

In [None]:
# Calculates the mean absolute error and correlation for each parameter

print("Performance Metrics")
for i, name in enumerate(PARAM_NAMES):
    mae = mean_absolute_error(labels_np[:, i], preds_np[:, i])
    r2 = r2_score(labels_np[:, i], preds_np[:, i])
    print(f"{name:<6} (unit: {PARAM_UNITS[i]:<4}) | MAE: {mae:.6f} | R2: {r2:.4f}")

In [None]:
# Plots predicted vs true value correlations for each parameter

fig, axes = plt.subplots(2, 3, figsize=(18, 10))
fig.suptitle('Predicted vs. True Poni Parameters', fontsize=16)
axes = axes.flatten()

for i, (ax, name) in enumerate(zip(axes, PARAM_NAMES)):
    ax.scatter(labels_np[:, i], preds_np[:, i], alpha=0.5, s=10)
    
    lims = [
        np.min([ax.get_xlim(), ax.get_ylim()]),
        np.max([ax.get_xlim(), ax.get_ylim()]),
    ]
    ax.plot(lims, lims, 'r--', alpha=0.75, zorder=0)
    ax.set_aspect('equal', adjustable='box')
    
    ax.set_title(name)
    ax.set_xlabel(f"True Value ({PARAM_UNITS[i]})")
    ax.set_ylabel(f"Predicted Value ({PARAM_UNITS[i]})")

plt.tight_layout(rect=[0, 0.03, 1, 0.95]) #type: ignore
plt.show()

In [None]:
# Plots error distributions for each parameter

errors_np = preds_np - labels_np

fig, axes = plt.subplots(2, 3, figsize=(18, 10))
fig.suptitle('Error Distribution for Poni Parameters', fontsize=16)
axes = axes.flatten()

for i, (ax, name) in enumerate(zip(axes, PARAM_NAMES)):
    sns.histplot(errors_np[:, i], kde=True, ax=ax, bins=30)
    ax.axvline(0, color='r', linestyle='--')
    ax.set_title(name)
    ax.set_xlabel(f"Prediction Error ({PARAM_UNITS[i]})")
    ax.set_ylabel("Frequency")

plt.tight_layout(rect=[0, 0.03, 1, 0.95]) #type: ignore
plt.show()