In [None]:
# ================================================================
# Package import
# ================================================================
import os
import random

import torch
from torch.utils.data import DataLoader

from src.dataset import MultiSinusoidDataset, custom_collate
from src.model import SlotFlow
from src.utils import (
    plot_dataset_consistency,
    plot_latent_space,
    plot_slot_posteriors,
    plot_confusion_matrix,
)

In [None]:
# ------------------------------------------------------------
# Select run directory
# ------------------------------------------------------------
run_id = "test_clariden"  # e.g. depth_8 / clariden run
result_path = os.path.join("pretrained_model", run_id)

config_path = os.path.join(result_path, "model_config.pt")
ckpt_dir = os.path.join(result_path, "checkpoints")

if not os.path.exists(config_path):
    raise FileNotFoundError(f"Config not found: {config_path}")
if not os.path.exists(ckpt_dir):
    raise FileNotFoundError(f"Checkpoint directory missing: {ckpt_dir}")

# ------------------------------------------------------------
# Load training config (defines architecture + dataset params)
# ------------------------------------------------------------
config = torch.load(config_path, map_location="cpu")

hidden_dim = config["hidden_dim"]
max_slots = config["max_slots"]
use_noise_encoder = config.get("use_noise_encoder", False)

num_samples_long = config["num_samples_long"]
num_samples_short = config["num_samples_short"]
tEnd_long = config["tEnd_long"]
tEnd_short = config["tEnd_short"]

freq_range = config["freq_range"]
amp_range = config["amp_range"]
noise_std = config["noise_std"]
max_components = config["max_components"]

print("Loaded model configuration.")

# ------------------------------------------------------------
# Locate SlotFlow best checkpoint
# ------------------------------------------------------------
ckpt_files = [
    f
    for f in os.listdir(ckpt_dir)
    if f.endswith(".ckpt") and ("best" in f or "hpc_ckpt" in f)
]

if not ckpt_files:
    raise FileNotFoundError(f"No checkpoint found in {ckpt_dir}")

# Use the newest or best checkpoint
ckpt_path = os.path.join(ckpt_dir, sorted(ckpt_files)[-1])
print(f"ðŸ“‚ Using checkpoint: {ckpt_path}")

ckpt = torch.load(ckpt_path, map_location="cpu")
state_dict = ckpt["state_dict"]

# Remove Lightning prefixes if present
clean_state = {}
for k, v in state_dict.items():
    k = k.replace("model.", "").replace("_orig_mod.", "")
    clean_state[k] = v

# ------------------------------------------------------------
# Instantiate model and load weights
# ------------------------------------------------------------
model = SlotFlow(
    hidden_dim=hidden_dim,
    max_slots=max_slots,
    use_noise_encoder=use_noise_encoder,
)

model.load_state_dict(clean_state, strict=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device).eval()

print(f"SlotFlow model loaded (epoch={ckpt.get('epoch')})")

# ============================================================
# === Create evaluation dataset
# ============================================================
dataset_cfg = dict(
    set_size=10_000,
    num_samples_long=num_samples_long,
    tEnd_long=tEnd_long,
    num_samples_short=num_samples_short,
    tEnd_short=tEnd_short,
    max_components=max_components,
    amp_range=amp_range,
    freq_range=freq_range,
    noise_std=noise_std,
    min_freq_sep=0.01,
    seed=42,
    num_comp=None,
    mode="inference",
    allowed_K_values=list(range(1, max_components + 1)),
)

test_set = MultiSinusoidDataset(**dataset_cfg)
test_loader = DataLoader(test_set, batch_size=1, collate_fn=custom_collate)

print(f"Dataset ready: {len(test_set)} signals")

In [None]:
# ---------------------------------------------------------
# Dataset Consistency Visualization
# ---------------------------------------------------------
# Sanity-check that the downsampled/segmented views (x_long, x_short)
# are consistent with the reference/master signal for a single sample.
#
# What this shows:
#   - TIME DOMAIN: overlays to confirm alignment (no unintended shifts,
#     padding artifacts, or windowing glitches).
#   - FREQUENCY DOMAIN (show_fft=True): spectra match up to expected
#     resolution differences from sampling/segment length.
# ---------------------------------------------------------
idx = random.randint(0, 5000)
plot_dataset_consistency(
    dataset=test_set,
    idx=idx,
    show_fft=False,  # also plot spectra
)

In [None]:
# ---------------------------------------------------------
# Posterior Visualization via Corner Plot
# ---------------------------------------------------------
# For a randomly chosen test signal, draw posterior samples
# over the sinusoid parameters and visualize them with a
# corner plot (marginal + pairwise joint distributions).
#
# This provides a qualitative check of:
#   - Uncertainty representation
#   - Presence of multimodality
#   - Calibration of the learned posterior
# ---------------------------------------------------------
idx = random.randint(0, 5000)
plot_slot_posteriors(
    model=model,
    dataset=test_set,
    idx=idx,
    n_samples=5000,
    device=device,
    use_gt_k=False,
    use_noise_encoder=use_noise_encoder,
    freq_range=freq_range,
)

In [None]:
# ---------------------------------------------------------
# Confusion Matrix Evaluation
# ---------------------------------------------------------
# The confusion matrix compares the modelâ€™s predicted number
# of components (K) against the true K in the test data.
# ---------------------------------------------------------
plot_confusion_matrix(model, test_loader, num_samples=1000)

In [None]:
# ---------------------------------------------------------
# Latent space visualization on INFERENCE-MODE data
# - Dataset returns 6-tuple items (no training metadata)
# - Useful to check how the model embeds "realistic"
#   inference-style samples
# ---------------------------------------------------------
plot_latent_space(
    model=model,
    data_loader=test_loader,
    method="tsne",  # dimensionality reduction method
    device=device,
    num_samples=1000,  # randomly subsample this many points
    plot_3d=False,
)