In [None]:
# Force matplotlib to use inline rendering
%matplotlib inline

import os
import sys

# add path to libraries for ipython
sys.path.append(os.path.expanduser("~/libs"))

import numpy as np
import tensorflow as tf
import tensortools as tt

from model.conv_lstmconv2d_encoder_decoder import ConvLSTMConv2DDecoderEncoderModel

In [None]:
TRAIN_DIR = "train/conv_c2dlstm_lower_lr"

BATCH_SIZE = 24
REG_LAMBDA = 5e-4
NUM_GPUS = 2

INPUT_SEQ_LENGTH = 10
OUTPUT_SEQ_LENGTH = 50

INITIAL_LR = 0.001
LR_DECAY_STEP_INTERVAL = 10000
LR_DECAY_FACTOR = 0.5

### Data

In [None]:
dataset_train = tt.datasets.moving_mnist.MovingMNISTTrainDataset(input_shape=[INPUT_SEQ_LENGTH, 64, 64, 1],
                                                                 target_shape=[OUTPUT_SEQ_LENGTH, 64, 64, 1])
dataset_valid = tt.datasets.moving_mnist.MovingMNISTValidDataset(input_shape=[INPUT_SEQ_LENGTH, 64, 64, 1],
                                                                 target_shape=[OUTPUT_SEQ_LENGTH, 64, 64, 1])
#dataset_test = tt.datasets.moving_mnist.MovingMNISTTestDataset(input_seq_length=INPUT_SEQ_LENGTH,
#                                                               target_seq_length=OUTPUT_SEQ_LENGTH)
dataset_test = None

### Training

In [None]:
tt.hardware.set_cuda_devices([7])
runtime = tt.core.DefaultRuntime(train_dir=TRAIN_DIR)
#runtime = tt.core.MultiGpuRuntime(NUM_GPUS, train_dir=TRAIN_DIR)
runtime.register_datasets(dataset_train, dataset_valid, dataset_test)
runtime.register_model(ConvLSTMConv2DDecoderEncoderModel(reg_lambda=REG_LAMBDA))
runtime.build(INITIAL_LR,
              LR_DECAY_STEP_INTERVAL,
              LR_DECAY_FACTOR,
              checkpoint_file=tt.core.LATEST_CHECKPOINT)

In [None]:
runtime.train(BATCH_SIZE, steps=30000)

### Evaluation

In [None]:
runtime.validate(50)

In [None]:
runtime.test(50)

### Visualization

In [None]:
x, y = dataset_valid.get_batch(1)
pred = runtime.predict(x)

tt.visualization.display_batch(x[0] * 255, nrows=2, ncols=5, title="Input")
tt.visualization.display_batch(y[0] * 1000, nrows=2, ncols=5, title="GT-Future")
tt.visualization.display_batch(pred[0] * 1000, nrows=2, ncols=5, title="GT-Prediction")

In [None]:
x, y = dataset_valid.get_batch(1)
pred = runtime.predict(x)

# concat x to y and prediction
concat_y = np.concatenate((x[0], y[0]))
concat_pred = np.concatenate((x[0], pred[0]))

tt.utils.video.write_multiclip_gif(os.path.join(TRAIN_DIR, "out/anim-30000.gif"),
                                   [concat_y * 255, concat_pred * 255],
                                   fps=10)

### Terminate

In [None]:
runtime.close()