#### Data Pre-processing (Explain further)

1. Resizing (512x512) - Split each image into 3
2. Grayscaling
3. Data Augmentation 
4. Data Normalization (255 instead of 65523 bc 8 bit grayscale)

In [4]:
import os
import cv2
from glob import glob
from matplotlib import pyplot as plt
import numpy as np
from utils import augment_data, preprocess_data

In [5]:
# Path where the images are stored
data_folder = '/Users/denisebeh/Downloads/SD1/'
data_dirs = ['train', 'val']
out_dirs = ['out', 'aug']
dirs = ['GroundTruth', 'GlareImage', 'GlareMask']

# Dir path to save the images
output_path = '/Users/denisebeh/Downloads/SD1_Output/'
directories = {}

# Get dir paths to save output images
for data_dir in data_dirs:
    directories[data_dir] = {}
    for out_dir in out_dirs:
        if data_dir == 'val' and out_dir == 'aug':
            continue
        directories[data_dir][out_dir] = {}
        for dir in dirs:
            path = os.path.join(output_path, data_dir, out_dir, dir)
            if not os.path.exists(path):
                os.makedirs(path)
            directories[data_dir][out_dir][dir] = path

In [3]:
train_images = sorted(glob(data_folder + data_dirs[0] + "/*png"))
val_images = sorted(glob(data_folder + data_dirs[1] + "/*png"))

# Pre-process training set images (Resizing + Grayscaling)
for idx, img_name in enumerate(train_images):
    img = cv2.imread(img_name, -1)
    # each image is now 512x512 with 4 channels
    ground_truth, glare_image, glare_mask = preprocess_data.resize_image(img)

    # Convert ground truth and glare images into grayscale
    grey_gt = preprocess_data.grayscale_image(ground_truth)
    grey_gi = preprocess_data.grayscale_image(glare_image)
    grey_gm = preprocess_data.grayscale_image(glare_mask)

    # Save outputs
    cv2.imwrite(os.path.join(directories['train']['out']['GroundTruth'], str(idx+1).zfill(6)) + ".png", grey_gt)
    cv2.imwrite(os.path.join(directories['train']['out']['GlareImage'], str(idx+1).zfill(6)) + ".png", grey_gi)
    cv2.imwrite(os.path.join(directories['train']['out']['GlareMask'], str(idx+1).zfill(6)) + ".png", grey_gm)
    print("Saved image ", idx+1)

# Pre-process validation set images (Resizing + Grayscaling)
for idx, img_name in enumerate(val_images):
    img = cv2.imread(img_name, -1)
    # each image is now 512x512 with 4 channels
    ground_truth, glare_image, glare_mask = preprocess_data.resize_image(img)

    # Convert ground truth and glare images into grayscale
    grey_gt = preprocess_data.grayscale_image(ground_truth)
    grey_gi = preprocess_data.grayscale_image(glare_image)
    grey_gm = preprocess_data.grayscale_image(glare_mask)
    
    # Save outputs
    cv2.imwrite(os.path.join(directories['val']['out']['GroundTruth'], str(idx+1).zfill(6)) + ".png", grey_gt)
    cv2.imwrite(os.path.join(directories['val']['out']['GlareImage'], str(idx+1).zfill(6)) + ".png", grey_gi)
    cv2.imwrite(os.path.join(directories['val']['out']['GlareMask'], str(idx+1).zfill(6)) + ".png", grey_gm)
    print("Saved image ", idx+1)

