In [139]:
import tensorflow as tf
from tensorflow.python.ops import init_ops
from tensorflow.contrib.layers.python.layers import regularizers

slim = tf.contrib.slim
tf.reset_default_graph()
trunc_normal = lambda stddev: init_ops.truncated_normal_initializer(0.0, stddev)

In [140]:
# Contants
image_channels = 3
time_frames_to_consider = 4
heigth_train= 32
width_train= 32
heigth_test= 210
width_test= 160
# +1 for input from previous layer !
scale_level_feature_maps= [[128, 256, 128, 3],
                           [128, 256, 128, 3],
                           [128, 256, 512, 256, 128, 3],
                           [128, 256, 512, 256, 128, 3]]
# as size of image increase in scaling ... conv layer increases !
scale_level_kernel_size = [ 
                            [3, 3, 3, 3],
                            [5, 3, 3, 5],
                            [5, 3, 3, 3, 3, 5],
                            [7, 5, 5, 5, 5, 7]
                          ]
# regularizer !
l2_val = 0.00005
# Adam optimizer !
adam_learning_rate = 0.0004

## ===================  COPIED CODE ==========================
#
#  TENSORBOARD VISUALIZATION FOR SHARPNESS AND (Peak Signal to Noise Ratio){PSNR}
#

def psnr_error(gen_frames, gt_frames):
    """
    Computes the Peak Signal to Noise Ratio error between the generated images and the ground
    truth images.
    @param gen_frames: A tensor of shape [batch_size, height, width, 3]. The frames generated by the
                       generator model.
    @param gt_frames: A tensor of shape [batch_size, height, width, 3]. The ground-truth frames for
                      each frame in gen_frames.
    @return: A scalar tensor. The mean Peak Signal to Noise Ratio error over each frame in the
             batch.
    """
    shape = tf.shape(gen_frames)
    num_pixels = tf.to_float(shape[1] * shape[2] * shape[3])
    square_diff = tf.square(gt_frames - gen_frames)

    batch_errors = 10 * log10(1 / ((1 / num_pixels) * tf.reduce_sum(square_diff, [1, 2, 3])))
    return tf.reduce_mean(batch_errors)

def sharp_diff_error(gen_frames, gt_frames):
    """
    Computes the Sharpness Difference error between the generated images and the ground truth
    images.
    @param gen_frames: A tensor of shape [batch_size, height, width, 3]. The frames generated by the
                       generator model.
    @param gt_frames: A tensor of shape [batch_size, height, width, 3]. The ground-truth frames for
                      each frame in gen_frames.
    @return: A scalar tensor. The Sharpness Difference error over each frame in the batch.
    """
    shape = tf.shape(gen_frames)
    num_pixels = tf.to_float(shape[1] * shape[2] * shape[3])

    # gradient difference
    # create filters [-1, 1] and [[1],[-1]] for diffing to the left and down respectively.
    # TODO: Could this be simplified with one filter [[-1, 2], [0, -1]]?
    pos = tf.constant(np.identity(3), dtype=tf.float32)
    neg = -1 * pos
    filter_x = tf.expand_dims(tf.pack([neg, pos]), 0)  # [-1, 1]
    filter_y = tf.pack([tf.expand_dims(pos, 0), tf.expand_dims(neg, 0)])  # [[1],[-1]]
    strides = [1, 1, 1, 1]  # stride of (1, 1)
    padding = 'SAME'

    gen_dx = tf.abs(tf.nn.conv2d(gen_frames, filter_x, strides, padding=padding))
    gen_dy = tf.abs(tf.nn.conv2d(gen_frames, filter_y, strides, padding=padding))
    gt_dx = tf.abs(tf.nn.conv2d(gt_frames, filter_x, strides, padding=padding))
    gt_dy = tf.abs(tf.nn.conv2d(gt_frames, filter_y, strides, padding=padding))

    gen_grad_sum = gen_dx + gen_dy
    gt_grad_sum = gt_dx + gt_dy

    grad_diff = tf.abs(gt_grad_sum - gen_grad_sum)

    batch_errors = 10 * log10(1 / ((1 / num_pixels) * tf.reduce_sum(grad_diff, [1, 2, 3])))
    return tf.reduce_mean(batch_errors)

## =================== COPIED CODE ENDS ======================


def l2_loss(generated_frames, expected_frames):
    losses = []
    for each_scale_gen_frames, each_scale_exp_frames in zip(generated_frames, expected_frames):
        losses.append(tf.nn.l2_loss(tf.subtract(each_scale_gen_frames, each_scale_exp_frames)))
    
    loss = tf.reduce_mean(tf.stack(losses))
    return loss

