# Imports

In [None]:
import tensorflow as tf
import os
import matplotlib as mp
import numpy as np
import matplotlib.pyplot as plt

from stereonet.generator import TrainingGeneratorStereoNet
from refocus_algorithms.layered_dof_tf import layered_bluring
from blur_refinement.model import refnet_blur_refinement
from tf_utils import optimistic_restore

# Parameters and Flags

In [None]:
# path to save
#TODO

full_w = 960
full_h = 540
resolution = (full_h,full_w)

# checkpoint to load
stereonet_weigts = '/home/matthieu/kgx_nfs2/ben/stereo_net_refactored/checkpts/sceneFlow/stereoNet_ScenFlow_driving_1.7888498861e-07_step_120000.ckpt'

# checkpoint_path         = './checkpoints/training_low_lr_2e-06_step_144755val_99.ckpt'

is_from_beginning = True
n_batch = 1
n_views = 1
is_training = False

target_disparity = 128
blur_magnitude = 0.1
max_disp = 300
min_disp = 0
from_scale=1

# Parsing files

In [None]:
n_every = 1
sanity_check=True
path_to_left_rgb = '/home/matthieu/kgx_nfs2/data/external/sceneflow/frames_cleanpass/35mm_focallength/scene_mixed_train/left/'
path_to_right_rgb = '/home/matthieu/kgx_nfs2/data/external/sceneflow/frames_cleanpass/35mm_focallength/scene_mixed_train/right'
left_rgb_names_train = [os.path.join(path_to_left_rgb,each_name) for each_name in os.listdir(path_to_left_rgb)][0::n_every]
right_rgb_names_train = [os.path.join(path_to_right_rgb,each_name) for each_name in os.listdir(path_to_right_rgb)][0::n_every]

path_to_disparity = '/home/matthieu/kgx_nfs2/data/external/sceneflow/disparity/35mm_focallength/scene_mixed_train'
disparity_names_train = [os.path.join(path_to_disparity,each_name) for each_name in os.listdir(path_to_disparity)][0::n_every]

left_rgb_names_train.sort()
right_rgb_names_train.sort()
disparity_names_train.sort()

print("sanity check started.")
n_train_data = len(left_rgb_names_train)

assert len(left_rgb_names_train) == len(right_rgb_names_train) == len(disparity_names_train), "Error : number of files doesn't match"

for idx in range(n_train_data):
    left_rgb = left_rgb_names_train[idx]
    right_rgb = right_rgb_names_train[idx]
    disparity = disparity_names_train[idx]

    assert left_rgb.split('/')[-1].split('.')[-2] == right_rgb.split('/')[-1].split('.')[-2] == disparity.split('/')[-1].split('.')[-2],"Error : wrong file name match "+left_rgb+ " "+right_rgb+ " "+disparity

    if idx % 500 == 0:
        print(idx,"out of",n_train_data,"completed")

print("sanity check finished.")

# Setting up the inputs

In [None]:
tf.reset_default_graph()#
# training generator
training_generator = iter(TrainingGeneratorStereoNet(left_rgb_names_train, right_rgb_names_train, disparity_names_train))
generator_data_type = (tf.float32, tf.float32, tf.float32)
training_set = tf.data.Dataset.from_generator(lambda: training_generator, generator_data_type)
training_set = training_set.batch(n_batch)
buffer_size = 1
training_set = training_set.prefetch(buffer_size)
# get training iterators
training_iterator = training_set.make_initializable_iterator()
# get data placeholders
rgb_left_train, rgb_right_train, disparity_train = training_iterator.get_next()

# Sizes at compile time
rgb_left_train.set_shape(tf.TensorShape([n_batch, full_h, full_w, 3]))
rgb_right_train.set_shape(tf.TensorShape([n_batch, full_h, full_w, 3]))
disparity_train.set_shape(tf.TensorShape([n_batch, full_h, full_w, 1]))


# Network and Outputs

In [None]:
#Setting up the stereonet
tf.logging.info("Network Declaration:")
blur_image, blur_image_lowres, disparity = refnet_blur_refinement(rgb_left_train, rgb_right_train, 
                                                                   target_disparity, blur_magnitude,
                                                                   min_disp = min_disp, max_disp = max_disp,
                                                                   from_scale=from_scale
                                                                  )

#Get disparity blur image
tf.logging.info("GT rebluring:")
blur_image_gt = layered_bluring(rgb_left_train, disparity_train, 
                                target_disparity,blur_magnitude, 
                                min_disp = min_disp, max_disp = max_disp,
                                downsampling_trick_max_kernel_size=11,
                                differenciable=False)

loss = tf.reduce_mean(tf.pow(blur_image-blur_image_gt,2))

#definition learning setp
learning_rate = tf.placeholder(tf.float32)
optimizer = tf.train.RMSPropOptimizer(learning_rate = learning_rate)
train_op = optimizer.minimize(loss=loss)



# Initialisaiton

In [None]:
# =============================================================================
# Initializors
# =============================================================================
tf.logging.info("Initialisaiton")

