In [None]:
import tensorflow as tf
import sys
sys.path.append('../kitti_eval/flow_tool/')
sys.path.append('..')
import flowlib as fl


from nets.flow_net import Flow_net

from utils.optical_flow_warp_fwd import TransformerFwd
from utils.optical_flow_warp_old import transformer_old
from utils.loss_utils import SSIM, cal_grad2_error_mask, charbonnier_loss, cal_grad2_error, compute_edge_aware_smooth_loss, ternary_loss, depth_smoothness
from utils.utils import average_gradients, normalize_depth_for_display, preprocess_image, deprocess_image, inverse_warp, inverse_warp_new
from data_loader.data_loader import DataLoader
import matplotlib.pyplot as plt 
import os
import cv2
import numpy as np
import imageio

# Parameter Setting 

In [None]:
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
tf.config.experimental.set_visible_devices(devices=gpus[0:2], device_type='GPU')

In [None]:
#是否要載入 pretrain weights
LOAD_DEPTH_WEIGHT = False


# 欲載入 pretrain weight 的路徑
load_depthWeight_path = '../../models/mnv2_segment_depth_multiloss/origin/PSMNet_DataAug_KITTI_depthCorrection_epoch-49_loss-3.4839.h5'
#load_depthWeight_path = '../../models/mnv2_segment_depth_multiloss/log_depth/log_depth4~80_epoch-21_loss-3.8513.h5' # log depth
#load_depthWeight_path = '../../models/mnv2_segment_depth_multiloss/224x224/224x224_epoch-39_loss-4.3492.h5'

# 儲存 training weight 的路徑
save_depthWeight_path = '../../models/mnv2_segment_depth_multiloss/log_depth/'

# Dataset路徑
kitti_depth_folder = '../../../../datasets/kitti_3frames_256_832'

# train mode
mode = "train_flow"
# mode = "train_dp"

# Model Input 解析度設定
EPOCH = 100
BATCH_SIZE = 8
IMG_HEIGHT = 256
IMG_WIDTH = 832
NUM_SCALE = 4
NUM_SOURCE = 2

# Loss Hyperparameters
SSIM_WEIGHT = 0.85
FLOW_RECONSTRUCTION_WEIGHT = 1.0
FLOW_SMOOTH_WEIGHT = 10.0
FLOW_CROSS_GEOMETRY_WEIGHT = 0.3
flow_diff_threshold = 4.0
flow_consist_weight = 0.01


initial_learning_rate = 1e-4
optimizer = tf.keras.optimizers.Adam(learning_rate=initial_learning_rate)

# Build Model 

In [None]:
model = Flow_net(input_shape=[IMG_HEIGHT, IMG_WIDTH, 9])

# 若LOAD_DEPTH_WEIGHT為True，載入Model weight
if LOAD_DEPTH_WEIGHT:
    model.load_weights(load_depthWeight_path)

model.summary()

In [None]:
tf.keras.utils.plot_model(model, show_shapes=True, dpi=64)

# Build Dataloader

In [None]:
dataloader = DataLoader(mode=mode, 
                        dataset_dir=kitti_depth_folder, 
                        img_height=IMG_HEIGHT, 
                        img_width=IMG_WIDTH)
dataset = dataloader.build_dataloader().batch(BATCH_SIZE)
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

# Training

In [None]:
# 設定訓練參數的儲存路徑與方式
# model_checkpoint = ModelCheckpoint(filepath=save_depthWeight_path+'log_depth4~80_epoch-{epoch:02d}_loss-{loss:.4f}.h5',
#                                    monitor='loss',
#                                    verbose=1,
#                                    save_best_only=False,
#                                    save_weights_only=True,
#                 mode='auto',
#                                    period=1)

# 設定log檔的儲存方式
# csv_logger = CSVLogger(filename=save_depthWeight_path+'training_log.csv',
#                        separator=',',
#                        append=True)

# Define the checkpoint directory to store the checkpoints

checkpoint_dir = './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")


# 設定每個epoch的 learning rate

def scheduler(epoch):
    if epoch < 10:
        return 0.0001
    else:
        return 0.00005


# 設定當Loss為NaN時停止訓練
# terminate_on_nan = TerminateOnNaN()

# # 設定loss function的callback functions
# callbacks = [model_checkpoint,
#              csv_logger,
#              learning_rate_scheduler,
#              terminate_on_nan]

