# Online Local Adaptive Model - Notebook 1

* Prior Probability Shift is one of the common problems encountered in Machine Learning algortihms.   
* There are some approaches for dealing with this problem in a 'static' scenario. But there are situations in which we need a model which deals with secvential data as input (e.g. a server which gets input from different users, with different data distributions).   
* In this project, we try to build a model which self adapts its predictions based on the local label distribution. 

### About notebook 1

In this notebook we implement the standard version of Lenet5 architecture and test it on the entire MNIST dataset (which has a uniform label distribution)

#### LeNet5 model used (with 28x28 inputs):
![title](https://cdnpythonmachinelearning.azureedge.net/wp-content/uploads/2017/09/lenet-5.png?x64257)

### Notebook setup

In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))
%matplotlib inline
# %matplotlib qt
%load_ext autoreload
%autoreload 2

### Imports

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
import time
from collections import deque
import os
import pickle
from training_plotter import TrainingPlotter
from dataset import MNISTDataset
import utils
from lenet5 import Lenet5
from lenet5_with_distr import Lenet5WithDistr

# numpy print options
np.set_printoptions(linewidth = 150)
np.set_printoptions(edgeitems = 10)
np.set_printoptions(precision=3)

### Set seed

In [None]:
# create a random generator using a constant seed in order to reproduce results
seed = 112358
nprg = np.random.RandomState(seed)

### Import MNIST dataset

In [None]:
MNIST_TRAIN_IMAGES_FILEPATH = 'MNIST_dataset/train-images.idx3-ubyte'
MNIST_TRAIN_LABELS_FILEPATH = 'MNIST_dataset/train-labels.idx1-ubyte'
MNIST_TEST_IMAGES_FILEPATH = 'MNIST_dataset/t10k-images.idx3-ubyte'
MNIST_TEST_LABELS_FILEPATH = 'MNIST_dataset/t10k-labels.idx1-ubyte'

mnist_ds = MNISTDataset(MNIST_TRAIN_IMAGES_FILEPATH, 
                        MNIST_TRAIN_LABELS_FILEPATH,
                        MNIST_TEST_IMAGES_FILEPATH,
                        MNIST_TEST_LABELS_FILEPATH)


### Data augmentation

In [None]:
mnist_ds.enhance_with_random_rotate(ratio = 2)

In [None]:
mnist_ds.enhance_with_random_zoomin(ratio = 2)

In [None]:
mnist_ds.enhance_with_random_zoomin_and_rotate(ratio = 2)

### Analyze dataset

In [None]:
print(mnist_ds.summary)

In [None]:
def plot_images_sample(images, image_size, labels, nlines, ncols):
    plt.figure(figsize=(8, 3), dpi = 150)
    train_sample_indices = np.random.choice(a = dataset.num_examples, size = nlines * ncols, replace=False)
    plt.imshow(utils.concat_images(images[train_sample_indices], image_size, nlines, ncols), cmap='gray_r')
    plt.title(str(np.argmax(labels[train_sample_indices], axis=1).reshape(nlines, ncols)), fontsize=8)
    plt.show()

In [None]:
# plot a sample from each train, validation and test set
for dataset in [mnist_ds.train, mnist_ds.validation, mnist_ds.test]:  
    plot_images_sample(dataset.images, mnist_ds.image_size, dataset.labels, nlines=5, ncols=20)

In [None]:
# plot label distribution of each train, validation and test set
plt.figure(figsize=(30, 3))

plt.subplot(131)
plt.hist(np.argmax(mnist_ds.train.labels, axis = 1), bins=np.arange(11), align='left', rwidth=0.8, normed=False)
plt.xticks(range(0, 10))
plt.xlabel('digit')
plt.ylabel('frequency')
plt.title('Train set distribution')

plt.subplot(132)
plt.hist(np.argmax(mnist_ds.validation.labels, axis = 1), bins=np.arange(11), align='left', rwidth=0.8, normed=False)
plt.xticks(range(0, 10))
plt.xlabel('digit')
plt.ylabel('frequency')
plt.title('Validation set distribution')

plt.subplot(133)
plt.hist(np.argmax(mnist_ds.test.labels, axis = 1), bins=np.arange(11), align='left', rwidth=0.8, normed=False)
plt.xticks(range(0, 10))
plt.xlabel('digit')
plt.ylabel('frequency')
plt.title('Test set distribution')

plt.show()

### Train a model

In [None]:
lenet5_model = Lenet5(mnist_ds, "zoomin_and_rotate_x2_allDigits_dropoutAfterF5F6",
              epochs=40, batch_size=128, variable_mean=0, variable_stddev=0.1,
              learning_rate=0.001,
              drop_out_keep_prob=0.5)

In [None]:
lenet5_model.train()

### Analyze results

#### 1. Test  a previous trained method on all test examples

