# GAN Animes Images Generation

## Import librairies


In [None]:
!pip install loguru


In [None]:
import logging
from loguru import logger

import sys
import os
import random
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms

from tqdm import tqdm
from torchvision import transforms
from torchvision.utils import save_image, make_grid

import numpy as np
import matplotlib.pyplot as plt

from google.colab import drive

logger.info(f"{sys.version = }")
logger.info(f"{np.__version__ = }")
logger.info(f"{torch.__version__ = }")

## Preparing env: load images

In [None]:
drive.mount("/content/drive")
device_config = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
logger.info(f"{device_config = }")

In [None]:
!unzip '/content/drive/MyDrive/IA/GAN/anime_images.zip' -d'/content/images'

## Load dataset

In [None]:
batch_quantity = 64
img_dimension = 64

# Create the dataset
anime_data = dset.ImageFolder(root = '/content/images',
    transform = transforms.Compose([
        transforms.Resize(img_dimension),
        transforms.CenterCrop(img_dimension),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]))

# Create the dataloader
data_loader = torch.utils.data.DataLoader(
    anime_data,
    batch_size = batch_quantity,
    shuffle = True,
    num_workers = 2,
    drop_last = True
)

logger.info(f"Nb batchs/{len(data_loader) = }")
logger.info(f"Nb images : {len(anime_data)}")

## Display some images

In [None]:
def display_image_grid(images):
    images = images / 2 + 0.5
    image_grid = make_grid(images, 8)
    image_grid = image_grid.permute(1, 2, 0)
    image_grid = image_grid.cpu().numpy()
    plt.imshow(image_grid)
    plt.xticks([])
    plt.yticks([])
    plt.show()
    plt.close()

sample_batch = next(iter(data_loader))
plt.figure(figsize = (8 , 8))
plt.axis("off")

display_image_grid(sample_batch[0] )



# Models: gen and disc

## Build models


In [None]:
def initialize_weights(module):
    module_name = module.__class__.__name__

    if module_name.find('Conv') != -1:
        nn.init.normal_(module.weight.data, 0.0, 0.02)
    elif module_name.find('BatchNorm') != -1:
        nn.init.normal_(module.weight.data, 1.0, 0.02)
        nn.init.constant_(module.bias.data, 0)

In [None]:
# Number of channels in the training images. For color images this is 3 (R G B)
num_channels = 3

# Size of z latent vector
z_dim = 100

# Size of feature maps in generator
gen_feature_size = 64

# Size of feature maps in discriminator
disc_feature_size = 64

# Number of training epochs
epoch_count = 50

# Learning rate for optimizers
learning_rate = 2e-4

learning_rate

In [None]:

