In [None]:
# export THEANO_FLAGS="device=gpu0, floatX=float32" optimizer=None

import theano
import theano.tensor as T
import numpy as np
import os
import socket
import argparse
import time
import datetime
import importlib
import pprint

%matplotlib inline
# %matplotlib nbagg
import matplotlib.pyplot as plt
from IPython import display

import lasagne
from lasagne.utils import floatX
from lasagne.updates import rmsprop, adam, momentum
from lasagne.layers import get_all_params, get_all_layers, get_all_param_values, get_output
from lasagne.objectives import squared_error, binary_crossentropy, aggregate

from utils.helperFunctions import *

In [None]:
# -------- setup options and data ------------------
pretrained_model_path = '/path/to/model_checkpoint.p'

checkpoint = pickle.load(open(pretrained_model_path, 'rb'))
model_values = checkpoint['model_values'] # overwrite the values of model parameters
options = checkpoint['options']
pprint.PrettyPrinter(indent=4).pprint(options)

# Load options
np.random.seed(options['seed'])
host = socket.gethostname() # get computer hostname
start_time = datetime.datetime.now().strftime("%y-%m-%d-%H-%M")

model = importlib.import_module(options['model_file'])

# Optional: change some options
# options['modelOptions']['target_seqlen'] = 10
# options['datasetOptions']['num_frames'] = 15

In [None]:
# ---------- build model and compile ---------------
input_batch = T.tensor4() # input image sequences
target = T.tensor4() # target image

print('Build model...')
model = model.Model(**options['modelOptions'])

print('Compile ...')
net, outputs = model.build_model(input_batch)

# compute loss
outputs = get_output(outputs)
output_frames = outputs

# # compute loss
# outputs = get_output(outputs + [filters])
# output_frames = outputs[:-1]
# output_filter = outputs[-1]

train_losses = []
for i in range(options['modelOptions']['target_seqlen']):
    output_frame = output_frames[i]

    if options['loss'] == 'squared_error':
        frame_loss = squared_error(output_frame, target[:, [i], :, :])
    elif options['loss'] == 'binary_crossentropy':
        # Clipping to avoid NaN's in binary crossentropy: https://github.com/Lasagne/Lasagne/issues/436
        output_frame = T.clip(output_frame, np.finfo(np.float32).eps, 1-np.finfo(np.float32).eps)
        frame_loss = binary_crossentropy(output_frame, target[:,[i],:,:])
    else:
        assert False

    train_losses.append(aggregate(frame_loss))

train_loss = sum(train_losses) / options['modelOptions']['target_seqlen']

# update
sh_lr = theano.shared(lasagne.utils.floatX(options['learning_rate'])) # to allow dynamic learning rate

layers = get_all_layers(net)
all_params = get_all_params(layers, trainable = True)
updates = adam(train_loss, all_params, learning_rate=sh_lr)
_test = theano.function([input_batch, target], [train_loss] + output_frames, allow_input_downcast=True)

In [None]:
# ------------ data setup ----------------
print('Prepare data...')
if options['dataset_file'] == 'datasets.bouncingMnist_original':
    options['dataset_file'] = 'datasets.bouncingMnist_originalTest'

datasetOptions = options['datasetOptions']
dataset = importlib.import_module(options['dataset_file'])
datasetOptions['mode'] = 'test'
dh = dataset.DataHandler(**datasetOptions)

In [None]:
# ------------ training setup ----------------
lasagne.layers.set_all_param_values(layers, model_values)
history_train = checkpoint['history_train']
batch_size = options['batch_size']

# ------------ actual training ----------------
input_seqlen = options['modelOptions']['input_seqlen']

# setup a test batch
its = dh.GetDatasetSize() // options['batch_size']
# its = 1
history_batch = []
for i in range(0, its):
    ind = np.arange(i*batch_size, (i+1)*batch_size)
#     import pdb; pdb.set_trace()
    batch = dh.GetBatch(ind) # generate data on the fly
    if options['dataset_file'] == 'datasets.stereoCarsColor':
        batch_input = batch[..., :input_seqlen].squeeze(axis=4)  # first frames
        batch_target = batch[..., input_seqlen:].squeeze(axis=4)  # last frame
    else:
        batch_input = batch[..., :input_seqlen].transpose(0,4,2,3,1).squeeze(axis=4) # first frames
        batch_target = batch[..., input_seqlen:].transpose(0,4,2,3,1).squeeze(axis=4) # last frame
    testOutputs = _test(batch_input, batch_target)
    loss_test = testOutputs[0]
    history_batch.append(loss_test)
#     import pdb; pdb.set_trace()
    predictions = np.asarray(testOutputs[1:]).transpose(1,2,3,4,0)
    print("Batch {} of {}".format(i+1, its))
    print("Test loss:\t{:.6f}".format(np.mean(loss_test)))
    
    savefigs = False
    if savefigs == True:
        for case_id in range(batch_size):
#             import pdb; pdb.set_trace()
            # clear the screen
            display.clear_output(wait=True)
            
            # visualize the prediction
            visualize_prediction(batch, fut=predictions, fig=1, case_id=case_id, saveId=(i+1)*100+case_id, savefig=True)

            # visualize the flow map
            visualize_flowmap(pred_filter, batch, predictions, input_seqlen, options['image_dim'], options['modelOptions']['dynamic_filter_size'][0], case_id=case_id, saveId=(i+1)*100+case_id, savefig=True)