In [None]:
# test using test_data method
temp = Lenet5(mnist_ds,"temp")
# temp.restore_session(ckpt_dir='./results/', ckpt_filename='Lenet5_allDigits_dropoutAfterF5F6_2018_02_10_02_59.model.ckpt')
# temp.restore_session(ckpt_dir='./results/', ckpt_filename='Lenet5_rotated_x2_allDigits_dropoutAfterF5F6_2018_02_10_23_50.model.ckpt')
temp.restore_session(ckpt_dir='./results/', ckpt_filename='Lenet5_zoomin_and_rotate_x2_allDigits_dropoutAfterF5F6_2018_03_18---15_46.model.ckpt')
test_loss, test_acc, total_predict, total_actual, wrong_predict_images, _ = temp.test_data(mnist_ds.test)
print('test_loss = {:.3f}, test_acc = {:.3f} ({}/{})'.format(test_loss,test_acc,  mnist_ds.test.num_examples - len(wrong_predict_images), mnist_ds.test.num_examples))

In [None]:
# sort wrong_predict_images by target label and plot them to file
wrong_predict = total_predict[total_actual != total_predict]
wrong_actual = total_actual[total_actual != total_predict]
wrong_predict_images = np.array(wrong_predict_images)
wrong_predict_images_sorted = wrong_predict_images[wrong_actual.argsort(), ]
wrong_predict_images_sorted = [image for image in wrong_predict_images_sorted]
from training_plotter import TrainingPlotter
TrainingPlotter.combine_images(wrong_predict_images_sorted, "wrong_predicted_after_restore_session.png")

In [None]:
# print target and predicted label distributions of wrong predicted examples
plt.figure(figsize=(20, 3))

plt.subplot(121)
plt.hist(wrong_predict, bins=np.arange(11), align='left', rwidth=0.8, normed=False)
plt.xticks(range(0, 10))
plt.xlabel('digit')
plt.ylabel('frequency')
plt.title('predicted label distribution corresponding to wrong predictions')

plt.subplot(122)
plt.hist(wrong_actual, bins=np.arange(11), align='left', rwidth=0.8, normed=False)
plt.xticks(range(0, 10))
plt.xlabel('digit')
plt.ylabel('frequency')
plt.title('target label distribution corresponding to wrong predictions')

plt.show()

#### 2.Test the model  on a sample of images

In [None]:
# choose all test images as sample
test_sample_size = mnist_ds.test.num_examples
test_sample_images, test_sample_labels = mnist_ds.test.images, mnist_ds.test.labels

In [None]:
# or choose a sample w.r.t. a given distribution
print('Counts per class:{}'.format(mnist_ds.test.counts_per_class))
test_sample_size = 3000
test_sample_weights = np.array([1, 4, 2, 7, 4, 12, 44, 33, 22, 11])
test_sample_weights = test_sample_weights / np.sum(test_sample_weights)

plt.bar(range(0,10), test_sample_weights)
plt.xticks(range(0,10))
plt.title('sample distribution')
plt.show()

test_sample_images, test_sample_labels = mnist_ds.test.next_batch(test_sample_size, weights = test_sample_weights)

In [None]:
# test using predict_images method
temp_model = Lenet5(mnist_ds,"temp")
# temp_model.restore_session(ckpt_dir='./results/', ckpt_filename='Lenet5_allDigits_dropoutAfterF5F6_2018_02_10_02_59.model.ckpt')
temp_model.restore_session(ckpt_dir='./results/', ckpt_filename='Lenet5_zoomin_and_rotate_x2_allDigits_dropoutAfterF5F6_2018_03_18---15_46.model.ckpt')
preds = temp_model.predict_images(test_sample_images)
target_labels = np.argmax(test_sample_labels, axis=1)
predicted_labels = np.argmax(preds, axis=1)
count_correct_predicted = np.sum(target_labels == predicted_labels)
print('Accuracy: {:.3f} ({}/{})'.format(count_correct_predicted / test_sample_size, count_correct_predicted, test_sample_size))

In [None]:
# sort wrong_predict_images by target label and plot them
wrong_predict = predicted_labels[target_labels != predicted_labels]
wrong_actual = target_labels[target_labels != predicted_labels]
wrong_predicted_images = test_sample_images[target_labels != predicted_labels]
wrong_predict_images_sorted = wrong_predicted_images[wrong_actual.argsort()]
wrong_actual_sorted = wrong_actual[wrong_actual.argsort()]
wrong_predict_sorted = wrong_predict[wrong_actual.argsort()]

plt.figure(figsize=(15, 3), dpi = 120)
plt.imshow(utils.concat_images(wrong_predict_images_sorted, mnist_ds.image_size, num_images_on_x = 5, num_images_on_y = 20), cmap='gray_r')
plt.title("Actual: {}\nPredicted: {}".format(np.array(wrong_actual_sorted), str(wrong_predict_sorted)), fontsize = 8)
plt.show()

# print target and predicted label distributions of wrong predicted examples
plt.figure(figsize=(20, 3))

plt.subplot(121)
plt.hist(wrong_predict, bins=np.arange(11), align='left', rwidth=0.8, normed=True)
plt.xticks(range(0, 10))
plt.xlabel('digit')
plt.ylabel('frequency')
plt.title('predicted label distribution corresponding to wrong predictions')

plt.subplot(122)
plt.hist(wrong_actual, bins=np.arange(11), align='left', rwidth=0.8, normed=True)
plt.xticks(range(0, 10))
plt.xlabel('digit')
plt.ylabel('frequency')
plt.title('target label distribution corresponding to wrong predictions')

plt.show()