Imports

import torch
from src.config import SMDConfig, ModelConfig
from src.data.smd_dataset import SMDSequenceDataset
from src.data.preprocessing import compute_smd_normalization_stats
from src.models.hybrid_model import HybridAnomalyModel


Load small batch for debugging

In [None]:
smd_cfg = SMDConfig(window_size=50)  # smaller window for fast debugging

mean, std = compute_smd_normalization_stats(smd_cfg.root_dir)
dataset = SMDSequenceDataset(smd_cfg, split="train", mean=mean, std=std)

loader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)

batch = next(iter(loader))
x = batch["input"]

print("Batch shape:", x.shape)


Initialize the hybrid model

In [None]:
model_cfg = ModelConfig(input_dim=x.shape[-1])

model = HybridAnomalyModel(
    input_dim=model_cfg.input_dim,
    window_size=smd_cfg.window_size,
    d_model=model_cfg.d_model,
    n_heads=model_cfg.n_heads,
    num_layers=model_cfg.num_layers,
    dim_feedforward=model_cfg.dim_feedforward,
    dropout=model_cfg.dropout,
    latent_dim=model_cfg.latent_dim,
)

print(model)


Test forward pass

In [None]:
z_noise = torch.randn(x.size(0), model_cfg.latent_dim)

recon, z, d_real, d_fake = model(x, z_noise)

print("Reconstruction:", recon.shape)
print("Latent:", z.shape)
print("Disc real:", d_real.shape)
print("Disc fake:", d_fake.shape)


Visual check reconstruction

In [None]:
import matplotlib.pyplot as plt

x_np = x[0].numpy()
recon_np = recon[0].detach().numpy()

plt.figure(figsize=(10, 4))
plt.plot(x_np[:, 0], label='Original')
plt.plot(recon_np[:, 0], label='Reconstructed')
plt.legend()
plt.title("Quick Reconstruction Check")
plt.show()
