[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/khetansarvesh/CV/blob/main/style_transfer/cycle_gan.ipynb)

In [1]:
from PIL import Image
import os
import numpy as np
import random
import copy
import sys
from tqdm import tqdm

import kagglehub
import albumentations as A
from albumentations.pytorch import ToTensorV2

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torchvision.utils import save_image

  check_for_updates()


In [2]:
os.environ["PYTHONHASHSEED"] = str(42)
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.cuda.manual_seed_all(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(DEVICE)

cuda


In [3]:
# creating a folder to store images while training
!mkdir saved_images

# **Dataset**


In [4]:
# downloading the dataset from kaggle : https://www.kaggle.com/datasets/suyashdamle/cyclegan
path = kagglehub.dataset_download("suyashdamle/cyclegan")
print("Path to dataset files:", path)

Path to dataset files: /root/.cache/kagglehub/datasets/suyashdamle/cyclegan/versions/1


In [5]:
class HorseZebraDataset(Dataset):
    def __init__(self):
        self.base = "/root/.cache/kagglehub/datasets/suyashdamle/cyclegan/versions/1/horse2zebra/horse2zebra/"
        self.horse_images = os.listdir(self.base + "trainA")
        self.zebra_images = os.listdir(self.base + "trainB")
        self.zebra_len = len(self.zebra_images)
        self.horse_len = len(self.horse_images)

    def __len__(self):
        return max(self.zebra_len, self.horse_len)

    def __getitem__(self, index):

        # finding image index
        zebra_img = self.zebra_images[index % self.zebra_len]
        horse_img = self.horse_images[index % self.horse_len]

        # finding image path
        horse_path = os.path.join(self.base + "trainA", horse_img)
        zebra_path = os.path.join(self.base + "trainB", zebra_img)

        # opening image and storing in array
        zebra_img = np.array(Image.open(zebra_path).convert("RGB"))
        horse_img = np.array(Image.open(horse_path).convert("RGB"))

        # performing transformations on the images zebra and horses
        transform = A.Compose(
                                [
                                    A.Resize(width=256, height=256),
                                    A.HorizontalFlip(p=0.5),
                                    A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
                                    ToTensorV2(),
                                ],
                                additional_targets={"image0": "image"},
                            )
        augmentations = transform(image=zebra_img, image0=horse_img)

        return augmentations["image"], augmentations["image0"]

In [6]:
dataset = HorseZebraDataset()
loader = DataLoader( dataset, batch_size=1, shuffle=True, num_workers=4, pin_memory=True)

# val_dataset = HorseZebraDataset(root_horse="cyclegan_test/horse1", root_zebra="cyclegan_test/zebra1", transform=transforms)
# val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, pin_memory=True)



# **Modelling**

In [7]:
class Generator(nn.Module):
    def __init__(self, img_channels=3, num_features=64, num_residuals=9):
        super().__init__()
        self.model = nn.Sequential(
                                        nn.Conv2d(3, 64, 7, 1, 3, padding_mode="reflect"), nn.InstanceNorm2d(64), nn.ReLU(inplace=True),
                                        nn.Conv2d(64, 128, 3, 2, 1, padding_mode="reflect"), nn.InstanceNorm2d(128), nn.ReLU(inplace=True),
                                        nn.Conv2d(128, 256, 3, 2, 1, padding_mode="reflect"), nn.InstanceNorm2d(256), nn.ReLU(inplace=True),
                                        nn.ConvTranspose2d(256, 128, 3, 2, 1, 1), nn.InstanceNorm2d(128),nn.ReLU(inplace=True),
                                        nn.ConvTranspose2d(128, 64, 3, 2, 1, 1), nn.InstanceNorm2d(64),nn.ReLU(inplace=True),
                                        nn.Conv2d(64, 3, 7, 1, 3, padding_mode="reflect")
                        )

    def forward(self, x):
        return torch.tanh(self.model(x))

In [8]:
class Discriminator(nn.Module):

    def __init__(self):
        super().__init__()

        self.model = nn.Sequential(
                                    nn.Conv2d(3, 64, 4, 2, 1, padding_mode="reflect"), nn.LeakyReLU(0.2, inplace=True),
                                    nn.Conv2d(64, 128, 4, 2, 1, bias = True, padding_mode="reflect"), nn.InstanceNorm2d(128), nn.LeakyReLU(0.2, inplace=True),
                                    nn.Conv2d(128, 256, 4, 2, 1, bias = True, padding_mode="reflect"), nn.InstanceNorm2d(256), nn.LeakyReLU(0.2, inplace=True),
                                    nn.Conv2d(256, 512, 4, 1, 1, bias = True, padding_mode="reflect"), nn.InstanceNorm2d(512), nn.LeakyReLU(0.2, inplace=True),
                                    nn.Conv2d(512, 1, 4, 1, 1, padding_mode="reflect")
                                  )

    def forward(self, x):
        return torch.sigmoid(self.model(x))

# **Training**

In [9]:
# initialing models
disc_H = Discriminator().to(DEVICE)
disc_Z = Discriminator().to(DEVICE)
gen_Z = Generator().to(DEVICE)
gen_H = Generator().to(DEVICE)

In [10]:
# defining optimizers
opt_disc = optim.Adam(list(disc_H.parameters()) + list(disc_Z.parameters()),lr=1e-5,betas=(0.5, 0.999),)
opt_gen = optim.Adam(list(gen_Z.parameters()) + list(gen_H.parameters()), lr=1e-5, betas=(0.5, 0.999),)

# defining loss functions
l1 = nn.L1Loss()
mse = nn.MSELoss()

# scalers
g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()

  g_scaler = torch.cuda.amp.GradScaler()
  d_scaler = torch.cuda.amp.GradScaler()


In [11]:
for epoch in range(10):
    H_reals = 0
    H_fakes = 0
    loop = tqdm(loader, leave=True)

    for idx, (zebra, horse) in enumerate(loop):

        # sending inputs to device at hand
        zebra, horse = zebra.to(DEVICE), horse.to(DEVICE)








        '''Train Discriminators H and Z while keeping Generator Constant'''
        with torch.cuda.amp.autocast():
            fake_horse = gen_H(zebra)
            D_H_real = disc_H(horse)
            D_H_fake = disc_H(fake_horse.detach())
            H_reals += D_H_real.mean().item()
            H_fakes += D_H_fake.mean().item()
            D_H_real_loss = mse(D_H_real, torch.ones_like(D_H_real))
            D_H_fake_loss = mse(D_H_fake, torch.zeros_like(D_H_fake))
            D_H_loss = D_H_real_loss + D_H_fake_loss

            fake_zebra = gen_Z(horse)
            D_Z_real = disc_Z(zebra)
            D_Z_fake = disc_Z(fake_zebra.detach())
            D_Z_real_loss = mse(D_Z_real, torch.ones_like(D_Z_real))
            D_Z_fake_loss = mse(D_Z_fake, torch.zeros_like(D_Z_fake))
            D_Z_loss = D_Z_real_loss + D_Z_fake_loss

            D_loss = (D_H_loss + D_Z_loss) / 2

        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()










        '''Train Generators H and Z while keeping Discriminator Constant'''
        with torch.cuda.amp.autocast():
            # adversarial loss for both generators
            D_H_fake = disc_H(fake_horse)
            D_Z_fake = disc_Z(fake_zebra)
            loss_G_H = mse(D_H_fake, torch.ones_like(D_H_fake))
            loss_G_Z = mse(D_Z_fake, torch.ones_like(D_Z_fake))

            # cycle loss
            cycle_zebra = gen_Z(fake_horse)
            cycle_horse = gen_H(fake_zebra)
            cycle_zebra_loss = l1(zebra, cycle_zebra)
            cycle_horse_loss = l1(horse, cycle_horse)

            # add all togethor
            G_loss = loss_G_Z + loss_G_H + 10*cycle_zebra_loss + 10*cycle_horse_loss

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()









        if idx % 200 == 0:
            save_image(fake_horse * 0.5 + 0.5, f"saved_images/horse_{idx}.png")
            save_image(fake_zebra * 0.5 + 0.5, f"saved_images/zebra_{idx}.png")



        loop.set_postfix(H_real=H_reals / (idx + 1), H_fake=H_fakes / (idx + 1))

  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
100%|██████████| 1334/1334 [02:15<00:00,  9.88it/s, H_fake=0.406, H_real=0.593]
100%|██████████| 1334/1334 [02:09<00:00, 10.27it/s, H_fake=0.39, H_real=0.607]
100%|██████████| 1334/1334 [02:10<00:00, 10.25it/s, H_fake=0.368, H_real=0.625]
100%|██████████| 1334/1334 [02:09<00:00, 10.27it/s, H_fake=0.357, H_real=0.636]
100%|██████████| 1334/1334 [02:10<00:00, 10.25it/s, H_fake=0.351, H_real=0.643]
100%|██████████| 1334/1334 [02:09<00:00, 10.26it/s, H_fake=0.337, H_real=0.66]
100%|██████████| 1334/1334 [02:09<00:00, 10.33it/s, H_fake=0.327, H_real=0.671]
100%|██████████| 1334/1334 [02:09<00:00, 10.32it/s, H_fake=0.318, H_real=0.679]
100%|██████████| 1334/1334 [02:09<00:00, 10.34it/s, H_fake=0.315, H_real=0.686]
100%|██████████| 1334/1334 [02:09<00:00, 10.30it/s, H_fake=0.307, H_real=0.689]
