# MNIST dataset handling
---

In this notebook we will develop a custom `dataset` class which will be able to:
- import the MNIST dataset from a **url**
- **read** the MNIST dataset and **store** it in a `torch.tensor`
- **save** the dataset in `.pt` format to be easily accessible within the `PyTorch` environment
- provide a method to create the dataset **splits**, according to some proportions
- provide a method to perform some **preprocessing** operations

We will procede as follows:
- file decoding procedure
    - analisys of the MNIST dataset format (info taken from this [source](http://yann.lecun.com/exdb/mnist/))
    - download the files from the sources
    - read the file and retrieve the data dimensions and type
    - store the data into a `torch.tensor` and save it to memory in `.pt` format
- `dataset` class implementation
    - define a constructor `__init__`
    - provide a method `splits` to split the dataset

## Dataset class construction

> At the heart of PyTorch data loading utility is the `torch.utils.data.DataLoader` class. It represents a Python iterable over a dataset.

The `DataLoader` class takes several arguments but the most important two are:
- `Dataset`: an abstract class representing a dataset. All subclasses of `Dataset` should overwrite:
    - `__getitem__()`: returns a fetched data sample for a given key
    - `__len__()`: returns the size of the dataset
- `sampler` or `batch_sampler`: they define the strategy to draw samples of batches of samples from the dataset.

So, the first thing to consider is to develop a subclass of `torch.utils.data.Dataset` which overloads the aformentioned methods and implements the ones of which we talked in the `file_decoding_procedure` notebook.


### Class constructor

The class constructor must be able to:
- download the dataset if requested and read it
- create the tensors and do all the load/store stuff
- leave the dataset empty if requested (it will be usefull to create splits)

### Overloading

The overloading of the `__len__()` and `__getitem__()` functions is straightforward.

### Save the dataset

We can use the already created function `save` to exploit this task.



In [12]:
import utils
import torch

class MNIST(torch.utils.data.Dataset):

    def __init__(
          self
        , folder: str
        , train: bool
        , download: bool=False
        , empty: bool=False
        ) -> None:
        """
        Class constructor.

        Args:
            folder (str): folder in which contains/will contain the data
            train (bool): if True the training dataset is built, otherwise the test dataset
            download (bool): if True the dataset will be downloaded (default = True)
            empty (bool): if True the tensors will be left empty (default = False)
        """

        # user folder check
        # ------------------------
        if folder is None:
            raise FileNotFoundError("Please specify the data folder")
        if not os.path.exists(folder) or os.path.isfile(folder):
            raise FileNotFoundError("Invalid data path: {}".format(folder))
        # ------------------------

        # utilities
        # ------------------------
        if train:
            urls = {
                'training-images': 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'
                , 'training-labels': 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'
            }
            self.save_file = 'training.pt'
        else:
            urls = {
                'test-images': 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz'
                , 'test-labels': 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz'
            }
            self.save_file = 'test.pt'
        # ------------------------

        # class members
        # ------------------------
        self.raw_folder = os.path.join(folder, "data/raw")
        self.processed_folder = os.path.join(folder, "data/processed")
        # ------------------------

        # dataset download
        # ------------------------
        if download:
            for name, url in urls.items():
                utils.download(url, self.raw_folder, name)
        # ------------------------
        
        # dataset folder check
        # ------------------------
        else:   # not download
            if not os.path.exists(self.raw_folder) or os.path.isfile(self.raw_folder):
                raise FileNotFoundError("Invalid data path: {}".format(self.raw_folder))
        # ------------------------

        # data storing
        # ------------------------
        if not empty:
            for name, _ in urls.items():
                filepath = os.path.join(self.raw_folder, name)
                if "images" in name:
                    self.data = utils.store_file_to_tensor(filepath)
                elif "labels" in name:
                    self.labels = utils.store_file_to_tensor(filepath)
            self.save()
            
        else:
            self.data = None
            self.labels = None
        # ------------------------
            
    
    def __len__(self) -> int:
        """
        Return the lenght of the dataset.

        Returns:
            length of the dataset (int)
        """
        return len(self.data) if self.data is not None else 0

    
    def __getitem__(
          self
        , idx: int
        ) -> tuple:
        """
        Retrieve the item of the dataset at index idx.

        Args:
            idx (int): index of the item to be retrieved.
        
        Returns:
            tuple: (image, label) 
        """
        img, label = self.data[idx], int(self.labels[idx])

        return (img, label)
    


    def save(self) -> None:
        """
        Save the dataset (tuple of torch.tensors) into a file defined by self.processed_folder and self.save_file.
        """
        if not os.path.exists(self.processed_folder):   
            os.makedirs(self.processed_folder)  

        # saving the training set into the correct folder
        with open(os.path.join(self.processed_folder, self.save_file), 'wb') as f:
            torch.save((self.data, self.labels), f)


    def load(self) -> None:
        """
        Load the file .pt in the path defined by self.processed_folder and self.save_file.
        """
        file_path = os.path.join(self.processed_folder, self.save_file)

        if not os.path.exists(file_path):
            raise FileNotFoundError("Folder not present: {}".format(file_path))

        self.data, self.labels = torch.load(file_path)

    
    def splits(
          self
        , proportions: list=[0.8, 0.2]
        , shuffle: bool=True
        ) -> None:
        """
        Split the the dataset according to the given proportions and return two instances of MNIST, training and validation.

        Args:
            proportions (list): (default=[0.8,0.2]) list of proportions for training set and validation set.
            shuffle (bool): (default=True) whether to shuffle the dataset or not
        """

        # check proportions
        if sum(proportions) == 1. and all([p > 0. for p in proportions]) and len(proportions) == 2:
            pass
        else:
            raise ValueError("Invalid proportions: they must (1) be 2 (2) sum up to 1 (3) be all positives.")

        
        # creating data indices for training and validation splits
        length = self.len()
        if length == 0:
            raise ValueError("Dataset must NOT be empty!")
        indices = np.arange(length)
        split = int(np.floor(proportions[0] * length))

        if shuffle:
            np.random.shuffle(indices)  # in-place operation
        
        training_indices, validation_indices = indices[:split], indices[split:]

        # dividing data with respect to the main classes
        data_per_class = {}
        for j in range(0, self.main_class_count):
            data_per_class.append([])

        for i in range(0, len(self.files)):
            data_per_class[self.labels[i]].append(i)

        num_splits = len(proportions)

        # creating a list of Dataset object (one dataset for each split)
        datasets = []
        for i in range(0, num_splits):
            datasets.append(Dataset(self.path, empty_dataset=True))

        # splitting data
        for j in range(0, self.main_class_count):
            start = 0  # index of the first element to consider

            for i in range(0, num_splits):
                p = proportions[i]
                n = int(p * len(data_per_class[j]))  # number of element to consider for class 'j'
                end = start + n if i < num_splits - 1 else len(data_per_class[j])  # last (excluded) element to consider

                sample_ids = data_per_class[j][start:end]  # indices of the samples to consider

                # adding the selected data to the current split
                datasets[i].files.extend([self.files[z] for z in sample_ids])
                datasets[i].labels.extend([self.labels[z] for z in sample_ids])
                if len(self.images) > 0:
                    datasets[i].images.extend([self.images[z] for z in sample_ids])

                start = end  # moving the starting index

        return datasets




SyntaxError: EOF while scanning triple-quoted string literal (<ipython-input-12-5edbf41f802a>, line 139)

In [13]:
path = "./"
# prova = MNIST(path, train=True, download=True, empty=False)
prova = MNIST(path, train=True, download=False, empty=True)
prova.load()