In [3]:

import os
import time
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, initializers, regularizers, optimizers, models
from glob import glob
from PIL import Image
from tensorflow.keras.applications import VGG19
from tensorflow.keras.models import Model

# -------------------------------
# Utilities
# -------------------------------
from tensorflow.keras.layers import Layer

class InstanceNormalization(Layer):
    def __init__(self, epsilon=1e-5, **kwargs):
        super(InstanceNormalization, self).__init__(**kwargs)
        self.epsilon = epsilon

    def build(self, input_shape):
        # Khởi tạo các tham số scale và offset
        self.scale = self.add_weight(name='scale',
                                     shape=(input_shape[-1],),
                                     initializer='ones',
                                     trainable=True)
        self.offset = self.add_weight(name='offset',
                                      shape=(input_shape[-1],),
                                      initializer='zeros',
                                      trainable=True)
        super(InstanceNormalization, self).build(input_shape)

    def call(self, inputs):
        # Tính toán trung bình và độ lệch chuẩn theo từng instance và channel
        mean, variance = tf.nn.moments(inputs, axes=[1, 2], keepdims=True)
        inv = tf.math.rsqrt(variance + self.epsilon)
        normalized = (inputs - mean) * inv
        return self.scale * normalized + self.offset

    def get_config(self):
        config = super(InstanceNormalization, self).get_config()
        config.update({"epsilon": self.epsilon})
        return config

def check_folder(path):
    """Ensure that a directory exists; create it if it doesn't."""
    if not os.path.exists(path):
        os.makedirs(path)

def show_all_variables(model):
    """Print all trainable variables."""
    trainable_vars = model.trainable_variables  # Lấy biến có thể huấn luyện từ model
    for var in trainable_vars:
        print(var.name, var.shape)


# -------------------------------
# Custom Layers and Functions
# -------------------------------




def rgb2yuv(rgb):
    """Convert RGB image to YUV color space."""
    yuv = tf.image.rgb_to_yuv((rgb + 1.0) / 2.0)  # Assuming input is in [-1, 1]
    return yuv


class DiscriminatorLossLayer(tf.keras.layers.Layer):
    def __init__(self, loss_type, **kwargs):
        super(DiscriminatorLossLayer, self).__init__(**kwargs)
        self.loss_type = loss_type

    def call(self, inputs):
        """Compute discriminator loss based on GAN type."""
        real, gray, fake, real_blur = inputs
        if self.loss_type in ['wgan-gp', 'wgan-lp']:
            real_loss = -tf.reduce_mean(real)
            gray_loss = tf.reduce_mean(gray)
            fake_loss = tf.reduce_mean(fake)
            real_blur_loss = tf.reduce_mean(real_blur)
        elif self.loss_type == 'lsgan':
            real_loss = tf.reduce_mean(tf.square(real - 1.0))
            gray_loss = tf.reduce_mean(tf.square(gray))
            fake_loss = tf.reduce_mean(tf.square(fake))
            real_blur_loss = tf.reduce_mean(tf.square(real_blur))
        elif self.loss_type in ['gan', 'dragan']:
            real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real), logits=real))
            gray_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(gray), logits=gray))
            fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake), logits=fake))
            real_blur_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(real_blur), logits=real_blur))
        elif self.loss_type == 'hinge':
            real_loss = tf.reduce_mean(tf.nn.relu(1.0 - real))
            gray_loss = tf.reduce_mean(tf.nn.relu(1.0 + gray))
            fake_loss = tf.reduce_mean(tf.nn.relu(1.0 + fake))
            real_blur_loss = tf.reduce_mean(tf.nn.relu(1.0 + real_blur))
        else:
            raise ValueError(f"Unsupported loss type: {self.loss_type}")

        # Assign weights based on your specific requirements
        weights = {
            'real_loss': 1.7,
            'fake_loss': 1.7,
            'gray_loss': 1.7,
            'real_blur_loss': 1.0
        }

        # Compute the total loss
        loss = (weights['real_loss'] * real_loss +
                weights['fake_loss'] * fake_loss +
                weights['gray_loss'] * gray_loss +
                weights['real_blur_loss'] * real_blur_loss)
        return loss

