## Imports

In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import matplotlib.pyplot as plt # plotting
import matplotlib.image as mpimg # images
import numpy as np #numpy
import time
import seaborn as sns
import tensorflow as tf
# import tensorflow.compat.v2 as tf #use tensorflow v2 as a main
import tensorflow.keras as keras # required for high level applications
from sklearn.model_selection import train_test_split # split for validation sets
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report
from sklearn.preprocessing import normalize # normalization of the matrix
import scipy
import pandas as pd
from tensorflow.keras.models import load_model
from tensorflow.keras import layers
from tensorflow.keras.utils import Sequence
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPooling2D, Conv2DTranspose, Concatenate, Flatten, Dense
 

import os
import cv2
from tensorflow.keras.utils import normalize
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint

import segmentation_models as sm

tf.version.VERSION

## Metrics

In [None]:
def ms_ssim_metric(y_true, y_pred):
    return tf.image.ssim_multiscale(y_true, y_pred, max_val=tf.reduce_max(y_true))

In [None]:
from tensorflow.keras import backend as K
def psnr(y_true, y_pred):
    max_pixel = 1.0
    mse = K.mean(K.square(y_true - y_pred))
    psnr = 10.0 * K.log((max_pixel**2) / mse) / K.log(10.0)
    return psnr

## Utility functions

In [None]:
reduce_lr_callback = tf.keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss', 
    factor=0.5,          
    patience=5,          
    min_lr=1e-7,         
    verbose=1         
)

class ShowReconstruction(keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
      rand_id = np.random.randint(len(source_images))
      source_image = source_images[rand_id][np.newaxis,...]
      reconstructed = self.model.predict(source_image)
      real_image = target_images[rand_id]

      plt.subplot(1,3,1)
      plt.imshow(source_image[0], cmap='gray')
      plt.title("Source Image")
      plt.axis('off')

      plt.subplot(1,3,2)
      plt.imshow(reconstructed[0], cmap='gray')
      plt.title("Produced Image")
      plt.axis('off')

      plt.subplot(1,3,3)
      plt.imshow(real_image, cmap='gray')
      plt.title("Real Image")
      plt.axis('off')

      plt.tight_layout()
      #plt.savefig("Reconstruction_Epoch_{}".format(epoch))
      plt.show()
            
def normalize_images(images):
    #Normalize from [-1,1] to [0,1] for the metrics
    return (images + 1.0) / 2.0


In [None]:
def evaluate_on_test_set(generator, X_val, y_val,batch_size= 4):
    num_samples = len(X_val)
    steps = num_samples // batch_size + (1 if num_samples % batch_size else 0)
    
    psnr_values = []
    ssim_values = []
    
    for step in range(steps):
        batch_start = step * batch_size
        batch_end = min(batch_start + batch_size, num_samples)
        X_batch = X_val[batch_start:batch_end]
        y_batch = y_val[batch_start:batch_end]
        
        predicted_batch = generator.predict(X_batch)
        # Normalize images from [-1, 1] to [0, 1] for PSNR and SSIM calculations
        predicted_batch = normalize_images(predicted_batch)
        y_batch_normalized = normalize_images(y_batch)
        
        for i in range(batch_end - batch_start):
            psnr_val = psnr(y_batch_normalized[i], predicted_batch[i])
            ssim_val = ms_ssim_metric(y_batch_normalized[i], predicted_batch[i])
            
            psnr_values.append(psnr_val)
            ssim_values.append(ssim_val)
    
    avg_psnr = np.mean(psnr_values)
    avg_ssim = np.mean(ssim_values)
    
    return avg_psnr, avg_ssim

# Load the images

In [None]:
source_folder = 'Dataset/source'
target_folder = 'Dataset/target'

# List all image file names in the source folder
source_image_files = os.listdir(source_folder)

# Initialize lists to store the preprocessed images
source_images = []
target_images = []

i = 0

# Loop through each image file in the source folder
for filename in source_image_files:

    # Load source image
    source_image = cv2.imread(os.path.join(source_folder, filename), cv2.IMREAD_GRAYSCALE)
    #source_image = source_image.astype('float32') / 255.0
    source_image = (source_image.astype(np.float32) - 127.5) / 127.5
    source_image = np.repeat(source_image[:, :, np.newaxis], 3, axis=2)
    source_image = cv2.resize(source_image, (512, 512))
    source_images.append(source_image)

    # Load corresponding target image from the target folder
    target_image = cv2.imread(os.path.join(target_folder, filename), cv2.IMREAD_GRAYSCALE)
    #target_image = target_image.astype('float32') / 255.0
    target_image = (target_image.astype(np.float32) - 127.5) / 127.5
    target_image = np.repeat(target_image[:, :, np.newaxis], 3, axis=2)
    target_image = cv2.resize(target_image, (512, 512))
    target_images.append(target_image)
       
    i += 1
    if i == 3100:
        break

# Convert the lists to NumPy arrays
source_images = np.array(source_images)
target_images = np.array(target_images)

#source_images = np.expand_dims(source_images, axis=-1)
#target_images = np.expand_dims(target_images, axis=-1)
# Print the shape of the loaded and preprocessed images
print("Source Images Shape:", source_images.shape)
print("Target Images Shape:", target_images.shape)

In [None]:
# Use first 100 as a test set
X_train = source_images[100:]
y_train = target_images[100:]

X_val = source_images[:100]
y_val = target_images[:100]

print(f"x_train: {len(X_train)} | x_val: {len(X_val)} | y_train: {len(y_train)} | y_val: {len(y_val)}")

In [None]:
# Kaggle dataset test
test_folder = 'Dataset/test/'

source_image_files = os.listdir(test_folder)

source_image_files = [filename for filename in source_image_files if filename.lower().endswith('.png')]
   
test_images = []

i = 0
for filename in source_image_files:
    test_image = cv2.imread(os.path.join(test_folder, filename), cv2.IMREAD_GRAYSCALE)
    test_image = cv2.resize(test_image, (1024, 1024))
    #test_image = cv2.bitwise_not(test_image)
    test_image = test_image.astype('float32') / 255.0
    test_images.append(test_image)
    if i == 4:
        break
    i += 1

test_images = np.array(test_images)
test_images = np.expand_dims(test_images, axis=-1)
print("Test Images Shape:", test_images.shape)

In [None]:
# Provided images from Olomouc
test_folder = 'Test/'

source_image_files = os.listdir(test_folder)

source_image_files = [filename for filename in source_image_files if filename.lower().endswith('.png')]
   
test_images = []

for filename in source_image_files:
    test_image = cv2.imread(os.path.join(test_folder, filename), cv2.IMREAD_GRAYSCALE)
    test_image = (test_image.astype(np.float32) - 127.5) / 127.5
    test_image = np.repeat(test_image[:, :, np.newaxis], 3, axis=2)
    test_image = cv2.resize(test_image, (1024, 1024))
    #test_image = cv2.bitwise_not(test_image)
    test_images.append(test_image)


test_images = np.array(test_images)
test_images = np.expand_dims(test_images, axis=-1)
print("Test Images Shape:", test_images.shape)

In [None]:
plt.rcParams["figure.figsize"] = (20,10)
fig, axs = plt.subplots(2, 5)
for i in range(2):
    for j in range(5):
      if i == 0:
            axs[i, j].imshow(source_images[j], cmap='gray')
            axs[i, j].set_title('Source image')
            axs[i, j].axis('off')
      else:
            axs[i, j].imshow(target_images[j], cmap='gray')
            axs[i, j].set_title('Target image')
            axs[i, j].axis('off')



plt.show()

# Build GAN blocks

In [None]:
def conv2d_block(input_tensor, n_filters, kernel_size=3, batchnorm=True):
    x = Conv2D(filters=n_filters, kernel_size=(kernel_size, kernel_size), kernel_initializer='he_normal', padding='same')(input_tensor)
    if batchnorm:
        x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = Conv2D(filters=n_filters, kernel_size=(kernel_size, kernel_size), kernel_initializer='he_normal', padding='same')(x)
    if batchnorm:
        x = BatchNormalization()(x)
    x = Activation('relu')(x)

    return x

In [None]:
def generator_WGAN(input_shape, n_filters=64, dropout=0.5, batchnorm=True):
    inputs = Input(input_shape)
    
    c1 = conv2d_block(inputs, n_filters=n_filters*1, kernel_size=3, batchnorm=batchnorm)
    p1 = MaxPooling2D((2, 2))(c1)
    
    c2 = conv2d_block(p1, n_filters=n_filters*2, kernel_size=3, batchnorm=batchnorm)
    p2 = MaxPooling2D((2, 2))(c2)

    c3 = conv2d_block(p2, n_filters=n_filters*4, kernel_size=3, batchnorm=batchnorm)
    p3 = MaxPooling2D((2, 2))(c3)

    c4 = conv2d_block(p3, n_filters=n_filters*8, kernel_size=3, batchnorm=batchnorm)
    p4 = MaxPooling2D((2, 2))(c4)

    c5 = conv2d_block(p4, n_filters=n_filters*16, kernel_size=3, batchnorm=batchnorm)
    p5 = MaxPooling2D((2, 2))(c5)


    cB = conv2d_block(p5, n_filters=n_filters*32, kernel_size=3, batchnorm=batchnorm)


    u6 = Conv2DTranspose(n_filters*16, (3, 3), strides=(2, 2), padding='same')(cB)
    u6 = Concatenate()([u6, c5]) 
    c6 = conv2d_block(u6, n_filters=n_filters*16, kernel_size=3, batchnorm=batchnorm)

    u7 = Conv2DTranspose(n_filters*8, (3, 3), strides=(2, 2), padding='same')(c6)
    u7 = Concatenate()([u7, c4]) 
    c7 = conv2d_block(u7, n_filters=n_filters*8, kernel_size=3, batchnorm=batchnorm)

    u8 = Conv2DTranspose(n_filters*4, (3, 3), strides=(2, 2), padding='same')(c7)
    u8 = Concatenate()([u8, c3]) 
    c8 = conv2d_block(u8, n_filters=n_filters*4, kernel_size=3, batchnorm=batchnorm)

    u9 = Conv2DTranspose(n_filters*2, (3, 3), strides=(2, 2), padding='same')(c8)
    u9 = Concatenate()([u9, c2]) 
    c9 = conv2d_block(u9, n_filters=n_filters*2, kernel_size=3, batchnorm=batchnorm)

    u10 = Conv2DTranspose(n_filters*1, (3, 3), strides=(2, 2), padding='same')(c9)
    u10 = Concatenate()([u10, c1]) 
    c10 = conv2d_block(u10, n_filters=n_filters*1, kernel_size=3, batchnorm=batchnorm)

    outputs = Conv2D(3, (1, 1), activation='tanh')(c10)

    model = Model(inputs=[inputs], outputs=[outputs])
    return model

In [None]:
def discriminator_WGAN(input_shape, n_filters=64):
    inputs = Input(input_shape)
    
    d1 = Conv2D(n_filters, kernel_size=4, strides=2, padding='same')(inputs)
    d1 = layers.LeakyReLU(alpha=0.2)(d1)

    d2 = Conv2D(n_filters*2, kernel_size=4, strides=2, padding='same')(d1)
    d2 = BatchNormalization()(d2)
    d2 = layers.LeakyReLU(alpha=0.2)(d2)

    d3 = Conv2D(n_filters*4, kernel_size=4, strides=2, padding='same')(d2)
    d3 = BatchNormalization()(d3)
    d3 = layers.LeakyReLU(alpha=0.2)(d3)

    d4 = Conv2D(n_filters*8, kernel_size=4, strides=2, padding='same')(d3)
    d4 = BatchNormalization()(d4)
    d4 = layers.LeakyReLU(alpha=0.2)(d4)

    d5 = Conv2D(n_filters*16, kernel_size=4, strides=2, padding='same')(d4)
    d5 = BatchNormalization()(d5)
    d5 = layers.LeakyReLU(alpha=0.2)(d5)
    
    flat = Flatten()(d5)
    outputs = Dense(1)(flat)

    model = Model(inputs=[inputs], outputs=[outputs])
    return model

#input_shape = (1024, 1024, 3)
#discriminator_WGAN = discriminator_WGAN(input_shape)
#discriminator_WGAN.summary()

## Wasserstein | 400train + 100test

In [None]:
def discriminator_loss(real_output, fake_output):
    return tf.reduce_mean(fake_output) - tf.reduce_mean(real_output)

def generator_loss(fake_output):
    return -tf.reduce_mean(fake_output)

generator = generator_WGAN(input_shape=(512, 512, 3))
discriminator = discriminator_WGAN(input_shape=(512, 512, 3))

generator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)

