In [1]:
def load_mnist(path, kind='train'):
    import os
    import gzip
    import numpy as np

    """Load MNIST data from `path`"""
    labels_path = os.path.join(path,
                               '%s-labels-idx1-ubyte.gz'
                               % kind)
    images_path = os.path.join(path,
                               '%s-images-idx3-ubyte.gz'
                               % kind)

    with gzip.open(labels_path, 'rb') as lbpath:
        labels = np.frombuffer(lbpath.read(), dtype=np.uint8,
                               offset=8)

    with gzip.open(images_path, 'rb') as imgpath:
        images = np.frombuffer(imgpath.read(), dtype=np.uint8,
                               offset=16).reshape(len(labels), 784)

    return images, labels


In [2]:
filePath = 'C:/Users/dvjr2/Google Drive/Documents/Syracuse/IST_718_BigDataAnalytics/Labs/Lab_003'

In [3]:
test = load_mnist(filePath)

In [4]:
type(test)

tuple

In [5]:
test[1]

array([5, 0, 4, ..., 5, 6, 8], dtype=uint8)

In [8]:
test[0][0:5]

array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]], dtype=uint8)

In [11]:
len(test[0])

60000

In [19]:
import matplotlib.pyplot as plt
import numpy as np
import os
import subprocess
from datetime import datetime
from threading import Thread, Event

def get_sprite_image(to_visualise, do_invert=True):
    to_visualise = vector_to_matrix_mnist(to_visualise)
    if do_invert:
        to_visualise = invert_grayscale(to_visualise)
    return create_sprite_image(to_visualise)

def vector_to_matrix_mnist(mnist_digits):
    import numpy as np
    """Reshapes normal mnist digit (batch,28*28) to matrix (batch,28,28)"""
    return np.reshape(mnist_digits, (-1, 28, 28))

def invert_grayscale(mnist_digits):
    """ Makes black white, and white black """
    return 255 - mnist_digits

def create_sprite_image(images):
    import numpy as np
    """Returns a sprite image consisting of images passed as argument. Images should be count x width x height"""
    if isinstance(images, list):
        images = np.array(images)
    img_h = images.shape[1]
    img_w = images.shape[2]
    n_plots = int(np.ceil(np.sqrt(images.shape[0])))

    spriteimage = np.ones((img_h * n_plots, img_w * n_plots))

    for i in range(n_plots):
        for j in range(n_plots):
            this_filter = i * n_plots + j
            if this_filter < images.shape[0]:
                this_img = images[this_filter]
                spriteimage[i * img_h:(i + 1) * img_h,
                j * img_w:(j + 1) * img_w] = this_img

    return spriteimage


X, Y = load_mnist(path=filePath, kind='t10k')

labels = ['t_shirt_top', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle_boots']
Y_str = np.array([labels[j] for j in Y])
np.savetxt('Xtest.txt', X, fmt='%.6e', delimiter='\t')
np.savetxt('Ytest.txt', Y_str, fmt='%s')

plt.imsave('zalando-mnist-sprite.png', get_sprite_image(X), cmap='gray')