# IMSEQ TRAIN

In [1]:
# Force matplotlib to use inline rendering
%matplotlib inline

import os
import sys

# add path to libraries for ipython
sys.path.append(os.path.expanduser("~/libs"))

import numpy as np
import tensorflow as tf
import tensortools as tt

from model.conv_lstmconv2d_encoder_decoder import ConvLSTMConv2DDecoderEncoderModel
from model.lstmconv2d_encoder_decoder import LSTMConv2DDecoderEncoderModel

In [2]:
TRAIN_DIR = "train/conv_c2dlstm_bce_no-nonlin-low_decay_5x5lstm"

BATCH_SIZE = 32
REG_LAMBDA = 1e-7
NUM_GPUS = 2

INPUT_SEQ_LENGTH = 10
OUTPUT_SEQ_LENGTH = 10

INITIAL_LR = 0.001
LR_DECAY_STEP_INTERVAL = 10000
LR_DECAY_FACTOR = 0.9

### Data

In [3]:
dataset_train = tt.datasets.moving_mnist.MovingMNISTTrainDataset(input_shape=[INPUT_SEQ_LENGTH, 64, 64, 1],
                                                                 target_shape=[OUTPUT_SEQ_LENGTH, 64, 64, 1])
dataset_valid = tt.datasets.moving_mnist.MovingMNISTValidDataset(input_shape=[INPUT_SEQ_LENGTH, 64, 64, 1],
                                                                 target_shape=[OUTPUT_SEQ_LENGTH, 64, 64, 1])
#dataset_test = tt.datasets.moving_mnist.MovingMNISTTestDataset(input_seq_length=INPUT_SEQ_LENGTH,
#                                                               target_seq_length=OUTPUT_SEQ_LENGTH)
dataset_test = None

File mnist.h5 has already been downloaded.
File mnist.h5 has already been downloaded.


### Training

In [4]:
tt.hardware.set_cuda_devices([1])

runtime = tt.core.DefaultRuntime(train_dir=TRAIN_DIR)
#runtime = tt.core.MultiGpuRuntime(NUM_GPUS, train_dir=TRAIN_DIR)

Launing default runtime...
Selecting GPU device: 1


In [5]:
runtime.register_datasets(dataset_train, dataset_valid, dataset_test)
runtime.register_model(ConvLSTMConv2DDecoderEncoderModel(reg_lambda=REG_LAMBDA))
#runtime.register_model(LSTMConv2DDecoderEncoderModel(reg_lambda=REG_LAMBDA))
runtime.build(INITIAL_LR,
              LR_DECAY_STEP_INTERVAL,
              LR_DECAY_FACTOR,
              verbose=True)

ConvStack/Conv1/W:0: 800
ConvStack/Conv1/b:0: 32
ConvStack/Conv2/W:0: 18432
ConvStack/Conv2/b:0: 64
ConvStack/Conv3/W:0: 36864
ConvStack/Conv3/b:0: 64
encoder-lstm/RNNConv2D/BasicLSTMConv2DCell/Conv_xi/W:0: 102400
encoder-lstm/RNNConv2D/BasicLSTMConv2DCell/Conv_xi/b:0: 64
encoder-lstm/RNNConv2D/BasicLSTMConv2DCell/Conv_xj/W:0: 102400
encoder-lstm/RNNConv2D/BasicLSTMConv2DCell/Conv_xj/b:0: 64
encoder-lstm/RNNConv2D/BasicLSTMConv2DCell/Conv_xf/W:0: 102400
encoder-lstm/RNNConv2D/BasicLSTMConv2DCell/Conv_xf/b:0: 64
encoder-lstm/RNNConv2D/BasicLSTMConv2DCell/Conv_xo/W:0: 102400
encoder-lstm/RNNConv2D/BasicLSTMConv2DCell/Conv_xo/b:0: 64
encoder-lstm/RNNConv2D/BasicLSTMConv2DCell/Conv_hi/W:0: 102400
encoder-lstm/RNNConv2D/BasicLSTMConv2DCell/Conv_hj/W:0: 102400
encoder-lstm/RNNConv2D/BasicLSTMConv2DCell/Conv_hf/W:0: 102400
encoder-lstm/RNNConv2D/BasicLSTMConv2DCell/Conv_ho/W:0: 102400
decoder-lstm/RNNConv2D/BasicLSTMConv2DCell/Conv_xi/W:0: 102400
decoder-lstm/RNNConv2D/BasicLSTMConv2DCell/Con

In [6]:
def write_gifs(rt, dataset, gstep):
    x, y = dataset.get_batch(1)
    pred = rt.predict(x)

    # concat x to y and prediction
    concat_y = np.concatenate((x[0], y[0]))
    concat_pred = np.concatenate((x[0], pred[0]))

    tt.utils.video.write_multi_gif(os.path.join(rt.train_dir, "out/anim-{:06d}.gif".format(gstep)),
                                   [concat_y * 255, concat_pred * 255],
                                   fps=4)
    
def on_valid(rt, gstep):
    write_gifs(rt, rt.datasets.valid, gstep)

In [None]:
runtime.train(BATCH_SIZE, steps=50000, on_validate=on_valid)

Starting epoch 1...
@    10: loss:     0.309, total-loss:     0.309 (   72.0 examples/sec,  0.44 sec/batch)
@    20: loss:     0.211, total-loss:     0.211 (   71.3 examples/sec,  0.45 sec/batch)
@    30: loss:     0.196, total-loss:     0.196 (   72.1 examples/sec,  0.44 sec/batch)
@    40: loss:     0.183, total-loss:     0.183 (   72.1 examples/sec,  0.44 sec/batch)
@    50: loss:     0.177, total-loss:     0.177 (   70.7 examples/sec,  0.45 sec/batch)
@    60: loss:     0.174, total-loss:     0.174 (   74.7 examples/sec,  0.43 sec/batch)
@    70: loss:     0.168, total-loss:     0.168 (   73.4 examples/sec,  0.44 sec/batch)
@    80: loss:     0.170, total-loss:     0.170 (   73.5 examples/sec,  0.44 sec/batch)
@    90: loss:     0.164, total-loss:     0.164 (   73.1 examples/sec,  0.44 sec/batch)


### Evaluation

In [None]:
runtime.validate(50)

In [None]:
runtime.test(50)

### Visualization

In [None]:
x, y = dataset_valid.get_batch(1)
pred = runtime.predict(x)

tt.visualization.display_batch(x[0] * 255, nrows=2, ncols=5, title="Input")
tt.visualization.display_batch(y[0] * 1000, nrows=2, ncols=5, title="GT-Future")
tt.visualization.display_batch(pred[0] * 1000, nrows=2, ncols=5, title="GT-Prediction")

In [None]:
write_gifs(runtime, dataset_valid, 999999999)

### Terminate

In [None]:
runtime.close()