In [7]:
import tensorflow as tf
from tensorflow import keras
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
from tqdm import tqdm


In [14]:
import tensorflow as tf
from tensorflow import keras
import tensorflow_datasets as tfds
from tqdm import tqdm
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import numpy as np
import math
import csv
import os

# 1. Data Preprocessing
def preprocess_image(image):
    image = tf.cast(image, tf.float32)
    image = (image / 127.5) - 1
    return image

def rgb_to_grayscale(image):
    return tf.image.rgb_to_grayscale(image)

def load_and_preprocess_data(batch_size=32, num_samples=5000):
    # Load CIFAR-10 dataset
    dataset, info = tfds.load('cifar10', with_info=True, as_supervised=True)
    train_dataset = dataset['train'].take(num_samples)
    
    # Preprocess and batch the data
    total_samples = num_samples
    with tqdm(total=total_samples, desc="Preprocessing data") as pbar:
        def preprocess_and_update(img, label):
            pbar.update(1)
            return preprocess_image(img), preprocess_image(img)
        
        train_dataset = train_dataset.map(preprocess_and_update, num_parallel_calls=tf.data.AUTOTUNE)
        train_dataset = train_dataset.map(lambda img_color, img_color2: (rgb_to_grayscale(img_color), img_color2), num_parallel_calls=tf.data.AUTOTUNE)
        train_dataset = train_dataset.batch(batch_size).shuffle(1000).prefetch(tf.data.AUTOTUNE)
    
    return train_dataset

# Update the generator and discriminator to work with 32x32 images
def build_generator():
    inputs = keras.Input(shape=(32, 32, 1))
    
    # Encoder
    e1 = keras.layers.Conv2D(64, 4, strides=2, padding='same')(inputs)  # 16x16x64
    e1 = keras.layers.LeakyReLU(0.2)(e1)
    
    e2 = keras.layers.Conv2D(128, 4, strides=2, padding='same')(e1)    # 8x8x128
    e2 = keras.layers.BatchNormalization()(e2)
    e2 = keras.layers.LeakyReLU(0.2)(e2)
    
    # Bridge
    b = keras.layers.Conv2D(256, 4, strides=1, padding='same')(e2)     # 8x8x256
    b = keras.layers.BatchNormalization()(b)
    b = keras.layers.LeakyReLU(0.2)(b)
    
    # Decoder
    d2 = keras.layers.Conv2DTranspose(128, 4, strides=2, padding='same')(b)  # 16x16x128
    d2 = keras.layers.BatchNormalization()(d2)
    d2 = keras.layers.ReLU()(d2)
    d2 = keras.layers.Concatenate()([d2, e1])                               # 16x16x192
    
    # Optionally, add a Conv2D layer to reduce channels
    d2 = keras.layers.Conv2D(128, 3, padding='same')(d2)                   # 16x16x128
    d2 = keras.layers.BatchNormalization()(d2)
    d2 = keras.layers.ReLU()(d2)
    
    d1 = keras.layers.Conv2DTranspose(64, 4, strides=2, padding='same')(d2)   # 32x32x64
    d1 = keras.layers.BatchNormalization()(d1)
    d1 = keras.layers.ReLU()(d1)
    d1 = keras.layers.Concatenate()([d1, inputs])                            # 32x32x65
    
    # Optionally, add a Conv2D layer to reduce channels
    d1 = keras.layers.Conv2D(64, 3, padding='same')(d1)                     # 32x32x64
    d1 = keras.layers.BatchNormalization()(d1)
    d1 = keras.layers.ReLU()(d1)
    
    outputs = keras.layers.Conv2D(3, 4, strides=1, padding='same', activation='tanh')(d1)  # 32x32x3
    
    return keras.Model(inputs, outputs)

def build_discriminator():
    input_img = keras.Input(shape=(32, 32, 1))
    target_img = keras.Input(shape=(32, 32, 3))
    x = keras.layers.Concatenate()([input_img, target_img])
    
    x = keras.layers.Conv2D(64, 4, strides=2, padding='same')(x)  # 16x16x64
    x = keras.layers.LeakyReLU(0.2)(x)
    
    x = keras.layers.Conv2D(128, 4, strides=2, padding='same')(x)  # 8x8x128
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.LeakyReLU(0.2)(x)
    
    x = keras.layers.Conv2D(256, 4, strides=2, padding='same')(x)  # 4x4x256
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.LeakyReLU(0.2)(x)
    
    x = keras.layers.Conv2D(1, 4, strides=1, padding='same')(x)    # 4x4x1
    
    return keras.Model([input_img, target_img], x)

# 4. Loss Functions
cross_entropy = keras.losses.BinaryCrossentropy(from_logits=True)

def generator_loss(disc_generated_output, gen_output, target):
    gan_loss = cross_entropy(tf.ones_like(disc_generated_output), disc_generated_output)
    l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
    total_loss = gan_loss + (100 * l1_loss)
    return total_loss, gan_loss, l1_loss

def discriminator_loss(disc_real_output, disc_generated_output):
    real_loss = cross_entropy(tf.ones_like(disc_real_output), disc_real_output)
    generated_loss = cross_entropy(tf.zeros_like(disc_generated_output), disc_generated_output)
    total_loss = real_loss + generated_loss
    return total_loss

# 5. Training Step
@tf.function
def train_step(input_image, target, generator, discriminator, generator_optimizer, discriminator_optimizer):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = generator(input_image, training=True)
        
        disc_real_output = discriminator([input_image, target], training=True)
        disc_generated_output = discriminator([input_image, gen_output], training=True)
        
        gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
        disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
    
    generator_gradients = gen_tape.gradient(gen_total_loss, generator.trainable_variables)
    discriminator_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    
    generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables))
    
    return gen_total_loss, disc_loss

# 6. Training Loop is now integrated into main

# 7. Image Generation and Saving (Updated)
def generate_and_save_images(model, epoch, test_input, max_images=16):
    predictions = model(test_input, training=False)
    
    num_images = min(predictions.shape[0], max_images)
    grid_size = math.ceil(math.sqrt(num_images))
    
    fig = plt.figure(figsize=(12, 12))
    
    for i in range(num_images):
        plt.subplot(grid_size, grid_size, i+1)
        img = (predictions[i] * 0.5) + 0.5
        plt.imshow(img.numpy())
        plt.axis('off')
    
    plt.tight_layout()
    plt.savefig(f'output/image_at_epoch_{epoch:04d}.png')
    plt.close()

# 8. Evaluation Metrics
def calculate_metrics(real_images, generated_images, win_size=7):
    """
    Calculate the average SSIM and PSNR between real and generated images.

    Args:
        real_images (tf.Tensor): Batch of real images.
        generated_images (tf.Tensor): Batch of generated images.
        win_size (int): Window size for SSIM. Must be odd and <= min(image dimensions).

    Returns:
        tuple: (average SSIM, average PSNR)
    """
    ssim_scores = []
    psnr_scores = []
    
    for real, generated in tqdm(zip(real_images, generated_images), desc="Calculating metrics", total=len(real_images)):
        # Rescale images from [-1, 1] to [0, 255]
        real = ((real * 0.5) + 0.5) * 255
        generated = ((generated * 0.5) + 0.5) * 255
        
        # Convert tensors to NumPy arrays
        real = real.numpy().astype(np.uint8)
        generated = generated.numpy().astype(np.uint8)
        
        # Ensure images have 3 channels
        if real.ndim == 3 and real.shape[-1] == 1:
            real = np.squeeze(real, axis=-1)
        if generated.ndim == 3 and generated.shape[-1] == 1:
            generated = np.squeeze(generated, axis=-1)
        
        # Calculate SSIM with updated parameters
        ssim_val = ssim(
            real, 
            generated, 
            win_size=win_size, 
            channel_axis=-1
        )
        
        # Calculate PSNR
        psnr_val = psnr(
            real, 
            generated, 
            data_range=255
        )
        
        ssim_scores.append(ssim_val)
        psnr_scores.append(psnr_val)
    
    return np.mean(ssim_scores), np.mean(psnr_scores)


