In [1]:
import tensorflow as tf
import sys
sys.path.append('../kitti_eval/flow_tool/')
sys.path.append('..')
import flowlib as fl

from nets.flow_net import Flow_net

from utils.optical_flow_warp_fwd import transformerFwd
from utils.optical_flow_warp_old import transformer_old
from utils.loss_utils import SSIM, cal_grad2_error
from data_loader.data_loader import DataLoader
import matplotlib.pyplot as plt 
import os
import imageio
import time
from IPython.display import clear_output
import numpy as np

# Parameter Setting 

In [2]:
tf.__version__

'2.0.0'

In [3]:
tf.test.is_built_with_cuda()

True

In [4]:
tf.config.experimental.list_physical_devices('GPU')

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU'),
 PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU'),
 PhysicalDevice(name='/physical_device:GPU:2', device_type='GPU'),
 PhysicalDevice(name='/physical_device:GPU:3', device_type='GPU')]

In [5]:
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
tf.config.experimental.set_visible_devices(devices=gpus, device_type='GPU')

In [None]:
# train mode
mode = "train_flow"
# mode = "train_dp"

#是否要載入 pretrain weights
LOAD_FLOW_WEIGHT = False
LOAD_DEPTH_WEIGHT = False
LOAD_POSE_WEIGHT = False

if mode == 'train_dp':
    LOAD_FLOW_WEIGHT = True

# Dataset路徑
kitti_depth_folder = '../../../../datasets/kitti_3frames_256_832'



# Model Input 解析度設定
num_GPU = 4
EPOCH = 20
PER_BATCH = 8
BATCH_SIZE = PER_BATCH * num_GPU
IMG_HEIGHT = 256
IMG_WIDTH = 832
NUM_SCALE = 4
NUM_SOURCE = 2

# Loss Hyperparameters
SSIM_WEIGHT = 0.85
FLOW_RECONSTRUCTION_WEIGHT = 0.15
FLOW_SMOOTH_WEIGHT = 10.0
FLOW_CROSS_GEOMETRY_WEIGHT = 0.3
flow_diff_threshold = 4.0
flow_consist_weight = 0.01

mirrored_strategy = tf.distribute.MirroredStrategy()
print('[info] Using %d GPUS to speedup' % mirrored_strategy.num_replicas_in_sync)

# Build DataLoader

In [None]:
dataloader = DataLoader(mode=mode, 
                        dataset_dir=kitti_depth_folder, 
                        img_height=IMG_HEIGHT, 
                        img_width=IMG_WIDTH,
                        split='train')
dataset = dataloader.build_dataloader().shuffle(5000).batch(BATCH_SIZE)
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)

# Build Model

In [None]:
with mirrored_strategy.scope():
    model = Flow_net(input_shape=[IMG_HEIGHT, IMG_WIDTH, 9])
# model.summary()

In [None]:
tf.keras.utils.plot_model(model, show_shapes=True, dpi=64)

dataset產生2個內容:三張連續彩色影像(n, h, w, 9), 相機內參(n, 3, 3), n:Batch size

# Define the Optimizers and Checkpoint-saver

In [None]:
initial_learning_rate = 1e-4
with mirrored_strategy.scope():
    optimizer = tf.keras.optimizers.Adam(learning_rate=initial_learning_rate)

In [None]:
checkpoint_attr = str(BATCH_SIZE) + '_' + str(IMG_HEIGHT) + '_' + str(IMG_WIDTH)
checkpoint_dir = os.path.join('../training_checkpoints', mode, checkpoint_attr)
if not os.path.isdir(checkpoint_dir):
    os.makedirs(checkpoint_dir)
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

