## Setup

In [35]:
!pip install -qq datasets huggingface_hub[cli] lightning tables
!pip install -qq torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu

In [36]:
import json
import pickle
from pathlib import Path
from typing import NamedTuple, Literal
import pandas as pd
from pandas.core.groupby import DataFrameGroupBy

import numpy as np
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from datasets import (
    Dataset,
    Features,
    Value,
    Array2D,
    DatasetInfo,
    load_dataset,
    ClassLabel,
    Sequence,
)
from lightning.pytorch import LightningDataModule


data_dir = Path("sqa_data")

## Loading and assembling the data

In [37]:
def dataset_split(pd_data, test_list):
    data_select_index_test = pd_data["context_source_file"] == "000"

    for test_file_i in test_list:
        data_select_index_test = (pd_data["context_source_file"] == test_file_i) | (
            data_select_index_test
        )

    print("Testing data percentage: ", sum(data_select_index_test) / pd_data.shape[0])
    print("Testing data number: ", sum(data_select_index_test))
    print(
        "Testing data unique scene: ",
        len(pd_data[data_select_index_test].context_index.unique()),
    )

    return data_select_index_test

In [38]:
# loading generated questions pickle
win_len = 1800
base_file_name = f"s1234_{win_len}_600"  # stride 600 is the one from the paper
pd_data = pd.read_pickle(data_dir / f"{base_file_name}_balanced.pkl")
with open(data_dir / f"{base_file_name}_context.pkl", "rb") as f:
    sensory_data = pickle.load(f)

# The OppQA data is split into a training set and a testing set. The
# training set contains SQA data generated on the first two Activity-
# of-Daily-Living (ADL) runs and a drill run of users 1-4, and the
# rest of the runs are used to generate testing data.

### splitting method 1: based on context
valid_list = [
    "S1-ADL1.dat",
    "S2-ADL1.dat",
    "S3-ADL1.dat",
    "S4-ADL1.dat",
    "S1-ADL3.dat",
    "S2-ADL3.dat",
    "S3-ADL3.dat",
    "S4-ADL3.dat",
    "S1-ADL2.dat",
    "S2-ADL2.dat",
    "S3-ADL2.dat",
    "S4-ADL2.dat",
]

train_list = [
    "S1-ADL4.dat",
    "S2-ADL4.dat",
    "S3-ADL4.dat",
    "S4-ADL4.dat",
    "S1-ADL5.dat",
    "S2-ADL5.dat",
    "S3-ADL5.dat",
    "S4-ADL5.dat",
    "S1-Drill.dat",
    "S2-Drill.dat",
    "S3-Drill.dat",
    "S4-Drill.dat",
]

#     ============  split train/valid based on no overlapping context:  ============
train_ind = dataset_split(pd_data, train_list)
valid_ind = dataset_split(pd_data, valid_list)


### splitting method 2: total random
# ============ random split train/valid:  ============
#     random_ind = np.random.rand(pd_data.shape[0])
#     train_ind = random_ind>=0.8
#     valid_ind = ~train_ind

# ====================================================

#     ### splitting method 3: based on q_struct
#     uniq_struct = ( pd_data.question_structure.unique() )
#     print('Total unique Q structure num: ',  len(uniq_struct))
#     # split the unique Q-struct to 50%-50%
#     rd_num = np.random.rand(len(uniq_struct))
#     train_ind_struct = rd_num<0.8
#     test_ind_struct = rd_num>=0.8

#     train_qstruct = uniq_struct[train_ind_struct]
#     # valid_qstruct = uniq_struct[valid_ind]
#     test_qstruct = uniq_struct[test_ind_struct]
#     train_ind = pd_data.question_structure.isin(train_qstruct)
#     valid_ind = pd_data.question_structure.isin(test_qstruct)

#     print('Train/test split:  %d / %d' %(sum(train_ind), sum(valid_ind)) )
#     # ====================================================

Testing data percentage:  0.8283266370433441
Testing data number:  93393
Testing data unique scene:  729
Testing data percentage:  0.17167336295665594
Testing data number:  19356
Testing data unique scene:  629


In [39]:
train_ind