# 9. Main Execution (Updated)
def main():
    # Create a directory to save CSV and images if not exists
    os.makedirs('output', exist_ok=True)
    
    # Initialize CSV file
    csv_file = os.path.join('output', 'training_metrics.csv')
    with open(csv_file, mode='w', newline='') as file:
        writer = csv.writer(file)
        # Write headers
        writer.writerow(['Epoch', 'Generator Loss', 'Discriminator Loss'])
    
    # Load and preprocess data
    dataset = load_and_preprocess_data()
    
    # Build models
    generator = build_generator()
    discriminator = build_discriminator()
    
    # Define optimizers
    generator_optimizer = keras.optimizers.Adam(2e-4, beta_1=0.5)
    discriminator_optimizer = keras.optimizers.Adam(2e-4, beta_1=0.5)
    
    # Define number of epochs
    epochs = 10
    
    for epoch in tqdm(range(epochs), desc="Training epochs"):
        gen_losses = []
        disc_losses = []
        
        for input_image, target in tqdm(dataset, desc=f"Epoch {epoch+1}/{epochs}", leave=False):
            gen_total_loss, disc_loss = train_step(input_image, target, generator, discriminator, generator_optimizer, discriminator_optimizer)
            gen_losses.append(gen_total_loss)
            disc_losses.append(disc_loss)
        
        avg_gen_loss = tf.reduce_mean(gen_losses).numpy()
        avg_disc_loss = tf.reduce_mean(disc_losses).numpy()
        
        tqdm.write(f"Epoch {epoch+1}/{epochs}")
        tqdm.write(f"Generator Loss: {avg_gen_loss:.4f}, Discriminator Loss: {avg_disc_loss:.4f}")
        
        # Log to CSV
        with open(csv_file, mode='a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow([epoch+1, avg_gen_loss, avg_disc_loss])
        
        # Save images every 10 epochs
        if (epoch + 1) % 10 == 0:
            # Select a batch of test images (you might want to use a fixed set for consistency)
            test_dataset = dataset.take(1)
            test_grayscale, test_color = next(iter(test_dataset))
            generate_and_save_images(generator, epoch + 1, test_grayscale)
    
    test_dataset = dataset.take(32)
    test_grayscale, test_color = next(iter(test_dataset))
    generated_images = generator(test_grayscale, training=False)
    
    ssim_score, psnr_score = calculate_metrics(test_color, generated_images)
    print(f"SSIM Score: {ssim_score:.4f}")
    print(f"PSNR Score: {psnr_score:.4f}")
    
    # Append evaluation metrics to CSV
    with open(csv_file, mode='a', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(['Evaluation', ssim_score, psnr_score])

if __name__ == "__main__":
    main()


Preprocessing data:   0%|          | 1/5000 [00:00<03:42, 22.50it/s]
Training epochs:   0%|          | 0/10 [00:00<?, ?it/s]2024-10-13 02:28:20.108454: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 02:28:41.364190: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
Training epochs:  10%|█         | 1/10 [00:21<03:13, 21.48s/it]

Epoch 1/10
Generator Loss: 20.0593, Discriminator Loss: 0.9834


2024-10-13 02:28:41.560983: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 02:28:50.889105: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
Training epochs:  20%|██        | 2/10 [00:31<01:55, 14.45s/it]

Epoch 2/10
Generator Loss: 16.2914, Discriminator Loss: 0.8967


2024-10-13 02:28:51.081556: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 02:29:00.259656: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
Training epochs:  30%|███       | 3/10 [00:40<01:24, 12.13s/it]

Epoch 3/10
Generator Loss: 15.5151, Discriminator Loss: 0.8832


2024-10-13 02:29:00.446831: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 02:29:09.650202: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
Training epochs:  40%|████      | 4/10 [00:49<01:06, 11.05s/it]

Epoch 4/10
Generator Loss: 15.4369, Discriminator Loss: 0.8235


2024-10-13 02:29:09.844017: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 02:29:19.023463: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
Training epochs:  50%|█████     | 5/10 [00:59<00:52, 10.44s/it]

Epoch 5/10
Generator Loss: 15.6265, Discriminator Loss: 0.7555


2024-10-13 02:29:19.215420: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 02:29:28.308390: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
Training epochs:  60%|██████    | 6/10 [01:08<00:40, 10.05s/it]

Epoch 6/10
Generator Loss: 15.7089, Discriminator Loss: 0.6475


2024-10-13 02:29:28.497820: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 02:29:37.566596: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
Training epochs:  70%|███████   | 7/10 [01:17<00:29,  9.79s/it]

Epoch 7/10
Generator Loss: 16.5692, Discriminator Loss: 0.4526


2024-10-13 02:29:37.756113: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 02:29:46.850718: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
Training epochs:  80%|████████  | 8/10 [01:26<00:19,  9.63s/it]

Epoch 8/10
Generator Loss: 16.6566, Discriminator Loss: 0.4419


2024-10-13 02:29:47.036472: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 02:29:56.514569: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
Training epochs:  90%|█████████ | 9/10 [01:36<00:09,  9.64s/it]

Epoch 9/10
Generator Loss: 17.2614, Discriminator Loss: 0.3010


2024-10-13 02:29:56.695538: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 02:30:05.702536: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
Training epochs:  90%|█████████ | 9/10 [01:45<00:09,  9.64s/it]2024-10-13 02:30:05.911745: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat

Epoch 10/10
Generator Loss: 14.3535, Discriminator Loss: 0.9362


Training epochs: 100%|██████████| 10/10 [01:46<00:00, 10.62s/it]
2024-10-13 02:30:06.303285: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Calculating metrics: 100%|██████████| 32/32 [00:00<00:00, 225.51it/s]

SSIM Score: 0.8731
PSNR Score: 21.8862





In [15]:
import tensorflow as tf
from tensorflow import keras
import tensorflow_datasets as tfds
from tqdm import tqdm
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import numpy as np
import math
import csv
import os
import datetime

# Enable mixed precision if desired (optional)
# from tensorflow.keras import mixed_precision
# mixed_precision.set_global_policy('mixed_float16')

# 1. Data Preprocessing
def preprocess_image(image):
    """Scales image to [-1, 1]."""
    image = tf.cast(image, tf.float32)
    image = (image / 127.5) - 1
    return image

def rgb_to_grayscale(image):
    """Converts RGB image to grayscale."""
    return tf.image.rgb_to_grayscale(image)

def augment(image, label):
    """Applies random augmentations to images."""
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    image = tf.image.random_brightness(image, max_delta=0.1)
    image = tf.image.random_contrast(image, lower=0.9, upper=1.1)
    return image, label

def load_and_preprocess_data(batch_size=32, num_samples=5000):
    """
    Loads CIFAR-10 dataset, applies preprocessing and augmentation,
    and prepares training, validation, and test datasets.
    """
    # Load CIFAR-10 dataset
    dataset, info = tfds.load('cifar10', with_info=True, as_supervised=True)
    train_dataset = dataset['train'].take(num_samples)

    # Define split sizes
    val_size = int(0.1 * num_samples)  # 10% for validation
    test_size = 32  # Fixed test set size

    # Preprocess and augment the data
    with tqdm(total=num_samples, desc="Preprocessing data") as pbar:
        def preprocess_and_update(img, label):
            pbar.update(1)
            img = preprocess_image(img)
            return img, img  # Input and target are initially the same

        train_dataset = train_dataset.map(preprocess_and_update, num_parallel_calls=tf.data.AUTOTUNE)
        train_dataset = train_dataset.map(lambda img_color, img_color2: (rgb_to_grayscale(img_color), img_color2),
                                          num_parallel_calls=tf.data.AUTOTUNE)
        train_dataset = train_dataset.map(augment, num_parallel_calls=tf.data.AUTOTUNE)
        train_dataset = train_dataset.shuffle(1000)
        train_dataset = train_dataset.batch(batch_size)
        train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)

    # Split into training and validation
    val_dataset = train_dataset.take(val_size // batch_size)
    train_dataset = train_dataset.skip(val_size // batch_size)

    # Create a fixed test set
    test_dataset = dataset['test'].take(test_size).map(lambda img, lbl: preprocess_image(img))
    test_dataset = test_dataset.map(lambda img: (rgb_to_grayscale(img), img))
    test_dataset = test_dataset.batch(test_size)

    return train_dataset, val_dataset, test_dataset

# 2. Build Generator
def build_generator():
    """
    Builds the generator model with an encoder-decoder architecture
    and skip connections.
    """
    inputs = keras.Input(shape=(32, 32, 1))
    
    # Encoder
    e1 = keras.layers.Conv2D(64, 4, strides=2, padding='same')(inputs)  # 16x16x64
    e1 = keras.layers.LeakyReLU(0.2)(e1)
    
    e2 = keras.layers.Conv2D(128, 4, strides=2, padding='same')(e1)    # 8x8x128
    e2 = keras.layers.BatchNormalization()(e2)
    e2 = keras.layers.LeakyReLU(0.2)(e2)
    
    # Bridge
    b = keras.layers.Conv2D(256, 4, strides=1, padding='same')(e2)     # 8x8x256
    b = keras.layers.BatchNormalization()(b)
    b = keras.layers.LeakyReLU(0.2)(b)
    
    # Decoder
    d2 = keras.layers.Conv2DTranspose(128, 4, strides=2, padding='same')(b)  # 16x16x128
    d2 = keras.layers.BatchNormalization()(d2)
    d2 = keras.layers.ReLU()(d2)
    d2 = keras.layers.Concatenate()([d2, e1])                               # 16x16x192
    
    # Reduce channels
    d2 = keras.layers.Conv2D(128, 3, padding='same')(d2)                   # 16x16x128
    d2 = keras.layers.BatchNormalization()(d2)
    d2 = keras.layers.ReLU()(d2)
    
    d1 = keras.layers.Conv2DTranspose(64, 4, strides=2, padding='same')(d2)   # 32x32x64
    d1 = keras.layers.BatchNormalization()(d1)
    d1 = keras.layers.ReLU()(d1)
    d1 = keras.layers.Concatenate()([d1, inputs])                            # 32x32x65
    
    # Reduce channels
    d1 = keras.layers.Conv2D(64, 3, padding='same')(d1)                     # 32x32x64
    d1 = keras.layers.BatchNormalization()(d1)
    d1 = keras.layers.ReLU()(d1)
    
    outputs = keras.layers.Conv2D(3, 4, strides=1, padding='same', activation='tanh')(d1)  # 32x32x3
    
    return keras.Model(inputs, outputs, name="Generator")

# 3. Build Discriminator
def build_discriminator():
    """
    Builds the discriminator model to classify real and generated images.
    """
    input_img = keras.Input(shape=(32, 32, 1))
    target_img = keras.Input(shape=(32, 32, 3))
    x = keras.layers.Concatenate()([input_img, target_img])  # 32x32x4
    
    x = keras.layers.Conv2D(64, 4, strides=2, padding='same')(x)  # 16x16x64
    x = keras.layers.LeakyReLU(0.2)(x)
    
    x = keras.layers.Conv2D(128, 4, strides=2, padding='same')(x)  # 8x8x128
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.LeakyReLU(0.2)(x)
    
    x = keras.layers.Conv2D(256, 4, strides=2, padding='same')(x)  # 4x4x256
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.LeakyReLU(0.2)(x)
    
    x = keras.layers.Conv2D(1, 4, strides=1, padding='same')(x)    # 4x4x1
    
    return keras.Model([input_img, target_img], x, name="Discriminator")

# 4. Loss Functions
cross_entropy = keras.losses.BinaryCrossentropy(from_logits=True)

def generator_loss(disc_generated_output, gen_output, target):
    """
    Calculates generator loss which is a combination of GAN loss and L1 loss.
    """
    gan_loss = cross_entropy(tf.ones_like(disc_generated_output), disc_generated_output)
    l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
    total_loss = gan_loss + (100 * l1_loss)
    return total_loss, gan_loss, l1_loss

def discriminator_loss(disc_real_output, disc_generated_output):
    """
    Calculates discriminator loss.
    """
    real_loss = cross_entropy(tf.ones_like(disc_real_output), disc_real_output)
    generated_loss = cross_entropy(tf.zeros_like(disc_generated_output), disc_generated_output)
    total_loss = real_loss + generated_loss
    return total_loss

# 5. Training Step
@tf.function
def train_step(input_image, target, generator, discriminator, generator_optimizer, discriminator_optimizer):
    """
    Performs one training step for both generator and discriminator.
    """
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        # Generate output
        gen_output = generator(input_image, training=True)
        
        # Discriminator output
        disc_real_output = discriminator([input_image, target], training=True)
        disc_generated_output = discriminator([input_image, gen_output], training=True)
        
        # Calculate losses
        gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
        disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
    
    # Calculate gradients
    generator_gradients = gen_tape.gradient(gen_total_loss, generator.trainable_variables)
    discriminator_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    
    # Apply gradients
    generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables))
    
    return gen_total_loss, disc_loss

# 6. Image Generation and Saving
def generate_and_save_images(model, epoch, test_input, save_dir='output', max_images=16):
    """
    Generates images using the generator and saves them to disk.
    """
    predictions = model(test_input, training=False)
    
    num_images = min(predictions.shape[0], max_images)
    grid_size = math.ceil(math.sqrt(num_images))
    
    fig = plt.figure(figsize=(12, 12))
    
    for i in range(num_images):
        plt.subplot(grid_size, grid_size, i+1)
        img = (predictions[i] * 0.5) + 0.5  # Rescale to [0, 1]
        plt.imshow(img.numpy())
        plt.axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f'image_at_epoch_{epoch:04d}.png'))
    plt.close()

