In [1]:
import cv2
import numpy as np
import os
from glob import glob
import mlflow
import tensorflow as tf
from tensorflow.keras import layers
import matplotlib.pyplot as plt


In [2]:
mlflow.set_experiment("Image Inpainting GAN1")  # Create or select an experiment
mlflow.tensorflow.autolog()  # Automatically logs TensorFlow metrics and parameters


In [3]:
def load_image_pairs(input_folder, gt_folder, img_size=(200, 200), batch_size=32, max_images=None):
    """
    Load paired images for training (input with holes and ground truth), with an optional limit.
    Args:
        input_folder (str): Path to images with holes (e.g., HoledImages).
        gt_folder (str): Path to original images (e.g., VegetableImages).
        img_size (tuple): Size to resize images (default is 128x128).
        batch_size (int): Number of images in a batch.
        max_images (int, optional): Maximum number of images to include. Default is None (no limit).
    Returns:
        tf.data.Dataset: Dataset with paired images.
    """
    def parse_pair(input_path, gt_path):
        # Load and preprocess input image
        input_img = tf.io.read_file(input_path)
        input_img = tf.image.decode_jpeg(input_img, channels=3)
        input_img = tf.image.resize(input_img, img_size) / 255.0  # Normalize
        
        # Load and preprocess ground truth image
        gt_img = tf.io.read_file(gt_path)
        gt_img = tf.image.decode_jpeg(gt_img, channels=3)
        gt_img = tf.image.resize(gt_img, img_size) / 255.0  # Normalize
        
        return input_img, gt_img

    # Get paired file paths
    input_paths = sorted(glob(f"{input_folder}/**/*.jpg", recursive=True))
    gt_paths = sorted(glob(f"{gt_folder}/**/*.jpg", recursive=True))
    
    # Limit the number of images
    if max_images is not None:
        input_paths = input_paths[:max_images]
        gt_paths = gt_paths[:max_images]
    
    # Create the dataset
    dataset = tf.data.Dataset.from_tensor_slices((input_paths, gt_paths))
    dataset = dataset.map(lambda x, y: parse_pair(x, y))
    dataset = dataset.shuffle(buffer_size=1000).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return dataset


In [4]:
train_data = load_image_pairs(
    "./Data/HoledImages/train",
    "./Data/VegetableImages/train",
    max_images=200  # Limit to 1000 images
)
val_data = load_image_pairs(
    "./Data/HoledImages/validation",
    "./Data/VegetableImages/validation",
    max_images=100
)
test_data = load_image_pairs(
    "./Data/HoledImages/test",
    "./Data/VegetableImages/test",
    max_images=100
)


In [5]:
def create_mask(input_images):
    """
    Create a mask for black square regions in the input images.
    Args:
        input_images (tf.Tensor): Batch of input images with holes.
    Returns:
        tf.Tensor: Binary mask with 1s for black square regions and 0s elsewhere.
    """
    # Black squares are assumed to have pixel values close to 0
    mask = tf.cast(tf.reduce_all(input_images < 0.1, axis=-1, keepdims=True), tf.float32)
    return mask


In [6]:
def build_generator(img_size=(200, 200, 3)):
    """
    Generator model: U-Net-style architecture.
    Args:
        img_size (tuple): Input image size (default is (200, 200, 3)).
    Returns:
        keras.Model: Generator model.
    """
    inputs = layers.Input(shape=img_size)

    # Encoder
    x1 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    x1 = layers.MaxPooling2D((2, 2))(x1)

    x2 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x1)
    x2 = layers.MaxPooling2D((2, 2))(x2)

    # Bottleneck
    x3 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(x2)
    x3 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(x3)

    # Decoder
    x4 = layers.UpSampling2D((2, 2))(x3)
    x4 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x4)

    x5 = layers.UpSampling2D((2, 2))(x4)
    outputs = layers.Conv2D(3, (3, 3), activation='sigmoid', padding='same')(x5)

    return tf.keras.Model(inputs, outputs, name="Generator")

generator = build_generator()
generator.summary()


