In [None]:
import os
import cv2
import numpy as np
from glob import glob
import time
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tqdm import tqdm
from joblib import Parallel, delayed
import tensorflow as tf
from tensorflow.keras import layers, Model

IMAGE_SIZE = 256
BATCH_SIZE = 32

def load_data(image_path, image_size):
    try:
        image = cv2.imread(image_path)
        if image is None:
            return None
        image = cv2.resize(image, (image_size, image_size))
        image = image / 255.0
        return image
    except Exception:
        return None



def get_image_paths(folder_path):
    image_paths = []
    for filename in os.listdir(folder_path):
        if filename.endswith(('.png', '.jpg')):
            image_paths.append(os.path.join(folder_path, filename))
    return image_paths

def data_generator(image_paths, batch_size=32, image_size=128):
    print('Loading data generator...')
    print('Number of images found:', len(image_paths))
    num_samples = len(image_paths)
    num_batches = (num_samples + batch_size - 1) // batch_size

    datagen = ImageDataGenerator(
        rotation_range=20,
        width_shift_range=0.1,
        height_shift_range=0.1,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest')

    batch_list = []
    for batch_idx in tqdm(range(num_batches)):
        start_idx = batch_idx * batch_size
        end_idx = min((batch_idx + 1) * batch_size, len(image_paths))
        batch_paths = image_paths[start_idx:end_idx]
        batch_images = []
        for image_path in batch_paths:
            image_data = load_data(image_path, image_size)
            if image_data is not None:
                # Apply data augmentation
                image_data = np.expand_dims(image_data, axis=0)
                augmented_image = next(datagen.flow(image_data, batch_size=1))[0]
                batch_images.append(augmented_image)

        if batch_images:
            batch_list.append(np.array(batch_images))
    return batch_list

def process_data(image_folder, num_processes, num_threads):
    image_paths = get_image_paths(image_folder)
    chunk_size = len(image_paths) // num_processes
    chunks = [image_paths[i:i + chunk_size] for i in range(0, len(image_paths), chunk_size)]

    results = Parallel(n_jobs=num_processes, backend="multiprocessing")(
        delayed(data_generator)(chunk, BATCH_SIZE, IMAGE_SIZE) for chunk in chunks
    )

    for result in results:
        for batch in result:
            pass

def build_dce_net():
    input_img = layers.Input(shape=[None, None, 3])
    conv1 = layers.Conv2D(
        32, (3, 3), strides=(1, 1), activation='relu', padding='same'
    )(input_img)
    conv2 = layers.Conv2D(
        32, (3, 3), strides=(1, 1), activation='relu', padding='same'
    )(conv1)
    conv3 = layers.Conv2D(
        32, (3, 3), strides=(1, 1), activation='relu', padding='same'
    )(conv2)
    conv4 = layers.Conv2D(
        32, (3, 3), strides=(1, 1), activation='relu', padding='same'
    )(conv3)
    int_con1 = layers.Concatenate(axis=-1)([conv4, conv3])
    conv5 = layers.Conv2D(
        32, (3, 3), strides=(1, 1), activation='relu', padding='same'
    )(int_con1)
    int_con2 = layers.Concatenate(axis=-1)([conv5, conv2])
    conv6 = layers.Conv2D(
        32, (3, 3), strides=(1, 1), activation='relu', padding='same'
    )(int_con2)
    int_con3 = layers.Concatenate(axis=-1)([conv6, conv1])
    x_r = layers.Conv2D(24, (3, 3), strides=(1, 1), activation='tanh', padding='same')(
        int_con3
    )
    return Model(inputs=input_img, outputs=x_r)

# Define custom loss functions
def color_constancy_loss(x):
    mean_rgb = tf.reduce_mean(x,axis=(1,2),keepdims=True)
    mr,mg,mb = mean_rgb[:,:,:,0],mean_rgb[:,:,:,1],mean_rgb[:,:,:,2]
    d_rg = tf.square(mr - mg)
    d_rb = tf.square(mr - mb)
    d_gb = tf.square(mb - mg)
    return tf.sqrt(tf.square(d_rg) + tf.square(d_rb) + tf.square(d_gb))

def exposure_loss(x, mean_val=0.6):
    x = tf.reduce_mean(x, axis=3, keepdims=True)
    mean = tf.nn.avg_pool2d(x, ksize=16, strides=16, padding='VALID')
    return tf.reduce_mean(tf.square(mean - mean_val))

def illumination_smoothness_loss(x):
    batch_size = tf.shape(x)[0]
    h_x = tf.shape(x)[1]
    w_x = tf.shape(x)[2]
    count_h = (tf.shape(x)[2] - 1) * tf.shape(x)[3]
    count_w = tf.shape(x)[2] * (tf.shape(x)[3] - 1)
    h_tv = tf.reduce_sum(tf.square((x[:,1:,:,:] - x[:,:h_x - 1, :, :])))
    w_tv = tf.reduce_sum(tf.square((x[:,:,1:,:] - x[:,:,:w_x - 1, :])))
    batch_size = tf.cast(count_h,dtype=tf.float32)
    count_h = tf.cast(count_h,dtype=tf.float32)
    count_w = tf.cast(count_w,dtype=tf.float32)
    return 2 * (h_tv / count_h + w_tv / count_w) / batch_size

class SpatialConsistencyLoss(tf.keras.losses.Loss):
    def __init__(self, **kwargs):
        super(SpatialConsistencyLoss, self).__init__(**kwargs)
        self.left_kernel = tf.constant([[[[0, 0, 0]], [[-1, 1, 0]], [[0, 0, 0]]]], dtype=tf.float32)
        self.right_kernel = tf.constant([[[[0, 0, 0]], [[0, 1, -1]], [[0, 0, 0]]]], dtype=tf.float32)
        self.up_kernel = tf.constant([[[[0, -1, 0]], [[0, 1, 0]], [[0, 0, 0]]]], dtype=tf.float32)
        self.down_kernel = tf.constant([[[[0, 0, 0]], [[0, 1, 0]], [[0, -1, 0]]]], dtype=tf.float32)

    def call(self, y_true, y_pred):
        original_mean = tf.reduce_mean(y_true, 3, keepdims=True)
        enhanced_mean = tf.reduce_mean(y_pred, 3, keepdims=True)
        original_pool = tf.nn.avg_pool2d(original_mean, ksize=4, strides=4, padding="VALID")
        enhanced_pool = tf.nn.avg_pool2d(enhanced_mean, ksize=4, strides=4, padding="VALID")
        d_original_left = tf.nn.conv2d(original_pool, self.left_kernel, strides=[1, 1, 1, 1], padding="SAME")
        d_original_right = tf.nn.conv2d(original_pool, self.right_kernel, strides=[1, 1, 1, 1], padding="SAME")
        d_original_up = tf.nn.conv2d(original_pool, self.up_kernel, strides=[1, 1, 1, 1], padding="SAME")
        d_original_down = tf.nn.conv2d(original_pool, self.down_kernel, strides=[1, 1, 1, 1], padding="SAME")
        d_enhanced_left = tf.nn.conv2d(enhanced_pool, self.left_kernel, strides=[1, 1, 1, 1], padding="SAME")
        d_enhanced_right = tf.nn.conv2d(enhanced_pool, self.right_kernel, strides=[1, 1, 1, 1], padding="SAME")
        d_enhanced_up = tf.nn.conv2d(enhanced_pool, self.up_kernel, strides=[1, 1, 1, 1], padding="SAME")
        d_enhanced_down = tf.nn.conv2d(enhanced_pool, self.down_kernel, strides=[1, 1, 1, 1], padding="SAME")

        d_left = tf.square(d_original_left - d_enhanced_left)
        d_right = tf.square(d_original_right - d_enhanced_right)
        d_up = tf.square(d_original_up - d_enhanced_up)
        d_down = tf.square(d_original_down - d_enhanced_down)
        return d_left + d_right + d_up + d_down

class ZeroDCE(Model):
    def __init__(self, **kwargs):
        super(ZeroDCE, self).__init__(**kwargs)
        self.dce_model = build_dce_net()
        self.spatial_constancy_loss = SpatialConsistencyLoss()

    def call(self, data):
        return self.dce_model(data)

    def compute_losses(self, data, output):
        loss_spatial_constancy = tf.reduce_mean(self.spatial_constancy_loss(data, output))
        loss_color_constancy = 5 * tf.reduce_mean(color_constancy_loss(output))
        loss_exposure = 10 * tf.reduce_mean(exposure_loss(output))
        loss_illumination = 200 * illumination_smoothness_loss(output)
        total_loss = loss_spatial_constancy + loss_color_constancy + loss_exposure + loss_illumination
        return {
            "total_loss": total_loss,
            "illumination_smoothness_loss": loss_illumination,
            "spatial_constancy_loss": loss_spatial_constancy,
            "color_constancy_loss": loss_color_constancy,
            "exposure_loss": loss_exposure,
        }

    def train_step(self, data):
        with tf.GradientTape() as tape:
            output = self.dce_model(data)
            losses = self.compute_losses(data, output)
        gradients = tape.gradient(losses["total_loss"], self.dce_model.trainable_weights)
        self.optimizer.apply_gradients(zip(gradients, self.dce_model.trainable_weights))
        return losses

    def test_step(self, data):
        output = self.dce_model(data)
        return self.compute_losses(data, output)

def train_model(model, train_batches, val_batches, epochs):
    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs}")
        for batch in train_batches:
            losses = model.train_step(batch)
            print("Training Loss:", losses["total_loss"].numpy())
        val_losses = []
        for val_batch in val_batches:
            val_loss = model.test_step(val_batch)
            val_losses.append(val_loss["total_loss"].numpy())
        print("Validation Loss:", np.mean(val_losses))


