In [None]:
#!/usr/bin/env python
# coding: utf-8

from __future__ import division

import numpy as np
import sys

import import_ipynb
from data_generators import mnist_generator

# %matplotlib notebook
%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
# Set RNG seeds, for repeatability

seed = 0
np.random.seed(seed)

In [None]:
# Locate MNIST files

# Download the MNIST data located here: http://yann.lecun.com/exdb/mnist/
# and set dir_mnist to the location of your downloaded data:
dir_mnist = './mnist'
# dir_mnist = '/home/mroos/Data/pylearn2data/mnist'

In [None]:
## Get all the images

# Get uniform noise samples
batch_size = 10 # make this a multiple of 5
gen_data = mnist_generator(dir_mnist, batch_size=batch_size, dataset='test',
                           random_order=False, null_types='u', p_null_class=100)
im_uniform, label, _ = next(gen_data)

# Get mixed noise samples
gen_data = mnist_generator(dir_mnist, batch_size=batch_size, dataset='test',
                           random_order=False, null_types='m', p_null_class=100)
im_mixed, label, _ = next(gen_data)


# Get shuffled noise samples
im_shuff = []
batch_size_shuff = batch_size//5
for tile_size in [1, 2, 4, 7, 14]:
    gen_data = mnist_generator(dir_mnist, batch_size=batch_size_shuff, dataset='test',
                               random_order=False, null_types='s', p_null_class=100, tile_size=tile_size)
    im, label, _ = next(gen_data)
    im_shuff.append(im)

In [None]:
## Plot images

plt.figure(figsize=(1.5*batch_size, 1.5*3))
for i in range(batch_size):
    ax = plt.subplot(3, batch_size, i+1)
    plt.imshow(np.reshape(im_uniform[i], (28, 28)), cmap='gray', aspect='equal')
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)            

for i in range(batch_size):
    ax = plt.subplot(3, batch_size, batch_size+i+1)
    plt.imshow(np.reshape(im_mixed[i], (28, 28)), cmap='gray', aspect='equal')
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)            
    
cnt = 0
for i_tile in range(5):
    for i_im in range(batch_size_shuff):
        cnt += 1
        ax = plt.subplot(3, batch_size, 2*batch_size+cnt)
        plt.imshow(np.reshape(im_shuff[i_tile][i_im], (28, 28)), cmap='gray', aspect='equal')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        
plt.savefig('figures/fig_null_image_examples.png', bbox_inches='tight')