# Dataset creation for the tutorial

In [1]:
# System modules
import os
import pathlib

# Tensorflow modules
import tensorflow as tf
import tensorflow_datasets as tfds

# Math modules
import numpy as np
import scipy

# Image modules
import PIL
import PIL.Image
from skimage.transform import resize

ModuleNotFoundError: No module named 'skimage'

### Load MNIST dataset

####  Here is a link to download MNIST_M dataset
[MNIST_M Dataset](https://drive.google.com/file/d/0B9Z4d7lAwbnTNDdNeFlERWRGNVk/view)

In [18]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

### Formate MNIST images to MNIST_M images format 
#### In order to create an unique model for multiple dataset, we need to rescale the images to the same dimensions (28 \* 28 \* 1 -> 32 \* 32 \* 3)

In [4]:
def reformate_list(values):
    """ Reformat a list of images to 32*32*3 """
    res = []
    for value in values:
        im = []
        
        # Transform 1 channel into 3 channels
        for i in range(len(value[0])):
            row = []
            for j in range(len(value[1])):
                pix = value[i][j]
                row.append([pix, pix, pix])
            im.append(row)
            
        # Resize to 32*32
        im = np.asarray(im)
        res.append(resize(im, (32,32)) * 255)
    return res

##### Reformating all dataset takes time so use those functions only once !

In [54]:
resized_xtrain = reformate_list(x_train)
resized_xtest = reformate_list(x_test)

CPU times: user 11min 35s, sys: 6min 26s, total: 18min 2s
Wall time: 6min 4s


### Save new dataset for next usage
#### To avoid losing our data and make the dataset creation easier on tensorflow, we will save them in directories with a specific architecture. 
##### Be sure to change the paths according to your own case.

In [6]:
# Global path variables (change them if you need to)
origin = os.getcwd()
home_dir = origin + "/data/MNIST_reformat/"

def create_dataset_arch():
    """ Create one directory for each category of the dataset (10 here)"""
    try: 
        os.mkdir(home_dir)
    except OSError:
        print("Creation dir failed")
    else:
        print("Successfully created")


    for i in range(10):
        try:
            os.mkdir(home_dir + str(i))
        except OSError:
            print("Creation dir " + str(i) + " failed")

In [7]:
def fill_dir(samples, labels, id=0):
    """ Fill the directories with the dataset depending of their labels """
    length = len(samples)
    
    for i in range(length):
        arr = np.asarray(samples[i]).astype(np.uint8)
        img = PIL.Image.fromarray(arr)
        img.save(home_dir + str(labels[i]) + "/" + str(id) + ".png")
        id += 1
    
    print("Successfully saved !")

##### Filling directories is a long process, be sure to use this cell only once !

In [84]:
fill_dir(resized_xtrain, y_train)
fill_dir(resized_xtest, y_test, 60000)

Successfully saved !
Successfully saved !
CPU times: user 1min 50s, sys: 1min 3s, total: 2min 54s
Wall time: 13min 2s