@tf.function
def train_step(noisy_img_batch, clean_img_batch):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        denoised_images = generator(noisy_img_batch, training=True)
        real_output = discriminator(clean_img_batch, training=True)
        fake_output = discriminator(denoised_images, training=True)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

    # Clip discriminator weights
    for w in discriminator.trainable_variables:
        w.assign(tf.clip_by_value(w, -0.01, 0.01))

    return gen_loss, disc_loss

In [None]:
noisy_dataset = tf.data.Dataset.from_tensor_slices(source_images)
clean_dataset = tf.data.Dataset.from_tensor_slices(target_images)
dataset = tf.data.Dataset.zip((noisy_dataset, clean_dataset))
BATCH_SIZE = 8
dataset = dataset.shuffle(buffer_size=1000).batch(BATCH_SIZE)

In [None]:
def train(dataset, epochs, X_val, y_val):
    for epoch in range(epochs):
        start_time = time.time()
        
        total_gen_loss = 0
        total_disc_loss = 0
        total_psnr = 0
        total_ssim = 0
        num_batches = 0

        for noisy_img_batch, clean_img_batch in dataset:
            gen_loss, disc_loss = train_step(noisy_img_batch, clean_img_batch)
            total_gen_loss += gen_loss
            total_disc_loss += disc_loss
            
            denoised_images = generator(noisy_img_batch, training=False)
            
            norm_denoised_images = normalize_images(denoised_images)
            norm_clean_img_batch = normalize_images(clean_img_batch)
            
            batch_psnr = tf.reduce_mean(psnr(norm_denoised_images, norm_clean_img_batch))
            batch_ssim = tf.reduce_mean(ms_ssim_metric(norm_denoised_images, norm_clean_img_batch))
            
            total_psnr += batch_psnr
            total_ssim += batch_ssim
            
            num_batches += 1
            
            
        val_avg_psnr, val_avg_ssim = evaluate_on_test_set(generator, X_val, y_val)

        avg_gen_loss = total_gen_loss / num_batches
        avg_disc_loss = total_disc_loss / num_batches
        train_avg_psnr = total_psnr / num_batches
        train_avg_ssim = total_ssim / num_batches
    
        end_time = time.time()
        epoch_duration = end_time - start_time
        total_duration = end_time - start_time

        print(f'Epoch {epoch + 1}/{epochs}, '
              f'Generator Loss: {avg_gen_loss}, '
              f'Discriminator Loss: {avg_disc_loss}, '
              f'PSNR: {train_avg_psnr}, '
              f'SSIM: {train_avg_ssim}, '
              f'Validation PSNR: {val_avg_psnr}, '
              f'Validation SSIM: {val_avg_ssim}, '
              f'Time: {epoch_duration:.2f} sec')

        if (epoch + 1) % 10 == 0:
            display_images(noisy_img_batch, epoch + 1)


EPOCHS = 100
train(dataset, EPOCHS, X_val, y_val)

In [None]:
generator.save('WGAN_2500epochs_500data.h5')

In [None]:
generator = load_model('WGAN_2500epochs_500data.h5')

In [None]:
def train(dataset, epochs, X_val, y_val):
    for epoch in range(epochs):
        start_time = time.time()
        
        total_gen_loss = 0
        total_disc_loss = 0
        total_psnr = 0
        total_ssim = 0
        num_batches = 0

        for noisy_img_batch, clean_img_batch in dataset:
            gen_loss, disc_loss = train_step(noisy_img_batch, clean_img_batch)
            total_gen_loss += gen_loss
            total_disc_loss += disc_loss
            
            denoised_images = generator(noisy_img_batch, training=False)
            
            norm_denoised_images = normalize_images(denoised_images)
            norm_clean_img_batch = normalize_images(clean_img_batch)
            
            batch_psnr = tf.reduce_mean(psnr(norm_denoised_images, norm_clean_img_batch))
            batch_ssim = tf.reduce_mean(ms_ssim_metric(norm_denoised_images, norm_clean_img_batch))
            
            total_psnr += batch_psnr
            total_ssim += batch_ssim
            
            num_batches += 1
            
            
        val_avg_psnr, val_avg_ssim = evaluate_on_test_set(generator, X_val, y_val)

        avg_gen_loss = total_gen_loss / num_batches
        avg_disc_loss = total_disc_loss / num_batches
        train_avg_psnr = total_psnr / num_batches
        train_avg_ssim = total_ssim / num_batches
    
        end_time = time.time()
        epoch_duration = end_time - start_time
        total_duration = end_time - start_time

        print(f'Epoch {epoch + 1}/{epochs}, '
              f'Generator Loss: {avg_gen_loss}, '
              f'Discriminator Loss: {avg_disc_loss}, '
              f'PSNR: {train_avg_psnr}, '
              f'SSIM: {train_avg_ssim}, '
              f'Validation PSNR: {val_avg_psnr}, '
              f'Validation SSIM: {val_avg_ssim}, '
              f'Time: {epoch_duration:.2f} sec')

        if (epoch + 1) % 10 == 0:
            display_images(noisy_img_batch, epoch + 1)


EPOCHS = 500
train(dataset, EPOCHS, X_val, y_val)

## Wasserstein + L1_Loss | 400train + 100test

In [None]:
def discriminator_loss(real_output, fake_output):
    return tf.reduce_mean(fake_output) - tf.reduce_mean(real_output)

def generator_loss(fake_output):
    return -tf.reduce_mean(fake_output)

def generator_loss_W_L1(fake_output, denoised_images, target_images, sigma=10000, alpha = 1):
    w_loss = -tf.reduce_mean(fake_output)
    l1_loss = tf.reduce_mean(tf.abs(target_images - denoised_images))
    total_gen_loss = w_loss + sigma * (alpha * l1_loss)
    
    return total_gen_loss