# Callback for printing the LR at the end of each epoch.
class PrintLR(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        print('\nLearning rate for epoch {} is {}'.format(epoch + 1,
                                                      model.optimizer.lr.numpy()))
    
callbacks = [
#     tf.keras.callbacks.TensorBoard(log_dir='./logs'),
    tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
                                       save_weights_only=True),
    tf.keras.callbacks.LearningRateScheduler(scheduler),
    PrintLR()
]


In [None]:
# 訓練起始 epoch [TODO]
initial_epoch=0

# 訓練結束 epoch [TODO]
final_epoch = 10

# 每個 epoch 的 training iteration
steps_per_epoch = 1000

# 設定每個 loss output 的權重: [depth_pred_2x, depth_pred_4x, depth_pred_8x, depth_pred_16x,seg_pred_2x,seg_pred_4x,seg_pred_8x,seg_pred_16x]
loss_weights=[0.25,0.25,0.25,0.25,0.75,0.75,0.75,0.75]

# 建立 loss function
# depth_loss = custom_depth_loss(depth_weight=1.0, disparity_weight=0.0)


model.compile(optimizer=tf.keras.optimizers.Adam())
model.fit(dataset.take(100), epochs=10, callbacks=callbacks)

In [None]:
def train(dataset, epochs):
    his_loss = []
    for epoch in range(2000, 2000+epochs):
        total_loss = 0
        count = 0
        for image_batch, cam in dataset.take(1):
            loss, flow, occlusion = train_step(image_batch)
            total_loss += loss.numpy()
            count += 1
        
        if epoch % 10 == 0:
            plt.figure(figsize=(20, 20))
            plt.subplot(2, 2, 1)
            plt.imshow(image_batch[0][:, :, :3])
            plt.subplot(2, 2, 2)
            plt.imshow(image_batch[0][:, :, 3:6])
            plt.show()

            plt.figure(figsize=(20, 20))
            plt.subplot(2, 2, 1)
            plt.imshow(flow[0].numpy())
            plt.subplot(2, 2, 2)
            plt.imshow(occlusion.numpy()[0].squeeze(), cmap='gray')
            plt.show()
            imageio.imwrite('flow{}.jpg'.format(epoch), flow[0].numpy())
            imageio.imwrite('occlusion{}.jpg'.format(epoch), occlusion[0].numpy())
        
        print('[info] Epoch:{} Loss:{}'.format(epoch, total_loss/count))
        his_loss.append(total_loss/count)
        
        
        
        

    return his_loss

@tf.function
def train_step(images):
    
    with tf.GradientTape() as tape:
        output_model = model(images) #training=True
        loss, flow, occlusion = build_flow_loss(images, output_model)

    gradients_of_model = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients_of_model, model.trainable_variables))

    return loss, flow, occlusion

