In [0]:
# Useful libraries
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from google.colab import files
import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams['mathtext.fontset'] = 'stix'
plt.rcParams['font.family'] = 'STIXGeneral'
plt.rcParams['font.size'] = 16
plt.rcParams['lines.linewidth'] = 3

# Download Moving-MNIST dataset
!wget 'http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy'

# Custom metrics
def cor(y_gt, y_pred):
  return tfp.stats.correlation(y_gt, y_pred, sample_axis=None, event_axis=None)
def ssim(y_gt, y_pred, max_val=1.):
  return tf.image.ssim(y_gt, y_pred, max_val)
def psnr(y_gt, y_pred, max_val=1.):
  return tf.image.psnr(y_gt, y_pred, max_val)

In [0]:
#
# Parameters
#
loss = 'logcosh'
optimizer = tf.keras.optimizers.Adam(lr=1e-4)
activ = 'selu'
init = 'lecun_normal'
nk = 128
ks = 5
lks = 3
#
Ninput = 10
Noutput = 5
#
Nsample = 100
bs = 10
ep = 4

In [0]:
#
# Data processing
#
def preprocess(data):
  data = data/255.
  data = data.swapaxes(0, 1)
  data = np.expand_dims(data, -1)
  return data
def split_train_test(dataset, n):
  data = np.load(dataset, mmap_mode='r')
  train = data[:, :n, :, :]
  test = data[:, n:n+n//10, :, :]
  train = preprocess(train)
  test = preprocess(test)
  return train, test
def make_XY(data, Ninput, Noutput):
  X = data[:, :Ninput, :, :, :]
  y = data[:, Ninput:Ninput+Noutput, :, :, :]
  return X, y
data_train, data_test = split_train_test('mnist_test_seq.npy', Nsample)
H, W, C = data_train.shape[-3:]
X, y = make_XY(data_train, Ninput, Noutput)
Xtest, ytest = make_XY(data_test, Ninput, Noutput)

In [0]:
# Model Architecture
def model_init():
  inputs = tf.keras.Input(shape=[None, H, W, C])
  # Encoder
  conv = tf.keras.layers.TimeDistributed(tf.keras.layers.Conv2D(nk, ks, padding='same',
                                                                activation=activ, kernel_initializer=init))(inputs)
  conv = tf.keras.layers.TimeDistributed(tf.keras.layers.Conv2D(nk, ks, padding='same',
                                                                activation=activ, kernel_initializer=init))(conv)
  LN = tf.keras.layers.LayerNormalization()(conv)
  CL1, cl1_h, cl1_c, _, _ = tf.keras.layers.Bidirectional(tf.keras.layers.ConvLSTM2D(nk, ks, padding='same',
                                                                                     activation=activ, kernel_initializer=init,
                                                                                     return_sequences=True, return_state=True))(LN)
  CL2, cl2_h, cl2_c = tf.keras.layers.ConvLSTM2D(nk, ks, padding='same',
                                                 activation=activ, kernel_initializer=init,
                                                 return_state=True)(CL1)
  # Decoder 1
  input_dec = tf.stack([CL2, tf.zeros_like(CL2), tf.zeros_like(CL2), tf.zeros_like(CL2), tf.zeros_like(CL2)], axis=1)
  CL3 = tf.keras.layers.ConvLSTM2D(nk, ks, padding='same',
                                   activation=activ, kernel_initializer=init,
                                   return_sequences=True)(input_dec, initial_state=[cl1_h, cl1_c])
  CL4 = tf.keras.layers.ConvLSTM2D(nk, ks, padding='same',
                                   activation=activ, kernel_initializer=init,
                                   return_sequences=True)(CL3, initial_state=[cl2_h, cl2_c])
  LN = tf.keras.layers.LayerNormalization()(CL4)
  # Deepen
  conv1 = tf.keras.layers.TimeDistributed(tf.keras.layers.Conv2D(nk, ks, padding='same',
                                                                 activation=activ, kernel_initializer=init))(LN)
  mp = tf.keras.layers.TimeDistributed(tf.keras.layers.MaxPooling2D())(conv1)
  conv2 = tf.keras.layers.TimeDistributed(tf.keras.layers.Conv2D(nk, ks, padding='same',
                                                                 activation=activ, kernel_initializer=init))(mp)
  mp = tf.keras.layers.TimeDistributed(tf.keras.layers.MaxPooling2D())(conv2)
  LN = tf.keras.layers.LayerNormalization()(mp)
  #
  conv3 = tf.keras.layers.TimeDistributed(tf.keras.layers.Conv2D(nk, ks, padding='same',
                                                                 activation=activ, kernel_initializer=init))(LN)
  mp = tf.keras.layers.TimeDistributed(tf.keras.layers.MaxPooling2D())(conv3)
  conv4 = tf.keras.layers.TimeDistributed(tf.keras.layers.Conv2D(nk, ks, padding='same',
                                                                 activation=activ, kernel_initializer=init))(mp)
  mp = tf.keras.layers.TimeDistributed(tf.keras.layers.MaxPooling2D())(conv4)
  LN = tf.keras.layers.LayerNormalization()(mp)
  #
  conv5 = tf.keras.layers.TimeDistributed(tf.keras.layers.Conv2D(nk, ks, padding='same',
                                                                 activation=activ, kernel_initializer=init))(LN)
  us = tf.keras.layers.TimeDistributed(tf.keras.layers.UpSampling2D())(conv5)
  concat = tf.keras.layers.Concatenate()([us, conv4])
  conv6 = tf.keras.layers.TimeDistributed(tf.keras.layers.Conv2D(nk, ks, padding='same',
                                                                 activation=activ, kernel_initializer=init))(concat)
  us = tf.keras.layers.TimeDistributed(tf.keras.layers.UpSampling2D())(conv6)
  concat = tf.keras.layers.Concatenate()([us, conv3])
  LN = tf.keras.layers.LayerNormalization()(concat)
  #
  conv7 = tf.keras.layers.TimeDistributed(tf.keras.layers.Conv2D(nk, ks, padding='same',
                                                                 activation=activ, kernel_initializer=init))(LN)
  us = tf.keras.layers.TimeDistributed(tf.keras.layers.UpSampling2D())(conv7)
  concat = tf.keras.layers.Concatenate()([us, conv2])
  conv8 = tf.keras.layers.TimeDistributed(tf.keras.layers.Conv2D(nk, ks, padding='same',
                                                                 activation=activ, kernel_initializer=init))(concat)
  us = tf.keras.layers.TimeDistributed(tf.keras.layers.UpSampling2D())(conv8)
  concat = tf.keras.layers.Concatenate()([us, conv1])
  LN = tf.keras.layers.LayerNormalization()(concat)
  # Decoder 2
  CL5 = tf.keras.layers.ConvLSTM2D(nk, ks, padding='same',
                                   activation=activ, kernel_initializer=init,
                                   return_sequences=True)(LN, initial_state=[cl1_h, cl1_c])
  CL6 = tf.keras.layers.ConvLSTM2D(nk, ks, padding='same',
                                   activation=activ, kernel_initializer=init,
                                   return_sequences=True)(CL5, initial_state=[cl2_h, cl2_c])
  LN = tf.keras.layers.LayerNormalization()(CL6)
  # Prediction
  preds = tf.keras.layers.Conv3D(1, lks, padding='same',
                                 bias_initializer=tf.keras.initializers.Constant(value=-np.log(99)),
                                 activation='sigmoid')(LN)
  return tf.keras.Model(inputs=inputs, outputs=preds)

In [0]:
# Initialize model
model = model_init()

# Build optimizer
model.compile(optimizer=optimizer,
              loss=loss,
              metrics=['acc', ssim, psnr, cor,
                       tf.keras.metrics.Precision(name='prec'),
                       tf.keras.metrics.Recall(name='recall')])

# Fit the model
history = model.fit(X, y, 
                    batch_size=bs, 
                    epochs=ep,
                    validation_split=0.1)

# Get results on test set
results = model.evaluate(Xtest, ytest, batch_size=bs, return_dict=True)

In [0]:
plt.figure(figsize=(12,6))
plt.plot(history.history['loss'], label='Train')
plt.plot(history.history['val_loss'], label='Val')
plt.axhline(results['loss'], linestyle='--', color='k', label='Test')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(loc='upper left')
plt.savefig('loss.png')
#files.download('loss.png')
#
plt.figure(figsize=(12,6))
plt.plot(history.history['prec'], label='Train')
plt.plot(history.history['val_prec'], label='Val')
plt.axhline(results['prec'], linestyle='--', color='k', label='Test')
plt.ylabel('prec')
plt.xlabel('epoch')
plt.legend(loc='upper left')
plt.savefig('prec.png')
#files.download('prec.png')
#
plt.figure(figsize=(12,6))
plt.plot(history.history['recall'], label='Train')
plt.plot(history.history['val_recall'], label='Val')
plt.axhline(results['recall'], linestyle='--', color='k', label='Test')
plt.ylabel('recall')
plt.xlabel('epoch')
plt.legend(loc='upper left')
plt.savefig('recall.png')
#files.download('recall.png')
#
plt.figure(figsize=(12,6))
plt.plot(history.history['cor'], label='Train')
plt.plot(history.history['val_cor'], label='Val')
plt.axhline(results['cor'], linestyle='--', color='k', label='Test')
plt.ylabel('cor')
plt.xlabel('epoch')
plt.legend(loc='upper left')
plt.savefig('cor.png')
#files.download('cor.png')
#
plt.figure(figsize=(12,6))
plt.plot(history.history['acc'], label='Train')
plt.plot(history.history['val_acc'], label='Val')
plt.axhline(results['cor'], linestyle='--', color='k', label='Test')
plt.ylabel('acc')
plt.xlabel('epoch')
plt.legend(loc='upper left')
plt.savefig('acc.png')
#files.download('acc.png')
#
plt.figure(figsize=(12,6))
plt.plot(history.history['ssim'], label='Train')
plt.plot(history.history['val_ssim'], label='Val')
plt.axhline(results['cor'], linestyle='--', color='k', label='Test')
plt.ylabel('ssim')
plt.xlabel('epoch')
plt.legend(loc='upper left')
plt.savefig('ssim.png')
#files.download('ssim.png')
#
plt.figure(figsize=(12,6))
plt.plot(history.history['psnr'], label='Train')
plt.plot(history.history['val_psnr'], label='Val')
plt.axhline(results['cor'], linestyle='--', color='k', label='Test')
plt.ylabel('psnr')
plt.xlabel('epoch')
plt.legend(loc='upper left')
plt.savefig('psnr.png')
#files.download('psnr.png')

In [0]:
# Plot Train
itest = 1
track = X[itest, :, :, :, :]
track = np.concatenate((track[None,:,:,:,:], model.predict(track[None,:,:,:,:])), axis=1)
true_track = data_train[itest, :, :, :, :]
for i in range(Ninput+Noutput):
    plt.figure(figsize=(10, 5))
    plt.subplot(121)
    if i >= Ninput:
        plt.text(1, 3, 'Prediction', color='w')
    else:
        plt.text(1, 3, 'Initial trajectory', color='w')
    plt.imshow(track[0, i, :, :, 0])
    plt.subplot(122)
    plt.text(1, 3, 'Ground truth', color='w')
    plt.imshow(true_track[i, :, :, 0])
    plt.savefig('train_%i.png' % (i+1))
    #files.download('train_%i.png' % (i+1))

In [0]:
# Plot Test
itest = 1
track = Xtest[itest, :, :, :, :]
track = np.concatenate((track[None,:,:,:,:], model.predict(track[None,:,:,:,:])), axis=1)
true_track = data_test[itest, :, :, :, :]
for i in range(Ninput+Noutput):
    plt.figure(figsize=(10, 5))
    plt.subplot(121)
    if i >= Ninput:
        plt.text(1, 3, 'Prediction', color='w')
    else:
        plt.text(1, 3, 'Initial trajectory', color='w')
    plt.imshow(track[0, i, :, :, 0])
    plt.subplot(122)
    plt.text(1, 3, 'Ground truth', color='w')
    plt.imshow(true_track[i, :, :, 0])
    plt.savefig('test_%i.png' % (i+1))
    #files.download('test_%i.png' % (i+1))