In [1]:
import cv2
import numpy as np
import os
from glob import glob
import mlflow
import tensorflow as tf
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import random

In [2]:
mlflow.set_experiment("Image Inpainting GAN1")  # Create or select an experiment
mlflow.tensorflow.autolog()  # Automatically logs TensorFlow metrics and parameters


In [3]:
def augment_images(input_image, gt_image):
    # Apply random flipping and rotations
    input_image = tf.image.random_flip_left_right(input_image)
    input_image = tf.image.random_flip_up_down(input_image)
    input_image = tf.image.rot90(input_image, k=random.randint(0, 3))
    
    gt_image = tf.image.random_flip_left_right(gt_image)
    gt_image = tf.image.random_flip_up_down(gt_image)
    gt_image = tf.image.rot90(gt_image, k=random.randint(0, 3))

    return input_image, gt_image

In [4]:
def load_image_pairs(input_folder, gt_folder, img_size=(200, 200), batch_size=32, max_images=None):
    """
    Load paired images for training (input with holes and ground truth), with an optional limit.
    Args:
        input_folder (str): Path to images with holes (e.g., HoledImages).
        gt_folder (str): Path to original images (e.g., VegetableImages).
        img_size (tuple): Size to resize images (default is 128x128).
        batch_size (int): Number of images in a batch.
        max_images (int, optional): Maximum number of images to include. Default is None (no limit).
    Returns:
        tf.data.Dataset: Dataset with paired images.
    """
    def parse_pair(input_path, gt_path):
        # Load and preprocess input image
        input_img = tf.io.read_file(input_path)
        input_img = tf.image.decode_jpeg(input_img, channels=3)
        input_img = tf.image.resize(input_img, img_size) / 255.0  # Normalize
        
        # Load and preprocess ground truth image
        gt_img = tf.io.read_file(gt_path)
        gt_img = tf.image.decode_jpeg(gt_img, channels=3)
        gt_img = tf.image.resize(gt_img, img_size) / 255.0  # Normalize
        
        return input_img, gt_img

    # Get paired file paths
    input_paths = sorted(glob(f"{input_folder}/**/*.jpg", recursive=True))
    gt_paths = sorted(glob(f"{gt_folder}/**/*.jpg", recursive=True))
    
    # Limit the number of images
    if max_images is not None:
        input_paths = input_paths[:max_images]
        gt_paths = gt_paths[:max_images]
    
    # Create the dataset
    dataset = tf.data.Dataset.from_tensor_slices((input_paths, gt_paths))
    dataset = dataset.map(lambda x, y: parse_pair(x, y))
    dataset = dataset.map(augment_images)
    dataset = dataset.shuffle(buffer_size=1000).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return dataset


In [5]:
train_data = load_image_pairs(
    "./Data/HoledImages/train",
    "./Data/VegetableImages/train",
    max_images=200  # Limit to 1000 images
)
val_data = load_image_pairs(
    "./Data/HoledImages/validation",
    "./Data/VegetableImages/validation",
    max_images=100
)
test_data = load_image_pairs(
    "./Data/HoledImages/test",
    "./Data/VegetableImages/test",
    max_images=100
)


In [6]:
def create_mask(input_images):
    """
    Create a mask for black square regions in the input images.
    Args:
        input_images (tf.Tensor): Batch of input images with holes.
    Returns:
        tf.Tensor: Binary mask with 1s for black square regions and 0s elsewhere.
    """
    # Black squares are assumed to have pixel values close to 0
    mask = tf.cast(tf.reduce_all(input_images == 0, axis=-1, keepdims=True), tf.float32)
    return mask


In [7]:
# def create_mask(input_images, max_percentage=10):
#     """
#     Create a mask with a single black square in the center, matching the black square added during preprocessing.

#     Args:
#         input_images (tf.Tensor): Batch of input images.
#         max_percentage (float): Maximum percentage of the image area covered by the square.

#     Returns:
#         tf.Tensor: Binary mask with 0s for the black square and 1s elsewhere.
#     """
#     batch_size, img_height, img_width, _ = input_images.shape
#     mask = np.ones((batch_size, img_height, img_width, 1), dtype=np.float32)

#     # Calculate the square size
#     max_area = (img_height * img_width) * (max_percentage / 100)
#     side_length = int(max_area ** 0.5)

