In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os
import numpy as np
from glob import glob
from PIL import Image, ImageOps
import matplotlib.pyplot as plt
import keras
from keras import layers
from keras.layers import Input,Conv2D,Concatenate
from keras.models import Model
import tensorflow as tf

In [3]:
TARGET_SIZE = 256
BATCH_SIZE = 32
MAX_TRAIN_IMAGES = 400

PRE_REQUISITE FUNCTION TO GENERATE DATA FROM FOLDERS

In [4]:
def load_image(file_path):
    image_data = tf.io.read_file(file_path)
    decoded_image = tf.image.decode_png(image_data, channels=3)
    resized_image = tf.image.resize(images=decoded_image, size=[TARGET_SIZE, TARGET_SIZE])
    normalized_image = resized_image / 255.0  # Scale pixel values to [0, 1]
    return normalized_image

In [5]:
def image_data_generator(image_paths):
    dataset = tf.data.Dataset.from_tensor_slices(image_paths)
    dataset = dataset.map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
    return dataset

DATASET PREPARATION

In [6]:
# Paths to your datasets
train_low_light_image_paths = sorted(glob("/content/drive/MyDrive/lol_dataset/our485/low/*"))[:MAX_TRAIN_IMAGES]
val_low_light_image_paths = sorted(glob("/content/drive/MyDrive/lol_dataset/our485/low/*"))[MAX_TRAIN_IMAGES:]

test_low_light_image_paths = sorted(glob("/content/drive/MyDrive/Train/low/*"))
test_high_light_image_paths = sorted(glob("/content/drive/MyDrive/Train/high/*"))

# Generate datasetsc
train_dataset = image_data_generator(train_low_light_image_paths)
val_dataset = image_data_generator(val_low_light_image_paths)

In [7]:
def build_dce_net():
    input_image = Input(shape=[None, None, 3])

    conv1 = Conv2D(32, (3, 3), strides=(1, 1), activation="relu", padding="same")(input_image)
    conv2 = Conv2D(32, (3, 3), strides=(1, 1), activation="relu", padding="same")(conv1)
    conv3 = Conv2D(32, (3, 3), strides=(1, 1), activation="relu", padding="same")(conv2)
    conv4 = Conv2D(32, (3, 3), strides=(1, 1), activation="relu", padding="same")(conv3)

    concat1 = Concatenate(axis=-1)([conv4, conv3])
    conv5 = Conv2D(32, (3, 3), strides=(1, 1), activation="relu", padding="same")(concat1)

    concat2 = Concatenate(axis=-1)([conv5, conv2])
    conv6 = Conv2D(32, (3, 3), strides=(1, 1), activation="relu", padding="same")(concat2)

    concat3 = Concatenate(axis=-1)([conv6, conv1])
    output_image = Conv2D(24, (3, 3), strides=(1, 1), activation="tanh", padding="same")(concat3)

    return Model(inputs=input_image, outputs=output_image)

CUSTOM LOSS FUNCTIONS

In [8]:
def compute_color_constancy_loss(image_batch):
    # Calculate the mean of each RGB channel
    mean_rgb_values = tf.reduce_mean(image_batch, axis=(1, 2), keepdims=True)
    mean_r_channel = mean_rgb_values[:, :, :, 0]
    mean_g_channel = mean_rgb_values[:, :, :, 1]
    mean_b_channel = mean_rgb_values[:, :, :, 2]

    # Compute the squared differences between the channel means
    diff_red_green = tf.square(mean_r_channel - mean_g_channel)
    diff_red_blue = tf.square(mean_r_channel - mean_b_channel)
    diff_green_blue = tf.square(mean_g_channel - mean_b_channel)

    # Calculate the color constancy loss
    color_loss = tf.sqrt(diff_red_green + diff_red_blue + diff_green_blue)
    return color_loss

In [9]:
def compute_exposure_loss(image, target_exposure=0.6):
    # Calculate the mean across the RGB channels
    grayscale_image = tf.reduce_mean(image, axis=3, keepdims=True)

    # Pool the image using a 16x16 kernel with non-overlapping regions
    pooled_mean = tf.nn.avg_pool2d(grayscale_image, ksize=16, strides=16, padding="VALID")

    # Calculate the exposure loss
    exposure_loss_value = tf.reduce_mean(tf.square(pooled_mean - target_exposure))

    return exposure_loss_value

In [10]:
def illumination_smoothness_loss(image):
    # Get the dimensions of the input tensor
    batch_size = tf.shape(image)[0]
    height = tf.shape(image)[1]
    width = tf.shape(image)[2]
    channels = tf.shape(image)[3]

    # Calculate the total number of horizontal and vertical differences
    horizontal_count = (width - 1) * channels
    vertical_count = width * (channels - 1)

    # Compute the horizontal and vertical total variation losses
    horizontal_tv_loss = tf.reduce_sum(tf.square(image[:, 1:, :, :] - image[:, :height - 1, :, :]))
    vertical_tv_loss = tf.reduce_sum(tf.square(image[:, :, 1:, :] - image[:, :, :width - 1, :]))

    # Convert counts and batch size to float for division
    batch_size = tf.cast(batch_size, dtype=tf.float32)
    horizontal_count = tf.cast(horizontal_count, dtype=tf.float32)
    vertical_count = tf.cast(vertical_count, dtype=tf.float32)

    # Calculate the smoothness loss
    smoothness_loss = 2 * (horizontal_tv_loss / horizontal_count + vertical_tv_loss / vertical_count) / batch_size

    return smoothness_loss

CUSTOM ZERO DCE MODEL ----> MODIFYING IT'S Properties

