In [11]:
from abc import ABC, abstractmethod
from typing import List, Any

class Dataset(ABC):
    name = None
    samples = List[int]
    labels = List[int]
    class_names = List[str]

    def __init__(self, name):
        self.name = name

    @property
    def samples(self):
        return self._samples

    @samples.setter
    def samples(self, samples):
        self._samples = samples

    @property
    def labels(self):
        return self._labels

    @labels.setter
    def labels(self, samples):
        self._labels = labels

    @abstractmethod
    def print_properties(self):
        raise NotImplementedError
    
    def set_other_stuff(self):
        x = range(len(self.samples))
        self.labels = list(x)
        self.samples = list(range(20,25))


In [30]:
class OtherDataset(Dataset):
    def __init__(self, name, samples):
        self.name = name
        self.samples = samples
        self.labels = [0,5,3,7,9]

    def print_properties(self):
        print(f"{self.name}, {self.samples}, {self.labels}")

    def set_other_properties(self):
        super().set_other_stuff()
        self.samples = [self.samples, self.samples]

In [32]:
x = OtherDataset("example",list(range(10)))
y = OtherDataset("example2", list(range(11,21)))
x.print_properties()
y.print_properties()
x.set_other_properties()
x.print_properties()

example, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 5, 3, 7, 9]
example2, [11, 12, 13, 14, 15, 16, 17, 18, 19, 20], [0, 5, 3, 7, 9]
example, [[20, 21, 22, 23, 24], [20, 21, 22, 23, 24]], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


In [33]:
class BaseDataset(ABC):
    samples: List[Any] = []
    labels: List[int] = []  
    class_names: List[str] = []
    is_training: bool = False

    def __init__(
        self,
        dataset_name: str,
        subset_fraction: float,
        is_training: bool = False,
    ):
        self._dataset_name = dataset_name
        self._dataset_root = "data/cifar-10-batches-py"
        self._subset_fraction = subset_fraction
        self._is_training = is_training
        self._samples = []
        self._labels = []
        self._class_names = []

    @property
    def samples(self):
        return self._samples

    @samples.setter
    def samples(self, samples):
        self._samples = samples

    @property
    def labels(self):
        return self._labels

    @labels.setter
    def labels(self, labels):
        self._labels = labels

    @property
    def class_names(self):
        return self._class_names

    @class_names.setter
    def class_names(self, labels):
        self._class_names = class_names

    def __len__(self):
        return NotImplementedError

    def __getitem__(self, index):
        raise NotImplementedError

    @abstractmethod
    def get_classes(self):
        raise NotImplementedError

    def subset_data(self):
        if self.subset_fraction < 1:
            sample_indexes = range(len(self.labels))
            subset_size = int(len(self.labels) * self.subset_fraction)
            sampled_indexes = random.sample(sample_indexes, subset_size)
            self.samples = [self.samples[index] for index in sampled_indexes]
            self.labels = [self.labels[index] for index in sampled_indexes]