generator = generator_WGAN(input_shape=(512, 512, 3))
discriminator = discriminator_WGAN(input_shape=(512, 512, 3))

generator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)

@tf.function
def train_step(noisy_img_batch, clean_img_batch):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        denoised_images = generator(noisy_img_batch, training=True)
        real_output = discriminator(clean_img_batch, training=True)
        fake_output = discriminator(denoised_images, training=True)

        gen_loss = generator_loss_W_L1(fake_output, denoised_images, clean_img_batch)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

    # Clip discriminator weights
    for w in discriminator.trainable_variables:
        w.assign(tf.clip_by_value(w, -0.01, 0.01))

    return gen_loss, disc_loss

In [None]:
noisy_dataset = tf.data.Dataset.from_tensor_slices(source_images)
clean_dataset = tf.data.Dataset.from_tensor_slices(target_images)
dataset = tf.data.Dataset.zip((noisy_dataset, clean_dataset))
BATCH_SIZE = 8
dataset = dataset.shuffle(buffer_size=1000).batch(BATCH_SIZE)

In [None]:
def train(dataset, epochs, X_val, y_val):
    for epoch in range(epochs):
        start_time = time.time()
        
        total_gen_loss = 0
        total_disc_loss = 0
        total_psnr = 0
        total_ssim = 0
        num_batches = 0

        for noisy_img_batch, clean_img_batch in dataset:
            gen_loss, disc_loss = train_step(noisy_img_batch, clean_img_batch)
            total_gen_loss += gen_loss
            total_disc_loss += disc_loss
            
            denoised_images = generator(noisy_img_batch, training=False)
            
            norm_denoised_images = normalize_images(denoised_images)
            norm_clean_img_batch = normalize_images(clean_img_batch)
            
            batch_psnr = tf.reduce_mean(psnr(norm_denoised_images, norm_clean_img_batch))
            batch_ssim = tf.reduce_mean(ms_ssim_metric(norm_denoised_images, norm_clean_img_batch))
            
            total_psnr += batch_psnr
            total_ssim += batch_ssim
            
            num_batches += 1
            
            
        val_avg_psnr, val_avg_ssim = evaluate_on_test_set(generator, X_val, y_val)

        avg_gen_loss = total_gen_loss / num_batches
        avg_disc_loss = total_disc_loss / num_batches
        train_avg_psnr = total_psnr / num_batches
        train_avg_ssim = total_ssim / num_batches
    
        end_time = time.time()
        epoch_duration = end_time - start_time
        total_duration = end_time - start_time

        print(f'Epoch {epoch + 1}/{epochs}, '
              f'Generator Loss: {avg_gen_loss}, '
              f'Discriminator Loss: {avg_disc_loss}, '
              f'PSNR: {train_avg_psnr}, '
              f'SSIM: {train_avg_ssim}, '
              f'Validation PSNR: {val_avg_psnr}, '
              f'Validation SSIM: {val_avg_ssim}, '
              f'Time: {epoch_duration:.2f} sec')

        if (epoch + 1) % 10 == 0:
            display_images(noisy_img_batch, epoch + 1)


EPOCHS = 500
train(dataset, EPOCHS, X_val, y_val)

In [None]:
#Epoch 492, Generator Loss: 95.9011459350586, Discriminator Loss: -0.36895322799682617, PSNR: 40.37303924560547, SSIM: 0.9956716895103455
generator.save('WGAN_L1_500epoch_500dataset_gen')
discriminator.save('WGAN_L1_500epoch_500dataset_disc')

## Wasserstein + L2_Loss | 400train + 100test

In [None]:
def discriminator_loss(real_output, fake_output):
    return tf.reduce_mean(fake_output) - tf.reduce_mean(real_output)

def generator_loss(fake_output):
    return -tf.reduce_mean(fake_output)

def generator_loss_W_L2(fake_output, denoised_images, target_images, sigma=10000, alpha = 1):
    w_loss = -tf.reduce_mean(fake_output)
    l2_loss = tf.reduce_mean(tf.square(target_images - denoised_images))
    total_gen_loss = w_loss + sigma * ( alpha * l2_loss)
    return total_gen_loss

generator = generator_WGAN(input_shape=(512, 512, 3))
discriminator = discriminator_WGAN(input_shape=(512, 512, 3))

generator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)

@tf.function
def train_step(noisy_img_batch, clean_img_batch):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        denoised_images = generator(noisy_img_batch, training=True)
        real_output = discriminator(clean_img_batch, training=True)
        fake_output = discriminator(denoised_images, training=True)

        gen_loss = generator_loss_W_L2(fake_output, denoised_images, clean_img_batch)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

    # Clip discriminator weights
    for w in discriminator.trainable_variables:
        w.assign(tf.clip_by_value(w, -0.01, 0.01))

    return gen_loss, disc_loss

In [None]:
noisy_dataset = tf.data.Dataset.from_tensor_slices(source_images)
clean_dataset = tf.data.Dataset.from_tensor_slices(target_images)
dataset = tf.data.Dataset.zip((noisy_dataset, clean_dataset))
BATCH_SIZE = 8
dataset = dataset.shuffle(buffer_size=1000).batch(BATCH_SIZE)

In [None]:
def train(dataset, epochs):
    for epoch in range(epochs):
        start_time = time.time()
        
        total_gen_loss = 0
        total_disc_loss = 0
        total_psnr = 0
        total_ssim = 0
        num_batches = 0

        for noisy_img_batch, clean_img_batch in dataset:
            gen_loss, disc_loss = train_step(noisy_img_batch, clean_img_batch)
            total_gen_loss += gen_loss
            total_disc_loss += disc_loss
            
            denoised_images = generator(noisy_img_batch, training=False)
            
            norm_denoised_images = normalize_images(denoised_images)
            norm_clean_img_batch = normalize_images(clean_img_batch)
            
            batch_psnr = tf.reduce_mean(psnr(norm_denoised_images, norm_clean_img_batch))
            batch_ssim = tf.reduce_mean(ms_ssim_metric(norm_denoised_images, norm_clean_img_batch))
            
            total_psnr += batch_psnr
            total_ssim += batch_ssim
            
            num_batches += 1

        avg_gen_loss = total_gen_loss / num_batches
        avg_disc_loss = total_disc_loss / num_batches
        avg_psnr = total_psnr / num_batches
        avg_ssim = total_ssim / num_batches
    
        end_time = time.time()
        epoch_duration = end_time - start_time
        total_duration = end_time - start_time

        print(f'Epoch {epoch + 1}/{epochs}, '
              f'Generator Loss: {avg_gen_loss}, '
              f'Discriminator Loss: {avg_disc_loss}, '
              f'PSNR: {avg_psnr}, '
              f'SSIM: {avg_ssim}, '
              f'Time: {epoch_duration:.2f} sec')

        if (epoch + 1) % 1 == 0:
            display_images(noisy_img_batch, epoch + 1)


EPOCHS = 500  
train(dataset, EPOCHS)

In [None]:
generator.save('WGAN_L2_500epoch_500dataset_gen.h5')
discriminator.save('WGAN_L2_500epoch_500dataset_disc.h5')

In [None]:
#generator = load_model('WGAN_L2_500epoch_500dataset_gen.h5')
avg_psnr, avg_ssim = evaluate_on_test_set(generator, X_val, y_val)
print(f"psnr: {avg_psnr}, ssim {avg_ssim}")

In [None]:
generator = load_model('WGAN_L2_500epoch_500dataset_gen.h5')
generator.compile()

predicted_images = generator.predict(test_images)
random_indices = np.random.choice(len(test_images), 6, replace=False)

for i in range(1,6):
    original_image = normalize_image(test_images[i].squeeze())  
    predicted_image = normalize_image(predicted_images[i].squeeze()) 

    plt.figure(figsize=(12, 6))

    plt.subplot(1, 2, 1)
    plt.title("Original Image")
    plt.imshow(original_image, cmap='gray')

    plt.subplot(1, 2, 2)
    plt.title("Predicted Image")
    plt.imshow(predicted_image, cmap='gray')

    plt.tight_layout()
    plt.show()

## Wasserstein + L2_Loss | 900train + 100test

In [None]:
def discriminator_loss(real_output, fake_output):
    return tf.reduce_mean(fake_output) - tf.reduce_mean(real_output)

def generator_loss(fake_output):
    return -tf.reduce_mean(fake_output)

def generator_loss_W_L2(fake_output, denoised_images, target_images, sigma=10000, alpha = 1):
    w_loss = -tf.reduce_mean(fake_output)
    l2_loss = tf.reduce_mean(tf.square(target_images - denoised_images))
    total_gen_loss = w_loss + sigma * ( alpha * l2_loss)
    return total_gen_loss

generator = generator_WGAN(input_shape=(512, 512, 3))
discriminator = discriminator_WGAN(input_shape=(512, 512, 3))

generator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)