# 7. Evaluation Metrics
def calculate_metrics(real_images, generated_images, win_size=7):
    """
    Calculates average SSIM and PSNR between real and generated images.
    
    Args:
        real_images (tf.Tensor): Batch of real images.
        generated_images (tf.Tensor): Batch of generated images.
        win_size (int): Window size for SSIM. Must be odd and <= min(image dimensions).
    
    Returns:
        tuple: (average SSIM, average PSNR)
    """
    ssim_scores = []
    psnr_scores = []
    
    for real, generated in tqdm(zip(real_images, generated_images), desc="Calculating metrics", total=len(real_images)):
        # Rescale images from [-1, 1] to [0, 255]
        real = ((real * 0.5) + 0.5) * 255
        generated = ((generated * 0.5) + 0.5) * 255
        
        # Convert tensors to NumPy arrays
        real = real.numpy().astype(np.uint8)
        generated = generated.numpy().astype(np.uint8)
        
        # Ensure images have 3 channels
        if real.ndim == 3 and real.shape[-1] == 1:
            real = np.squeeze(real, axis=-1)
        if generated.ndim ==3 and generated.shape[-1] ==1:
            generated = np.squeeze(generated, axis=-1)
        
        # Calculate SSIM with updated parameters
        try:
            ssim_val = ssim(
                real, 
                generated, 
                win_size=win_size, 
                channel_axis=-1
            )
        except ValueError as e:
            print(f"SSIM calculation error: {e}")
            ssim_val = 0  # Assign a default value or handle as needed
        
        # Calculate PSNR
        psnr_val = psnr(
            real, 
            generated, 
            data_range=255
        )
        
        ssim_scores.append(ssim_val)
        psnr_scores.append(psnr_val)
    
    return np.mean(ssim_scores), np.mean(psnr_scores)

