# Evaluation  for MovingMNIST

In [None]:
%matplotlib inline

import os
import sys
sys.path.append(os.path.expanduser("~/libs"))

import cv2
import numpy as np
import tensorflow as tf
import tensortools as tt

from model.frame_prediction import LSTMConv2DPredictionModel

### Hyperparams

In [None]:
# Data
INPUT_SEQ_LENGTH = 10
OUTPUT_SEQ_LENGTH = 10

# Evaluation
EVAL_BATCH_SIZE = 50

In [None]:
# validation while training
OUT_DIR_NAME = "out-eval"
NUM_SAMPLES = 4
GIF_FPS = 5

#### Directory Paths:

In [None]:
ROOT_DIR = "/work/sauterme/"
DATA_DIR = ROOT_DIR + "data"

In [None]:
TRAIN_DIR = ROOT_DIR + "train/mm/ss/3l3i5hp/c326464k533s212bn/wd1e-05/LV"

assert os.path.exists(TRAIN_DIR)

### Data

In [None]:
AS_BINARY = True
dataset_valid = tt.datasets.moving_mnist.MovingMNISTValidDataset(DATA_DIR,
                                                                 input_shape=[INPUT_SEQ_LENGTH, 64, 64, 1],
                                                                 target_shape=[OUTPUT_SEQ_LENGTH, 64, 64, 1],
                                                                 as_binary=AS_BINARY)
dataset_test = tt.datasets.moving_mnist.MovingMNISTTestDataset(DATA_DIR,
                                                               input_seq_length=INPUT_SEQ_LENGTH,
                                                               target_seq_length=OUTPUT_SEQ_LENGTH,
                                                               as_binary=AS_BINARY)

### Runtime

In [None]:
GPU_ID = 1

In [None]:
runtime = tt.core.DefaultRuntime(train_dir=TRAIN_DIR, gpu_devices=[GPU_ID])

In [None]:
runtime.list_checkpoints()

In [None]:
CHECKPOINT = tt.core.LATEST_CHECKPOINT

In [None]:
runtime.register_datasets(None, dataset_valid, dataset_test)

In [None]:
runtime.register_model(LSTMConv2DPredictionModel(main_loss=tt.loss.bce))

Consider to restore the **EMA variables** as well when building the model. These might generate worse results in models using batch-normalization, since the shaddow variables might get restored properly...

In [None]:
runtime.build(restore_checkpoint=CHECKPOINT, restore_model_params=True,
              restore_ema_variables=False, verbose=True)

## Evaluation

In [None]:
runtime.validate(EVAL_BATCH_SIZE)

In [None]:
runtime.test(EVAL_BATCH_SIZE)

## Visualizations
On which dataset we want to test on on the next section. For sequences > 10, we have to use the validation set...

In [None]:
dataset_eval = dataset_test

### Random prediction
Either as **binary** (like in training) or **float** (as in raw dataset)...

In [None]:
def write_animation(dir_path, inputs, targets, predictions, fps):
    concat_tgt = np.concatenate((inputs, targets))
    concat_pred = np.concatenate((inputs, predictions))

    tt.utils.video.write_multi_gif(os.path.join(dir_path, "anim-{:02d}.gif".format(i)),
                                   [concat_tgt, concat_pred],
                                   fps=fps, pad_value=1.0)

    tt.utils.video.write_multi_image_sequence(os.path.join(dir_path, "timeline-{:02d}.png".format(i)),
                                              [concat_tgt, concat_pred],
                                              pad_value=1.0)

def show(inputs, targets, predictions, rows=2):
    tt.visualization.display_batch(inputs, ncols=5, nrows=rows, title="Inputs")
    tt.visualization.display_batch(targets, ncols=5, nrows=rows, title="Targets")
    tt.visualization.display_batch(predictions, ncols=5, nrows=rows, title="Predictions")

In [None]:
dir_path = os.path.join(runtime.train_dir, OUT_DIR_NAME, "random")

inputs, targets = dataset_eval.get_batch(NUM_SAMPLES)

predictions = runtime.predict(inputs)

show(inputs[0], targets[0], predictions[0])
for i in range(inputs.shape[0]):
    write_animation(dir_path, inputs[i], targets[i], predictions[i], GIF_FPS)

### Specific Predictions
We are using the inputs used in _Unsupervised Learning with LSTMs_ cropped out of the paper. These consist of two normal sequences, one sequence with only one character and one sequence with three characters...

In [None]:
SOURCE_PATH = "assets/other/moving_mnist/"

In [None]:
def read_sequence(dir_path, seq_id):
    image_list = []
    for i in range(INPUT_SEQ_LENGTH+OUTPUT_SEQ_LENGTH):
        image_path = os.path.join(dir_path, str(seq_id), "{:02d}.png".format(i))
        image = tt.utils.image.read(image_path, color_flags = cv2.IMREAD_GRAYSCALE)
        image_list.append(image)
    seq = np.array(image_list)
    seq = seq / 255.0
    seq = np.expand_dims(seq, 0)
    return seq[:,:INPUT_SEQ_LENGTH] , seq[:,INPUT_SEQ_LENGTH:] 

