# Initial exploration

The first step required to conduct the experiments was to convert Finn's architecture to Chainer's code. There's no automatic processes. To accomplish this task, a manual conversion was required. Below is some snippets of code illutrating the differences between TF code and Chainer's code.

## Data importation

The code used to import the dataset is praticaly the same, except that the actions, states and images files are extracted from the original TF binary file and store in the `/data/processed` folder.

In [None]:
csv_ref = []
for j in xrange(len(files)):
    logger.info("Creating data from tsrecords {0}/{1}".format(j+1, len(files)))
    raw, act, sta = sess.run([image_seq, action_seq, state_seq])
    ref = []
    ref.append(j)

    if create_img == 1:
        for k in xrange(raw.shape[0]):
            img = Image.fromarray(raw[k], 'RGB')
            img.save(out_dir + '/image_batch_' + str(j) + '_' + str(k) + '.png')
        ref.append('image_batch_' + str(j) + '_*' + '.png')
    else:
        ref.append('')

    np.save(out_dir + '/image_batch_' + str(j), raw)
    np.save(out_dir + '/action_batch_' + str(j), act)
    np.save(out_dir + '/state_batch_' + str(j), sta)

    ref.append('image_batch_' + str(j) + '.npy')
    ref.append('action_batch_' + str(j) + '.npy')
    ref.append('state_batch_' + str(j) + '.npy')
    csv_ref.append(ref)

logger.info("Writing the results into map file '{0}'".format('map.csv'))
with open(out_dir + '/map.csv', 'wb') as csvfile:
    writer = csv.writer(csvfile, quoting=csv.QUOTE_ALL)
    writer.writerow(['id', 'img_bitmap_path', 'img_np_path', 'action_np_path', 'state_np_path'])
    for row in csv_ref:
        writer.writerow(row)

Thus, this part of the importation need to be created only once since the files need to be extracted only one time. After that, it's necessary to import the actions, states and images in the code. Conveniently, they can be easily importer following the CSV map file.

In [None]:
logger.info("Fetching the models and inputs")
data_map = []
with open(data_dir + '/map.csv', 'rb') as f:
    reader = csv.reader(f)
    for row in reader:
        data_map.append(row)

if len(data_map) <= 1: # empty or only header
    logger.error("No file map found")
    exit()

# Load the images, actions and states
images = []
actions = []
states = []
for i in xrange(1, len(data_map)): # Exclude the header
    logger.info("Loading data {0}/{1}".format(i, len(data_map)-1))
    images.append(np.load(data_dir + '/' + data_map[i][2]))
    actions.append(np.load(data_dir + '/' + data_map[i][3]))
    states.append(np.load(data_dir + '/' + data_map[i][4]))

## Training the model

It's important to understand how TensorFlow works. TF create a computation graph first then runs it multiple time depedending of the number of iteration to realize. Therefore, one main difference between the original code and Chaine's code is how the model is created and modified between each operations.

In [None]:
import numpy as np
import tensorflow as tf
import chainer

In [None]:
# Tensorflow.
# The model is created through a 'scope' meaning that the variable used in the creation of this 'model' can only be 
# seen by this model.
with tf.variable_scope('model', reuse=None) as training_scope:
  images, actions, states = build_tfrecord_input(training=True)
  model = Model(images, actions, states, FLAGS.sequence_length,
                prefix='train')

with tf.variable_scope('val_model', reuse=None):
  val_images, val_actions, val_states = build_tfrecord_input(training=False)
  val_model = Model(val_images, val_actions, val_states,
                    FLAGS.sequence_length, training_scope, prefix='val')
    
# ...

