## Setup

In [2]:
import json
from pathlib import Path
from typing import Any, Literal, NamedTuple

import pandas as pd
import torch
from datasets import (
    Array2D,
    Array3D,
    ClassLabel,
    Dataset,
    DatasetInfo,
    Features,
    Sequence,
    Value,
)
from pandas.core.groupby import DataFrameGroupBy
from tables import open_file
from tqdm.auto import tqdm
from tqdm.contrib.concurrent import process_map

  from .autonotebook import tqdm as notebook_tqdm


# Configure: Full or simple

In [5]:
data_dir = (Path(".").absolute().parents[1] / "generated-dataset-10_000" / "data").absolute()
target_data_repo = "dasyd/time-qa"

len(list(data_dir.iterdir()))

10000

## Loading and assembling the data

In [8]:
class Sample(NamedTuple):
    sample_id: int
    question_id: int
    trajectory: torch.Tensor
    action_sequence: dict[str, Any]
    textual_description: str
    question_type: str
    question: str
    answer_type: str
    answer: str
    options: dict[str, str | bool] | None
    correct_option: str


def get_for(num: int):
    path = data_dir / str(num)
    with open_file(path / "data.hdf5", "r") as hdf5_file:
        trajectory = hdf5_file.root["joints"][:].astype("float32")

    with open(path / "data.json") as json_file:
        data = json.load(json_file)

    return [
        Sample(
            sample_id=num,
            question_id=question_id,
            trajectory=trajectory,
            action_sequence=data["prompt_sequence"],
            textual_description=data["textual_description"],
            question_type=qa_pair["question_type"],
            question=qa_pair["question"],
            answer_type=qa_pair["answer_type"],
            answer=qa_pair["answer"],
            options=qa_pair["options"],  # can be None
            correct_option=qa_pair["correct_option"],
        )
        for question_id, qa_pair in enumerate(data["qa_pairs"])
    ]


next(iter(get_for(0)))

