<a href="https://colab.research.google.com/github/mdrk300902/demo-repo/blob/main/CycleGAN_Translating_(Horses__Zebras).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Colab setup: install necessary libraries if needed
!pip install torch torchvision tqdm plotly pillow



In [None]:
import os
import numpy as np
import pandas as pd
import math, sys, random
import time, datetime
from zipfile import ZipFile
import shutil
from PIL import Image
import glob, itertools
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.utils import save_image, make_grid
import plotly.graph_objects as go
import warnings
warnings.filterwarnings("ignore")


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


## Introduction
### In this notebook we use [CycleGAN](https://arxiv.org/abs/1703.10593) to convert Horse Images to Zebra ones using [Horse2zebra Dataset](https://www.kaggle.com/balraj98/horse2zebra-dataset).

<img src="https://junyanz.github.io/CycleGAN/images/teaser.jpg" width="1000" height="900"/>
<h4></h4>
<h4><center>Image Source:  <a href="https://arxiv.org/pdf/1703.10593.pdf">Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks [C. Jun-Yan Zhu et al.]</a></center></h4>

### Libraries 📚⬇

In [None]:
from google.colab import files
# Upload your zip files for: 'cyclegan-translating-horses-zebras-pytorch.zip' and 'horse2zebra-dataset.zip'
uploaded = files.upload()

Saving horse2zebra.zip to horse2zebra.zip


### Settings

In [None]:
# Extract model and dataset zips
with ZipFile("/content/horse2zebra.zip", 'r') as zip_ref:
    zip_ref.extractall("/content")
with ZipFile("/content/horse2zebra.zip", 'r') as zip_ref:
    zip_ref.extractall("/content")
print("Files extracted successfully!")

Files extracted successfully!


### Define Utilities

In [None]:
# Paths and parameters
pretrained_model_path = "/content/cyclegan-translating-horses-zebras-pytorch/saved_models"
dataset_path = "/content/horse2zebra"
epoch_start = 25
n_epochs = 60
batch_size = 4
lr = 0.0001
b1, b2 = 0.5, 0.999
decay_epoch = 1
n_workers = 2  # Use small value in Colab
img_height = 256
img_width = 256
channels = 3
sample_interval = 100
checkpoint_interval = -1
n_residual_blocks = 9
lambda_cyc = 10.0
lambda_id = 5.0
debug_mode = False
os.makedirs("images", exist_ok=True)
os.makedirs("saved_models", exist_ok=True)


### Define Dataset Class

In [None]:
def to_rgb(image):
    rgb_image = Image.new("RGB", image.size)
    rgb_image.paste(image)
    return rgb_image

class ReplayBuffer:
    def __init__(self, max_size=50):
        assert max_size > 0
        self.max_size = max_size
        self.data = []
    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0, 1) > 0.5:
                    i = random.randint(0, self.max_size - 1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return Variable(torch.cat(to_return))


### Get Train/Test Dataloaders

In [None]:
class ImageDataset(Dataset):
    def __init__(self, root, transforms_=None, unaligned=False, mode="train"):
        self.transform = transforms.Compose(transforms_)
        self.unaligned = unaligned
        self.files_A = sorted(glob.glob(os.path.join(root, f"{mode}A") + "/*.*"))
        self.files_B = sorted(glob.glob(os.path.join(root, f"{mode}B") + "/*.*"))
        if debug_mode:
            self.files_A = self.files_A[:100]
            self.files_B = self.files_B[:100]
    def __getitem__(self, index):
        image_A = Image.open(self.files_A[index % len(self.files_A)])
        image_B = Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)] if self.unaligned else self.files_B[index % len(self.files_B)])
        if image_A.mode != "RGB":
            image_A = to_rgb(image_A)
        if image_B.mode != "RGB":
            image_B = to_rgb(image_B)
        item_A = self.transform(image_A)
        item_B = self.transform(image_B)
        return {"A": item_A, "B": item_B}
    def __len__(self):
        return max(len(self.files_A), len(self.files_B))


<h3><center>CycleGAN Model Architecture</center></h3>
<img src="https://miro.medium.com/max/700/1*_KxtJIVtZjVaxxl-Yl1vJg.png" width="900" height="900"/>
<h4></h4>
<h4><center>Image Source:  <a href="https://arxiv.org/pdf/1703.10593.pdf">Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks [C. Jun-Yan Zhu et al.]</a></center></h4>

### Define Model Classes