# After building the computation graph, the model is trained via this training loop.
for itr in range(FLAGS.num_iterations):
  # feed_dict is a dictionary containing the placeholder (the variable replaced in each iteration of the loop)
  # and the values.
  # E.g: in the scope 'model' iter_num is replaced by the value of 'np.float32(itr)'
  feed_dict = {model.iter_num: np.float32(itr),
               model.lr: FLAGS.learning_rate}
    
  # Finally, to execute one epoch, 'sess.run' is called
  cost, _, summary_str = sess.run([model.loss, model.train_op, model.summ_op],
                                  feed_dict)
    
# Chainer.
# Before, in TF, images, actions, states, ... were passed directly to the model.
# Here, only the prefix and the model type is used
training_model = Model(
    is_cdna=model_type == 'CDNA',
    is_dna=model_type == 'DNA',
    is_stp=model_type == 'STP',
    prefix='train'
)
validation_model = Model(
    is_cdna=model_type == 'CDNA',
    is_dna=model_type == 'DNA',
    is_stp=model_type == 'STP',
    prefix='val'
)

# ...

for itr in xrange(epoch):
    # ...
    
    # Instead of using placeholders, we use the variable directly in Chainer
    # 'img_training_set', 'act_training_set' and 'sta_training_set' are random mini-batches of data 
    # fed in each iteration
    loss, psnr_all, summaries = training_model(
        img_training_set, 
        act_training_set, 
        sta_training_set, 
        itr, 
        schedsamp_k, 
        use_state, 
        num_masks, 
        context_frames
    ) 

## The models

In this first iteration, the main goal was to reproduce as exactly as possible TF code in Chainer. Thus, the same 'stateless' architecture found in TF was used in Chainer. Below is an abstract of some of the key differences. The code can be found up to commit `2346519b30f045181985e9d9dbceb0dc57214fa0`.

### Stateless LSTM

'Stateless' because the instance of the LSTM class is recreated at each iteration: reseting its internal state. To be able to continue the computation work betweem each iteration, the previous 'state', which contain the `cell state` and the `cell hidden state`

In [None]:
# TensorFlow
def basic_conv_lstm_cell(inputs, state, num_channels, filter_size=5, forget_bias=1.0, scope=None, reuse=None):
    c, h = tf.split(axis=3, num_or_size_splits=2, value=state)
    inputs_h = tf.concat(axis=3, values=[inputs, h])

    i_j_f_o = layers.conv2d(inputs_h,
                            4 * num_channels, [filter_size, filter_size],
                            stride=1,
                            activation_fn=None,
                            scope='Gates')

    i, j, f, o = tf.split(axis=3, num_or_size_splits=4, value=i_j_f_o)

    new_c = c * tf.sigmoid(f + forget_bias) + tf.sigmoid(i) * tf.tanh(j)
    new_h = tf.tanh(new_c) * tf.sigmoid(o)

return new_h, tf.concat(axis=3, values=[new_c, new_h])

# Chainer
def __call__(self, inputs, state, num_channels, filter_size=5, forget_bias=1.0):
    h, c = F.split_axis(state, indices_or_sections=2, axis=1)
    inputs_h = F.concat((inputs, h))

    i_j_f_o = L.Convolution2D(
        in_channels=inputs_h.shape[1], 
        out_channels=4*num_channels, 
        ksize=(filter_size, filter_size), 
        pad=filter_size/2
    )(inputs_h)

    i, j, f, o = F.split_axis(i_j_f_o, indices_or_sections=4, axis=1)

    new_c = c * F.sigmoid(f + forget_bias) + F.sigmoid(i) * F.tanh(j)
    new_h = F.tanh(new_c) * F.sigmoid(o)

    return new_h, F.concat((new_c, new_h))

### StatelessModel

This class is a wrapper around all the other models: StatelessCDNA, StatelessDNA and StatelessSTP. It's used to compute the loss and other common statistics between the models.