@tf.function
def train_step(noisy_img_batch, clean_img_batch):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        denoised_images = generator(noisy_img_batch, training=True)
        real_output = discriminator(clean_img_batch, training=True)
        fake_output = discriminator(denoised_images, training=True)

        gen_loss = generator_loss_W_L2(fake_output, denoised_images, clean_img_batch)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

    # Clip discriminator weights
    for w in discriminator.trainable_variables:
        w.assign(tf.clip_by_value(w, -0.01, 0.01))

    return gen_loss, disc_loss

In [None]:
noisy_dataset = tf.data.Dataset.from_tensor_slices(X_train)
clean_dataset = tf.data.Dataset.from_tensor_slices(y_train)
dataset = tf.data.Dataset.zip((noisy_dataset, clean_dataset))
BATCH_SIZE = 4
dataset = dataset.shuffle(buffer_size=1000).batch(BATCH_SIZE)

In [None]:
def train(dataset, epochs, X_val, y_val):
    for epoch in range(epochs):
        start_time = time.time()
        
        total_gen_loss = 0
        total_disc_loss = 0
        total_psnr = 0
        total_ssim = 0
        num_batches = 0

        for noisy_img_batch, clean_img_batch in dataset:
            gen_loss, disc_loss = train_step(noisy_img_batch, clean_img_batch)
            total_gen_loss += gen_loss
            total_disc_loss += disc_loss
            
            denoised_images = generator(noisy_img_batch, training=False)
            
            norm_denoised_images = normalize_images(denoised_images)
            norm_clean_img_batch = normalize_images(clean_img_batch)
            
            batch_psnr = tf.reduce_mean(psnr(norm_denoised_images, norm_clean_img_batch))
            batch_ssim = tf.reduce_mean(ms_ssim_metric(norm_denoised_images, norm_clean_img_batch))
            
            total_psnr += batch_psnr
            total_ssim += batch_ssim
            
            num_batches += 1
            
            
        val_avg_psnr, val_avg_ssim = evaluate_on_test_set(generator, X_val, y_val)

        avg_gen_loss = total_gen_loss / num_batches
        avg_disc_loss = total_disc_loss / num_batches
        train_avg_psnr = total_psnr / num_batches
        train_avg_ssim = total_ssim / num_batches
    
        end_time = time.time()
        epoch_duration = end_time - start_time
        total_duration = end_time - start_time

        print(f'Epoch {epoch + 1}/{epochs}, '
              f'Generator Loss: {avg_gen_loss}, '
              f'Discriminator Loss: {avg_disc_loss}, '
              f'PSNR: {train_avg_psnr}, '
              f'SSIM: {train_avg_ssim}, '
              f'Validation PSNR: {val_avg_psnr}, '
              f'Validation SSIM: {val_avg_ssim}, '
              f'Time: {epoch_duration:.2f} sec')

        if (epoch + 1) % 10 == 0:
            display_images(noisy_img_batch, epoch + 1)


EPOCHS = 500  
train(dataset, EPOCHS, X_val, y_val)

In [None]:
generator.save('WGAN_L2_500epoch_1000dataset_gen.h5')
discriminator.save('WGAN_L2_500epoch_1000dataset_disc.h5')

In [None]:
predicted_images = None
predicted_images = generator.predict(test_images)
num_samples_to_visualize = 6
random_indices = np.random.choice(len(test_images), num_samples_to_visualize)

for i, idx in enumerate(random_indices):
    original_image = normalize_image(test_images[i].squeeze())
    predicted_image = normalize_image(predicted_images[i].squeeze())

    plt.figure(figsize=(12, 6))


    plt.subplot(1, 2, 1)
    plt.title("Original Image")
    plt.imshow(original_image, cmap='gray')


    plt.subplot(1, 2, 2)
    plt.title("Predicted Image")
    plt.imshow(predicted_image, cmap='gray')

    plt.tight_layout()
    plt.show()

## Wasserstein + L2_Loss + Perceptual | 1900train + 100test

In [None]:
def discriminator_loss(real_output, fake_output):
    return tf.reduce_mean(fake_output) - tf.reduce_mean(real_output)

vgg = tf.keras.applications.VGG19(include_top=False, input_shape=(None, None, 3))
vgg.trainable = False
perceptual_layers = ['block5_conv4']
vgg_model = tf.keras.Model([vgg.input], [vgg.get_layer(layer).output for layer in perceptual_layers])

def perceptual_loss(generated, target):
    gen_features = vgg_model(generated)
    target_features = vgg_model(target)
    return tf.reduce_mean(tf.square(target_features - gen_features))

def generator_loss_W_L2_Perceptual(fake_output, denoised_images, target_images, beta=10, sigma=10000, alpha=1):
    w_loss = -tf.reduce_mean(fake_output)
    l2_loss = tf.reduce_mean(tf.square(target_images - denoised_images))
    p_loss = perceptual_loss(denoised_images, target_images)
    total_gen_loss = w_loss + sigma * (alpha * l2_loss + beta * p_loss)
    return total_gen_loss

generator = generator_WGAN(input_shape=(512, 512, 3))
discriminator = discriminator_WGAN(input_shape=(512, 512, 3))

generator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)

#history_buffer = HistoryBuffer(max_size=8)

@tf.function
def train_step(noisy_img_batch, clean_img_batch):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        denoised_images = generator(noisy_img_batch, training=True)
        real_output = discriminator(clean_img_batch, training=True)
        fake_output = discriminator(denoised_images, training=True)

        gen_loss = generator_loss_W_L2_Perceptual(fake_output, denoised_images, clean_img_batch)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

    # Clip discriminator weights
    for w in discriminator.trainable_variables:
        w.assign(tf.clip_by_value(w, -0.01, 0.01))

    return gen_loss, disc_loss

In [None]:
noisy_dataset = tf.data.Dataset.from_tensor_slices(X_train)
clean_dataset = tf.data.Dataset.from_tensor_slices(y_train)
dataset = tf.data.Dataset.zip((noisy_dataset, clean_dataset))
BATCH_SIZE = 8
dataset = dataset.shuffle(buffer_size=1000).batch(BATCH_SIZE)

In [None]:
def train(dataset, epochs, X_val, y_val):
    for epoch in range(epochs):
        start_time = time.time()
        
        total_gen_loss = 0
        total_disc_loss = 0
        total_psnr = 0
        total_ssim = 0
        num_batches = 0

        for noisy_img_batch, clean_img_batch in dataset:
            gen_loss, disc_loss = train_step(noisy_img_batch, clean_img_batch)
            total_gen_loss += gen_loss
            total_disc_loss += disc_loss
            
            denoised_images = generator(noisy_img_batch, training=False)
            
            norm_denoised_images = normalize_images(denoised_images)
            norm_clean_img_batch = normalize_images(clean_img_batch)
            
            batch_psnr = tf.reduce_mean(psnr(norm_denoised_images, norm_clean_img_batch))
            batch_ssim = tf.reduce_mean(ms_ssim_metric(norm_denoised_images, norm_clean_img_batch))
            
            total_psnr += batch_psnr
            total_ssim += batch_ssim
            
            num_batches += 1
            
            
        val_avg_psnr, val_avg_ssim = evaluate_on_test_set(generator, X_val, y_val)

        avg_gen_loss = total_gen_loss / num_batches
        avg_disc_loss = total_disc_loss / num_batches
        train_avg_psnr = total_psnr / num_batches
        train_avg_ssim = total_ssim / num_batches
    
        end_time = time.time()
        epoch_duration = end_time - start_time
        total_duration = end_time - start_time

        print(f'Epoch {epoch + 1}/{epochs}, '
              f'Generator Loss: {avg_gen_loss}, '
              f'Discriminator Loss: {avg_disc_loss}, '
              f'PSNR: {train_avg_psnr}, '
              f'SSIM: {train_avg_ssim}, '
              f'Validation PSNR: {val_avg_psnr}, '
              f'Validation SSIM: {val_avg_ssim}, '
              f'Time: {epoch_duration:.2f} sec')

        if (epoch + 1) % 1 == 0:
            display_images(noisy_img_batch, epoch + 1)


EPOCHS = 1  
train(dataset, EPOCHS, X_val, y_val)

In [None]:
generator.save('WGAN_L2Per_500epoch_2000dataset_gen.h5')
discriminator.save('WGAN_L2Per_500epoch_2000dataset_disc.h5')

In [None]:
predicted_images = None
predicted_images = generator.predict(test_images)
num_samples_to_visualize = 6
random_indices = np.random.choice(len(test_images), num_samples_to_visualize)

for i, idx in enumerate(random_indices):
    original_image = normalize_image(test_images[i].squeeze())
    predicted_image = normalize_image(predicted_images[i].squeeze())

    plt.figure(figsize=(12, 6))


    plt.subplot(1, 2, 1)
    plt.title("Original Image")
    plt.imshow(original_image, cmap='gray')


    plt.subplot(1, 2, 2)
    plt.title("Predicted Image")
    plt.imshow(predicted_image, cmap='gray')

    plt.tight_layout()
    plt.show()

## Wasserstein + L2_Loss + Perceptual + Sobel | 1900train + 100test

In [None]:
def discriminator_loss(real_output, fake_output):
    return tf.reduce_mean(fake_output) - tf.reduce_mean(real_output)