# 8. Main Execution with Enhancements
def main():
    # Set random seed for reproducibility (optional)
    tf.random.set_seed(42)
    np.random.seed(42)
    
    # Create a directory to save outputs if not exists
    os.makedirs('output', exist_ok=True)
    
    # Initialize CSV files
    train_csv_file = os.path.join('output', 'training_metrics.csv')
    eval_csv_file = os.path.join('output', 'evaluation_metrics.csv')
    
    with open(train_csv_file, mode='w', newline='') as file:
        writer = csv.writer(file)
        # Write headers
        writer.writerow(['Epoch', 'Generator Loss', 'Discriminator Loss'])
    
    with open(eval_csv_file, mode='w', newline='') as file:
        writer = csv.writer(file)
        # Write headers
        writer.writerow(['Epoch', 'SSIM Score', 'PSNR Score'])
    
    # Load and preprocess data
    train_dataset, val_dataset, test_dataset = load_and_preprocess_data()
    
    # Build models
    generator = build_generator()
    discriminator = build_discriminator()
    
    # Define optimizers
    generator_optimizer = keras.optimizers.Adam(2e-4, beta_1=0.5)
    discriminator_optimizer = keras.optimizers.Adam(2e-4, beta_1=0.5)
    
    # Define number of epochs
    epochs = 100
    patience = 10  # For early stopping
    
    # Setup TensorBoard
    log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    tensorboard_callback = keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
    
    # Setup Model Checkpointing
    checkpoint_dir = './checkpoints'
    os.makedirs(checkpoint_dir, exist_ok=True)
    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
    checkpoint = tf.train.Checkpoint(
        generator=generator,
        discriminator=discriminator,
        generator_optimizer=generator_optimizer,
        discriminator_optimizer=discriminator_optimizer
    )
    
    # Select a fixed test set for consistent evaluation
    test_grayscale, test_color = next(iter(test_dataset))
    
    # Initialize variables for early stopping
    best_val_loss = float('inf')
    wait = 0
    
    # Training Loop
    for epoch in tqdm(range(1, epochs + 1), desc="Training epochs"):
        gen_losses = []
        disc_losses = []
        
        # Training
        for input_image, target in tqdm(train_dataset, desc=f"Epoch {epoch}/{epochs}", leave=False):
            gen_total_loss, disc_loss = train_step(input_image, target, generator, discriminator, generator_optimizer, discriminator_optimizer)
            gen_losses.append(gen_total_loss)
            disc_losses.append(disc_loss)
        
        # Calculate average losses
        avg_gen_loss = tf.reduce_mean(gen_losses).numpy()
        avg_disc_loss = tf.reduce_mean(disc_losses).numpy()
        
        # Validation
        val_gen_losses = []
        val_disc_losses = []
        for val_input, val_target in val_dataset:
            # Generate output
            gen_output = generator(val_input, training=False)
            
            # Discriminator output
            disc_real_output = discriminator([val_input, val_target], training=False)
            disc_generated_output = discriminator([val_input, gen_output], training=False)
            
            # Calculate losses
            val_gen_total_loss, val_gen_gan_loss, val_gen_l1_loss = generator_loss(disc_generated_output, gen_output, val_target)
            val_disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
            
            val_gen_losses.append(val_gen_total_loss)
            val_disc_losses.append(val_disc_loss)
        
        # Calculate average validation losses
        avg_val_gen_loss = tf.reduce_mean(val_gen_losses).numpy()
        avg_val_disc_loss = tf.reduce_mean(val_disc_losses).numpy()
        avg_val_loss = avg_val_gen_loss + avg_val_disc_loss
        
        # Log training losses to CSV
        with open(train_csv_file, mode='a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow([epoch, avg_gen_loss, avg_disc_loss])
        
        # Log validation losses to TensorBoard
        with tf.summary.create_file_writer(log_dir).as_default():
            tf.summary.scalar('Generator Loss (Train)', avg_gen_loss, step=epoch)
            tf.summary.scalar('Discriminator Loss (Train)', avg_disc_loss, step=epoch)
            tf.summary.scalar('Generator Loss (Val)', avg_val_gen_loss, step=epoch)
            tf.summary.scalar('Discriminator Loss (Val)', avg_val_disc_loss, step=epoch)
            tf.summary.scalar('Total Validation Loss', avg_val_loss, step=epoch)
        
        # Print metrics
        tqdm.write(f"Epoch {epoch}/{epochs}")
        tqdm.write(f"Generator Loss: {avg_gen_loss:.4f}, Discriminator Loss: {avg_disc_loss:.4f}")
        tqdm.write(f"Validation Generator Loss: {avg_val_gen_loss:.4f}, Validation Discriminator Loss: {avg_val_disc_loss:.4f}")
        
        # Check for improvement
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            wait = 0
            # Save the best models
            checkpoint.save(file_prefix=checkpoint_prefix)
            tqdm.write("Validation loss improved. Checkpoint saved.")
        else:
            wait += 1
            tqdm.write(f"No improvement in validation loss for {wait} epochs.")
            if wait >= patience:
                tqdm.write("Early stopping triggered.")
                break
        
        # Every 10 epochs, perform evaluation
        if epoch % 10 == 0:
            # Generate images
            generate_and_save_images(generator, epoch, test_grayscale)
            
            # Generate output
            generated_images = generator(test_grayscale, training=False)
            
            # Calculate metrics
            ssim_score, psnr_score = calculate_metrics(test_color, generated_images)
            tqdm.write(f"SSIM Score: {ssim_score:.4f}")
            tqdm.write(f"PSNR Score: {psnr_score:.4f}")
            
            # Log evaluation metrics to separate CSV
            with open(eval_csv_file, mode='a', newline='') as file:
                writer = csv.writer(file)
                writer.writerow([epoch, ssim_score, psnr_score])
            
            # Log evaluation metrics to TensorBoard
            with tf.summary.create_file_writer(log_dir).as_default():
                tf.summary.scalar('SSIM Score', ssim_score, step=epoch)
                tf.summary.scalar('PSNR Score', psnr_score, step=epoch)
    
    # After training, perform final evaluation
    # Optionally, load the best checkpoint
    latest_ckpt = tf.train.latest_checkpoint(checkpoint_dir)
    if latest_ckpt:
        checkpoint.restore(latest_ckpt)
        tqdm.write(f"Restored from checkpoint: {latest_ckpt}")
    
    # Final evaluation on the fixed test set
    generated_images = generator(test_grayscale, training=False)
    ssim_score, psnr_score = calculate_metrics(test_color, generated_images)
    print(f"Final SSIM Score: {ssim_score:.4f}")
    print(f"Final PSNR Score: {psnr_score:.4f}")
    
    # Append final evaluation metrics to CSV
    with open(eval_csv_file, mode='a', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(['Final Evaluation', ssim_score, psnr_score])
    
    # Optionally, save the final models
    generator.save('output/final_generator.h5')
    discriminator.save('output/final_discriminator.h5')
    print("Models saved.")

if __name__ == "__main__":
    main()


Preprocessing data:   0%|          | 1/5000 [00:00<12:19,  6.76it/s]
2024-10-13 02:33:48.611872: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Training epochs:   0%|          | 0/100 [00:00<?, ?it/s]2024-10-13 02:34:02.342452: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 02:34:09.487322:

Epoch 1/100
Generator Loss: 37.3097, Discriminator Loss: 0.7813
Validation Generator Loss: 35.6143, Validation Discriminator Loss: 1.5378


Training epochs:   1%|          | 1/100 [00:22<36:43, 22.26s/it]

Validation loss improved. Checkpoint saved.


2024-10-13 02:34:17.319495: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 02:34:19.509650: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-13 02:34:20.038199: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Epoch 2/100
Generator Loss: 36.9919, Discriminator Loss: 0.4614
Validation Generator Loss: 33.8303, Validation Discriminator Loss: 1.5349
Validation loss improved. Checkpoint saved.


2024-10-13 02:34:26.563716: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 02:34:28.666810: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-13 02:34:29.188695: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Epoch 3/100
Generator Loss: 37.2943, Discriminator Loss: 0.5080
Validation Generator Loss: 34.0764, Validation Discriminator Loss: 1.1682
Validation loss improved. Checkpoint saved.


2024-10-13 02:34:35.834570: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 02:34:37.981000: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-13 02:34:38.520483: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Epoch 4/100
Generator Loss: 37.1762, Discriminator Loss: 0.5794
Validation Generator Loss: 33.6501, Validation Discriminator Loss: 1.5288
Validation loss improved. Checkpoint saved.


2024-10-13 02:34:45.258497: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 02:34:47.378115: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-13 02:34:47.906805: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Epoch 5/100
Generator Loss: 36.7406, Discriminator Loss: 0.5866
Validation Generator Loss: 33.0085, Validation Discriminator Loss: 1.4930
Validation loss improved. Checkpoint saved.


2024-10-13 02:34:54.578962: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 02:34:56.821904: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-13 02:34:57.356317: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Epoch 6/100
Generator Loss: 36.5761, Discriminator Loss: 0.6553
Validation Generator Loss: 33.8889, Validation Discriminator Loss: 1.7836
No improvement in validation loss for 1 epochs.


2024-10-13 02:35:03.769410: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 02:35:05.895851: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-13 02:35:06.421357: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Epoch 7/100
Generator Loss: 36.2320, Discriminator Loss: 0.6596
Validation Generator Loss: 34.2767, Validation Discriminator Loss: 1.4527
No improvement in validation loss for 2 epochs.


2024-10-13 02:35:12.934073: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 02:35:15.020425: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-13 02:35:15.584233: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Epoch 8/100
Generator Loss: 36.2172, Discriminator Loss: 0.6716
Validation Generator Loss: 35.1196, Validation Discriminator Loss: 1.4186
No improvement in validation loss for 3 epochs.


2024-10-13 02:35:22.061284: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 02:35:24.179623: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-13 02:35:24.722281: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Epoch 9/100
Generator Loss: 36.0752, Discriminator Loss: 0.6671
Validation Generator Loss: 34.4054, Validation Discriminator Loss: 1.4907
No improvement in validation loss for 4 epochs.


2024-10-13 02:35:31.471746: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 02:35:33.626639: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-13 02:35:34.158160: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Epoch 10/100
Generator Loss: 35.9284, Discriminator Loss: 0.6859
Validation Generator Loss: 33.2915, Validation Discriminator Loss: 1.2158
No improvement in validation loss for 5 epochs.


Calculating metrics: 100%|██████████| 32/32 [00:00<00:00, 223.59it/s]
Training epochs:  10%|█         | 10/100 [01:46<14:27,  9.64s/it]

SSIM Score: 0.4748
PSNR Score: 15.1377


2024-10-13 02:35:41.328905: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 02:35:43.480124: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-13 02:35:44.003116: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Epoch 11/100
Generator Loss: 35.8700, Discriminator Loss: 0.6470
Validation Generator Loss: 32.9931, Validation Discriminator Loss: 1.2532
Validation loss improved. Checkpoint saved.


2024-10-13 02:35:50.672058: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 02:35:52.795698: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-13 02:35:53.336674: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Epoch 12/100
Generator Loss: 35.5830, Discriminator Loss: 0.6573
Validation Generator Loss: 34.3305, Validation Discriminator Loss: 1.2967
No improvement in validation loss for 1 epochs.


2024-10-13 02:35:59.884994: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 02:36:02.061778: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-13 02:36:02.595520: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Epoch 13/100
Generator Loss: 35.4862, Discriminator Loss: 0.6583
Validation Generator Loss: 36.2427, Validation Discriminator Loss: 1.2783
No improvement in validation loss for 2 epochs.


2024-10-13 02:36:08.926047: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 02:36:11.001780: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-13 02:36:11.517785: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Epoch 14/100
Generator Loss: 35.3503, Discriminator Loss: 0.6378
Validation Generator Loss: 37.4768, Validation Discriminator Loss: 1.2825
No improvement in validation loss for 3 epochs.


2024-10-13 02:36:17.953501: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 02:36:20.021158: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-13 02:36:20.546441: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Epoch 15/100
Generator Loss: 35.2367, Discriminator Loss: 0.5816
Validation Generator Loss: 35.0972, Validation Discriminator Loss: 1.3394
No improvement in validation loss for 4 epochs.


2024-10-13 02:36:27.088920: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 02:36:29.253231: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-13 02:36:29.806263: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Epoch 16/100
Generator Loss: 35.2611, Discriminator Loss: 0.5917
Validation Generator Loss: 33.9597, Validation Discriminator Loss: 1.2762
No improvement in validation loss for 5 epochs.


2024-10-13 02:36:36.563354: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 02:36:38.687977: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-13 02:36:39.217938: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Epoch 17/100
Generator Loss: 35.1133, Discriminator Loss: 0.6330
Validation Generator Loss: 32.9222, Validation Discriminator Loss: 1.2236
Validation loss improved. Checkpoint saved.


2024-10-13 02:36:45.816797: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 02:36:47.936071: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-13 02:36:48.497438: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Epoch 18/100
Generator Loss: 34.9588, Discriminator Loss: 0.6056
Validation Generator Loss: 34.0757, Validation Discriminator Loss: 1.6225
No improvement in validation loss for 1 epochs.


2024-10-13 02:36:55.025306: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 02:36:57.170563: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-13 02:36:57.696668: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Epoch 19/100
Generator Loss: 34.7161, Discriminator Loss: 0.6275
Validation Generator Loss: 33.2473, Validation Discriminator Loss: 1.4901
No improvement in validation loss for 2 epochs.


2024-10-13 02:37:04.502510: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 02:37:06.653309: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-13 02:37:07.191817: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Epoch 20/100
Generator Loss: 34.4345, Discriminator Loss: 0.6121
Validation Generator Loss: 33.7939, Validation Discriminator Loss: 1.6383
No improvement in validation loss for 3 epochs.


Calculating metrics: 100%|██████████| 32/32 [00:00<00:00, 442.15it/s]
Training epochs:  20%|██        | 20/100 [03:18<12:32,  9.40s/it]

SSIM Score: 0.6030
PSNR Score: 15.7510


2024-10-13 02:37:13.790313: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 02:37:15.819796: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-13 02:37:16.335366: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Epoch 21/100
Generator Loss: 34.4313, Discriminator Loss: 0.6328
Validation Generator Loss: 32.7748, Validation Discriminator Loss: 1.4212
No improvement in validation loss for 4 epochs.


2024-10-13 02:37:22.626227: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 02:37:24.673825: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-13 02:37:25.176804: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Epoch 22/100
Generator Loss: 34.2099, Discriminator Loss: 0.6012
Validation Generator Loss: 34.0707, Validation Discriminator Loss: 1.3580
No improvement in validation loss for 5 epochs.


2024-10-13 02:37:31.577457: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 02:37:33.612320: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-13 02:37:34.125406: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Epoch 23/100
Generator Loss: 34.1278, Discriminator Loss: 0.6062
Validation Generator Loss: 40.5748, Validation Discriminator Loss: 1.4003
No improvement in validation loss for 6 epochs.


2024-10-13 02:37:40.649504: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 02:37:42.693080: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-13 02:37:43.209956: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Epoch 24/100
Generator Loss: 34.1880, Discriminator Loss: 0.6281
Validation Generator Loss: 34.2168, Validation Discriminator Loss: 1.6371
No improvement in validation loss for 7 epochs.


2024-10-13 02:37:49.424176: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 02:37:51.447302: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-13 02:37:51.960356: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Epoch 25/100
Generator Loss: 33.8618, Discriminator Loss: 0.6154
Validation Generator Loss: 34.5666, Validation Discriminator Loss: 1.4141
No improvement in validation loss for 8 epochs.


2024-10-13 02:37:58.254352: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 02:38:00.285513: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-13 02:38:00.787531: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Epoch 26/100
Generator Loss: 33.6681, Discriminator Loss: 0.5969
Validation Generator Loss: 36.2486, Validation Discriminator Loss: 1.6512
No improvement in validation loss for 9 epochs.


2024-10-13 02:38:07.243470: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 02:38:09.349835: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-13 02:38:09.858698: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Epoch 27/100
Generator Loss: 33.7762, Discriminator Loss: 0.5732
Validation Generator Loss: 38.4997, Validation Discriminator Loss: 1.3148
No improvement in validation loss for 10 epochs.
Early stopping triggered.
Restored from checkpoint: ./checkpoints/ckpt-7


Calculating metrics: 100%|██████████| 32/32 [00:00<00:00, 445.90it/s]


Final SSIM Score: 0.4191
Final PSNR Score: 14.8780




Models saved.


In [21]:
import tensorflow as tf
from tensorflow import keras
import tensorflow_datasets as tfds
from tqdm import tqdm
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import numpy as np
import math
import csv
import os
import datetime

# 1. Data Preprocessing
def preprocess_image(image):
    """Scales image to [-1, 1]."""
    image = tf.cast(image, tf.float32)
    image = (image / 127.5) - 1
    return image

def rgb_to_grayscale(image):
    """Converts RGB image to grayscale."""
    return tf.image.rgb_to_grayscale(image)

def augment(image, label):
    """Applies random augmentations to images."""
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    image = tf.image.random_brightness(image, max_delta=0.1)
    image = tf.image.random_contrast(image, lower=0.9, upper=1.1)
    return image, label

def load_and_preprocess_data(batch_size=32, num_samples=5000):
    """
    Loads CIFAR-10 dataset, applies preprocessing and augmentation,
    and prepares training, validation, and test datasets.
    """
    # Load CIFAR-10 dataset
    dataset, info = tfds.load('cifar10', with_info=True, as_supervised=True)
    train_dataset = dataset['train'].take(num_samples)

    # Define split sizes
    val_size = int(0.1 * num_samples)  # 10% for validation
    test_size = 32  # Fixed test set size

    # Preprocess and augment the data
    with tqdm(total=num_samples, desc="Preprocessing data") as pbar:
        def preprocess_and_update(img, label):
            pbar.update(1)
            img = preprocess_image(img)
            return img, img  # Input and target are initially the same

        train_dataset = train_dataset.map(preprocess_and_update, num_parallel_calls=tf.data.AUTOTUNE)
        train_dataset = train_dataset.map(lambda img_color, img_color2: (rgb_to_grayscale(img_color), img_color2),
                                          num_parallel_calls=tf.data.AUTOTUNE)
        train_dataset = train_dataset.map(augment, num_parallel_calls=tf.data.AUTOTUNE)
        train_dataset = train_dataset.shuffle(1000)
        train_dataset = train_dataset.batch(batch_size)
        train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)

    # Split into training and validation
    val_dataset = train_dataset.take(val_size // batch_size)
    train_dataset = train_dataset.skip(val_size // batch_size)

    # Create a fixed test set
    test_dataset = dataset['test'].take(test_size).map(lambda img, lbl: preprocess_image(img))
    test_dataset = test_dataset.map(lambda img: (rgb_to_grayscale(img), img))
    test_dataset = test_dataset.batch(test_size)

    return train_dataset, val_dataset, test_dataset

# Update the generator and discriminator to work with 32x32 images
def build_generator():
    inputs = keras.Input(shape=(32, 32, 1))
    
    # Encoder
    e1 = keras.layers.Conv2D(64, 4, strides=2, padding='same')(inputs)  # 16x16x64
    e1 = keras.layers.LeakyReLU(0.2)(e1)
    
    e2 = keras.layers.Conv2D(128, 4, strides=2, padding='same')(e1)    # 8x8x128
    e2 = keras.layers.BatchNormalization()(e2)
    e2 = keras.layers.LeakyReLU(0.2)(e2)
    
    # Bridge
    b = keras.layers.Conv2D(256, 4, strides=1, padding='same')(e2)     # 8x8x256
    b = keras.layers.BatchNormalization()(b)
    b = keras.layers.LeakyReLU(0.2)(b)
    
    # Decoder
    d2 = keras.layers.Conv2DTranspose(128, 4, strides=2, padding='same')(b)  # 16x16x128
    d2 = keras.layers.BatchNormalization()(d2)
    d2 = keras.layers.ReLU()(d2)
    d2 = keras.layers.Concatenate()([d2, e1])                               # 16x16x192
    
    # Optionally, add a Conv2D layer to reduce channels
    d2 = keras.layers.Conv2D(128, 3, padding='same')(d2)                   # 16x16x128
    d2 = keras.layers.BatchNormalization()(d2)
    d2 = keras.layers.ReLU()(d2)
    
    d1 = keras.layers.Conv2DTranspose(64, 4, strides=2, padding='same')(d2)   # 32x32x64
    d1 = keras.layers.BatchNormalization()(d1)
    d1 = keras.layers.ReLU()(d1)
    d1 = keras.layers.Concatenate()([d1, inputs])                            # 32x32x65
    
    # Optionally, add a Conv2D layer to reduce channels
    d1 = keras.layers.Conv2D(64, 3, padding='same')(d1)                     # 32x32x64
    d1 = keras.layers.BatchNormalization()(d1)
    d1 = keras.layers.ReLU()(d1)
    
    outputs = keras.layers.Conv2D(3, 4, strides=1, padding='same', activation='tanh')(d1)  # 32x32x3
    
    return keras.Model(inputs, outputs)

def build_discriminator():
    input_img = keras.Input(shape=(32, 32, 1))
    target_img = keras.Input(shape=(32, 32, 3))
    x = keras.layers.Concatenate()([input_img, target_img])
    
    x = keras.layers.Conv2D(64, 4, strides=2, padding='same')(x)  # 16x16x64
    x = keras.layers.LeakyReLU(0.2)(x)
    
    x = keras.layers.Conv2D(128, 4, strides=2, padding='same')(x)  # 8x8x128
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.LeakyReLU(0.2)(x)
    
    x = keras.layers.Conv2D(256, 4, strides=2, padding='same')(x)  # 4x4x256
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.LeakyReLU(0.2)(x)
    
    x = keras.layers.Conv2D(1, 4, strides=1, padding='same')(x)    # 4x4x1
    
    return keras.Model([input_img, target_img], x)
# 4. Loss Functions
cross_entropy = keras.losses.BinaryCrossentropy(from_logits=True)

def generator_loss(disc_generated_output, gen_output, target, l1_weight=10):
    """
    Calculates generator loss which is a combination of GAN loss and L1 loss.
    """
    gan_loss = cross_entropy(tf.ones_like(disc_generated_output), disc_generated_output)
    l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
    total_loss = gan_loss + (l1_weight * l1_loss)
    return total_loss, gan_loss, l1_loss

def discriminator_loss(disc_real_output, disc_generated_output):
    """
    Calculates discriminator loss.
    """
    real_loss = cross_entropy(tf.ones_like(disc_real_output), disc_real_output)
    generated_loss = cross_entropy(tf.zeros_like(disc_generated_output), disc_generated_output)
    total_loss = real_loss + generated_loss
    return total_loss

# 5. Training Step
@tf.function
def train_step(input_image, target, generator, discriminator, generator_optimizer, discriminator_optimizer, l1_weight=10):
    """
    Performs one training step for both generator and discriminator.
    """
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        # Generate output
        gen_output = generator(input_image, training=True)
        
        # Discriminator output
        disc_real_output = discriminator([input_image, target], training=True)
        disc_generated_output = discriminator([input_image, gen_output], training=True)
        
        # Calculate losses
        gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(
            disc_generated_output, gen_output, target, l1_weight=l1_weight
        )
        disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
    
    # Calculate gradients
    generator_gradients = gen_tape.gradient(gen_total_loss, generator.trainable_variables)
    discriminator_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    
    # Apply gradients
    generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables))
    
    return gen_total_loss, disc_loss

# 6. Image Generation and Saving
def generate_and_save_images(model, epoch, test_input, save_dir='output', max_images=16):
    """
    Generates images using the generator and saves them to disk.
    """
    predictions = model(test_input, training=False)
    
    num_images = min(predictions.shape[0], max_images)
    grid_size = math.ceil(math.sqrt(num_images))
    
    fig = plt.figure(figsize=(12, 12))
    
    for i in range(num_images):
        plt.subplot(grid_size, grid_size, i+1)
        img = (predictions[i] * 0.5) + 0.5  # Rescale to [0, 1]
        plt.imshow(img.numpy())
        plt.axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f'image_at_epoch_{epoch:04d}.png'))
    plt.close()

