In [2]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from skimage.transform import resize
import math

In [34]:
class BatchGenerator2(tf.keras.utils.Sequence):
    # Generate (x1,x2,x3,x4), x_tot for mini batches (x1-4 are quarters of x_tot)

    def __init__(self, x_set, batch_size):
        self.batch_size = batch_size
        self.xs = x_set
        self.on_epoch_end()

    def __len__(self):
        return math.ceil(len(self.xs)/self.batch_size)

    def split(self, arr, nrows, ncols):
        r, h = arr.shape[:-1]
        return (arr.reshape(h//nrows, nrows, -1, ncols).swapaxes(1, 2).reshape(-1, nrows, ncols, 1)) 
    
    def embed(self, batch):
        imgs = np.zeros((self.batch_size, 64, 64, 1))
        locs = np.zeros((self.batch_size, 2))
        for k in range(len(imgs)):
            i, j = np.random.randint(0,3), np.random.randint(0,3)
            x, y = (i*16, 32+i*16), (j*16, 32+j*16)
            loc  = [(i+1)*16, (j+1)*16]
            np.append(locs, loc)
            imgs[k][x[0]:x[1], y[0]:y[1]] = batch[k]
        return imgs, locs

    def __getitem__(self, idx):
        batch = self.xs[idx*self.batch_size : (idx+1)*self.batch_size]
        imgs, locs = self.embed(batch)
        
        x1s, x2s, x3s, x4s = [], [], [], []
        for x in imgs:
            #x1,x2,x3,x4 = self.split(x, 14, 14)
            x1,x2,x3,x4 = self.split(x, 32, 32) 
            x1s.append(x1)
            x2s.append(x2)
            x3s.append(x3)
            x4s.append(x4)
        xs = [np.array(x1s), np.array(x2s), np.array(x3s), np.array(x4s)]
        return xs, [imgs,locs]

    def on_epoch_end(self):
        self.xs = np.random.permutation(self.xs)

In [35]:
(x_train, _), (x_test, _) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)

x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.

x_train = resize(x_train, (x_train.shape[0], 32, 32, 1))
x_test = resize(x_test, (x_test.shape[0], 32, 32, 1))

batch_size = 32
train_gen = BatchGenerator2(x_train, batch_size)
test_gen = BatchGenerator2(x_test, batch_size)

In [36]:
for (xs, ys), i in zip(train_gen, range(1)):
    print(ys[1].shape)

(32, 2)