#     # Center position of the square
#     y0 = (img_height - side_length) // 2
#     x0 = (img_width - side_length) // 2
#     y1 = y0 + side_length
#     x1 = x0 + side_length

#     # Apply the black square mask
#     for i in range(batch_size):
#         mask[i, y0:y1, x0:x1, :] = 0

#     return tf.convert_to_tensor(mask, dtype=tf.float32)


In [8]:
# def visualize_black_square_and_mask(image_path, mask):
#     """
#     Visualize the image with a black square and the corresponding mask.

#     Args:
#         image_path (str): Path to the input image with a black square.
#         mask (tf.Tensor): Generated mask tensor.
#     """
#     # Decode and preprocess the image
#     image = tf.image.decode_jpeg(tf.io.read_file(image_path))
#     image = tf.image.resize(image, (200, 200))  # Resize to match the mask size
#     image = tf.cast(image, tf.float32) / 255.0  # Cast to float32 and normalize

#     plt.figure(figsize=(10, 5))

#     # Display the image
#     plt.subplot(1, 2, 1)
#     plt.imshow(image)
#     plt.title("Image with Black Square")
#     plt.axis("off")

#     # Display the mask
#     plt.subplot(1, 2, 2)
#     plt.imshow(mask[0, :, :, 0], cmap="gray")
#     plt.title("Mask")
#     plt.axis("off")

#     plt.show()


# # Example
# image_path = "./Data/HoledImages/train/Bean/0026.jpg"  # Replace with a valid image path
# sample_image = tf.image.decode_jpeg(tf.io.read_file(image_path))
# sample_image = tf.image.resize(sample_image, (200, 200))
# sample_image = tf.cast(sample_image, tf.float32) / 255.0  # Cast to float32 and normalize
# sample_image = tf.expand_dims(sample_image, axis=0)  # Add batch dimension

# # Create a mask for the sample image
# mask = create_mask(sample_image, max_percentage=10)
# visualize_black_square_and_mask(image_path, mask)


In [9]:
# def build_generator(img_size=(200, 200, 3)):
#     inputs = layers.Input(shape=img_size)

#     # Encoder
#     x1 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
#     x1 = layers.BatchNormalization()(x1)
#     x1 = layers.MaxPooling2D((2, 2))(x1)  # Output: (100, 100, 64)

#     x2 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x1)
#     x2 = layers.BatchNormalization()(x2)
#     x2 = layers.MaxPooling2D((2, 2))(x2)  # Output: (50, 50, 128)

#     # Bottleneck
#     x3 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(x2)
#     x3 = layers.BatchNormalization()(x3)
#     x3 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(x3)
#     x3 = layers.Dropout(0.3)(x3)  # Output: (50, 50, 256)

#     # Decoder with skip connections
#     x4 = layers.UpSampling2D((2, 2))(x3)  # Output: (100, 100, 256)
#     x4 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x4)
#     x4 = layers.BatchNormalization()(x4)

#     # Align spatial dimensions of `x2_resized` to match `x4`
#     x2_resized = layers.Conv2D(128, (1, 1), activation='relu', padding='same')(x2)
#     x2_resized = layers.UpSampling2D((2, 2))(x2_resized)  # Now (100, 100, 128)
#     x4 = layers.Add()([x4, x2_resized])  # Now compatible for addition

#     x5 = layers.UpSampling2D((2, 2))(x4)  # Output: (200, 200, 128)
#     x5 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x5)
#     x5 = layers.BatchNormalization()(x5)

#     # Align spatial dimensions of `x1_resized` to match `x5`
#     x1_resized = layers.Conv2D(64, (1, 1), activation='relu', padding='same')(x1)  # (100, 100, 64)
#     x1_resized = layers.UpSampling2D((2, 2))(x1_resized)  # Now (200, 200, 64)
#     x5 = layers.Add()([x5, x1_resized])  # Now compatible for addition

#     outputs = layers.Conv2D(3, (3, 3), activation='sigmoid', padding='same')(x5)  # Output: (200, 200, 3)

#     return tf.keras.Model(inputs, outputs, name="Generator")


# generator = build_generator()
# generator.summary()


