In [1]:
import os
import time
import glob
import random
from PIL import Image
import numpy as np
import tensorflow as tf

In [2]:
def load_image(img_path):
    img = tf.io.read_file(img_path)
    img = tf.io.decode_png(img, channels=0, dtype=tf.dtypes.uint16)
    img = tf.image.resize(img, size=(224, 224), antialias=True)
    img = img / 255.0
    return img

In [14]:
def data_path(image_path, hazy_image_path):
    images_path = list(glob.glob(image_path + "*.png"))
    hazy_images_path = list(glob.glob(hazy_image_path + "*.png"))
    
    train_img = []
    val_img = []
    
    # Match each GT image with its corresponding hazy image
    for images in images_path:
        for hazy_images in hazy_images_path:
            train_img.append(hazy_images)
            val_img.append(hazy_images)
            
    total_images = len(train_img)

    temp = list(zip(train_img, val_img))
    np.random.shuffle(temp)
    train_img, val_img = zip(*temp)

    train_img = list(train_img)
    val_img = list(val_img)
    
    # 60:40 : 27:18
    train_img = train_img[:int(total_images * 0.6)]
    train_hazy = val_img[:int(total_images * 0.6)]
    val_img = train_img[int(total_images * 0.6):]
    val_hazy = val_img[int(total_images * 0.6):]

    return train_img, val_img

In [15]:
def dataLoder(train_data, val_data, batch_size):
    train_data_img = tf.data.Dataset.from_tensor_slices(img[1] for img in train_data).map(lambda data: load_image(data))
    train_data_hazy = tf.data.Dataset.from_tensor_slices(img[0] for img in train_data).map(lambda data: load_image(data))
    train = tf.data.Dataset.zip((train_data_img, train_data_hazy)).batch(batch_size)
    
    val_data_img = tf.data.Dataset.from_tensor_slices(img[1] for img in val_data).map(lambda data: load_image(data))
    val_data_hazy = tf.data.Dataset.from_tensor_slices(img[0] for img in val_data).map(lambda data: load_image(data))
    val = tf.data.Dataset.zip((val_data_img, val_data_hazy)).batch(batch_size)
    
    return train, val

In [16]:
def result(model, hazy_img, img):

    dehazed_img = model(hazy_img, training = True)
    plt.figure(figsize = (15,12))

    display_list = [hazy_img[0], img[0], dehazed_img[0]]
    title = ['Hazy Image', 'Ground Truth', 'Dehazed Image']

    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.title(title[i])
        plt.imshow(display_list[i])
        plt.axis('off')

    plt.show()

In [17]:
def create_model():
    inputs = tf.keras.Input(shape=[224, 224, 3])
    model = tf.keras.Sequential()

    # Convolutional Layers
    model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same', input_shape=input_shape))
    model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same'))
    model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2)))

    model.add(tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same'))
    model.add(tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same'))
    model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2)))

    model.add(tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same'))
    model.add(tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same'))
    model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2)))

    # Upsampling Layers
    model.add(tf.keras.layers.Conv2DTranspose(128, (3, 3), strides=(2, 2), activation='relu', padding='same'))
    model.add(tf.keras.layers.Conv2DTranspose(64, (3, 3), strides=(2, 2), activation='relu', padding='same'))

    # Output Layer
    model.add(tf.keras.layers.Conv2D(3, (1, 1), activation='sigmoid', padding='same'))

    return model

In [20]:
epochs = 15
batch_size = 16
k_init = tf.keras.initializers.random_normal(stddev=0.008, seed = 101)
regularizer = tf.keras.regularizers.L2(1e-4)
b_init = tf.constant_initializer()

train_data, val_data = data_path(image_path="../data/test/target/", hazy_image_path="../data/test/input/")
train, val = dataLoder(train_data, val_data, batch_size)

ValueError: Attempt to convert a value (<generator object dataLoder.<locals>.<genexpr> at 0x000001C9C6A2F060>) with an unsupported type (<class 'generator'>) to a Tensor.