# Utilities

In [None]:
%matplotlib inline

In [None]:
import functools
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data

In [None]:
# Constants
TENSOR_SERVER_MODEL_SPEC_NAME = "mnist-model"

LINEAR_MODEL_CHECKPOINT_PREFIX = "tensorflow-checkpoints/mnist-model/linear-model"
LINEAR_MODEL_TENSOR_SERVER_DIR = "tensorflow-server-models/mnist-model/1"
LINEAR_MODEL_TENSOR_SERVER_VERSION = 1

NEURAL_NETWORK_MODEL_CHECKPOINT_PREFIX = "tensorflow-checkpoints/mnist-model/neural-network-model"
NEURAL_NETWORK_MODEL_TENSOR_SERVER_DIR = "tensorflow-server-models/mnist-model/2"
NEURAL_NETWORK_MODEL_TENSOR_SERVER_VERSION = 2

In [None]:
def display_mnist_images(images):
    plt.rcParams['image.interpolation'] = 'nearest'
    plt.rcParams['image.cmap'] = 'gray'
    nrows, ncols = images.shape
    fig, axs = plt.subplots(1, nrows, figsize=(25, 3))
    for i in range(nrows):
        image = (images[i].reshape(28, 28) * 255).astype(np.uint8)
        axs.flat[i].imshow(image)
    plt.show()

In [None]:
# ensure mnist images are only loaded once
@functools.lru_cache()
def get_mnist_data():
    return input_data.read_data_sets("mnist_data/", one_hot=True)

In [None]:
def sample_images(n, images):
    nrows, _ = images.shape
    row_idxs = np.random.randint(0, nrows - 1, n)
    return images[row_idxs, :]

In [None]:
def sample_and_predict_one_model(mnist_images, model, sample_count=10):
    samples = sample_images(sample_count, mnist_images)
    display_mnist_images(samples)
    return model.predict(samples)

In [None]:
def sample_and_predict_two_models(mnist_images, model1, model2, sample_count=10):
    samples = sample_images(sample_count, mnist_images)
    display_mnist_images(samples)
    return model1.predict(samples), model2.predict(samples)

In [None]:
# load MNIST data
mnist_data = get_mnist_data()