127089      True
1164649    False
1246654     True
1404904     True
810542      True
           ...  
1569687     True
1572900     True
1572908     True
1572909     True
1572917     True
Name: context_source_file, Length: 112749, dtype: bool

In [40]:
# The validation is really the test set, i.e. the hold-out for final evaluation.
test_ind = valid_ind
del valid_ind

# Generate a proper validation set
percent_valid = 0.1

# train_ind is effectively a mask, and we want to set some to false and have a copy where those are then true for the validation set
# tghe selection shall be random
valid_ind = np.zeros_like(train_ind)
valid_ind[
    np.random.choice(
        np.where(train_ind)[0], int(percent_valid * np.sum(train_ind)), replace=False
    )
] = True
train_ind[valid_ind] = False

train_ind.sum(), valid_ind.sum(), test_ind.sum()

(84054, 9339, 19356)

In [41]:
pd_data

Unnamed: 0,context_source_file,context_start_point,context_index,question,answer,pred_answer,question_family_index,question_structure,question_index,split
127089,S1-ADL5.dat,9000,252,The tester closed the back Door After closing ...,No,Invalid,3,"3_['Close the back Door', 'Close the Fridge', ...",127089,Test
1164649,S4-ADL1.dat,34200,1118,Is it true that the user opened the front Door...,No,No,3,"3_['Open the front Door', 'Close the Fridge', ...",1164649,Test
1246654,S4-Drill.dat,3000,1294,The person opened the Fridge Following closing...,No,Invalid,3,"3_['Open the Fridge', 'Close the back Door', '...",1246654,Test
1404904,S4-Drill.dat,24600,1330,The person closed the front Door After closing...,No,Yes,3,"3_['Close the front Door', 'Close the back Doo...",1404904,Test
810542,S2-Drill.dat,41400,705,The tester opened the back Door After opening ...,No,Yes,3,"3_['Open the back Door', 'Open the Fridge', 'O...",810542,Test
...,...,...,...,...,...,...,...,...,...,...
1569687,S4-Drill.dat,42000,1359,Confirm if the user performs the same action F...,Yes,Invalid,9,"9_['Close the front Door', 'Close the back Doo...",1569687,Test
1572900,S4-Drill.dat,42600,1360,The subject performs the same action Preceding...,Yes,Invalid,9,"9_['Open the back Door', 'Open the front Door'...",1572900,Test
1572908,S4-Drill.dat,42600,1360,Is it the case that the subject performs the s...,Yes,Invalid,9,"9_['Close the back Door', 'Close the front Doo...",1572908,Test
1572909,S4-Drill.dat,42600,1360,The subject performs the same action Following...,Yes,Invalid,9,"9_['Open the front Door', 'Open the back Door'...",1572909,Test


In [42]:
pd_data.info()

<class 'pandas.core.frame.DataFrame'>
Index: 112749 entries, 127089 to 1572917
Data columns (total 10 columns):
 #   Column                 Non-Null Count   Dtype 
---  ------                 --------------   ----- 
 0   context_source_file    112749 non-null  object
 1   context_start_point    112749 non-null  int64 
 2   context_index          112749 non-null  int64 
 3   question               112749 non-null  object
 4   answer                 112749 non-null  object
 5   pred_answer            112749 non-null  object
 6   question_family_index  112749 non-null  int64 
 7   question_structure     112749 non-null  object
 8   question_index         112749 non-null  int64 
 9   split                  112749 non-null  object
dtypes: int64(4), object(6)
memory usage: 9.5+ MB


In [43]:
pd_data["answer"].value_counts()

answer
No                         33361
Yes                        29649
1                          20650
0                          20343
Close the Fridge             771
Open the front Door          755
Open the back Door           729
Close the back Door          710
Toggle the Switch            696
Close the front Door         691
Open the Fridge              685
Drink from the Cup           681
Close the third Drawer       649
Open the Dishwasher          630
2                            345
Open the third Drawer        270
Clean the Table              251
Close the Dishwasher         246
Open the first Drawer        223
Close the second Drawer      167
Close the first Drawer       151
Open the second Drawer        88
3                              8
Name: count, dtype: int64

In [44]:
len(pd_data["answer"].value_counts())

