## Setup

In [None]:
!pip install -qq seaborn

In [None]:
import json
from pathlib import Path
from typing import Literal

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from datasets import Dataset, VerificationMode, concatenate_datasets, load_dataset
from lightning.pytorch import LightningDataModule
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

## Option #1: Load data from huggingface 

Run `huggingface-cli login` now.

In [None]:
class TimeQADataModule(LightningDataModule):
    def __init__(
        self,
        task: Literal["binary", "multi", "open"],
        batch_size: int = 32,
        repo: str = "dasyd/time-qa",
    ):
        super().__init__()

        self.task = task
        self.batch_size = batch_size
        self.repo = repo

    def _load_dataset_split(self, splits: list[str]):
        """Workaround to overcome the missing hf implementation of only dowloading the split shards"""

        return load_dataset(
            self.repo,
            self.task,
            data_dir=self.task,
            data_files={split: f"{split}-*" for split in splits},
            verification_mode=VerificationMode.NO_CHECKS,
            num_proc=len(splits),
        )

    def prepare_data(self) -> None:
        # Download all, since only this is run on the main process
        self._load_dataset_split(["train", "val", "test"])

    def setup(self, stage: str) -> None:
        if stage == "fit":
            self.dataset = self._load_dataset_split(["train", "val"])
        elif stage == "test":
            self.dataset = self._load_dataset_split(["test"])

    def train_dataloader(self) -> DataLoader:
        return DataLoader(self.dataset["train"], batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self) -> DataLoader:
        return DataLoader(self.dataset["val"], batch_size=self.batch_size)

    def test_dataloader(self) -> DataLoader:
        return DataLoader(self.dataset["test"], batch_size=self.batch_size)

    def all_splits(self) -> Dataset:
        return concatenate_datasets(self._load_dataset_split(["train", "val", "test"]).values())


module = TimeQADataModule(task="binary")  # TODO: load all variants (binary, multi, open)!
module.prepare_data()
complete = module.all_splits()
complete

In [None]:
next(iter(complete)).keys()

In [None]:
data = pd.DataFrame.from_records(
    [
        {
            "sample_id": entry["sample_id"],
            "question_id": entry["question_id"],
            "textual_description": entry["textual_description"],
            "question_type": entry["question_type"],
            "question": entry["question"],
            "answer_type": entry["answer_type"],
            "answer": entry["answer"],
            "action_sequence": [],  # Not available in the dataset
        }
        for entry in tqdm(complete)
    ],
)
pd.to_pickle(data, "textual_part.pkl")

In [None]:
data = pd.read_pickle("textual_part.pkl")

## Option #2: Load local data

In [None]:
base_path = Path(".").absolute().parents[1] / "generated-dataset-30_000" / "data"

assert base_path.exists()

In [None]:
def get_instances():
    for file in base_path.glob("*/data.json"):
        with open(file) as f:
            as_json = json.load(f)

        def answer_mapping(answer_type, answer, options, correct_option):
            match answer_type:
                case "binary":
                    correct = str(options[correct_option])
                    return {"True": "A", "False": "B"}[correct]
                case "multi":
                    return correct_option
                case _:
                    return answer

        action_sequence = [entry["action"] for entry in as_json["prompt_sequence"]]

        yield from (
            {
                "sample_id": str(file.parts[-2]),
                "question_id": question_id,
                "textual_description": as_json["textual_description"],
                "question_type": qa_pair["question_type"],
                "question": qa_pair["question"],
                "answer_type": qa_pair["answer_type"],
                "answer": answer_mapping(
                    qa_pair["answer_type"], qa_pair["answer"], qa_pair["options"], qa_pair["correct_option"]
                ),
                "action_sequence": action_sequence,
            }
            for question_id, qa_pair in enumerate(as_json["qa_pairs"])
        )


data = pd.DataFrame.from_records(list(tqdm(get_instances())))

## Handling

In [None]:
data.info()

In [None]:
data.rename(
    columns={
        "textual_description": "Description",
        "question_type": "Question Type",
        "question": "Question",
        "answer_type": "Answer Type",
        "answer": "Answer",
        "action_sequence": "Action Sequence",
    },
    inplace=True,
)

In [None]:
data = data.sort_values(by=["Answer Type"])

In [None]:
data["unique_id"] = data["sample_id"].astype(str) + "_" + data["question_id"].astype(str)
data["Action Sequence"] = data["Action Sequence"].apply(lambda x: " - ".join(x))
data

In [None]:
binary = data[data["Answer Type"] == "binary"]
multi = data[data["Answer Type"] == "multi"]
binary_multi = pd.concat([binary, multi])
open_ = data[data["Answer Type"] == "open"]

In [None]:
sns.set_context("talk")
g = sns.displot(
    data=binary_multi.sort_values(by=["Answer Type", "Question Type", "Answer"]),
    x="Answer",
    hue="Answer Type",
    col="Question Type",
    col_wrap=5,
)
g.set_titles("{col_name}")
sns.move_legend(g, "lower center", bbox_to_anchor=(0.5, 1), ncol=3, frameon=True)
plt.savefig("answer_distribution.pdf", bbox_inches="tight")
pass

In [None]:
plt.figure(figsize=(7, 14))
ax = sns.histplot(
    data=data.sort_values(by=["Answer Type", "Question Type"]),
    y="Question Type",
    hue="Answer Type",
    # shrink=0.75,
)
sns.move_legend(ax, "lower center", bbox_to_anchor=(0.5, 1), ncol=3, frameon=True)
plt.savefig("question_distribution.pdf", bbox_inches="tight")
pass

In [None]:
len(data)

In [None]:
data_unique = data.copy()
data_unique["Q&A"] = data_unique["Question"] + " " + data_unique["Answer"]
data_unique.nunique()[["Description", "Q&A", "Question Type", "Question", "Answer", "Action Sequence"]]

In [None]:
# from generate.prompts.base.utility import all_actions
# len(all_actions)