In [36]:
import pickle
import numpy as np
class CIFAR10Dataset(BaseDataset):
    train_list = [
        ["data_batch_1", "c99cafc152244af753f735de768cd75f"],
        ["data_batch_2", "d4bba439e000b95fd0a9bffe97cbabec"],
        ["data_batch_3", "54ebc095f3ab1f0389bbae665268c751"],
        ["data_batch_4", "634d18415352ddfa80567beed471001a"],
        ["data_batch_5", "482c414d41f54cd18b22e5b47cb7c3cb"],
    ]
    test_list = [
        ["test_batch", "40351d587109b95175f43aff81a1287e"],
    ]
    meta = {
        "filename": "batches.meta",
        "key": "label_names",
        "md5": "5ff9c542aee3614f3951f8cda6e48888",
    }
    def __init__(
        self,
        dataset_name: str,
        subset_fraction: float,
        is_training: bool = BaseDataset.is_training,
    ):
        super().__init__(
            dataset_name=dataset_name,
            subset_fraction=subset_fraction,
            is_training=is_training,
        )

        downloaded_list = self.train_list + self.test_list

        # now load the picked numpy arrays
        imgs = []
        for file_name, checksum in downloaded_list:
            file_path = self._dataset_root + "/" + file_name
            with open(file_path, "rb") as f:
                entry = pickle.load(f, encoding="latin1")
                imgs.append(entry["data"])
                if "labels" in entry:
                    self._labels.extend(entry["labels"])
                else:
                    self._labels.extend(entry["fine_labels"])

        self.samples = np.vstack(imgs).reshape(-1, 3, 32, 32)
        self.samples = self.samples.transpose((0, 2, 3, 1))  # convert to HWC

        #self._load_meta()
        #if subset_fraction < 1.0:
        #    self.subset_data()

    def get_classes(self) -> None:
        path = self._dataset_root / self.meta["filename"]
        if not check_integrity(path, self.meta["md5"]):
            raise RuntimeError(
                "Dataset metadata file not found or corrupted."
                + " You can use download=True to download it"
            )
        with open(path, "rb") as infile:
            data = pickle.load(infile, encoding="latin1")
            self.class_names = data[self.meta["key"]]
        self.class_to_idx = {_class: i for i, _class in enumerate(self.class_names)}

In [37]:
data = CIFAR10Dataset(dataset_name="CIFAR10", subset_fraction=0.5, is_training=True)

In [39]:
print(len(data.labels))

, 5, 0, 0, 7, 0, 9, 6, 1, 0, 3, 9, 7, 4, 9, 1, 6, 8, 1, 2, 3, 3, 5, 4, 8, 9, 7, 4, 4, 1, 2, 4, 9, 8, 7, 9, 5, 1, 2, 1, 6, 6, 4, 5, 7, 4, 5, 8, 5, 2, 8, 7, 8, 2, 3, 6, 1, 3, 3, 1, 5, 1, 9, 0, 9, 2, 0, 6, 2, 4, 8, 5, 7, 6, 1, 2, 9, 4, 5, 0, 3, 3, 7, 7, 7, 1, 4, 5, 0, 2, 8, 5, 0, 0, 6, 2, 0, 8, 4, 5, 4, 5, 6, 4, 7, 9, 4, 2, 0, 6, 4, 0, 0, 6, 4, 6, 1, 9, 5, 5, 2, 2, 6, 3, 4, 5, 9, 1, 7, 2, 3, 9, 6, 5, 0, 2, 9, 7, 1, 7, 2, 2, 0, 8, 6, 4, 3, 2, 7, 7, 0, 4, 1, 6, 5, 1, 3, 0, 3, 9, 0, 0, 2, 5, 0, 4, 0, 1, 9, 8, 4, 9, 4, 2, 4, 3, 3, 4, 0, 4, 3, 2, 8, 9, 1, 5, 8, 1, 8, 2, 4, 5, 2, 4, 1, 1, 6, 6, 8, 5, 2, 2, 5, 0, 8, 2, 3, 6, 2, 9, 6, 1, 4, 5, 9, 0, 1, 0, 0, 8, 1, 1, 6, 6, 9, 5, 4, 1, 7, 8, 6, 9, 1, 7, 6, 0, 9, 3, 5, 3, 2, 5, 3, 4, 9, 7, 1, 4, 4, 6, 1, 3, 8, 8, 0, 6, 7, 7, 6, 7, 2, 3, 2, 2, 6, 2, 7, 4, 0, 3, 6, 2, 6, 3, 3, 0, 9, 5, 1, 1, 5, 3, 6, 4, 3, 4, 1, 0, 4, 5, 5, 2, 8, 9, 4, 3, 1, 8, 0, 1, 3, 3, 4, 4, 2, 9, 7, 6, 8, 1, 8, 9, 1, 3, 1, 7, 3, 0, 0, 2, 8, 3, 9, 2, 7, 2, 6, 0, 1, 6, 1, 6, 7, 5,