In [None]:
import random
from pathlib import Path

from datasets import load_dataset, DatasetDict, load_from_disk, Dataset, concatenate_datasets
from la.utils.utils import MyDatasetDict
from nn_core.common import PROJECT_ROOT

# Subdivide CIFAR100 into tasks

### Params

In [None]:
num_shared_classes = 80
num_novel_classes_per_task = 5

### Data loading

In [None]:
DATASET_DIR: Path = PROJECT_ROOT / "data" / "encoded_data" / "cifar100_train_test_1"
dataset: Dataset = load_from_disk(dataset_path=str(DATASET_DIR))
dataset

## Add ids

In [None]:
dataset["train"] = dataset["train"].map(lambda row, ind: {"id": ind}, batched=True, with_indices=True)
dataset["test"] = dataset["test"].map(lambda row, ind: {"id": ind}, batched=True, with_indices=True)

In [None]:
num_train_samples_per_class = 500
num_test_samples_per_class = 100

In [None]:
all_classes = dataset["train"].features["fine_label"].names
all_classes_ids = [id for id, _ in enumerate(all_classes)]
class_str_to_id = {c: i for i, c in enumerate(all_classes)}
num_classes = len(all_classes)
print(f"{num_classes} classes in total")

### Sample shared classes

In [None]:
shared_classes = set(random.sample(all_classes_ids, k=num_shared_classes))

assert len(shared_classes) == num_shared_classes

In [None]:
non_shared_classes = set([c for c in all_classes_ids if c not in shared_classes])

assert len(non_shared_classes) == num_classes - num_shared_classes

### Subdivide data into tasks defined by different classes subsets

In [None]:
num_tasks = (num_classes - num_shared_classes) // num_novel_classes_per_task

In [None]:
new_dataset = MyDatasetDict()
global_to_local_class_mappings = {}

# task 0 is a dummy task that consists of the samples for all the classes
new_dataset[f"task_0_train"] = dataset["train"]
new_dataset[f"task_0_test"] = dataset["test"]

global_to_local_class_mappings["task_0"] = {class_str_to_id[c]: i for i, c in enumerate(all_classes)}

shared_train_samples = dataset["train"].filter(lambda x: x["fine_label"] in shared_classes)
shared_test_samples = dataset["test"].filter(lambda x: x["fine_label"] in shared_classes)

for i in range(1, num_tasks + 1):
    task_novel_classes = set(random.sample(list(non_shared_classes), k=num_novel_classes_per_task))

    # remove the classes sampled for this task so that all tasks have disjoint novel classes
    non_shared_classes = non_shared_classes.difference(task_novel_classes)

    task_classes = shared_classes.union(task_novel_classes)

    global_to_local_class_map = {c: i for i, c in enumerate(list(task_classes))}

    novel_train_samples = dataset["train"].filter(lambda x: x["fine_label"] in task_novel_classes)

    task_train_samples = concatenate_datasets([shared_train_samples, novel_train_samples])

    task_train_samples = task_train_samples.map(
        lambda row: {"fine_label": global_to_local_class_map[row["fine_label"]]}
    )

    novel_test_samples = dataset["test"].filter(lambda x: x["fine_label"] in task_novel_classes)

    task_test_samples = concatenate_datasets([shared_test_samples, novel_test_samples])

    task_test_samples = task_test_samples.map(lambda row: {"fine_label": global_to_local_class_map[row["fine_label"]]})

    assert len(task_train_samples) == num_train_samples_per_class * len(task_classes)
    assert len(task_test_samples) == num_test_samples_per_class * len(task_classes)

    global_to_local_class_mappings[f"task_{i}"] = global_to_local_class_map

    new_dataset[f"task_{i}_train"] = task_train_samples
    new_dataset[f"task_{i}_test"] = task_test_samples

In [None]:
metadata = {
    "num_train_samples_per_class": num_train_samples_per_class,
    "num_test_samples_per_class": num_test_samples_per_class,
    "num_shared_classes": num_shared_classes,
    "num_novel_classes_per_task": num_novel_classes_per_task,
    "num_tasks": num_tasks,
    "shared_classes": list(shared_classes),
    "non_shared_classes": list(non_shared_classes),
    "all_classes": all_classes,
    "all_classes_ids": all_classes_ids,
    "num_classes": num_classes,
    "global_to_local_class_mappings": global_to_local_class_mappings,
}

new_dataset["metadata"] = metadata

output_folder = PROJECT_ROOT / "data" / "cifar100_tasks"
new_dataset.save_to_disk(output_folder)

In [None]:
print(metadata["global_to_local_class_mappings"])