23

In [45]:
num_multiclass = 17

In [46]:
pd_data.iloc[15000]

context_source_file                                           S2-Drill.dat
context_start_point                                                  18000
context_index                                                          666
question                 Is it the case that the user closed the front ...
answer                                                                  No
pred_answer                                                             No
question_family_index                                                    3
question_structure       3_['Close the front Door', 'Open the Fridge', ...
question_index                                                      652310
split                                                                 Test
Name: 652310, dtype: object

In [47]:
sensory_data.keys()

dict_keys(['raw', 'embedding'])

In [48]:
sensory_data["raw"]["S1-ADL1.dat_0"].shape

(1800, 77)

In [49]:
context_key_list = (
    pd_data["context_source_file"] + "_" + pd_data["context_start_point"].astype(str)
)
sensory_matrix = np.zeros((len(pd_data), win_len, 77), dtype="float32")

for key, values in tqdm(sensory_data["raw"].items()):
    sensory_matrix[np.where(context_key_list == key), :] = values

100%|██████████| 1362/1362 [00:54<00:00, 25.18it/s]


In [50]:
with open(Path("sqa_data_gen") / "question_family.json") as f:
    data = json.load(f)

question_family_index_to_type = {
    int(entry["index"]): entry["question_type"] for entry in data["questions"]
}
question_family_index_to_type

{0: 'existence',
 1: 'counting',
 2: 'action_compare',
 3: 'action_compare',
 4: 'counting',
 5: 'counting',
 6: 'action_query',
 7: 'action_query',
 8: 'existence',
 9: 'action_compare',
 10: 'number_compare',
 11: 'number_compare',
 12: 'action_query',
 13: 'counting',
 14: 'time_query',
 15: 'time_query'}

In [51]:
class Sample(NamedTuple):
    # sample_id: int
    question_id: int
    trajectory: torch.Tensor
    # textual_description: str
    question_family_index: int
    question_type: str
    question: str
    # answer_type: str
    answer: str
    # options: dict[str, str | bool] | None
    # correct_option: str


def get_for(data: pd.DataFrame) -> list[Sample]:
    return [
        Sample(
            # sample_id=num,
            question_id=row["question_index"],
            trajectory=sensory_matrix[0, ...],
            # textual_description=data["textual_description"],
            question_family_index=int(row["question_family_index"]),
            question_type=question_family_index_to_type[
                int(row["question_family_index"])
            ],
            question=row["question"],
            answer=row["answer"],
            # options=json.dumps(qa_pair["options"]),  # can be None
            # correct_option=qa_pair["correct_option"],
        )
        for _, row in data.iterrows()
    ]


get_for(pd_data.iloc[0:2])

[Sample(question_id=127089, trajectory=array([[-1.103, -0.458,  0.174, ...,  0.442, -0.037, -0.174],
        [-0.919, -0.351,  0.1  , ...,  1.22 ,  0.183, -0.172],
        [-0.719, -0.302,  0.079, ...,  1.762,  0.37 , -0.17 ],
        ...,
        [-0.948, -0.315,  0.009, ..., -0.007,  0.036, -0.007],
        [-0.946, -0.319,  0.013, ...,  0.047, -0.014, -0.007],
        [-0.945, -0.317,  0.011, ...,  0.056, -0.005, -0.008]],
       dtype=float32), question_family_index=3, question_type='action_compare', question='The tester closed the back Door After closing the refrigerator OR Before closing the front Door?', answer='No'),
 Sample(question_id=1164649, trajectory=array([[-1.103, -0.458,  0.174, ...,  0.442, -0.037, -0.174],
        [-0.919, -0.351,  0.1  , ...,  1.22 ,  0.183, -0.172],
        [-0.719, -0.302,  0.079, ...,  1.762,  0.37 , -0.17 ],
        ...,
        [-0.948, -0.315,  0.009, ..., -0.007,  0.036, -0.007],
        [-0.946, -0.319,  0.013, ...,  0.047, -0.014, -0.007],


In [52]:
# pd_data

In [53]:
df_all = pd.DataFrame.from_dict(get_for(pd_data))
df_all["question-answer"] = df_all["question"] + " " + df_all["answer"]
df_all

Unnamed: 0,question_id,trajectory,question_family_index,question_type,question,answer,question-answer
0,127089,"[[-1.103, -0.458, 0.174, -1.914, -0.546, 0.091...",3,action_compare,The tester closed the back Door After closing ...,No,The tester closed the back Door After closing ...
1,1164649,"[[-1.103, -0.458, 0.174, -1.914, -0.546, 0.091...",3,action_compare,Is it true that the user opened the front Door...,No,Is it true that the user opened the front Door...
2,1246654,"[[-1.103, -0.458, 0.174, -1.914, -0.546, 0.091...",3,action_compare,The person opened the Fridge Following closing...,No,The person opened the Fridge Following closing...
3,1404904,"[[-1.103, -0.458, 0.174, -1.914, -0.546, 0.091...",3,action_compare,The person closed the front Door After closing...,No,The person closed the front Door After closing...
4,810542,"[[-1.103, -0.458, 0.174, -1.914, -0.546, 0.091...",3,action_compare,The tester opened the back Door After opening ...,No,The tester opened the back Door After opening ...
...,...,...,...,...,...,...,...
112744,1569687,"[[-1.103, -0.458, 0.174, -1.914, -0.546, 0.091...",9,action_compare,Confirm if the user performs the same action F...,Yes,Confirm if the user performs the same action F...
112745,1572900,"[[-1.103, -0.458, 0.174, -1.914, -0.546, 0.091...",9,action_compare,The subject performs the same action Preceding...,Yes,The subject performs the same action Preceding...
112746,1572908,"[[-1.103, -0.458, 0.174, -1.914, -0.546, 0.091...",9,action_compare,Is it the case that the subject performs the s...,Yes,Is it the case that the subject performs the s...
112747,1572909,"[[-1.103, -0.458, 0.174, -1.914, -0.546, 0.091...",9,action_compare,The subject performs the same action Following...,Yes,The subject performs the same action Following...


In [54]:
df_all["question_family_index"].value_counts().sort_index()

question_family_index
0       503
1       505
2      2345
3     38289
4      2347
5     38494
6      1505
7      6888
8       371
9      6900
10     4766
11     9836
Name: count, dtype: int64

In [55]:
df_all["question_type"].value_counts().sort_index()

question_type
action_compare    47534
action_query       8393
counting          41346
existence           874
number_compare    14602
Name: count, dtype: int64

In [56]:
# Make sure that the question_type makes sense

by_family_index = df_all.groupby("question_type")
for family_index, group in by_family_index:
    print(family_index, ", from:", group["question_family_index"].unique())
    print("\n".join(group["question"].iloc[:15]))
    print()

action_compare , from: [3 2 9]
The tester closed the back Door After closing the refrigerator OR Before closing the front Door?
Is it true that the user opened the front Door Preceding closing the Fridge OR After opening the Fridge?
The person opened the Fridge Following closing the back Door OR Preceding opening the back Door, correct?
The person closed the front Door After closing the back Door AND After opening the back Door?
The tester opened the back Door After opening the Fridge AND Before opening the front Door, correct?
Is it correct to say that the person closed the Fridge After closing the back Door OR Preceding opening the Fridge?
Is it true that the user opened the back Door Preceding closing the front Door AND After closing the refrigerator?
The subject closed the back Door After opening the Fridge OR Following opening the Fridge?
Does the tester opens the front Door After closing the back Door OR Preceding opening the back Door?
Confirm if the tester closed the back Door 

In [57]:
len(df_all)

112749

In [58]:
df_all["question"].isna().any()

False

In [59]:
df_all["trajectory-str"] = df_all["trajectory"].apply(lambda x: str(x))
df_all[
    [
        "question_type",
        "answer",
        "question",
        "question-answer",
        "trajectory-str",
    ]
].nunique()

question_type          5
answer                23
question           84389
question-answer    89735
trajectory-str         1
dtype: int64

In [60]:
trajectory_shape = next(iter(get_for(pd_data.iloc[:1]))).trajectory.shape
trajectory_shape

(1800, 77)

In [61]:
multi_options = pd.Series(
    df_all[df_all["question_type"] == "action_query"]["answer"].unique()
).sort_values()
multi_option_to_int = {answer: index for index, answer in enumerate(multi_options)}
multi_option_to_int

{'Clean the Table': 0,
 'Close the Dishwasher': 1,
 'Close the Fridge': 2,
 'Close the back Door': 3,
 'Close the first Drawer': 4,
 'Close the front Door': 5,
 'Close the second Drawer': 6,
 'Close the third Drawer': 7,
 'Drink from the Cup': 8,
 'Open the Dishwasher': 9,
 'Open the Fridge': 10,
 'Open the back Door': 11,
 'Open the first Drawer': 12,
 'Open the front Door': 13,
 'Open the second Drawer': 14,
 'Open the third Drawer': 15,
 'Toggle the Switch': 16}

In [62]:
question_type_to_answer_type = {
    "existence": "binary",
    "action_compare": "binary",
    "number_compare": "binary",
    "action_query": "multi",
    "counting": "count",
}

In [63]:
def make_info(answer_type: str) -> DatasetInfo:
    base_features = {
        # "sample_id": Value("int32"),
        "question_id": Value("int32"),
        "trajectory": Array2D(dtype="float32", shape=trajectory_shape),
        # "textual_description": Value("string"),
        "question_type": Value("string"),
        "question": Value("string"),
        "answer_type": Value("string"),
        # "answer": Value("string"),
        # "options": Value("string"),  # JSON encoded
        # "correct_option": Value("string"),
    }

    match answer_type:
        case "binary":
            answer_features = {
                "answer": ClassLabel(names=["true", "false"], num_classes=2)
            }
        case "multi":
            answer_features = {
                "answer": ClassLabel(
                    names=multi_options.to_list(), num_classes=num_multiclass
                ),
                "options": Sequence(Value("string")),
            }
        case "count":
            answer_features = {"answer": Value("uint8")}
        case _:
            raise ValueError(f"Invalid task '{answer_type}'")

    return DatasetInfo(features=Features(base_features | answer_features))

## Persisting the dataset

In [64]:
def push_grouped_df_to_hub(
    df_group: DataFrameGroupBy,
    split: Literal["test", "val", "train"],
    limit_task: list[str] | None = None,
    token: str = None,
):
    """
    Takes a grouped DataFrame, feature types, dataset info, and a Hugging Face authentication token,
    then pushes each group to the Hugging Face Hub under specified configurations.

    :param df_group: Grouped Pandas DataFrame object.
    :param feat_type: Feature type for the dataset.
    :param info: Information about the dataset.
    :param token: Hugging Face authentication token.
    """
    for name, group in df_group:
        if limit_task and name not in limit_task:
            continue
        print(f"Group Name: {name} of split: {split}")

        answer_type = question_type_to_answer_type[name]

        match answer_type:
            case "binary":
                group["answer"] = (group["answer"] == "Yes").astype(int)

            case "multi":
                # group["answer_index"] = [
                #     multi_option_to_int[ans] for ans in group["answer"]
                # ]
                group["options"] = group["answer"].apply(
                    lambda _: multi_options.to_list()
                )

            case "count":
                group["answer"] = group["answer"].astype(int)

            case _:
                raise ValueError(f"Invalid task name: {name}")

        def gen_it():
            yield from group.to_dict(orient="records")

        match win_len:
            case 500:
                repo_name = "dasyd/OppQA-500"
            case 1800:
                repo_name = "dasyd/OppQA"
            case _:
                raise ValueError(
                    f"This window length has no dataset attached: {win_len}"
                )

        # Create a dataset from the list of dictionaries and push it to the hub
        dataset = Dataset.from_generator(
            gen_it,
            info=make_info(answer_type),
        )
        dataset.push_to_hub(
            repo_name,
            config_name=name,
            token=token,
            split=split,
            #     commit_message=f"[Version Revision] Restructured item shape of {split} split of {name} task dataset",
        )
        print(f"Pushed {name} of split: {split} to Hub {repo_name}")

In [65]:
splits = {
    "train": train_ind,
    "val": valid_ind,
    "test": test_ind,
}

for name, mask in splits.items():
    data = get_for(pd_data[mask])
    df = pd.DataFrame(elem._asdict() for elem in data)
    push_grouped_df_to_hub(
        df.groupby("question_type"),
        split=name,
    )

Group Name: action_compare of split: train


Generating train split: 37227 examples [01:01, 601.32 examples/s]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:05<00:00,  5.00s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.85s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.42s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.29s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.27s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.33s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.38s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:05<00:00,  5.04s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.85s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.92s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.42s/ba]
Creating parquet from Arrow form

Pushed action_compare of split: train to Hub dasyd/OppQA
Group Name: action_query of split: train


Generating train split: 6161 examples [00:10, 571.85 examples/s]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:05<00:00,  5.27s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.39s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.45s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.22s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.94s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.60s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.51s/ba]
Uploading the dataset shards: 100%|██████████| 7/7 [01:14<00:00, 10.65s/it]


Pushed action_query of split: train to Hub dasyd/OppQA
Group Name: counting of split: train


Generating train split: 32215 examples [00:59, 539.40 examples/s]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.89s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.49s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.56s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.20s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.56s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.52s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:05<00:00,  5.32s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:05<00:00,  5.36s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.96s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.63s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.51s/ba]
Creating parquet from Arrow form

Pushed counting of split: train to Hub dasyd/OppQA
Group Name: existence of split: train


Generating train split: 523 examples [00:00, 655.46 examples/s]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:02<00:00,  2.90s/ba]
Uploading the dataset shards: 100%|██████████| 1/1 [00:11<00:00, 11.88s/it]


Pushed existence of split: train to Hub dasyd/OppQA
Group Name: number_compare of split: train


Generating train split: 7928 examples [00:13, 582.71 examples/s]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:05<00:00,  5.54s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:05<00:00,  5.03s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:05<00:00,  5.48s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.96s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:05<00:00,  5.24s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.62s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.58s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:05<00:00,  5.08s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:05<00:00,  5.14s/ba]
Uploading the dataset shards: 100%|██████████| 9/9 [01:47<00:00, 11.98s/it]


