<a href="https://www.kaggle.com/code/wolflxchuppy/monet-gan?scriptVersionId=292854562" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os

# **LOAD DATASET**

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

from kaggle_datasets import KaggleDatasets
import matplotlib.pyplot as plt
import numpy as np

try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()
print('Number of replicas:', strategy.num_replicas_in_sync)

AUTOTUNE = tf.data.experimental.AUTOTUNE
    
print(tf.__version__)

In [None]:
GCS_PATH = KaggleDatasets().get_gcs_path()

In [None]:
MONET_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/monet_tfrec/*.tfrec'))
print('Monet TFRecord Files:', len(MONET_FILENAMES))

PHOTO_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/photo_tfrec/*.tfrec'))
print('Photo TFRecord Files:', len(PHOTO_FILENAMES))

In [None]:
IMAGE_SIZE = [256, 256]

def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = (tf.cast(image, tf.float32) / 127.5) - 1
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    return image

def read_tfrecord(example):
    tfrecord_format = {
        "image_name": tf.io.FixedLenFeature([], tf.string),
        "image": tf.io.FixedLenFeature([], tf.string),
        "target": tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    return image

In [None]:
def load_dataset(filenames, labeled=True, ordered=False):
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)
    return dataset

In [None]:
monet_ds = load_dataset(MONET_FILENAMES, labeled=True).batch(1)
photo_ds = load_dataset(PHOTO_FILENAMES, labeled=True).batch(1)

In [None]:
example_monet = next(iter(monet_ds))
example_photo = next(iter(photo_ds))

In [None]:
plt.subplot(121)
plt.title('Photo')
plt.imshow(example_photo[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Monet')
plt.imshow(example_monet[0] * 0.5 + 0.5)

# **The Generator (ResNet-Based)**

In [None]:
import tensorflow as tf
from tensorflow.keras import layers

def resnet_block(input_layer, n_filters):
    # Standard ResNet block for style preservation
    g = layers.Conv2D(n_filters, (3,3), padding='same')(input_layer)
    g = layers.BatchNormalization()(g)
    g = layers.Activation('relu')(g)
    
    g = layers.Conv2D(n_filters, (3,3), padding='same')(g)
    g = layers.BatchNormalization()(g)
    
    # Concatenate original input (skip connection)
    return layers.Add()([g, input_layer])

def build_generator(n_resnet=9):
    inputs = layers.Input(shape=(256, 256, 3))

    # 1. Initial Convolution
    x = layers.Conv2D(64, (7,7), padding='same')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    # 2. Downsampling (Reduced to 128x128, then 64x64)
    # This helps the model 'see' larger areas of the photo at once
    for filters in [128, 256]:
        x = layers.Conv2D(filters, (3,3), strides=(2,2), padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.Activation('relu')(x)

    # 3. Transform: ResNet blocks (Operating at 64x64 resolution)
    for _ in range(n_resnet):
        x = resnet_block(x, 256)

    # 4. Upsampling (Back to 128x128, then 256x256)
    for filters in [128, 64]:
        x = layers.Conv2DTranspose(filters, (3,3), strides=(2,2), padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.Activation('relu')(x)

    # 5. Final Output (256, 256, 3)
    outputs = layers.Conv2D(3, (7,7), padding='same', activation='tanh')(x)

    return tf.keras.Model(inputs, outputs)

# **The Discriminator (PatchGAN)**

In [None]:
def build_discriminator():
    inputs = layers.Input(shape=(256, 256, 3))
    x = layers.Conv2D(64, (4,4), strides=(2,2), padding='same')(inputs)
    x = layers.LeakyReLU(alpha=0.2)(x)
    
    # Downsample to patches
    x = layers.Conv2D(128, (4,4), strides=(2,2), padding='same')(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    
    # Output a 1-channel prediction map
    outputs = layers.Conv2D(1, (4,4), padding='same')(x)
    return tf.keras.Model(inputs, outputs)

In [None]:
class CycleGan(tf.keras.Model):
    def __init__(self, generator_g, generator_f, discriminator_x, discriminator_y, lambda_cycle=10.0, lambda_id=5.0):
        super(CycleGan, self).__init__()
        self.gen_g = generator_g  # Photo -> Monet
        self.gen_f = generator_f  # Monet -> Photo
        self.disc_x = discriminator_x # Discriminator for Photo
        self.disc_y = discriminator_y # Discriminator for Monet
        self.lambda_cycle = lambda_cycle
        self.lambda_id = lambda_id

    def compile(self, gen_g_optimizer, gen_f_optimizer, disc_x_optimizer, disc_y_optimizer, gen_loss_fn, disc_loss_fn):
        super(CycleGan, self).compile()
        self.gen_g_optimizer = gen_g_optimizer
        self.gen_f_optimizer = gen_f_optimizer
        self.disc_x_optimizer = disc_x_optimizer
        self.disc_y_optimizer = disc_y_optimizer
        self.gen_loss_fn = gen_loss_fn
        self.disc_loss_fn = disc_loss_fn
        self.cycle_loss_fn = tf.keras.losses.MeanAbsoluteError()
        self.identity_loss_fn = tf.keras.losses.MeanAbsoluteError()

    @tf.function
    def train_step(self, batch_data):
        real_x, real_y = batch_data # x is Photo, y is Monet

        with tf.GradientTape(persistent=True) as tape:
            # 1. Generator G translates Photo -> Fake Monet
            # 2. Generator F translates Monet -> Fake Photo
            fake_y = self.gen_g(real_x, training=True)
            fake_x = self.gen_f(real_y, training=True)

            # 3. Cycle Consistency: translate back to original domain
            cycled_x = self.gen_f(fake_y, training=True)
            cycled_y = self.gen_g(fake_x, training=True)

            # 4. Identity Mapping: feed domain image to its own generator
            same_x = self.gen_f(real_x, training=True)
            same_y = self.gen_g(real_y, training=True)

            # Calculate Generator Losses
            gen_g_loss = self.gen_loss_fn(self.disc_y(fake_y))
            gen_f_loss = self.gen_loss_fn(self.disc_x(fake_x))

            # Cycle Loss: How close is cycled_x to real_x?
            cycle_loss_g = self.cycle_loss_fn(real_x, cycled_x) * self.lambda_cycle
            cycle_loss_f = self.cycle_loss_fn(real_y, cycled_y) * self.lambda_cycle

            # Identity Loss: Does the model preserve color/composition?
            id_loss_g = self.identity_loss_fn(real_y, same_y) * self.lambda_id
            id_loss_f = self.identity_loss_fn(real_x, same_x) * self.lambda_id

            total_gen_g_loss = gen_g_loss + cycle_loss_g + id_loss_g
            total_gen_f_loss = gen_f_loss + cycle_loss_f + id_loss_f

            # Calculate Discriminator Losses
            disc_x_loss = self.disc_loss_fn(
                self.disc_x(real_x, training=True),
                self.disc_x(tf.stop_gradient(fake_x), training=True)
            )
            
            disc_y_loss = self.disc_loss_fn(
                self.disc_y(real_y, training=True),
                self.disc_y(tf.stop_gradient(fake_y), training=True)
            )

            grads_g = tape.gradient(total_gen_g_loss, self.gen_g.trainable_variables)
            # Generator F gradients
            grads_f = tape.gradient(total_gen_f_loss, self.gen_f.trainable_variables)
            
            # Discriminator gradients
            grads_dx = tape.gradient(disc_x_loss, self.disc_x.trainable_variables)
            grads_dy = tape.gradient(disc_y_loss, self.disc_y.trainable_variables)
            
            # Apply gradients
            self.gen_g_optimizer.apply_gradients(zip(grads_g, self.gen_g.trainable_variables))
            self.gen_f_optimizer.apply_gradients(zip(grads_f, self.gen_f.trainable_variables))
            
            self.disc_x_optimizer.apply_gradients(zip(grads_dx, self.disc_x.trainable_variables))
            self.disc_y_optimizer.apply_gradients(zip(grads_dy, self.disc_y.trainable_variables))

        
        return {"g_loss": total_gen_g_loss, "f_loss": total_gen_f_loss, "dx_loss": disc_x_loss, "dy_loss": disc_y_loss}

# **Loss Function**

In [None]:
# Binary Cross Entropy is fine, but MSE is often more stable for CycleGAN
loss_obj = tf.keras.losses.MeanSquaredError()

def discriminator_loss(real, generated):
    real_loss = loss_obj(tf.ones_like(real), real)
    generated_loss = loss_obj(tf.zeros_like(generated), generated)
    return (real_loss + generated_loss) * 0.5

def generator_loss(generated):
    return loss_obj(tf.ones_like(generated), generated)

# **Data Augmentation**

In [None]:
def load_and_decode_function(file_path):
    # Read the file from the disk
    img = tf.io.read_file(file_path)
    
    # Convert the compressed string to a 3D uint8 tensor
    img = tf.image.decode_jpeg(img, channels=3)
    
    # Convert to floats in the [0, 1] range
    img = tf.image.convert_image_dtype(img, tf.float32)
    
    # Resize the image to the model's expected input size
    img = tf.image.resize(img, [256, 256])
    
    return img

In [None]:
# Create the generator and discriminator instances
monet_generator = build_generator() # Photo -> Monet
photo_generator = build_generator() # Monet -> Photo
monet_discriminator = build_discriminator()
photo_discriminator = build_discriminator()

# Initialize the CycleGAN model
cycle_gan_model = CycleGan(
    monet_generator, photo_generator, 
    photo_discriminator, monet_discriminator
)

# Compile with optimizers and losses
cycle_gan_model.compile(
    gen_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5),
    gen_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5),
    disc_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5),
    disc_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5),
    gen_loss_fn = generator_loss,
    disc_loss_fn = discriminator_loss
)

# Train the model (at least 25-50 epochs for decent results)
cycle_gan_model.fit(
    tf.data.Dataset.zip((photo_ds, monet_ds)),
    epochs=5
)


# **SUBMISSION**

In [None]:
import os, shutil
from PIL import Image
generator = cycle_gan_model.gen_g

os.makedirs("../images", exist_ok=True)
for i, photo in enumerate(photo_ds.repeat().take(7000)):

    if len(photo.shape) == 4:
        img = photo[0]
    else:
        img = photo
    

    input_tensor = tf.expand_dims(img, axis=0)

    # Generate image
    pred = generator(input_tensor, training=False)
    pred = pred[0].numpy()

    # De-normalize
    pred = (pred * 127.5 + 127.5)
    pred = np.clip(pred, 0, 255).astype(np.uint8)

    im = Image.fromarray(pred)
    im.save(f"../images/{i}.jpg", quality=95)

In [None]:
shutil.make_archive(
    "/kaggle/working/images",
    "zip",
    "../images"
)