<a href="https://colab.research.google.com/github/karthikeya-io/image-cartoonization/blob/main/cartoonization_cyclegan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torchvision.utils import save_image
from PIL import Image

from torchvision.datasets import ImageFolder


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# !ls "/content/drive/MyDrive/datasets/DBZ/vegeta"

ls: cannot access '/content/drive/MyDrive/datasets/DBZ/vegeta': No such file or directory


In [None]:
def load_dataset(dataset_path, image_size, batch_size):
    transform = transforms.Compose([
        transforms.Resize(int(image_size * 1.12), Image.BICUBIC),
        transforms.RandomCrop(image_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    dataset = ImageFolder(dataset_path, transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    return dataloader

anf_dataloader = load_dataset("/content/drive/MyDrive/datasets/animeface", image_size=256, batch_size=1)
celeba_dataloader = load_dataset("/content/drive/MyDrive/datasets/celeba", image_size=256, batch_size=1)


In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        self.conv_block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_channels, in_channels, kernel_size=3),
            nn.InstanceNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_channels, in_channels, kernel_size=3),
            nn.InstanceNorm2d(in_channels),
        )

    def forward(self, x):
        return x + self.conv_block(x)

In [None]:
class Generator(nn.Module):
    def __init__(self, in_channels, out_channels, num_residual_blocks=9):
        super(Generator, self).__init__()

        # Initial convolution block
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(in_channels, 64, kernel_size=7),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
        ]

        # Downsampling
        in_features = 64
        out_features = in_features * 2
        for _ in range(2):
            model += [
                nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features
            out_features = in_features * 2

        # Residual blocks
        for _ in range(num_residual_blocks):
            model += [ResidualBlock(in_features)]

        # Upsampling
        out_features = in_features // 2
        for _ in range(2):
            model += [
                nn.ConvTranspose2d(in_features, out_features, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features
            out_features = in_features // 2

        # Output layer
        model += [nn.ReflectionPad2d(3), nn.Conv2d(64, out_channels, kernel_size=7), nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels):
        super(Discriminator, self).__init__()

        model = [
            nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
        ]

        model += [
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
        ]

        model += [
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
        ]
        model += [
            nn.Conv2d(256, 512, kernel_size=4, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
        ]

        model += [nn.Conv2d(512, 1, kernel_size=4, padding=1)]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

In [None]:
criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()

In [None]:
# Initialize the generators and discriminators
G_A2B = Generator(3, 3).cuda()
G_B2A = Generator(3, 3).cuda()
D_A = Discriminator(3).cuda()
D_B = Discriminator(3).cuda()

# Initialize the optimizers
optimizer_G = optim.Adam(list(G_A2B.parameters()) + list(G_B2A.parameters()), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_A = optim.Adam(D_A.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_B = optim.Adam(D_B.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Training loop
num_epochs = 100
for epoch in range(num_epochs):
    for i, (real_A, real_B) in enumerate(zip(celeba_dataloader, anf_dataloader)):

        real_A = real_A[0].cuda()
        real_B = real_B[0].cuda()

        # Update generators
        optimizer_G.zero_grad()

        # Identity loss
        loss_identity_A = criterion_identity(G_B2A(real_A), real_A)
        loss_identity_B = criterion_identity(G_A2B(real_B), real_B)

        # GAN loss
        fake_A = G_B2A(real_B)
        fake_B = G_A2B(real_A)
        loss_GAN_A2B = criterion_GAN(D_B(fake_B), torch.ones_like(D_B(fake_B)).cuda())
        loss_GAN_B2A = criterion_GAN(D_A(fake_A), torch.ones_like(D_A(fake_A)).cuda())

        # Cycle loss
        recov_A = G_B2A(fake_B)
        recov_B = G_A2B(fake_A)
        loss_cycle_A = criterion_cycle(recov_A, real_A)
        loss_cycle_B = criterion_cycle(recov_B, real_B)

        # Total generator loss
        loss_G = (
            loss_identity_A + loss_identity_B
            + loss_GAN_A2B + loss_GAN_B2A
            + loss_cycle_A + loss_cycle_B
        )
        loss_G.backward()
        optimizer_G.step()

        # Update discriminators
        optimizer_D_A.zero_grad()
        optimizer_D_B.zero_grad()

        # Discriminator A loss
        loss_D_A_real = criterion_GAN(D_A(real_A), torch.ones_like(D_A(real_A)).cuda())
        loss_D_A_fake = criterion_GAN(D_A(fake_A.detach()), torch.zeros_like(D_A(fake_A)).cuda())
        loss_D_A = (loss_D_A_real + loss_D_A_fake) / 2
        loss_D_A.backward()
        optimizer_D_A.step()

        # Discriminator B loss
        loss_D_B_real = criterion_GAN(D_B(real_B), torch.ones_like(D_B(real_B)).cuda())
        loss_D_B_fake = criterion_GAN(D_B(fake_B.detach()), torch.zeros_like(D_B(fake_B)).cuda())
        loss_D_B = (loss_D_B_real + loss_D_B_fake) / 2
        loss_D_B.backward()
        optimizer_D_B.step()

        # Print progress
        if i % 100 == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {i}: "
                f"Loss D_A: {loss_D_A.item():.4f}, "
                f"Loss D_B: {loss_D_B.item():.4f}, "
                f"Loss G: {loss_G.item():.4f}"
            )

    # Save images
    os.makedirs("/content/drive/MyDrive/cartooncyclegan/output_images", exist_ok=True)
    save_image(fake_A, f"/content/drive/MyDrive/cartooncyclegan/output_images/fake_A_{epoch}.png", normalize=True)
    save_image(fake_B, f"/content/drive/MyDrive/cartooncyclegan/output_images/fake_B_{epoch}.png", normalize=True)
    save_image(recov_A, f"/content/drive/MyDrive/cartooncyclegan/output_images/recov_A_{epoch}.png", normalize=True)
    save_image(recov_B, f"/content/drive/MyDrive/cartooncyclegan/output_images/recov_B_{epoch}.png", normalize=True)

    # Save model checkpoints
    os.makedirs("/content/drive/MyDrive/cartooncyclegan/checkpoints", exist_ok=True)
    torch.save(G_A2B.state_dict(), f"/content/drive/MyDrive/cartooncyclegan/checkpoints/G_A2B_{epoch}.pth")
    torch.save(G_B2A.state_dict(), f"/content/drive/MyDrive/cartooncyclegan/checkpoints/G_B2A_{epoch}.pth")
    torch.save(D_A.state_dict(), f"/content/drive/MyDrive/cartooncyclegan/checkpoints/D_A_{epoch}.pth")
    torch.save(D_B.state_dict(), f"/content/drive/MyDrive/cartooncyclegan/checkpoints/D_B_{epoch}.pth")


Epoch [0/100] Batch 0: Loss D_A: 0.7019, Loss D_B: 0.4892, Loss G: 3.9730
Epoch [0/100] Batch 100: Loss D_A: 0.2489, Loss D_B: 0.2489, Loss G: 1.2295
Epoch [0/100] Batch 200: Loss D_A: 0.2835, Loss D_B: 0.2680, Loss G: 1.3739
Epoch [0/100] Batch 300: Loss D_A: 0.2602, Loss D_B: 0.2333, Loss G: 1.4521
Epoch [0/100] Batch 400: Loss D_A: 0.2938, Loss D_B: 0.2708, Loss G: 1.4980
Epoch [0/100] Batch 500: Loss D_A: 0.3302, Loss D_B: 0.2512, Loss G: 2.0757
Epoch [0/100] Batch 600: Loss D_A: 0.1983, Loss D_B: 0.2291, Loss G: 1.7968
Epoch [0/100] Batch 700: Loss D_A: 0.3292, Loss D_B: 0.2948, Loss G: 1.3744
Epoch [0/100] Batch 800: Loss D_A: 0.2933, Loss D_B: 0.2770, Loss G: 1.2973
Epoch [0/100] Batch 900: Loss D_A: 0.1912, Loss D_B: 0.2483, Loss G: 1.5286
Epoch [1/100] Batch 0: Loss D_A: 0.2780, Loss D_B: 0.2394, Loss G: 1.2594
Epoch [1/100] Batch 100: Loss D_A: 0.2625, Loss D_B: 0.2206, Loss G: 1.4974
Epoch [1/100] Batch 200: Loss D_A: 0.1450, Loss D_B: 0.1930, Loss G: 1.2483
Epoch [1/100] Ba