class GeneratorLossLayer(Layer):
    def __init__(self, loss_type='gan', **kwargs):
        super(GeneratorLossLayer, self).__init__(**kwargs)
        self.loss_type = loss_type
    def call(self, fake):
        if self.loss_type in ['wgan-gp', 'wgan-lp']:
            fake_loss = -tf.reduce_mean(fake)
        elif self.loss_type == 'lsgan':
            fake_loss = tf.reduce_mean(tf.square(fake - 1.0))
        elif self.loss_type in ['gan', 'dragan']:
            fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake), logits=fake))
        elif self.loss_type == 'hinge':
            fake_loss = -tf.reduce_mean(fake)
        else:
            raise ValueError(f"Unsupported loss type: {self.loss_type}")

        return fake_loss

class ContentLossLayer(layers.Layer):
    def call(self, real_features, generated_features):
        return tf.reduce_mean(tf.square(real_features - generated_features))

class StyleLossLayer(layers.Layer):
    def call(self, real_features, generated_features):
        def gram_matrix(features):
            shape = tf.shape(features)
            reshaped_features = tf.reshape(features, [shape[0], -1, shape[-1]])
            gram = tf.matmul(reshaped_features, reshaped_features, transpose_a=True)
            return gram / tf.cast(shape[1] * shape[2], tf.float32)
        
        return tf.reduce_mean(tf.square(gram_matrix(real_features) - gram_matrix(generated_features)))

class TotalVariationLossLayer(layers.Layer):
    def call(self, generated_image):
        return tf.reduce_sum(tf.image.total_variation(generated_image))
class RgbToYuvLayer(Layer):
    def __init__(self, **kwargs):
        super(RgbToYuvLayer, self).__init__(**kwargs)

    def call(self, inputs):
        """Convert RGB to YUV."""
        yuv = tf.image.rgb_to_yuv((inputs + 1.0) / 2.0)  # Assuming input is in [-1, 1]
        return yuv
class L1LossLayer(Layer):
    def __init__(self, **kwargs):
        super(L1LossLayer, self).__init__(**kwargs)

    def call(self, inputs):
        x, y = inputs
        return tf.reduce_mean(tf.abs(x - y))
class HuberLossLayer(Layer):
    def __init__(self, delta=1.0, **kwargs):
        super(HuberLossLayer, self).__init__(**kwargs)
        self.delta = delta

    def call(self, inputs):
        y_true, y_pred = inputs
        huber_loss_fn = tf.keras.losses.Huber()
        return huber_loss_fn(y_true, y_pred)
def color_loss_fn(real, fake):
    real_yuv = rgb2yuv(real)
    fake_yuv = rgb2yuv(fake)
    
    # Call the layer as a function to compute the actual loss values
    y_loss = HuberLossLayer()([real_yuv[..., 0], fake_yuv[..., 0]])
    u_loss = HuberLossLayer()([real_yuv[..., 1], fake_yuv[..., 1]])
    v_loss = HuberLossLayer()([real_yuv[..., 2], fake_yuv[..., 2]])
    
    return y_loss + u_loss + v_loss


# -------------------------------
# Data Loader
# -------------------------------
import numpy as np
import cv2
import os
from glob import glob


