# IMSEQ TRAIN

In [None]:
# Force matplotlib to use inline rendering
%matplotlib inline

import os
import sys

# add path to libraries for ipython
sys.path.append(os.path.expanduser("~/libs"))

import numpy as np
import tensorflow as tf
import tensortools as tt

from model.frame_prediction import LSTMConv2DPredictionModel

In [None]:
ROOT_DIR = "/work/sauterme/"
DATA_DIR = ROOT_DIR + "data"

BATCH_SIZE = 32
EVAL_BATCH_SIZE = 50

WEIGHT_DECAY = 1e-5

INPUT_SEQ_LENGTH = 10
OUTPUT_SEQ_LENGTH = 10

INITIAL_LR = 0.001
LR_DECAY_STEP_INTERVAL = 1000
LR_DECAY_FACTOR = 0.95

### Data

In [None]:
dataset_train = tt.datasets.moving_mnist.MovingMNISTTrainDataset(DATA_DIR,
                                                                 input_shape=[INPUT_SEQ_LENGTH, 64, 64, 1],
                                                                 target_shape=[OUTPUT_SEQ_LENGTH, 64, 64, 1],
                                                                 as_binary=True)
dataset_valid = tt.datasets.moving_mnist.MovingMNISTValidDataset(DATA_DIR,
                                                                 input_shape=[INPUT_SEQ_LENGTH, 64, 64, 1],
                                                                 target_shape=[OUTPUT_SEQ_LENGTH, 64, 64, 1],
                                                                 as_binary=True)

### Training

In [None]:
#TRAIN_DIR = ROOT_DIR + "train/wd1e5_64filters_peep_no-input-bn"
TRAIN_DIR = ROOT_DIR + "train/refac_test"

In [None]:
runtime = tt.core.DefaultRuntime(train_dir=TRAIN_DIR, gpu_devices=[7])

In [None]:
runtime.register_datasets(dataset_train, dataset_valid, dataset_valid)
runtime.register_model(LSTMConv2DPredictionModel(weight_decay=WEIGHT_DECAY,
                                                 filters=[32, 64, 64], ksizes=[(5,5),(3,3),(3,3)],
                                                 strides=[(2,2),(1,1),(2,2)], bias_init=0.1,
                                                 output_activation=tf.nn.sigmoid,
                                                 bn_feature_enc=True, bn_feature_dec=True, 
                                                 lstm_layers=2, 
                                                 lstm_ksize_input=(3, 3), lstm_ksize_hidden=(3,3),
                                                 lstm_use_peepholes=True, lstm_cell_clip=None,
                                                 lstm_bn_input_hidden=False, lstm_bn_hidden_hidden=False,
                                                 lstm_bn_peepholes=False))
runtime.register_optimizer(tt.training.Optimizer('adam',
                                                 INITIAL_LR,
                                                 LR_DECAY_STEP_INTERVAL,
                                                 LR_DECAY_FACTOR))
runtime.build(verbose=True)

In [None]:
def write_animations(rt, dataset, gstep):
    samples = 4
    root = os.path.join(rt.train_dir, "out", "{:06d}".format(gstep))
    x, y = dataset.get_batch(samples)
    pred = rt.predict(x)

    # concat x to y and prediction
    for i in range(samples):
        concat_y = np.concatenate((x[i], y[i]))
        concat_pred = np.concatenate((x[i], pred[i]))

        tt.utils.video.write_multi_gif(os.path.join(root, "anim-{:02d}.gif".format(i)),
                                       [concat_y, concat_pred],
                                       fps=5, pad_value=1.0)

        tt.utils.video.write_multi_image_sequence(os.path.join(root, "timeline-{:02d}.png".format(i)),
                                                  [concat_y, concat_pred],
                                                  pad_value=1.0)
    
def on_valid(rt, gstep):
    write_animations(rt, rt.datasets.valid, gstep)

In [None]:
runtime.train(BATCH_SIZE, EVAL_BATCH_SIZE,
              steps=50000, on_validate=on_valid)

### Evaluation

In [None]:
runtime.validate(EVAL_BATCH_SIZE)

In [None]:
runtime.test(EVAL_BATCH_SIZE)

### Visualization

In [None]:
x, y = dataset_valid.get_batch(1)
pred = runtime.predict(x)
print(x.dtype, y.dtype, pred.dtype)
print(x.min(), x.max())
print(y.min(), y.max())
print(pred.min(), pred.max())
tt.visualization.display_batch(x[0], nrows=2, ncols=5, title="Input")
tt.visualization.display_batch(y[0], nrows=2, ncols=5, title="GT-Future")
tt.visualization.display_batch(pred[0], nrows=2, ncols=5, title="GT-Prediction")

In [None]:
write_animations(runtime, dataset_valid, 999999999)

### Terminate

In [None]:
runtime.close()