Sample(sample_id=0, question_id=0, trajectory=array([[[-0.0000000e+00,  0.0000000e+00,  9.0687686e-01],
        [-6.7598253e-02, -9.7624855e-03,  8.1419015e-01],
        [ 6.9608040e-02, -7.2357678e-03,  8.1790030e-01],
        ...,
        [ 9.2216330e-03,  2.9207093e-01,  1.0123324e+00],
        [-1.0229243e-01,  3.1216171e-01,  1.0552542e+00],
        [-2.9680135e-02,  3.6213297e-01,  1.0495307e+00]],

       [[-1.5695986e-03,  3.3287038e-04,  9.0606552e-01],
        [-6.9089182e-02, -9.8003680e-03,  8.1337023e-01],
        [ 6.8105243e-02, -6.6510309e-03,  8.1711417e-01],
        ...,
        [ 3.8956096e-03,  2.9604143e-01,  1.0154035e+00],
        [-1.0466182e-01,  3.1318346e-01,  1.0573907e+00],
        [-3.6545765e-02,  3.6466584e-01,  1.0536150e+00]],

       [[-3.1320225e-03,  8.3422964e-04,  9.0520048e-01],
        [-7.0561066e-02, -9.7233178e-03,  8.1248724e-01],
        [ 6.6613667e-02, -5.9788004e-03,  8.1629044e-01],
        ...,
        [-2.2514353e-03,  3.0089089e-01, 

In [9]:
all_ids = sorted([int(path.name) for path in data_dir.glob("*") if (path / "data.hdf5").exists()])

# Some file may only be partially written, we ignore them
for num in tqdm(all_ids[:]):
    try:
        with open_file(data_dir / str(num) / "data.hdf5", "r") as hdf5_file:
            hdf5_file.root["joints"]
    except Exception:
        print(f"Error in {num}, ignoring")
        all_ids.remove(num)

len(all_ids)

100%|██████████| 10000/10000 [00:09<00:00, 1046.51it/s]


10000

In [11]:
splits = ["train", "val", "test"]

# For a predefined split:
# def get_ids_in_split(split: str) -> list[int]:
#     with open(data_dir.parent / f"{split}.json", "r") as json_file:
#         return {entry["org_idx"] for entry in json.load(json_file)}
# ids_in_split = {split: get_ids_in_split(split) for split in splits}

# For a random split:
train = 0.8
val = 0.10
test = 0.10
ids_in_split = {
    "train": sorted(all_ids[: int(train * len(all_ids))]),
    "val": sorted(all_ids[int(train * len(all_ids)) : int((train + val) * len(all_ids))]),
    "test": sorted(all_ids[int((train + val) * len(all_ids)) :]),
}

assert len(all_ids)
assert len(set([*ids_in_split["train"], *ids_in_split["val"], *ids_in_split["test"]])) == len(all_ids)

len(ids_in_split["train"]), len(ids_in_split["val"]), len(ids_in_split["test"])

(8000, 1000, 1000)

In [12]:
def make_info(task: Literal["binary", "open", "multi"]):
    base_features = {
        "sample_id": Value("int32"),
        "question_id": Value("int32"),
        "trajectory": Array3D(dtype="float32", shape=next(iter(get_for(0))).trajectory.shape),
        "action_sequence": Sequence(
            Features(
                {
                    "start": Value("float32"),
                    "end": Value("float32"),
                    "action": Value("string"),
                    "action_sentence": Value("string"),
                }
            ),
            length=4,
        ),
        "textual_description": Value("string"),
        "question_type": Value("string"),
        "question": Value("string"),
        "answer_type": Value("string"),
        "answer_text": Value("string"),  # That's the answer in text form, always present
    }

    match task:
        case "binary":
            feat_type = Features(
                base_features | {"answer": ClassLabel(names=["true", "false"], num_classes=2)}
            )

        case "multi":
            feat_type = Features(
                base_features
                | {
                    "answer": ClassLabel(num_classes=3),
                    "options_sequence": Sequence(Value("string"), length=3),
                    "options": Features(
                        {
                            "A": Value("string"),
                            "B": Value("string"),
                            "C": Value("string"),
                        }
                    ),
                }
            )

        case "open":
            feat_type = Features(base_features | {"answer_text": Value("string")})

    return DatasetInfo(features=feat_type)


def get_all(ids: set[int]) -> pd.DataFrame:
    # Using process_map to automatically manage the progress bar with executor.map
    results = process_map(get_for, ids, max_workers=None, chunksize=1, total=len(ids))
    return pd.DataFrame([elem._asdict() for sublist in results for elem in sublist])

In [13]:
def push_grouped_df_to_hub(
    df_group: DataFrameGroupBy,
    split: Literal["test", "val", "train"],
    limit_task: list[Literal["open", "multi", "binary"]],
    token: str = None,
):
    """
    Takes a grouped DataFrame, feature types, dataset info, and a Hugging Face authentication token,
    then pushes each group to the Hugging Face Hub under specified configurations.

    :param df_group: Grouped Pandas DataFrame object.
    :param feat_type: Feature type for the dataset.
    :param info: Information about the dataset.
    :param token: Hugging Face authentication token.
    """
    for name, group in df_group:
        if name not in limit_task:
            continue
        print(f"Group Name: {name}")

        group = group.rename(columns={"answer": "answer_text"})

        match name:
            case "binary":
                group["answer"] = (group["correct_option"] == "A").astype(int)
                group = group.drop(columns=["options"])

            case "multi":
                mapping = {"A": 0, "B": 1, "C": 2}
                group["answer"] = [mapping[option] for option in group["correct_option"]]

                group["options_sequence"] = group["options"].apply(lambda x: list(x.values()))

            case "open":
                group = group.drop(columns=["options"])

        group = group.drop(columns=["correct_option"])

        lst_dict = group.to_dict(orient="records")

        info = make_info(name)

        # Create a dataset from the list of dictionaries and push it to the hub
        dataset = Dataset.from_list(lst_dict, info=info, split=split)
        dataset.push_to_hub(
            target_data_repo,
            config_name=name,
            token=token,
            split=split,
        )

        print(f"Pushed to huggingface: {target_data_repo}/{split}/{name}")

## Persisting the dataset

In [15]:
splits = ["val", "test", "train"]

for s in splits:
    _df = get_all(ids_in_split[s])
    _df_group = _df.groupby("answer_type")

    push_grouped_df_to_hub(_df_group, split=s, limit_task=["binary", "multi", "open"])

100%|██████████| 1000/1000 [00:01<00:00, 723.40it/s]


Group Name: binary


Creating parquet from Arrow format: 100%|██████████| 3/3 [00:01<00:00,  1.75ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:11<00:00, 11.76s/it]


Pushed to huggingface: dasyd/time-qa/val/binary
Group Name: multi


Creating parquet from Arrow format: 100%|██████████| 2/2 [00:01<00:00,  1.82ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:06<00:00,  6.99s/it]


Pushed to huggingface: dasyd/time-qa/val/multi
Group Name: open


Creating parquet from Arrow format: 100%|██████████| 2/2 [00:01<00:00,  1.75ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:06<00:00,  6.89s/it]


Pushed to huggingface: dasyd/time-qa/val/open


100%|██████████| 1000/1000 [00:01<00:00, 680.82it/s]


Group Name: binary


Creating parquet from Arrow format: 100%|██████████| 3/3 [00:01<00:00,  1.90ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:08<00:00,  8.60s/it]


Pushed to huggingface: dasyd/time-qa/test/binary
Group Name: multi


Creating parquet from Arrow format: 100%|██████████| 2/2 [00:01<00:00,  1.83ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:06<00:00,  6.54s/it]


Pushed to huggingface: dasyd/time-qa/test/multi
Group Name: open


Creating parquet from Arrow format: 100%|██████████| 2/2 [00:01<00:00,  1.85ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:06<00:00,  6.00s/it]


Pushed to huggingface: dasyd/time-qa/test/open


100%|██████████| 8000/8000 [00:11<00:00, 690.78it/s]


Group Name: binary


Creating parquet from Arrow format: 100%|██████████| 4/4 [00:02<00:00,  1.84ba/s]
Creating parquet from Arrow format: 100%|██████████| 4/4 [00:02<00:00,  1.88ba/s]
Creating parquet from Arrow format: 100%|██████████| 4/4 [00:02<00:00,  1.73ba/s]
Creating parquet from Arrow format: 100%|██████████| 4/4 [00:02<00:00,  1.68ba/s]
Creating parquet from Arrow format: 100%|██████████| 4/4 [00:02<00:00,  1.68ba/s]
Uploading the dataset shards: 100%|██████████| 5/5 [00:58<00:00, 11.65s/it]


Pushed to huggingface: dasyd/time-qa/train/binary
Group Name: multi


Creating parquet from Arrow format: 100%|██████████| 4/4 [00:02<00:00,  1.55ba/s]
Creating parquet from Arrow format: 100%|██████████| 4/4 [00:02<00:00,  1.38ba/s]
Creating parquet from Arrow format: 100%|██████████| 4/4 [00:02<00:00,  1.40ba/s]
Uploading the dataset shards: 100%|██████████| 3/3 [00:41<00:00, 13.75s/it]


Pushed to huggingface: dasyd/time-qa/train/multi
Group Name: open


Creating parquet from Arrow format: 100%|██████████| 4/4 [00:02<00:00,  1.39ba/s]
Creating parquet from Arrow format: 100%|██████████| 4/4 [00:02<00:00,  1.44ba/s]
Creating parquet from Arrow format: 100%|██████████| 4/4 [00:02<00:00,  1.44ba/s]
Uploading the dataset shards: 100%|██████████| 3/3 [00:45<00:00, 15.09s/it]


Pushed to huggingface: dasyd/time-qa/train/open


You have to first run:

```shell
huggingface-cli login
```

## Test if it works (this re-downloads the dataset)

In [17]:
from typing import Literal

from datasets import VerificationMode, concatenate_datasets, load_dataset
from lightning.pytorch import LightningDataModule
from torch.utils.data import DataLoader


class TimeQADataModule(LightningDataModule):
    def __init__(
        self,
        task: Literal["binary", "multi", "open"],
        batch_size: int = 32,
        repo: str = target_data_repo,
    ):
        super().__init__()

        self.task = task
        self.batch_size = batch_size
        self.repo = repo

    def _load_dataset_split(self, splits: list[str]):
        """Workaround to overcome the missing hf implementation of only dowloading the split shards"""
        dataset = load_dataset(
            self.repo,
            self.task,
            data_dir=self.task,
            data_files={split: f"{split}-*" for split in splits},
            verification_mode=VerificationMode.NO_CHECKS,
            num_proc=len(splits),
        )
        dataset.set_format(type="torch", columns=["trajectory"], output_all_columns=True)
        return dataset

    def prepare_data(self) -> None:
        # Download all, since only this is run on the main process
        self._load_dataset_split(["train", "val", "test"])

    def setup(self, stage: str) -> None:
        if stage == "fit":
            self.dataset = self._load_dataset_split(["train", "val"])
        elif stage == "test":
            self.dataset = self._load_dataset_split(["test"])

    def train_dataloader(self) -> DataLoader:
        return DataLoader(self.dataset["train"], batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self) -> DataLoader:
        return DataLoader(self.dataset["val"], batch_size=self.batch_size)

    def test_dataloader(self) -> DataLoader:
        return DataLoader(self.dataset["test"], batch_size=self.batch_size)

    def all_splits(self) -> Dataset:
        return concatenate_datasets(self._load_dataset_split(["train", "val", "test"]).values())


module = TimeQADataModule(task="multi", batch_size=3)
module.prepare_data()
module.setup("fit")
# module.setup("test")

Generating train split: 100%|██████████| 11695/11695 [00:05<00:00, 2239.31 examples/s]
Setting num_proc from 2 back to 1 for the val split to disable multiprocessing as it only contains one shard.
Generating val split: 100%|██████████| 1416/1416 [00:00<00:00, 1578.17 examples/s]


In [None]:
loader = module.val_dataloader()
batch = next(iter(loader))
list(batch.keys())

['trajectory',
 'trajectory_rot6d',
 'sample_id',
 'question_id',
 'action_sequence',
 'textual_description',
 'question_type',
 'question',
 'answer_type',
 'answer_text',
 'options',
 'answer',
 'options_sequence']

In [None]:
batch

{'trajectory': tensor([[[[ 0.0000e+00, -1.4608e-03, -2.0801e-03,  ..., -2.8628e-01,
            -2.8653e-01, -2.8676e-01],
           [ 0.0000e+00, -2.5851e-04, -2.6852e-04,  ..., -7.0418e-02,
            -7.0737e-02, -7.1097e-02],
           [ 9.2889e-01,  9.2918e-01,  9.2991e-01,  ...,  7.2556e-01,
             7.2503e-01,  7.2449e-01]],
 
          [[ 3.1463e-02,  2.9694e-02,  2.8941e-02,  ..., -2.4621e-01,
            -2.4644e-01, -2.4672e-01],
           [-8.4738e-02, -8.5087e-02, -8.5188e-02,  ..., -1.5554e-01,
            -1.5586e-01, -1.5622e-01],
           [ 9.7729e-01,  9.7763e-01,  9.7829e-01,  ...,  7.6634e-01,
             7.6578e-01,  7.6528e-01]],
 
          [[-3.0521e-02, -3.2319e-02, -3.3269e-02,  ..., -3.1650e-01,
            -3.1687e-01, -3.1718e-01],
           [-9.0789e-02, -9.0961e-02, -9.0880e-02,  ..., -1.6001e-01,
            -1.6033e-01, -1.6070e-01],
           [ 8.7560e-01,  8.7595e-01,  8.7671e-01,  ...,  6.7012e-01,
             6.6967e-01,  6.6918e-01]]