checkpoint = tf.train.Checkpoint(flowModel=model, flowOptimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(checkpoint, directory=checkpoint_dir, checkpoint_name='ckpt', max_to_keep=10)

# 若LOAD_FLOW_WEIGHT為True，載入Model weight
if LOAD_FLOW_WEIGHT:
    checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

## Show input image, optical flow result and soft occlusion result 

In [None]:
def show_image(image, flow, occlusion, epoch, isSave=False, count=0):
    plt.figure(figsize=(16, 9))
    
    display_list = [image[:, :, :3], image[:, :, 3:6], 
                    flow[0].numpy(), flow[1].numpy(), 
                    occlusion[0].numpy().squeeze(), occlusion[1].numpy().squeeze()]
    title = ['Input Image(L)', 'Input Image(M)', 
             'Optical Flow(FW)', 'Optical Flow(BW)', 
             'Soft Occlusion(FW)', 'Soft Occlusion(BW)']

    for i in range(6):
        plt.subplot(3, 2, i+1)
        plt.title(title[i])
        # getting the pixel values between [0, 1] to plot it.
        if i < 4:
            plt.imshow(display_list[i])
        else:
            plt.imshow(display_list[i], cmap='gray')
        plt.axis('off')
    plt.show()
    
    if isSave:
        if not os.path.isdir('./result'):
            os.makedirs('./result')
        imageio.imwrite('./result/{}_{}_imageL.jpg'.format(epoch, count), image[:, :, :3])
        imageio.imwrite('./result/{}_{}_imageM.jpg'.format(epoch, count), image[:, :, 3:6])
        imageio.imwrite('./result/{}_{}_flowbw.jpg'.format(epoch, count), flow[1].numpy())
        imageio.imwrite('./result/{}_{}_occlusionbw.jpg'.format(epoch, count), occlusion[1].numpy())
        imageio.imwrite('./result/{}_{}_flowfw.jpg'.format(epoch, count), flow[0].numpy())
        imageio.imwrite('./result/{}_{}_occlusionfw.jpg'.format(epoch, count), occlusion[0].numpy())

# Training

In [None]:

def fit(dataset, epochs):
    his_loss = []
    for epoch in range(epochs):
        #with mirrored_strategy.scope():
        optimizer.learning_rate = initial_learning_rate * (0.5**(epoch//5))
        total_loss = 0
        count = 1
        start = time.time()
        for image_batch, cam in dataset:

            loss = train_step(image_batch)
            total_loss += loss.numpy()
            if count % 20 == 0:
                print('[info] Epoch:{} Count:{} Loss:{} Time:{:.2f} sec'.format(epoch, count, total_loss/count, time.time()-start))
#                 show_image(image_batch[0], [flow[0][0], flow[1][0]], [occlusion[0][0], occlusion[1][0]], epoch)
                if count % 200 == 0:
                    clear_output(wait=True)
                    print('[info] Epoch:{} Count:{} Loss:{} Time:{:.2f} sec'.format(epoch, count, total_loss/count, time.time()-start))
#                     show_image(image_batch[0], [flow[0][0], flow[1][0]], [occlusion[0][0], occlusion[1][0]], epoch, isSave=True, count=count)
                start = time.time()
            his_loss.append(total_loss/count)
            count += 1


        path = ckpt_manager.save(checkpoint_number=epoch)         
        print("[info] model saved to %s" % path)



    return his_loss
    
with mirrored_strategy.scope():
    # with mirrored_strategy.scope():
    def step_fn(inputs):
        with tf.GradientTape() as tape:
            predict = model(inputs) #training=True
            #loss = build_flow_loss(inputs, predict)
            loss = model.total_losses

        gradients_of_model = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients_of_model, model.trainable_variables))
        return loss

    @tf.function
    def train_step(images):
        per_example_losses = mirrored_strategy.experimental_run_v2(step_fn, args=(images,))
        mean_loss = mirrored_strategy.reduce(tf.distribute.ReduceOp.MEAN, per_example_losses, axis=None)
    #     mean_loss, flow, occlusion = step_fn(images)
        return mean_loss





In [None]:
tf.keras.backend.set_learning_phase(True)
#checkpoint.restore(checkpoint_prefix+'-8')

In [None]:
his = fit(dist_dataset, EPOCH)

## Plot loss result

In [None]:
def show_loss(his):
    ep = len(his)
    ep = np.arange(ep)
    plt.plot(ep, his)
    plt.show()

In [None]:
show_loss(his_loss)

In [None]:
show_loss(small_loss)

In [None]:
@tf.function
def build_flow_loss(input_images, predict):
    tgt_image = input_images[:, :, :, 3:6]
    src_image_stack = tf.concat([input_images[:, :, :, :3], input_images[:, :, :, 6:]], axis=3)

    pred_fw_flows = predict[:3]
    pred_bw_flows = predict[3:]

    reconstructed_loss = 0
    cross_reconstructed_loss = 0
    flow_smooth_loss = 0
    cross_flow_smooth_loss = 0
    ssim_loss = 0
    cross_ssim_loss = 0



    curr_tgt_image_all = []
    curr_src_image_stack_all = []
    occlusion_map_0_all = []
    occlusion_map_1_all = []
    occlusion_map_2_all = []
    occlusion_map_3_all = []
    occlusion_map_4_all = []
    occlusion_map_5_all = []

    # Calculate different scale occulsion maps described in 'Occlusion Aware Unsupervised
    # Learning of Optical Flow by Yang Wang et al'
    occu_masks_bw = []
    occu_masks_bw_avg = []
    occu_masks_fw = []
    occu_masks_fw_avg = []

    for i in range(len(pred_bw_flows)):
        temp_occu_masks_bw = []
        temp_occu_masks_bw_avg = []
        temp_occu_masks_fw = []
        temp_occu_masks_fw_avg = []

        for s in range(NUM_SCALE):
            H = int(IMG_HEIGHT / (2**s))
            W = int(IMG_WIDTH  / (2**s))

            mask, mask_avg = occulsion(pred_bw_flows[i][s])
            temp_occu_masks_bw.append(mask)
            temp_occu_masks_bw_avg.append(mask_avg)
            # [src0, tgt, src0_1]

            mask, mask_avg = occulsion(pred_fw_flows[i][s])
            temp_occu_masks_fw.append(mask)
            temp_occu_masks_fw_avg.append(mask_avg)
            # [tgt, src1, src1_1]

        occu_masks_bw.append(temp_occu_masks_bw)
        occu_masks_bw_avg.append(temp_occu_masks_bw_avg)
        occu_masks_fw.append(temp_occu_masks_fw)
        occu_masks_fw_avg.append(temp_occu_masks_fw_avg)

    for s in range(NUM_SCALE):
        H = int(IMG_HEIGHT / (2**s))
        W = int(IMG_WIDTH  / (2**s))
        curr_tgt_image = tf.image.resize(
            tgt_image, [H, W], method=tf.image.ResizeMethod.BILINEAR)
        curr_src_image_stack = tf.image.resize(
            src_image_stack, [H, W], method=tf.image.ResizeMethod.BILINEAR)

        curr_tgt_image_all.append(curr_tgt_image)
        curr_src_image_stack_all.append(curr_src_image_stack)

        # src0
        curr_proj_image_optical_src0 = transformer_old(curr_tgt_image, pred_fw_flows[0][s], [H, W])
        curr_proj_error_optical_src0 = tf.abs(curr_proj_image_optical_src0 - curr_src_image_stack[:,:,:,0:3])
        reconstructed_loss += tf.reduce_mean(
            input_tensor=curr_proj_error_optical_src0 * occu_masks_bw[0][s]) / occu_masks_bw_avg[0][s]

        curr_proj_image_optical_src0_1 = transformer_old(curr_src_image_stack[:,:,:,3:6], pred_fw_flows[2][s], [H, W])
        curr_proj_error_optical_src0_1 = tf.abs(curr_proj_image_optical_src0_1 - curr_src_image_stack[:,:,:,0:3])
        cross_reconstructed_loss += tf.reduce_mean(
            input_tensor=curr_proj_error_optical_src0_1 * occu_masks_bw[2][s]) / occu_masks_bw_avg[2][s]

        # tgt
        curr_proj_image_optical_tgt = transformer_old(curr_src_image_stack[:,:,:,3:6], pred_fw_flows[1][s], [H, W])
        curr_proj_error_optical_tgt = tf.abs(curr_proj_image_optical_tgt - curr_tgt_image)
        reconstructed_loss += tf.reduce_mean(
            input_tensor=curr_proj_error_optical_tgt * occu_masks_bw[1][s]) / occu_masks_bw_avg[1][s]

        curr_proj_image_optical_tgt_1 = transformer_old(curr_src_image_stack[:,:,:,0:3], pred_bw_flows[0][s], [H, W])
        curr_proj_error_optical_tgt_1 = tf.abs(curr_proj_image_optical_tgt_1 - curr_tgt_image)
        reconstructed_loss += tf.reduce_mean(
            input_tensor=curr_proj_error_optical_tgt_1 * occu_masks_fw[0][s]) / occu_masks_fw_avg[0][s]

        # src1
        curr_proj_image_optical_src1 = transformer_old(curr_tgt_image, pred_bw_flows[1][s], [H, W])
        curr_proj_error_optical_src1 = tf.abs(curr_proj_image_optical_src1 - curr_src_image_stack[:,:,:,3:6])
        reconstructed_loss += tf.reduce_mean(
            input_tensor=curr_proj_error_optical_src1 * occu_masks_fw[1][s]) / occu_masks_fw_avg[1][s]

        curr_proj_image_optical_src1_1 = transformer_old(curr_src_image_stack[:,:,:,0:3], pred_bw_flows[2][s], [H, W])
        curr_proj_error_optical_src1_1 = tf.abs(curr_proj_image_optical_src1_1 - curr_src_image_stack[:,:,:,3:6])
        cross_reconstructed_loss += tf.reduce_mean(
            input_tensor=curr_proj_error_optical_src1_1 * occu_masks_fw[2][s]) / occu_masks_fw_avg[2][s]

        if SSIM_WEIGHT > 0:
            # src0
            ssim_loss += tf.reduce_mean(
                input_tensor=SSIM(curr_proj_image_optical_src0 * occu_masks_bw[0][s],
                     curr_src_image_stack[:,:,:,0:3] * occu_masks_bw[0][s])) / occu_masks_bw_avg[0][s]

            cross_ssim_loss += tf.reduce_mean(
                input_tensor=SSIM(curr_proj_image_optical_src0_1 * occu_masks_bw[2][s],
                     curr_src_image_stack[:,:,:,0:3] * occu_masks_bw[2][s])) / occu_masks_bw_avg[2][s]

            # tgt
            ssim_loss += tf.reduce_mean(
                input_tensor=SSIM(curr_proj_image_optical_tgt * occu_masks_bw[1][s],
                     curr_tgt_image * occu_masks_bw[1][s])) / occu_masks_bw_avg[1][s]

            ssim_loss += tf.reduce_mean(
                input_tensor=SSIM(curr_proj_image_optical_tgt_1 * occu_masks_fw[0][s],
                     curr_tgt_image * occu_masks_fw[0][s])) / occu_masks_fw_avg[0][s]

            # src1
            ssim_loss += tf.reduce_mean(
                input_tensor=SSIM(curr_proj_image_optical_src1 * occu_masks_fw[1][s],
                     curr_src_image_stack[:,:,:,3:6] * occu_masks_fw[1][s])) / occu_masks_fw_avg[1][s]

            cross_ssim_loss += tf.reduce_mean(
                input_tensor=SSIM(curr_proj_image_optical_src1_1 * occu_masks_fw[2][s],
                     curr_src_image_stack[:,:,:,3:6] * occu_masks_fw[2][s])) / occu_masks_fw_avg[2][s]

        # Compute second-order derivatives for flow smoothness loss
        flow_smooth_loss += cal_grad2_error(
            pred_fw_flows[0][s] / 20.0, curr_src_image_stack[:,:,:,0:3], 1.0)

        flow_smooth_loss += cal_grad2_error(
            pred_fw_flows[1][s] / 20.0, curr_tgt_image, 1.0)

        cross_flow_smooth_loss += cal_grad2_error(
            pred_fw_flows[2][s] / 20.0, curr_src_image_stack[:,:,:,0:3], 1.0)

        flow_smooth_loss += cal_grad2_error(
            pred_bw_flows[0][s] / 20.0, curr_tgt_image, 1.0)

        flow_smooth_loss += cal_grad2_error(
            pred_bw_flows[1][s] / 20.0, curr_src_image_stack[:,:,:,3:6], 1.0)

        cross_flow_smooth_loss += cal_grad2_error(
            pred_bw_flows[2][s] / 20.0, curr_src_image_stack[:,:,:,3:6], 1.0)

        # [TODO] Add first-order derivatives for flow smoothness loss
        # [TODO] use robust Charbonnier penalty?

        if s == 0:
            occlusion_map_0_all = occu_masks_bw[0][s]
            occlusion_map_1_all = occu_masks_bw[1][s]
            occlusion_map_2_all = occu_masks_bw[2][s]
            occlusion_map_3_all = occu_masks_fw[0][s]
            occlusion_map_4_all = occu_masks_fw[1][s]
            occlusion_map_5_all = occu_masks_fw[2][s]

    recon_losses = reconstructed_loss + FLOW_CROSS_GEOMETRY_WEIGHT * cross_reconstructed_loss
    ssim_losses = ssim_loss + FLOW_CROSS_GEOMETRY_WEIGHT * cross_ssim_loss
    smooth_losses = flow_smooth_loss + FLOW_CROSS_GEOMETRY_WEIGHT*cross_flow_smooth_loss

    losses = FLOW_RECONSTRUCTION_WEIGHT * recon_losses + \
             SSIM_WEIGHT * ssim_losses + \
             FLOW_SMOOTH_WEIGHT * smooth_losses



#     summaries = []
#     summaries.append(tf.compat.v1.summary.scalar("total_loss", losses))
#     summaries.append(tf.compat.v1.summary.scalar("reconstructed_loss", reconstructed_loss))
#     summaries.append(tf.compat.v1.summary.scalar("cross_reconstructed_loss", cross_reconstructed_loss))
#     summaries.append(tf.compat.v1.summary.scalar("ssim_loss", ssim_loss))
#     summaries.append(tf.compat.v1.summary.scalar("cross_ssim_loss", cross_ssim_loss))
#     summaries.append(tf.compat.v1.summary.scalar("flow_smooth_loss", flow_smooth_loss))
#     summaries.append(tf.compat.v1.summary.scalar("cross_flow_smooth_loss", cross_flow_smooth_loss))

#     s = 0
#     tf.compat.v1.summary.image('scale%d_target_image' % s, tf.image.convert_image_dtype(curr_tgt_image_all[0], dtype=tf.uint8))

#     for i in range(NUM_SOURCE):
#         tf.compat.v1.summary.image('scale%d_src_image_%d' % (s, i), \
#                         tf.image.convert_image_dtype(curr_src_image_stack_all[0][:, :, :, i*3:(i+1)*3], dtype=tf.uint8))

#     tf.compat.v1.summary.image('scale%d_flow_src02tgt' % s, fl.flow_to_color(self.pred_fw_flows[0][s], max_flow=256))
#     tf.compat.v1.summary.image('scale%d_flow_tgt2src1' % s, fl.flow_to_color(self.pred_fw_flows[1][s], max_flow=256))
#     tf.compat.v1.summary.image('scale%d_flow_src02src1' % s, fl.flow_to_color(self.pred_fw_flows[2][s], max_flow=256))
#     tf.compat.v1.summary.image('scale%d_flow_tgt2src0' % s, fl.flow_to_color(self.pred_bw_flows[0][s], max_flow=256))
#     tf.compat.v1.summary.image('scale%d_flow_src12tgt' % s, fl.flow_to_color(self.pred_bw_flows[1][s], max_flow=256))
#     tf.compat.v1.summary.image('scale%d_flow_src12src0' % s, fl.flow_to_color(self.pred_bw_flows[2][s], max_flow=256))

#     tf.compat.v1.summary.image('scale_flyout_mask_src0', occlusion_map_0_all)
#     tf.compat.v1.summary.image('scale_flyout_mask_tgt', occlusion_map_1_all)
#     tf.compat.v1.summary.image('scale_flyout_mask_src0_1', occlusion_map_2_all)
#     tf.compat.v1.summary.image('scale_flyout_mask_tgt1', occlusion_map_3_all)
#     tf.compat.v1.summary.image('scale_flyout_mask_src1', occlusion_map_4_all)
#     tf.compat.v1.summary.image('scale_flyout_mask_src1_1', occlusion_map_5_all)

#     self.summ_op = tf.compat.v1.summary.merge(summaries)
    s = 0
    flow = [fl.flow_to_color(pred_fw_flows[0][s], max_flow=256), fl.flow_to_color(pred_bw_flows[0][s], max_flow=256)]
    occlusion = [occlusion_map_3_all, occlusion_map_0_all]
    return losses


def occulsion(pred_flow):
    """
    Here, we compute the soft occlusion maps proposed in https://arxiv.org/pdf/1711.05890.pdf

    pred_flow: the estimated forward optical flow
    """

    n, h, w, c = pred_flow.get_shape().as_list()
#     transformerFwd = TransformerFwd()
    occu_mask = [
        tf.clip_by_value(
            transformerFwd(
                tf.ones(shape=[PER_BATCH, h, w, 1], dtype='float32'),
                pred_flow, [h , w], backprop=True),
            clip_value_min=0.0,
            clip_value_max=1.0)
        ]
    occu_mask = tf.reshape(occu_mask, [PER_BATCH, h, w, 1])
    occu_mask_avg = tf.reduce_mean(input_tensor=occu_mask)

    return occu_mask, occu_mask_avg