In [0]:
# Useful libraries
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from google.colab import files
import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams['mathtext.fontset'] = 'stix'
plt.rcParams['font.family'] = 'STIXGeneral'
plt.rcParams['font.size'] = 16
plt.rcParams['lines.linewidth'] = 3

# Download Moving-MNIST dataset
!wget 'http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy'

# Custom metrics
def cor(y_gt, y_pred):
  return tfp.stats.correlation(y_gt, y_pred, sample_axis=None, event_axis=None)
def ssim(y_gt, y_pred, max_val=1.):
  return tf.image.ssim(y_gt, y_pred, max_val)
def psnr(y_gt, y_pred, max_val=1.):
  return tf.image.psnr(y_gt, y_pred, max_val)

In [0]:
#
# Parameters
#
loss = 'logcosh'
optimizer = tf.keras.optimizers.Adam(lr=1e-4)
activ = 'relu'
init = 'he_normal'
nk = 128
ks = 5
lks = 3
#
Ninput = 10
Noutput = 5
#
Nsample = 100
bs = 10
ep = 4

In [0]:
#
# Data processing
#
def preprocess(data):
  data = data/255.
  data = data.swapaxes(0, 1)
  data = np.expand_dims(data, -1)
  return data
def split_train_test(dataset, n):
  data = np.load(dataset, mmap_mode='r')
  train = data[:, :n, :, :]
  test = data[:, n:n+n//10, :, :]
  train = preprocess(train)
  test = preprocess(test)
  return train, test
def make_XY(data, Ninput, Noutput):
  X_motion = data[:, :Ninput, :, :, :]
  y = data[:, Ninput:Ninput+Noutput, :, :, :]
  X_content = X_motion[:, -1, :, :, :]
  return X_content, X_motion, y
data_train, data_test = split_train_test('mnist_test_seq.npy', Nsample)
H, W, C = data_train.shape[-3:]
X_content, X_motion, y = make_XY(data_train, Ninput, Noutput)
Xtest_content, Xtest_motion, ytest = make_XY(data_test, Ninput, Noutput)

In [0]:
# Model Architecture
def model_init():
  # Inputs
  inputs_motion = tf.keras.Input(shape=[None, H, W, C])
  inputs_content = tf.keras.Input(shape=[H, W, C])
  # Motion Encoder
  x = tf.keras.layers.TimeDistributed(tf.keras.layers.Conv2D(nk, ks, padding='same', activation=activ, kernel_initializer=init))(inputs_motion)
  res1a = tf.keras.layers.Lambda(lambda x: x[:,-1])(x)
  x = tf.keras.layers.TimeDistributed(tf.keras.layers.MaxPooling2D())(x)
  x = tf.keras.layers.TimeDistributed(tf.keras.layers.Conv2D(nk, ks, padding='same', activation=activ, kernel_initializer=init))(x)
  res2a = tf.keras.layers.Lambda(lambda x: x[:,-1])(x)
  x = tf.keras.layers.TimeDistributed(tf.keras.layers.MaxPooling2D())(x)
  x = tf.keras.layers.LayerNormalization()(x)
  x = tf.keras.layers.ConvLSTM2D(nk, ks, padding='same', activation=activ, kernel_initializer=init, return_sequences=True)(x)
  ME = tf.keras.layers.ConvLSTM2D(nk, ks, padding='same', activation=activ, kernel_initializer=init)(x)
  # Content Encoder
  x = tf.keras.layers.Conv2D(nk, ks, padding='same', activation=activ, kernel_initializer=init)(inputs_content)
  x = tf.keras.layers.Conv2D(nk, ks, padding='same', activation=activ, kernel_initializer=init)(x)
  res1b = tf.keras.layers.BatchNormalization()(x)
  x = tf.keras.layers.MaxPooling2D()(res1b)
  x = tf.keras.layers.Conv2D(nk, ks, padding='same', activation=activ, kernel_initializer=init)(x)
  x = tf.keras.layers.Conv2D(nk, ks, padding='same', activation=activ, kernel_initializer=init)(x)
  res2b = tf.keras.layers.BatchNormalization()(x)
  x = tf.keras.layers.MaxPooling2D()(res2b)
  x = tf.keras.layers.Conv2D(nk, ks, padding='same', activation=activ, kernel_initializer=init)(x)
  CE = tf.keras.layers.Conv2D(nk, ks, padding='same', activation=activ, kernel_initializer=init)(x)
  # Combination layers
  x = tf.keras.layers.Concatenate()([CE, ME])
  x = tf.keras.layers.Conv2D(nk, ks, padding='same', activation=activ, kernel_initializer=init)(x)
  x = tf.keras.layers.Conv2D(nk, ks, padding='same', activation=activ, kernel_initializer=init)(x)
  x = tf.keras.layers.Conv2D(nk, ks, padding='same', activation=activ, kernel_initializer=init)(x)
  x = tf.keras.layers.BatchNormalization()(x)
  # Decoder layers
  x = tf.keras.layers.UpSampling2D()(x)
  x = tf.keras.layers.Concatenate()([x, res2a, res2b])
  x = tf.keras.layers.Conv2D(nk, ks, padding='same', activation=activ, kernel_initializer=init)(x)
  x = tf.keras.layers.Conv2D(nk, ks, padding='same', activation=activ, kernel_initializer=init)(x)
  x = tf.keras.layers.BatchNormalization()(x)
  x = tf.keras.layers.UpSampling2D()(x)
  x = tf.keras.layers.Concatenate()([x, res1a, res1b])
  x = tf.keras.layers.Conv2D(nk, ks, padding='same', activation=activ, kernel_initializer=init)(x)
  x = tf.keras.layers.Conv2D(nk, ks, padding='same', activation=activ, kernel_initializer=init)(x)
  x = tf.keras.layers.BatchNormalization()(x)
  # Prediction
  x = tf.stack([x, tf.zeros_like(x), tf.zeros_like(x), tf.zeros_like(x), tf.zeros_like(x)], axis=1)
  x = tf.keras.layers.ConvLSTM2D(nk, ks, padding='same', activation=activ, kernel_initializer=init, return_sequences=True)(x)
  x = tf.keras.layers.ConvLSTM2D(nk, ks, padding='same', activation=activ, kernel_initializer=init, return_sequences=True)(x)
  x = tf.keras.layers.LayerNormalization()(x)
  preds = tf.keras.layers.Conv3D(1, lks, padding='same', activation='sigmoid',
                                 bias_initializer=tf.keras.initializers.Constant(value=-np.log(99)))(x)
  return tf.keras.Model(inputs=[inputs_motion, inputs_content], outputs=preds)

In [0]:
# Initialize model
model = model_init()

# Build optimizer
model.compile(optimizer=optimizer,
              loss=loss,
              metrics=['acc', ssim, psnr, cor,
                       tf.keras.metrics.Precision(name='prec'),
                       tf.keras.metrics.Recall(name='recall')])

# Fit the model
history = model.fit([X_motion, X_content], y, 
                    batch_size=bs, 
                    epochs=ep,
                    validation_split=0.1)

# Get results on test set
results = model.evaluate([Xtest_motion, Xtest_content], ytest, batch_size=bs, return_dict=True)

In [0]:
plt.figure(figsize=(12,6))
plt.plot(history.history['loss'], label='Train')
plt.plot(history.history['val_loss'], label='Val')
plt.axhline(results['loss'], linestyle='--', color='k', label='Test')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(loc='upper left')
plt.savefig('loss.png')
#files.download('loss.png')
#
plt.figure(figsize=(12,6))
plt.plot(history.history['prec'], label='Train')
plt.plot(history.history['val_prec'], label='Val')
plt.axhline(results['prec'], linestyle='--', color='k', label='Test')
plt.ylabel('prec')
plt.xlabel('epoch')
plt.legend(loc='upper left')
plt.savefig('prec.png')
#files.download('prec.png')
#
plt.figure(figsize=(12,6))
plt.plot(history.history['recall'], label='Train')
plt.plot(history.history['val_recall'], label='Val')
plt.axhline(results['recall'], linestyle='--', color='k', label='Test')
plt.ylabel('recall')
plt.xlabel('epoch')
plt.legend(loc='upper left')
plt.savefig('recall.png')
#files.download('recall.png')
#
plt.figure(figsize=(12,6))
plt.plot(history.history['cor'], label='Train')
plt.plot(history.history['val_cor'], label='Val')
plt.axhline(results['cor'], linestyle='--', color='k', label='Test')
plt.ylabel('cor')
plt.xlabel('epoch')
plt.legend(loc='upper left')
plt.savefig('cor.png')
#files.download('cor.png')
#
plt.figure(figsize=(12,6))
plt.plot(history.history['acc'], label='Train')
plt.plot(history.history['val_acc'], label='Val')
plt.axhline(results['cor'], linestyle='--', color='k', label='Test')
plt.ylabel('acc')
plt.xlabel('epoch')
plt.legend(loc='upper left')
plt.savefig('acc.png')
#files.download('acc.png')
#
plt.figure(figsize=(12,6))
plt.plot(history.history['ssim'], label='Train')
plt.plot(history.history['val_ssim'], label='Val')
plt.axhline(results['cor'], linestyle='--', color='k', label='Test')
plt.ylabel('ssim')
plt.xlabel('epoch')
plt.legend(loc='upper left')
plt.savefig('ssim.png')
#files.download('ssim.png')
#
plt.figure(figsize=(12,6))
plt.plot(history.history['psnr'], label='Train')
plt.plot(history.history['val_psnr'], label='Val')
plt.axhline(results['cor'], linestyle='--', color='k', label='Test')
plt.ylabel('psnr')
plt.xlabel('epoch')
plt.legend(loc='upper left')
plt.savefig('psnr.png')
#files.download('psnr.png')

In [0]:
# Plot Train
itest = 1
track_motion = X_motion[itest, :, :, :, :]
track_content = X_content[itest, :, :, :]
track = data_train[itest, :Ninput, :, :, :]
track = np.concatenate((track[None,:,:,:,:], model.predict([track_motion[None, :, :, :, :], track_content[None, :,:,:]])), axis=1)
true_track = data_train[itest, :, :, :, :]
for i in range(Ninput+Noutput):
    plt.figure(figsize=(10, 5))
    plt.subplot(121)
    if i >= Ninput:
        plt.text(1, 3, 'Prediction', color='w')
    else:
        plt.text(1, 3, 'Initial trajectory', color='w')
    plt.imshow(track[0, i, :, :, 0])
    plt.subplot(122)
    plt.text(1, 3, 'Ground truth', color='w')
    plt.imshow(true_track[i, :, :, 0])
    plt.savefig('train_%i.png' % (i+1))
    #files.download('train_%i.png' % (i+1))

In [0]:
# Plot Test
itest = 1
track_motion = Xtest_motion[itest, :, :, :, :]
track_content = Xtest_content[itest, :, :, :]
track = data_test[itest, :Ninput, :, :, :]
track = np.concatenate((track[None,:,:,:,:], model.predict([track_motion[None, :, :, :, :], track_content[None, :,:,:]])), axis=1)
true_track = data_test[itest, :, :, :, :]
for i in range(Ninput+Noutput):
    plt.figure(figsize=(10, 5))
    plt.subplot(121)
    if i >= Ninput:
        plt.text(1, 3, 'Prediction', color='w')
    else:
        plt.text(1, 3, 'Initial trajectory', color='w')
    plt.imshow(track[0, i, :, :, 0])
    plt.subplot(122)
    plt.text(1, 3, 'Ground truth', color='w')
    plt.imshow(true_track[i, :, :, 0])
    plt.savefig('test_%i.png' % (i+1))
    #files.download('test_%i.png' % (i+1))