# IMSEQ TRAIN

In [1]:
# Force matplotlib to use inline rendering
%matplotlib inline

import os
import sys

# add path to libraries for ipython
sys.path.append(os.path.expanduser("~/libs"))

import numpy as np
import tensorflow as tf
import tensortools as tt

from model.conv_lstmconv2d_encoder_decoder_nodiag import ConvLSTMConv2DDecoderEncoderModelND
from model.conv_lstmconv2d_encoder_decoder import ConvLSTMConv2DDecoderEncoderModel

In [2]:
TRAIN_DIR = "/work/sauterme/train/lstmconv2d-basic"

BATCH_SIZE = 32
REG_LAMBDA = 1e-6
NUM_GPUS = 4

INPUT_SEQ_LENGTH = 10
OUTPUT_SEQ_LENGTH = 10

INITIAL_LR = 0.0005
LR_DECAY_STEP_INTERVAL = 10000
LR_DECAY_FACTOR = 0.9

DATA_DIR = "/work/sauterme/data"

### Data

In [3]:
dataset_train = tt.datasets.moving_mnist.MovingMNISTTrainDataset(DATA_DIR,
                                                                 input_shape=[INPUT_SEQ_LENGTH, 64, 64, 1],
                                                                 target_shape=[OUTPUT_SEQ_LENGTH, 64, 64, 1])
dataset_valid = tt.datasets.moving_mnist.MovingMNISTValidDataset(DATA_DIR,
                                                                 input_shape=[INPUT_SEQ_LENGTH, 64, 64, 1],
                                                                 target_shape=[OUTPUT_SEQ_LENGTH, 64, 64, 1])
#dataset_test = tt.datasets.moving_mnist.MovingMNISTTestDataset(DATA_DIR,
#                                                               input_seq_length=INPUT_SEQ_LENGTH,
#                                                               target_seq_length=OUTPUT_SEQ_LENGTH)
dataset_test = None

File mnist.h5 has already been downloaded.
File mnist.h5 has already been downloaded.


### Training

In [4]:
tt.hardware.set_cuda_devices([7])

runtime = tt.core.DefaultRuntime(train_dir=TRAIN_DIR)
#runtime = tt.core.MultiGpuRuntime(NUM_GPUS, train_dir=TRAIN_DIR)

Launing default runtime...
Selecting GPU device: 7


In [5]:
runtime.register_datasets(dataset_train, dataset_valid, dataset_test)
runtime.register_model(ConvLSTMConv2DDecoderEncoderModelND(reg_lambda=REG_LAMBDA))
runtime.register_optimizer(tt.training.Optimizer('adam',
                                                 INITIAL_LR,
                                                 LR_DECAY_STEP_INTERVAL,
                                                 LR_DECAY_FACTOR))
runtime.build(verbose=True)

Found 140 update ops.
Initializing variables...

--------------------------------------------------------------------------------
ConvStack/x_bn/beta:0                                                 |        1
ConvStack/Conv1/W:0                                                   |      800
ConvStack/Conv1/b:0                                                   |       32
ConvStack/conv1_bn/beta:0                                             |       32
ConvStack/Conv2/W:0                                                   |    18432
ConvStack/Conv2/b:0                                                   |       64
ConvStack/conv2_bn/beta:0                                             |       64
ConvStack/Conv3/W:0                                                   |    36864
ConvStack/Conv3/b:0                                                   |       64
ConvStack/conv3_bn/beta:0                                             |       64
encoder-lstm/RNNConv2D/BasicLSTMConv2DCell/Conv_x/xi/W:0    

In [None]:
def write_gifs(rt, dataset, gstep):
    x, y = dataset.get_batch(1)
    pred = rt.predict(x)

    # concat x to y and prediction
    concat_y = np.concatenate((x[0], y[0]))
    concat_pred = np.concatenate((x[0], pred[0]))

    tt.utils.video.write_multi_gif(os.path.join(rt.train_dir, "out/anim-{:06d}.gif".format(gstep)),
                                   [concat_y, concat_pred],
                                   fps=4)
    
def on_valid(rt, gstep):
    write_gifs(rt, rt.datasets.valid, gstep)

In [None]:
runtime.train(BATCH_SIZE, valid_batch_size=50, steps=50000, on_validate=on_valid)

Starting epoch 1...
@    1: loss:   0.791, t-loss:   0.791 (    2.6 examples/sec, 12.19 sec/batch)
@   25: loss:   0.669, t-loss:   0.669 (   14.3 examples/sec,  2.23 sec/batch)


### Evaluation

In [None]:
runtime.validate(50)

In [None]:
runtime.test(50)

### Visualization

In [None]:
x, y = dataset_valid.get_batch(1)
pred = runtime.predict(x)
print(x.dtype, y.dtype, pred.dtype)
print(x.min(), x.max())
print(y.min(), y.max())
print(pred.min(), pred.max())
tt.visualization.display_batch(x[0], nrows=2, ncols=5, title="Input")
tt.visualization.display_batch(y[0], nrows=2, ncols=5, title="GT-Future")
tt.visualization.display_batch(pred[0], nrows=2, ncols=5, title="GT-Prediction")

In [None]:
write_gifs(runtime, dataset_valid, 999999999)

### Terminate

In [None]:
runtime.close()