## Train LeNet5 on MNIST dataset using TensorFlow

#### Notebook setup

In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:95% !important; }</style>"))
%matplotlib inline
# %matplotlib qt
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import time
from dataset import MNISTDataset
from utils import Utils
from lenet5 import Lenet5

In [None]:
# numpy print options
np.set_printoptions(linewidth = 150)
np.set_printoptions(edgeitems = 10)

In [None]:
# use a constant seed in order to reproduce results
seed = 112358
nprg = np.random.RandomState(seed)

#### Import MNIST dataset

In [None]:
MNIST_TRAIN_IMAGES_FILEPATH = 'MNIST_dataset/train-images.idx3-ubyte'
MNIST_TRAIN_LABELS_FILEPATH = 'MNIST_dataset/train-labels.idx1-ubyte'
MNIST_TEST_IMAGES_FILEPATH = 'MNIST_dataset/t10k-images.idx3-ubyte'
MNIST_TEST_LABELS_FILEPATH = 'MNIST_dataset/t10k-labels.idx1-ubyte'

mnist_ds = MNISTDataset(MNIST_TRAIN_IMAGES_FILEPATH, 
                        MNIST_TRAIN_LABELS_FILEPATH,
                        MNIST_TEST_IMAGES_FILEPATH,
                        MNIST_TEST_LABELS_FILEPATH)


In [None]:
mnist_ds.enhance_with_random_rotate(ratio = 2)

In [None]:
mnist_ds.enhance_with_random_zoomin(ratio = 2)

In [None]:
mnist_ds.enhance_with_random_zoomin_and_rotate(ratio = 2)

#### Analyze data

In [None]:
print(mnist_ds.summary)

In [None]:
# plot a sample from each train, validation and test set
nlines = 5
ncols = 25

for dataset in [mnist_ds.train, mnist_ds.validation, mnist_ds.test]:  
    # plot train images and print train labels
    plt.figure(figsize=(nlines, ncols), dpi = 150)
    train_sample_indices = nprg.choice(a = dataset.num_examples, size = nlines * ncols)
    plt.imshow(Utils.concat_images(dataset.images[train_sample_indices], mnist_ds.image_size, nlines, ncols), cmap='gray_r')
    plt.show()
    print(np.argmax(dataset.labels[train_sample_indices], axis=1).reshape(nlines, ncols))


In [None]:
# print label distribution of each train, validation and test set
for dataset in [mnist_ds.train, mnist_ds.validation, mnist_ds.test]:  
    plt.figure(figsize=(15, 5))
    plt.hist(np.argmax(dataset.labels, axis=1), bins=np.arange(11), align='left', rwidth=0.5, normed=False)
    plt.xticks(range(0, 10))
    plt.xlabel('digit')
    plt.ylabel('frequency')
    plt.show()


#### Build and train LeNet5 model using TensorFlow

In [None]:
lenet5_model = Lenet5(mnist_ds, "zoomin_and_rotate_x2_allDigits_dropoutAfterF5F6",
              epochs=40, batch_size=128, variable_mean=0, variable_stddev=0.1,
              learning_rate=0.001,
              drop_out_keep_prob=0.5)

In [None]:
lenet5_model.train()

In [None]:
# test using test_data method
test_loss, test_acc, total_predict, total_actual, wrong_predict_images = Lenet5(mnist_ds,"temp").test_data(mnist_ds.test)
print('test_loss = {:.3f}, test_acc = {:.3f}'.format(test_loss,test_acc))
print('#wrong_predicted_images = {}'.format(len(wrong_predict_images)))

In [None]:
# sort wrong_predict_images by target label
wrong_predict = total_predict[total_actual != total_predict]
wrong_actual = total_actual[total_actual != total_predict]
wrong_predict_images = np.array(wrong_predict_images)
wrong_predict_images_sorted = wrong_predict_images[wrong_actual.argsort(), ]
wrong_predict_images_sorted = [image for image in wrong_predict_images_sorted]
# plot wrong_predicted_images file
from training_plotter import TrainingPlotter
TrainingPlotter.combine_images(wrong_predict_images_sorted, "wrong_predicted_after_restore_session.png")

In [None]:
# test again using predict_images method
preds = lenet5_1.predict_images(mnist_ds.test.images)
target_labels = np.argmax(mnist_ds.test.labels, axis=1)
predicted_labels = np.argmax(preds, axis=1)
print('Targets: \n', target_labels)
print('Predictions: \n', predicted_labels)

# sort wrong_predict_images by target label
wrong_predict = predicted_labels[target_labels != predicted_labels]
wrong_actual = target_labels[target_labels != predicted_labels]
wrong_predicted_images = mnist_ds.test.images[target_labels != predicted_labels]
wrong_predict_images_sorted = wrong_predict_images[wrong_actual.argsort(), ]
print(wrong_predict_images_sorted.shape)

# plot wrong predicted images, sorted by target label
nlines = 5
ncols = 25
plt.figure(figsize=(25, 10), dpi = 150)
plt.imshow(Utils.concat_images(wrong_predict_images_sorted, mnist_ds.image_size, nlines, ncols), cmap='gray_r')
plt.title(str(wrong_predicted_labels))
plt.show()