In [None]:
transforms_ = [
    transforms.Resize(int(img_height * 1.12), Image.BICUBIC),
    transforms.RandomCrop((img_height, img_width)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
train_dataloader = DataLoader(
    ImageDataset(f"{dataset_path}", transforms_=transforms_, unaligned=True),
    batch_size=batch_size,
    shuffle=True,
    num_workers=n_workers,
)
test_dataloader = DataLoader(
    ImageDataset(f"{dataset_path}", transforms_=transforms_, unaligned=True, mode="test"),
    batch_size=1,
    shuffle=True,
    num_workers=1,
)


### Train CycleGAN

In [None]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
        if hasattr(m, "bias") and m.bias is not None:
            torch.nn.init.constant_(m.bias.data, 0.0)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
        )
    def forward(self, x):
        return x + self.block(x)

class GeneratorResNet(nn.Module):
    def __init__(self, input_shape, num_residual_blocks):
        super(GeneratorResNet, self).__init__()
        channels = input_shape[0]
        out_features = 64
        model = [
            nn.ReflectionPad2d(channels),
            nn.Conv2d(channels, out_features, 7),
            nn.InstanceNorm2d(out_features),
            nn.ReLU(inplace=True),
        ]
        in_features = out_features
        for _ in range(2):
            out_features *= 2
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features
        for _ in range(num_residual_blocks):
            model += [ResidualBlock(out_features)]
        for _ in range(2):
            out_features //= 2
            model += [
                nn.Upsample(scale_factor=2),
                nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features
        model += [nn.ReflectionPad2d(channels), nn.Conv2d(out_features, channels, 7), nn.Tanh()]
        self.model = nn.Sequential(*model)
    def forward(self, x):
        return self.model(x)

class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()
        channels, height, width = input_shape
        self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
        def discriminator_block(in_filters, out_filters, normalize=True):
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        self.model = nn.Sequential(
            *discriminator_block(channels, 64, normalize=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, 4, padding=1)
        )
    def forward(self, img):
        return self.model(img)


In [None]:
# Losses, models, optimizers
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

input_shape = (channels, img_height, img_width)

G_AB = GeneratorResNet(input_shape, n_residual_blocks).to(device)
G_BA = GeneratorResNet(input_shape, n_residual_blocks).to(device)
D_A = Discriminator(input_shape).to(device)
D_B = Discriminator(input_shape).to(device)
criterion_GAN.to(device)
criterion_cycle.to(device)
criterion_identity.to(device)

import os

# Always initialize weights first
G_AB.apply(weights_init_normal)
G_BA.apply(weights_init_normal)
D_A.apply(weights_init_normal)
D_B.apply(weights_init_normal)

# Only try to load pretrained weights if files exist
try:
    if epoch_start != 0:
        G_AB_path = f"{pretrained_model_path}/G_AB.pth"
        G_BA_path = f"{pretrained_model_path}/G_BA.pth"
        D_A_path = f"{pretrained_model_path}/D_A.pth"
        D_B_path = f"{pretrained_model_path}/D_B.pth"
        if all(os.path.isfile(p) for p in [G_AB_path, G_BA_path, D_A_path, D_B_path]):
            G_AB.load_state_dict(torch.load(G_AB_path))
            G_BA.load_state_dict(torch.load(G_BA_path))
            D_A.load_state_dict(torch.load(D_A_path))
            D_B.load_state_dict(torch.load(D_B_path))
        else:
            print("No checkpoint files were found; training will start from scratch.")
except Exception as e:
    print(f"Could not load weights: {e}\nTraining will start from scratch.")


optimizer_G = torch.optim.Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=lr, betas=(b1, b2))
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=lr, betas=(b1, b2))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=lr, betas=(b1, b2))
Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor

fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()
train_counter = []
train_losses_gen, train_losses_id, train_losses_gan, train_losses_cyc = [], [], [], []
train_losses_disc, train_losses_disc_a, train_losses_disc_b = [], [], []
test_counter = [2*idx*len(train_dataloader.dataset) for idx in range(epoch_start+1, n_epochs+1)]
test_losses_gen, test_losses_disc = [], []

