In [None]:
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from PIL import Image
import torch
import os
import pandas as pd
import numpy as np
import config


In [None]:
class JetbotACDataset(Dataset):
    def __init__(self, csv_path, data_dir, frames_per_clip=8, frameskip=1, transform=None):
        """Dataset of Jetbot clips with single motor actions.

        The dataset groups rows by `session_id` and samples random clips from
        each session similar to :class:`DROIDVideoDataset`. Actions are
        stored in the `action` column and take values of `0.0` or `0.13`.
        """
        self.csv_path = csv_path
        self.data_dir = data_dir
        self.frames_per_clip = frames_per_clip
        self.frameskip = frameskip
        self.transform = transform
        self.sessions = self._load_sessions()

    def _load_sessions(self):
        if not os.path.exists(self.csv_path):
            raise FileNotFoundError(f'CSV not found: {self.csv_path}')
        df = pd.read_csv(self.csv_path)
        required_cols = {'session_id', 'image_path', 'action'}
        if not required_cols.issubset(df.columns):
            raise ValueError(f'Missing required columns: {required_cols - set(df.columns)}')
        sessions = []
        for _, group in df.groupby('session_id'):
            group = group.sort_values('timestamp').reset_index(drop=True)
            sessions.append(group)
        return sessions

    def __len__(self):
        return len(self.sessions)

    def __getitem__(self, idx):
        session = self.sessions[idx]
        needed = self.frames_per_clip * self.frameskip
        if len(session) < needed:
            raise ValueError(f'Session length {len(session)} < required {needed}')
        start = np.random.randint(0, len(session) - needed + 1)
        indices = np.arange(start, start + needed, self.frameskip)
        frames, actions = [], []
        for i in indices:
            row = session.iloc[i]
            img = Image.open(os.path.join(self.data_dir, row['image_path'])).convert('RGB')
            if self.transform:
                img = self.transform(img)
            frames.append(img)
            actions.append(row['action'])
        buffer = torch.stack(frames, dim=0)
        actions = torch.tensor(actions, dtype=torch.float32)
        return buffer, actions, indices


In [None]:
# Example usage
transform = transforms.Compose([
    transforms.Resize((config.IMAGE_SIZE, config.IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = JetbotACDataset(
    config.CSV_PATH,
    config.DATA_DIR,
    frames_per_clip=8,
    frameskip=1,
    transform=transform,
)
