# Generate Double-Mnist Data
a simple dataset in which each image contains an even and an odd mnist digit

#### Load MNIST Data

In [None]:
import numpy as np
na = np.newaxis
np.random.seed(7160)

import matplotlib.pyplot as plt

In [None]:
from keras.datasets import mnist
(XTrain, YTrain), (XTest, YTest) = mnist.load_data()

#### Helper Functions

In [None]:
def crop_digit(digit):
    
    H, W = digit.shape
    rowsumsnn = digit.sum(0) > 0
    colsumsnn = digit.sum(1) > 0
    hmin = np.argmax(colsumsnn)
    hmax = H - np.argmax(colsumsnn[::-1])
    wmin = np.argmax(rowsumsnn)
    wmax = W - np.argmax(rowsumsnn[::-1])
    
    return digit[hmin:hmax, wmin:wmax]

def crop_and_randomplace_digits(d1, d2, canvassize=(100,100), shuffle=True):
    
    if shuffle and np.random.randint(2):
        # change which digit is placed first
        d1, d2 = d2, d1
        
    canvas = np.zeros(canvassize)
    d1c = crop_digit(d1)
    d2c = crop_digit(d2)
    d1pos = np.random.randint(0, canvas.shape[0] - d1c.shape[0]), np.random.randint(0, canvas.shape[1] - d1c.shape[1])

    while True:
        # position second digit randomly, if the two digits overlap sample a new position
        d2pos = np.random.randint(0, canvas.shape[0] - d2c.shape[0]), np.random.randint(0, canvas.shape[1] - d2c.shape[1])
        hoverlap = (d2pos[0] <= d1pos[0] and d1pos[0] < d2pos[0] + d2c.shape[0]) or (d1pos[0] <= d2pos[0] and d2pos[0] < d1pos[0] + d1c.shape[0])
        woverlap = (d2pos[1] <= d1pos[1] and d1pos[1] < d2pos[1] + d2c.shape[1]) or (d1pos[1] <= d2pos[1] and d2pos[1] < d1pos[1] + d1c.shape[1])
        overlap = hoverlap and woverlap 
        if not overlap:
            break
    
    canvas[d1pos[0]:d1pos[0]+d1c.shape[0], d1pos[1]:d1pos[1]+d1c.shape[1]] = d1c
    canvas[d2pos[0]:d2pos[0]+d2c.shape[0], d2pos[1]:d2pos[1]+d2c.shape[1]] = d2c
    
    return canvas

def sample_oddeven_dataset(X, Y, shuffle=False, maxlength=None):
    if maxlength is None:
        maxlength = len(X)

    # even and odd digits
    evenodd = np.arange(10).reshape(5, 2).T
    even, odd = evenodd

    # boolean masks for even and odd digits
    evenodd_masks = (np.sum(Y[:, na, na] == evenodd[na, ...], axis=2) > 0).T

    if shuffle:
        indices = np.random.permutation(Y.shape[0])
    else:
        indices = np.arange(Y.shape[0])

    even_indices = indices[evenodd_masks[0]]
    odd_indices  = indices[evenodd_masks[1]]

    maxlength = np.min([len(even_indices), len(odd_indices), maxlength])

    # generate the joint images
    images = []
    labels = []
    for im_idx, (e_idx, o_idx) in enumerate(zip(even_indices[:maxlength], odd_indices[:maxlength])):
        images.append(crop_and_randomplace_digits(X[e_idx], X[o_idx]))
        labels.append([Y[e_idx], Y[o_idx]])
        
    return np.asarray(images), np.asarray(labels)

#### Generate Dataset

In [None]:
joint_X, joint_Y = sample_oddeven_dataset(XTrain, YTrain, True, 20)

#### Qualitative Test

In [None]:
for i in range(10):
    plt.imshow(joint_X[i], cmap='gray')
    plt.title(joint_Y[i])
    plt.show()