class GeneratorNetwork(nn.Module):
    def __init__(self):
        super(GeneratorNetwork, self).__init__()

        self.model = nn.Sequential(
            nn.ConvTranspose2d(z_dim, gen_feature_size * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(gen_feature_size * 8),
            nn.ReLU(True),

            nn.ConvTranspose2d(gen_feature_size * 8, gen_feature_size * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(gen_feature_size * 4),
            nn.ReLU(True),

            nn.ConvTranspose2d(gen_feature_size * 4, gen_feature_size * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(gen_feature_size * 2),
            nn.ReLU(True),

            nn.ConvTranspose2d(gen_feature_size * 2, gen_feature_size, 4, 2, 1, bias=False),
            nn.BatchNorm2d(gen_feature_size),
            nn.ReLU(True),

            nn.ConvTranspose2d(gen_feature_size, num_channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input_data):
        return self.model(input_data)


net_generator = GeneratorNetwork().to(device_config)
net_generator.apply(initialize_weights)
logger.info(net_generator)

In [None]:

class DiscriminatorNetwork(nn.Module):
    def __init__(self):
        super(DiscriminatorNetwork, self).__init__()

        self.model = nn.Sequential(
            nn.Conv2d(num_channels, disc_feature_size, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(disc_feature_size, disc_feature_size * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(disc_feature_size * 2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(disc_feature_size * 2, disc_feature_size * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(disc_feature_size * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(disc_feature_size * 4, disc_feature_size * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(disc_feature_size * 8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(disc_feature_size * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input_data):
        return self.model(input_data)

net_discriminator = DiscriminatorNetwork().to(device_config)
net_discriminator.apply(initialize_weights)
logger.info(net_discriminator)

In [None]:
latent_vector = torch.randn(64, z_dim, 1, 1, device=device_config)
sample_output = net_generator(latent_vector)
logger.info(sample_output.shape)
sample_output[0]


## Random generated output can be visualised individually or in batches

In [None]:
plt.imshow(np.transpose(make_grid(sample_output[0].to(device_config), padding=2, normalize=True).cpu(), (1, 2, 0)))

In [None]:
plt.imshow(np.transpose(make_grid(sample_output.to(device_config), padding=2, normalize=True).cpu(), (1, 2, 0)))

In [None]:
net_discriminator.eval()

with torch.no_grad():
    prediction = net_discriminator(sample_output)

logger.info(prediction[:5])

## Train models

In [None]:

loss_function = nn.BCELoss()


real_label_val = 1.
fake_label_val = 0.

optimizer_discriminator = optim.Adam(net_discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optimizer_generator = optim.Adam(net_generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))

In [None]:
def train(epoch_count, data_loader, net_discriminator, net_generator, device_config, batch_quantity, real_label_val,
          fake_label_val, optimizer_discriminator, optimizer_generator, z_dim):
    for epoch in tqdm(range(epoch_count)):
        for idx, dataset in enumerate(data_loader, 0):

            net_discriminator.zero_grad()

            real_imgs = dataset[0].to(device_config)
            real_img_labels = torch.full((batch_quantity,), real_label_val, dtype=torch.float, device=device_config)

            real_output = net_discriminator(real_imgs).view(-1)
            disc_real_loss = loss_function(real_output, real_img_labels)
            disc_real_loss.backward()

            mean_real_score = real_output.mean().item()

            noise_vec = torch.randn(batch_quantity, z_dim, 1, 1, device=device_config)

            generated_imgs = net_generator(noise_vec)
            fake_img_labels = torch.full((batch_quantity,), fake_label_val, dtype=torch.float, device=device_config)

            fake_output = net_discriminator(generated_imgs.detach()).view(-1)
            disc_fake_loss = loss_function(fake_output, fake_img_labels)
            disc_fake_loss.backward()

            mean_fake_score = fake_output.mean().item()

            optimizer_discriminator.step()

            net_generator.zero_grad()

            output = net_discriminator(generated_imgs).view(-1)
            gen_loss = loss_function(output, real_img_labels)
            gen_loss.backward()
            optimizer_generator.step()

            if (idx % 200 == 0) or ((epoch == epoch_count - 1) and (idx == len(data_loader) - 1)):
                logger.info(f"""Epoch {epoch + 1}: - Step: {idx} |
                          D_real Loss: {disc_real_loss:.3f} |
                          D_fake Loss: {disc_fake_loss:.3f} |
                          G_Loss: {gen_loss:.3f} |
                          Real_score {mean_real_score:.3f} |
                          Fake_score {mean_fake_score:.3f}
                """)

                with torch.no_grad():
                    fake_imgs_display = net_generator(noise_vec).reshape(batch_quantity, 3, 64, 64).detach().cpu()
                    display_image_grid(fake_imgs_display)

    full_discriminator_model_path = "/content/drive/MyDrive/IA/GAN/discriminator_model.pth"
    torch.save(net_discriminator, full_discriminator_model_path)
    logger.info(f"Final Discriminator Model saved : {full_discriminator_model_path}")

    full_generator_model_path = "/content/drive/MyDrive/IA/GAN/generator_model.pth"
    torch.save(net_generator, full_generator_model_path)
    logger.info(f"Final Generator model saved : {full_generator_model_path}")



train(epoch_count, data_loader, net_discriminator, net_generator, device_config, batch_quantity, real_label_val,
          fake_label_val, optimizer_discriminator, optimizer_generator, z_dim,)

In [None]:

net_generator_loaded = torch.load("/content/drive/MyDrive/IA/GAN/generator_model.pth", weights_only=False)
logger.info(net_generator_loaded.eval())
logger.info("Generator model loaded.")


net_discriminator_loaded = torch.load("/content/drive/MyDrive/IA/GAN/discriminator_model.pth", weights_only=False)
logger.info(net_discriminator_loaded.eval())
logger.info("Discriminator model loaded.")


In [None]:
noise_vec = torch.randn(64, z_dim, 1, 1, device=device_config)

new_generatedimgs = net_generator_loaded(noise_vec).reshape(64, 3, 64, 64).detach()
plt.figure(figsize = (8 , 8))
plt.axis("off")
plt.title("Generated images")

display_image_grid(new_generatedimgs)

In [None]:

prediction = net_discriminator_loaded(new_generatedimgs)

mean=prediction.mean().item()
std=prediction.std().item()
p_max = prediction.max().item()
idx_max = prediction.argmax().item()
logger.info(f"Mean proba = {mean:5f}")
logger.info(f"Std = {std:5f}")

logger.info(f"Max proba = {p_max:02f}")
logger.info(f"iMax proba (image id/num) = {idx_max}")


In [None]:
prediction.mean().item()