In [7]:
import tensorflow as tf
import sys
sys.path.append('../kitti_eval/flow_tool/')
sys.path.append('..')
import flowlib as fl

from nets.depth_net import Depth_net
from nets.pose_net import Pose_net
from nets.flow_net import Flow_net

from utils.optical_flow_warp_fwd import transformerFwd
from utils.optical_flow_warp_old import transformer_old
from utils.loss_utils import SSIM, cal_grad2_error_mask, charbonnier_loss, cal_grad2_error, compute_edge_aware_smooth_loss, ternary_loss, depth_smoothness
from utils.utils import average_gradients, normalize_depth_for_display, preprocess_image, deprocess_image, inverse_warp, inverse_warp_new

from data_loader.data_loader import DataLoader
import matplotlib.pyplot as plt 
import os
import imageio
import time
from IPython.display import clear_output

# Parameter Setting 

In [8]:
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
tf.config.experimental.set_visible_devices(devices=gpus[0:2], device_type='GPU')

In [9]:
# train mode
# mode = "train_flow"
mode = "train_dp"

#是否要載入 pretrain weights
LOAD_FLOW_WEIGHT = False
LOAD_DP_WEIGHT = False

if mode == 'train_dp':
    LOAD_FLOW_WEIGHT = True

# Dataset路徑
kitti_depth_folder = '../../../../datasets/kitti_3frames_256_832'


# Model Input 解析度設定
EPOCH = 50
BATCH_SIZE = 8
IMG_HEIGHT = 256
IMG_WIDTH = 832
NUM_SCALE = 4
NUM_SOURCE = 2

# Loss Hyperparameters
SSIM_WEIGHT = 0.85
FLOW_RECONSTRUCTION_WEIGHT = 1.0
FLOW_SMOOTH_WEIGHT = 10.0
FLOW_CROSS_GEOMETRY_WEIGHT = 0.3
FLOW_DIFF_THRESHOLD = 4.0
FLOW_CONSIST_WEIGHT = 0.01

DP_RECONSTRUCTION_WEIGHT = 50.0
DP_SMOOTH_WEIGHT = 10.0 # origin 1.0
DP_CROSS_GEOMETRY_WEIGHT = 0.8

compute_minimum_loss = False
is_depth_upsampling = False # [!!!!] Cannot work. => cause depth smoothness loss to be 0
joint_encoder = False
equal_weighting = False  # equal weight for depth smoothness loss
depth_normalization = False # depth normalization for depth smoothness loss
scale_normalize = True # depth normalization for all loss

# Build DataLoader

In [None]:
dataloader = DataLoader(mode=mode, 
                        dataset_dir=kitti_depth_folder, 
                        img_height=IMG_HEIGHT, 
                        img_width=IMG_WIDTH,
                        split='train')
dataset = dataloader.build_dataloader().batch(BATCH_SIZE)
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

# Build Model

In [None]:
flow_model = Flow_net(input_shape=[IMG_HEIGHT, IMG_WIDTH, 9])
flow_model.summary()

In [None]:
# 'shufflenetv2', 'mobilenetv2', 'mnasnet', 'mobilenetv3'
net_name = 'shufflenetv2'
depth_model = Depth_net(net_name, input_shape=[IMG_HEIGHT,IMG_WIDTH,3], training=True)
depth_model.summary()

In [None]:
pose_model = Pose_net(input_shape=[IMG_HEIGHT, IMG_WIDTH, 9])
pose_model.summary()

# Define the Optimizers and Checkpoint-saver

In [None]:
initial_learning_rate = 1e-4
flow_optimizer = tf.keras.optimizers.Adam(learning_rate=initial_learning_rate)
dp_optimizer = tf.keras.optimizers.Adam(learning_rate=initial_learning_rate)

In [None]:
checkpoint_attr = str(BATCH_SIZE) + '_' + str(IMG_HEIGHT) + '_' + str(IMG_WIDTH)
checkpoint_dir_flow = os.path.join('../training_checkpoints', 'train_flow', checkpoint_attr)
checkpoint_dir_dp = os.path.join('../training_checkpoints', 'train_dp', checkpoint_attr)
if not os.path.isdir(checkpoint_dir_flow):
    os.makedirs(checkpoint_dir_flow)