In [7]:
def build_discriminator(img_size=(200, 200, 3)):
    """
    Discriminator model: Patch-based classification.
    Args:
        img_size (tuple): Input image size (default is (200, 200, 3)).
    Returns:
        keras.Model: Discriminator model.
    """
    inputs = layers.Input(shape=img_size)

    x = layers.Conv2D(64, (4, 4), strides=2, activation='relu', padding='same')(inputs)
    x = layers.Conv2D(128, (4, 4), strides=2, activation='relu', padding='same')(x)
    x = layers.Conv2D(256, (4, 4), strides=2, activation='relu', padding='same')(x)
    x = layers.Flatten()(x)
    outputs = layers.Dense(1, activation='sigmoid')(x)

    return tf.keras.Model(inputs, outputs, name="Discriminator")

discriminator = build_discriminator()
discriminator.summary()


In [8]:
def build_gan(generator, discriminator):
    """
    Combines generator and discriminator into a GAN model.
    Args:
        generator (keras.Model): Generator model.
        discriminator (keras.Model): Discriminator model.
    Returns:
        keras.Model: Combined GAN model.
    """
    # discriminator.trainable = False  # Freeze discriminator for GAN training

    gan_input = layers.Input(shape=(200, 200, 3))
    generated_image = generator(gan_input)
    gan_output = discriminator(generated_image)

    return tf.keras.Model(gan_input, gan_output, name="GAN")

gan = build_gan(generator, discriminator)
gan.summary()


In [9]:
# Define the GAN loss function
def gan_loss(real_images, generated_images, fake_output):
    # Adversarial loss (encourages realistic generation)
    adversarial_loss = tf.keras.losses.BinaryCrossentropy()(tf.ones_like(fake_output), fake_output)
    # Reconstruction loss (encourages similarity to ground truth)
    reconstruction_loss = tf.keras.losses.MeanSquaredError()(real_images, generated_images)
    return adversarial_loss + reconstruction_loss



In [10]:
gen_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4, beta_1=0.5)
disc_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4, beta_1=0.5)


In [11]:
# Compile the discriminator
discriminator.compile(
    loss='binary_crossentropy',
    optimizer=disc_optimizer,
    metrics=['accuracy']
)

# Compile the GAN
gan.compile(
    loss=gan_loss,
    optimizer=gen_optimizer
)


In [12]:
print(f"Discriminator trainable variables: {len(discriminator.trainable_variables)}")
print(f"Generator trainable variables: {len(generator.trainable_variables)}")


Discriminator trainable variables: 8
Generator trainable variables: 12


In [13]:
@tf.function
def train_step(generator, discriminator, input_images, gt_images, gen_optimizer, disc_optimizer):
    """
    Perform one training step for both generator and discriminator with masking for inpainting.
    """
    # Create a mask for black square regions
    mask = create_mask(input_images)

    # Ensure discriminator is trainable
    discriminator.trainable = True

    # Train discriminator
    with tf.GradientTape() as disc_tape:
        # Generate inpainted images
        generated_images = generator(input_images, training=True)

        # Merge inpainted regions with unmasked regions
        inpainted_images = mask * generated_images + (1 - mask) * input_images

        real_output = discriminator(gt_images, training=True)
        fake_output = discriminator(inpainted_images, training=True)

        disc_loss = (
            tf.keras.losses.BinaryCrossentropy()(tf.ones_like(real_output), real_output) +
            tf.keras.losses.BinaryCrossentropy()(tf.zeros_like(fake_output), fake_output)
        )

    gradients_disc = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    disc_optimizer.apply_gradients(zip(gradients_disc, discriminator.trainable_variables))

    # Train generator
    discriminator.trainable = False  # Freeze discriminator for GAN training
    with tf.GradientTape() as gen_tape:
        # Generate inpainted images
        generated_images = generator(input_images, training=True)

        # Merge inpainted regions with unmasked regions
        inpainted_images = mask * generated_images + (1 - mask) * input_images

        fake_output = discriminator(inpainted_images, training=False)

        # Calculate the generator loss
        gen_loss = (
            tf.keras.losses.BinaryCrossentropy()(tf.ones_like(fake_output), fake_output) +  # Adversarial loss
            tf.keras.losses.MeanSquaredError()(mask * gt_images, mask * generated_images)  # Reconstruction loss
        )

    gradients_gen = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gen_optimizer.apply_gradients(zip(gradients_gen, generator.trainable_variables))

    return gen_loss, disc_loss


