Import the required packages

In [None]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt

Make plots grayscale and display inline. Fix random seed.

In [None]:
%matplotlib inline
plt.rcParams['image.cmap'] = 'gray'
np.random.seed(seed=485)

Create a custom plot fn for matrix data. pyplot's built in "matshow" will not let us tile multiple images.

In [None]:
def plot_image(image):
    nr, nc = image.size() 
    extent = [-0.5, nc-0.5, nr-0.5, -0.5] 
    plt.imshow(image.numpy(), extent=extent, origin='upper', interpolation='nearest') 

Download MNIST data or load it if already downloaded

In [None]:
mnist_train = torchvision.datasets.MNIST(root='data', train=True, download=True) # train data only

Check the data shape

In [None]:
print("Data shape: {}".format(mnist_train.train_data.size()))
print("Labels shape: {}".format(mnist_train.train_labels.size()))

The dataset is ordered (starting with examples labeled '0', ending with examples labeled '9'). To generate train examples, use 60,000 random indices into the dataset.

In [None]:
indices = torch.randperm(60000)
trainimages = mnist_train.train_data[indices]
trainlabels = mnist_train.train_labels[indices]

Plot the first image

In [None]:
plot_image(trainimages[0])
plt.title("First entry in MNIST dataset")
plt.show()

Mean of all images in training set

In [None]:
mean = torch.mean(trainimages.float(), dim=0)
plot_image(mean)
plt.title("Mean of entries in MNIST dataset")
plt.show()

Show pixels that are zero for all images in training set as black, pixels that are nonzero for at least one image as white. 

In [None]:
nonzero = torch.gt(torch.sum(trainimages.float(), dim=0), 0.0)
plot_image(nonzero)
plt.show()


Define a function for displaying a stack of images.
imgstack is a list of mxn images 

In [None]:
def montage(imgstack):
    plt.figure()
    width = int(np.ceil(np.sqrt(len(imgstack))))
    height = int(np.floor(np.sqrt(len(imgstack))))
    for i in range(0, len(imgstack)):
        plt.subplot(height, width, i+1)
        plot_image(imgstack[i])
    plt.show()

Show the first 16 train images

In [None]:
montage(trainimages[0:9])

Display labels of the images above

In [None]:
trainlabels[0:9]

Compute mean of images in each digit class

In [None]:
trainmeans = torch.zeros((10,28,28))
for i in range(10):
    indices = (trainlabels == i).nonzero()[:,0]
    images = trainimages[indices]
    trainmeans[i] = torch.mean(images.float(), dim=0)

In [None]:
montage([trainmeans[i] for i in range(10)])

Compute distribution over digit classes in training set

In [None]:
plt.hist(trainlabels.numpy(), bins=[-0.5 + i for i in range(11)])
plt.xlim([-0.5, 9.5])
plt.show()