In [4]:
import tensorflow as tf
tf.keras.backend.clear_session()
import tensorflow.keras.backend as K
from tensorflow.keras.layers import (
    Input, Conv2D, Conv2DTranspose, LayerNormalization, 
    MultiHeadAttention, Dense, Dropout, Add, Concatenate, 
    BatchNormalization, ReLU, Reshape, Permute, UpSampling2D, Lambda
)
from tensorflow.keras.applications import VGG19
from tensorflow.keras.models import Model
import numpy as np
import os
from skimage.color import rgb2lab
from keras.layers import Resizing
import matplotlib.pyplot as plt
import os
from skimage import color

# Constants
LAMBDA_ADVERSARIAL = 0.1
LAMBDA_PERCEPTUAL = 100
LAMBDA_L1 = 10
LAMBDA_COLOR = 1
GLOBAL_BATCH_SIZE = 16

class WindowAttention(tf.keras.layers.Layer):
    def __init__(self, dim, window_size, num_heads, qkv_bias=True):
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        self.scale = (dim // num_heads) ** -0.5
        self.qkv = Dense(dim * 3, use_bias=qkv_bias)
        self.proj = Dense(dim)

    def call(self, x):
        B, H, W, C = x.shape
        x = Reshape((-1, C))(x)
        qkv = self.qkv(x)
        qkv = Reshape((-1, 3, self.num_heads, C // self.num_heads))(qkv)
        qkv = tf.transpose(qkv, [2, 0, 3, 1, 4])
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = tf.matmul(q, k, transpose_b=True) * self.scale
        attn = tf.nn.softmax(attn, axis=-1)
        x = tf.matmul(attn, v)
        x = tf.transpose(x, [0, 2, 1, 3])
        x = Reshape((H, W, C))(x)
        x = self.proj(x)
        return x

def swin_transformer_block(x, dim, num_heads, window_size=7):
    x_norm1 = LayerNormalization(epsilon=1e-5)(x)
    x_attn = WindowAttention(dim, window_size, num_heads)(x_norm1)
    x = Add()([x, x_attn])
    x_norm2 = LayerNormalization(epsilon=1e-5)(x)
    x_mlp = Dense(dim * 4, activation='gelu')(x_norm2)
    x_mlp = Dense(dim)(x_mlp)
    x = Add()([x, x_mlp])
    return x

def color_encoder(input_shape=(256, 256, 1)):
    random_input = Input(shape=(input_shape[0], input_shape[1], 128))
    x = Conv2D(32, (4, 4), strides=2, padding="same")(random_input)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2D(64, (4, 4), strides=2, padding="same")(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2D(128, (4, 4), strides=2, padding="same")(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2D(256, (4, 4), strides=2, padding="same")(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    return Model(random_input, x, name="ColorEncoder")

def color_transformer(encoder_output, color_features):
    x = Concatenate()([encoder_output, color_features])
    x_initial = Conv2D(256, (3, 3), padding="same")(x)
    x_initial = BatchNormalization()(x_initial)
    x_initial = ReLU()(x_initial)
    x = swin_transformer_block(x_initial, dim=256, num_heads=4)
    x = swin_transformer_block(x, dim=256, num_heads=4)
    x = Add()([x, x_initial])
    return x

def build_generator(input_shape=(256, 256, 1)):
    inp = Input(shape=input_shape)
    
    # VGG Feature Extractor
    grayscale_input = Lambda(lambda x: tf.repeat(x, 3, axis=-1))(inp)
    vgg = VGG19(weights="imagenet", include_top=False)
    vgg.trainable = False
    vgg_features = []
    x = grayscale_input
    for i, layer in enumerate(vgg.layers):
        if isinstance(layer, tf.keras.layers.InputLayer):
            continue
        x = layer(x)
        if layer.name in ['block1_conv2', 'block2_conv2', 'block3_conv4', 'block4_conv4']:
            vgg_features.append(x)
    
    random_noise = Lambda(lambda x: tf.random.normal(
        shape=(tf.shape(x)[0], input_shape[0], input_shape[1], 128),
        mean=0, stddev=0.1))(inp)
    
    color_enc = color_encoder(input_shape)
    color_features = color_enc(random_noise)
    
    # Encoder path with VGG features
    e1 = Conv2D(32, (4, 4), strides=2, padding="same")(inp)  # 128x128
    e1 = BatchNormalization()(e1)
    e1 = ReLU()(e1)
    vgg1_resized = Conv2D(64, (1, 1))(vgg_features[0])
    vgg1_resized = Resizing(128, 128)(vgg1_resized)
    e1 = Concatenate()([e1, vgg1_resized])
    
    e2 = Conv2D(64, (4, 4), strides=2, padding="same")(e1)  # 64x64
    e2 = BatchNormalization()(e2)
    e2 = ReLU()(e2)
    vgg2_resized = Conv2D(64, (1, 1))(vgg_features[1])
    vgg2_resized = Resizing(64, 64)(vgg2_resized) 
    e2 = Concatenate()([e2, vgg2_resized])
    
    e3 = Conv2D(128, (4, 4), strides=2, padding="same")(e2)  # 32x32
    e3 = BatchNormalization()(e3)
    e3 = ReLU()(e3)
    vgg3_resized = Conv2D(128, (1, 1))(vgg_features[2])
    vgg3_resized = Resizing(32, 32)(vgg3_resized)
    e3 = Concatenate()([e3, vgg3_resized])
    
    e4 = Conv2D(256, (4, 4), strides=2, padding="same")(e3)  # 16x16
    e4 = BatchNormalization()(e4)
    e4 = ReLU()(e4)
    vgg4_resized = Conv2D(256, (1, 1))(vgg_features[3])
    vgg4_resized = Resizing(16, 16)(vgg4_resized)
    e4 = Concatenate()([e4, vgg4_resized])
    
    models = {
        'vgg': vgg,
        'color_encoder': color_enc
    }
    
    x = color_transformer(e4, color_features)
    
    # Decoder path
    d1 = Conv2DTranspose(128, (4, 4), strides=2, padding="same")(x)
    d1 = Concatenate()([d1, e3])
    d1 = BatchNormalization()(d1)
    d1 = ReLU()(d1)
    
    d2 = Conv2DTranspose(64, (4, 4), strides=2, padding="same")(d1)
    d2 = Concatenate()([d2, e2])
    d2 = BatchNormalization()(d2)
    d2 = ReLU()(d2)
    
    d3 = Conv2DTranspose(32, (4, 4), strides=2, padding="same")(d2)
    d3 = Concatenate()([d3, e1])
    d3 = BatchNormalization()(d3)
    d3 = ReLU()(d3)
    
    d4 = Conv2DTranspose(16, (4, 4), strides=2, padding="same")(d3)
    d4 = BatchNormalization()(d4)
    d4 = ReLU()(d4)
    
    output = Conv2D(2, (3, 3), padding="same", activation="tanh")(d4)
    
    return Model(inp, output, name="Generator"), models

def build_discriminator(input_shape=(256, 256, 3)):
    inp = Input(shape=input_shape)
    x = Conv2D(64, (4, 4), strides=2, padding="same")(inp)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2D(128, (4, 4), strides=2, padding="same")(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2D(256, (4, 4), strides=2, padding="same")(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2D(512, (4, 4), strides=2, padding="same")(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    output = Conv2D(1, (4, 4), padding="same")(x)
    return Model(inp, output, name="Discriminator")

def gradient_penalty(discriminator, l_channel, fake_ab, real_ab):
    """Compute gradient penalty for WGAN-GP"""
    batch_size = tf.shape(l_channel)[0]
    alpha = tf.random.uniform([batch_size, 1, 1, 1], 0., 1.)
    
    # Create interpolated images
    fake_images = tf.concat([l_channel, fake_ab], axis=-1)
    real_images = tf.concat([l_channel, real_ab], axis=-1)
    interpolated = alpha * real_images + (1 - alpha) * fake_images
    
    with tf.GradientTape() as tape:
        tape.watch(interpolated)
        disc_interpolated = discriminator(interpolated, training=True)
    
    gradients = tape.gradient(disc_interpolated, interpolated)
    gradients_sqr = tf.square(gradients)
    gradients_sqr_sum = tf.reduce_sum(gradients_sqr, axis=[1, 2, 3])
    gradient_l2_norm = tf.sqrt(gradients_sqr_sum)
    gradient_penalty = tf.reduce_mean(tf.square(gradient_l2_norm - 1.0))
    
    return gradient_penalty

def preprocess_image(image_path, target_size=(256, 256)):
    def _convert_to_lab(image_path):
        # Read and normalize image to [0, 1]
        image = tf.io.read_file(image_path)
        image = tf.image.decode_jpeg(image, channels=3)
        image = tf.image.resize(image, target_size)
        image = tf.cast(image, tf.float32) / 255.0
        
        # Convert to LAB
        lab_image = color.rgb2lab(image.numpy())
        
        # Normalize L to [-1, 1] and ab to [-1, 1]
        L = (lab_image[..., :1] - 50.0) / 50.0  # Map [0, 100] to [-1, 1]
        ab = lab_image[..., 1:] / 127.0  # Map [-128, 128] to [-1, 1]
        
        return L.astype(np.float32), ab.astype(np.float32)
    
    L, ab = tf.py_function(func=_convert_to_lab, inp=[image_path], Tout=[tf.float32, tf.float32])
    L.set_shape((target_size[0], target_size[1], 1))
    ab.set_shape((target_size[0], target_size[1], 2))
    return L, ab


def build_vgg_feature_extractor(input_shape=(256, 256, 3)):
    """Build a VGG feature extractor model that outputs the features we need"""
    vgg = VGG19(weights="imagenet", include_top=False, input_shape=input_shape)
    vgg.trainable = False
    
    # Get the specific layer we want
    layer_name = 'block4_conv4'
    feature_extractor = Model(
        inputs=vgg.input,
        outputs=vgg.get_layer(layer_name).output,
        name='vgg_feature_extractor'
    )
    return feature_extractor



# Initialize distribution strategy
strategy = tf.distribute.MirroredStrategy()
print('Number of devices:', strategy.num_replicas_in_sync)


def create_distributed_dataset(image_dir, batch_size=32):  # Add batch_size parameter with default value
    def load_and_preprocess_image(image_path):
        # Read image
        image = tf.io.read_file(image_path)
        image = tf.image.decode_jpeg(image, channels=3)
        
        # Resize
        image = tf.image.resize(image, [256, 256])
        
        # Convert to float32 and normalize to [-1, 1]
        image = tf.cast(image, tf.float32) / 127.5 - 1
        
        # Convert RGB to LAB
        lab = tf.py_function(lambda x: color.rgb2lab(x), [image], tf.float32)
        
        # Split into L and ab channels
        l_channel = lab[..., :1]
        ab_channels = lab[..., 1:]
        
        return l_channel, ab_channels

    # Get all image paths
    image_paths = tf.data.Dataset.list_files(str(image_dir + '/*.*'))
    
    # Create dataset
    dataset = image_paths.map(load_and_preprocess_image, 
                            num_parallel_calls=tf.data.AUTOTUNE)
    
    # Batch and prefetch
    dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    
    return dataset
    



@tf.function
def lab_to_rgb_tf(lab_image):
    """Convert LAB image to RGB using TensorFlow operations"""
    # Denormalize L from [-1, 1] to [0, 100]
    L = (lab_image[..., 0:1] + 1) * 50
    # Denormalize ab from [-1, 1] to [-128, 128]
    ab = lab_image[..., 1:] * 127
    
    # Instead of using skimage's lab2rgb, we'll use a simplified conversion
    # This is an approximation of the LAB to RGB conversion
    # Note: For more accurate results, you might want to implement the full conversion matrix
    lab = tf.concat([L, ab], axis=-1)
    
    # Normalize to [0, 1] range for VGG
    rgb = tf.clip_by_value(lab / 255.0, 0, 1)
    # Repeat the L channel 3 times to create an RGB image
    rgb = tf.repeat(rgb[..., 0:1], 3, axis=-1)
    
    return rgb
def train_model(image_dir, epochs=50, checkpoint_dir='checkpoints'):
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
        
    with strategy.scope():
        generator, models = build_generator(input_shape=(256, 256, 1))
        discriminator = build_discriminator(input_shape=(256, 256, 3))
        vgg_feature_extractor = build_vgg_feature_extractor()
        color_encoder_model = models['color_encoder']
        
        feature_matching_conv = Conv2D(512, (1, 1), padding='same')
        dummy_input = tf.random.normal([1, 32, 32, 256])
        _ = feature_matching_conv(dummy_input)
        
        gen_optimizer = tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5, beta_2=0.999)
        disc_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-5, beta_1=0.5, beta_2=0.999)
       
        def train_step(dist_inputs):
            l_channel, ab_channels = dist_inputs
            
            # Train discriminator n_critic times
            n_critic = 5
            for _ in range(n_critic):
                with tf.GradientTape() as disc_tape:
                    fake_ab = generator(l_channel, training=True)
                    fake_images = tf.concat([l_channel, fake_ab], axis=-1)
                    real_images = tf.concat([l_channel, ab_channels], axis=-1)
                    
                    disc_fake = discriminator(fake_images, training=True)
                    disc_real = discriminator(real_images, training=True)
                    
                    disc_loss = tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real)
                    gp = gradient_penalty(discriminator, l_channel, fake_ab, ab_channels)
                    total_disc_loss = (disc_loss + 10.0 * gp) / strategy.num_replicas_in_sync
                
                # Calculate and apply discriminator gradients
                disc_gradients = disc_tape.gradient(total_disc_loss, discriminator.trainable_variables)
                disc_gradients, _ = tf.clip_by_global_norm(disc_gradients, 1.0)
                disc_optimizer.apply_gradients(zip(disc_gradients, discriminator.trainable_variables))
            
            # Generator training (once)
            with tf.GradientTape() as gen_tape:
                fake_ab = generator(l_channel, training=True)
                fake_images = tf.concat([l_channel, fake_ab], axis=-1)
                real_images = tf.concat([l_channel, ab_channels], axis=-1)
                
                fake_rgb = (fake_images + 1) / 2
                real_rgb = (real_images + 1) / 2
                vgg_real = vgg_feature_extractor(real_rgb)
                vgg_fake = vgg_feature_extractor(fake_rgb)
                
                batch_size = tf.shape(l_channel)[0]
                random_noise = tf.random.normal(shape=(batch_size, 256, 256, 128), mean=0, stddev=0.1)
                color_encoded_features = color_encoder_model(random_noise)
                color_encoded_features = tf.image.resize(color_encoded_features, [32, 32])
                color_encoded_features = feature_matching_conv(color_encoded_features)
                
                disc_fake = discriminator(fake_images, training=True)
                
                # Generator losses
                gen_adv_loss = -tf.reduce_mean(disc_fake)
                perceptual_loss = tf.reduce_mean(tf.abs(vgg_fake - vgg_real))
                l1_loss = tf.reduce_mean(tf.abs(real_images - fake_images))
                color_loss = tf.reduce_mean(tf.abs(color_encoded_features - vgg_real))
            
                
                total_gen_loss = (
                    LAMBDA_ADVERSARIAL * gen_adv_loss +
                    LAMBDA_PERCEPTUAL * perceptual_loss +
                    LAMBDA_L1 * l1_loss +
                    LAMBDA_COLOR * color_loss
                ) / strategy.num_replicas_in_sync
            
            # Calculate and apply generator gradients
            gen_gradients = gen_tape.gradient(total_gen_loss, generator.trainable_variables)
            gen_optimizer.apply_gradients(zip(gen_gradients, generator.trainable_variables))
            
            return total_gen_loss, total_disc_loss

        @tf.function
        def distributed_train_step(dist_inputs):
            per_replica_losses = strategy.run(train_step, args=(dist_inputs,))
            return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)

    dataset = create_distributed_dataset(image_dir)
    dist_dataset = strategy.experimental_distribute_dataset(dataset)
    
    train_gen_loss = tf.keras.metrics.Mean(name='train_gen_loss')
    train_disc_loss = tf.keras.metrics.Mean(name='train_disc_loss')

    for epoch in range(epochs):
        print(f"\nEpoch {epoch + 1}/{epochs}")
        train_gen_loss.reset_state()
        train_disc_loss.reset_state()
        
        for step, dist_inputs in enumerate(dist_dataset):
            gen_loss, disc_loss = distributed_train_step(dist_inputs)
            train_gen_loss.update_state(gen_loss)
            train_disc_loss.update_state(disc_loss)
            
            if step % 50 == 0:
                print(f"Step {step}: Gen Loss = {train_gen_loss.result():.4f}, "
                      f"Disc Loss = {train_disc_loss.result():.4f}")
        
        if (epoch + 1) % 5 == 0:
            generator.save_weights(f'generator_epoch_{epoch+1}.weights.h5')
            discriminator.save_weights(f'discriminator_epoch_{epoch+1}.weights.h5')
            print(f"Saved weights for epoch {epoch+1}")
        
        print(f"Epoch {epoch + 1} Results:")
        print(f"Generator Loss: {train_gen_loss.result():.4f}")
        print(f"Discriminator Loss: {train_disc_loss.result():.4f}")


Number of devices: 2


In [None]:
train_model("/kaggle/input/pascal-voc-2012-dataset/VOC2012_train_val/VOC2012_train_val/JPEGImages",epochs=50)

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg19/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5
[1m80134624/80134624[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step

Epoch 1/50
Step 0: Gen Loss = 811.7355, Disc Loss = 0.1773
Step 50: Gen Loss = 771.1486, Disc Loss = -1.7468
Step 100: Gen Loss = 771.1143, Disc Loss = -3.9859
Step 150: Gen Loss = 770.3538, Disc Loss = -6.8389
Step 200: Gen Loss = 771.9755, Disc Loss = -10.0045
