In [8]:
# 1. Imports
import torch
from pathlib import Path
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import sys

# Add project root to sys.path
project_root = Path(__file__).resolve().parent.parent if '__file__' in globals() else Path().resolve().parent
sys.path.append(str(project_root))

from models.autoencoder.dataset_autoencoder import DatasetAutoencoder
from models.autoencoder.architectures.conv_autoencoder16 import ConvAutoencoder16

# 2. Load the same dataset you trained on
dataset = DatasetAutoencoder(
    path=Path("data/waves"),    # or wherever your raw data lives
    reduction="",               # match what you used in training
    n=0,
    save=False,                 # no need to re-cache here
    force_reload=False
)

# 3. Load your checkpoint
ckpt = project_root / Path("artifacts/autoencoder/checkpoints") / "ConvAE16_2025-06-16_12-02-51.pt"
model = ConvAutoencoder16.load(ckpt)
model.eval()  # set to inference mode

# 4. Pick a small batch from validation (or full dataset) for testing
loader = DataLoader(dataset, batch_size=6, shuffle=True)
batch = next(iter(loader))            # shape [6, length]
batch_in = batch.unsqueeze(1).to(model.device)   # → [6,1,length]

# 5. Run the autoencoder
with torch.no_grad():
    recon = model(batch_in)           # → [6, length]
recon = recon.cpu()                   # bring back to CPU

# 6. Plot originals vs reconstructions
fig, axes = plt.subplots(2, 3, figsize=(12, 6))
for i, ax in enumerate(axes.flatten()):
    ax.plot(batch[i].numpy(),   label="Original")
    ax.plot(recon[i].numpy(),   label="Reconstruction")
    ax.set_xticks([]); ax.set_yticks([])
    ax.legend(fontsize="small")
fig.suptitle("ConvAE16 Reconstructions (6 random waves)")
fig.tight_layout()


TypeError: ConvAutoencoder16.__init__() got an unexpected keyword argument 'model_name'