Saved image  1
Saved image  2
Saved image  3
Saved image  4
Saved image  5
Saved image  6
Saved image  7
Saved image  8
Saved image  9
Saved image  10
Saved image  11
Saved image  12
Saved image  13
Saved image  14
Saved image  15
Saved image  16
Saved image  17
Saved image  18
Saved image  19
Saved image  20
Saved image  21
Saved image  22
Saved image  23
Saved image  24
Saved image  25
Saved image  26
Saved image  27
Saved image  28
Saved image  29
Saved image  30
Saved image  31
Saved image  32
Saved image  33
Saved image  34
Saved image  35
Saved image  36
Saved image  37
Saved image  38
Saved image  39
Saved image  40
Saved image  41
Saved image  42
Saved image  43
Saved image  44
Saved image  45
Saved image  46
Saved image  47
Saved image  48
Saved image  49
Saved image  50
Saved image  51
Saved image  52
Saved image  53
Saved image  54
Saved image  55
Saved image  56
Saved image  57
Saved image  58
Saved image  59
Saved image  60
Saved image  61
Saved image  62
Saved image  63
S

In [6]:
# Define augmentation pipeline
def augment_image(img):
    # Rotate right + box blur
    img_1 = augment_data.rotate_image_right(img)
    img_1 = augment_data.box_blur_image(img_1)

    # Flip horizontal + decrease brightness
    img_2 = augment_data.flip_image_horizontal(img)
    img_2 = augment_data.decrease_brightness(img_2)

    # Rotate left + positive contrast
    img_3 = augment_data.rotate_image_left(img)
    img_3 = augment_data.positive_contrast_image(img_3)
    return img_1, img_2, img_3

# Perform data augmentation for training data
train_ground_truth = sorted(glob(directories['train']['out']['GroundTruth'] + "/*png"))
train_glare_images = sorted(glob(directories['train']['out']['GlareImage'] + "/*png"))
train_glare_masks = sorted(glob(directories['train']['out']['GlareMask'] + "/*png"))

num_images = len(train_ground_truth)
assert num_images == len(train_glare_images)
assert num_images == len(train_glare_masks)

# Augment ground truth and glare images
for idx in range(num_images):
    gt = cv2.imread(train_ground_truth[idx], -1)
    gi = cv2.imread(train_glare_images[idx], -1)
    gm = cv2.imread(train_glare_masks[idx], -1)
    gt_1, gt_2, gt_3 = augment_image(gt)
    gi_1, gi_2, gi_3 = augment_image(gi)
    gm_1, gm_2, gm_3 = augment_image(gm)

    # Save outputs
    img_num = idx * 4 + 1
    cv2.imwrite(os.path.join(directories['train']['aug']['GroundTruth'], str(img_num).zfill(6)) + ".png", gt)
    cv2.imwrite(os.path.join(directories['train']['aug']['GlareImage'], str(img_num).zfill(6)) + ".png", gi)
    cv2.imwrite(os.path.join(directories['train']['aug']['GlareMask'], str(img_num).zfill(6)) + ".png", gm)
    print("Saved image ", img_num)
    img_num += 1
    cv2.imwrite(os.path.join(directories['train']['aug']['GroundTruth'], str(img_num).zfill(6)) + ".png", gt_1)
    cv2.imwrite(os.path.join(directories['train']['aug']['GlareImage'], str(img_num).zfill(6)) + ".png", gi_1)
    cv2.imwrite(os.path.join(directories['train']['aug']['GlareMask'], str(img_num).zfill(6)) + ".png", gm_1)
    print("Saved image ", img_num)
    img_num += 1
    cv2.imwrite(os.path.join(directories['train']['aug']['GroundTruth'], str(img_num).zfill(6)) + ".png", gt_2)
    cv2.imwrite(os.path.join(directories['train']['aug']['GlareImage'], str(img_num).zfill(6)) + ".png", gi_2)
    cv2.imwrite(os.path.join(directories['train']['aug']['GlareMask'], str(img_num).zfill(6)) + ".png", gm_2)
    print("Saved image ", img_num)
    img_num += 1
    cv2.imwrite(os.path.join(directories['train']['aug']['GroundTruth'], str(img_num).zfill(6)) + ".png", gt_3)
    cv2.imwrite(os.path.join(directories['train']['aug']['GlareImage'], str(img_num).zfill(6)) + ".png", gi_3)
    cv2.imwrite(os.path.join(directories['train']['aug']['GlareMask'], str(img_num).zfill(6)) + ".png", gm_3)
    print("Saved image ", img_num)