In [None]:
# TensorFlow
class Model(object):
  def __init__(self, images=None, actions=None, states=None, sequence_length=None, reuse_scope=None, prefix=None):
    #...
    
    gen_images, gen_states = construct_model(
        images,
        actions,
        states,
        iter_num=self.iter_num,
        k=FLAGS.schedsamp_k,
        use_state=FLAGS.use_state,
        num_masks=FLAGS.num_masks,
        cdna=FLAGS.model == 'CDNA',
        dna=FLAGS.model == 'DNA',
        stp=FLAGS.model == 'STP',
        context_frames=FLAGS.context_frames
    )
    
    # ...
    
    # L2 loss, PSNR for eval.
    loss, psnr_all = 0.0, 0.0
    for i, x, gx in zip(
        range(len(gen_images)), images[FLAGS.context_frames:],
        gen_images[FLAGS.context_frames - 1:]):
      recon_cost = mean_squared_error(x, gx)
      psnr_i = peak_signal_to_noise_ratio(x, gx)
      psnr_all += psnr_i
      summaries.append(
          tf.summary.scalar(prefix + '_recon_cost' + str(i), recon_cost))
      summaries.append(tf.summary.scalar(prefix + '_psnr' + str(i), psnr_i))
      loss += recon_cost

    for i, state, gen_state in zip(
        range(len(gen_states)), states[FLAGS.context_frames:],
        gen_states[FLAGS.context_frames - 1:]):
      state_cost = mean_squared_error(state, gen_state) * 1e-4
      summaries.append(
          tf.summary.scalar(prefix + '_state_cost' + str(i), state_cost))
      loss += state_cost
    summaries.append(tf.summary.scalar(prefix + '_psnr_all', psnr_all))
    self.psnr_all = psnr_all

    self.loss = loss = loss / np.float32(len(images) - FLAGS.context_frames)

    summaries.append(tf.summary.scalar(prefix + '_loss', loss))

    self.lr = tf.placeholder_with_default(FLAGS.learning_rate, ())

    self.train_op = tf.train.AdamOptimizer(self.lr).minimize(loss)
    self.summ_op = tf.summary.merge(summaries)