# 7. Evaluation Metrics
def calculate_metrics(real_images, generated_images, win_size=7):
    """
    Calculates average SSIM and PSNR between real and generated images.
    
    Args:
        real_images (tf.Tensor): Batch of real images.
        generated_images (tf.Tensor): Batch of generated images.
        win_size (int): Window size for SSIM. Must be odd and <= min(image dimensions).
    
    Returns:
        tuple: (average SSIM, average PSNR)
    """
    ssim_scores = []
    psnr_scores = []
    
    for real, generated in tqdm(zip(real_images, generated_images), desc="Calculating metrics", total=len(real_images)):
        # Rescale images from [-1, 1] to [0, 255]
        real = ((real * 0.5) + 0.5) * 255
        generated = ((generated * 0.5) + 0.5) * 255
        
        # Convert tensors to NumPy arrays
        real = real.numpy().astype(np.uint8)
        generated = generated.numpy().astype(np.uint8)
        
        # Ensure images have 3 channels
        if real.ndim == 3 and real.shape[-1] == 1:
            real = np.squeeze(real, axis=-1)
        if generated.ndim ==3 and generated.shape[-1] ==1:
            generated = np.squeeze(generated, axis=-1)
        
        # Calculate SSIM with updated parameters
        try:
            ssim_val = ssim(
                real, 
                generated, 
                win_size=win_size, 
                channel_axis=-1
            )
        except ValueError as e:
            print(f"SSIM calculation error: {e}")
            ssim_val = 0  # Assign a default value or handle as needed
        
        # Calculate PSNR
        psnr_val = psnr(
            real, 
            generated, 
            data_range=255
        )
        
        ssim_scores.append(ssim_val)
        psnr_scores.append(psnr_val)
    
    return np.mean(ssim_scores), np.mean(psnr_scores)

# 8. Plotting Function
def plot_metrics(train_csv, eval_csv, output_dir='output'):
    """
    Plots training and evaluation metrics from CSV files.
    
    Args:
        train_csv (str): Path to the training metrics CSV file.
        eval_csv (str): Path to the evaluation metrics CSV file.
        output_dir (str): Directory to save the plots.
    """
    # Ensure the output directory exists
    os.makedirs(output_dir, exist_ok=True)
    
    # Load training metrics
    train_epochs = []
    gen_losses = []
    disc_losses = []
    with open(train_csv, 'r') as file:
        reader = csv.DictReader(file)
        for row in reader:
            train_epochs.append(int(row['Epoch']))
            gen_losses.append(float(row['Generator Loss']))
            disc_losses.append(float(row['Discriminator Loss']))
    
    # Load evaluation metrics
    eval_epochs = []
    ssim_scores = []
    psnr_scores = []
    with open(eval_csv, 'r') as file:
        reader = csv.DictReader(file)
        for row in reader:
            eval_epochs.append(int(row['Epoch']))
            ssim_scores.append(float(row['SSIM Score']))
            psnr_scores.append(float(row['PSNR Score']))
    
    # Plot Generator and Discriminator Losses
    plt.figure(figsize=(10, 6))
    plt.plot(train_epochs, gen_losses, label='Generator Loss')
    plt.plot(train_epochs, disc_losses, label='Discriminator Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Losses')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(output_dir, 'training_losses.png'))
    plt.close()
    
    # Plot SSIM and PSNR Scores
    plt.figure(figsize=(10, 6))
    plt.plot(eval_epochs, ssim_scores, label='SSIM Score')
    plt.plot(eval_epochs, psnr_scores, label='PSNR Score')
    plt.xlabel('Epoch')
    plt.ylabel('Score')
    plt.title('Evaluation Metrics')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(output_dir, 'evaluation_metrics.png'))
    plt.close()
    
    print(f"Plots saved in '{output_dir}' directory.")

