# load mnist

In [1]:
import gzip
import pickle
import os
import sys
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import os

''' mnist dataset mnist.pkl.gz
contains: X_train (50000),X_vaild (10000),X_test (10000), each img of size 784
input dim:784
output_dim:(10)
'''
PY2 = sys.version_info[0] == 2

if PY2:
    from urllib import urlretrieve

    def pickle_load(f, encoding):
        return pickle.load(f)
else:
    from urllib.request import urlretrieve

    def pickle_load(f, encoding):
        return pickle.load(f, encoding=encoding)

def _load_data(url, filename):
    """Load data from `url` and store the result in `filename`."""
    if not os.path.exists(filename):
        print("Downloading MNIST dataset")
        urlretrieve(url, filename)

    with gzip.open(filename, 'rb') as f:
        return pickle_load(f, encoding='latin-1')



def load_data(filename, url=None):
    """Get data with labels, split into training, validation and test set."""
    data = _load_data(url,filename)
    X_train, y_train = data[0]
    X_valid, y_valid = data[1]
    X_test, y_test = data[2]



    return dict(
        X_train=X_train,
        y_train=y_train,
        X_valid=X_valid,
        y_valid=y_valid,
        X_test=X_test,
        y_test=y_test,
        num_examples_train=X_train.shape[0],
        num_examples_valid=X_valid.shape[0],
        num_examples_test=X_test.shape[0],
        input_dim=X_train.shape[1],
        output_dim=10)

# create sequence

In [4]:
''' Parameters:
ORG_SHP:  digit image shape 
OUT_SHP: output cluttered_MNIST image shape
NUM_DISTORTIONS: number of distortions set in the output image
dist_shape: shape of each distortion
NUM_DISTORTIONS_DB: length of  distortions list
'''
ORG_SHP = [28,28]
OUT_SHP = [100,100]
NUM_DISTORTIONS = 6
dist_size = (9,9)  
NUM_DISTORTIONS_DB = 100000

mnist_data = load_data('data/mnist.pkl.gz')
outfile = "data/cluttered_mnist_100x100_6distortions"

np.random.seed(1234)

### create list with distortions
all_digits = np.concatenate([mnist_data['X_train'], mnist_data['X_valid']], axis=0)
all_digits = all_digits.reshape([-1] + ORG_SHP) #(600000,28,28)
num_digits = all_digits.shape[0] 

distortions = []
'''create a list of different distortions
shape of each distortion: dist_size
length of the list: NUM_DISTORTIONS_DB
'''
for i in range(NUM_DISTORTIONS_DB):
    rand_digit = np.random.randint(num_digits)
    rand_x = np.random.randint(ORG_SHP[1]-dist_size[1])
    rand_y = np.random.randint(ORG_SHP[0]-dist_size[0])

    digit = all_digits[rand_digit]
    distortion = digit[rand_y:rand_y + dist_size[0],
                       rand_x:rand_x + dist_size[1]]
    assert distortion.shape == dist_size
    distortions += [distortion]
print("Created distortions")


Created distortions


In [3]:
def create_sample1(x, output_shp, num_distortions=NUM_DISTORTIONS):
    ''' combine digitals with distortions, the True digit set in the center of output image
    Parameters:
    x (np.array,list): list of n True digital images,  dim=(n,28,28), n=len(dataset)
    output_shp: output shape of the True digit
    '''
    a, b = x[0].shape
    x_offset = (output_shp[1]-len(x)*a)//2 #center of the image
    y_offset = (output_shp[1]-len(x)*a)//2 #center of the image
    x_offset += np.random.choice(range(int(-2*x_offset/3), int(2*x_offset/3))) # set the offset of  x randomly
    y_offset += np.random.choice(range(int(-2*y_offset/3), int(2*y_offset/3)))  #set the offset of y randomly
    angle = np.random.choice(range(int(-b*0.5), int(b*0.5))) # set the angle randomly

    output = np.zeros(output_shp)
    for i,digit in enumerate(x):
        x_start = i*b + x_offset

        x_end = x_start + b
        y_start = y_offset + np.floor(i*angle)
        y_end = y_start + a
        if y_end > (output_shp[1]-1):
            m = output_shp[1] - y_end
            y_end += m
            y_start += m
        if y_start < 0:
            m = y_start
            y_end -= m
            y_start -= m
        y_start,y_end=int(y_start),int(y_end)
       
        output[y_start:y_end, x_start:x_end] = digit

    if num_distortions > 0:
            output = add_distortions(output, num_distortions) #add different distotions
    return output

def sample_digits(n, x, y,out_shp=None):
    ''' reshape x  to (28,28)'''
    if out_shp is None:
        shp = x.shape[1]
    else:
        shp = out_shp
    #n_samples = x.shape[0]
    #idxs = np.random.choice(range(n_samples), replace=True, size=n)
    return [x[n].reshape(shp)],[y[n]]