@tf.function
def build_flow_loss(input_images, output_model):
    tgt_image = input_images[:, :, :, 3:6]
    src_image_stack = tf.concat([input_images[:, :, :, :3], input_images[:, :, :, 6:]], axis=3)
    
    pred_fw_flows = output_model[:3]
    pred_bw_flows = output_model[3:]
    
    reconstructed_loss = 0
    cross_reconstructed_loss = 0
    flow_smooth_loss = 0
    cross_flow_smooth_loss = 0
    ssim_loss = 0
    cross_ssim_loss = 0

    curr_tgt_image_all = []
    curr_src_image_stack_all = []
    occlusion_map_0_all = []
    occlusion_map_1_all = []
    occlusion_map_2_all = []
    occlusion_map_3_all = []
    occlusion_map_4_all = []
    occlusion_map_5_all = []

    # Calculate different scale occulsion maps described in 'Occlusion Aware Unsupervised
    # Learning of Optical Flow by Yang Wang et al'
    occu_masks_bw = []
    occu_masks_bw_avg = []
    occu_masks_fw = []
    occu_masks_fw_avg = []

    for i in range(len(pred_bw_flows)):
        temp_occu_masks_bw = []
        temp_occu_masks_bw_avg = []
        temp_occu_masks_fw = []
        temp_occu_masks_fw_avg = []

        for s in range(NUM_SCALE):
            H = int(IMG_HEIGHT / (2**s))
            W = int(IMG_WIDTH  / (2**s))

            mask, mask_avg = occulsion(pred_bw_flows[i][s], H, W)
            temp_occu_masks_bw.append(mask)
            temp_occu_masks_bw_avg.append(mask_avg)
            # [src0, tgt, src0_1]

            mask, mask_avg = occulsion(pred_fw_flows[i][s], H, W)
            temp_occu_masks_fw.append(mask)
            temp_occu_masks_fw_avg.append(mask_avg)
            # [tgt, src1, src1_1]

        occu_masks_bw.append(temp_occu_masks_bw)
        occu_masks_bw_avg.append(temp_occu_masks_bw_avg)
        occu_masks_fw.append(temp_occu_masks_fw)
        occu_masks_fw_avg.append(temp_occu_masks_fw_avg)

    for s in range(NUM_SCALE):
        H = int(IMG_HEIGHT / (2**s))
        W = int(IMG_WIDTH  / (2**s))
        curr_tgt_image = tf.image.resize(
            tgt_image, [H, W], method=tf.image.ResizeMethod.BILINEAR)
        curr_src_image_stack = tf.image.resize(
            src_image_stack, [H, W], method=tf.image.ResizeMethod.BILINEAR)

        curr_tgt_image_all.append(curr_tgt_image)
        curr_src_image_stack_all.append(curr_src_image_stack)

        # src0
        curr_proj_image_optical_src0 = transformer_old(curr_tgt_image, pred_fw_flows[0][s], [H, W])
        curr_proj_error_optical_src0 = tf.abs(curr_proj_image_optical_src0 - curr_src_image_stack[:,:,:,0:3])
        reconstructed_loss += tf.reduce_mean(
            input_tensor=curr_proj_error_optical_src0 * occu_masks_bw[0][s]) / occu_masks_bw_avg[0][s]

        curr_proj_image_optical_src0_1 = transformer_old(curr_src_image_stack[:,:,:,3:6], pred_fw_flows[2][s], [H, W])
        curr_proj_error_optical_src0_1 = tf.abs(curr_proj_image_optical_src0_1 - curr_src_image_stack[:,:,:,0:3])
        cross_reconstructed_loss += tf.reduce_mean(
            input_tensor=curr_proj_error_optical_src0_1 * occu_masks_bw[2][s]) / occu_masks_bw_avg[2][s]

        # tgt
        curr_proj_image_optical_tgt = transformer_old(curr_src_image_stack[:,:,:,3:6], pred_fw_flows[1][s], [H, W])
        curr_proj_error_optical_tgt = tf.abs(curr_proj_image_optical_tgt - curr_tgt_image)
        reconstructed_loss += tf.reduce_mean(
            input_tensor=curr_proj_error_optical_tgt * occu_masks_bw[1][s]) / occu_masks_bw_avg[1][s]

        curr_proj_image_optical_tgt_1 = transformer_old(curr_src_image_stack[:,:,:,0:3], pred_bw_flows[0][s], [H, W])
        curr_proj_error_optical_tgt_1 = tf.abs(curr_proj_image_optical_tgt_1 - curr_tgt_image)
        reconstructed_loss += tf.reduce_mean(
            input_tensor=curr_proj_error_optical_tgt_1 * occu_masks_fw[0][s]) / occu_masks_fw_avg[0][s]

        # src1
        curr_proj_image_optical_src1 = transformer_old(curr_tgt_image, pred_bw_flows[1][s], [H, W])
        curr_proj_error_optical_src1 = tf.abs(curr_proj_image_optical_src1 - curr_src_image_stack[:,:,:,3:6])
        reconstructed_loss += tf.reduce_mean(
            input_tensor=curr_proj_error_optical_src1 * occu_masks_fw[1][s]) / occu_masks_fw_avg[1][s]

        curr_proj_image_optical_src1_1 = transformer_old(curr_src_image_stack[:,:,:,0:3], pred_bw_flows[2][s], [H, W])
        curr_proj_error_optical_src1_1 = tf.abs(curr_proj_image_optical_src1_1 - curr_src_image_stack[:,:,:,3:6])
        cross_reconstructed_loss += tf.reduce_mean(
            input_tensor=curr_proj_error_optical_src1_1 * occu_masks_fw[2][s]) / occu_masks_fw_avg[2][s]

        if SSIM_WEIGHT > 0:
            # src0
            ssim_loss += tf.reduce_mean(
                input_tensor=SSIM(curr_proj_image_optical_src0 * occu_masks_bw[0][s],
                     curr_src_image_stack[:,:,:,0:3] * occu_masks_bw[0][s])) / occu_masks_bw_avg[0][s]

            cross_ssim_loss += tf.reduce_mean(
                input_tensor=SSIM(curr_proj_image_optical_src0_1 * occu_masks_bw[2][s],
                     curr_src_image_stack[:,:,:,0:3] * occu_masks_bw[2][s])) / occu_masks_bw_avg[2][s]

            # tgt
            ssim_loss += tf.reduce_mean(
                input_tensor=SSIM(curr_proj_image_optical_tgt * occu_masks_bw[1][s],
                     curr_tgt_image * occu_masks_bw[1][s])) / occu_masks_bw_avg[1][s]

            ssim_loss += tf.reduce_mean(
                input_tensor=SSIM(curr_proj_image_optical_tgt_1 * occu_masks_fw[0][s],
                     curr_tgt_image * occu_masks_fw[0][s])) / occu_masks_fw_avg[0][s]

            # src1
            ssim_loss += tf.reduce_mean(
                input_tensor=SSIM(curr_proj_image_optical_src1 * occu_masks_fw[1][s],
                     curr_src_image_stack[:,:,:,3:6] * occu_masks_fw[1][s])) / occu_masks_fw_avg[1][s]

            cross_ssim_loss += tf.reduce_mean(
                input_tensor=SSIM(curr_proj_image_optical_src1_1 * occu_masks_fw[2][s],
                     curr_src_image_stack[:,:,:,3:6] * occu_masks_fw[2][s])) / occu_masks_fw_avg[2][s]

        # Compute second-order derivatives for flow smoothness loss
        flow_smooth_loss += cal_grad2_error(
            pred_fw_flows[0][s] / 20.0, curr_src_image_stack[:,:,:,0:3], 1.0)

        flow_smooth_loss += cal_grad2_error(
            pred_fw_flows[1][s] / 20.0, curr_tgt_image, 1.0)

        cross_flow_smooth_loss += cal_grad2_error(
            pred_fw_flows[2][s] / 20.0, curr_src_image_stack[:,:,:,0:3], 1.0)

        flow_smooth_loss += cal_grad2_error(
            pred_bw_flows[0][s] / 20.0, curr_tgt_image, 1.0)

        flow_smooth_loss += cal_grad2_error(
            pred_bw_flows[1][s] / 20.0, curr_src_image_stack[:,:,:,3:6], 1.0)

        cross_flow_smooth_loss += cal_grad2_error(
            pred_bw_flows[2][s] / 20.0, curr_src_image_stack[:,:,:,3:6], 1.0)

        # [TODO] Add first-order derivatives for flow smoothness loss
        # [TODO] use robust Charbonnier penalty?

        if s == 0:
            occlusion_map_0_all = occu_masks_bw[0][s]
            occlusion_map_1_all = occu_masks_bw[1][s]
            occlusion_map_2_all = occu_masks_bw[2][s]
            occlusion_map_3_all = occu_masks_fw[0][s]
            occlusion_map_4_all = occu_masks_fw[1][s]
            occlusion_map_5_all = occu_masks_fw[2][s]

    losses = FLOW_RECONSTRUCTION_WEIGHT * ((1.0 - SSIM_WEIGHT) * \
                  (reconstructed_loss + FLOW_CROSS_GEOMETRY_WEIGHT*cross_reconstructed_loss) + \
                    SSIM_WEIGHT*(ssim_loss+FLOW_CROSS_GEOMETRY_WEIGHT*cross_ssim_loss)) + \
                  FLOW_SMOOTH_WEIGHT * (flow_smooth_loss + FLOW_CROSS_GEOMETRY_WEIGHT*cross_flow_smooth_loss)


    