# 9. Main Execution with Enhancements
def main():
    # Set random seed for reproducibility (optional)
    tf.random.set_seed(42)
    np.random.seed(42)
    
    # Create a directory to save outputs if not exists
    os.makedirs('output', exist_ok=True)
    
    # Initialize CSV files
    train_csv_file = os.path.join('output', 'training_metrics.csv')
    eval_csv_file = os.path.join('output', 'evaluation_metrics.csv')
    
    with open(train_csv_file, mode='w', newline='') as file:
        writer = csv.writer(file)
        # Write headers
        writer.writerow(['Epoch', 'Generator Loss', 'Discriminator Loss'])
    
    with open(eval_csv_file, mode='w', newline='') as file:
        writer = csv.writer(file)
        # Write headers
        writer.writerow(['Epoch', 'SSIM Score', 'PSNR Score'])
    
    # Load and preprocess data
    train_dataset, val_dataset, test_dataset = load_and_preprocess_data()
    
    # Build models
    generator = build_generator()
    discriminator = build_discriminator()
    
    # Define optimizers
    generator_optimizer = keras.optimizers.Adam(1e-4, beta_1=0.5)  # Reduced learning rate
    discriminator_optimizer = keras.optimizers.Adam(1e-4, beta_1=0.5)  # Reduced learning rate
    
    # Define number of epochs
    epochs = 100
    patience = 15  # For early stopping
    
    # Setup TensorBoard
    log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    tensorboard_callback = keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
    
    # Setup Model Checkpointing
    checkpoint_dir = './checkpoints'
    os.makedirs(checkpoint_dir, exist_ok=True)
    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
    checkpoint = tf.train.Checkpoint(
        generator=generator,
        discriminator=discriminator,
        generator_optimizer=generator_optimizer,
        discriminator_optimizer=discriminator_optimizer
    )
    
    # Select a fixed test set for consistent evaluation
    test_grayscale, test_color = next(iter(test_dataset))
    
    # Initialize variables for early stopping
    best_val_loss = float('inf')
    wait = 0
    
    # Training Loop
    for epoch in tqdm(range(1, epochs + 1), desc="Training epochs"):
        gen_losses = []
        disc_losses = []
        
        # Training
        for input_image, target in tqdm(train_dataset, desc=f"Epoch {epoch}/{epochs}", leave=False):
            gen_total_loss, disc_loss = train_step(
                input_image, target, generator, discriminator, 
                generator_optimizer, discriminator_optimizer, 
                l1_weight=10  # Adjusted L1 weight
            )
            gen_losses.append(gen_total_loss)
            disc_losses.append(disc_loss)
        
        # Calculate average losses
        avg_gen_loss = tf.reduce_mean(gen_losses).numpy()
        avg_disc_loss = tf.reduce_mean(disc_losses).numpy()
        
        # Validation
        val_gen_losses = []
        val_disc_losses = []
        for val_input, val_target in val_dataset:
            # Generate output
            gen_output = generator(val_input, training=False)
            
            # Discriminator output
            disc_real_output = discriminator([val_input, val_target], training=False)
            disc_generated_output = discriminator([val_input, gen_output], training=False)
            
            # Calculate losses
            val_gen_total_loss, val_gen_gan_loss, val_gen_l1_loss = generator_loss(
                disc_generated_output, gen_output, val_target, l1_weight=10
            )
            val_disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
            
            val_gen_losses.append(val_gen_total_loss)
            val_disc_losses.append(val_disc_loss)
        
        # Calculate average validation losses
        avg_val_gen_loss = tf.reduce_mean(val_gen_losses).numpy()
        avg_val_disc_loss = tf.reduce_mean(val_disc_losses).numpy()
        avg_val_loss = avg_val_gen_loss + avg_val_disc_loss
        
        # Log training losses to CSV
        with open(train_csv_file, mode='a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow([epoch, avg_gen_loss, avg_disc_loss])
        
        # Log validation losses to TensorBoard
        with tf.summary.create_file_writer(log_dir).as_default():
            tf.summary.scalar('Generator Loss (Train)', avg_gen_loss, step=epoch)
            tf.summary.scalar('Discriminator Loss (Train)', avg_disc_loss, step=epoch)
            tf.summary.scalar('Generator Loss (Val)', avg_val_gen_loss, step=epoch)
            tf.summary.scalar('Discriminator Loss (Val)', avg_val_disc_loss, step=epoch)
            tf.summary.scalar('Total Validation Loss', avg_val_loss, step=epoch)
        
        # Print metrics
        tqdm.write(f"Epoch {epoch}/{epochs}")
        tqdm.write(f"Generator Loss: {avg_gen_loss:.4f}, Discriminator Loss: {avg_disc_loss:.4f}")
        tqdm.write(f"Validation Generator Loss: {avg_val_gen_loss:.4f}, Validation Discriminator Loss: {avg_val_disc_loss:.4f}")
        
        # Check for improvement
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            wait = 0
            # Save the best models
            checkpoint.save(file_prefix=checkpoint_prefix)
            tqdm.write("Validation loss improved. Checkpoint saved.")
        else:
            wait += 1
            tqdm.write(f"No improvement in validation loss for {wait} epochs.")
            if wait >= patience:
                tqdm.write("Early stopping triggered.")
                break
        
        # Every 10 epochs, perform evaluation
        if epoch % 10 == 0:
            # Generate images
            generate_and_save_images(generator, epoch, test_grayscale)
            
            # Generate output
            generated_images = generator(test_grayscale, training=False)
            
            # Calculate metrics
            ssim_score, psnr_score = calculate_metrics(test_color, generated_images)
            tqdm.write(f"SSIM Score: {ssim_score:.4f}")
            tqdm.write(f"PSNR Score: {psnr_score:.4f}")
            
            # Log evaluation metrics to separate CSV
            with open(eval_csv_file, mode='a', newline='') as file:
                writer = csv.writer(file)
                writer.writerow([epoch, ssim_score, psnr_score])
            
            # Log evaluation metrics to TensorBoard
            with tf.summary.create_file_writer(log_dir).as_default():
                tf.summary.scalar('SSIM Score', ssim_score, step=epoch)
                tf.summary.scalar('PSNR Score', psnr_score, step=epoch)
    
    # 9. Plotting Function
def plot_metrics(train_csv, eval_csv, output_dir='output'):
    """
    Plots training and evaluation metrics from CSV files.
    
    Args:
        train_csv (str): Path to the training metrics CSV file.
        eval_csv (str): Path to the evaluation metrics CSV file.
        output_dir (str): Directory to save the plots.
    """
    # Ensure the output directory exists
    os.makedirs(output_dir, exist_ok=True)
    
    # Load training metrics
    train_epochs = []
    gen_losses = []
    disc_losses = []
    with open(train_csv, 'r') as file:
        reader = csv.DictReader(file)
        for row in reader:
            train_epochs.append(int(row['Epoch']))
            gen_losses.append(float(row['Generator Loss']))
            disc_losses.append(float(row['Discriminator Loss']))
    
    # Load evaluation metrics
    eval_epochs = []
    ssim_scores = []
    psnr_scores = []
    with open(eval_csv, 'r') as file:
        reader = csv.DictReader(file)
        for row in reader:
            eval_epochs.append(int(row['Epoch']))
            ssim_scores.append(float(row['SSIM Score']))
            psnr_scores.append(float(row['PSNR Score']))
    
    # Plot Generator and Discriminator Losses
    plt.figure(figsize=(10, 6))
    plt.plot(train_epochs, gen_losses, label='Generator Loss')
    plt.plot(train_epochs, disc_losses, label='Discriminator Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Losses')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(output_dir, 'training_losses.png'))
    plt.close()
    
    # Plot SSIM and PSNR Scores
    plt.figure(figsize=(10, 6))
    plt.plot(eval_epochs, ssim_scores, label='SSIM Score')
    plt.plot(eval_epochs, psnr_scores, label='PSNR Score')
    plt.xlabel('Epoch')
    plt.ylabel('Score')
    plt.title('Evaluation Metrics')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(output_dir, 'evaluation_metrics.png'))
    plt.close()
    
    print(f"Plots saved in '{output_dir}' directory.")