In [14]:
def visualize_and_log_images(generator, input_images, gt_images, log_path, epoch=None, step=None, prefix="test_results"):
    """
    Visualize sample inputs, inpainted outputs, and ground truth; save to MLflow.
    """
    generated_images = generator(input_images, training=False)
    mask = create_mask(input_images)
    inpainted_images = mask * generated_images + (1 - mask) * input_images  # Merge inpainted regions

    num_samples = min(5, input_images.shape[0])
    fig, axes = plt.subplots(num_samples, 3, figsize=(12, num_samples * 3))

    for i in range(num_samples):
        axes[i, 0].imshow(input_images[i].numpy())
        axes[i, 0].set_title("Input (With Holes)")
        axes[i, 1].imshow(inpainted_images[i].numpy())
        axes[i, 1].set_title("Inpainted (Generated)")
        axes[i, 2].imshow(gt_images[i].numpy())
        axes[i, 2].set_title("Ground Truth")
        for ax in axes[i]:
            ax.axis("off")

    plt.tight_layout()
    file_name = f"{prefix}_epoch_{int(epoch)+1}_step_{int(step)+1}.png" if epoch is not None and step is not None else f"{prefix}.png"
    output_path = os.path.join(log_path, file_name)
    plt.savefig(output_path)
    plt.close()
    mlflow.log_artifact(output_path, artifact_path="visualizations")

    print(f"Visualization saved and logged: {output_path}")


In [None]:
EPOCHS = 50
log_path = "./training_logs"
os.makedirs(log_path, exist_ok=True)

with mlflow.start_run():
    # Log hyperparameters
    mlflow.log_param("epochs", EPOCHS)
    mlflow.log_param("learning_rate", 1e-4)
    mlflow.log_param("image_size", (200, 200))
    mlflow.log_param("batch_size", 32)

    for epoch in range(EPOCHS):
        step = 0
        gen_loss_epoch = 0
        disc_loss_epoch = 0
        
        for input_images, gt_images in train_data:  # Iterate over training batches
            gen_loss, disc_loss = train_step(generator, discriminator, input_images, gt_images, gen_optimizer, disc_optimizer)
            
            gen_loss_epoch += gen_loss.numpy()
            disc_loss_epoch += disc_loss.numpy()
            step += 1

            # Visualize and log images every 100 steps
            if step % 5 == 0:
                visualize_and_log_images(generator, input_images, gt_images, log_path, epoch, step)

        # Average losses
        gen_loss_epoch /= step
        disc_loss_epoch /= step

        # Log metrics to MLflow
        mlflow.log_metric("generator_loss", gen_loss_epoch, step=epoch)
        mlflow.log_metric("discriminator_loss", disc_loss_epoch, step=epoch)

        # Save models every 10 epochs
        if (epoch + 1) % 10 == 0:
            gen_model_path = f"generator_epoch_{epoch+1}.h5"
            disc_model_path = f"discriminator_epoch_{epoch+1}.h5"
            generator.save(gen_model_path)
            discriminator.save(disc_model_path)
            mlflow.log_artifact(gen_model_path, artifact_path="models")
            mlflow.log_artifact(disc_model_path, artifact_path="models")

        print(f"Epoch {epoch+1}/{EPOCHS} | Gen Loss: {gen_loss_epoch:.4f}, Disc Loss: {disc_loss_epoch:.4f}")