class ImageGenerator:
    def __init__(self, dataset_path, img_size, batch_size):
        self.dataset_path = dataset_path
        self.img_size = img_size
        self.batch_size = batch_size
        self.image_paths = glob(os.path.join(dataset_path, '*.*'))
        self.num_images = len(self.image_paths)
        self.pointer = 0

    def load_images(self):
        """Load a batch of images (both types)."""
        batch_images = []
        for _ in range(self.batch_size):
            if self.pointer >= self.num_images:
                self.pointer = 0
                np.random.shuffle(self.image_paths)
            img, img_gray = self.load_image(self.image_paths[self.pointer])
            batch_images.append((img, img_gray))  # Thêm ảnh màu và ảnh xám
            self.pointer += 1
        return np.array(batch_images)

    def load_image(self, image_path):
        """Load and preprocess a single image (color and grayscale)."""
        img = cv2.imread(image_path)
        if img is None:
            raise ValueError(f"Image at path {image_path} could not be loaded.")
        
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (self.img_size, self.img_size))
        
        # Normalize image to [-1, 1]
        img = img.astype(np.float32) / 127.5 - 1
        
        # Create grayscale image from color image
        img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        img_gray = cv2.resize(img_gray, (self.img_size, self.img_size))
        img_gray = np.stack([img_gray] * 3, axis=-1)  # Convert to 3 channels
        img_gray = img_gray.astype(np.float32) / 127.5 - 1  # Normalize grayscale image to [-1, 1]
        
        return img, img_gray

    def save_images(self, images, image_path, photo_path=None):
        fake = self.inverse_transform(images.squeeze())
        if photo_path:
            return self.imsave(self.adjust_brightness_from_src_to_dst(fake, self.read_img(photo_path)[0]), image_path)
        else:
            return self.imsave(fake, image_path)

    def inverse_transform(self, images):
        images = (images + 1.) / 2 * 255
        images = np.clip(images, 0, 255)
        return images.astype(np.uint8)

    def imsave(self, images, path):
        return cv2.imwrite(path, cv2.cvtColor(images, cv2.COLOR_RGB2BGR))

    def read_img(self, image_path):
        img = cv2.imread(image_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        img_gray = np.stack([img_gray] * 3, axis=-1)  # Chuyển đổi về 3 kênh
        return img, img_gray

    def calculate_average_brightness(self, img):
        R = img[..., 0].mean()
        G = img[..., 1].mean()
        B = img[..., 2].mean()
        brightness = 0.299 * R + 0.587 * G + 0.114 * B
        return brightness, B, G, R

    def adjust_brightness_from_src_to_dst(self, dst, src, path=None, if_show=None, if_info=None):
        brightness1, B1, G1, R1 = self.calculate_average_brightness(src)
        brightness2, B2, G2, R2 = self.calculate_average_brightness(dst)
        brightness_difference = brightness1 / brightness2

        if if_info:
            print('Average brightness of source image:', brightness1)
            print('Average brightness of target image:', brightness2)
            print('Brightness difference:', brightness_difference)

        dstf = dst * brightness_difference
        dstf = np.clip(dstf, 0, 255).astype(np.uint8)

        ma, na, _ = src.shape
        mb, nb, _ = dst.shape
        result_show_img = np.zeros((max(ma, mb), 3 * max(na, nb), 3), dtype=np.uint8)
        result_show_img[:mb, :nb, :] = dst
        result_show_img[:ma, nb:nb + na, :] = src
        result_show_img[:mb, nb + na:nb + na + nb, :] = dstf

        if if_show:
            cv2.imshow('Brightness Adjustment', cv2.cvtColor(result_show_img, cv2.COLOR_BGR2RGB))
            cv2.waitKey(0)
            cv2.destroyAllWindows()

        if path is not None:
            cv2.imwrite(path, cv2.cvtColor(result_show_img, cv2.COLOR_RGB2BGR))

        return dstf




# -------------------------------
# Models
# -------------------------------

# Simplified ConvBlock: Convolution followed by Normalization and Activation
class ConvBlock(layers.Layer):
    def __init__(self, filters, kernel_size=3, strides=1, use_depthwise=False, **kwargs):
        super(ConvBlock, self).__init__(**kwargs)
        if use_depthwise:
            self.conv = layers.SeparableConv2D(filters, kernel_size, strides=strides, padding='same', use_bias=False)
        else:
            self.conv = layers.Conv2D(filters, kernel_size, strides=strides, padding='same', use_bias=False)
        self.norm = layers.BatchNormalization()
        self.activation = layers.LeakyReLU(0.2)
    
    def call(self, inputs, training=False):
        x = self.conv(inputs)
        x = self.norm(x, training=training)
        return self.activation(x)

# Simplified Residual Block with optional depthwise separable convolution
class ResidualBlock(layers.Layer):
    def __init__(self, filters, use_depthwise=False, **kwargs):
        super(ResidualBlock, self).__init__(**kwargs)
        self.conv1 = ConvBlock(filters, kernel_size=1, strides=1)
        self.conv2 = ConvBlock(filters, kernel_size=3, strides=1, use_depthwise=use_depthwise)
        self.conv3 = layers.Conv2D(filters, kernel_size=1, strides=1, padding='same', use_bias=False)
        self.norm = layers.BatchNormalization()
        self.activation = layers.LeakyReLU(0.2)
    
    def call(self, inputs, training=False):
        x = self.conv1(inputs, training=training)
        x = self.conv2(x, training=training)
        x = self.conv3(x)
        x = self.norm(x, training=training)
        return self.activation(x + inputs)

# Simplified Downsampling Block
class DownsampleBlock(layers.Layer):
    def __init__(self, filters, **kwargs):
        super(DownsampleBlock, self).__init__(**kwargs)
        self.conv = ConvBlock(filters, kernel_size=3, strides=2)
    
    def call(self, inputs, training=False):
        return self.conv(inputs, training=training)

# Simplified Upsampling Block
class UpsampleBlock(layers.Layer):
    def __init__(self, filters, **kwargs):
        super(UpsampleBlock, self).__init__(**kwargs)
        self.upsample = layers.UpSampling2D(size=(2, 2), interpolation='bilinear')
        self.conv = ConvBlock(filters, kernel_size=3, strides=1)
    
    def call(self, inputs, training=False):
        x = self.upsample(inputs)
        return self.conv(x, training=training)

# Simplified Gnet
class G_net(Model):
    def __init__(self, **kwargs):
        super(G_net, self).__init__(**kwargs)
        self.initial_conv = ConvBlock(32, kernel_size=7, strides=1)
        self.conv1 = ConvBlock(64, kernel_size=3, strides=2)
        self.conv2 = ConvBlock(64)
        self.down1 = DownsampleBlock(128)
        self.conv3 = ConvBlock(128)
        self.dsconv1 = ConvBlock(128, use_depthwise=True)
        self.down2 = DownsampleBlock(256)

        # Residual Blocks
        self.res_blocks = [ResidualBlock(256, use_depthwise=True) for _ in range(4)]
        
        self.up1 = UpsampleBlock(128)
        self.conv4 = ConvBlock(128)
        self.up2 = UpsampleBlock(64)
        self.conv5 = ConvBlock(64)
        self.final_conv = layers.Conv2D(3, kernel_size=7, strides=1, padding='same', activation='tanh')
    
    def call(self, inputs, training=False):
        x = self.initial_conv(inputs, training=training)
        x = self.conv1(x, training=training)
        x = self.conv2(x, training=training)
        x = self.down1(x, training=training)
        x = self.conv3(x, training=training)
        x = self.dsconv1(x, training=training)
        x = self.down2(x, training=training)

        for block in self.res_blocks:
            x = block(x, training=training)
        
        x = self.up1(x, training=training)
        x = self.conv4(x, training=training)
        x = self.up2(x, training=training)
        x = self.conv5(x, training=training)
        return self.final_conv(x)


# Simplified Discriminator

class D_net(Model):
    def __init__(self, **kwargs):
        super(D_net, self).__init__(**kwargs)
        self.model = tf.keras.Sequential([
            layers.Conv2D(32, kernel_size=3, strides=1, padding='same'),  # K3,S1,C32
            layers.LeakyReLU(),  # Leaky ReLU activation
            layers.Conv2D(64, kernel_size=3, strides=2, padding='same'),  # K3,S2,C64
            layers.LeakyReLU(),  # Leaky ReLU activation
            layers.Conv2D(128, kernel_size=3, strides=1, padding='same'),  # K3,S1,C128
            layers.LayerNormalization(),  # Instance Normalization
            layers.LeakyReLU(),  # Leaky ReLU activation
            layers.Conv2D(256, kernel_size=3, strides=2, padding='same'),  # K3,S2,C256
            layers.LayerNormalization(),  # Instance Normalization
            layers.LeakyReLU(),  # Leaky ReLU activation
            layers.Conv2D(1, kernel_size=3, strides=1, padding='same')  # K3,S1,C1
        ])
    
    def call(self, inputs, training=False):
        return self.model(inputs, training=training)


# Testing Function
def test_models(generator, discriminator, input_shape):
    dummy_input = tf.random.normal([1] + list(input_shape))  # Batch size of 1
    print("Testing Generator...")
    generated_output = generator(dummy_input, training=False)
    print(f"Generator output shape: {generated_output.shape}")
    
    print("Testing Discriminator...")
    discriminator_output = discriminator(generated_output, training=False)
    print(f"Discriminator output shape: {discriminator_output.shape}")

# Clear previous models from memory
tf.keras.backend.clear_session()

# Initialize models
img_size = (256, 256, 3)  # Example image size
generator = G_net()
discriminator = D_net()

# Test the models with a sample input shape
test_models(generator, discriminator, img_size)
generator.summary()
discriminator.summary()



Testing Generator...
Generator output shape: (1, 128, 128, 3)
Testing Discriminator...
Discriminator output shape: (1, 32, 32, 1)


In [None]:
# -------------------------------
# AnimeGANv2 Model Class
# -------------------------------
import tf2onnx

class AnimeGANv2(tf.keras.Model):
    def __init__(self, config):
        super(AnimeGANv2, self).__init__()
        # Configuration parameters
        self.model_name = config.get('model_name', 'AnimeGANv2')
        self.dataset_name = config['dataset_name']
        self.train_photo_path = config['train_photo_path']
        self.style_path = config['style_path']
        self.smooth_path = config['smooth_path']
        self.val_path = config['val_path']
        
        self.epoch = config['epoch']
        self.init_epoch = config['init_epoch']
        
        self.gan_type = config['gan_type']
        self.batch_size = config['batch_size']
        self.save_freq = config['save_freq']
        
        self.init_lr = config['init_lr']
        self.g_lr = config['g_lr']
        self.d_lr = config['d_lr']
        self.ld = config['ld']
        
        self.tv_weight = config['tv_weight']
        self.con_weight = config['con_weight']
        self.sty_weight = config['sty_weight']
        self.color_weight = config['color_weight']
        self.g_adv_weight = config['g_adv_weight']
        self.d_adv_weight = config['d_adv_weight']
        
        self.img_size = config['img_size']
        self.img_ch = config['img_ch']
        
        self.ch = config['ch']
        self.n_dis = config['n_dis']
        self.sn = config['sn']
        
        self.checkpoint_dir = config['checkpoint_dir']
        self.log_dir = config['log_dir']
        self.sample_dir = config['sample_dir']
        
        self.model_dir = self.get_model_dir(config)
        
        # Ensure directories exist
        check_folder(os.path.join(self.checkpoint_dir, self.model_dir))
        check_folder(os.path.join(self.log_dir, self.model_dir))
        check_folder(os.path.join(self.sample_dir, self.model_dir))
        
        # Initialize Models
        self.generator = G_net()
        self.discriminator = D_net()
        
        # Initialize VGG for content and style loss
        vgg19 = VGG19(include_top=False, weights='imagenet', input_shape=(self.img_size[0], self.img_size[1], 3))
        self.vgg = Model(inputs=vgg19.input, outputs=[vgg19.get_layer(name).output for name in ['block4_conv4']])
        self.vgg.trainable = False
        
        # Initialize Loss Layers
        self.content_loss_layer = ContentLossLayer()
        self.style_loss_layer = StyleLossLayer()
        self.total_variation_loss_layer = TotalVariationLossLayer()
        
        # Initialize Optimizers
        self.g_optimizer = optimizers.Adam(learning_rate=self.g_lr, beta_1=0.5, beta_2=0.999)
        self.d_optimizer = optimizers.Adam(learning_rate=self.d_lr, beta_1=0.5, beta_2=0.999)
        
        # Initialize Loss Layers
        self.generator_loss_layer = GeneratorLossLayer(loss_type=self.gan_type)
        self.discriminator_loss_layer = DiscriminatorLossLayer(loss_type=self.gan_type)
        
        # Initialize Image Generators
        self.real_image_generator = ImageGenerator(self.train_photo_path, self.img_size[0], self.batch_size)
        self.anime_image_generator = ImageGenerator(self.style_path, self.img_size[0], self.batch_size)
        self.anime_smooth_generator = ImageGenerator(self.smooth_path, self.img_size[0], self.batch_size)
        self.dataset_num = max(self.real_image_generator.num_images, 
                               self.anime_image_generator.num_images, 
                               self.anime_smooth_generator.num_images)
        
        # Prepare Checkpoints
        self.checkpoint = tf.train.Checkpoint(generator=self.generator,
                                              discriminator=self.discriminator,
                                              optimizer_g=self.g_optimizer,
                                              optimizer_d=self.d_optimizer)
        self.checkpoint_manager = tf.train.CheckpointManager(self.checkpoint, 
                                                               os.path.join(self.checkpoint_dir, self.model_dir), 
                                                               max_to_keep=self.epoch)
        
        # Restore latest checkpoint if available
        if self.checkpoint_manager.latest_checkpoint:
            self.checkpoint.restore(self.checkpoint_manager.latest_checkpoint)
            print(f"Restored from {self.checkpoint_manager.latest_checkpoint}")
        else:
            print("Initializing from scratch.")
        
        # Initialize TensorBoard writer
        self.summary_writer = tf.summary.create_file_writer(os.path.join(self.log_dir, self.model_dir))
        
    def get_model_dir(self, config):
        """Define the model directory based on hyperparameters."""
        return f"{self.model_name}_{self.dataset_name}_{self.gan_type}_G{int(self.g_adv_weight)}_D{int(self.d_adv_weight)}_" \
               f"C{int(self.con_weight)}_S{int(self.sty_weight)}_Color{int(self.color_weight)}_TV{int(self.tv_weight)}"

    @tf.function
    def train_step_fn(self, real, anime, anime_smooth):
        """Single training step."""
        with tf.GradientTape(persistent=True) as tape:
            # Generate images
            generated = self.generator(real, training=True)
            
            # Discriminator outputs
            real_output = self.discriminator(anime, training=True)
            anime_gray_output = self.discriminator(anime, training=True)  # Assuming anime_gray is same as anime
            fake_output = self.discriminator(generated, training=True)
            smooth_output = self.discriminator(anime_smooth, training=True)
            
            # Compute losses
            content_loss, style_loss = self.compute_con_sty_loss(real, generated, anime_gray_output)
            color_loss = color_loss_fn(real, generated)
            tv_loss = self.tv_weight * self.total_variation_loss_layer(generated)
            total_loss = (self.con_weight * content_loss +
                          self.sty_weight * style_loss +
                          self.color_weight * color_loss +
                          tv_loss)
            
            # Generator adversarial loss
            g_adv_loss = self.g_adv_weight * self.generator_loss_layer(fake_output)
            generator_loss = total_loss + g_adv_loss
            
            # Discriminator adversarial loss
            d_adv_loss = self.d_adv_weight * self.discriminator_loss_layer([real_output, anime_gray_output, fake_output, smooth_output])
            discriminator_loss = d_adv_loss + self.gradient_penalty(anime, generated)
        
        # Compute gradients
        gradients_of_generator = tape.gradient(generator_loss, self.generator.trainable_variables)
        gradients_of_discriminator = tape.gradient(discriminator_loss, self.discriminator.trainable_variables)
        
        # Apply gradients
        self.g_optimizer.apply_gradients(zip(gradients_of_generator, self.generator.trainable_variables))
        self.d_optimizer.apply_gradients(zip(gradients_of_discriminator, self.discriminator.trainable_variables))
        
        return {"g_loss": generator_loss, "d_loss": discriminator_loss}

    def compute_con_sty_loss(self, real, generated, anime_gray):
        """Compute content and style loss."""
        real_features = self.vgg(real)
        generated_features = self.vgg(generated)
        anime_gray_features = self.vgg(anime_gray)
        
        content_loss = sum(self.content_loss_layer(rf, gf) for rf, gf in zip(real_features, generated_features))
        style_loss = sum(self.style_loss_layer(agf, gf) for agf, gf in zip(anime_gray_features, generated_features))
        tf.print("Content Loss:", content_loss)
        tf.print("Style Loss:", style_loss)
        return content_loss, style_loss

    def gradient_penalty(self, real, fake):
        """Compute gradient penalty for WGAN-GP and related losses."""
        if 'dragan' in self.gan_type:
            eps = tf.random.uniform(shape=tf.shape(real), minval=0., maxval=1.)
            _, x_var = tf.nn.moments(real, axes=[0, 1, 2, 3])
            x_std = tf.sqrt(x_var)
            fake = real + 0.5 * x_std * eps

        alpha = tf.random.uniform(shape=[tf.shape(real)[0], 1, 1, 1], minval=0., maxval=1.)
        interpolated = real + alpha * (fake - real)

        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated)
            disc_interpolated = self.discriminator(interpolated, training=True)
        
        grad = gp_tape.gradient(disc_interpolated, [interpolated])[0]
        grad_norm = tf.norm(tf.reshape(grad, [tf.shape(real)[0], -1]), axis=1)
        
        if 'lp' in self.gan_type:
            GP = self.ld * tf.reduce_mean(tf.square(tf.maximum(0.0, grad_norm - 1.)))
        elif 'gp' in self.gan_type or self.gan_type == 'dragan':
            GP = self.ld * tf.reduce_mean(tf.square(grad_norm - 1.))
        else:
            GP = 0.0
        
        return GP

    def save_sample(self, epoch):
        """Save sample images for visual inspection."""
        val_files = glob(os.path.join(self.val_path, '*.*'))
        save_path = os.path.join(self.sample_dir, self.model_dir, f'{epoch:03d}')
        check_folder(save_path)
        for i, sample_file in enumerate(val_files):
            print(f'Saving sample: {i} {sample_file}')
            try:
                sample_image, _ = self.real_image_generator.load_image(sample_file)
                sample_image = np.expand_dims(sample_image, axis=0)  # Add batch dimension
                sample_image = tf.convert_to_tensor(sample_image, dtype=tf.float32)
                test_generated = self.generator(sample_image, training=False)
                test_generated = tf.squeeze(test_generated, axis=0)
                # Convert to [0, 255] and uint8
                sample_image = tf.cast((sample_image[0] + 1.0) * 127.5, tf.uint8)
                test_generated = tf.cast((test_generated + 1.0) * 127.5, tf.uint8)
                # Save images
                Image.fromarray(sample_image.numpy()).save(os.path.join(save_path, f'{i:03d}_a.jpg'))
                Image.fromarray(test_generated.numpy()).save(os.path.join(save_path, f'{i:03d}_b.jpg'))
            except Exception as e:
                print(f"Error saving sample {sample_file}: {e}")
        print(f"Saved samples for epoch {epoch} at {save_path}")
    def save_onnx_model(self, epoch):
        """Save the model in ONNX format."""
        model_path = os.path.join(self.checkpoint_dir, self.model_dir, f'model_epoch_{epoch}.onnx')
        # Convert the Keras model to ONNX format
        model_proto = tf2onnx.convert.from_keras(self)
        with open(model_path, "wb") as f:
            f.write(model_proto.SerializeToString())
        print(f"Model saved in ONNX format at {model_path}")
        
    def train(self):
        """Train the AnimeGANv2 model."""
        # Initialize TensorBoard writer
        self.summary_writer = tf.summary.create_file_writer(os.path.join(self.log_dir, self.model_dir))
        
        for epoch in range(1, self.epoch + 1):
            start_time = time.time()
            print(f"Starting epoch {epoch}/{self.epoch}")
            
            init_mean_loss = []
            mean_loss = []
            for idx in range(int(self.dataset_num / self.batch_size)):
                # Load batch data
                batch = self.real_image_generator.load_images()
                real = batch[:,0,:,:,:]
                anime = self.anime_image_generator.load_images()[:,0,:,:,:]
                anime_smooth = self.anime_smooth_generator.load_images()[:,0,:,:,:]

                # Training phase
                if epoch <= self.init_epoch:
                    # Pre-training Generator
                    with tf.GradientTape() as tape:
                        generated = self.generator(real, training=True)