config = tf.ConfigProto()
config.gpu_options.allow_growth=True
sess = tf.Session(config=config)

writer = tf.summary.FileWriter("./logs")

init = tf.global_variables_initializer()
sess.run(init)
sess.run(training_iterator.initializer)
train_init_op = training_iterator.make_initializer(training_set)
sess.run(train_init_op)
#img_left = sess.run(rgb_left_train)
#print img_left.shape

# =============================================================================
# Checkpoints and Visualizations
# =============================================================================
if is_from_beginning:
    lr = 1e-5
    step = 0
    val_step = 0
    losses = np.zeros((1,3))
    #loading pretrained stereonet weigts
    tf.logging.info("Restoring variables from "+stereonet_weigts)
    rv = optimistic_restore(sess, stereonet_weigts)
    tf.logging.info("Restored %d vars"%rv)
else:
    saver1 = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))
    saver1.restore(sess, './checkpoints/stereoNet_testrun_9.69773729788e-07_step_88000val_0.ckpt')
    lr = 9.69773729788e-07
    step = 88000
    val_step = 0
    #9.69773729788e-07_step_88000val_0.ckpt
    
    
tf.summary.scalar('loss', loss)
tf.summary.image('GT_blur', blur_image_gt)
tf.summary.image('blur_upsampled', tf.clip_by_value(blur_image,0,1))
tf.summary.image('blur_lowres', blur_image_lowres)
tf.summary.image('disparity', disparity)

train_summary_op = tf.summary.merge_all()

# Training

In [None]:

n_epoch = 100
n_iter = 200000

save_n_every_step = 2000
show_n_every_step = 1000

force_show_steps = [2,4,8,16,32,64,128,256,512]


# =============================================================================
# Training loop
# =============================================================================
tf.logging.info("Starting training")


for each_epoch in range(n_epoch):
    for each_iter in range(n_iter):
        
        run_list = [train_op,
                    train_summary_op,
                    loss
                   ]
        
        output = sess.run(run_list, feed_dict={learning_rate:lr})
        
        train_summary = output[1]
        loss_out = output[2]

        if step != 0 and step % save_n_every_step == 0:
            lr *= 0.9
            saver = tf.train.Saver(var_list = tf.global_variables())
            save_path = saver.save(sess,"checkpoints/run_" + str(lr) + "_step_" + str(step) + "val_" + str(val_step) + '.ckpt')
        
        if (step != 0 and step % show_n_every_step == 0) or step in force_show_steps:
            tf.logging.info("Step %d. Loss=%f"%(step,loss_out))
            writer.add_summary(train_summary,int(step))
            
        else:
            from math import isnan
            if isnan(loss_out):
                raise BaseException("Loss in Nan at step"%step)
       
        step += 1

# Test

In [1]:
from PIL import Image
from stereonet.utils import readPFM

import tensorflow as tf
import os
import numpy as np
from refocus_algorithms.layered_dof_tf import layered_bluring
from blur_refinement.model import refnet_blur_refinement

checkpoint_to_load = "/home/matthieu/dev/refnet/blur_refinement/checkpoints/run_2.54186582833e-06_step_26000val_0.ckpt"

#params
target_disparity = [0, 50, 100, 150, 200, 250]
blur_magnitude = 0.1
max_disp = 300
min_disp = 0
downsampling_trick_max_kernel_size=11
from_scale = 1
#
full_w = 960
full_h = 540
resolution = (full_h,full_w)
#
out_folder  = "/home/matthieu/dev/refnet/blur_refinement/test_out"

  from ._conv import register_converters as _register_converters


In [2]:


#inputs
path_to_left_rgb = '/home/matthieu/kgx_nfs2/data/external/sceneflow/frames_cleanpass/35mm_focallength/scene_mixed_test/left/'
path_to_right_rgb = '/home/matthieu/kgx_nfs2/data/external/sceneflow/frames_cleanpass/35mm_focallength/scene_mixed_test/right'
left_rgb_names_test = [os.path.join(path_to_left_rgb,each_name) for each_name in os.listdir(path_to_left_rgb)]
right_rgb_names_test = [os.path.join(path_to_right_rgb,each_name) for each_name in os.listdir(path_to_right_rgb)]
path_to_disparity = '/home/matthieu/kgx_nfs2/data/external/sceneflow/disparity/35mm_focallength/scene_mixed_test'
disparity_names_test = [os.path.join(path_to_disparity,each_name) for each_name in os.listdir(path_to_disparity)]
left_rgb_names_test.sort()
right_rgb_names_test.sort()
disparity_names_test.sort()