Saved image  1
Saved image  2
Saved image  3
Saved image  4
Saved image  5
Saved image  6
Saved image  7
Saved image  8
Saved image  9
Saved image  10
Saved image  11
Saved image  12
Saved image  13
Saved image  14
Saved image  15
Saved image  16
Saved image  17
Saved image  18
Saved image  19
Saved image  20
Saved image  21
Saved image  22
Saved image  23
Saved image  24
Saved image  25
Saved image  26
Saved image  27
Saved image  28
Saved image  29
Saved image  30
Saved image  31
Saved image  32
Saved image  33
Saved image  34
Saved image  35
Saved image  36
Saved image  37
Saved image  38
Saved image  39
Saved image  40
Saved image  41
Saved image  42
Saved image  43
Saved image  44
Saved image  45
Saved image  46
Saved image  47
Saved image  48
Saved image  49
Saved image  50
Saved image  51
Saved image  52
Saved image  53
Saved image  54
Saved image  55
Saved image  56
Saved image  57
Saved image  58
Saved image  59
Saved image  60
Saved image  61
Saved image  62
Saved image  63
S

#### Model Selection
- Using: https://github.com/ChenyangLEI/polarization-reflection-removal?tab=readme-ov-file
- List and explain any assumptions or interpretations you have made about the requirements

#### Model Training
Explain the model training process

- The model should be able to take in a single OR a batch of [512 x 512] input images for inference.


In [7]:
import os
import cv2
from glob import glob
import matplotlib.pyplot as plt
import numpy as np
import time
import tensorflow as tf
from utils import utils
from utils.network import DialUNet as UNet
import utils.losses as loss

[i] Loaded pre-trained vgg19 parameters


In [8]:
# To verify if tf will utilize GPU
devices = tf.config.list_physical_devices()
print("\nDevices: ", devices)

gpus=tf.config.list_physical_devices('GPU')
if gpus:
    details = tf.config.experimental.get_device_details(gpus[0])
    print("GPU details: ", details)