if not os.path.isdir(checkpoint_dir_dp):
    os.makedirs(checkpoint_dir_dp)
    
checkpoint_prefix_flow = os.path.join(checkpoint_dir_flow, "ckpt")
checkpoint_prefix_dp = os.path.join(checkpoint_dir_dp, "ckpt")

checkpoint_flow = tf.train.Checkpoint(flowModel=flow_model, 
                                      flowOptimizer=flow_optimizer)
checkpoint_dp = tf.train.Checkpoint(depthModel=depth_model,
                                    poseModel=pose_model,
                                    dpOptimizer=dp_optimizer)

if LOAD_FLOW_WEIGHT:
    checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
if LOAD_DP_WEIGHT:
    checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

# Training

In [None]:
def fit(dataset, epochs):
    his_loss = []
    for epoch in range(EPOCH):
        total_loss = 0
        count = 1
        start = time.time()
        for image_batch, cam in dataset:

            loss = train_step(image_batch, cam)
            total_loss += loss.numpy()
            if count % 20 == 0:
                print('[info] Epoch:{} Count:{} Loss:{} Time:{:.2f} sec'.format(epoch, count, total_loss/count, time.time()-start))
                show_image(image_batch[0], flow[0], occlusion[0], epoch)
                if count % 200 == 0:
                    clear_output(wait=True)
                    print('[info] Epoch:{} Count:{} Loss:{} Time:{:.2f} sec'.format(epoch, count, total_loss/count, time.time()-start))
                    show_image(image_batch[0], flow[0], occlusion[0], epoch, isSave=True)
                start = time.time()

            count += 1

        his_loss.append(total_loss/count)
        checkpoint.save(checkpoint_prefix)
        

    return his_loss

@tf.function
def train_step(images, cam):
    predict_flow = flow_model(images)
    with tf.GradientTape() as tape:
        predict_disp_src0 = depth_model(images[:, :, :, :3]) #training=True
        predict_disp_tgt = depth_model(images[:, :, :, 3:6]) #training=True
        predict_disp_src1 = depth_model(images[:, :, :, 6:]) #training=True
        predict_pose = pose_model(images)
        loss = build_dp_loss(images, 
                             [predict_disp_src0, predict_disp_tgt, predict_disp_src1], 
                             predict_pose, 
                             predict_flow,
                             cam)

    gradients_of_depth = tape.gradient(loss, depth_model.trainable_variables)
    gradients_of_pose = tape.gradient(loss, pose_model.trainable_variables)
    dp_optimizer.apply_gradients(zip(gradients_of_depth, depth_model.trainable_variables))
    dp_optimizer.apply_gradients(zip(gradients_of_pose, pose_model.trainable_variables))

    return loss

def build_dp_loss(images, predict_disp, predict_pose, predict_flow, cam):
    disp = {}
    depth = {}
#     depth_upsampled = {}
    
    tgt_image = input_images[:, :, :, 3:6]
    src_image_stack = tf.concat([input_images[:, :, :, :3], input_images[:, :, :, 6:]], axis=3)
    image_name = ['src0', 'tgt', 'src1']
    for i, name in enumerate(image_name):
        if scale_normalize:
            # As proposed in https://arxiv.org/abs/1712.00175, this can
            # bring improvement in depth estimation, but not included in our paper.
            predict_disp[i] = [spatial_normalize(d) for d in predict_disp[i]]

        predict_depth = [1. / d for d in predict_disp[i]]
        
        disp[name] = predict_disp[i]
        depth[name] = predict_depth