Pushed number_compare of split: train to Hub dasyd/OppQA
Group Name: action_compare of split: val


Generating train split: 4144 examples [00:07, 559.93 examples/s]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:05<00:00,  5.18s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.74s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.54s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.20s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.55s/ba]
Uploading the dataset shards: 100%|██████████| 5/5 [00:51<00:00, 10.34s/it]


Pushed action_compare of split: val to Hub dasyd/OppQA
Group Name: action_query of split: val


Generating train split: 620 examples [00:01, 586.16 examples/s]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:03<00:00,  3.33s/ba]
Uploading the dataset shards: 100%|██████████| 1/1 [00:07<00:00,  7.96s/it]


Pushed action_query of split: val to Hub dasyd/OppQA
Group Name: counting of split: val


Generating train split: 3670 examples [00:06, 529.98 examples/s]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.06s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:03<00:00,  3.82s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:03<00:00,  3.83s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:03<00:00,  3.94s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:03<00:00,  4.00s/ba]
Uploading the dataset shards: 100%|██████████| 5/5 [00:47<00:00,  9.51s/it]


Pushed counting of split: val to Hub dasyd/OppQA
Group Name: existence of split: val


Generating train split: 68 examples [00:00, 581.98 examples/s]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:00<00:00,  2.74ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:01<00:00,  1.20s/it]


