# Sources
The architecture utilises the ideas of:<br>
Khamis, Sameh, et al. (Google / PerceptiveIO) "Stereonet: Guided hierarchical refinement for real-time edge-aware depth prediction". ECCV 2018.<br>
Zhang, Yinda, et al. (Google / PerceptiveIO) "Activestereonet: end-to-end self-supervised learning for active stereo systems". ECCV 2018.<br>
<br>
Training data is:<br>
Mayer, N. et al. (Thomas Brox, University of Freiburg) "A Large Dataset to Train Convolutional Networks for Disparity, Optical Flow, and Scene Flow Estimation". CVPR 2016.

# Import

In [None]:
import tensorflow as tf
import generator as gen
import os, sys
import matplotlib as mp
%matplotlib inline
import matplotlib.pyplot as plt
import math

from model import stereonet

import numpy as np

%load_ext autoreload
%autoreload 2

In [None]:
from tensorflow.python.client import device_lib
#device_lib.list_local_devices()

# Parameters and Flags

In [None]:
### For inference ###

# path to input 
inference_path          = 'inputs'

# path to save
save_path_depth         = 'outputs'
save_path_disparity     = 'outputs'
save_path_invalidation  = 'outputs'

# scale factor for output
depth_scale             = 1
disparity_scale         = 256
invalidation_scale      = 1024

disp_min = 0
disp_max = 255

full_w = 960
full_h = 540
resolution = (full_h,full_w)

# checkpoint to load
# checkpoint_path         = './checkpoints/training_low_lr_2e-06_step_144755val_99.ckpt'

is_from_beginning = True
n_batch = 1
n_views = 1
is_training = False


# Dataset Input

In [None]:
tf.reset_default_graph()

n_every = 1
sanity_check=True
path_to_left_rgb = '/ben/kgx_nfs/data/external/sceneflow/frames_cleanpass/35mm_focallength/scene_forwards/slow/left/'
path_to_right_rgb = '/ben/kgx_nfs/data/external/sceneflow/frames_cleanpass/35mm_focallength/scene_forwards/slow/right'
left_rgb_names_train = [os.path.join(path_to_left_rgb,each_name) for each_name in os.listdir(path_to_left_rgb)][0::n_every]
right_rgb_names_train = [os.path.join(path_to_right_rgb,each_name) for each_name in os.listdir(path_to_right_rgb)][0::n_every]

path_to_disparity = '/ben/kgx_nfs/data/external/sceneflow/disparity/35mm_focallength/scene_forwards/slow/left'
disparity_names_train = [os.path.join(path_to_disparity,each_name) for each_name in os.listdir(path_to_disparity)][0::n_every]

left_rgb_names_train.sort()
right_rgb_names_train.sort()
disparity_names_train.sort()

print("sanity check started.")
n_train_data = len(left_rgb_names_train)

assert len(left_rgb_names_train) == len(right_rgb_names_train) == len(disparity_names_train), "Error : number of files doesn't match"

for idx in range(n_train_data):

    left_rgb = left_rgb_names_train[idx]
    right_rgb = right_rgb_names_train[idx]
    disparity = disparity_names_train[idx]

    assert left_rgb.split('/')[-1].split('.')[-2] == right_rgb.split('/')[-1].split('.')[-2] == disparity.split('/')[-1].split('.')[-2],"Error : wrong file name match"

    if idx % 500 == 0:
        print(idx,"out of",n_train_data,"completed")

print("sanity check finished.")

# Setting up the Graph

In [None]:
# training generator
training_generator = iter(gen.TrainingGeneratorStereoNet(\
    left_rgb_names_train, right_rgb_names_train, disparity_names_train))
generator_data_type = (tf.float32, tf.float32, tf.float32)
training_set = tf.data.Dataset.from_generator(lambda: training_generator, generator_data_type)
training_set = training_set.batch(n_batch)
buffer_size = 1
training_set = training_set.prefetch(buffer_size)
# get training iterators
training_iterator = training_set.make_initializable_iterator()
# get data placeholders
rgb_left_train, rgb_right_train, disparity_train = training_iterator.get_next()

# Sizes at compile time
rgb_left_train.set_shape(tf.TensorShape([n_batch, full_h, full_w, 3]))
rgb_right_train.set_shape(tf.TensorShape([n_batch, full_h, full_w, 3]))
disparity_train.set_shape(tf.TensorShape([n_batch, full_h, full_w, 1]))


# Network and Outputs