In [None]:
dir_path = os.path.join(runtime.train_dir, OUT_DIR_NAME, "spec")

for i in range(6):
    inputs, targets = read_sequence(SOURCE_PATH, i)
    predictions = runtime.predict(inputs)
    
    show(inputs[0], targets[0], predictions[0])
    write_animation(dir_path, inputs[0], targets[0], predictions[0], GIF_FPS)

### Bigger Image

#### a) Scaled

In [None]:
SCALE_FACTOR = 2.0

In [None]:
runtime.unregister_datasets()
runtime.build(restore_checkpoint=CHECKPOINT, restore_model_params=True,
              restore_ema_variables=False,
              input_shape=[INPUT_SEQ_LENGTH, int(64 * SCALE_FACTOR), int(64 * SCALE_FACTOR), 1],
              target_shape=[OUTPUT_SEQ_LENGTH, int(64 * SCALE_FACTOR), int(64 * SCALE_FACTOR), 1])

In [None]:
dir_path = os.path.join(runtime.train_dir, OUT_DIR_NAME, "scaled")

inputs, targets = dataset_eval.get_batch(NUM_SAMPLES)

inputs = np.reshape(inputs, [-1, 64, 64, 1])
targets = np.reshape(targets, [-1, 64, 64, 1])

inputs = np.split(inputs, inputs.shape[0])
targets = np.split(targets, targets.shape[0])

for i in range(len(inputs)):
    current = np.squeeze(inputs[i], 0)
    inputs[i] = tt.utils.image.resize(current, SCALE_FACTOR)
inputs = np.stack(inputs)

for i in range(len(targets)):
    current = np.squeeze(targets[i], 0)
    targets[i] = tt.utils.image.resize(current, SCALE_FACTOR)
targets = np.stack(targets)

inputs = np.reshape(inputs, [-1, INPUT_SEQ_LENGTH, int(64 * SCALE_FACTOR), int(64 * SCALE_FACTOR), 1])
targets = np.reshape(targets, [-1, OUTPUT_SEQ_LENGTH, int(64 * SCALE_FACTOR), int(64 * SCALE_FACTOR), 1])

predictions = runtime.predict(inputs)

show(inputs[0], targets[0], predictions[0])
for i in range(inputs.shape[0]):
    write_animation(dir_path, inputs[i], targets[i], predictions[i], GIF_FPS)

#### b) Zero-Padded

In [None]:
SIZE_FACTOR = 2.0
AS_BINARY = True

For this, we have to use the **validation set**, because it allows variable size of the image...

In [None]:
dataset_eval = tt.datasets.moving_mnist.MovingMNISTValidDataset(
    DATA_DIR, input_shape=[INPUT_SEQ_LENGTH, int(64 * SIZE_FACTOR), int(64 * SIZE_FACTOR), 1],
    target_shape=[OUTPUT_SEQ_LENGTH, int(64 * SIZE_FACTOR), int(64 * SIZE_FACTOR), 1],
    as_binary=AS_BINARY)

In [None]:
runtime.register_datasets(valid_ds=dataset_eval)
runtime.build(restore_checkpoint=CHECKPOINT, restore_model_params=True,
              restore_ema_variables=False)

In [None]:
dir_path = os.path.join(runtime.train_dir, OUT_DIR_NAME, "padded")

inputs, targets = dataset_eval.get_batch(NUM_SAMPLES)

predictions = runtime.predict(inputs)

show(inputs[0], targets[0], predictions[0])
for i in range(inputs.shape[0]):
    write_animation(dir_path, inputs[i], targets[i], predictions[i], GIF_FPS)

### Longer Time Range

In [None]:
# Data
INPUT_SEQ_LENGTH = 10
OUTPUT_SEQ_LENGTH = 50

AS_BINARY = True
dataset_eval = tt.datasets.moving_mnist.MovingMNISTValidDataset(DATA_DIR,
                                                                input_shape=[INPUT_SEQ_LENGTH, 64, 64, 1],
                                                                target_shape=[OUTPUT_SEQ_LENGTH, 64, 64, 1],
                                                                as_binary=AS_BINARY)

In [None]:
runtime.register_datasets(None, dataset_eval)

In [None]:
runtime.build(restore_checkpoint=CHECKPOINT, restore_model_params=True,
              restore_ema_variables=False)

In [None]:
dir_path = os.path.join(runtime.train_dir, OUT_DIR_NAME, "long")

inputs, targets = dataset_eval.get_batch(NUM_SAMPLES)

predictions = runtime.predict(inputs)

show(inputs[0], targets[0], predictions[0], rows=10)
for i in range(inputs.shape[0]):
    write_animation(dir_path, inputs[i], targets[i], predictions[i], GIF_FPS)