def gdl_loss(generated_frames, expected_frames, alpha=2):
    """
    difference with side pixel and below pixel
    """
    scale_losses = []
    for i in xrange(len(generated_frames)):
        # create filters [-1, 1] and [[1],[-1]] for diffing to the left and down respectively.
        pos = tf.constant(np.identity(3), dtype=tf.float32)
        neg = -1 * pos
        filter_x = tf.expand_dims(tf.stack([neg, pos]), 0)  # [-1, 1]
        filter_y = tf.stack([tf.expand_dims(pos, 0), tf.expand_dims(neg, 0)])  # [[1],[-1]]
        strides = [1, 1, 1, 1]  # stride of (1, 1)
        padding = 'SAME'

        gen_dx = tf.abs(tf.nn.conv2d(generated_frames[i], filter_x, strides, padding=padding))
        gen_dy = tf.abs(tf.nn.conv2d(generated_frames[i], filter_y, strides, padding=padding))
        gt_dx = tf.abs(tf.nn.conv2d(expected_frames[i], filter_x, strides, padding=padding))
        gt_dy = tf.abs(tf.nn.conv2d(expected_frames[i], filter_y, strides, padding=padding))

        grad_diff_x = tf.abs(gt_dx - gen_dx)
        grad_diff_y = tf.abs(gt_dy - gen_dy)

        scale_losses.append(tf.reduce_sum((grad_diff_x ** alpha + grad_diff_y ** alpha)))

    # condense into one tensor and avg
    return tf.reduce_mean(tf.stack(scale_losses))

def total_loss(generated_frames, expected_frames, lambda_gdl=1.0, lambda_l2=1.0):
    total_loss_cal = (lambda_gdl * gdl_loss(generated_frames, expected_frames) + 
                     lambda_l2 * l2_loss(generated_frames, expected_frames))
    return total_loss_cal

