<div class="alert alert-block alert-info" align="center">
    <h1>
        Imports
    </h1>
</div>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

from tensorflow.keras.preprocessing.image import random_shift, random_shear, random_rotation, random_zoom
from keras.datasets import fashion_mnist
from keras.utils import to_categorical
from keras.preprocessing.image import ImageDataGenerator

<div class="alert alert-block alert-info" align="center">
    <h1>
        Functions
    </h1>
</div>

In [None]:
def plot_images(imgs, labels=None, rows=1, figsize=(20,8), fontsize=14):
    figure = plt.figure(figsize=figsize)
    cols = max(1,len(imgs) // rows-1)
    labels_present = False
    # checking if labels is a numpy array
    if type(labels).__module__ == np.__name__:
        labels_present=labels.any()
    elif labels:
        labels_present=True
    for i in range(len(imgs)):
        subplot = figure.add_subplot(rows, cols+1, i+1)
        # axis off, but leave a bounding box
        plt.tick_params(
            axis='both',
            which='both',
            bottom='off',
            top='off',
            left='off',
            right='off',
            labelbottom='off',
            labelleft='off')
        # plot labels if present
        if labels_present:
            subplot.set_title(labels[i], fontsize=fontsize)
        plt.imshow(imgs[i][:,:,0], cmap='Greys')
        
    plt.show()

# Extract n random samples of each class from the dataset
def get_random_sample(number_of_samples=10):
    x = []
    y = []
    for category_number in range(0,10):
        # get all samples of a category
        train_data_category = train_data[train_labels==category_number]
        # pick a number of random samples from the category
        train_data_category = train_data_category[np.random.randint(train_data_category.shape[0], 
                                                                    size=number_of_samples), :]
        x.extend(train_data_category)
        y.append([category_number]*number_of_samples)
    
    return np.asarray(x).reshape(-1, 28, 28, 1), y

<div class="alert alert-block alert-info" align="center">
    <h1>
        Generating data
    </h1>
</div>

In [None]:
(raw_train_data, raw_train_labels), (raw_test_data, raw_test_labels) = fashion_mnist.load_data()

# normalizing
test_data = np.asarray(raw_test_data / 255.0 , dtype=float)
train_data = np.asarray(raw_train_data / 255.0 , dtype=float)
test_labels= np.asarray(raw_test_labels , dtype=np.int32)
train_labels = np.asarray(raw_train_labels , dtype=np.int32)

In [None]:
x_ten_samples, y_ten_samples = get_random_sample(number_of_samples=5)
 
y_ten_samples = to_categorical(y_ten_samples)
    
print(np.shape(x_ten_samples))
print(np.shape(y_ten_samples))
plot_images(x_ten_samples, rows=10, figsize=(20,20))

In [None]:
print(np.shape(train_data))
print(np.shape(train_labels))

# reshape the images to 4D tensors
x_train_data = train_data.reshape(-1, 28, 28, 1)
y_train_data = to_categorical(train_labels)
 
print(np.shape(x_train_data))
print(np.shape(y_train_data))
 
x_test_data = test_data.reshape(-1, 28, 28, 1)
y_test_data = to_categorical(test_labels)
print(np.shape(x_test_data))
print(np.shape(y_test_data))

In [None]:
img = x_ten_samples[1]
plot_images([img])

In [None]:
img_shifted = [ random_shift(
    img, 
    wrg=0.1, 
    hrg=0.2, 
    row_axis=0, 
    col_axis=1, 
    channel_axis=2, 
    fill_mode='constant',
    cval=0
) for _ in range(5) ]
plot_images(img_shifted)

img_rotated = [ random_rotation(
    img, 
    20, 
    row_axis=0, 
    col_axis=1, 
    channel_axis=2, 
        fill_mode='constant',
    cval=0
) for _ in range(5) ]
plot_images(img_rotated)

img_sheared = [ random_shear(
    img,
    intensity=0.5,
    row_axis=0,
    col_axis=1,
    channel_axis=2,
    fill_mode='constant',
    cval=0
) for _ in range(5) ]
plot_images(img_sheared)

img_zoomed = [ random_zoom(
    img,
    zoom_range=(0.7,1.3),
    row_axis=0,
    col_axis=1,
    channel_axis=2,
    fill_mode='constant',
    cval=0
) for _ in range(5) ]
plot_images(img_zoomed)

In [None]:
datagen = ImageDataGenerator(
        rotation_range=20,
        width_shift_range=0.1,
        height_shift_range=0.1,
        shear_range=0.5,
        zoom_range=(0.9, 1.1),
        horizontal_flip=False,
        vertical_flip=False, 
        fill_mode='constant',
        cval=0
)

batches = 0
max_batches = 10
img_gen = []
for x_batch in datagen.flow(img.reshape((1,) + img.shape), batch_size=max_batches):
    img_gen.append(x_batch[0])
    batches += 1
    if batches >= max_batches:
        # generator loops indefinetly
        break
        
plot_images(img_gen, rows=10, figsize=(20,16))