#             visualize_flowmapStereo(pred_filter, batch, predictions, input_seqlen, options['image_dim'], options['modelOptions']['dynamic_filter_size'][0], case_id=case_id, saveId=(i+1)*100+case_id, savefig=True)

            # animated gif
            import matplotlib.animation as animation

            fig = plt.figure() # make figure

            # make axesimage object
            # the vmin and vmax here are very important to get the color map correct
            redgreen = np.zeros((3,options['image_dim'],options['image_dim'],options['datasetOptions']['num_frames']))
            # import pdb; pdb.set_trace()
            redgreen[0,:,:,:input_seqlen] = 1
            redgreen[1,:,:,input_seqlen:] = 1

            data = batch[case_id]

            data2 = np.concatenate((np.zeros((options['batch_size'], 1, options['image_dim'], 
                                                    options['image_dim'], input_seqlen)), 
                                          predictions), axis=4)
            data2 = np.concatenate((batch[..., :input_seqlen], 
                                          predictions), axis=4)
            data2 = data2[case_id]

            plt.subplot(1,3,1)
            im0 = plt.imshow(redgreen[..., 0].transpose(1,2,0).squeeze(), cmap=plt.cm.gray, vmin=0, vmax=1, interpolation="nearest")
            plt.axis('off')
            plt.subplot(1,3,2)
            im = plt.imshow(data[..., 0].transpose(1,2,0).squeeze(), cmap=plt.cm.gray, vmin=0, vmax=1, interpolation="nearest")
            plt.axis('off')
            plt.subplot(1,3,3)
            im2 = plt.imshow(data2[..., 0].transpose(1,2,0).squeeze(), cmap=plt.cm.gray, vmin=0, vmax=1, interpolation="nearest")
            plt.axis('off')

            # function to update figure
            def updatefig(j):
                # set the data in the axesimage object 
                im0.set_array(redgreen[..., j].transpose(1,2,0).squeeze())
                im.set_array(data[..., j].transpose(1,2,0).squeeze())
                im2.set_array(data2[..., j].transpose(1,2,0).squeeze())
                return im, im2

            # kick off the animation
            ani = animation.FuncAnimation(fig, updatefig, frames=range(options['datasetOptions']['num_frames']), interval=200, repeat_delay=1000, blit=True)
            ani.save('images/%d' % ((i+1)*100+case_id) + '.gif', writer='imagemagick')
            plt.show()
# print statistics
print("  Test loss:\t{:.6f}".format(np.mean(history_batch)))
print("  Parameter count: {}".format(lasagne.layers.count_params(net)))

In [None]:
savefig = False
case_id = 0 # element of the batch to use for visualization

# convergence plot
plt.figure()
plt.plot(range(1,len(history_train)+1), history_train, label="loss")
plt.legend()
plt.show()

# visualize the prediction
visualize_prediction(batch, fut=predictions, fig=1, case_id=case_id, savefig=savefig)
    
visualize_flowmap(pred_filter, batch, predictions, input_seqlen, options['image_dim'], options['modelOptions']['dynamic_filter_size'][0], case_id, savefig=savefig)
# visualize_flowmapStereo(pred_filter, batch, predictions, input_seqlen, options['image_dim'], options['modelOptions']['dynamic_filter_size'][0], case_id, savefig=savefig)

In [None]:
import matplotlib.animation as animation

fig = plt.figure() # make figure
case_id = 1

# make axesimage object
# the vmin and vmax here are very important to get the color map correct
redgreen = np.zeros((3,options['image_dim'],options['image_dim'],options['datasetOptions']['num_frames']))
# import pdb; pdb.set_trace()
redgreen[0,:,:,:input_seqlen] = 1
redgreen[1,:,:,input_seqlen:] = 1

data = batch[case_id]

data2 = np.concatenate((np.zeros((options['batch_size'], 1, options['image_dim'], 
                                        options['image_dim'], input_seqlen)), 
                              predictions), axis=4)
data2 = np.concatenate((batch[..., :input_seqlen], 
                              predictions), axis=4)
data2 = data2[case_id]

plt.subplot(1,3,1)
im0 = plt.imshow(redgreen[..., 0].transpose(1,2,0).squeeze(), cmap=plt.cm.gray, vmin=0, vmax=1, interpolation="nearest")
plt.axis('off')
plt.subplot(1,3,2)
im = plt.imshow(data[..., 0].transpose(1,2,0).squeeze(), cmap=plt.cm.gray, vmin=0, vmax=1, interpolation="nearest")
plt.axis('off')
plt.subplot(1,3,3)
im2 = plt.imshow(data2[..., 0].transpose(1,2,0).squeeze(), cmap=plt.cm.gray, vmin=0, vmax=1, interpolation="nearest")
plt.axis('off')

# function to update figure
def updatefig(j):
    # set the data in the axesimage object 
    im0.set_array(redgreen[..., j].transpose(1,2,0).squeeze())
    im.set_array(data[..., j].transpose(1,2,0).squeeze())
    im2.set_array(data2[..., j].transpose(1,2,0).squeeze())
    return im, im2

# kick off the animation
ani = animation.FuncAnimation(fig, updatefig, frames=range(options['datasetOptions']['num_frames']), interval=100, repeat_delay=1000, blit=True)
ani.save(options['name'] + '.gif', writer='imagemagick')
plt.show()