# Main part of the script
train_image_folder = 'first_zero_de_images'
val_image_folder = 'first_zero_de_images'
num_processes = int(input("Enter the number of processes: "))
num_threads = int(input("Enter the number of threads per process: "))

start_time = time.time()
process_data(train_image_folder, num_processes, num_threads)
process_data(val_image_folder, num_processes, num_threads)
end_time = time.time()
print("Execution for Data processing time:", end_time - start_time)
p1 = end_time - start_time

train_data_generator = data_generator(train_image_folder)
val_data_generator = data_generator(val_image_folder)

zero_dce_model = ZeroDCE()
zero_dce_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4))

start_time = time.time()
train_model(zero_dce_model, train_data_generator, val_data_generator, epochs=1)
end_time = time.time()
m1 = end_time - start_time
print("Model Execution time:", end_time - start_time)

print("overall time of execution :",p1+m1)


Enter the number of processes: 8
Enter the number of threads per process: 2


  self.pid = os.fork()


Loading data generator...Loading data generator...Loading data generator...
Number of images found:
 
Number of images found: Loading data generator...25
Loading data generator...

  0%|          | 0/1 [00:00<?, ?it/s]

25Loading data generator...Loading data generator...Loading data generator...Number of images found:





 Number of images found: Number of images found:

  0%|          | 0/1 [00:00<?, ?it/s]