Pushed existence of split: val to Hub dasyd/OppQA
Group Name: number_compare of split: val


Generating train split: 837 examples [00:01, 646.82 examples/s]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:05<00:00,  5.15s/ba]
Uploading the dataset shards: 100%|██████████| 1/1 [00:10<00:00, 10.88s/it]


Pushed number_compare of split: val to Hub dasyd/OppQA
Group Name: action_compare of split: test


Generating train split: 6163 examples [00:11, 541.74 examples/s]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:05<00:00,  5.69s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.60s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.57s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.34s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.58s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.64s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.88s/ba]
Uploading the dataset shards: 100%|██████████| 7/7 [01:23<00:00, 11.94s/it]


Pushed action_compare of split: test to Hub dasyd/OppQA
Group Name: action_query of split: test


Generating train split: 1612 examples [00:03, 445.99 examples/s]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  5.00s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:03<00:00,  3.96s/ba]
Uploading the dataset shards: 100%|██████████| 2/2 [00:20<00:00, 10.33s/it]


Pushed action_query of split: test to Hub dasyd/OppQA
Group Name: counting of split: test


Generating train split: 5461 examples [00:10, 512.99 examples/s]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:05<00:00,  5.33s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:03<00:00,  3.79s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:03<00:00,  3.99s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.40s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.67s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.11s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.14s/ba]
Uploading the dataset shards: 100%|██████████| 7/7 [01:14<00:00, 10.67s/it]


