In [None]:
from mcap_data_loader.datasets.mcap_dataset import (
    McapFlatBuffersEpisodeDataset,
    McapFlatBuffersEpisodeDatasetConfig,
)
import numpy as np
from collections import defaultdict

dataset = McapFlatBuffersEpisodeDataset(
    McapFlatBuffersEpisodeDatasetConfig(
        data_root="../mcap_data/reach_tag_blip2_features",
        keys=[
            f"/{cam}_camera/color/image_raw/features_proj" for cam in ("env", "follow")
        ],
    )
)
dataset.load()

In [2]:
ep_norms = defaultdict(lambda: defaultdict(list))
for episode in dataset:
    ref_data = {}
    ref_data_norm = {}
    ep_key = episode.config.data_root
    for sample in episode:
        for key, value in sample.items():
            if key not in ref_data:
                ref_data[key] = value["data"]
                ref_data_norm[key] = np.linalg.norm(value["data"])
            norm = np.linalg.norm(ref_data[key] - value["data"]) + ref_data_norm[key]
            ep_norms[ep_key][key].append(norm)
    ep_norms[ep_key]["t"] = list(range(len(episode)))

In [59]:
import holoviews as hv
from holoviews import opts

hv.extension("bokeh")

curves: list[hv.Curve] = []
for ep_key, episode in ep_norms.items():
    for s_key in episode.keys() - {"t"}:
        curve = hv.Curve(episode, "t", s_key, label=ep_key.name)
        curves.append(curve)
hv.Layout(curves).cols(2)

In [73]:
overlays = []
step = len(dataset.config.keys)
for key_start in range(step):
    key_curves = curves[key_start::step]
    overlays.append(
        hv.NdOverlay(
            {c.label: c.opts(width=450, height=400, xlim=(0, 300)) for c in key_curves}, kdims="episode"
        ).opts(title="")
    )
hv.Layout(overlays)