vgg = tf.keras.applications.VGG19(include_top=False, input_shape=(None, None, 3))
vgg.trainable = False
perceptual_layers = ['block5_conv4']
vgg_model = tf.keras.Model([vgg.input], [vgg.get_layer(layer).output for layer in perceptual_layers])

def perceptual_loss(generated, target):
    gen_features = vgg_model(generated)
    target_features = vgg_model(target)
    return tf.reduce_mean(tf.square(target_features - gen_features))

def sobel_loss(generated, target):
    sobel_generated = tf.image.sobel_edges(generated)
    sobel_target = tf.image.sobel_edges(target)
    
    return tf.reduce_mean(tf.square(sobel_target - sobel_generated))

def generator_loss_W_L2_Perceptual(fake_output, denoised_images, target_images, beta=10, sigma=10000, alpha=1):
    w_loss = -tf.reduce_mean(fake_output)
    l2_loss = tf.reduce_mean(tf.square(target_images - denoised_images))
    p_loss = perceptual_loss(denoised_images, target_images)
    total_gen_loss = w_loss + sigma * (alpha * l2_loss + beta * p_loss)
    return total_gen_loss

def generator_loss_W_L2_Perceptual_Sobel(fake_output, denoised_images, target_images, beta=10, sigma=10000, alpha=1, gamma=10):
    w_loss = -tf.reduce_mean(fake_output)
    l2_loss = tf.reduce_mean(tf.square(target_images - denoised_images))
    p_loss = perceptual_loss(denoised_images, target_images)
    s_loss = sobel_loss(denoised_images, target_images)
    
    total_gen_loss = w_loss + sigma * (alpha * l2_loss + beta * p_loss + gamma * s_loss)
    return total_gen_loss

generator = generator_WGAN(input_shape=(512, 512, 3))
discriminator = discriminator_WGAN(input_shape=(512, 512, 3))

generator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)

@tf.function
def train_step(noisy_img_batch, clean_img_batch):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        denoised_images = generator(noisy_img_batch, training=True)
        real_output = discriminator(clean_img_batch, training=True)
        fake_output = discriminator(denoised_images, training=True)

        gen_loss = generator_loss_W_L2_Perceptual_Sobel(fake_output, denoised_images, clean_img_batch)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

    # Clip discriminator weights
    for w in discriminator.trainable_variables:
        w.assign(tf.clip_by_value(w, -0.01, 0.01))

    return gen_loss, disc_loss

In [None]:
def train(dataset, epochs, X_val, y_val):
    for epoch in range(epochs):
        start_time = time.time()
        
        total_gen_loss = 0
        total_disc_loss = 0
        total_psnr = 0
        total_ssim = 0
        num_batches = 0

        for noisy_img_batch, clean_img_batch in dataset:
            gen_loss, disc_loss = train_step(noisy_img_batch, clean_img_batch)
            total_gen_loss += gen_loss
            total_disc_loss += disc_loss
            
            denoised_images = generator(noisy_img_batch, training=False)
            
            norm_denoised_images = normalize_images(denoised_images)
            norm_clean_img_batch = normalize_images(clean_img_batch)
            
            batch_psnr = tf.reduce_mean(psnr(norm_denoised_images, norm_clean_img_batch))
            batch_ssim = tf.reduce_mean(ms_ssim_metric(norm_denoised_images, norm_clean_img_batch))
            
            total_psnr += batch_psnr
            total_ssim += batch_ssim
            
            num_batches += 1
            
            
        val_avg_psnr, val_avg_ssim = evaluate_on_test_set(generator, X_val, y_val)

        avg_gen_loss = total_gen_loss / num_batches
        avg_disc_loss = total_disc_loss / num_batches
        train_avg_psnr = total_psnr / num_batches
        train_avg_ssim = total_ssim / num_batches
    
        end_time = time.time()
        epoch_duration = end_time - start_time
        total_duration = end_time - start_time

        print(f'Epoch {epoch + 1}/{epochs}, '
              f'Generator Loss: {avg_gen_loss}, '
              f'Discriminator Loss: {avg_disc_loss}, '
              f'PSNR: {train_avg_psnr}, '
              f'SSIM: {train_avg_ssim}, '
              f'Validation PSNR: {val_avg_psnr}, '
              f'Validation SSIM: {val_avg_ssim}, '
              f'Time: {epoch_duration:.2f} sec')

        if (epoch + 1) % 1 == 0:
            display_images(noisy_img_batch, epoch + 1)


EPOCHS = 1  
train(dataset, EPOCHS, X_val, y_val)

In [None]:
generator.save('WGAN_L2Sob_850epoch_2000dataset_gen.h5')
discriminator.save('WGAN_L2Sob_850epoch_2000dataset_disc.h5')

In [None]:
# 750 epochs | Generator Loss: 22.633766174316406, Discriminator Loss: -0.008981402032077312, PSNR: 49.84444046020508, SSIM: 0.9991695880889893, Validation PSNR: 39.01051330566406, Validation SSIM: 0.992892324924469, Time: 387.34 sec
predicted_images = None
predicted_images = generator.predict(test_images)
num_samples_to_visualize = 6
random_indices = np.random.choice(len(test_images), num_samples_to_visualize)

for i, idx in enumerate(random_indices):
    original_image = normalize_image(test_images[i].squeeze())
    predicted_image = normalize_image(predicted_images[i].squeeze())

    plt.figure(figsize=(12, 6))


    plt.subplot(1, 2, 1)
    plt.title("Original Image")
    plt.imshow(original_image, cmap='gray')


    plt.subplot(1, 2, 2)
    plt.title("Predicted Image")
    plt.imshow(predicted_image, cmap='gray')

    plt.tight_layout()
    plt.show()

In [None]:
# 850 epoch | Generator Loss: 18.908538818359375, Discriminator Loss: -0.011361256241798401, PSNR: 49.376426696777344, SSIM: 0.999204695224762, Validation PSNR: 39.010196685791016, Validation SSIM: 0.9930490255355835, Time: 386.87 sec
predicted_images = None
predicted_images = generator.predict(test_images)
num_samples_to_visualize = 6
random_indices = np.random.choice(len(test_images), num_samples_to_visualize)

for i, idx in enumerate(random_indices):
    original_image = normalize_image(test_images[i].squeeze())
    predicted_image = normalize_image(predicted_images[i].squeeze())

    plt.figure(figsize=(12, 6))


    plt.subplot(1, 2, 1)
    plt.title("Original Image")
    plt.imshow(original_image, cmap='gray')


    plt.subplot(1, 2, 2)
    plt.title("Predicted Image")
    plt.imshow(predicted_image, cmap='gray')

    plt.tight_layout()
    plt.show()

## Wasserstein + L1_Loss + Perceptual + Sobel | 1900train + 100test

In [None]:
def discriminator_loss(real_output, fake_output):
    return tf.reduce_mean(fake_output) - tf.reduce_mean(real_output)

vgg = tf.keras.applications.VGG19(include_top=False, input_shape=(None, None, 3))
vgg.trainable = False
perceptual_layers = ['block5_conv4']
vgg_model = tf.keras.Model([vgg.input], [vgg.get_layer(layer).output for layer in perceptual_layers])

def perceptual_loss(generated, target):
    gen_features = vgg_model(generated)
    target_features = vgg_model(target)
    return tf.reduce_mean(tf.square(target_features - gen_features))

def sobel_loss(generated, target):
    sobel_generated = tf.image.sobel_edges(generated)
    sobel_target = tf.image.sobel_edges(target)
    
    return tf.reduce_mean(tf.square(sobel_target - sobel_generated))


def generator_loss_W_L1_Perceptual_Sobel(fake_output, denoised_images, target_images, beta=10, sigma=10000, alpha=1, gamma=10):
    w_loss = -tf.reduce_mean(fake_output)
    l1_loss = tf.reduce_mean(tf.abs(target_images - denoised_images))
    p_loss = perceptual_loss(denoised_images, target_images)
    s_loss = sobel_loss(denoised_images, target_images)
    
    total_gen_loss = w_loss + sigma * (alpha * l1_loss + beta * p_loss + gamma * s_loss)
    return total_gen_loss

generator = generator_WGAN(input_shape=(512, 512, 3))
discriminator = discriminator_WGAN(input_shape=(512, 512, 3))

generator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)

@tf.function
def train_step(noisy_img_batch, clean_img_batch):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        denoised_images = generator(noisy_img_batch, training=True)
        real_output = discriminator(clean_img_batch, training=True)
        fake_output = discriminator(denoised_images, training=True)

        gen_loss = generator_loss_W_L1_Perceptual_Sobel(fake_output, denoised_images, clean_img_batch)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

    # Clip discriminator weights
    for w in discriminator.trainable_variables:
        w.assign(tf.clip_by_value(w, -0.01, 0.01))

    return gen_loss, disc_loss