Number of images found:Number of images found:Number of images found:2525
    25252525





  0%|          | 0/1 [00:00<?, ?it/s]




100%|██████████| 1/1 [00:16<00:00, 16.66s/it]
100%|██████████| 1/1 [00:23<00:00, 23.63s/it]
100%|██████████| 1/1 [00:24<00:00, 24.86s/it]
100%|██████████| 1/1 [00:25<00:00, 25.13s/it]
100%|██████████| 1/1 [00:25<00:00, 25.37s/it]
100%|██████████| 1/1 [00:25<00:00, 25.92s/it]
100%|██████████| 1/1 [00:26<00:00, 26.43s/it]
100%|██████████| 1/1 [00:26<00:00, 26.47s/it]


Loading data generator...Loading data generator...Loading data generator...Loading data generator...

Number of images found:Loading data generator...Loading data generator...Loading data generator...Number of images found:
  Loading data generator...



Number of images found:Number of images found:Number of images found:Number of images found:  Number of images found:25
2525Number of images found: 
 2525


  0%|          | 0/1 [00:00<?, ?it/s]

2525  




  0%|          | 0/1 [00:00<?, ?it/s]

25

  0%|          | 0/1 [00:00<?, ?it/s]






100%|██████████| 1/1 [00:16<00:00, 16.42s/it]
100%|██████████| 1/1 [00:24<00:00, 24.20s/it]
100%|██████████| 1/1 [00:24<00:00, 24.64s/it]
100%|██████████| 1/1 [00:24<00:00, 24.74s/it]
100%|██████████| 1/1 [00:25<00:00, 25.14s/it]
100%|██████████| 1/1 [00:25<00:00, 25.18s/it]
100%|██████████| 1/1 [00:25<00:00, 25.49s/it]
100%|██████████| 1/1 [00:25<00:00, 25.60s/it]


Execution for Data processing time: 54.06022810935974
Loading data generator...
Number of images found: 53


100%|██████████| 2/2 [00:00<00:00, 1507.39it/s]


Loading data generator...
Number of images found: 53


100%|██████████| 2/2 [00:00<00:00, 2038.05it/s]


Epoch 1/1
Validation Loss: nan
Model Execution time: 0.008009195327758789
overall time of execution : 54.0682373046875


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