In [141]:
class GenerativeNetwork:
    def __init__(self,heigth_train, width_train, heigth_test, width_test, scale_level_feature_maps, scale_level_kernel_size):
        self.heigth_train = heigth_train
        self.width_train = width_train
        self.heigth_test = heigth_test
        self.width_test = width_test

        self.scale_level_feature_maps = scale_level_feature_maps
        self.scale_level_kernel_size = scale_level_kernel_size
        self.len_scale = len(self.scale_level_kernel_size)
        assert len(self.scale_level_feature_maps) == len(self.scale_level_kernel_size), "Length should be equal !"
        
        # Placeholders for inputs and outputs ... !
        self.input_train = tf.placeholder(dtype=tf.float32, shape=[None, self.heigth_train, self.width_train, time_frames_to_consider * image_channels])
        self.output_train = tf.placeholder(dtype=tf.float32, shape=[None, self.heigth_train, self.width_train, image_channels])
        self.input_test = tf.placeholder(dtype=tf.float32, shape=[None, self.heigth_test, self.width_test, time_frames_to_consider * image_channels])
        self.output_test = tf.placeholder(dtype=tf.float32, shape=[None, self.heigth_test, self.width_test, image_channels])
        
        self.each_scale_predication_train = []
        self.each_scale_ground_truth_train = []
        self.each_scale_predication_test = []
        self.each_scale_ground_truth_test = []
        
        self.create_graph(self.input_train, self.output_train, heigth_train, width_train, 
                          self.each_scale_predication_train, 
                          self.each_scale_ground_truth_train,
                          reuse=None)
        
        # reuse graph at time of test !
        self.create_graph(self.input_train, self.output_train, heigth_test, width_test, 
                          self.each_scale_predication_test,
                          self.each_scale_ground_truth_test,
                          reuse=True)
        
        self.loss()
        
        # print self.each_scale_predication_train
        # print self.each_scale_ground_truth_train
        # print self.each_scale_predication_test
        # print self.each_scale_ground_truth_test
        
    def rescale_image(self, scaling_factor, heigth, width, input_data, ground_truths, last_generated_frame):
        """
        scaling_factor, heigth, width = values
        input_data, ground_truths = Tensors
        """
        rescaled_heigth = int(scaling_factor * heigth)
        rescaled_width = int(scaling_factor * width)
        assert rescaled_heigth != 0 and rescaled_width != 0, "scaling factor should not be zero !"
        input_reshaped = tf.image.resize_images(input_data, [rescaled_heigth, rescaled_width])
        ground_truths_reshaped = tf.image.resize_images(ground_truths, [rescaled_heigth, rescaled_width])
        last_generated_frame_reshaped = None
        if last_generated_frame!=None:
            last_generated_frame_reshaped = tf.image.resize_images(last_generated_frame, [rescaled_heigth, rescaled_width])
        return (input_reshaped, ground_truths_reshaped, last_generated_frame_reshaped)
    
    def create_graph(self, input_data, ground_truths, heigth, width, 
                     predicated_at_each_scale_tensor, ground_truth_at_each_scale_tensor, reuse):
                
        # for each scale ... 
        for each_scale in range(self.len_scale):
            conv_counter = 0 
            with tf.variable_scope('scale_'+str(each_scale),reuse=reuse):
                # scaling create [1/64, 1/32, 1/16, 1/4]
                scaling_factor = 1.0 / (2**(self.len_scale - 1 - each_scale))
                last_generated_frame = None
                if each_scale > 0:
                    last_generated_frame = predicated_at_each_scale_tensor[each_scale-1]
                
                input_reshaped, ground_truths_reshaped, last_generated_frame_reshaped = self.rescale_image(scaling_factor, heigth, width, input_data, ground_truths, last_generated_frame)
                
                # append last scale output 
                if each_scale > 0:
                    input_reshaped = tf.concat([input_reshaped, last_generated_frame_reshaped],axis=3)
                
                # print (input_reshaped, ground_truths_reshaped)
                predication = input_reshaped
                
                # for each conv layers in that scale ... 
                feature_maps = scale_level_feature_maps[each_scale]
                kernel_size = scale_level_kernel_size[each_scale]
                
                assert len(feature_maps)==len(kernel_size), "Length should be equal !"
                for index, (each_filter, each_kernel) in enumerate(zip(feature_maps, kernel_size)): 
                    with tf.variable_scope('conv_'+str(conv_counter),reuse=reuse):
                        conv_counter += 1
                        activiation = tf.nn.relu
                        # last layer tanh !
                        if index==(len(kernel_size)-1):
                            activiation = tf.nn.tanh
                        predication = slim.conv2d(predication, each_filter, [each_kernel, each_kernel], 
                                              weights_initializer=trunc_normal(0.01),
                                              weights_regularizer=regularizers.l2_regularizer(l2_val),
                                              activation_fn=activiation)
                
                        
                # APPEND LAST GENERATED FRAME
                predicated_at_each_scale_tensor.append(predication)
                ground_truth_at_each_scale_tensor.append(ground_truths_reshaped)
                
    def loss(self):
        self.combined_loss = total_loss(self.each_scale_predication_train, self.each_scale_ground_truth_train)
        self.optimizer = tf.train.AdamOptimizer(adam_learning_rate)
        global_step = tf.Variable(0,name="global_step_var",trainable=False)
        self.step = self.optimizer.minimize(self.combined_loss, global_step=global_step)

In [142]:
g = GenerativeNetwork(heigth_train, width_train, heigth_test, width_test, scale_level_feature_maps, scale_level_kernel_size)

In [57]:
tf.concat?

In [32]:
slim.conv2d?

In [None]:
import numpy as np
pos = tf.constant(np.identity(3), dtype=tf.float32)
neg = -1 * pos
sess = tf.Session()
print sess.run(pos)
print sess.run(tf.expand_dims(tf.stack([neg, pos]), 0))
print sess.run(tf.expand_dims(pos, 0))
print sess.run(tf.stack([tf.expand_dims(pos, 0), tf.expand_dims(neg, 0)]))
x = tf.constant(np.arange(4*2*3).reshape(1,4,2,3), dtype=tf.float32)
y = tf.constant(np.arange(24,24+4*2*3).reshape(1,4,2,3), dtype=tf.float32)
print sess.run(x)
print ("----")
print sess.run(y)
pos = tf.constant(np.identity(3), dtype=tf.float32)
neg = -1 * pos
filter_x = tf.expand_dims(tf.stack([neg, pos]), 0)  # [-1, 1]
filter_y = tf.stack([tf.expand_dims(pos, 0), tf.expand_dims(neg, 0)])  # [[1],[-1]]
strides = [1, 1, 1, 1]  # stride of (1, 1)
padding = 'SAME'
print (sess.run(filter_y)).shape
sess.run(tf.nn.conv2d(x, filter_x, strides, padding=padding))
sess.run(tf.nn.conv2d(x, filter_y, strides, padding=padding))