In [None]:
# Setting up the StereoNet and the InvalidationNet
full_res_disparity_map, _, disparity_map_1_2, disparity_map_1_4, disparity_map_1_8 = stereonet(rgb_left_train,rgb_right_train, max_disp_lowres=18, is_training=True)



# Loss

In [None]:
# Get disparity GT
disparity_reference_1_8 = tf.image.resize_images(disparity_train,(int(math.ceil(float(full_h)/8)),int(full_w/8)),align_corners=True) / 8
disparity_reference_1_4 = tf.image.resize_images(disparity_train,(int(full_h/4),int(full_w/4)),align_corners=True) / 4
disparity_reference_1_2 = tf.image.resize_images(disparity_train,(int(full_h/2),int(full_w/2)),align_corners=True) / 2
disparity_reference_1_1 = disparity_train

disparity_1_8 = disparity_map_1_8
disparity_1_4 = disparity_map_1_4
disparity_1_2 = disparity_map_1_2
disparity_1_1 = full_res_disparity_map

#def barron_loss(x, a, c, e=1e-5):
#    b = tf.abs(2.-a) + e
#    d = tf.where(tf.greater_equal(a, 0.), a+e, a-e)
#    return b/d * (tf.pow(tf.square(x/c)/b+1., 0.5 * d)-1.)

#alpha = 1.0
#c     = 2.0

def barron_spec(x):    
    return tf.sqrt(tf.square(x/2.0)+1.)-1.

# normalization parameters for disparities
disp_max_1_8 = tf.reduce_max(disparity_reference_1_8)
disp_max_1_4 = tf.reduce_max(disparity_reference_1_4)
disp_max_1_2 = tf.reduce_max(disparity_reference_1_2)
disp_max_1_1 = tf.reduce_max(disparity_reference_1_1)
disp_max_const = 1000.

#diff_1_8 = (disparity_1_8 - disparity_reference_1_8) / disp_max_const
#diff_1_4 = (disparity_1_4 - disparity_reference_1_4) / disp_max_const
#diff_1_2 = (disparity_1_2 - disparity_reference_1_2) / disp_max_const
#diff_1_1 = (disparity_1_1 - disparity_reference_1_1) / disp_max_const

diff_1_8 = (disparity_1_8 - disparity_reference_1_8)
diff_1_4 = (disparity_1_4 - disparity_reference_1_4)
diff_1_2 = (disparity_1_2 - disparity_reference_1_2)
diff_1_1 = (disparity_1_1 - disparity_reference_1_1)

# get pixel count to normalize loss contribution on every level
pixels_1_8 = int(math.ceil(float(full_h)/8)) * int(full_w/8)
pixels_1_4 = int(full_h/4) * int(full_w/4)
pixels_1_2 = int(full_h/2) * int(full_w/2)
pixels_1_1 = int(full_h) * int(full_w)

#loss_1_8 = tf.reduce_sum(barron_spec(diff_1_8)) / pixels_1_8 
#loss_1_4 = tf.reduce_sum(barron_spec(diff_1_4)) / pixels_1_4
#loss_1_2 = tf.reduce_sum(barron_spec(diff_1_2)) / pixels_1_2
#loss_1_1 = tf.reduce_sum(barron_spec(diff_1_1)) / pixels_1_1

loss_1_8 = tf.reduce_sum(barron_spec(diff_1_8))
loss_1_4 = tf.reduce_sum(barron_spec(diff_1_4))
loss_1_2 = tf.reduce_sum(barron_spec(diff_1_2))
loss_1_1 = tf.reduce_sum(barron_spec(diff_1_1))

#loss_1_8 = tf.reduce_sum(tf.square(disparity_1_8 - disparity_reference_1_8))
#loss_1_4 = tf.reduce_sum(tf.square(disparity_1_4 - disparity_reference_1_4))
#loss_1_2 = tf.reduce_sum(tf.square(disparity_1_2 - disparity_reference_1_2))
#loss_1_1 = tf.reduce_sum(tf.square(disparity_1_1 - disparity_reference_1_1))

loss = loss_1_8 + loss_1_4 + loss_1_2 + loss_1_1

# Training

In [None]:
# Initializors
learning_rate = tf.placeholder(tf.float32)
optimizer = tf.train.RMSPropOptimizer(learning_rate = learning_rate)
train_op = optimizer.minimize(loss=loss)

config = tf.ConfigProto()
config.gpu_options.allow_growth=True
sess = tf.Session(config=config)

train_summary_op = tf.summary.scalar("train_loss",loss)
writer = tf.summary.FileWriter("./logs")