In [None]:
def train(dataset, epochs, X_val, y_val):
    for epoch in range(epochs):
        start_time = time.time()
        
        total_gen_loss = 0
        total_disc_loss = 0
        total_psnr = 0
        total_ssim = 0
        num_batches = 0

        for noisy_img_batch, clean_img_batch in dataset:
            gen_loss, disc_loss = train_step(noisy_img_batch, clean_img_batch)
            total_gen_loss += gen_loss
            total_disc_loss += disc_loss
            
            denoised_images = generator(noisy_img_batch, training=False)
            
            norm_denoised_images = normalize_images(denoised_images)
            norm_clean_img_batch = normalize_images(clean_img_batch)
            
            batch_psnr = tf.reduce_mean(psnr(norm_denoised_images, norm_clean_img_batch))
            batch_ssim = tf.reduce_mean(ms_ssim_metric(norm_denoised_images, norm_clean_img_batch))
            
            total_psnr += batch_psnr
            total_ssim += batch_ssim
            
            num_batches += 1
            
            
        val_avg_psnr, val_avg_ssim = evaluate_on_test_set(generator, X_val, y_val)

        avg_gen_loss = total_gen_loss / num_batches
        avg_disc_loss = total_disc_loss / num_batches
        train_avg_psnr = total_psnr / num_batches
        train_avg_ssim = total_ssim / num_batches
    
        end_time = time.time()
        epoch_duration = end_time - start_time
        total_duration = end_time - start_time

        print(f'Epoch {epoch + 1}/{epochs}, '
              f'Generator Loss: {avg_gen_loss}, '
              f'Discriminator Loss: {avg_disc_loss}, '
              f'PSNR: {train_avg_psnr}, '
              f'SSIM: {train_avg_ssim}, '
              f'Validation PSNR: {val_avg_psnr}, '
              f'Validation SSIM: {val_avg_ssim}, '
              f'Time: {epoch_duration:.2f} sec')

        if (epoch + 1) % 10 == 0:
            display_images(noisy_img_batch, epoch + 1)


EPOCHS = 500
train(dataset, EPOCHS, X_val, y_val)

In [None]:
generator.save('WGAN_L1Sob_750epoch_2000dataset_gen.h5')
discriminator.save('WGAN_L1Sob_750epoch_2000dataset_disc.h5')

In [None]:
# 500 Epoch | Generator Loss: 77.4582748413086, Discriminator Loss: -0.02741212397813797, PSNR: 48.36365509033203, SSIM: 0.9988846778869629, Validation PSNR: 40.32948303222656, Validation SSIM: 0.9937796592712402
predicted_images = None
predicted_images = generator.predict(test_images)
num_samples_to_visualize = 6
random_indices = np.random.choice(len(test_images), num_samples_to_visualize)

for i, idx in enumerate(random_indices):
    original_image = normalize_image(test_images[i].squeeze())
    predicted_image = normalize_image(predicted_images[i].squeeze())

    plt.figure(figsize=(12, 6))


    plt.subplot(1, 2, 1)
    plt.title("Original Image")
    plt.imshow(original_image, cmap='gray')


    plt.subplot(1, 2, 2)
    plt.title("Predicted Image")
    plt.imshow(predicted_image, cmap='gray')

    plt.tight_layout()
    plt.show()

In [None]:
#750 epoch | Generator Loss: 59.79399490356445, Discriminator Loss: -0.009202822111546993, PSNR: 49.406856536865234, SSIM: 0.9992486834526062, Validation PSNR: 40.53458023071289, Validation SSIM: 0.9937618970870972, Time: 386.69 sec
predicted_images = None
predicted_images = generator.predict(test_images)
num_samples_to_visualize = 6
random_indices = np.random.choice(len(test_images), num_samples_to_visualize)

for i, idx in enumerate(random_indices):
    original_image = normalize_image(test_images[i].squeeze())
    predicted_image = normalize_image(predicted_images[i].squeeze())

    plt.figure(figsize=(12, 6))


    plt.subplot(1, 2, 1)
    plt.title("Original Image")
    plt.imshow(original_image, cmap='gray')


    plt.subplot(1, 2, 2)
    plt.title("Predicted Image")
    plt.imshow(predicted_image, cmap='gray')

    plt.tight_layout()
    plt.show()

## Wasserstein + L1_Loss + Perceptual + Sobel | 3000 dataset

In [None]:
def discriminator_loss(real_output, fake_output):
    return tf.reduce_mean(fake_output) - tf.reduce_mean(real_output)

vgg = tf.keras.applications.VGG19(include_top=False, input_shape=(None, None, 3))
vgg.trainable = False
perceptual_layers = ['block5_conv4']
vgg_model = tf.keras.Model([vgg.input], [vgg.get_layer(layer).output for layer in perceptual_layers])

def perceptual_loss(generated, target):
    gen_features = vgg_model(generated)
    target_features = vgg_model(target)
    return tf.reduce_mean(tf.square(target_features - gen_features))

def sobel_loss(generated, target):
    sobel_generated = tf.image.sobel_edges(generated)
    sobel_target = tf.image.sobel_edges(target)
    
    return tf.reduce_mean(tf.square(sobel_target - sobel_generated))


def generator_loss_W_L1_Perceptual_Sobel(fake_output, denoised_images, target_images, beta=10, sigma=10000, alpha=1, gamma=10):
    w_loss = -tf.reduce_mean(fake_output)
    l1_loss = tf.reduce_mean(tf.abs(target_images - denoised_images))
    p_loss = perceptual_loss(denoised_images, target_images)
    s_loss = sobel_loss(denoised_images, target_images)
    
    total_gen_loss = w_loss + sigma * (alpha * l1_loss + beta * p_loss + gamma * s_loss)
    return total_gen_loss

generator = generator_WGAN(input_shape=(512, 512, 3))
discriminator = discriminator_WGAN(input_shape=(512, 512, 3))

generator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)

#history_buffer = HistoryBuffer(max_size=8)

@tf.function
def train_step(noisy_img_batch, clean_img_batch):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        denoised_images = generator(noisy_img_batch, training=True)
        real_output = discriminator(clean_img_batch, training=True)
        fake_output = discriminator(denoised_images, training=True)

        gen_loss = generator_loss_W_L1_Perceptual_Sobel(fake_output, denoised_images, clean_img_batch)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

    # Clip discriminator weights
    for w in discriminator.trainable_variables:
        w.assign(tf.clip_by_value(w, -0.01, 0.01))

    return gen_loss, disc_loss

In [None]:
def evaluate_on_test_set(generator, X_val, y_val,batch_size= 4):
    num_samples = len(X_val)
    steps = num_samples // batch_size + (1 if num_samples % batch_size else 0)
    
    psnr_values = []
    ssim_values = []
    
    for step in range(steps):
        batch_start = step * batch_size
        batch_end = min(batch_start + batch_size, num_samples)
        X_batch = X_val[batch_start:batch_end]
        y_batch = y_val[batch_start:batch_end]
        
        predicted_batch = generator.predict(X_batch)
        # Normalize images from [-1, 1] to [0, 1] for PSNR and SSIM calculations
        predicted_batch = normalize_images(predicted_batch)
        y_batch_normalized = normalize_images(y_batch)
        
        for i in range(batch_end - batch_start):
            psnr_val = psnr(y_batch_normalized[i], predicted_batch[i])
            ssim_val = ms_ssim_metric(y_batch_normalized[i], predicted_batch[i])
            
            psnr_values.append(psnr_val)
            ssim_values.append(ssim_val)
    
    avg_psnr = np.mean(psnr_values)
    avg_ssim = np.mean(ssim_values)
    
    return avg_psnr, avg_ssim

In [None]:
noisy_dataset = tf.data.Dataset.from_tensor_slices(X_train)
clean_dataset = tf.data.Dataset.from_tensor_slices(y_train)
dataset = tf.data.Dataset.zip((noisy_dataset, clean_dataset))
BATCH_SIZE = 4
#dataset = dataset.shuffle(buffer_size=1000).batch(BATCH_SIZE)
dataset = dataset.shuffle(buffer_size=1000).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

In [None]:
def train(dataset, epochs, X_val, y_val):
    for epoch in range(epochs):
        start_time = time.time()
        
        total_gen_loss = 0
        total_disc_loss = 0
        total_psnr = 0
        total_ssim = 0
        num_batches = 0

        for noisy_img_batch, clean_img_batch in dataset:
            gen_loss, disc_loss = train_step(noisy_img_batch, clean_img_batch)
            total_gen_loss += gen_loss
            total_disc_loss += disc_loss
            
            denoised_images = generator(noisy_img_batch, training=False)
            
            norm_denoised_images = normalize_images(denoised_images)
            norm_clean_img_batch = normalize_images(clean_img_batch)
            
            batch_psnr = tf.reduce_mean(psnr(norm_denoised_images, norm_clean_img_batch))
            batch_ssim = tf.reduce_mean(ms_ssim_metric(norm_denoised_images, norm_clean_img_batch))
            
            total_psnr += batch_psnr
            total_ssim += batch_ssim
            
            num_batches += 1
            
            
        val_avg_psnr, val_avg_ssim = evaluate_on_test_set(generator, X_val, y_val)

        avg_gen_loss = total_gen_loss / num_batches
        avg_disc_loss = total_disc_loss / num_batches
        train_avg_psnr = total_psnr / num_batches
        train_avg_ssim = total_ssim / num_batches
    
        end_time = time.time()
        epoch_duration = end_time - start_time
        total_duration = end_time - start_time

        print(f'Epoch {epoch + 1}/{epochs}, '
              f'Generator Loss: {avg_gen_loss}, '
              f'Discriminator Loss: {avg_disc_loss}, '
              f'PSNR: {train_avg_psnr}, '
              f'SSIM: {train_avg_ssim}, '
              f'Validation PSNR: {val_avg_psnr}, '
              f'Validation SSIM: {val_avg_ssim}, '
              f'Time: {epoch_duration:.2f} sec')

        if (epoch + 1) % 10 == 0:
            display_images(noisy_img_batch, epoch + 1)