#                         self.generator.summary()
                        content_loss, style_loss = self.compute_con_sty_loss(real, generated, anime)
                        color_loss = color_loss_fn(real, generated)
                        tv_loss = self.tv_weight * self.total_variation_loss_layer(generated)
                        total_loss = (self.con_weight * content_loss +
                                      self.sty_weight * style_loss +
                                      self.color_weight * color_loss +
                                      tv_loss)
                    gradients = tape.gradient(total_loss, self.generator.trainable_variables)
                    
                    init_mean_loss.append(total_loss)
                    # Kiểm tra nếu cả gradient và biến trainable đều có giá trị
                    if gradients and self.generator.trainable_variables:
                        self.g_optimizer.apply_gradients(zip(gradients, self.generator.trainable_variables))
                    else:
                        print("Error: Either gradients or trainable variables are empty!")
                    if (idx + 1) % 200 == 0:
                        print(f"Epoch: {epoch} Step: {idx}/{int(self.dataset_num / self.batch_size)} "
                              f"Time: {time.time() - start_time:.2f}s Init_Loss: {total_loss.numpy():.8f} "
                              f"Mean_Init_Loss: {np.mean(init_mean_loss).item():.8f}")
                        init_mean_loss.clear()
                else:
                    # Update Discriminator
                    with tf.GradientTape() as tape:
                        generated = self.generator(real, training=True)
                        real_output = self.discriminator(anime, training=True)
                        anime_gray_output = self.discriminator(anime, training=True)
                        fake_output = self.discriminator(generated, training=True)
                        smooth_output = self.discriminator(anime_smooth, training=True)
                        
                        # Compute losses
                        d_loss = self.discriminator_loss_layer([real_output, anime_gray_output, fake_output, smooth_output]) + self.gradient_penalty(anime, generated)
                    gradients = tape.gradient(d_loss, self.discriminator.trainable_variables)
                    self.d_optimizer.apply_gradients(zip(gradients, self.discriminator.trainable_variables))

                    # Update Generator
                    with tf.GradientTape() as tape:
                        generated = self.generator(real, training=True)
                        fake_output = self.discriminator(generated, training=True)
                        content_loss, style_loss = self.compute_con_sty_loss(real, generated, anime)
                        color_loss = color_loss_fn(real, generated)
                        tv_loss = self.tv_weight * self.total_variation_loss_layer(generated)
                        g_adv_loss = self.g_adv_weight * self.generator_loss_layer(fake_output)
                        g_loss = (self.con_weight * content_loss +
                                  self.sty_weight * style_loss +
                                  self.color_weight * color_loss +
                                  tv_loss +
                                  g_adv_loss)
                    gradients = tape.gradient(g_loss, self.generator.trainable_variables)
                    self.g_optimizer.apply_gradients(zip(gradients, self.generator.trainable_variables))
                    mean_loss.append([d_loss.numpy(),g_loss.numpy()])

                    if (idx + 1) % 200 == 0:
                        print(f"Epoch: {epoch} Step: {idx}/{int(self.dataset_num / self.batch_size)} "
                              f"Time: {time.time() - start_time:.2f}s D_Loss: {d_loss.numpy():.8f}, G_Loss: {g_loss.numpy():.8f} "
                              f"Mean_dLoss: {np.mean(mean_loss,axis=0)[0]:.8f}"
                              f"Mean_gLoss: {np.mean(mean_loss,axis=0)[1]:.8f}")
                        mean_loss.clear()

            # Save checkpoints
            if (epoch % self.save_freq == 0):
                self.checkpoint_manager.save()
                print(f"Checkpoint saved at epoch {epoch}")
            
            # Save sample images
            self.save_sample(epoch)
            self.save_onnx_model(epoch)
            
            print(f"Epoch {epoch} completed in {time.time() - start_time:.2f}s")

        