#         depth_upsampled[name] = [tf.image.resize(d, [IMG_HEIGHT, IMG_WIDTH],
#                                  method=tf.image.ResizeMethod.BILINEAR) 
#                                  for d in predict_depth]
        
    
    pred_fw_flows = predict_flow[:3]
    pred_bw_flows = predict_flow[3:]
    
    proj_cam2pix = cam
    proj_pix2cam = tf.linalg.inv(cam)
    
    smooth_loss = 0
    reconstructed_loss = 0
    cross_reconstructed_loss = 0
    ssim_loss = 0
    cross_ssim_loss = 0

    proj_error_depth_all = []
    flyout_map_all_tgt = []
    flyout_map_all_src0 = []
    flyout_map_all_src1 = []
    curr_tgt_image_all = []
    curr_src_image_stack_all = []
    proj_error_src0 = []
    proj_error_src0_1 = []
    proj_error_src1 = []
    proj_error_src1_1 = []
    proj_error_tgt = []
    proj_error_tgt1 = []
    upsampled_tgt_depth_all = []
    summaries = []
    
    # Calculate different scale occulsion maps described in 'Occlusion Aware Unsupervised
    # Learning of Optical Flow by Yang Wang et al'
    occu_masks_bw = []
    occu_masks_bw_avg = []
    occu_masks_fw = []
    occu_masks_fw_avg = []

    for i in range(len(pred_bw_flows)):
        temp_occu_masks_bw = []
        temp_occu_masks_bw_avg = []
        temp_occu_masks_fw = []
        temp_occu_masks_fw_avg = []

        for s in range(NUM_SCALE):
            H = int(IMG_HEIGHT / (2**s))
            W = int(IMG_WIDTH  / (2**s))

            mask, mask_avg = occulsion(pred_bw_flows[i][s], H, W)
            temp_occu_masks_bw.append(mask)
            temp_occu_masks_bw_avg.append(mask_avg)
            # [src0, tgt, src0_1]

            mask, mask_avg = occulsion(pred_fw_flows[i][s], H, W)
            temp_occu_masks_fw.append(mask)
            temp_occu_masks_fw_avg.append(mask_avg)
            # [tgt, src1, src1_1]

        occu_masks_bw.append(temp_occu_masks_bw)
        occu_masks_bw_avg.append(temp_occu_masks_bw_avg)
        occu_masks_fw.append(temp_occu_masks_fw)
        occu_masks_fw_avg.append(temp_occu_masks_fw_avg)
    
    
    scaled_tgt_images = [None for _ in range(NUM_SCALE)]
    scaled_src_images_stack = [None for _ in range(NUM_SCALE)]
    for s in range(NUM_SCALE):
        H = int(IMG_HEIGHT / (2**s))
        W = int(IMG_WIDTH  / (2**s))
        if s == 0: # Just as a precaution. TF often has interpolation bugs.
            scaled_tgt_images[s] = tgt_image
            scaled_src_images_stack[s] = src_image_stack
        else:
            scaled_tgt_images[s] = tf.image.resize(
                tgt_image, [H, W], method=tf.image.ResizeMethod.BILINEAR)
            scaled_src_images_stack[s] = tf.image.resize(
                src_image_stack, [H, W], method=tf.image.ResizeMethod.BILINEAR)
    
        curr_tgt_image = scaled_tgt_images[s]
        curr_src_image_stack = scaled_src_images_stack[s]
        curr_tgt_image_all.append(curr_tgt_image)
        curr_src_image_stack_all.append(curr_src_image_stack)
        
        tgt_depth = depth['tgt'][s]
        src0_depth = depth['src0'][s]
        src1_depth = depth['src1'][s]
        
        
        # src0
        depth_flow_src02tgt, _ = inverse_warp(
            src0_depth,
            pred_poses[:, 0, 0:6], # src0 -> tgt (fw0)
            proj_cam2pix[:, s, :, :],
            proj_pix2cam[:, s, :, :])
        curr_proj_image_tgt2src0 = transformer_old(curr_tgt_image, depth_flow_src02tgt, [H, W])
        curr_proj_error_src0 = tf.abs(curr_proj_image_tgt2src0 - curr_src_image_stack[:,:,:,0:3])

        depth_flow_src02src1, _ = inverse_warp(
            src0_depth,
            pred_poses[:, 0, 6:12], # src0 -> src1 (fw2)
            proj_cam2pix[:, s, :, :],
            proj_pix2cam[:, s, :, :])
        curr_proj_image_src12src0 = transformer_old(curr_src_image_stack[:,:,:,3:6], depth_flow_src02src1, [H, W])
        curr_proj_error_src0_1 = tf.abs(curr_proj_image_src12src0 - curr_src_image_stack[:,:,:,0:3])

        # tgt
        depth_flow_tgt2src1, _ = inverse_warp(
            tgt_depth,
            pred_poses[:, 0, 12:18], # tgt -> src1 (fw1)
            proj_cam2pix[:, s, :, :],
            proj_pix2cam[:, s, :, :])
        curr_proj_image_src12tgt = transformer_old(curr_src_image_stack[:,:,:,3:6], depth_flow_tgt2src1, [H, W])
        curr_proj_error_tgt = tf.abs(curr_proj_image_src12tgt - curr_tgt_image)

        depth_flow_tgt2src0, _ = inverse_warp(
            tgt_depth,
            pred_poses[:, 0, 18:24], # tgt -> src0 (bw0)
            proj_cam2pix[:, s, :, :],
            proj_pix2cam[:, s, :, :])
        curr_proj_image_src02tgt = transformer_old(curr_src_image_stack[:,:,:,0:3], depth_flow_tgt2src0, [H, W])
        curr_proj_error_tgt_1 = tf.abs(curr_proj_image_src02tgt - curr_tgt_image)

        # src1
        depth_flow_src12src0, _ = inverse_warp(
            src1_depth,
            pred_poses[:, 0, 24:30], # src1 -> src0 (bw2)
            proj_cam2pix[:, s, :, :],
            proj_pix2cam[:, s, :, :])
        curr_proj_image_src02src1 = transformer_old(curr_src_image_stack[:,:,:,0:3], depth_flow_src12src0, [H, W])
        curr_proj_error_src1 = tf.abs(curr_proj_image_src02src1 - curr_src_image_stack[:,:,:,3:6])

        depth_flow_src12tgt, _ = inverse_warp(
            src1_depth,
            pred_poses[:, 0, 30:36], # src1 -> tgt (bw1)
            proj_cam2pix[:, s, :, :],
            proj_pix2cam[:, s, :, :])
        curr_proj_image_tgt2src1 = transformer_old(curr_tgt_image, depth_flow_src12tgt, [H, W])
        curr_proj_error_src1_1 = tf.abs(curr_proj_image_tgt2src1 - curr_src_image_stack[:,:,:,3:6])
        
        
        
        if not compute_minimum_loss:
            # src0
            reconstructed_loss += tf.reduce_mean(input_tensor=curr_proj_error_src0 * occu_masks_bw[0][s]) / occu_masks_bw_avg[0][s]
            cross_reconstructed_loss += tf.reduce_mean(input_tensor=curr_proj_error_src0_1 * occu_masks_bw[2][s]) / occu_masks_bw_avg[2][s]
            # tgt
            reconstructed_loss += tf.reduce_mean(input_tensor=curr_proj_error_tgt * occu_masks_bw[1][s]) / occu_masks_bw_avg[1][s]
            reconstructed_loss += tf.reduce_mean(input_tensor=curr_proj_error_tgt_1 * occu_masks_fw[0][s]) / occu_masks_fw_avg[0][s]
            # src1
            cross_reconstructed_loss += tf.reduce_mean(input_tensor=curr_proj_error_src1 * occu_masks_fw[2][s]) / occu_masks_fw_avg[2][s]
            reconstructed_loss += tf.reduce_mean(input_tensor=curr_proj_error_src1_1 * occu_masks_fw[1][s]) / occu_masks_fw_avg[1][s]

            if ssim_weight > 0:
                # src0
                ssim_loss += tf.reduce_mean(input_tensor=SSIM(curr_proj_image_tgt2src0 * occu_masks_bw[0][s], curr_src_image_stack[:,:,:,0:3] * occu_masks_bw[0][s])) / occu_masks_bw_avg[0][s]
                cross_ssim_loss += tf.reduce_mean(input_tensor=SSIM(curr_proj_image_src12src0 * occu_masks_bw[2][s], curr_src_image_stack[:,:,:,0:3] * occu_masks_bw[2][s])) / occu_masks_bw_avg[2][s]
                # tgt
                ssim_loss += tf.reduce_mean(input_tensor=SSIM(curr_proj_image_src12tgt * occu_masks_bw[1][s], curr_tgt_image * occu_masks_bw[1][s])) / occu_masks_bw_avg[1][s]
                ssim_loss += tf.reduce_mean(input_tensor=SSIM(curr_proj_image_src02tgt * occu_masks_fw[0][s], curr_tgt_image * occu_masks_fw[0][s])) / occu_masks_fw_avg[0][s]
                # src1
                cross_ssim_loss += tf.reduce_mean(input_tensor=SSIM(curr_proj_image_src02src1 * occu_masks_fw[2][s], curr_src_image_stack[:,:,:,3:6] * occu_masks_fw[2][s])) / occu_masks_bw_avg[2][s]
                ssim_loss += tf.reduce_mean(input_tensor=SSIM(curr_proj_image_tgt2src1 * occu_masks_fw[1][s], curr_src_image_stack[:,:,:,3:6] * occu_masks_fw[1][s])) / occu_masks_fw_avg[1][s]

        if dp_smooth_weight > 0:
            if depth_normalization:
                # Perform depth normalization, dividing by the mean.
                mean_tgt_disp = tf.reduce_mean(input_tensor=disp['tgt'][s], axis=[1, 2, 3], keepdims=True)
                tgt_disp_input = disp['tgt'][s] / mean_tgt_disp
                mean_src0_disp = tf.reduce_mean(input_tensor=disp['src0'][s], axis=[1, 2, 3], keepdims=True)
                src0_disp_input = disp['src0'][s] / mean_src0_disp
                mean_src1_disp = tf.reduce_mean(input_tensor=disp['src1'][s], axis=[1, 2, 3], keepdims=True)
                src1_disp_input = disp['src1'][s] / mean_src1_disp
            else:
                tgt_disp_input = disp['tgt'][s]
                src0_disp_input = disp['src0'][s]
                src1_disp_input = disp['src1'][s]
            scaling_f = (1.0 if equal_weighting else 1.0 / (2**s))
            # Edge-aware first-order
            smooth_loss += scaling_f * depth_smoothness(tgt_disp_input, scaled_tgt_images[s])
            smooth_loss += scaling_f * depth_smoothness(src0_disp_input, scaled_src_images_stack[s][:,:,:,0:3])
            smooth_loss += scaling_f * depth_smoothness(src1_disp_input, scaled_src_images_stack[s][:,:,:,3:6])
    
    # Loss Hyperparameters
