# Generate Double-Mnist Data
a simple dataset in which each image contains an even and an odd mnist digit

#### Load MNIST Data

In [None]:
import os
import numpy as np
na = np.newaxis
np.random.seed(7160)

import matplotlib.pyplot as plt

In [None]:
from keras.datasets import mnist
(XTrain, YTrain), (XTest, YTest) = mnist.load_data()

#### Helper Functions

In [None]:
from skimage.transform import resize

In [None]:
def crop_and_rescale_digit(digit, rescale_range = [.8, 2.0]):
    
    H, W = digit.shape
    rowsumsnn = digit.sum(0) > 0
    colsumsnn = digit.sum(1) > 0
    hmin = np.argmax(colsumsnn)
    hmax = H - np.argmax(colsumsnn[::-1])
    wmin = np.argmax(rowsumsnn)
    wmax = W - np.argmax(rowsumsnn[::-1])
    
    digit_cropped = digit[hmin:hmax, wmin:wmax]
    
    scale_factor = np.random.uniform(*rescale_range)
    if not np.isclose(scale_factor, 1.):
        digit_out = resize(digit_cropped, (int(digit_cropped.shape[0] * scale_factor), int(digit_cropped.shape[1] * scale_factor))) # , anti_aliasing=False)
    else:
        digit_out = digit_cropped
    return digit_out

def crop_and_randomplace_digits(d1, d2, canvassize=(96,96), shuffle=True):
    
    if shuffle and np.random.randint(2):
        # change which digit is placed first
        d1, d2 = d2, d1
        
    canvas = np.zeros(canvassize)
    
    # position first digit, then try at most 20 times to randomly fit the second one, if not possible sample new first position
    while True:
        d1c = crop_and_rescale_digit(d1)
        d2c = crop_and_rescale_digit(d2)
        d1pos = np.random.randint(0, canvas.shape[0] - d1c.shape[0]), np.random.randint(0, canvas.shape[1] - d1c.shape[1])

        for i in range(20):
            # position second digit randomly, if the two digits overlap sample a new position
            d2pos = np.random.randint(0, canvas.shape[0] - d2c.shape[0]), np.random.randint(0, canvas.shape[1] - d2c.shape[1])
            hoverlap = (d2pos[0] <= d1pos[0] and d1pos[0] < d2pos[0] + d2c.shape[0]) or (d1pos[0] <= d2pos[0] and d2pos[0] < d1pos[0] + d1c.shape[0])
            woverlap = (d2pos[1] <= d1pos[1] and d1pos[1] < d2pos[1] + d2c.shape[1]) or (d1pos[1] <= d2pos[1] and d2pos[1] < d1pos[1] + d1c.shape[1])
            overlap = hoverlap and woverlap 
            if not overlap:
                break
        if not overlap:
            break
    
    canvas[d1pos[0]:d1pos[0]+d1c.shape[0], d1pos[1]:d1pos[1]+d1c.shape[1]] = d1c
    canvas[d2pos[0]:d2pos[0]+d2c.shape[0], d2pos[1]:d2pos[1]+d2c.shape[1]] = d2c
    
    return canvas

def sample_oddeven_dataset(X, Y, shuffle=False, maxlength=None):
    if maxlength is None:
        maxlength = len(X)

    # even and odd digits
    evenodd = np.arange(10).reshape(5, 2).T
    even, odd = evenodd

    # boolean masks for even and odd digits
    evenodd_masks = (np.sum(Y[:, na, na] == evenodd[na, ...], axis=2) > 0).T
    even_indices = np.arange(Y.shape[0])[evenodd_masks[0]]
    odd_indices  = np.arange(Y.shape[0])[evenodd_masks[1]]
    
    if shuffle:
        even_indices = np.random.permutation(even_indices)
        odd_indices  = np.random.permutation(odd_indices)

    maxlength = np.min([len(even_indices), len(odd_indices), maxlength])

    # generate the joint images
    images = []
    labels = []
    for im_idx, (e_idx, o_idx) in enumerate(zip(even_indices[:maxlength], odd_indices[:maxlength])):
        #if maxlength % im_idx == 5000:
        #    print('{}/{}'.format(im_idx+1, maxlength))
        images.append(crop_and_randomplace_digits(X[e_idx], X[o_idx]))
        labels.append([Y[e_idx], Y[o_idx]])
        
    return np.asarray(images), np.asarray(labels)

#### Generate Dataset

In [None]:
joint_X, joint_Y = sample_oddeven_dataset(XTrain, YTrain, True)
assert(np.alltrue((joint_Y.sum(1) % 2) == 1)), "ERROR: each image should contain exactly one odd and one even number"
joint_X_test, joint_Y_test = sample_oddeven_dataset(XTest, YTest, True)
assert(np.alltrue((joint_Y.sum(1) % 2) == 1)), "ERROR: each image should contain exactly one odd and one even number"

In [None]:
# train set:
train = {'data': joint_X, 'labels': joint_Y}
test  = {'data': joint_X_test, 'labels': joint_Y_test}

trainpath = os.path.join('../datasets/doublemnist-train')
testpath  = os.path.join('../datasets/doublemnist-test')

np.savez(trainpath, **train)
np.savez(testpath,  **test)

#### Qualitative Test

In [None]:
train = np.load(trainpath + '.npz')
test  = np.load(testpath  + '.npz')
dataset = {'train': (train['data'], train['labels']),
                'test': (test['data'], test['labels'])}
dataset['valid'] = dataset['test']

In [None]:
for i in range(10):
    plt.imshow(dataset['test'][0][i], cmap='gray')
    plt.title( dataset['test'][1][i])
    plt.show()