# 10. Final Evaluation and Plotting
def finalize_training(checkpoint_dir='./checkpoints', train_csv='output/training_metrics.csv', eval_csv='output/evaluation_metrics.csv'):
    """
    Restores the best model, performs final evaluation, and plots metrics.
    """
    # Restore the best checkpoint
    latest_ckpt = tf.train.latest_checkpoint(checkpoint_dir)
    if latest_ckpt:
        print(f"Restoring from checkpoint: {latest_ckpt}")
        checkpoint.restore(latest_ckpt)
    else:
        print("No checkpoint found. Proceeding without restoring.")
    
    # Load test dataset
    _, _, test_dataset = load_and_preprocess_data()
    test_grayscale, test_color = next(iter(test_dataset))
    
    # Generate images
    generated_images = generator(test_grayscale, training=False)
    
    # Calculate metrics
    ssim_score, psnr_score = calculate_metrics(test_color, generated_images)
    print(f"Final SSIM Score: {ssim_score:.4f}")
    print(f"Final PSNR Score: {psnr_score:.4f}")
    
    # Append final evaluation metrics to CSV
    with open(eval_csv, mode='a', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(['Final Evaluation', ssim_score, psnr_score])
    
    # Save the final models
    generator.save('output/final_generator.h5')
    discriminator.save('output/final_discriminator.h5')
    print("Final models saved.")
    
    # Plot metrics
    plot_metrics(train_csv, eval_csv, output_dir='output')

if __name__ == "__main__":
    main()
    # After training completes, plot the metrics
    plot_metrics(
        train_csv='output/training_metrics.csv',
        eval_csv='output/evaluation_metrics.csv',
        output_dir='output'
    )


Preprocessing data:   0%|          | 1/5000 [00:00<09:11,  9.07it/s]
2024-10-13 09:26:00.456888: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Training epochs:   0%|          | 0/100 [00:00<?, ?it/s]2024-10-13 09:26:14.676003: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 09:26:22.892781:

Epoch 1/100
Generator Loss: 4.6054, Discriminator Loss: 1.2104
Validation Generator Loss: 4.2390, Validation Discriminator Loss: 1.3984


Training epochs:   1%|          | 1/100 [00:23<38:49, 23.53s/it]

Validation loss improved. Checkpoint saved.


2024-10-13 09:26:30.621640: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-10-13 09:26:32.852796: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-10-13 09:26:33.380295: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Epoch 2/100
Generator Loss: 4.7043, Discriminator Loss: 1.1639
Validation Generator Loss: 4.0784, Validation Discriminator Loss: 1.3760
Validation loss improved. Checkpoint saved.


2024-10-13 09:26:40.327583: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Training epochs:   2%|▏         | 2/100 [00:40<32:45, 20.05s/it]


KeyboardInterrupt: 

In [25]:
import tensorflow as tf
from tensorflow import keras
import tensorflow_datasets as tfds
from tqdm import tqdm
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import numpy as np
import math
import csv
import os
import datetime

# 1. Data Preprocessing
def preprocess_image(image, target):
    image = tf.image.resize(image, [256, 256])  # Ensure the image is resized to 256x256
    target = tf.image.resize(target, [256, 256])  # Ensure the target is resized to 256x256

    image = (image / 127.5) - 1  # Normalize the image to [-1, 1] range
    target = (target / 127.5) - 1  # Normalize the target to [-1, 1] range

    return image, target

def rgb_to_grayscale(image):
    """Converts RGB image to grayscale."""
    return tf.image.rgb_to_grayscale(image)

def augment(image, label):
    """Applies random augmentations to images."""
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    image = tf.image.random_brightness(image, max_delta=0.1)
    image = tf.image.random_contrast(image, lower=0.9, upper=1.1)
    return image, label

def load_and_preprocess_data(batch_size=32, num_samples=10000):
    """
    Loads CIFAR-10 dataset, applies preprocessing and augmentation,
    and prepares training, validation, and test datasets.
    """
    # Load CIFAR-10 dataset
    dataset, info = tfds.load('cifar10', with_info=True, as_supervised=True)
    train_dataset = dataset['train'].take(num_samples)

    # Define split sizes
    val_size = int(0.1 * num_samples)  # 10% for validation
    test_size = 32  # Fixed test set size

    # Preprocess and augment the data
    with tqdm(total=num_samples, desc="Preprocessing data") as pbar:
        def preprocess_and_update(img, label):
            pbar.update(1)
            img = preprocess_image(img)
            return img, img  # Input and target are initially the same

        train_dataset = train_dataset.map(preprocess_and_update, num_parallel_calls=tf.data.AUTOTUNE)
        train_dataset = train_dataset.map(lambda img_color, img_color2: (rgb_to_grayscale(img_color), img_color2),
                                          num_parallel_calls=tf.data.AUTOTUNE)
        train_dataset = train_dataset.map(augment, num_parallel_calls=tf.data.AUTOTUNE)
        train_dataset = train_dataset.cache()
        train_dataset = train_dataset.shuffle(1000)
        train_dataset = train_dataset.batch(batch_size)
        train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)

    # Split into training and validation
    val_dataset = train_dataset.take(val_size // batch_size)
    train_dataset = train_dataset.skip(val_size // batch_size)

    # Create a fixed test set
    test_dataset = dataset['test'].take(test_size).map(lambda img, lbl: preprocess_image(img))
    test_dataset = test_dataset.map(lambda img: (rgb_to_grayscale(img), img))
    test_dataset = test_dataset.batch(test_size)

    return train_dataset, val_dataset, test_dataset

# 2. Residual Block for Generator
# Define the residual block with an option to control output channels
# Define the residual block with an option to control output channels
def residual_block(x, filters):
    skip = x  # Save the input tensor for the skip connection

    # First convolutional layer
    x = keras.layers.Conv2D(filters, (3, 3), padding='same')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.ReLU()(x)

    # Second convolutional layer
    x = keras.layers.Conv2D(filters, (3, 3), padding='same')(x)
    x = keras.layers.BatchNormalization()(x)

    # If skip connection has a different number of channels, project it to match
    if skip.shape[-1] != filters:
        skip = keras.layers.Conv2D(filters, (1, 1), padding='same')(skip)
    
    # Skip connection: add input and output
    x = keras.layers.Add()([x, skip])
    return x


# Generator model (U-Net style with residual blocks)
def build_generator():
    inputs = keras.Input(shape=(256, 256, 3))
    
    # Encoder
    e1 = keras.layers.Conv2D(64, (4, 4), strides=2, padding='same')(inputs)    # 128x128x64
    e1 = keras.layers.ReLU()(e1)
    
    e2 = keras.layers.Conv2D(128, (4, 4), strides=2, padding='same')(e1)       # 64x64x128
    e2 = keras.layers.BatchNormalization()(e2)
    e2 = keras.layers.ReLU()(e2)
    
    e3 = keras.layers.Conv2D(256, (4, 4), strides=2, padding='same')(e2)       # 32x32x256
    e3 = keras.layers.BatchNormalization()(e3)
    e3 = keras.layers.ReLU()(e3)
    
    e4 = keras.layers.Conv2D(512, (4, 4), strides=2, padding='same')(e3)       # 16x16x512
    e4 = keras.layers.BatchNormalization()(e4)
    e4 = keras.layers.ReLU()(e4)
    
    e5 = keras.layers.Conv2D(512, (4, 4), strides=2, padding='same')(e4)       # 8x8x512
    e5 = keras.layers.BatchNormalization()(e5)
    e5 = keras.layers.ReLU()(e5)
    
    # Bottleneck
    b = residual_block(e5, 512)                                                # 8x8x512
    
    # Decoder
    d1 = keras.layers.Conv2DTranspose(512, (4, 4), strides=2, padding='same')(b)  # 16x16x512
    d1 = keras.layers.BatchNormalization()(d1)
    d1 = keras.layers.ReLU()(d1)
    d1 = keras.layers.Concatenate()([d1, e4])                                  # 16x16x1024
    
    d1 = residual_block(d1, 512)                                               # Ensure it has 512 channels
    
    d2 = keras.layers.Conv2DTranspose(256, (4, 4), strides=2, padding='same')(d1)  # 32x32x256
    d2 = keras.layers.BatchNormalization()(d2)
    d2 = keras.layers.ReLU()(d2)
    d2 = keras.layers.Concatenate()([d2, e3])                                  # 32x32x512
    
    d2 = residual_block(d2, 256)                                               # Ensure it has 256 channels
    
    d3 = keras.layers.Conv2DTranspose(128, (4, 4), strides=2, padding='same')(d2)  # 64x64x128
    d3 = keras.layers.BatchNormalization()(d3)
    d3 = keras.layers.ReLU()(d3)
    d3 = keras.layers.Concatenate()([d3, e2])                                  # 64x64x256
    
    d3 = residual_block(d3, 128)                                               # Ensure it has 128 channels
    
    d4 = keras.layers.Conv2DTranspose(64, (4, 4), strides=2, padding='same')(d3)  # 128x128x64
    d4 = keras.layers.BatchNormalization()(d4)
    d4 = keras.layers.ReLU()(d4)
    d4 = keras.layers.Concatenate()([d4, e1])                                  # 128x128x128
    
    d4 = residual_block(d4, 64)                                                # Ensure it has 64 channels
    
    # Final layer
    outputs = keras.layers.Conv2DTranspose(3, (4, 4), strides=2, padding='same', activation='tanh')(d4)  # 256x256x3
    
    return keras.Model(inputs, outputs)

# Call the generator function
generator = build_generator()
generator.summary()
# 4. Spectral Normalization Wrapper
def SpectralNormalization(layer):
    """Applies spectral normalization to a given layer."""
    return keras.layers.LayerNormalization()(layer)

# 5. Build Discriminator with Spectral Normalization
def build_discriminator():
    """
    Builds the discriminator model to classify real and generated images.
    Incorporates spectral normalization for stability.
    """
    input_img = keras.Input(shape=(32, 32, 1))
    target_img = keras.Input(shape=(32, 32, 3))
    x = keras.layers.Concatenate()([input_img, target_img])  # 32x32x4
    
    # Convolutional Layers with Spectral Normalization
    x = keras.layers.Conv2D(64, 4, strides=2, padding='same',
                            kernel_initializer=keras.initializers.RandomNormal(0, 0.02),
                            kernel_regularizer=keras.regularizers.l2(0.0001))(x)  # 16x16x64
    x = keras.layers.LeakyReLU(0.2)(x)
    
    x = keras.layers.Conv2D(128, 4, strides=2, padding='same',
                            kernel_initializer=keras.initializers.RandomNormal(0, 0.02),
                            kernel_regularizer=keras.regularizers.l2(0.0001))(x)  # 8x8x128
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.LeakyReLU(0.2)(x)
    
    x = keras.layers.Conv2D(256, 4, strides=2, padding='same',
                            kernel_initializer=keras.initializers.RandomNormal(0, 0.02),
                            kernel_regularizer=keras.regularizers.l2(0.0001))(x)  # 4x4x256
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.LeakyReLU(0.2)(x)
    
    # Final Layer
    x = keras.layers.Conv2D(1, 4, strides=1, padding='same',
                            kernel_initializer=keras.initializers.RandomNormal(0, 0.02),
                            kernel_regularizer=keras.regularizers.l2(0.0001))(x)    # 4x4x1
    
    return keras.Model([input_img, target_img], x, name="Discriminator")

# 6. Loss Functions
cross_entropy = keras.losses.BinaryCrossentropy(from_logits=True)

def generator_loss(disc_generated_output, gen_output, target, l1_weight=10):
    """
    Calculates generator loss which is a combination of GAN loss and L1 loss.
    """
    gan_loss = cross_entropy(tf.ones_like(disc_generated_output), disc_generated_output)
    l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
    total_loss = gan_loss + (l1_weight * l1_loss)
    return total_loss, gan_loss, l1_loss

def discriminator_loss(disc_real_output, disc_generated_output):
    """
    Calculates discriminator loss.
    """
    real_loss = cross_entropy(tf.ones_like(disc_real_output), disc_real_output)
    generated_loss = cross_entropy(tf.zeros_like(disc_generated_output), disc_generated_output)
    total_loss = real_loss + generated_loss
    return total_loss

# 7. Training Step
@tf.function
def train_step(input_image, target, generator, discriminator, generator_optimizer, discriminator_optimizer, l1_weight=10):
    """
    Performs one training step for both generator and discriminator.
    """
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        # Generate output
        gen_output = generator(input_image, training=True)
        
        # Discriminator output
        disc_real_output = discriminator([input_image, target], training=True)
        disc_generated_output = discriminator([input_image, gen_output], training=True)
        
        # Calculate losses
        gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(
            disc_generated_output, gen_output, target, l1_weight=l1_weight
        )
        disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
    
    # Calculate gradients
    generator_gradients = gen_tape.gradient(gen_total_loss, generator.trainable_variables)
    discriminator_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    
    # Apply gradients
    generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables))
    
    return gen_total_loss, disc_loss

# 8. Image Generation and Saving
def generate_and_save_images(model, epoch, test_input, save_dir='output', max_images=16):
    """
    Generates images using the generator and saves them to disk.
    """
    predictions = model(test_input, training=False)
    
    num_images = min(predictions.shape[0], max_images)
    grid_size = math.ceil(math.sqrt(num_images))
    
    fig = plt.figure(figsize=(12, 12))
    
    for i in range(num_images):
        plt.subplot(grid_size, grid_size, i+1)
        img = (predictions[i] * 0.5) + 0.5  # Rescale to [0, 1]
        plt.imshow(img.numpy())
        plt.axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f'image_at_epoch_{epoch:04d}.png'))
    plt.close()

# 9. Evaluation Metrics
def calculate_metrics(real_images, generated_images, win_size=7):
    """
    Calculates average SSIM and PSNR between real and generated images.
    
    Args:
        real_images (tf.Tensor): Batch of real images.
        generated_images (tf.Tensor): Batch of generated images.
        win_size (int): Window size for SSIM. Must be odd and <= min(image dimensions).
    
    Returns:
        tuple: (average SSIM, average PSNR)
    """
    ssim_scores = []
    psnr_scores = []
    
    for real, generated in tqdm(zip(real_images, generated_images), desc="Calculating metrics", total=len(real_images)):
        # Rescale images from [-1, 1] to [0, 255]
        real = ((real * 0.5) + 0.5) * 255
        generated = ((generated * 0.5) + 0.5) * 255
        
        # Convert tensors to NumPy arrays
        real = real.numpy().astype(np.uint8)
        generated = generated.numpy().astype(np.uint8)
        
        # Ensure images have 3 channels
        if real.ndim == 3 and real.shape[-1] == 1:
            real = np.squeeze(real, axis=-1)
        if generated.ndim ==3 and generated.shape[-1] ==1:
            generated = np.squeeze(generated, axis=-1)
        
        # Calculate SSIM with updated parameters
        try:
            ssim_val = ssim(
                real, 
                generated, 
                win_size=win_size, 
                multichannel=True
            )
        except ValueError as e:
            print(f"SSIM calculation error: {e}")
            ssim_val = 0  # Assign a default value or handle as needed
        
        # Calculate PSNR
        psnr_val = psnr(
            real, 
            generated, 
            data_range=255
        )
        
        ssim_scores.append(ssim_val)
        psnr_scores.append(psnr_val)
    
    return np.mean(ssim_scores), np.mean(psnr_scores)

