In [None]:
import torch
import matplotlib.pyplot as plt

from modules.transformer import TransformerModule
from modules.mamba import MambaModule

from dataset import FrankaDataModule

In [None]:
checkpoint = "/work/tc064/tc064/s2567498/virtual-sensing/lightning_logs/transformer/version_11/checkpoints/epoch=0-step=520.ckpt"
hparams = "/work/tc064/tc064/s2567498/virtual-sensing/lightning_logs/transformer/version_11/hparams.yaml"

In [None]:
model = TransformerModule.load_from_checkpoint(
    checkpoint_path=checkpoint, hparams_file=hparams
)
model.eval()

In [None]:
dm = FrankaDataModule(
    "/work/tc064/tc064/s2567498/data-w-camera",
    batch_size=1,
    num_workers=0,
    data_portion=0.1,
    episode_length=200,
    use_cpu=True
)
dm.setup()
data_loader = dm.val_dataloader()

In [None]:
outputs = []
targets = []
for batch in data_loader:
    with torch.no_grad():
        target = batch["sensor_data"][0, 10, [7,8,9]]
        targets.append(target)
        output = model.predict(batch)
        outputs.append(output)
        print(output, target)

In [None]:
output_values = [out.numpy().flatten().tolist() for out in outputs]
target_values = [tgt.numpy().tolist() for tgt in targets]

# Plotting
for i in range(len(output_values)):
    plt.figure(figsize=(10, 5))
    plt.plot(output_values[i], label="Output", marker="o")
    plt.plot(target_values[i], label="Target", marker="x")
    plt.title(f"Sensor Outputs vs Targets for Batch {i+1}")
    plt.xlabel("Sensor Index")
    plt.ylabel("Sensor Value")
    plt.legend()
    plt.show()