In [1]:
import random
from collections import namedtuple
from pathlib import Path
from pytorch_lightning import seed_everything

from datasets import load_dataset, DatasetDict, load_from_disk, Dataset, concatenate_datasets
from la.utils.utils import MyDatasetDict
from nn_core.common import PROJECT_ROOT

  warn(f"Failed to load image Python extension: {e}")


In [2]:
USE_CACHED: bool = True

In [3]:
def get_dataset(name: str, split: str, perc: float, seed: int = 42):
    """
    Load a dataset from the HuggingFace datasets library.
    """
    assert 0 < perc <= 1
    dataset = load_dataset(
        name,
        split=split,
        use_auth_token=True,
    )
    seed_everything(seed)

    # Select a random subset
    if perc != 1:
        dataset = dataset.shuffle(seed=seed).select(list(range(int(len(dataset) * perc))))

    return dataset

In [4]:
DatasetParams = namedtuple("DatasetParams", ["name", "fine_grained", "train_split", "test_split", "perc", "hf_key"])

In [5]:
dataset_params: DatasetParams = DatasetParams("cifar100", None, "train", "test", 1, ("cifar100",))
dataset_params

DatasetParams(name='cifar100', fine_grained=None, train_split='train', test_split='test', perc=1, hf_key=('cifar100',))

In [6]:
DATASET_KEY = "_".join(map(str, [v for k, v in dataset_params._asdict().items() if k != "hf_key" and v is not None]))
DATASET_DIR: Path = PROJECT_ROOT / "data" / "encoded_data" / DATASET_KEY
DATASET_DIR

PosixPath('/home/donato/PycharmProjects/latent-aggregation/data/encoded_data/cifar100_train_test_1')

In [7]:
if not DATASET_DIR.exists() or not USE_CACHED:

    dataset: DatasetDict = DatasetDict(
        train=get_dataset(name=dataset_params.name, split=dataset_params.train_split, perc=dataset_params.perc),
        test=get_dataset(name=dataset_params.name, split=dataset_params.test_split, perc=dataset_params.perc),
    )
else:
    dataset: Dataset = load_from_disk(dataset_path=str(DATASET_DIR))

dataset

DatasetDict({
    train: Dataset({
        features: ['img', 'fine_label', 'coarse_label'],
        num_rows: 50000
    })
    test: Dataset({
        features: ['img', 'fine_label', 'coarse_label'],
        num_rows: 10000
    })
})

# Subdivide CIFAR100 into tasks

### Params

In [8]:
num_shared_classes = 80
num_novel_classes_per_task = 5

## Add ids

In [9]:
dataset["train"] = dataset["train"].map(lambda row, ind: {"id": ind}, batched=True, with_indices=True)
dataset["test"] = dataset["test"].map(lambda row, ind: {"id": ind}, batched=True, with_indices=True)

In [10]:
num_train_samples_per_class = 500
num_test_samples_per_class = 100

In [11]:
all_classes = dataset["train"].features["fine_label"].names
all_classes_ids = [id for id, _ in enumerate(all_classes)]
class_str_to_id = {c: i for i, c in enumerate(all_classes)}
num_classes = len(all_classes)
print(f"{num_classes} classes in total")

100 classes in total


### Sample shared classes

In [12]:
shared_classes = set(random.sample(all_classes_ids, k=num_shared_classes))

assert len(shared_classes) == num_shared_classes

In [13]:
non_shared_classes = set([c for c in all_classes_ids if c not in shared_classes])

assert len(non_shared_classes) == num_classes - num_shared_classes

### Subdivide data into tasks defined by different classes subsets

In [14]:
num_tasks = (num_classes - num_shared_classes) // num_novel_classes_per_task

In [15]:
new_dataset = MyDatasetDict()
global_to_local_class_mappings = {}

# task 0 is a dummy task that consists of the samples for all the classes
new_dataset[f"task_0_train"] = dataset["train"]
new_dataset[f"task_0_test"] = dataset["test"]

global_to_local_class_mappings["task_0"] = {class_str_to_id[c]: i for i, c in enumerate(all_classes)}

shared_train_samples = dataset["train"].filter(lambda x: x["fine_label"] in shared_classes)
shared_test_samples = dataset["test"].filter(lambda x: x["fine_label"] in shared_classes)

