In [24]:
import os
import concurrent.futures

import torch
import torchvision
import numpy
import polars


class WLRZEpisodeDataset(torch.utils.data.Dataset):
    def __init__(self, path: str):
        self._metadata = polars.read_json(path)
        self._base_path = os.path.dirname(path)

    def __len__(self):
        return len(self._metadata)

    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.numpy(force=True)

        metadata = self._metadata[index]

        with concurrent.futures.ThreadPoolExecutor() as executor:
            def _decode_images(paths: ...):
                return list(executor.map(
                    lambda x: torchvision.io.decode_image(
                        os.path.join(self._base_path, x),
                        mode=torchvision.io.ImageReadMode.RGB,
                    ),
                    paths,
                ))
            
            images_left, images_front, images_right = list(executor.map(
                _decode_images,
                (
                    metadata["left"],
                    metadata["front"],
                    metadata["right"],
                ),
            ))

        return dict(
            images_left=numpy.stack(images_left),
            images_front=numpy.stack(images_front),
            images_right=numpy.stack(images_right),
            dof_positions=numpy.stack(metadata["joint"]),
            text=numpy.stack(metadata["task"]),
        )



In [25]:
dataset = WLRZEpisodeDataset("samples/2026-01-21_demo_clothes/episode_0/data.json")


In [26]:
dataset

<__main__.WLRZEpisodeDataset at 0x79030e3ec890>