SSIM_WEIGHT = 0.85
DP_RECONSTRUCTION_WEIGHT = 0.15
DP_SMOOTH_WEIGHT = 0.1 # origin 1.0
DP_CROSS_GEOMETRY_WEIGHT = 0.8
    
    reconstruct_losses = reconstructed_loss + DP_CROSS_GEOMETRY_WEIGHT * cross_reconstructed_loss
    ssim_losses = ssim_loss + DP_CROSS_GEOMETRY_WEIGHT * cross_ssim_loss
    losses =  DP_RECONSTRUCTION_WEIGHT * reconstruct_losses + \
              SSIM_WEIGHT * ssim_losses + \
              DP_SMOOTH_WEIGHT * smooth_loss
    
    return losses

def occulsion(pred_flow, H, W):
    """
    Here, we compute the soft occlusion maps proposed in https://arxiv.org/pdf/1711.05890.pdf

    pred_flow: the estimated forward optical flow
    """
#     transformerFwd = TransformerFwd()
    occu_mask = [
        tf.clip_by_value(
            transformerFwd(
                tf.ones(shape=[BATCH_SIZE, H, W, 1], dtype='float32'),
                pred_flow, [H , W], backprop=True),
            clip_value_min=0.0,
            clip_value_max=1.0)
        ]
    occu_mask = tf.reshape(occu_mask, [BATCH_SIZE, H, W, 1])
    occu_mask_avg = tf.reduce_mean(input_tensor=occu_mask)

    return occu_mask, occu_mask_avg
        

In [None]:
tf.keras.backend.set_learning_phase(True)