EPOCHS = 2
train(dataset, EPOCHS, X_val, y_val)

In [None]:
generator.save('WGAN_L1Sob_750epoch_3000dataset_gen.h5')
discriminator.save('WGAN_L1Sob_750epoch_3000dataset_disc.h5')

In [None]:
#Epoch 500 | Generator Loss: 60.009727478027344, Discriminator Loss: -0.004583950154483318, PSNR: 48.44008255004883, SSIM: 0.9990533590316772, Validation PSNR: 43.31389617919922, Validation SSIM: 0.9967049360275269, Time: 630.21 sec
predicted_images = None
predicted_images = generator.predict(test_images)
num_samples_to_visualize = 6
random_indices = np.random.choice(len(test_images), num_samples_to_visualize)

for i, idx in enumerate(random_indices):
    original_image = normalize_image(test_images[i].squeeze())
    predicted_image = normalize_image(predicted_images[i].squeeze())

    plt.figure(figsize=(12, 6))


    plt.subplot(1, 2, 1)
    plt.title("Original Image")
    plt.imshow(original_image, cmap='gray')


    plt.subplot(1, 2, 2)
    plt.title("Predicted Image")
    plt.imshow(predicted_image, cmap='gray')

    plt.tight_layout()
    plt.show()

In [None]:
#Epoch 750 | Generator Loss: 52.27979278564453, Discriminator Loss: -0.004730249289423227, PSNR: 49.8203239440918, SSIM: 0.9993076920509338, Validation PSNR: 44.08540344238281, Validation SSIM: 0.9968264102935791, Time: 624.27 sec
predicted_images = None
predicted_images = generator.predict(test_images)
num_samples_to_visualize = 6
random_indices = np.random.choice(len(test_images), num_samples_to_visualize)

for i, idx in enumerate(random_indices):
    original_image = normalize_image(test_images[i].squeeze())
    predicted_image = normalize_image(predicted_images[i].squeeze())

    plt.figure(figsize=(12, 6))


    plt.subplot(1, 2, 1)
    plt.title("Original Image")
    plt.imshow(original_image, cmap='gray')


    plt.subplot(1, 2, 2)
    plt.title("Predicted Image")
    plt.imshow(predicted_image, cmap='gray')

    plt.tight_layout()
    plt.show()

## Wasserstein + L1_Loss + Perceptual + Sobel | 1024x1024 1000/100 dataset

In [None]:
def discriminator_loss(real_output, fake_output):
    return tf.reduce_mean(fake_output) - tf.reduce_mean(real_output)

vgg = tf.keras.applications.VGG19(include_top=False, input_shape=(None, None, 3))
vgg.trainable = False
perceptual_layers = ['block5_conv4']
vgg_model = tf.keras.Model([vgg.input], [vgg.get_layer(layer).output for layer in perceptual_layers])

def perceptual_loss(generated, target):
    gen_features = vgg_model(generated)
    target_features = vgg_model(target)
    return tf.reduce_mean(tf.square(target_features - gen_features))

def sobel_loss(generated, target):
    sobel_generated = tf.image.sobel_edges(generated)
    sobel_target = tf.image.sobel_edges(target)

    return tf.reduce_mean(tf.square(sobel_target - sobel_generated))


def generator_loss_W_L1_Perceptual_Sobel(fake_output, denoised_images, target_images, beta=10, sigma=10000, alpha=1, gamma=10):
    w_loss = -tf.reduce_mean(fake_output)
    l1_loss = tf.reduce_mean(tf.abs(target_images - denoised_images))
    p_loss = perceptual_loss(denoised_images, target_images)
    s_loss = sobel_loss(denoised_images, target_images)

    total_gen_loss = w_loss + sigma * (alpha * l1_loss + beta * p_loss + gamma * s_loss)
    return total_gen_loss


#generator = generator_WGAN(input_shape=(1024, 1024, 3), n_filters=32)
#generator.summary()
#discriminator = discriminator_WGAN(input_shape=(1024, 1024, 3), n_filters=32)
#discriminator.summary()

generator = load_model('WGAN_L1Sob_100epoch_1100dataset_fullsize_gen.h5')
discriminator = load_model('WGAN_L1Sob_100epoch_1100dataset_fullsize_disc.h5')

generator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)

#history_buffer = HistoryBuffer(max_size=8)

@tf.function
def train_step(noisy_img_batch, clean_img_batch):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        denoised_images = generator(noisy_img_batch, training=True)
        real_output = discriminator(clean_img_batch, training=True)
        fake_output = discriminator(denoised_images, training=True)

        gen_loss = generator_loss_W_L1_Perceptual_Sobel(fake_output, denoised_images, clean_img_batch)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

    # Clip discriminator weights
    for w in discriminator.trainable_variables:
        w.assign(tf.clip_by_value(w, -0.01, 0.01))

    return gen_loss, disc_loss

In [None]:
noisy_dataset = tf.data.Dataset.from_tensor_slices(X_train)
clean_dataset = tf.data.Dataset.from_tensor_slices(y_train)
dataset = tf.data.Dataset.zip((noisy_dataset, clean_dataset))
BATCH_SIZE = 1
#dataset = dataset.shuffle(buffer_size=1000).batch(BATCH_SIZE)
dataset = dataset.shuffle(buffer_size=1100).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

In [None]:
def train(dataset, epochs, X_val, y_val):
    for epoch in range(epochs):
        start_time = time.time()
        
        total_gen_loss = 0
        total_disc_loss = 0
        total_psnr = 0
        total_ssim = 0
        num_batches = 0

        for noisy_img_batch, clean_img_batch in dataset:
            gen_loss, disc_loss = train_step(noisy_img_batch, clean_img_batch)
            total_gen_loss += gen_loss
            total_disc_loss += disc_loss
            
            denoised_images = generator(noisy_img_batch, training=False)
            
            norm_denoised_images = normalize_images(denoised_images)
            norm_clean_img_batch = normalize_images(clean_img_batch)
            
            batch_psnr = tf.reduce_mean(psnr(norm_denoised_images, norm_clean_img_batch))
            batch_ssim = tf.reduce_mean(ms_ssim_metric(norm_denoised_images, norm_clean_img_batch))
            
            total_psnr += batch_psnr
            total_ssim += batch_ssim
            
            num_batches += 1
              
        val_avg_psnr, val_avg_ssim = evaluate_on_test_set(generator, X_val, y_val,1)

        avg_gen_loss = total_gen_loss / num_batches
        avg_disc_loss = total_disc_loss / num_batches
        train_avg_psnr = total_psnr / num_batches
        train_avg_ssim = total_ssim / num_batches
    
        end_time = time.time()
        epoch_duration = end_time - start_time
        total_duration = end_time - start_time

        print(f'Epoch {epoch + 1}/{epochs}, '
              f'Generator Loss: {avg_gen_loss}, '
              f'Discriminator Loss: {avg_disc_loss}, '
              f'PSNR: {train_avg_psnr}, '
              f'SSIM: {train_avg_ssim}, '
              f'Validation PSNR: {val_avg_psnr}, '
              f'Validation SSIM: {val_avg_ssim}, '
              f'Time: {epoch_duration:.2f} sec')

        if (epoch + 1) % 10 == 0:
            display_images(noisy_img_batch, epoch + 1)


EPOCHS = 100
train(dataset, EPOCHS, X_val, y_val)

In [None]:
generator.save('WGAN_L1Sob_500epoch_1100dataset_fullsize_gen.h5')
discriminator.save('WGAN_L1Sob_500epoch_1100dataset_fullsize_disc.h5')

In [None]:
#Epoch500train Generator Loss: 122.08033752441406, Discriminator Loss: -0.012598012574017048, PSNR: 34.723182678222656, SSIM: 0.9896923303604126, Time: 477.38 sec

# + fine-tune on inverse colors 3000dataset | 50 - 50 test set (50 normal, 50bitwise)

In [None]:
source_folder = 'Dataset/source'
target_folder = 'Dataset/target'

# List all image file names in the source folder
source_image_files = os.listdir(source_folder)

# Initialize lists to store the preprocessed images
source_images = []
target_images = []

i = 0