#     summaries = []
#     summaries.append(tf.compat.v1.summary.scalar("total_loss", losses))
#     summaries.append(tf.compat.v1.summary.scalar("reconstructed_loss", reconstructed_loss))
#     summaries.append(tf.compat.v1.summary.scalar("cross_reconstructed_loss", cross_reconstructed_loss))
#     summaries.append(tf.compat.v1.summary.scalar("ssim_loss", ssim_loss))
#     summaries.append(tf.compat.v1.summary.scalar("cross_ssim_loss", cross_ssim_loss))
#     summaries.append(tf.compat.v1.summary.scalar("flow_smooth_loss", flow_smooth_loss))
#     summaries.append(tf.compat.v1.summary.scalar("cross_flow_smooth_loss", cross_flow_smooth_loss))

#     s = 0
#     tf.compat.v1.summary.image('scale%d_target_image' % s, tf.image.convert_image_dtype(curr_tgt_image_all[0], dtype=tf.uint8))

#     for i in range(NUM_SOURCE):
#         tf.compat.v1.summary.image('scale%d_src_image_%d' % (s, i), \
#                         tf.image.convert_image_dtype(curr_src_image_stack_all[0][:, :, :, i*3:(i+1)*3], dtype=tf.uint8))

#     tf.compat.v1.summary.image('scale%d_flow_src02tgt' % s, fl.flow_to_color(self.pred_fw_flows[0][s], max_flow=256))
#     tf.compat.v1.summary.image('scale%d_flow_tgt2src1' % s, fl.flow_to_color(self.pred_fw_flows[1][s], max_flow=256))
#     tf.compat.v1.summary.image('scale%d_flow_src02src1' % s, fl.flow_to_color(self.pred_fw_flows[2][s], max_flow=256))
#     tf.compat.v1.summary.image('scale%d_flow_tgt2src0' % s, fl.flow_to_color(self.pred_bw_flows[0][s], max_flow=256))
#     tf.compat.v1.summary.image('scale%d_flow_src12tgt' % s, fl.flow_to_color(self.pred_bw_flows[1][s], max_flow=256))
#     tf.compat.v1.summary.image('scale%d_flow_src12src0' % s, fl.flow_to_color(self.pred_bw_flows[2][s], max_flow=256))