Pushed counting of split: test to Hub dasyd/OppQA
Group Name: existence of split: test


Generating train split: 283 examples [00:00, 658.47 examples/s]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:01<00:00,  1.65s/ba]
Uploading the dataset shards: 100%|██████████| 1/1 [00:04<00:00,  4.16s/it]


Pushed existence of split: test to Hub dasyd/OppQA
Group Name: number_compare of split: test


Generating train split: 5837 examples [00:10, 531.08 examples/s]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.94s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.98s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.50s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.73s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:04<00:00,  4.80s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:07<00:00,  7.08s/ba]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:05<00:00,  5.44s/ba]
Uploading the dataset shards: 100%|██████████| 7/7 [01:21<00:00, 11.61s/it]


Pushed number_compare of split: test to Hub dasyd/OppQA


You have to first run:

```shell
huggingface-cli login
```

## Test if it works (this re-downloads the dataset)

In [67]:
from datasets import VerificationMode


class TimeQADataModule(LightningDataModule):
    KEY = "dasyd/OppQA"

    def __init__(
        self,
        batch_size: int = 32,
        task: Literal[
            "existence", "action_compare", "number_compare", "action_query", "counting"
        ] = "action_compare",
    ):
        super().__init__()

        self.batch_size = batch_size
        self.task = task

    def _load_dataset_split(self, splits: list[str]):
        """Workaround to overcome the missing hf implementation of only dowloading the split shards"""

        return load_dataset(
            TimeQADataModule.KEY,
            self.task,
            data_dir=self.task,
            data_files={split: f"{split}-*" for split in splits},
            verification_mode=VerificationMode.NO_CHECKS,
            num_proc=len(splits),
        )

    def prepare_data(self) -> None:
        self._load_dataset_split(["val", "train", "test"])

    def setup(self, stage: str) -> None:
        if stage == "fit":
            self.dataset = self._load_dataset_split(["train", "val"])
        elif stage == "test":
            self.dataset = self._load_dataset_split(["test"])

        self.dataset = self.dataset.with_format("torch")

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.dataset["train"], batch_size=self.batch_size, shuffle=True
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(self.dataset["val"], batch_size=self.batch_size)

    def test_dataloader(self) -> DataLoader:
        return DataLoader(self.dataset["test"], batch_size=self.batch_size)


module = TimeQADataModule(batch_size=4, task="action_compare")
module.prepare_data()
module.setup("fit")


[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Downloading data: 100%|██████████| 187M/187M [00:08<00:00, 21.7MB/s]

[A
[A
Downloading data: 100%|██████████| 187M/187M [00:09<00:00, 20.2MB/s]

[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Downloading data: 100%|██████████| 187M/187M [00:08<00:00, 23.0MB/s]
Downloading data: 100%|██████████| 187M/187M [00:09<00:00, 19.4MB/s]
Downloading data: 100%|██████████| 187M/187M [00:04<00:00, 38.0MB/s]



[A[A[A

[A[A


[A[A[A

[A[A


[A[A[A

[A[A


[A[A[A

[A[A


[A[A[A

[A[A


[A[A[A

[A[A


[A[A[A

[A[A


[A[A[A

[A[A


[A[A[A

[A[A


[A[A[A

[A[A


[A[A[A


[A[A[A

[A[A

[A[A


[A[A[A


[A[A[A

[A[A

[A[A

KeyboardInterrupt: 

In [None]:
loader = module.train_dataloader()
batch = next(iter(loader))

# list(batch.keys())
batch["trajectory"].shape