In [10]:
def build_generator(img_size=(200, 200, 3)):
    """
    Generator model: U-Net-style architecture.
    Args:
        img_size (tuple): Input image size (default is (200, 200, 3)).
    Returns:
        keras.Model: Generator model.
    """
    inputs = layers.Input(shape=img_size)

    # Encoder
    x1 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    x1 = layers.MaxPooling2D((2, 2))(x1)

    x2 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x1)
    x2 = layers.MaxPooling2D((2, 2))(x2)

    # Bottleneck
    x3 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(x2)
    x3 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(x3)

    # Decoder
    x4 = layers.UpSampling2D((2, 2))(x3)
    x4 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x4)

    x5 = layers.UpSampling2D((2, 2))(x4)
    outputs = layers.Conv2D(3, (3, 3), activation='sigmoid', padding='same')(x5)

    return tf.keras.Model(inputs, outputs, name="Generator")

generator = build_generator()
generator.summary()


In [11]:
def build_discriminator(img_size=(200, 200, 3)):
    """
    Discriminator model: Patch-based classification.
    Args:
        img_size (tuple): Input image size (default is (200, 200, 3)).
    Returns:
        keras.Model: Discriminator model.
    """
    inputs = layers.Input(shape=img_size)

    x = layers.Conv2D(64, (4, 4), strides=2, activation='relu', padding='same')(inputs)
    x = layers.Conv2D(128, (4, 4), strides=2, activation='relu', padding='same')(x)
    x = layers.Conv2D(256, (4, 4), strides=2, activation='relu', padding='same')(x)
    x = layers.Flatten()(x)
    outputs = layers.Dense(1, activation='sigmoid')(x)

    return tf.keras.Model(inputs, outputs, name="Discriminator")

discriminator = build_discriminator()
discriminator.summary()


In [12]:
def build_gan(generator, discriminator):
    """
    Combines generator and discriminator into a GAN model.
    Args:
        generator (keras.Model): Generator model.
        discriminator (keras.Model): Discriminator model.
    Returns:
        keras.Model: Combined GAN model.
    """
    # discriminator.trainable = False  # Freeze discriminator for GAN training

    gan_input = layers.Input(shape=(200, 200, 3))
    generated_image = generator(gan_input)
    gan_output = discriminator(generated_image)

    return tf.keras.Model(gan_input, gan_output, name="GAN")

gan = build_gan(generator, discriminator)
gan.summary()


In [13]:
# Define the GAN loss function
def gan_loss(real_images, generated_images, fake_output):
    # Adversarial loss (encourages realistic generation)
    adversarial_loss = tf.keras.losses.BinaryCrossentropy()(tf.ones_like(fake_output), fake_output)
    # Reconstruction loss (encourages similarity to ground truth)
    reconstruction_loss = tf.keras.losses.MeanSquaredError()(real_images, generated_images)
    return adversarial_loss + reconstruction_loss



In [14]:
# Add learning rate schedulers
lr_schedule_gen = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=1e-4, decay_steps=10000, decay_rate=0.96, staircase=True
)
lr_schedule_disc = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=1e-4, decay_steps=10000, decay_rate=0.96, staircase=True
)

# Use the schedulers in optimizers
gen_optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule_gen, beta_1=0.5)
disc_optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule_disc, beta_1=0.5)


In [15]:
# Compile the discriminator
discriminator.compile(
    loss='binary_crossentropy',
    optimizer=disc_optimizer,
    metrics=['accuracy']
)

# Compile the GAN
gan.compile(
    loss=gan_loss,
    optimizer=gen_optimizer
)


In [16]:
print(f"Discriminator trainable variables: {len(discriminator.trainable_variables)}")
print(f"Generator trainable variables: {len(generator.trainable_variables)}")


Discriminator trainable variables: 8
Generator trainable variables: 12