Devices:  [PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
GPU details:  {'device_name': 'METAL'}


In [9]:
# Prepare data
train_gi_names = sorted(glob(directories['train']['out']['GlareImage'] + '/*png'))
train_gt_names = sorted(glob(directories['train']['out']['GroundTruth'] + '/*png'))
train_gm_names = sorted(glob(directories['train']['out']['GlareMask'] + '/*png'))

val_gi_names = sorted(glob(directories['val']['out']['GlareImage'] + '/*png'))
val_gt_names = sorted(glob(directories['val']['out']['GroundTruth'] + '/*png'))
val_gm_names = sorted(glob(directories['val']['out']['GlareMask'] + '/*png'))

assert len(train_gi_names) == len(train_gt_names)
assert len(train_gi_names) == len(train_gm_names)

assert len(val_gi_names) == len(val_gt_names)
assert len(val_gi_names) == len(val_gm_names)

num_train, num_val = len(train_gi_names), len(val_gi_names)
print('Data load succeed!', num_train, num_val)

Data load succeed! 12000 650


In [10]:
ckpt_path = '/Users/denisebeh/Desktop/NUS/Job Apps/HTX/image-deglaring-project/Submission_ckpt'

# Set up the model and define the graph
with tf.compat.v1.variable_scope(tf.compat.v1.get_variable_scope()):
    tf.compat.v1.disable_eager_execution()
    input=tf.compat.v1.placeholder(tf.float32,shape=[None,None,None,5])
    reflection=tf.compat.v1.placeholder(tf.float32,shape=[None,None,None,5])
    target=tf.compat.v1.placeholder(tf.float32,shape=[None,None,None,5])
    overexp_mask = utils.tf_overexp_mask(input)
    tf_input, tf_reflection, tf_target, real_input = utils.prepare_real_input(input, target, reflection, overexp_mask)
    reflection_layer=UNet(real_input, ext='Ref_')
    transmission_layer = UNet(tf.concat([real_input, reflection_layer],axis=3),ext='Tran_') 

    lossDict = {}

    lossDict["percep_t"]=0.2*loss.compute_percep_loss(0.5 * tf_target[...,4:5],  0.5*transmission_layer[...,4:5], overexp_mask, reuse=False )
    lossDict["percep_r"]=0.2*loss.compute_percep_loss(0.5 * tf_reflection[...,4:5], 0.5*reflection_layer[...,4:5], overexp_mask, reuse=True)

    lossDict["pncc"] = 6*loss.compute_percep_ncc_loss(tf.multiply(0.5*transmission_layer[...,4:5],overexp_mask), 
        tf.multiply(0.5*reflection_layer[...,4:5],overexp_mask))

    lossDict["reconstruct"]= loss.mask_reconstruct_loss(tf_input[...,4:5], transmission_layer[...,4:5], reflection_layer[...,4:5], overexp_mask)
    
    lossDict["reflection"] = lossDict["percep_r"]
    lossDict["transmission"]=lossDict["percep_t"]
    lossDict["all_loss"] = lossDict["reflection"] + lossDict["transmission"] + lossDict["pncc"]

[i] Hypercolumn ON, building hypercolumn features ... 
[i] Hypercolumn ON, building hypercolumn features ... 
[i] Hypercolumn ON, building hypercolumn features ... 
[i] Hypercolumn ON, building hypercolumn features ... 
[i] Hypercolumn ON, building hypercolumn features ... 
[i] Hypercolumn ON, building hypercolumn features ... 




In [11]:
######### Session #########
all_vars=[var for var in tf.compat.v1.trainable_variables() if 'g_' in var.name]
all_opt=tf.compat.v1.train.AdamOptimizer(learning_rate=0.0001).minimize(lossDict["all_loss"],var_list=all_vars)
for var in tf.compat.v1.trainable_variables():
    print("Listing trainable variables ... ",var)

saver=tf.compat.v1.train.Saver(max_to_keep=20)
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
sess=tf.compat.v1.Session(config=config)
sess.run(tf.compat.v1.global_variables_initializer())
var_restore = [v for v in tf.compat.v1.trainable_variables()]
saver_restore=tf.compat.v1.train.Saver(var_restore)
ckpt=tf.train.get_checkpoint_state(ckpt_path)

Listing trainable variables ...  <tf.Variable 'Ref_g_conv1_1/weights:0' shape=(1, 1, 655, 32) dtype=float32>
Listing trainable variables ...  <tf.Variable 'Ref_g_conv1_1/biases:0' shape=(32,) dtype=float32>
Listing trainable variables ...  <tf.Variable 'Ref_g_conv1_2/weights:0' shape=(3, 3, 32, 32) dtype=float32>
Listing trainable variables ...  <tf.Variable 'Ref_g_conv1_2/biases:0' shape=(32,) dtype=float32>
Listing trainable variables ...  <tf.Variable 'Ref_g_conv2_1/weights:0' shape=(3, 3, 32, 64) dtype=float32>
Listing trainable variables ...  <tf.Variable 'Ref_g_conv2_1/biases:0' shape=(64,) dtype=float32>
Listing trainable variables ...  <tf.Variable 'Ref_g_conv2_2/weights:0' shape=(3, 3, 64, 64) dtype=float32>
Listing trainable variables ...  <tf.Variable 'Ref_g_conv2_2/biases:0' shape=(64,) dtype=float32>
Listing trainable variables ...  <tf.Variable 'Ref_g_conv3_1/weights:0' shape=(3, 3, 64, 128) dtype=float32>
Listing trainable variables ...  <tf.Variable 'Ref_g_conv3_1/biase

2024-04-10 14:55:14.662497: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M2
2024-04-10 14:55:14.662521: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 16.00 GB
2024-04-10 14:55:14.662527: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 5.33 GB
2024-04-10 14:55:14.662566: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:303] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2024-04-10 14:55:14.662584: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:269] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
2024-04-10 14:55:16.256126: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:375] MLIR V1 optimization pass is not enabled
2024-04-10 14:55:17.214744: I tensorflow/core/grapple