# 10. Plotting Function
def plot_metrics(train_csv, eval_csv, output_dir='output'):
    """
    Plots training and evaluation metrics from CSV files.
    
    Args:
        train_csv (str): Path to the training metrics CSV file.
        eval_csv (str): Path to the evaluation metrics CSV file.
        output_dir (str): Directory to save the plots.
    """
    # Ensure the output directory exists
    os.makedirs(output_dir, exist_ok=True)
    
    # Load training metrics
    train_epochs = []
    gen_losses = []
    disc_losses = []
    with open(train_csv, 'r') as file:
        reader = csv.DictReader(file)
        for row in reader:
            train_epochs.append(int(row['Epoch']))
            gen_losses.append(float(row['Generator Loss']))
            disc_losses.append(float(row['Discriminator Loss']))
    
    # Load evaluation metrics
    eval_epochs = []
    ssim_scores = []
    psnr_scores = []
    with open(eval_csv, 'r') as file:
        reader = csv.DictReader(file)
        for row in reader:
            eval_epochs.append(int(row['Epoch']))
            ssim_scores.append(float(row['SSIM Score']))
            psnr_scores.append(float(row['PSNR Score']))
    
    # Plot Generator and Discriminator Losses
    plt.figure(figsize=(10, 6))
    plt.plot(train_epochs, gen_losses, label='Generator Loss')
    plt.plot(train_epochs, disc_losses, label='Discriminator Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Losses')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(output_dir, 'training_losses.png'))
    plt.close()
    
    # Plot SSIM and PSNR Scores
    plt.figure(figsize=(10, 6))
    plt.plot(eval_epochs, ssim_scores, label='SSIM Score')
    plt.plot(eval_epochs, psnr_scores, label='PSNR Score')
    plt.xlabel('Epoch')
    plt.ylabel('Score')
    plt.title('Evaluation Metrics')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(output_dir, 'evaluation_metrics.png'))
    plt.close()
    
    print(f"Plots saved in '{output_dir}' directory.")

# 11. Main Execution with Enhancements
def main():
    # Set random seed for reproducibility (optional)
    tf.random.set_seed(42)
    np.random.seed(42)
    
    # Create a directory to save outputs if not exists
    os.makedirs('output', exist_ok=True)
    
    # Initialize CSV files
    train_csv_file = os.path.join('output', 'training_metrics.csv')
    eval_csv_file = os.path.join('output', 'evaluation_metrics.csv')
    
    with open(train_csv_file, mode='w', newline='') as file:
        writer = csv.writer(file)
        # Write headers
        writer.writerow(['Epoch', 'Generator Loss', 'Discriminator Loss'])
    
    with open(eval_csv_file, mode='w', newline='') as file:
        writer = csv.writer(file)
        # Write headers
        writer.writerow(['Epoch', 'SSIM Score', 'PSNR Score'])
    
    # Load and preprocess data
    train_dataset, val_dataset, test_dataset = load_and_preprocess_data()
    
    # Build models
    generator = build_generator()
    discriminator = build_discriminator()
    
    # Define optimizers with lower learning rates
    generator_optimizer = keras.optimizers.Adam(1e-4, beta_1=0.5)
    discriminator_optimizer = keras.optimizers.Adam(1e-4, beta_1=0.5)
    
    # Define number of epochs and patience for early stopping
    epochs = 100
    patience = 15  # For early stopping
    
    # Setup TensorBoard
    log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    tensorboard_callback = keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
    
    # Setup Model Checkpointing
    checkpoint_dir = './checkpoints'
    os.makedirs(checkpoint_dir, exist_ok=True)
    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
    checkpoint = tf.train.Checkpoint(
        generator=generator,
        discriminator=discriminator,
        generator_optimizer=generator_optimizer,
        discriminator_optimizer=discriminator_optimizer
    )
    
    # Select a fixed test set for consistent evaluation
    test_grayscale, test_color = next(iter(test_dataset))
    
    # Initialize variables for early stopping
    best_val_loss = float('inf')
    wait = 0
    
    # Training Loop
    for epoch in tqdm(range(1, epochs + 1), desc="Training epochs"):
        gen_losses = []
        disc_losses = []
        
        # Training
        for input_image, target in tqdm(train_dataset, desc=f"Epoch {epoch}/{epochs}", leave=False):
            gen_total_loss, disc_loss = train_step(
                input_image, target, generator, discriminator, 
                generator_optimizer, discriminator_optimizer, 
                l1_weight=10  # Adjusted L1 weight
            )
            gen_losses.append(gen_total_loss)
            disc_losses.append(disc_loss)
        
        # Calculate average losses
        avg_gen_loss = tf.reduce_mean(gen_losses).numpy()
        avg_disc_loss = tf.reduce_mean(disc_losses).numpy()
        
        # Validation
        val_gen_losses = []
        val_disc_losses = []
        for val_input, val_target in val_dataset:
            # Generate output
            gen_output = generator(val_input, training=False)
            
            # Discriminator output
            disc_real_output = discriminator([val_input, val_target], training=False)
            disc_generated_output = discriminator([val_input, gen_output], training=False)
            
            # Calculate losses
            val_gen_total_loss, val_gen_gan_loss, val_gen_l1_loss = generator_loss(
                disc_generated_output, gen_output, val_target, l1_weight=10
            )
            val_disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
            
            val_gen_losses.append(val_gen_total_loss)
            val_disc_losses.append(val_disc_loss)
        
        # Calculate average validation losses
        avg_val_gen_loss = tf.reduce_mean(val_gen_losses).numpy()
        avg_val_disc_loss = tf.reduce_mean(val_disc_losses).numpy()
        avg_val_loss = avg_val_gen_loss + avg_val_disc_loss
        
        # Log training losses to CSV
        with open(train_csv_file, mode='a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow([epoch, avg_gen_loss, avg_disc_loss])
        
        # Log validation losses to TensorBoard
        with tf.summary.create_file_writer(log_dir).as_default():
            tf.summary.scalar('Generator Loss (Train)', avg_gen_loss, step=epoch)
            tf.summary.scalar('Discriminator Loss (Train)', avg_disc_loss, step=epoch)
            tf.summary.scalar('Generator Loss (Val)', avg_val_gen_loss, step=epoch)
            tf.summary.scalar('Discriminator Loss (Val)', avg_val_disc_loss, step=epoch)
            tf.summary.scalar('Total Validation Loss', avg_val_loss, step=epoch)
        
        # Print metrics
        tqdm.write(f"Epoch {epoch}/{epochs}")
        tqdm.write(f"Generator Loss: {avg_gen_loss:.4f}, Discriminator Loss: {avg_disc_loss:.4f}")
        tqdm.write(f"Validation Generator Loss: {avg_val_gen_loss:.4f}, Validation Discriminator Loss: {avg_val_disc_loss:.4f}")
        
        # Check for improvement
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            wait = 0
            # Save the best models
            checkpoint.save(file_prefix=checkpoint_prefix)
            tqdm.write("Validation loss improved. Checkpoint saved.")
        else:
            wait += 1
            tqdm.write(f"No improvement in validation loss for {wait} epochs.")
            if wait >= patience:
                tqdm.write("Early stopping triggered.")
                break
        
        # Every 10 epochs, perform evaluation
        if epoch % 10 == 0:
            # Generate images
            generate_and_save_images(generator, epoch, test_grayscale)
            
            # Generate output
            generated_images = generator(test_grayscale, training=False)
            
            # Calculate metrics
            ssim_score, psnr_score = calculate_metrics(test_color, generated_images)
            tqdm.write(f"SSIM Score: {ssim_score:.4f}")
            tqdm.write(f"PSNR Score: {psnr_score:.4f}")
            
            # Log evaluation metrics to separate CSV
            with open(eval_csv_file, mode='a', newline='') as file:
                writer = csv.writer(file)
                writer.writerow([epoch, ssim_score, psnr_score])
            
            # Log evaluation metrics to TensorBoard
            with tf.summary.create_file_writer(log_dir).as_default():
                tf.summary.scalar('SSIM Score', ssim_score, step=epoch)
                tf.summary.scalar('PSNR Score', psnr_score, step=epoch)
    
# 12. Final Evaluation and Plotting
def finalize_training(checkpoint_dir='./checkpoints', train_csv='output/training_metrics.csv', eval_csv='output/evaluation_metrics.csv'):
    """
    Restores the best model, performs final evaluation, and plots metrics.
    """
    # Restore the best checkpoint
    latest_ckpt = tf.train.latest_checkpoint(checkpoint_dir)
    if latest_ckpt:
        print(f"Restoring from checkpoint: {latest_ckpt}")
        checkpoint.restore(latest_ckpt)
    else:
        print("No checkpoint found. Proceeding without restoring.")
    
    # Load test dataset
    _, _, test_dataset = load_and_preprocess_data()
    test_grayscale, test_color = next(iter(test_dataset))
    
    # Generate images
    generated_images = generator(test_grayscale, training=False)
    
    # Calculate metrics
    ssim_score, psnr_score = calculate_metrics(test_color, generated_images)
    print(f"Final SSIM Score: {ssim_score:.4f}")
    print(f"Final PSNR Score: {psnr_score:.4f}")
    
    # Append final evaluation metrics to CSV
    with open(eval_csv, mode='a', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(['Final Evaluation', ssim_score, psnr_score])
    
    # Save the final models
    generator.save('output/final_generator.h5')
    discriminator.save('output/final_discriminator.h5')
    print("Final models saved.")
    
    # Plot metrics
    plot_metrics(train_csv, eval_csv, output_dir='output')

# Run the main function and finalize training
if __name__ == "__main__":
    main()
    # After training completes, perform final evaluation and plotting
    finalize_training(
        checkpoint_dir='./checkpoints',
        train_csv='output/training_metrics.csv',
        eval_csv='output/evaluation_metrics.csv',
        output_dir='output'
    )


Preprocessing data:   0%|          | 1/10000 [00:00<05:13, 31.88it/s]


TypeError: in user code:

    File "/var/folders/61/x8wy5qqd0rx0wn1t_91hms500000gn/T/ipykernel_4331/951804591.py", line 53, in preprocess_and_update  *
        img = preprocess_image(img)

    TypeError: tf__preprocess_image() missing 1 required positional argument: 'target'