In [17]:
@tf.function
def train_step(generator, discriminator, input_images, gt_images, gen_optimizer, disc_optimizer):
    """
    Perform one training step for both generator and discriminator with masking for inpainting.
    """
    # Create a mask for black square regions
    mask = create_mask(input_images)

    # Ensure discriminator is trainable
    discriminator.trainable = True

    # Train discriminator
    with tf.GradientTape() as disc_tape:
        # Generate inpainted images
        generated_images = generator(input_images, training=True)

        # Merge inpainted regions with unmasked regions
        inpainted_images = mask * generated_images + (1 - mask) * input_images

        real_output = discriminator(gt_images, training=True)
        fake_output = discriminator(inpainted_images, training=True)

        disc_loss = (
            tf.keras.losses.BinaryCrossentropy()(tf.ones_like(real_output), real_output) +
            tf.keras.losses.BinaryCrossentropy()(tf.zeros_like(fake_output), fake_output)
        )

    gradients_disc = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    disc_optimizer.apply_gradients(zip(gradients_disc, discriminator.trainable_variables))

    # Train generator
    discriminator.trainable = False  # Freeze discriminator for GAN training
    with tf.GradientTape() as gen_tape:
        # Generate inpainted images
        generated_images = generator(input_images, training=True)

        # Merge inpainted regions with unmasked regions
        inpainted_images = mask * generated_images + (1 - mask) * input_images

        fake_output = discriminator(inpainted_images, training=False)

        # Calculate the generator loss
        gen_loss = (
            tf.keras.losses.BinaryCrossentropy()(tf.ones_like(fake_output), fake_output) +  # Adversarial loss
            tf.keras.losses.MeanSquaredError()(mask * gt_images, mask * generated_images)  # Reconstruction loss
        )

    gradients_gen = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gen_optimizer.apply_gradients(zip(gradients_gen, generator.trainable_variables))

    return gen_loss, disc_loss


In [18]:
def visualize_and_log_images(generator, input_images, gt_images, log_path, epoch=None, step=None, prefix="test_results"):
    """
    Visualize sample inputs, inpainted outputs, and ground truth; save to MLflow.
    """
    generated_images = generator(input_images, training=False)
    mask = create_mask(input_images)
    inpainted_images = mask * generated_images + (1 - mask) * input_images  # Merge inpainted regions

    num_samples = min(5, input_images.shape[0])
    fig, axes = plt.subplots(num_samples, 3, figsize=(12, num_samples * 3))

    for i in range(num_samples):
        axes[i, 0].imshow(input_images[i].numpy())
        axes[i, 0].set_title("Input (With Holes)")
        axes[i, 1].imshow(inpainted_images[i].numpy())
        axes[i, 1].set_title("Inpainted (Generated)")
        axes[i, 2].imshow(gt_images[i].numpy())
        axes[i, 2].set_title("Ground Truth")
        for ax in axes[i]:
            ax.axis("off")

    plt.tight_layout()
    file_name = f"{prefix}_epoch_{int(epoch)+1}_step_{int(step)+1}.png" if epoch is not None and step is not None else f"{prefix}.png"
    output_path = os.path.join(log_path, file_name)
    plt.savefig(output_path)
    plt.close()
    mlflow.log_artifact(output_path, artifact_path="visualizations")

    print(f"Visualization saved and logged: {output_path}")


In [19]:
EPOCHS = 60
log_path = "./training_logs"
os.makedirs(log_path, exist_ok=True)

with mlflow.start_run():
    # Log hyperparameters
    mlflow.log_param("epochs", EPOCHS)
    mlflow.log_param("learning_rate", 1e-4)
    mlflow.log_param("image_size", (200, 200))
    mlflow.log_param("batch_size", 32)

    for epoch in range(EPOCHS):
        step = 0
        gen_loss_epoch = 0
        disc_loss_epoch = 0
        
        for input_images, gt_images in train_data:  # Iterate over training batches
            gen_loss, disc_loss = train_step(generator, discriminator, input_images, gt_images, gen_optimizer, disc_optimizer)
            
            gen_loss_epoch += gen_loss.numpy()
            disc_loss_epoch += disc_loss.numpy()
            step += 1

            # Visualize and log images every 100 steps
            if step % 5 == 0:
                visualize_and_log_images(generator, input_images, gt_images, log_path, epoch, step)

        # Average losses
        gen_loss_epoch /= step
        disc_loss_epoch /= step

        # Log metrics to MLflow
        mlflow.log_metric("generator_loss", gen_loss_epoch, step=epoch)
        mlflow.log_metric("discriminator_loss", disc_loss_epoch, step=epoch)

        # Save models every 10 epochs
        if (epoch + 1) % 10 == 0:
            gen_model_path = f"generator_epoch_{epoch+1}.h5"
            disc_model_path = f"discriminator_epoch_{epoch+1}.h5"
            generator.save(gen_model_path)
            discriminator.save(disc_model_path)
            mlflow.log_artifact(gen_model_path, artifact_path="models")
            mlflow.log_artifact(disc_model_path, artifact_path="models")

        print(f"Epoch {epoch+1}/{EPOCHS} | Gen Loss: {gen_loss_epoch:.4f}, Disc Loss: {disc_loss_epoch:.4f}")