In [12]:
# log checkpoints
print("[i] contain checkpoint: ", ckpt)
assert ckpt
saver_restore=tf.compat.v1.train.Saver([var for var in tf.compat.v1.trainable_variables()])
print('loaded ' + ckpt.model_checkpoint_path)
saver_restore.restore(sess,ckpt.model_checkpoint_path)

[i] contain checkpoint:  model_checkpoint_path: "/Users/denisebeh/Desktop/NUS/Job Apps/HTX/image-deglaring-project/Submission_ckpt/model.ckpt"
all_model_checkpoint_paths: "/Users/denisebeh/Desktop/NUS/Job Apps/HTX/image-deglaring-project/Submission_ckpt/0055/model.ckpt"
all_model_checkpoint_paths: "/Users/denisebeh/Desktop/NUS/Job Apps/HTX/image-deglaring-project/Submission_ckpt/0060/model.ckpt"
all_model_checkpoint_paths: "/Users/denisebeh/Desktop/NUS/Job Apps/HTX/image-deglaring-project/Submission_ckpt/0065/model.ckpt"
all_model_checkpoint_paths: "/Users/denisebeh/Desktop/NUS/Job Apps/HTX/image-deglaring-project/Submission_ckpt/0070/model.ckpt"
all_model_checkpoint_paths: "/Users/denisebeh/Desktop/NUS/Job Apps/HTX/image-deglaring-project/Submission_ckpt/0075/model.ckpt"
all_model_checkpoint_paths: "/Users/denisebeh/Desktop/NUS/Job Apps/HTX/image-deglaring-project/Submission_ckpt/0080/model.ckpt"
all_model_checkpoint_paths: "/Users/denisebeh/Desktop/NUS/Job Apps/HTX/image-deglaring-pr

2024-04-10 14:55:19.218268: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
2024-04-10 14:55:19.301672: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:954] PluggableGraphOptimizer failed: INVALID_ARGUMENT: Failed to deserialize the `graph_buf`.
2024-04-10 14:55:19.315894: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:954] PluggableGraphOptimizer failed: INVALID_ARGUMENT: Failed to deserialize the `graph_buf`.


In [14]:
# Configurations
maxepoch = 11
step = 0
save_model_freq = 2