#Get forward 
for f in range(len(target_disparity)):
    tf.reset_default_graph()
    
        # Sizes at compile time
    rgb_left_test = tf.placeholder(shape=[1, full_h, full_w, 3], dtype=tf.float32)
    rgb_right_test = tf.placeholder(shape=[1, full_h, full_w, 3], dtype=tf.float32)
    disparity_test = tf.placeholder(shape=[1, full_h, full_w, 1], dtype=tf.float32)
    
    blur_image, blur_image_lowres, disparity = refnet_blur_refinement(rgb_left_test, rgb_right_test, 
                                                                       target_disparity[f], blur_magnitude,
                                                                       min_disp = min_disp, max_disp = max_disp,
                                                                       from_scale=from_scale
                                                                      )
    #open sess
    config = tf.ConfigProto()
    config.gpu_options.allow_growth=True
    sess = tf.Session(config=config)

    #init
    init = tf.global_variables_initializer()
    sess.run(init)

    #restore weigts 
    saver1 = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))
    saver1.restore(sess, checkpoint_to_load)

    #do test
    for i in range(len(left_rgb_names_test)):
        print(i)
        rgb_left_data = np.expand_dims(np.asarray(Image.open(left_rgb_names_test[i]))/255.0, 0)
        rgb_right_data = np.expand_dims(np.asarray(Image.open(right_rgb_names_test[i]))/255.0, 0)
        disparity_data = np.expand_dims(np.expand_dims(readPFM(disparity_names_test[i]), 0), -1)

        blur_image_out, blur_image_lowres_out = sess.run([blur_image, blur_image_lowres], 
                                                 feed_dict={ rgb_left_test:rgb_left_data,
                                                             rgb_right_test:rgb_right_data})

        Image.fromarray((255*blur_image_out[0]).astype(np.uint8)).save(out_folder+"/%d_%d_blur_upsampling.png"%(i,target_disparity[f]))



At low res focus is 0.000000 and magnitude 0.100000 min max disparity 0.000000 150.000000
INFO:tensorflow:Refocusing with disparity range 0.000000 150.000000
INFO:tensorflow:Disparity slice 0.000000
INFO:tensorflow:Bluring with radius size 0.000000 kernel size 3
INFO:tensorflow:Bluring with radius size 0.000000 kernel size 3
INFO:tensorflow:Disparity slice 10.000000
INFO:tensorflow:Bluring with radius size 1.000000 kernel size 5
INFO:tensorflow:Bluring with radius size 1.000000 kernel size 5
INFO:tensorflow:Disparity slice 20.000000
INFO:tensorflow:Bluring with radius size 2.000000 kernel size 7
INFO:tensorflow:Bluring with radius size 2.000000 kernel size 7
INFO:tensorflow:Disparity slice 30.000000
INFO:tensorflow:Bluring with radius size 3.000000 kernel size 9
INFO:tensorflow:Bluring with radius size 3.000000 kernel size 9
INFO:tensorflow:Disparity slice 40.000000
INFO:tensorflow:Bluring with radius size 4.000000 kernel size 11
INFO:tensorflow:Bluring with radius size 4.000000 kernel

INFO:tensorflow:Bluring with radius size 7.000000 kernel size 17
INFO:tensorflow:Bluring with radius size 7.000000 kernel size 17
INFO:tensorflow:Disparity slice 130.000000
INFO:tensorflow:Bluring with radius size 8.000000 kernel size 19
INFO:tensorflow:Bluring with radius size 8.000000 kernel size 19
INFO:tensorflow:Disparity slice 140.000000
INFO:tensorflow:Bluring with radius size 9.000000 kernel size 21
INFO:tensorflow:Bluring with radius size 9.000000 kernel size 21
INFO:tensorflow:Image composition
INFO:tensorflow:Upsampling from scale 1/2 to 1/1
INFO:tensorflow:Restoring parameters from /home/matthieu/dev/refnet/blur_refinement/checkpoints/run_2.54186582833e-06_step_26000val_0.ckpt
0
1
2
3
4
5
6
7
8
9
10
At low res focus is 75.000000 and magnitude 0.100000 min max disparity 0.000000 150.000000
INFO:tensorflow:Refocusing with disparity range 0.000000 150.000000
INFO:tensorflow:Disparity slice 0.000000
INFO:tensorflow:Bluring with radius size 7.500000 kernel size 17
INFO:tensorflo

INFO:tensorflow:Disparity slice 80.000000
INFO:tensorflow:Bluring with radius size 4.500000 kernel size 11
INFO:tensorflow:Bluring with radius size 4.500000 kernel size 11
INFO:tensorflow:Disparity slice 90.000000
INFO:tensorflow:Bluring with radius size 3.500000 kernel size 9
INFO:tensorflow:Bluring with radius size 3.500000 kernel size 9
INFO:tensorflow:Disparity slice 100.000000
INFO:tensorflow:Bluring with radius size 2.500000 kernel size 7
INFO:tensorflow:Bluring with radius size 2.500000 kernel size 7
INFO:tensorflow:Disparity slice 110.000000
INFO:tensorflow:Bluring with radius size 1.500000 kernel size 5
INFO:tensorflow:Bluring with radius size 1.500000 kernel size 5
INFO:tensorflow:Disparity slice 120.000000
INFO:tensorflow:Bluring with radius size 0.500000 kernel size 3
INFO:tensorflow:Bluring with radius size 0.500000 kernel size 3
INFO:tensorflow:Disparity slice 130.000000
INFO:tensorflow:Bluring with radius size 0.500000 kernel size 3
INFO:tensorflow:Bluring with radius si