# Imports

In [None]:
import random
from collections import namedtuple
from pathlib import Path
from pytorch_lightning import seed_everything

from datasets import load_dataset, DatasetDict, load_from_disk, Dataset, concatenate_datasets
from la.utils.utils import MyDatasetDict
from nn_core.common import PROJECT_ROOT

In [None]:
USE_CACHED: bool = True
seed = 42
seed_everything(seed)

# Data loading

In [None]:
def get_dataset(name: str, split: str, perc: float, seed: int = 42):
    """
    Load a dataset from the HuggingFace datasets library.
    """
    assert 0 < perc <= 1
    dataset = load_dataset(
        name,
        split=split,
        use_auth_token=True,
    )

    # Select a random subset
    if perc != 1:
        dataset = dataset.shuffle(seed=seed).select(list(range(int(len(dataset) * perc))))

    return dataset

In [None]:
DatasetParams = namedtuple("DatasetParams", ["name", "fine_grained", "train_split", "test_split", "perc", "hf_key"])

In [None]:
dataset_name = "tiny_imagenet"  # tiny_imagenet or cifar100

dataset_ref = {"tiny_imagenet": "Maysee/tiny-imagenet", "cifar100": "cifar100"}

dataset_train_split = {"tiny_imagenet": "train", "cifar100": "train"}

dataset_test_split = {"tiny_imagenet": "valid", "cifar100": "test"}

dataset_label = {"tiny_imagenet": "label", "cifar100": "fine_label"}

label_key = dataset_label[dataset_name]

dataset_num_train_samples_per_class = {"tiny_imagenet": 500, "cifar100": 500}

dataset_num_test_samples_per_class = {"tiny_imagenet": 50, "cifar100": 100}

In [None]:
dataset_params: DatasetParams = DatasetParams(
    dataset_ref[dataset_name],
    None,
    dataset_train_split[dataset_name],
    dataset_test_split[dataset_name],
    1,
    (dataset_ref[dataset_name],),
)

In [None]:
DATASET_KEY = "_".join(map(str, [v for k, v in dataset_params._asdict().items() if k != "hf_key" and v is not None]))
DATASET_DIR: Path = PROJECT_ROOT / "data" / "encoded_data" / DATASET_KEY
DATASET_DIR

In [None]:
if not DATASET_DIR.exists() or not USE_CACHED:

    dataset: DatasetDict = DatasetDict(
        train=get_dataset(name=dataset_params.name, split=dataset_params.train_split, perc=dataset_params.perc),
        test=get_dataset(name=dataset_params.name, split=dataset_params.test_split, perc=dataset_params.perc),
    )
else:
    dataset: Dataset = load_from_disk(dataset_path=str(DATASET_DIR))

dataset

# Subdivide into tasks

### Params

In [None]:
num_shared_classes = 100
num_novel_classes_per_task = 20

## Add ids

In [None]:
dataset["train"] = dataset["train"].map(lambda row, ind: {"id": ind}, batched=True, with_indices=True)
dataset["test"] = dataset["test"].map(lambda row, ind: {"id": ind}, batched=True, with_indices=True)

In [None]:
num_train_samples_per_class = dataset_num_train_samples_per_class[dataset_name]
num_test_samples_per_class = dataset_num_test_samples_per_class[dataset_name]

In [None]:
all_classes = dataset["train"].features[label_key].names
all_classes_ids = [id for id, _ in enumerate(all_classes)]
class_str_to_id = {c: i for i, c in enumerate(all_classes)}
num_classes = len(all_classes)
print(f"{num_classes} classes in total")

### Sample shared classes

In [None]:
shared_classes = set(random.sample(all_classes_ids, k=num_shared_classes))

assert len(shared_classes) == num_shared_classes

In [None]:
non_shared_classes = set([c for c in all_classes_ids if c not in shared_classes])

assert len(non_shared_classes) == num_classes - num_shared_classes

### Subdivide data into tasks defined by different classes subsets

In [None]:
num_tasks = (num_classes - num_shared_classes) // num_novel_classes_per_task

In [None]:
new_dataset = MyDatasetDict()
global_to_local_class_mappings = {}

# task 0 is a dummy task that consists of the samples for all the classes
new_dataset[f"task_0_train"] = dataset["train"]
new_dataset[f"task_0_test"] = dataset["test"]

global_to_local_class_mappings["task_0"] = {class_str_to_id[c]: i for i, c in enumerate(all_classes)}

shared_train_samples = dataset["train"].filter(lambda x: x[label_key] in shared_classes)
shared_test_samples = dataset["test"].filter(lambda x: x[label_key] in shared_classes)

for i in range(1, num_tasks + 1):
    task_novel_classes = set(random.sample(list(non_shared_classes), k=num_novel_classes_per_task))

    # remove the classes sampled for this task so that all tasks have disjoint novel classes
    non_shared_classes = non_shared_classes.difference(task_novel_classes)

    task_classes = shared_classes.union(task_novel_classes)

    global_to_local_class_map = {c: i for i, c in enumerate(list(task_classes))}

    novel_train_samples = dataset["train"].filter(lambda x: x[label_key] in task_novel_classes)

    task_train_samples = concatenate_datasets([shared_train_samples, novel_train_samples])

    task_train_samples = task_train_samples.map(lambda row: {"fine_label": global_to_local_class_map[row[label_key]]})

    novel_test_samples = dataset["test"].filter(lambda x: x[label_key] in task_novel_classes)

    task_test_samples = concatenate_datasets([shared_test_samples, novel_test_samples])

    task_test_samples = task_test_samples.map(lambda row: {"fine_label": global_to_local_class_map[row[label_key]]})

    print(task_classes)
    assert len(task_train_samples) == num_train_samples_per_class * len(task_classes)
    assert len(task_test_samples) == num_test_samples_per_class * len(task_classes)

    global_to_local_class_mappings[f"task_{i}"] = global_to_local_class_map

    new_dataset[f"task_{i}_train"] = task_train_samples
    new_dataset[f"task_{i}_test"] = task_test_samples

In [None]:
metadata = {
    "num_train_samples_per_class": num_train_samples_per_class,
    "num_test_samples_per_class": num_test_samples_per_class,
    "num_shared_classes": num_shared_classes,
    "num_novel_classes_per_task": num_novel_classes_per_task,
    "num_tasks": num_tasks,
    "shared_classes": list(shared_classes),
    "non_shared_classes": list(non_shared_classes),
    "all_classes": all_classes,
    "all_classes_ids": all_classes_ids,
    "num_classes": num_classes,
    "global_to_local_class_mappings": global_to_local_class_mappings,
}

new_dataset["metadata"] = metadata

In [None]:
print(metadata["global_to_local_class_mappings"])

# Save to file

In [None]:
dataset_folder = PROJECT_ROOT / "data" / f"{dataset_name}"

if not dataset_folder.exists():
    dataset_folder.mkdir()

output_folder = dataset_folder / f"S{num_shared_classes}_N{num_novel_classes_per_task}"

if not (output_folder).exists():
    (output_folder).mkdir()

In [None]:
new_dataset.save_to_disk(output_folder)