# Loop through each image file in the source folder
for filename in source_image_files:

    # Load source image
    source_image = cv2.imread(os.path.join(source_folder, filename), cv2.IMREAD_GRAYSCALE)
    #source_image = source_image.astype('float32') / 255.0
    if i > 50:
        source_image = cv2.bitwise_not(source_image)
    source_image = (source_image.astype(np.float32) - 127.5) / 127.5
    source_image = np.repeat(source_image[:, :, np.newaxis], 3, axis=2)
    source_image = cv2.resize(source_image, (512, 512))
    source_images.append(source_image)

    # Load corresponding target image from the target folder
    target_image = cv2.imread(os.path.join(target_folder, filename), cv2.IMREAD_GRAYSCALE)
    #target_image = target_image.astype('float32') / 255.0
    if i > 50:
        target_image = cv2.bitwise_not(target_image)
    target_image = (target_image.astype(np.float32) - 127.5) / 127.5
    target_image = np.repeat(target_image[:, :, np.newaxis], 3, axis=2)
    target_image = cv2.resize(target_image, (512, 512))
    target_images.append(target_image)
       
    i += 1
    if i == 3100:
        break

# Convert the lists to NumPy arrays
source_images = np.array(source_images)
target_images = np.array(target_images)

#source_images = np.expand_dims(source_images, axis=-1)
#target_images = np.expand_dims(target_images, axis=-1)
# Print the shape of the loaded and preprocessed images
print("Source Images Shape:", source_images.shape)
print("Target Images Shape:", target_images.shape)

In [None]:
# Use first 100 as a test set
X_train = source_images[100:]
y_train = target_images[100:]

X_val = source_images[:100]
y_val = target_images[:100]

print(f"x_train: {len(X_train)} | x_val: {len(X_val)} | y_train: {len(y_train)} | y_val: {len(y_val)}")

In [None]:
plt.rcParams["figure.figsize"] = (20,10)
fig, axs = plt.subplots(2, 5)
for i in range(2):
    for j in range(5):
      if i == 0:
            axs[i, j].imshow(normalize_images(source_images[j]), cmap='gray')
            axs[i, j].set_title('Source image')
            axs[i, j].axis('off')
      else:
            axs[i, j].imshow(normalize_images(target_images[j]), cmap='gray')
            axs[i, j].set_title('Target image')
            axs[i, j].axis('off')


plt.show()

In [None]:
def discriminator_loss(real_output, fake_output):
    return tf.reduce_mean(fake_output) - tf.reduce_mean(real_output)

vgg = tf.keras.applications.VGG19(include_top=False, input_shape=(None, None, 3))
vgg.trainable = False
perceptual_layers = ['block5_conv4']
vgg_model = tf.keras.Model([vgg.input], [vgg.get_layer(layer).output for layer in perceptual_layers])

def perceptual_loss(generated, target):
    gen_features = vgg_model(generated)
    target_features = vgg_model(target)
    return tf.reduce_mean(tf.square(target_features - gen_features))

def sobel_loss(generated, target):
    sobel_generated = tf.image.sobel_edges(generated)
    sobel_target = tf.image.sobel_edges(target)
    
    return tf.reduce_mean(tf.square(sobel_target - sobel_generated))


def generator_loss_W_L1_Perceptual_Sobel(fake_output, denoised_images, target_images, beta=10, sigma=10000, alpha=1, gamma=10):
    w_loss = -tf.reduce_mean(fake_output)
    l1_loss = tf.reduce_mean(tf.abs(target_images - denoised_images))
    p_loss = perceptual_loss(denoised_images, target_images)
    s_loss = sobel_loss(denoised_images, target_images)
    
    total_gen_loss = w_loss + sigma * (alpha * l1_loss + beta * p_loss + gamma * s_loss)
    return total_gen_loss


generator = load_model('WGAN_L1Sob_750epoch_3000dataset_gen.h5')
discriminator = load_model('WGAN_L1Sob_750epoch_3000dataset_disc.h5')

#generator = generator_WGAN(input_shape=(512, 512, 3))
#discriminator = discriminator_WGAN(input_shape=(512, 512, 3))

generator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)


generator.compile(optimizer=generator_optimizer,
              loss=generator_loss_W_L1_Perceptual_Sobel,
              metrics=[psnr, ms_ssim_metric])

discriminator.compile(optimizer=discriminator_optimizer,
              loss=discriminator_loss,
              metrics=[psnr, ms_ssim_metric])

#history_buffer = HistoryBuffer(max_size=8)

@tf.function
def train_step(noisy_img_batch, clean_img_batch):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        denoised_images = generator(noisy_img_batch, training=True)
        real_output = discriminator(clean_img_batch, training=True)
        fake_output = discriminator(denoised_images, training=True)

        gen_loss = generator_loss_W_L1_Perceptual_Sobel(fake_output, denoised_images, clean_img_batch)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

    # Clip discriminator weights
    for w in discriminator.trainable_variables:
        w.assign(tf.clip_by_value(w, -0.01, 0.01))

    return gen_loss, disc_loss

In [None]:
noisy_dataset = tf.data.Dataset.from_tensor_slices(X_train)
clean_dataset = tf.data.Dataset.from_tensor_slices(y_train)
dataset = tf.data.Dataset.zip((noisy_dataset, clean_dataset))
BATCH_SIZE = 4
#dataset = dataset.shuffle(buffer_size=1000).batch(BATCH_SIZE)
dataset = dataset.shuffle(buffer_size=1000).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

In [None]:
def train(dataset, epochs, X_val, y_val):
    for epoch in range(epochs):
        start_time = time.time()
        
        total_gen_loss = 0
        total_disc_loss = 0
        total_psnr = 0
        total_ssim = 0
        num_batches = 0

        for noisy_img_batch, clean_img_batch in dataset:
            gen_loss, disc_loss = train_step(noisy_img_batch, clean_img_batch)
            total_gen_loss += gen_loss
            total_disc_loss += disc_loss
            
            denoised_images = generator(noisy_img_batch, training=False)
            
            norm_denoised_images = normalize_images(denoised_images)
            norm_clean_img_batch = normalize_images(clean_img_batch)
            
            batch_psnr = tf.reduce_mean(psnr(norm_denoised_images, norm_clean_img_batch))
            batch_ssim = tf.reduce_mean(ms_ssim_metric(norm_denoised_images, norm_clean_img_batch))
            
            total_psnr += batch_psnr
            total_ssim += batch_ssim
            
            num_batches += 1
            
            
        val_avg_psnr, val_avg_ssim = evaluate_on_test_set(generator, X_val, y_val)

        avg_gen_loss = total_gen_loss / num_batches
        avg_disc_loss = total_disc_loss / num_batches
        train_avg_psnr = total_psnr / num_batches
        train_avg_ssim = total_ssim / num_batches
    
        end_time = time.time()
        epoch_duration = end_time - start_time
        total_duration = end_time - start_time

        print(f'Epoch {epoch + 1}/{epochs}, '
              f'Generator Loss: {avg_gen_loss}, '
              f'Discriminator Loss: {avg_disc_loss}, '
              f'PSNR: {train_avg_psnr}, '
              f'SSIM: {train_avg_ssim}, '
              f'Validation PSNR: {val_avg_psnr}, '
              f'Validation SSIM: {val_avg_ssim}, '
              f'Time: {epoch_duration:.2f} sec')

        if (epoch + 1) % 10 == 0:
            display_images(noisy_img_batch, epoch + 1)


EPOCHS = 100
train(dataset, EPOCHS, X_val, y_val)

In [None]:
generator.save('WGAN_L1Sob_500epoch_3000dataset_invert_gen.h5')
discriminator.save('WGAN_L1Sob_500epoch_3000dataset_invert_disc.h5')

In [None]:
#Epoch 250 | Generator Loss: 67.63257598876953, Discriminator Loss: -0.01677171140909195, PSNR: 47.56999969482422, SSIM: 0.9988458752632141, Validation PSNR: 33.62474060058594, Validation SSIM: 0.9820313453674316, Time: 634.36 sec
predicted_images = None
predicted_images = generator.predict(test_images)
num_samples_to_visualize = 6
random_indices = np.random.choice(len(test_images), num_samples_to_visualize)

for i, idx in enumerate(random_indices):
    original_image = normalize_image(test_images[i].squeeze())
    predicted_image = normalize_image(predicted_images[i].squeeze())

    plt.figure(figsize=(12, 6))


    plt.subplot(1, 2, 1)
    plt.title("Original Image")
    plt.imshow(original_image, cmap='gray')


    plt.subplot(1, 2, 2)
    plt.title("Predicted Image")
    plt.imshow(predicted_image, cmap='gray')

    plt.tight_layout()
    plt.show()

In [None]:
# 500 epochs | Generator Loss: 50.092411041259766, Discriminator Loss: -0.003123227972537279, PSNR: 48.603302001953125, SSIM: 0.9992698431015015, Validation PSNR: 33.26730728149414, Validation SSIM: 0.9800387024879456, Time: 627.18 sec
predicted_images = None
predicted_images = generator.predict(test_images)
num_samples_to_visualize = 6
random_indices = np.random.choice(len(test_images), num_samples_to_visualize)

for i, idx in enumerate(random_indices):
    original_image = normalize_image(test_images[i].squeeze())
    predicted_image = normalize_image(predicted_images[i].squeeze())

    plt.figure(figsize=(12, 6))


    plt.subplot(1, 2, 1)
    plt.title("Original Image")
    plt.imshow(original_image, cmap='gray')


    plt.subplot(1, 2, 2)
    plt.title("Predicted Image")
    plt.imshow(predicted_image, cmap='gray')

    plt.tight_layout()
    plt.show()