# MNIST dataset handling
---

In this notebook we will develop a custom `dataset` class which will be able to:
- import the MNIST dataset from a **url**
- **read** the MNIST dataset and **load** it in a `torch.tensor`
- **save** the dataset in `.pt` format to be easily accessible within the `PyTorch` environment
- provide a method to create the dataset **splits**, according to some proportions
- provide a method to perform some **preprocessing** operations

We will procede as follows:
- file decoding procedure
    - analisys of the MNIST dataset format (info taken from this [source](http://yann.lecun.com/exdb/mnist/))
    - download the files from the sources
    - reading the file and retrieving the data dimensions and type
    - loading the data into a `torch.tensor`
- `dataset` class construction
    - define a constructor `__init__`
    - provide a method `create_splits` to split the dataset

## File decoding procedure
### MNIST Dataset format


The **IDX file format** is a simple format for vectors and multidimensional matrices of various numerical types.
The basic format is

```
magic number
size in dimension 0
size in dimension 1
size in dimension 2
.....
size in dimension N
data
```
    
The magic number is an integer (MSB first). The first 2 bytes are always 0.

The third byte codes the type of the data:
- 0x08: unsigned byte
- 0x09: signed byte
- 0x0B: short (2 bytes)
- 0x0C: int (4 bytes)
- 0x0D: float (4 bytes)
- 0x0E: double (8 bytes)

The 4-th byte codes the number of dimensions of the vector/matrix: 1 for vectors, 2 for matrices....

The sizes in each dimension are 4-byte integers (MSB first, high endian, like in most non-Intel processors).

The data is stored like in a C array, i.e. the index in the last dimension changes the fastest.

### File downloading from url

The MNIST dataset is freely available at this [source](http://yann.lecun.com/exdb/mnist/).
Our aim is that of automatically downloading it (if not present in the project folders) and save it in order to let it be accessible from the `dataset` class that we are going to develop.

In [None]:
urls = {
          'training-images': 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'
        , 'training-labels': 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'
        , 'test-images': 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz'
        , 'test-labels': 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz'
}

In order to download something from a url we can take advantage of the [requests](https://2.python-requests.org/en/master/api/) library.

In [None]:
import requests

First of all we need to know whether the sources are _downloadable_ or not. This is just a check and can be easily implemented as follows:
- we look at the `content-type` of the document: if it is **application/x-gzip**, that it is a valid source for our dataset.

To do so, we can use the `requests.head` function, which is useful for obtaining met-ainformation about the entity implied by the request without transferring the entity-body itself.

In [None]:
for name, url in urls.items():
    h = requests.head(url)
    print("{} -> {}".format(name, h.headers['content-type']))

In [None]:
def is_downloadable(url: str) -> bool:
    """
    Does the url contain a downloadable resource for our project.

    Args:
        url     (str): url of the source to be downloaded

    Returns:
        is_downloadable (bool): True if the source has application/x-gzip as content-type 
    """
    h = requests.head(url, allow_redirects=True)
    header = h.headers
    content_type = header.get('content-type')
    if 'application/x-gzip' in content_type.lower():
        return True
    return False

In [None]:
for name, url in urls.items():
    print("{} -> {}".format(name, is_downloadable(url)))

In order to write the content of the file to be downloaded in a file located in our machine, we have to retrieve the filename. 

For instance: `http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz` gives us the filename `t10k-labels-idx1-ubyte.gz`.

In [None]:
# retrieve the filename
for name, url in urls.items():
    if url.find('/'):
        zipped_filename = url.rsplit('/', 1)[-1]
        if zipped_filename.find('.'):
            filename = zipped_filename.rsplit('.', 1)[-2]
        print("{} -> {} | {}".format(name, zipped_filename, filename))

Once we have verified the correctness of the files to be downloaded and retrieved the filename, we can proceed to download the actual content and write it into a file in our machine. 

For this purpose, we can take advantage of the `request.get` function which, compared to the `request.head` function, returns us also the body information (i.e. the actual data we want to download).
Once the filenames are available we can save the content into them.

In the next cell, what obtain a `request.Response` object, whose content is the data that we are going to use.

In [None]:
# get the file content
r = requests.get(urls['test-labels'], stream=True)
r.raw.read(10)  # read raw content (not decoded yet)


Then we can write this content into the actual file.

In [None]:
# get the file content
r = requests.get(urls['test-labels'], stream=True)
with open('test.gz', 'wb') as f:
    f.write(r.raw.data)

Wrapping up the lines of code, we came up with the following function.

In [None]:
def download(url: str, folder: str) -> None:
    """
    Download a .gz file from the provided url and saves it to a folder.

    Args:
        url         (str): url of the source to be downloaded
        folder      (str): folder in which the file will be saved
    """
    if is_downloadable(url):
        # retrieve the filename
        for name, url in urls.items():
            if url.find('/'):
                zipped_filename = url.rsplit('/', 1)[-1]
                if zipped_filename.find('.'):
                    filename = zipped_filename.rsplit('.', 1)[-2]

        # get the file content
        r = requests.get(url, stream=True)

        # write the files
        with open("{}{}".format(folder, zipped_filename), 'wb') as f:
            f.write(r.raw.data)


The following step regards the unzipping procedure. To do so, we can exploit the [gzip](https://docs.python.org/3/library/gzip.html) library and, in particular, the `gzip.open` function along with the [shutil](https://docs.python.org/3/library/shutil.html) library with the `shutil.copyfileobj` function.

In [None]:
import gzip
import shutil

# open the zipped file
with gzip.open('test.gz', 'rb') as f_in:
    # open the uncompressed file to be filled
    with open('test', 'wb') as f_out:
        # fill the uncompressed file
        shutil.copyfileobj(f_in, f_out)

Before completining the download function, we should include some checks regarding the existance of the directory in which we are going to save the files.

Thus, we have to import the [os](https://docs.python.org/3/library/os.html) library.

In [None]:
import os

folder = './data/'

if not os.path.exists(folder):   # check for the existence of directory ./data/
    os.makedirs(folder)          # creation of the folder if it doesn't exist

We are ready to build our `download` function taking care of checking for the presence of already downloaded files (In that case we will skip the download).

In [None]:
def download(url: str, folder: str) -> None:
    """
    Download a .gz file from the provided url and saves it to a folder.

    Args:
        url         (str): url of the source to be downloaded
        folder      (str): folder in which the file will be saved
    """
    if is_downloadable(url):
        # retrieve the filename
        if url.find('/'):
            compressed_filename = url.rsplit('/', 1)[-1]
            if compressed_filename.find('.'):
                filename = compressed_filename.rsplit('.', 1)[-2]

        # get the file content
        r = requests.get(url, stream=True)

        # check for the existence of directory
        if not os.path.exists(folder):   
            # creation of the folder if it doesn't exist
            os.makedirs(folder)          

        compressed_file_path = os.path.join(folder, compressed_filename)
        file_path = os.path.join(folder, filename)

        print("Checking presence of {} ...".format(compressed_file_path))
        if not os.path.exists(compressed_file_path): # if no downloads present

            print("Downloading {} ...".format(compressed_file_path))
            
            # write the files
            with open(compressed_file_path, 'wb') as f:
                f.write(r.raw.data)
        
        else:   # if the files have already been downloaded
            print("Already downloaded.")


        print("Checking presence of uncompressed file {} ...".format(file_path))
        if not os.path.exists(file_path):   # if uncompressed file is present 

            print("Extracting {} ...".format(file_path))
            
            # open the compressed file
            with gzip.open(compressed_file_path, 'rb') as f_in:
                # open the uncompressed file to be filled
                with open(file_path, 'wb') as f_out:
                    # fill the uncompressed file
                    shutil.copyfileobj(f_in, f_out)
        
        else:   # if the files have already been downloaded
            print("Already downloaded.")
        
            


Now we can easily download the files.

In [None]:
for name, url in urls.items():
    download(url, './data/raw')

### Reading and information retrieving

Taking into account the MNIST dataset format provided in the reference at the beginning of this notebook

In [None]:
with open('test', 'rb') as f:    # open the file for reading in binary mode 'rb'                         
    m_numb_32bit = f.read(4)    # magic number
    print(m_numb_32bit)

At the moment, all the retrieved informations are in the form of bytes. 

The `f.read(4)` function call, reads 4 bytes (32 bits) at a time from the original file, then it proceeds to the next 4. And so on.

In the above variable, the information are stored in exadecimal binary format.

We can then retrieve the bytes composing the magic number and store them in a list (`m_numb_list`) in which the indices are as follows:
 - \[0\]: 0
 - \[1\]: 0
 - \[2\]: encoding number for the type of the data
 - \[3\]: number of dimensions


In [None]:
m_numb_list = [byte for byte in m_numb_32bit]
print('\nm_numb_list: ',m_numb_list)

As anticipated, the bytes of the magic number gives us some information. In particular:
- the third byte defines the type of the data (in this case it is `0x08` which encoded for `unsigned byte`)
- the fourth byte defines the number of dimensions (in this case `3`, so we have a cube)

We can then read the following bytes to retrieve some other information about the dimensions of the data:
- d_list_32bit\[0\]: size in dimension 0
- d_list_32bit\[1\]: size in dimension 1
- d_list_32bit\[2\]: size in dimension 2
- ...
- d_list_32bit\[N\]: size in dimension N

**NOTE**: here we are considering the `train-images-idx3-ubyte` file, which is referred to the training images. For the other files, more or less dimensions might be available.

In [None]:
with open('train-images-idx3-ubyte', 'rb') as f:    # open the file for reading in binary mode 'rb'       
    f.read(4)                                       # discard the first 4 bytes (magic number)
    d_list_32bit = [f.read(4) for _ in range(m_numb_list[3])]

print('d_list_32bit:', d_list_32bit)

We obtain a list of three elements, which represents the the number of data in each dimension.

We use the `struct` module from Python in order to convert the byte format into the decimal one.
In order to do so, we use the big-endian format `">"` along with the type/size of bytes that we want to unpack at a time: 
- for the informations bytes (which corresponds to the N + 1 first 4-bytes (32 bits) elements of the file, being N the number of dimensions) we use the `">I"` format (i.e. big-endian 4-bytes int)
- for the actual data we will rely on the third byte of the magic number.

In [None]:
dimensions = [struct.unpack('>I', dimension)[0] for dimension in d_list_32bit]
print('dimensions:', dimensions)


At this point we can wrap all the lines of code written so far into a prototype function which will read an input IDX file and return the list of its dimensions.

In [None]:
def read_idx_file(file_path: str) -> list:
    
    # open the file for reading in binary mode 'rb'
    with open(file_path, 'rb') as f:     
        # magic number list   
        m_numb_list = [byte for byte in f.read(4)] 
        # dimensions list  
        d_list_32bit = [f.read(4) for _ in range(m_numb_list[3])]
        # unpack data
        dimensions = [struct.unpack('>I', dimension)[0] for dimension in d_list_32bit]

        return dimensions

In [None]:
print(read_idx_file(r'train-images-idx3-ubyte'))
print(read_idx_file(r'train-labels-idx1-ubyte'))
print(read_idx_file(r't10k-images-idx3-ubyte'))
print(read_idx_file(r't10k-labels-idx1-ubyte'))

### Loading the data into a `torch.tensor`

Now we can focus on the part of the _file decoding procedure_: retrieving the data and loading it into a `torch.tensor`.

It will be useful to report the encoding strings here and to create an ad-hoc dictionary to handle it.
- 0x08: unsigned byte
- 0x09: signed byte
- 0x0B: short (2 bytes)
- 0x0C: int (4 bytes)
- 0x0D: float (4 bytes)
- 0x0E: double (8 bytes)

The encoding dictionary contains, for each exadecimal byte, the corresponding format, the standard size (used by `f.read(...)` in order to match the size defined by the format) and the PyTorch type, to be used when we will load the dataset into a `torch.tensor`.

The format that we will use to unpack the data is reported below.


In [None]:
encoding = {
      b'\x08':['B',1,torch.uint8]
    , b'\x09':['b',1,torch.int8]
    , b'\x0B':['h',2,torch.short]
    , b'\x0C':['i',4,torch.int32]
    , b'\x0D':['f',4,torch.float32]
    , b'\x0E':['d',8,torch.float64]
    }

e_format = ">" + encoding[m_numb_list[2].to_bytes(1, byteorder='big')][0]
n_bytes = encoding[m_numb_list[2].to_bytes(1, byteorder='big')][1]
d_type = encoding[m_numb_list[2].to_bytes(1, byteorder='big')][2]

Now that we have the dimensions and the encoding format we can procede retrieving the actual data.
To do so, we will consider again the case of the training images dataset trying to generalize it as much as possible, in order to exploit the code for the other files too.

In [None]:
import time

start = time.time()

with open('train-images-idx3-ubyte', 'rb') as f:    # open the file for reading in binary mode 'rb'       
    f.read(16)                                      # discard the first 16 bytes (magic number + informations)
    
    # reading all the bytes of the file progressively and store them in a torch.tensor accordingly to the dimensions
    dataset = torch.tensor([[[struct.unpack(e_format, f.read(n_bytes))[0] 
                                for _ in range(dimensions[2])] 
                                for _ in range(dimensions[1])] 
                                for _ in range(dimensions[0])]
                            , dtype=d_type)

end = time.time()
print("Loading time: {:2f}".format(end-start))


The bigger dataset, which in this case is the one of the training images, takes approximately 20s to be read and loaded into a `torch.tensor`.

We can update the previously defined function, changing a bit its name into `read_idx_file_to_tensor()` and adapting the tensor inizialization taking into account the possibility of having different dimensions

In [None]:
def read_idx_file_to_tensor(file_path: str) -> torch.tensor:
    
    # open the file for reading in binary mode 'rb'
    with open(file_path, 'rb') as f:     
        # magic number list   
        m_numb_list = [byte for byte in f.read(4)] 
        # dimensions list  
        d_list_32bit = [f.read(4) for _ in range(m_numb_list[3])]
        dimensions = [struct.unpack('>I', dimension)[0] for dimension in d_list_32bit]
        
        encoding = {
                      b'\x08':['B',1,torch.uint8]
                    , b'\x09':['b',1,torch.int8]
                    , b'\x0B':['h',2,torch.short]
                    , b'\x0C':['i',4,torch.int32]
                    , b'\x0D':['f',4,torch.float32]
                    , b'\x0E':['d',8,torch.float64]
                    }

        e_format = ">" + encoding[m_numb_list[2].to_bytes(1, byteorder='big')][0]
        n_bytes = encoding[m_numb_list[2].to_bytes(1, byteorder='big')][1]
        d_type = encoding[m_numb_list[2].to_bytes(1, byteorder='big')][2]


        if len(dimensions) == 3:    # images
           
            print('Loading {} ...'.format(file_path))
            
            dataset = torch.tensor(
                [
                    [
                        [struct.unpack(e_format, f.read(n_bytes))[0] 
                        for _ in range(dimensions[2])] 
                    for _ in range(dimensions[1])] 
                for _ in range(dimensions[0])]
                , dtype=d_type
            )

            print('{} loaded!'.format(file_path))
        

        elif len(dimensions) == 1:  # labels
        
            print('Loading {} ...'.format(file_path))

            dataset = torch.tensor(
                [struct.unpack(e_format, f.read(n_bytes))[0]
                for _ in range(dimensions[0])]
                , dtype=d_type
            )

            print('{} loaded!'.format(file_path))
        

        else:   # wrong dimensions
            raise ValueError("Invalid dimensions in the IDX file!")

        
        return dataset

Having created this main function, we are able to call it, passing the paths to the files. One possible method consists in calling the training images and the training labels inside a tuple called `training_set` which will then be just a tuple of `torch.tensor`s. The same will be done for the `test_set` tuple, as shown below.

In [None]:
training_set = (
    read_idx_file_to_tensor(r'train-images-idx3-ubyte')
    , read_idx_file_to_tensor(r'train-labels-idx1-ubyte')
)

test_set = (
    read_idx_file_to_tensor(r't10k-images-idx3-ubyte')
    , read_idx_file_to_tensor(r't10k-labels-idx1-ubyte')

)

We can check the correctness of this function by printing the first 10 image-label pairs of each dataset.

In [None]:
import matplotlib.pyplot as plt

n_images = 10

fig = plt.figure(figsize=(10,5))
for i in range(n_images):
    image = training_set[0][i]
    label = training_set[1][i].item()
    sp = fig.add_subplot(2, 5, i+1)
    sp.set_title(label)
    plt.axis('off')
    plt.imshow(image, cmap='gray')
plt.show()


In [None]:
n_images = 10

fig = plt.figure(figsize=(10,5))
for i in range(n_images):
    image = test_set[0][i]
    label = test_set[1][i].item()
    sp = fig.add_subplot(2, 5, i+1)
    sp.set_title(label)
    plt.axis('off')
    plt.imshow(image, cmap='gray')
plt.show()