Visualization saved and logged: ./training_logs\test_results_epoch_1_step_6.png
Epoch 1/20 | Gen Loss: 0.7453, Disc Loss: 1.3314
Visualization saved and logged: ./training_logs\test_results_epoch_2_step_6.png
Epoch 2/20 | Gen Loss: 0.9212, Disc Loss: 1.1567
Visualization saved and logged: ./training_logs\test_results_epoch_3_step_6.png
Epoch 3/20 | Gen Loss: 1.2391, Disc Loss: 1.0963
Visualization saved and logged: ./training_logs\test_results_epoch_4_step_6.png
Epoch 4/20 | Gen Loss: 1.3911, Disc Loss: 0.6953
Visualization saved and logged: ./training_logs\test_results_epoch_5_step_6.png
Epoch 5/20 | Gen Loss: 1.9077, Disc Loss: 0.3981
Visualization saved and logged: ./training_logs\test_results_epoch_6_step_6.png
Epoch 6/20 | Gen Loss: 3.2542, Disc Loss: 0.2513
Visualization saved and logged: ./training_logs\test_results_epoch_7_step_6.png
Epoch 7/20 | Gen Loss: 4.0303, Disc Loss: 0.1168
Visualization saved and logged: ./training_logs\test_results_epoch_8_step_6.png
Epoch 8/20 | Gen 



Epoch 10/20 | Gen Loss: 3.9407, Disc Loss: 0.2680
Visualization saved and logged: ./training_logs\test_results_epoch_11_step_6.png
Epoch 11/20 | Gen Loss: 4.1008, Disc Loss: 0.0510
Visualization saved and logged: ./training_logs\test_results_epoch_12_step_6.png
Epoch 12/20 | Gen Loss: 4.1702, Disc Loss: 0.0639
Visualization saved and logged: ./training_logs\test_results_epoch_13_step_6.png
Epoch 13/20 | Gen Loss: 4.1942, Disc Loss: 0.0638
Visualization saved and logged: ./training_logs\test_results_epoch_14_step_6.png


KeyboardInterrupt: 

In [None]:
test_gen_loss = 0
test_disc_loss = 0
test_steps = 0

log_path = "./test_visualizations"
os.makedirs(log_path, exist_ok=True)

for i, (test_input_images, test_gt_images) in enumerate(test_data):
    generated_images = generator(test_input_images, training=False)
    fake_output = discriminator(generated_images, training=False)
    real_output = discriminator(test_gt_images, training=False)

    test_gen_loss += gan_loss(test_gt_images, generated_images, fake_output).numpy()
    test_disc_loss += (
        tf.keras.losses.BinaryCrossentropy()(tf.ones_like(real_output), real_output).numpy() +
        tf.keras.losses.BinaryCrossentropy()(tf.zeros_like(fake_output), fake_output).numpy()
    )
    test_steps += 1

    # Visualize and log test images (limit to 5 batches for clarity)
    if i < 5:
        visualize_and_log_images(
            generator, test_input_images, test_gt_images, log_path, prefix=f"test_batch_{i+1}"
        )

test_gen_loss /= test_steps
test_disc_loss /= test_steps

print(f"Test Gen Loss: {test_gen_loss:.4f}, Test Disc Loss: {test_disc_loss:.4f}")
mlflow.log_metric("test_generator_loss", test_gen_loss)
mlflow.log_metric("test_discriminator_loss", test_disc_loss)


Visualization saved and logged: ./test_visualizations\test_batch_1.png
Visualization saved and logged: ./test_visualizations\test_batch_2.png
Visualization saved and logged: ./test_visualizations\test_batch_3.png
Visualization saved and logged: ./test_visualizations\test_batch_4.png
Test Gen Loss: 1.4358, Test Disc Loss: 0.3656


In [40]:
mlflow.end_run()