Visualization saved and logged: ./training_logs\test_results_epoch_1_step_6.png
Epoch 1/20 | Gen Loss: 0.7467, Disc Loss: 1.3804
Visualization saved and logged: ./training_logs\test_results_epoch_2_step_6.png
Epoch 2/20 | Gen Loss: 0.7212, Disc Loss: 1.3804
Visualization saved and logged: ./training_logs\test_results_epoch_3_step_6.png
Epoch 3/20 | Gen Loss: 0.6983, Disc Loss: 1.3784
Visualization saved and logged: ./training_logs\test_results_epoch_4_step_6.png
Epoch 4/20 | Gen Loss: 0.7202, Disc Loss: 1.3573
Visualization saved and logged: ./training_logs\test_results_epoch_5_step_6.png
Epoch 5/20 | Gen Loss: 0.7385, Disc Loss: 1.4084
Visualization saved and logged: ./training_logs\test_results_epoch_6_step_6.png
Epoch 6/20 | Gen Loss: 0.7406, Disc Loss: 1.3303
Visualization saved and logged: ./training_logs\test_results_epoch_7_step_6.png
Epoch 7/20 | Gen Loss: 0.7502, Disc Loss: 1.3378
Visualization saved and logged: ./training_logs\test_results_epoch_8_step_6.png
Epoch 8/20 | Gen 



Epoch 10/20 | Gen Loss: 0.8196, Disc Loss: 1.3166
Visualization saved and logged: ./training_logs\test_results_epoch_11_step_6.png
Epoch 11/20 | Gen Loss: 0.7950, Disc Loss: 1.2780
Visualization saved and logged: ./training_logs\test_results_epoch_12_step_6.png
Epoch 12/20 | Gen Loss: 0.8663, Disc Loss: 1.3182
Visualization saved and logged: ./training_logs\test_results_epoch_13_step_6.png
Epoch 13/20 | Gen Loss: 0.8521, Disc Loss: 1.2597
Visualization saved and logged: ./training_logs\test_results_epoch_14_step_6.png
Epoch 14/20 | Gen Loss: 0.8349, Disc Loss: 1.2890
Visualization saved and logged: ./training_logs\test_results_epoch_15_step_6.png
Epoch 15/20 | Gen Loss: 0.8279, Disc Loss: 1.3625
Visualization saved and logged: ./training_logs\test_results_epoch_16_step_6.png
Epoch 16/20 | Gen Loss: 0.7637, Disc Loss: 1.3057
Visualization saved and logged: ./training_logs\test_results_epoch_17_step_6.png
Epoch 17/20 | Gen Loss: 0.8373, Disc Loss: 1.2874
Visualization saved and logged: .



Epoch 20/20 | Gen Loss: 0.7898, Disc Loss: 1.3200


In [None]:
test_gen_loss = 0
test_disc_loss = 0
test_steps = 0

log_path = "./test_visualizations"
os.makedirs(log_path, exist_ok=True)

for i, (test_input_images, test_gt_images) in enumerate(test_data):
    generated_images = generator(test_input_images, training=False)
    fake_output = discriminator(generated_images, training=False)
    real_output = discriminator(test_gt_images, training=False)

    test_gen_loss += gan_loss(test_gt_images, generated_images, fake_output).numpy()
    test_disc_loss += (
        tf.keras.losses.BinaryCrossentropy()(tf.ones_like(real_output), real_output).numpy() +
        tf.keras.losses.BinaryCrossentropy()(tf.zeros_like(fake_output), fake_output).numpy()
    )
    test_steps += 1

    # Visualize and log test images (limit to 5 batches for clarity)
    if i < 5:
        visualize_and_log_images(
            generator, test_input_images, test_gt_images, log_path, prefix=f"test_batch_{i+1}"
        )

test_gen_loss /= test_steps
test_disc_loss /= test_steps

print(f"Test Gen Loss: {test_gen_loss:.4f}, Test Disc Loss: {test_disc_loss:.4f}")
mlflow.log_metric("test_generator_loss", test_gen_loss)
mlflow.log_metric("test_discriminator_loss", test_disc_loss)


Visualization saved and logged: ./test_visualizations\test_batch_1.png
Visualization saved and logged: ./test_visualizations\test_batch_2.png
Visualization saved and logged: ./test_visualizations\test_batch_3.png
Visualization saved and logged: ./test_visualizations\test_batch_4.png
Test Gen Loss: 0.7335, Test Disc Loss: 1.6580


In [17]:
mlflow.end_run()