#     tf.compat.v1.summary.image('scale_flyout_mask_src0', occlusion_map_0_all)
#     tf.compat.v1.summary.image('scale_flyout_mask_tgt', occlusion_map_1_all)
#     tf.compat.v1.summary.image('scale_flyout_mask_src0_1', occlusion_map_2_all)
#     tf.compat.v1.summary.image('scale_flyout_mask_tgt1', occlusion_map_3_all)
#     tf.compat.v1.summary.image('scale_flyout_mask_src1', occlusion_map_4_all)
#     tf.compat.v1.summary.image('scale_flyout_mask_src1_1', occlusion_map_5_all)

#     self.summ_op = tf.compat.v1.summary.merge(summaries)
    s = 0
    flow = fl.flow_to_color(pred_fw_flows[0][s], max_flow=256)
    occlusion = occlusion_map_0_all
    return losses, flow, occlusion

def occulsion(pred_flow, H, W):
    """
    Here, we compute the soft occlusion maps proposed in https://arxiv.org/pdf/1711.05890.pdf

    pred_flow: the estimated forward optical flow
    """
    transformerFwd = TransformerFwd()
    occu_mask = [
        tf.clip_by_value(
            transformerFwd(
                tf.ones(shape=[BATCH_SIZE, H, W, 1], dtype='float32'),
                pred_flow, [H , W], backprop=True),
            clip_value_min=0.0,
            clip_value_max=1.0)
        ]
    occu_mask = tf.reshape(occu_mask, [BATCH_SIZE, H, W, 1])
    occu_mask_avg = tf.reduce_mean(input_tensor=occu_mask)

    return occu_mask, occu_mask_avg




In [None]:
tf.keras.backend.set_learning_phase(True)

In [None]:
his = train(dataset, 1000)

## Show flow and occlusion mask 

In [None]:
x = len(his)
x = np.arange(x)
plt.plot(x, his)

In [None]:
for i in his:
    temp.append(i)

In [None]:
x = len(temp)
x = np.arange(x)
plt.plot(x, temp)

In [None]:
import numpy as np
flow = tf.random.uniform((1,256,832,3), 0, 1)
occlusion = tf.random.uniform((1,256,832,1), 0, 1)

print('[info] flow:', flow.shape)
print('[info] occlusion:', occlusion.shape)
plt.figure(figsize=(20, 50))
plt.subplot(121)
plt.imshow(flow[0].numpy())
plt.subplot(122)
plt.imshow(occlusion.numpy()[0].squeeze(), cmap='gray')
plt.show()