for i in range(1, num_tasks + 1):
    task_novel_classes = set(random.sample(list(non_shared_classes), k=num_novel_classes_per_task))

    # remove the classes sampled for this task so that all tasks have disjoint novel classes
    non_shared_classes = non_shared_classes.difference(task_novel_classes)

    task_classes = shared_classes.union(task_novel_classes)

    global_to_local_class_map = {c: i for i, c in enumerate(list(task_classes))}

    novel_train_samples = dataset["train"].filter(lambda x: x["fine_label"] in task_novel_classes)

    task_train_samples = concatenate_datasets([shared_train_samples, novel_train_samples])

    task_train_samples = task_train_samples.map(
        lambda row: {"fine_label": global_to_local_class_map[row["fine_label"]]}
    )

    novel_test_samples = dataset["test"].filter(lambda x: x["fine_label"] in task_novel_classes)

    task_test_samples = concatenate_datasets([shared_test_samples, novel_test_samples])

    task_test_samples = task_test_samples.map(lambda row: {"fine_label": global_to_local_class_map[row["fine_label"]]})

    assert len(task_train_samples) == num_train_samples_per_class * len(task_classes)
    assert len(task_test_samples) == num_test_samples_per_class * len(task_classes)

    global_to_local_class_mappings[f"task_{i}"] = global_to_local_class_map

    new_dataset[f"task_{i}_train"] = task_train_samples
    new_dataset[f"task_{i}_test"] = task_test_samples

Filter:   0%|          | 0/50000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/10000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/50000 [00:00<?, ? examples/s]

Map:   0%|          | 0/42500 [00:00<?, ? examples/s]

Filter:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map:   0%|          | 0/8500 [00:00<?, ? examples/s]

Filter:   0%|          | 0/50000 [00:00<?, ? examples/s]

Map:   0%|          | 0/42500 [00:00<?, ? examples/s]

Filter:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map:   0%|          | 0/8500 [00:00<?, ? examples/s]

Filter:   0%|          | 0/50000 [00:00<?, ? examples/s]

Map:   0%|          | 0/42500 [00:00<?, ? examples/s]

Filter:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map:   0%|          | 0/8500 [00:00<?, ? examples/s]

Filter:   0%|          | 0/50000 [00:00<?, ? examples/s]

Map:   0%|          | 0/42500 [00:00<?, ? examples/s]

Filter:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map:   0%|          | 0/8500 [00:00<?, ? examples/s]

In [18]:
metadata = {
    "num_train_samples_per_class": num_train_samples_per_class,
    "num_test_samples_per_class": num_test_samples_per_class,
    "num_shared_classes": num_shared_classes,
    "num_novel_classes_per_task": num_novel_classes_per_task,
    "num_tasks": num_tasks,
    "shared_classes": list(shared_classes),
    "non_shared_classes": list(non_shared_classes),
    "all_classes": all_classes,
    "all_classes_ids": all_classes_ids,
    "num_classes": num_classes,
    "global_to_local_class_mappings": global_to_local_class_mappings,
}

new_dataset["metadata"] = metadata

output_folder = PROJECT_ROOT / "data" / "cifar100_tasks"
new_dataset.save_to_disk(output_folder)

Saving the dataset (0/1 shards):   0%|          | 0/50000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/10000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/42500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/8500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/42500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/8500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/42500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/8500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/42500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/8500 [00:00<?, ? examples/s]

In [17]:
print(metadata["global_to_local_class_mappings"])


{'task_0': {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9, 10: 10, 11: 11, 12: 12, 13: 13, 14: 14, 15: 15, 16: 16, 17: 17, 18: 18, 19: 19, 20: 20, 21: 21, 22: 22, 23: 23, 24: 24, 25: 25, 26: 26, 27: 27, 28: 28, 29: 29, 30: 30, 31: 31, 32: 32, 33: 33, 34: 34, 35: 35, 36: 36, 37: 37, 38: 38, 39: 39, 40: 40, 41: 41, 42: 42, 43: 43, 44: 44, 45: 45, 46: 46, 47: 47, 48: 48, 49: 49, 50: 50, 51: 51, 52: 52, 53: 53, 54: 54, 55: 55, 56: 56, 57: 57, 58: 58, 59: 59, 60: 60, 61: 61, 62: 62, 63: 63, 64: 64, 65: 65, 66: 66, 67: 67, 68: 68, 69: 69, 70: 70, 71: 71, 72: 72, 73: 73, 74: 74, 75: 75, 76: 76, 77: 77, 78: 78, 79: 79, 80: 80, 81: 81, 82: 82, 83: 83, 84: 84, 85: 85, 86: 86, 87: 87, 88: 88, 89: 89, 90: 90, 91: 91, 92: 92, 93: 93, 94: 94, 95: 95, 96: 96, 97: 97, 98: 98, 99: 99}, 'task_1': {0: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6, 8: 7, 9: 8, 10: 9, 11: 10, 12: 11, 13: 12, 14: 13, 17: 14, 18: 15, 19: 16, 20: 17, 22: 18, 23: 19, 24: 20, 25: 21, 27: 22, 28: 23, 29: 24, 31: 25, 32: 26,