In [15]:
class ZeroDCE(Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.dce_model = build_dce_net()

    def compile(self, learning_rate, **kwargs):
        super().compile(**kwargs)
        self.optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.illumination_smoothness_loss_tracker = keras.metrics.Mean(name="illumination_smoothness_loss")
        self.color_constancy_loss_tracker = keras.metrics.Mean(name="color_constancy_loss")
        self.exposure_loss_tracker = keras.metrics.Mean(name="exposure_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.illumination_smoothness_loss_tracker,
            self.color_constancy_loss_tracker,
            self.exposure_loss_tracker,
        ]

    def get_enhanced_image(self, data, output):
        r_layers = [output[:, :, :, i:i+3] for i in range(0, 24, 3)]
        x = data
        for r in r_layers:
            x = x + r * (tf.square(x) - x)
        return x

    def call(self, data):
        dce_net_output = self.dce_model(data)
        return self.get_enhanced_image(data, dce_net_output)

    def compute_losses(self, data, output):
        enhanced_image = self.get_enhanced_image(data, output)
        loss_illumination = 200 * illumination_smoothness_loss(output)
        loss_color_constancy = 5 * tf.reduce_mean(compute_color_constancy_loss(enhanced_image))
        loss_exposure = 10 * tf.reduce_mean(compute_exposure_loss(enhanced_image))
        total_loss = loss_illumination + loss_color_constancy + loss_exposure

        return {
            "total_loss": total_loss,
            "illumination_smoothness_loss": loss_illumination,
            "color_constancy_loss": loss_color_constancy,
            "exposure_loss": loss_exposure,
        }

    def train_step(self, data):
        with tf.GradientTape() as tape:
            output = self.dce_model(data)
            losses = self.compute_losses(data, output)

        gradients = tape.gradient(losses["total_loss"], self.dce_model.trainable_weights)
        self.optimizer.apply_gradients(zip(gradients, self.dce_model.trainable_weights))

        self.total_loss_tracker.update_state(losses["total_loss"])
        self.illumination_smoothness_loss_tracker.update_state(losses["illumination_smoothness_loss"])
        self.color_constancy_loss_tracker.update_state(losses["color_constancy_loss"])
        self.exposure_loss_tracker.update_state(losses["exposure_loss"])

        return {metric.name: metric.result() for metric in self.metrics}

    def test_step(self, data):
        output = self.dce_model(data)
        losses = self.compute_losses(data, output)

        self.total_loss_tracker.update_state(losses["total_loss"])
        self.illumination_smoothness_loss_tracker.update_state(losses["illumination_smoothness_loss"])
        self.color_constancy_loss_tracker.update_state(losses["color_constancy_loss"])
        self.exposure_loss_tracker.update_state(losses["exposure_loss"])

        return {metric.name: metric.result() for metric in self.metrics}

In [12]:
def plot(images, titles, figure_size=(10, 10)):
    # Ensure images and titles are lists
    if not isinstance(images, list):
        images = [images]
    if not isinstance(titles, list):
        titles = [titles]

    fig = plt.figure(figsize=figure_size)
    for i in range(len(images)):
        ax = fig.add_subplot(1, len(images), i + 1)
        ax.set_title(titles[i])
        if images[i].ndim == 2:  # Grayscale image
            plt.imshow(images[i], cmap='gray')
        else:  # RGB image
            plt.imshow(images[i])
        plt.axis("off")

    plt.tight_layout()
    plt.show()

In [13]:
def calculate_psnr(original_image, enhanced_image):
    # Convert the images to numpy arrays if they are not already.
    original_image = np.asarray(original_image)
    enhanced_image = np.asarray(enhanced_image)

    # Ensure the images have the same shape.
    if original_image.shape != enhanced_image.shape:
        raise ValueError("The shapes of the input images must be the same.")

    # Calculate the mean squared error (MSE) between the two images.
    mse = np.mean(np.square(original_image - enhanced_image))

    # If MSE is close to zero, return a high PSNR (infinity in theory).
    if mse == 0:
        return float('inf')

    # Calculate the peak signal-to-noise ratio (PSNR).
    max_pixel_value = 255.0  # Assuming pixel values range from 0 to 255.
    psnr = 10 * np.log10((max_pixel_value ** 2) / mse)

    return psnr


In [17]:
Image_Enhancer = ZeroDCE()
Image_Enhancer.compile(learning_rate = 1e-4)
Image_Enhancer.fit(train_dataset,validation_data=val_dataset,epochs=50)

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50


<keras.src.callbacks.History at 0x7b3a1ed55750>

In [18]:
def low_to_high_light(original_image, Image_Enhancer):
    # Convert the input image to a NumPy array and normalize it to the range [0, 1]
    image = tf.keras.utils.img_to_array(original_image)
    image = image.astype("float32") / 255.0
    image = np.expand_dims(image, axis=0)

    # Enhance the image using the provided Image_Enhancer model or function
    output_image = Image_Enhancer(image)

    # Convert the enhanced image back to uint8 and create a PIL Image
    output_image = tf.cast((output_image[0, :, :, :] * 255), dtype=np.uint8)
    output_image = Image.fromarray(output_image.numpy())

    return output_image


Displaying Results for 10 images

In [20]:
psnr_ratio = []
for i in range(len(test_low_light_image_paths)):
    # Load low-light image and enhance
    low_light_image = Image.open(test_low_light_image_paths[i])
    enhanced_image = low_to_high_light(low_light_image, Image_Enhancer)

    # Load corresponding high-light image
    high_light_image = Image.open(test_high_light_image_paths[i])

    # Calculate PSNR between high-light and enhanced images
    psnr = calculate_psnr(high_light_image, enhanced_image)
    psnr_ratio.append(psnr)

In [21]:
np.average(psnr_ratio)

28.00094429035725