for epoch in range(epoch_start, n_epochs):
    loss_gen = loss_id = loss_gan = loss_cyc = 0.0
    loss_disc = loss_disc_a = loss_disc_b = 0.0
    tqdm_bar = tqdm(train_dataloader, desc=f'Training Epoch {epoch} ', total=int(len(train_dataloader)))
    for batch_idx, batch in enumerate(tqdm_bar):
        real_A = Variable(batch["A"].type(Tensor)).to(device)
        real_B = Variable(batch["B"].type(Tensor)).to(device)
        valid = Variable(Tensor(np.ones((real_A.size(0), *D_A.output_shape))), requires_grad=False).to(device)
        fake = Variable(Tensor(np.zeros((real_A.size(0), *D_A.output_shape))), requires_grad=False).to(device)

        # Train Generators
        G_AB.train()
        G_BA.train()
        optimizer_G.zero_grad()
        loss_id_A = criterion_identity(G_BA(real_A), real_A)
        loss_id_B = criterion_identity(G_AB(real_B), real_B)
        loss_identity = (loss_id_A + loss_id_B) / 2
        fake_B = G_AB(real_A)
        loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
        fake_A = G_BA(real_B)
        loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)
        loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2
        recov_A = G_BA(fake_B)
        loss_cycle_A = criterion_cycle(recov_A, real_A)
        recov_B = G_AB(fake_A)
        loss_cycle_B = criterion_cycle(recov_B, real_B)
        loss_cycle = (loss_cycle_A + loss_cycle_B) / 2
        loss_G = lambda_id * loss_identity + loss_GAN + lambda_cyc * loss_cycle
        loss_G.backward()
        optimizer_G.step()

        # Train Discriminator-A
        D_A.train()
        optimizer_D_A.zero_grad()
        loss_real = criterion_GAN(D_A(real_A), valid)
        fake_A_ = fake_A_buffer.push_and_pop(fake_A)
        loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
        loss_D_A = (loss_real + loss_fake) / 2
        loss_D_A.backward()
        optimizer_D_A.step()

        # Train Discriminator-B
        D_B.train()
        optimizer_D_B.zero_grad()
        loss_real = criterion_GAN(D_B(real_B), valid)
        fake_B_ = fake_B_buffer.push_and_pop(fake_B)
        loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
        loss_D_B = (loss_real + loss_fake) / 2
        loss_D_B.backward()
        optimizer_D_B.step()
        loss_D = (loss_D_A + loss_D_B) / 2

        # Log Progress
        loss_gen += loss_G.item(); loss_id += loss_identity.item(); loss_gan += loss_GAN.item(); loss_cyc += loss_cycle.item()
        loss_disc += loss_D.item(); loss_disc_a += loss_D_A.item(); loss_disc_b += loss_D_B.item()
        train_counter.append(2*(batch_idx*batch_size + real_A.size(0) + epoch*len(train_dataloader.dataset)))
        train_losses_gen.append(loss_G.item()); train_losses_id.append(loss_identity.item()); train_losses_gan.append(loss_GAN.item()); train_losses_cyc.append(loss_cycle.item())
        train_losses_disc.append(loss_D.item()); train_losses_disc_a.append(loss_D_A.item()); train_losses_disc_b.append(loss_D_B.item())
        tqdm_bar.set_postfix(Gen_loss=loss_gen/(batch_idx+1), identity=loss_id/(batch_idx+1), adv=loss_gan/(batch_idx+1), cycle=loss_cyc/(batch_idx+1),
                            Disc_loss=loss_disc/(batch_idx+1), disc_a=loss_disc_a/(batch_idx+1), disc_b=loss_disc_b/(batch_idx+1))

    # Testing (run same logic as above for evaluation, and optionally save outputs as images and models)
    # Repeat testing loop as you did for training, on test_dataloader, and save images.
    # Also append to test_losses_gen, test_losses_disc, and periodically save model weights.


No checkpoint files were found; training will start from scratch.


Training Epoch 25 :   0%|          | 0/334 [00:00<?, ?it/s]

Training Epoch 26 :   0%|          | 0/334 [00:00<?, ?it/s]

Training Epoch 27 :   0%|          | 0/334 [00:00<?, ?it/s]

Training Epoch 28 :   0%|          | 0/334 [00:00<?, ?it/s]

Training Epoch 29 :   0%|          | 0/334 [00:00<?, ?it/s]

Training Epoch 30 :   0%|          | 0/334 [00:00<?, ?it/s]

In [None]:
fig = go.Figure()
fig.add_trace(go.Scatter(x=train_counter, y=train_losses_gen, mode='lines', name='Train Gen Loss (Loss_G)'))
fig.add_trace(go.Scatter(x=train_counter, y=train_losses_id, mode='lines', name='Train Gen Identity Loss'))
fig.add_trace(go.Scatter(x=train_counter, y=train_losses_gan, mode='lines', name='Train Gen GAN Loss'))
fig.add_trace(go.Scatter(x=train_counter, y=train_losses_cyc, mode='lines', name='Train Gen Cyclic Loss'))
fig.add_trace(go.Scatter(x=test_counter, y=test_losses_gen, marker_symbol='star-diamond',
                         marker_color='orange', marker_line_width=1, marker_size=9, mode='markers', name='Test Gen Loss (Loss_G)'))
fig.update_layout(width=1000, height=500, title="Train vs. Test Generator Loss",
                  xaxis_title="Number of training examples seen (A+B)", yaxis_title="Generator Losses")
fig.show()

fig = go.Figure()
fig.add_trace(go.Scatter(x=train_counter, y=train_losses_disc, mode='lines', name='Train Disc Loss (Loss_D)'))
fig.add_trace(go.Scatter(x=train_counter, y=train_losses_disc_a, mode='lines', name='Train Disc-A Loss'))
fig.add_trace(go.Scatter(x=train_counter, y=train_losses_disc_b, mode='lines', name='Train Disc-B Loss'))
fig.add_trace(go.Scatter(x=test_counter, y=test_losses_disc, marker_symbol='star-diamond',
                         marker_color='orange', marker_line_width=1, marker_size=9, mode='markers', name='Test Disc Loss (Loss_G)'))
fig.update_layout(width=1000, height=500, title="Train vs. Test Discriminator Loss",
                  xaxis_title="Number of training examples seen (A+B)", yaxis_title="Discriminator Losses")
fig.show()


In [None]:
from IPython.display import display
from PIL import Image
import glob

# Display generated images (example for first few)
sample_images = sorted(glob.glob("images/*.png"))[:5]
for file in sample_images:
    display(Image.open(file))
