In [1]:
from keras.models import Model
from keras.layers import Flatten, Conv2D, Conv2DTranspose, Dense, Input, Reshape
from keras.utils import plot_model
import keras
import tensorflow as tf

# encoder
latents = []
enc_inputs = [Input(shape=(14,14,1)) for _ in range(16)]
conv1 = Conv2D(filters=32, kernel_size=(3,3), strides=(2,2), activation='relu', padding='same')
conv2 = Conv2D(filters=64, kernel_size=(3,3), strides=(2,2), activation='relu', padding='same')
flatten = Flatten()
latent = Dense(units=4, activation='sigmoid')

for i in range(16):
    enc_out = latent(flatten(conv2(conv1(enc_inputs[i]))))
    latents.append(enc_out)
latent_concat = keras.layers.concatenate(latents, name='latent_concat')

# decoder
dec_input = Input(shape=(4*16,), name='dec_input')
x = Dense(units=1568, activation='relu', name='dense')(dec_input)
x = Reshape(target_shape=(14,14,8), name='reshape')(x)
x = Conv2DTranspose(filters=64, kernel_size=3, strides=2, activation='relu', padding='same', name='deconv1')(x)
x = Conv2DTranspose(filters=32, kernel_size=3, strides=2, activation='relu', padding='same', name='deconv2')(x)
dec_output = Conv2DTranspose(filters=1, kernel_size=3, padding='same', name='pt_conv')(x)


class customModel(Model):
    def test_step(self, data):
        x, y, y_regr = data
        y_pred = self(x, training=False)
        y_pred = tf.image.extract_glimpse(y_pred, (28,28), y_regr, centered=False, normalized=False, noise='zero')
        self.compiled_loss(y, y_pred)
        self.compiled_metrics.update_state(y, y_pred)
        return {m.name: m.result() for m in self.metrics}
        
        
encoders = [Model(enc_inputs[i], latents[i], name='Encoder_%d'%i) for i in range(16)]
decoder = Model(dec_input, dec_output, name="Decoder")
model = customModel(enc_inputs, decoder(latent_concat))

model.compile(loss='mse', optimizer='adam')
# print(model.summary())

In [2]:
import keras
import numpy as np

def split(im, nrows, ncols):
    # split 'im' (w*h*c) array into n equal parts.
    width, height = im.shape[:-1]
    im = im.reshape(height//nrows, nrows, -1, ncols)  # split in n 2d arrays along cols
    im = im.swapaxes(1, 2)     # restore order: zig-zag
    im = im.reshape(-1, nrows, ncols, 1) # x 2d arrays with new dims + channel
    return im

def embed_and_translate(data, n_width, n_height):
    ndata = np.zeros((len(data), n_width, n_height, 1), dtype='float32')
    translations = np.empty((len(data), 2), dtype='float32')
    width, height = data.shape[1], data.shape[2]
    for i in range(len(data)):
        x = np.random.randint(n_width-width)
        y = np.random.randint(n_height-height)
        ndata[i][x:x+width, y:y+height] = data[i] # rows, cols = height, width
        translations[i][0] = x+(width//2)
        translations[i][1] = y+(height//2)
    return ndata, translations

# !!! eval loss at right position!

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_train = x_train.astype('float32') / 255.
x_train_augmented, _ = embed_and_translate(x_train, 56, 56)
x_train = None
x_train_split = np.array([split(x, 14, 14) for x in x_train_augmented], dtype='float32')
x_train_split = x_train_split.swapaxes(0,1)
x_train_split = [x for x in x_train_split]

x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
x_test = x_test.astype('float32') / 255.
x_test_augmented, y_test_regr = embed_and_translate(x_test, 56, 56)
x_test_split = np.array([split(x, 14, 14) for x in x_test_augmented], dtype='float32')
x_test_augmented = None
x_test_split = x_test_split.swapaxes(0,1)
x_test_split = [x for x in x_test_split]

In [3]:
model.fit(x_train_split, x_train_augmented, validation_data=(x_test_split, x_test, y_test_regr), epochs=25)

Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoch 12/25
Epoch 13/25
Epoch 14/25
Epoch 15/25
Epoch 16/25
Epoch 17/25
Epoch 18/25
Epoch 19/25
Epoch 20/25
Epoch 21/25
Epoch 22/25
Epoch 23/25
Epoch 24/25
Epoch 25/25


<tensorflow.python.keras.callbacks.History at 0x7f852871b130>