In [1]:
from keras.preprocessing.image import ImageDataGenerator
from mnist_helpers import *
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import cv2
import time

Using Theano backend.


In [2]:
mnist = input_data.read_data_sets('MNIST_data/', one_hot=False)
datagen = ImageDataGenerator(rotation_range=15, width_shift_range=0.1, 
    height_shift_range=0.1, zoom_range=0.1)

train_x = np.reshape(mnist.train.images, [-1, 28, 28, 1])
train_y = mnist.train.labels
valid_x = np.reshape(mnist.validation.images, [-1, 28, 28, 1])
valid_y = mnist.validation.labels

aug_x = np.concatenate([train_x, valid_x], axis=0)
aug_y = np.concatenate([train_y, valid_y], axis=0)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [3]:
print(aug_x.shape)
print(aug_y.shape)

(60000, 28, 28, 1)
(60000,)


In [4]:
def elastic_transform_wrapper_no_channels(tensor, kernel_dim=13, sigma=6, alpha=36, negated=False):
    # [N, W, H]
    N = tensor.shape[0]
    ret = []
    for i in range(N):
        distorted = elastic_transform(tensor[i], kernel_dim=kernel_dim, sigma=sigma, alpha=alpha, negated=negated)
        ret.append(distorted)
    
    return np.array(ret)

In [17]:
distorted_x = []
distorted_y = []

print("start distortion ...")
start = time.time()
epoch = 0
epoch_n = 10
for x, y in datagen.flow(aug_x, aug_y, batch_size=60000, shuffle=False):
    reshaped = x.reshape([-1, 28, 28])
    distorted = elastic_transform_wrapper_no_channels(reshaped, negated=True)
    distorted_x.append(distorted.reshape([-1, 784]))
    distorted_y.append(y)
    epoch += 1
    
    # 1 에폭당 약 300~400초 정도 걸릴것.
    print("[{}/{}]: {:.1f}".format(epoch, epoch_n, time.time()-start))
    if epoch == epoch_n:
        break

start distortion ...
[1/10]: 324.1
[2/10]: 651.9
[3/10]: 977.5
[4/10]: 1301.7
[5/10]: 1626.4
[6/10]: 1953.1
[7/10]: 2281.6
[8/10]: 2606.2
[9/10]: 2932.9
[10/10]: 3256.2


In [27]:
dx_npy = np.array(distorted_x)

In [31]:
dy_npy = np.array(distorted_y)

In [22]:
import sys

In [24]:
sys.getsizeof(distorted_x)

200

In [29]:
dx_npy.shape

(10, 60000, 784)

In [32]:
dy_npy.shape

(10, 60000)

In [33]:
np.save("distorted_x", dx_npy)

In [34]:
np.save("distorted_y", dy_npy)

In [52]:
images = dx_npy[0]

In [53]:
labels = dy_npy[0]

In [54]:
images.shape

(60000, 784)

In [55]:
labels.shape

(60000,)

In [61]:
# >>> a = np.array([1, 0, 3])
# >>> b = np.zeros((3, 4))
# >>> b[np.arange(3), a] = 1
def one_hot(dense, ndim=10):
    N = dense.shape[0]
    ret = np.zeros([N, ndim])
    ret[np.arange(N), dense] = 1
    return ret

In [62]:
oh_labels = one_hot(labels)

In [65]:
oh_labels.shape

(60000, 10)

In [73]:
train = np.concatenate([images, oh_labels], axis=1)

In [74]:
train.shape

(60000, 794)

In [81]:
x_batch = train[0:16, :784]

In [83]:
y_batch = train[0:16, 784:]

In [82]:
x_batch.shape

(16, 784)

In [84]:
y_batch.shape

(16, 10)

In [85]:
x_batch[0]

array([ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.  

In [86]:
y_batch[0]

array([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.])