In [1]:
from typing import Literal

from torch.utils.data import DataLoader
from lightning.pytorch import LightningDataModule
from datasets import load_dataset, concatenate_datasets, Dataset, VerificationMode

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
class TimeQADataModule(LightningDataModule):

    def __init__(
        self,
        task: Literal["binary", "multi", "open"],
        batch_size: int = 32,
        repo: str = "dasyd/time-qa",
    ):
        super().__init__()

        self.task = task
        self.batch_size = batch_size
        self.repo = repo

    def _load_dataset_split(self, splits: list[str]):
        """Workaround to overcome the missing hf implementation of only dowloading the split shards"""
        dataset = load_dataset(
            self.repo,
            self.task,
            data_dir=self.task,
            data_files={split: f"{split}-*" for split in splits},
            verification_mode=VerificationMode.NO_CHECKS,
            num_proc=len(splits),
        )
        dataset.set_format(type="torch", columns=["trajectory", "trajectory_rot6d"], output_all_columns=True)
        return dataset

    def prepare_data(self) -> None:
        # Download all, since only this is run on the main process
        self._load_dataset_split(["train", "val", "test"])

    def setup(self, stage: str) -> None:
        if stage == "fit":
            self.dataset = self._load_dataset_split(["train", "val"])
        elif stage == "test":
            self.dataset = self._load_dataset_split(["test"])

    def train_dataloader(self) -> DataLoader:
        return DataLoader(self.dataset["train"], batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self) -> DataLoader:
        return DataLoader(self.dataset["val"], batch_size=self.batch_size)

    def test_dataloader(self) -> DataLoader:
        return DataLoader(self.dataset["test"], batch_size=self.batch_size)

    def all_splits(self) -> Dataset:
        return concatenate_datasets(self._load_dataset_split(["train", "val", "test"]).values())


module = TimeQADataModule(task="multi", batch_size=4)
module.prepare_data()
module.setup("test")

test_data = module.dataset["test"]
len(test_data)

6790

In [6]:
test_data[0]

{'trajectory': tensor([[[ 0.0000,  0.0036,  0.0085,  ..., -0.4432, -0.4421, -0.4414],
          [ 0.0000,  0.0054,  0.0155,  ...,  0.0395,  0.0531,  0.0641],
          [ 0.8392,  0.8376,  0.8356,  ...,  0.9551,  0.9550,  0.9550]],
 
         [[ 0.0183,  0.0237,  0.0323,  ..., -0.4166, -0.4158, -0.4152],
          [-0.0904, -0.0848, -0.0740,  ..., -0.0344, -0.0208, -0.0098],
          [ 0.8840,  0.8820,  0.8795,  ...,  1.0210,  1.0212,  1.0209]],
 
         [[-0.0479, -0.0443, -0.0387,  ..., -0.4954, -0.4950, -0.4950],
          [-0.0824, -0.0776, -0.0687,  ..., -0.0545, -0.0406, -0.0293],
          [ 0.7851,  0.7843,  0.7838,  ...,  0.9338,  0.9343,  0.9345]],
 
         ...,
 
         [[ 0.1326,  0.1334,  0.1283,  ..., -0.5152, -0.5102, -0.5072],
          [ 0.2395,  0.2459,  0.2546,  ...,  0.1603,  0.1726,  0.1838],
          [ 0.6019,  0.6020,  0.6036,  ...,  0.6719,  0.6704,  0.6693]],
 
         [[-0.0157, -0.0443, -0.1066,  ..., -0.4274, -0.4251, -0.4237],
          [ 0.0964,  0

In [18]:
relevant_test_data = list(test_data)[::5]
len(relevant_test_data)

1358

In [26]:
for example in relevant_test_data[:20]:
    print(example["question_type"])
    print(" -> ".join(example["action_sequence"]["action"]))
    print(example["question"])
    print(example["options_sequence"])
    print(str(example["answer"]) + " | " + example["answer_text"])
    print()

right_before_multi
kicking -> walking -> hopping -> hopping
What did the person do right before hopping? A: chopping, B: chopping, or C: chopping?
['chopping', 'chopping', 'chopping']
2 | The correct answer is walking.

right_after_multi
moving backwards -> moving backwards -> swimming -> bowing
Following swimming, what action did the person perform directly after? A: twisting, B: twisting, or C: bowing?
['twisting', 'twisting', 'bowing']
2 | The correct answer is: bowing.

comparison_counting_multi
planting its feet -> walking -> walking -> fighting
Among A: doing jumping jacks, B: fighting and C: fighting, which action is executed as many times as walking?
['doing jumping jacks', 'fighting', 'fighting']
1 | The correct answer is: fighting.

after_multi
leaning -> leaning -> squatting -> walking
What did the person do some time after squatting? Was it A: moving backwards, B: planting its feet or C: walking?
['moving backwards', 'planting its feet', 'walking']
2 | The correct answer is