def add_distortions(digits, num_distortions):
    ''' choose num_distortions diff distortions and add them to the output image'''
    canvas = np.zeros_like(digits)
    for i in range(num_distortions):
        rand_distortion = distortions[np.random.randint(NUM_DISTORTIONS_DB)]
        rand_x = np.random.randint(OUT_SHP[1]-dist_size[1])
        rand_y = np.random.randint(OUT_SHP[0]-dist_size[0])
        canvas[rand_y:rand_y+dist_size[0],
               rand_x:rand_x+dist_size[1]] = rand_distortion
    canvas += digits

    return np.clip(canvas, 0, 1)


def create_dataset(n, X, labels, org_shp, out_shp):
    '''create cluttered_mnist dataset'''
    out_X, out_lab = [], []
    out_X = np.zeros((n, np.prod(out_shp)))
    for i in range(n):
        if (i+1) % 1000 == 0:
            print(i)
        x_, y_ = sample_digits(i, X, labels, org_shp)
        digits = create_sample1(x_, out_shp)

        digits = digits.reshape(-1)
        y_ = np.array(y_)
        out_X[i, ] = digits
        out_lab.append(y_)

    return out_X.astype('float32'), np.vstack(out_lab).astype('int32')

In [5]:
N_TRAIN = 50000
N_VALID = 10000
N_TEST = 10000
X_train, y_train = create_dataset(N_TRAIN, mnist_data['X_train'],
                                  mnist_data['y_train'], ORG_SHP, OUT_SHP
                                  )
X_valid, y_valid = create_dataset(N_VALID, mnist_data['X_valid'],
                                  mnist_data['y_valid'], ORG_SHP, OUT_SHP
                                  )
X_test, y_test = create_dataset(N_TEST, mnist_data['X_test'],
                                mnist_data['y_test'], ORG_SHP, OUT_SHP
                                )
np.savez_compressed(
    outfile,
    X_train=X_train,
    y_train=y_train,
    X_valid=X_valid,
    y_valid=y_valid,
    X_test=X_test,
    y_test=y_test)
## create train, valid, and test sets

999
1999
2999
3999
4999
5999
6999
7999
8999
9999
10999
11999
12999
13999
14999
15999
16999
17999
18999
19999
20999
21999
22999
23999
24999
25999
26999
27999
28999
29999
30999
31999
32999
33999
34999
35999
36999
37999
38999
39999
40999
41999
42999
43999
44999
45999
46999
47999
48999
49999
999
1999
2999
3999
4999
5999
6999
7999
8999
9999
999
1999
2999
3999
4999
5999
6999
7999
8999
9999


# convert to .JPG format

In [7]:
'''for a more intuitive result, convert the dataset to .JPG format'''
import numpy as np
import os
import cv2
path='data/cluttered_mnist_100x100_6distortions.npz'
#path="mnist_cluttered_60x60_6distortions.npz"
img_dim=100
data = np.load(path)
X_train, y_train = data['X_train'],data['y_train']
X_test,y_test=data['X_test'],data['y_test']
X_train = X_train.reshape((X_train.shape[0], img_dim, img_dim))
X_test=X_test.reshape(X_test.shape[0],img_dim,img_dim)

digit_size=28
mnist_train_X, mnist_train_y = mnist_data['X_train'],mnist_data['y_train']
mnist_test_X,mnist_test_y=mnist_data['X_test'],mnist_data['y_test']
mnist_train_X = mnist_train_X.reshape((mnist_train_X.shape[0], digit_size,digit_size ))
mnist_test_X=mnist_test_X.reshape(mnist_test_X.shape[0],digit_size,digit_size)

In [14]:
def JPG_img(path,X,y):
    for i,(data,label) in enumerate(zip(X,y)):
        if type(label) ==np.int64:
            path_folder=path+'/'+str(label)
        else:
            path_folder=path+'/'+str(label[0])
        if not os.path.exists(path_folder):
            os.makedirs(path_folder)
        img=np.uint8(data*255)
        cv2.imwrite(path_folder+'/{}.jpg'.format(i),img)

In [15]:
mnist_train=JPG_img('data/MNIST/train',mnist_train_X,mnist_train_y)
mnist_test=JPG_img('data/MNIST/test',mnist_test_X,mnist_test_y)
print('mnist dataset convert to JPG')
cluttered_train=JPG_img('data/cluttered_mnist/train',X_train,y_train)
clutered_test=JPG_img('data/cluttered_mnist/test',X_test,y_test)
print('cluttered mnist dataset convert to JPG')

mnist dataset convert to JPG
cluttered mnist dataset convert to JPG