In [None]:
# -------------------------------
# Main Function
# -------------------------------

def main():
    """Main function to train AnimeGANv2."""
    
    # Enable mixed precision if desired (optional)
    # mixed_precision.set_global_policy('mixed_float16')
    
    # Set memory growth option for GPU in TensorFlow 2.x
    physical_devices = tf.config.list_physical_devices('GPU')
    if physical_devices:
        try:
            for device in physical_devices:
                tf.config.experimental.set_memory_growth(device, True)
            print("GPU memory growth set successfully.")
        except RuntimeError as e:
            print(f"Error setting GPU memory growth: {e}")
    
    # Configuration dictionary
    config = {
        'model_name': 'AnimeGANv2',
        'dataset_name': 'Hayao',  # Change as needed
        'train_photo_path': '/kaggle/input/animegan/train_photo',
        'style_path': '/kaggle/input/animegan/Hayao/style',
        'smooth_path': '/kaggle/input/animegan/Hayao/smooth',
        'val_path': '/kaggle/input/animegan/val',
        'epoch': 101,
        'init_epoch': 20,
        'gan_type': 'lsgan',  # Options: 'gan', 'lsgan', 'wgan-gp', etc.
        'batch_size': 12,
        'save_freq': 1,
        'init_lr': 0.0001,
        'g_lr': 0.00001,
        'd_lr': 0.00001,
        'ld': 10.0,  # Gradient penalty lambda
        'tv_weight': 1.0,  # Total Variation Loss weight
        'con_weight': 1.5,  # Content Loss weight
        'sty_weight': 3.0,  # Style Loss weight
        'color_weight': 10.0,  # Color Loss weight
        'g_adv_weight': 300.0,  # Adversarial Loss weight for Generator
        'd_adv_weight': 300.0,  # Adversarial Loss weight for Discriminator
        'img_size': [256, 256],
        'img_ch': 3,
        'ch': 64,  # Base channel number per layer
        'n_dis': 3,  # Number of discriminator layers
        'sn': True,  # Use spectral normalization
        'checkpoint_dir': 'checkpoint',
        'log_dir': 'logs',
        'sample_dir': 'samples'
    }
    
    # Initialize and train the model
    gan = AnimeGANv2(config)
    gan.train()
    print(" [*] Training finished!")

if __name__ == '__main__':
    main()