init = tf.global_variables_initializer()
sess.run(init)
sess.run(training_iterator.initializer)
train_init_op = training_iterator.make_initializer(training_set)
sess.run(train_init_op)
#img_left = sess.run(rgb_left_train)
#print img_left.shape

# Checkpoints and Visualizations

In [None]:
if is_from_beginning:
    lr = 1e-4
    step = 0
    val_step = 0
    losses = np.zeros((1,3))
else:
    saver1 = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))
    saver1.restore(sess, './checkpoints/stereoNet_testrun_9.69773729788e-07_step_88000val_0.ckpt')
    lr = 9.69773729788e-07
    step = 88000
    val_step = 0
    #9.69773729788e-07_step_88000val_0.ckpt

In [None]:
n_epoch = 100
n_iter = 200000
show_all_n_every_step = 10000
save_n_every_step = 2000
show_n_every_step = 100
show_n_losses = 10

In [None]:
def display_images(tuple_of_images,tuple_of_titles,statement,idx_in_batch=0):
    print(statement)
    n_images = len(tuple_of_images)
    
    for each_idx in range(n_images):
        
        plt.figure(figsize=(15,60))
        plt.title(tuple_of_titles[each_idx])
        plt.axis('off')
        if n_images == 1:
            each_image = tuple_of_images.squeeze()
        else:
            each_image = tuple_of_images[each_idx][idx_in_batch].squeeze()
        plt.imshow(each_image)
        #plt.imshow(each_image,cmap='gray')

        plt.show()

In [None]:
for each_epoch in range(n_epoch):
    for each_iter in range(n_iter):
        
        run_list = [train_op,
                    train_summary_op,
                    loss, loss_1_8, loss_1_1,
                    rgb_left_train, disparity_reference_1_1,
                    disparity_1_8, disparity_1_4, disparity_1_2, disparity_1_1
                   ]
        
        output = sess.run(run_list, feed_dict={learning_rate:lr})
        
        train_summary = output[1]
        loss_out, loss_1_8_out, loss_1_1_out = output[2:5]
        rgb_out, disparity_out = output[5:7]
        disparity_1_8_out, disparity_1_4_out, disparity_1_2_out, disparity_1_1_out = output[7:11]
                
        writer.add_summary(train_summary,int(step))
        if step % show_n_losses == 0:
            print "step :",step,"loss :",loss_out, " (",loss_1_8_out,"+",loss_1_1_out," )"
            #print np.amin(disparity_1_8_out), np.amax(disparity_1_8_out)
            #print np.amin(disparity_1_1_out), np.amax(disparity_1_1_out)
            losses = np.append(losses, [[loss_out, loss_1_8_out, loss_1_1_out]], axis=0)
    
        if step != 0 and step % save_n_every_step == 0:
            lr *= 0.9
            saver = tf.train.Saver(var_list = tf.global_variables())
            save_path = saver.save(sess,"checkpoints/stereoNet_testrun_" + str(lr) + "_step_" + str(step) + "val_" + str(val_step) + '.ckpt')
        
        if step != 0 and step % show_n_every_step == 0:
            #print disparity_1_8_out, disparity_1_4_out, disparity_1_2_out, disparity_1_1_out
            display_images((rgb_out,disparity_out,
                                 disparity_1_8_out
                                ),
                                ('original img','disparity_gt',
                                 'disparity_raw 1/8'
                                ),
                                'testrun')
            
            mp.pyplot.plot(losses[1:,1])
            mp.pyplot.legend(['loss_1_8_out'])
            fig = plt.figure()
            mp.pyplot.plot(losses[1:,2])
            mp.pyplot.legend(['loss_1_1_out'])
            plt.show()
            
        if step != 0 and step % show_all_n_every_step == 0:
            #print disparity_1_8_out, disparity_1_4_out, disparity_1_2_out, disparity_1_1_out
            display_images((rgb_out,disparity_out,
                                 disparity_1_8_out,
                                 disparity_1_4_out,
                                 disparity_1_2_out,
                                 disparity_1_1_out
                                ),
                                ('original img','disparity_gt',
                                 'disparity_raw 1/8',
                                 'disparity 1/4',
                                 'disparity 1/2',
                                 'disparity 1/1'
                                ),
                                'testrun')
        step += 1
        
        

In [None]:
util.display_images((rgb_out,disparity_out,
                                 disparity_1_8_out,
                                 disparity_1_4_out,
                                 disparity_1_2_out,
                                 disparity_1_1_out
                                ),
                                ('original img','disparity_gt',
                                 'disparity_raw 1/8',
                                 'disparity 1/4',
                                 'disparity 1/2',
                                 'disparity 1/1'
                                ),
                                'testrun')