# Train model with training data
for epoch in range(1,maxepoch):
    print("Processing epoch %d"%epoch)
    epoch_dir = "result/%04d"%(epoch)

    if os.path.isdir(epoch_dir):
        continue
    else:
        os.makedirs(epoch_dir)
    cnts = {"cnt":0, "all_t":0, "all_r":0, "all_pncc":0,"all_recon":0}

    for id in np.random.permutation(num_train):
        tmp_M = utils.prepare_single_item(train_gi_names[id])
        tmp_R = utils.prepare_single_item(train_gm_names[id])
        tmp_T = utils.prepare_single_item(train_gt_names[id])
        tmp_M, tmp_T, tmp_R = utils.crop_augmentation_MRT(tmp_M, tmp_T, tmp_R)
        fetch_list=[all_opt, overexp_mask, transmission_layer, reflection_layer, tf_input, tf_target, tf_reflection, lossDict]
        st=time.time()
        h,w=utils.crop_shape(tmp_M)
        magic = np.random.random()
        tmp_M = tmp_M[:,:h,:w,:]
        tmp_R = tmp_R[:,:h,:w,:]
        tmp_T = tmp_T[:,:h,:w,:]

        if magic < 0.5:
            tmp_M = tmp_M[:,::2,::2,:]
            tmp_R = tmp_R[:,::2,::2,:]
            tmp_T = tmp_T[:,::2,::2,:]

        _,out_mask, pred_image_t,pred_image_r,gt_input,gt_target,gt_reflection,crt_lossDict=sess.run(fetch_list,
                feed_dict={input:tmp_M, reflection:tmp_R, target:tmp_T})

        cnts,step=utils.cnts_add_display(epoch,cnts,step,crt_lossDict,st)
        if ((id % 20) == 0 and (epoch % save_model_freq == 0)) or (step % 100 == 1):
            utils.save_concat_img(out_mask, gt_input, gt_target,gt_reflection,pred_image_t,pred_image_r, epoch_dir + "/train_%06d.jpg"%(id))

    # save model and images every epoch
    if epoch % save_model_freq == 0:
        all_loss_test=np.zeros(num_val, dtype=float)#num_val*num_val//2, dtype=float)
        metrics = {"T_ssim":0,"T_psnr":0,"R_ssim":0, "R_psnr":0}
        saver.save(sess, "result/model.ckpt")
        saver.save(sess, epoch_dir + "/model.ckpt")
        for id in range(num_val):
            tmp_M = utils.prepare_single_item(val_gi_names[id])
            tmp_R = utils.prepare_single_item(val_gm_names[id])
            tmp_T = utils.prepare_single_item(val_gt_names[id])

            h, w = utils.crop_shape(tmp_M)
            out_loss,out_mask,pred_image_t, pred_image_r, gt_input,gt_target,gt_reflection=sess.run([lossDict["transmission"], 
                overexp_mask, transmission_layer, reflection_layer,tf_input,tf_target,tf_reflection],
                feed_dict={input:tmp_M[:,:h:2,:w:2,:], reflection:tmp_R[:,:h:2,:w:2,:], target:tmp_T[:,:h:2,:w:2,:]})
            print("Epc: %3d, shape of outputs: "%epoch,pred_image_t.shape, pred_image_r.shape)
            utils.save_concat_img(out_mask, gt_input, gt_target, gt_reflection, pred_image_t, pred_image_r, epoch_dir + "/val_%06d.jpg"%(id))
            all_loss_test[id]=out_loss
            metrics = utils.get_metrics(metrics,out_mask, gt_target,gt_reflection,pred_image_t,pred_image_r)
        utils.save_results(all_loss_test, metrics, id, epoch)

Processing epoch 1
Processing epoch 2
Processing epoch 3
Processing epoch 4
Processing epoch 5


iter: 005 001 || r:20.77 | t:10.50 | pncc:3.12 |recon:0.036 |time:23.09
Epc:   5, shape of outputs:  (1, 128, 128, 5) (1, 128, 128, 5)
gt_target is:  float32
out mask is:  float32
pred image T is:  float32
first is:  float32
second is:  float32
Processing epoch 6
iter: 006 001 || r:15.25 | t:8.00 | pncc:2.93 |recon:0.032 |time:4.18
Epc:   6, shape of outputs:  (1, 128, 128, 5) (1, 128, 128, 5)
gt_target is:  float32
out mask is:  float32
pred image T is:  float32
first is:  float32
second is:  float32
Processing epoch 7
iter: 007 001 || r:17.45 | t:8.26 | pncc:2.83 |recon:0.030 |time:2.57
Epc:   7, shape of outputs:  (1, 128, 128, 5) (1, 128, 128, 5)
gt_target is:  float32
out mask is:  float32
pred image T is:  float32
first is:  float32
second is:  float32
Processing epoch 8
iter: 008 001 || r:15.37 | t:7.27 | pncc:2.64 |recon:0.029 |time:3.07
Epc:   8, shape of outputs:  (1, 128, 128, 5) (1, 128, 128, 5)
gt_target is:  float32
out mask is:  float32
pred image T is:  float32
first is

KeyboardInterrupt: 

#### Results
- Present the final training and validation results (L1 loss per epoch)
- Provide a section to demonstrate the model’s functionality by showing the inference output removing the glare in an input image