# Chainer
class Model(chainer.Chain):
    def __init__(self, is_cdna=True, is_dna=False, is_stp=False, prefix=None):
        # ...
    def __call__(self, images, actions=None, states=None, iter_num=-1.0, scheduled_sampling_k=-1, 
                 use_state=True, num_masks=10, num_frame_before_prediction=2):
        gen_images, gen_states = self.model(images, actions, states, iter_num, scheduled_sampling_k, 
                                            use_state, num_masks, fore_prediction)

    # L2 loss, PSNR for eval
    loss, psnr_all = 0.0, 0.0
    summaries = []
    for i, x, gx in zip(range(len(gen_images)), images[num_frame_before_prediction:], gen_images[num_frame_before_prediction - 1:]):
        x = variable.Variable(x)
        recon_cost = mean_squared_error(x, gx)
        psnr_i = peak_signal_to_noise_ratio(x, gx)
        psnr_all += psnr_i
        summaries.append(self.prefix + '_recon_cost' + str(i) + ': ' + str(recon_cost.data))
        summaries.append(self.prefix + '_psnr' + str(i) + ': ' + str(psnr_i.data))
        loss += recon_cost

    for i, state, gen_state in zip(range(len(gen_states)), states[num_frame_before_prediction:], gen_states[num_frame_before_prediction - 
        state = variable.Variable(state)
        state_cost = mean_squared_error(state, gen_state) * 1e-4
        summaries.append(self.prefix + '_state_cost' + str(i) + ': ' + str(state_cost.data))
        loss += state_cost

    summaries.append(self.prefix + '_psnr_all: ' + str(psnr_all.data))
    self.psnr_all = psnr_all
    self.loss = loss = loss / np.float32(len(images) - num_frame_before_prediction)
    summaries.append(self.prefix + '_loss: ' + str(loss.data))

    return self.loss, self.psnr_all, summaries

The network consist of seven LSTM cells and seven convolutions/deconvolutions that are usually similar between each variation of the model. Below is an example of a convolution, hidden LSTM cell and deconvolution.

In [None]:
# TensorFlow
enc0 = slim.layers.conv2d(prev_image, 32, [5, 5], stride=2, scope='scale1_conv1', normalizer_fn=tf_layers.layer_norm, 
                          normalizer_params={'scope': 'layer_norm1'})

hidden1, lstm_state1 = lstm_func(enc0, lstm_state1, lstm_size[0], scope='state1')
hidden1 = tf_layers.layer_norm(hidden1, scope='layer_norm2')

# ...

enc4 = slim.layers.conv2d_transpose(hidden5, hidden5.get_shape()[3], 3, stride=2, scope='convt1')

# Chainer
enc0 = L.Convolution2D(in_channels=3, out_channels=32, ksize=(5, 5), stride=2, pad=5/2)(prev_image)
# TensorFlow code use layer_normalization for normalize on the output convolution
enc0 = layer_normalization_conv_2d(enc0)

hidden1, lstm_state1 = self.stateless_lstm(inputs=enc0, state=lstm_state1, num_channels=32)
hidden1 = layer_normalization_conv_2d(hidden1)

# ...

enc4 = L.Deconvolution2D(in_channels=hidden5.shape[1], out_channels=hidden5.shape[1], ksize=(3,3), 
                         stride=2, outsize=(hidden5.shape[2]*2, hidden5.shape[3]*2), pad=3/2)(hidden5)

Finally each transformations outputed by the different type of models, are used to create a list of masks.

In [None]:
# TensorFlow
#...
masks = slim.layers.conv2d_transpose(
    enc6, num_masks + 1, 1, stride=1, scope='convt7')
masks = tf.reshape(masks, [-1, num_masks + 1])
masks = tf.nn.softmax(masks)
masks = tf.reshape(masks, [int(batch_size), int(img_height), int(img_width), num_masks + 1])
mask_list = tf.split(axis=3, num_or_size_splits=num_masks + 1, value=masks)
output = mask_list[0] * prev_image
for layer, mask in zip(transformed, mask_list[1:]):
  output += layer * mask
gen_images.append(output)

# Chainer
#...
masks = L.Deconvolution2D(in_channels=enc6.shape[1], out_channels=num_masks+1, ksize=(1,1), stride=1)(enc6)
masks = F.reshape(masks, (-1, num_masks + 1))
masks = F.softmax(masks)
masks = F.reshape(masks, (int(batch_size), num_masks+1, int(img_height), int(img_width))) # Previously num_mask at the end, but  on axis=1? ok!
mask_list = F.split_axis(masks, indices_or_sections=num_masks+1, axis=1) # Previously axis=3 but our channels are on axis=1 ?
output = F.scale(prev_image, mask_list[0], axis=0)
for layer, mask in zip(transformed, mask_list[1:]):
    output += F.scale(layer, mask, axis=0)
gen_images.append(output)

### StatelessCDNA

This model outputs multiple normalized convolution kernels to apply to the previous image to compute new pixels value

In [None]:
# TensorFlow
#...

def cdna_transformation(prev_image, cdna_input, num_masks, color_channels):
  batch_size = int(cdna_input.get_shape()[0])

  # Predict kernels using linear function of last hidden layer.
  cdna_kerns = slim.layers.fully_connected(
      cdna_input,
      DNA_KERN_SIZE * DNA_KERN_SIZE * num_masks,
      scope='cdna_params',
      activation_fn=None)

  # Reshape and normalize.
  cdna_kerns = tf.reshape(
      cdna_kerns, [batch_size, DNA_KERN_SIZE, DNA_KERN_SIZE, 1, num_masks])
  cdna_kerns = tf.nn.relu(cdna_kerns - RELU_SHIFT) + RELU_SHIFT
  norm_factor = tf.reduce_sum(cdna_kerns, [1, 2, 3], keep_dims=True)
  cdna_kerns /= norm_factor

  cdna_kerns = tf.tile(cdna_kerns, [1, 1, 1, color_channels, 1])
  cdna_kerns = tf.split(axis=0, num_or_size_splits=batch_size, value=cdna_kerns)
  prev_images = tf.split(axis=0, num_or_size_splits=batch_size, value=prev_image)

  # Transform image.
  transformed = []
  for kernel, preimg in zip(cdna_kerns, prev_images):
    kernel = tf.squeeze(kernel)
    if len(kernel.get_shape()) == 3:
      kernel = tf.expand_dims(kernel, -1)
    conv = tf.nn.depthwise_conv2d(preimg, kernel, [1, 1, 1, 1], 'SAME')
    transformed.append(conv)
  transformed = tf.concat(axis=0, values=transformed)
  transformed = tf.split(axis=3, num_or_size_splits=num_masks, value=transformed)


# Chainer
#... 
cdna_input = F.reshape(hidden5, (int(batch_size), -1))
cdna_kerns = L.Linear(in_size=None, out_size=5*5*num_masks)(cdna_input)

# Reshape and normalize
#cdna_kerns = np.reshape(cdna_kerns, (batch_size, 5, 5, 1, num_masks))
cdna_kerns = F.reshape(cdna_kerns, (batch_size, 1, 5, 5, num_masks))
cdna_kerns = F.relu(cdna_kerns - 1e-12) + 1e-12
norm_factor = sum.sum(cdna_kerns, (1, 2, 3), keepdims=True)

# The norm factor is broadcasted to match the shape difference
axis_reshape = 0
norm_factor_new_shape = tuple([1] * axis_reshape + list(norm_factor.shape) +
                               [1] * (len(cdna_kerns.shape) - axis_reshape - len(norm_factor.shape)))
norm_factor = F.reshape(norm_factor, norm_factor_new_shape)
norm_factor_broadcasted = F.broadcast_to(norm_factor, cdna_kerns.shape)
cdna_kerns = cdna_kerns / norm_factor_broadcasted

cdna_kerns = F.tile(cdna_kerns, (1,3,1,1,1))
cdna_kerns = F.split_axis(cdna_kerns, indices_or_sections=batch_size, axis=0)
prev_images = F.split_axis(prev_image, indices_or_sections=batch_size, axis=0)

# Transform image
tmp_transformed = []
for kernel, preimg in zip(cdna_kerns, prev_images):
    kernel = F.squeeze(kernel)
    if len(kernel.shape) == 3:
        kernel = kernel[..., np.keepdims]
    conv = L.DepthwiseConvolution2D(in_channels=preimg.shape[1], channel_multiplier=kernel.shape[3], ksize=(kernel.shape[1], , stride=1, pad=kernel.shape[1]/2)(preimg)
    tmp_transformed.append(conv)
tmp_transformed = F.concat(tmp_transformed, axis=0)
tmp_transformed = F.split_axis(tmp_transformed, indices_or_sections=num_masks, axis=1) # Previously axis=3 but our channels are on 
transformed = transformed + list(tmp_transformed)

### StatelessDNA

This model outputs a distribution over locations in the previous frame for each pixel in the new frame. 
The predicted value becomes the expectation under the distribution.

In [None]:
# TensorFlow
def dna_transformation(prev_image, dna_input):
  # Construct translated images.
  prev_image_pad = tf.pad(prev_image, [[0, 0], [2, 2], [2, 2], [0, 0]])
  image_height = int(prev_image.get_shape()[1])
  image_width = int(prev_image.get_shape()[2])

  inputs = []
  for xkern in range(DNA_KERN_SIZE):
    for ykern in range(DNA_KERN_SIZE):
      tmp = tf.slice(prev_image_pad, [0, xkern, ykern, 0], [-1, image_height, image_width, -1])
      tmp = tf.expand_dims(tmp, [3])
      inputs.append(tmp)
  inputs = tf.concat(axis=3, values=inputs)

  # Normalize channels to 1.
  kernel = tf.nn.relu(dna_input - RELU_SHIFT) + RELU_SHIFT
  kernel_sum = tf.reduce_sum(kernel, [3], keep_dims=True)
  kernel = kernel / kernel_sum
  kernel = tf.expand_dims(kernel, [4])
  kernel = tf.reduce_sum(kernel * inputs, [3], keep_dims=False)
  return kernel

# Chainer
# ...
prev_image_pad = F.pad(prev_image, pad_width=[[0,0], [0,0], [2,2], [2,2]], mode='constant', constant_values=0)
kernel_inputs = []
for xkern in range(5):
    for ykern in range(5):
        #tmp = F.get_item(prev_image_pad, [prev_image_pad.shape[0], prev_image_pad.shape[0], xkern:img_height, ykern:img_width])
        tmp = F.get_item(prev_image_pad, list([slice(0,prev_image_pad.shape[0]), slice(0,prev_image_pad.shape[1]), slice(), slice(ykern,img_width)]))
        # ** Added this operation to make sure the size was still the original one!
        tmp = F.pad(tmp, [[0,0], [0,0], [0, xkern], [0, ykern]], mode='constant', constant_values=0)
        tmp = F.expand_dims(tmp, axis=1) # Previously axis=3 but our channel is on axis=1 ? ok!
        kernel_inputs.append(tmp.data)
kernel_inputs = F.concat(kernel_inputs, axis=1) # Previously axis=3 but our channel us on axis=1 ? ok!

# Normalize channels to 1
kernel_normalized = F.relu(enc7 - 1e-12) + 1e+12
kernel_normalized_sum = F.sum(kernel_normalized, axis=1, keepdims=True) # Previously axis=3 but our channel are on axis 1 ? ok!
kernel_normalized = broadcasted_division(kernel_normalized, kernel_normalized_sum)
kernel_normalized = F.expand_dims(kernel_normalized, axis=2)
kernel_normalized = F.scale(kernel_inputs, kernel_normalized, axis=0)
kernel_normalized = F.sum(kernel_normalized, axis=1, keepdims=False)
transformed = [kernel_normalized]

### StatelessSTP

This model outputs the parameters of multiple affine transformations to apply to the previous image.

In [None]:
# TensorFloat
def stp_transformation(prev_image, stp_input, num_masks):
  # Only import spatial transformer if needed.
  from spatial_transformer import transformer

  identity_params = tf.convert_to_tensor(
      np.array([1.0, 0.0, 0.0, 0.0, 1.0, 0.0], np.float32))
  transformed = []
  width, height = prev_image.get_shape()[1:3]
  for i in range(num_masks - 1):
    params = slim.layers.fully_connected(
        stp_input, 6, scope='stp_params' + str(i),
        activation_fn=None) + identity_params
    trans = transformer(prev_image, params)
    transformed.append(trans)

  return transformed

# Chainer
#...
stp_input0 = F.reshape(hidden5, (int(batch_size), -1))
stp_input1 = L.Linear(in_size=None, out_size=100)(stp_input0)
identity_params = np.array([[1.0, 0.0, 0.0, 0.0, 1.0, 0.0]], dtype=np.float32)
identity_params = np.repeat(identity_params, int(batch_size), axis=0)
identity_params = variable.Variable(identity_params)

stp_transformations = []
for i in range(num_masks-1):
    params = L.Linear(in_size=None, out_size=6)(stp_input1) + identity_params
    params = F.reshape(params, (int(params.shape[0]), 2, 3))
    grid = F.spatial_transformer_grid(params, (prev_image.shape[2], prev_image.shape[3]))
    trans = F.spatial_transformer_sampler(prev_image, grid)
    stp_transformations.append(trans)

transformed += stp_transformations