<a href="https://colab.research.google.com/github/harryypham/MyMLPractice/blob/main/CycleGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torchvision.utils import save_image


%matplotlib inline

In [3]:
#https://www.kaggle.com/datasets/balraj98/horse2zebra-dataset
def rename(dir):
  path = f"/content/dataset/{dir}"
  i = 0
  for filename in os.listdir(path):
      f = os.path.join(path, filename)
      os.rename(f, os.path.join(path, f'{i}.jpg'))
      i += 1

rename("trainA")
rename("trainB")
rename("testA")
rename("testB")

In [6]:
class HorseZebraDataset(Dataset):
    def __init__(self, root_zebra, root_horse, transform=None):
        self.root_zebra = root_zebra
        self.root_horse = root_horse
        self.transform = transform

        self.zebra_images = os.listdir(root_zebra)
        self.horse_images = os.listdir(root_horse)
        self.length_dataset = max(len(self.zebra_images), len(self.horse_images)) # 1000, 1500
        self.zebra_len = len(self.zebra_images)
        self.horse_len = len(self.horse_images)

    def __len__(self):
        return self.length_dataset

    def __getitem__(self, index):
        zebra_img = self.zebra_images[index % self.zebra_len]
        horse_img = self.horse_images[index % self.horse_len]

        zebra_path = os.path.join(self.root_zebra, zebra_img)
        horse_path = os.path.join(self.root_horse, horse_img)

        zebra_img = np.array(Image.open(zebra_path).convert("RGB"))
        horse_img = np.array(Image.open(horse_path).convert("RGB"))

        if self.transform:
            augmentations = self.transform(image=zebra_img, image0=horse_img)
            zebra_img = augmentations["image"]
            horse_img = augmentations["image0"]

        return zebra_img, horse_img

transforms = A.Compose(
    [
        A.Resize(width=256, height=256),
        A.HorizontalFlip(p=0.5),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
    ],
    additional_targets={"image0": "image"},
)

In [7]:
class Conv(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size, stride, padding, pad_reflect=True, res_block=False):
    super().__init__()
    self.res_block = res_block
    self.conv = nn.Sequential(
        nn.ReflectionPad2d(padding),
        nn.Conv2d(in_channels, out_channels, kernel_size, stride),
        nn.InstanceNorm2d(out_channels),
        nn.ReLU(),
    )

  def forward(self, x):
    if self.res_block:
      return self.conv(x) + x
    return self.conv(x)

class ConvTranspose(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size, stride, padding, pad_reflect=True):
    super().__init__()
    self.conv = nn.Sequential(
        nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, output_padding=padding),
        nn.InstanceNorm2d(out_channels),
        nn.ReLU(),
    )

  def forward(self, x):
    return self.conv(x)

class Generator(nn.Module):
  def __init__(self, in_channels, out_channels, num_res_block=6):
    super().__init__()
    self.in_channels = in_channels
    self.out_channels = out_channels

    self.encoder = self._make_encoder()
    self.decoder = self._make_decoder()
    self.bottleneck = self._make_bottleneck(num_res_block)

  def _make_encoder(self):
    net = nn.Sequential(
        Conv(self.in_channels, 64, 7, 1, 3),
        Conv(64, 128, 3, 2, 1),
        Conv(128, 256, 3, 2, 1)
    )
    return net

  def _make_bottleneck(self, num_blocks):
    ls = []
    for _ in range(num_blocks):
      ls.append(Conv(256, 256, 3, 1, 1, res_block=True))
    return nn.Sequential(*ls)


  def _make_decoder(self):
    net = nn.Sequential(
        ConvTranspose(256, 128, 3, 2, 1),
        ConvTranspose(128, 64, 3, 2, 1),
        Conv(64, 3, 7, 1, 3)
    )
    return net
  def forward(self, x):
    x = self.encoder(x)
    x = self.bottleneck(x)
    x = self.decoder(x)
    return x

In [8]:
class DConv(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 4, 2),
        nn.InstanceNorm2d(out_channels),
        nn.LeakyReLU(0.2),
    )

  def forward(self, x):
    return self.conv(x)

class Discriminator(nn.Module):
  def __init__(self, in_channels):
    super().__init__()
    self.net = nn.Sequential(
        DConv(in_channels, 64),
        DConv(64, 128),
        DConv(128, 256),
        DConv(256, 512),
        nn.Conv2d(512, 1, 4, 1),
        nn.Sigmoid()
    )

  def forward(self, x):
    return self.net(x)

In [9]:
def train_fn(
    disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, device):
    H_reals = 0
    H_fakes = 0
    loop = tqdm(loader, leave=True)
    LAMBDA_IDENTITY = 0.0
    LAMBDA_CYCLE = 10

    for idx, (zebra, horse) in enumerate(loop):
        zebra = zebra.to(device)
        horse = horse.to(device)

        # Train Discriminators H and Z
        fake_horse = gen_H(zebra)
        D_H_real = disc_H(horse)
        D_H_fake = disc_H(fake_horse.detach())
        H_reals += D_H_real.mean().item()
        H_fakes += D_H_fake.mean().item()
        D_H_real_loss = mse(D_H_real, torch.ones_like(D_H_real))
        D_H_fake_loss = mse(D_H_fake, torch.zeros_like(D_H_fake))
        D_H_loss = D_H_real_loss + D_H_fake_loss

        fake_zebra = gen_Z(horse)
        D_Z_real = disc_Z(zebra)
        D_Z_fake = disc_Z(fake_zebra.detach())
        D_Z_real_loss = mse(D_Z_real, torch.ones_like(D_Z_real))
        D_Z_fake_loss = mse(D_Z_fake, torch.zeros_like(D_Z_fake))
        D_Z_loss = D_Z_real_loss + D_Z_fake_loss

        # put it togethor
        D_loss = (D_H_loss + D_Z_loss) / 2

        opt_disc.zero_grad()
        D_loss.backward()
        opt_disc.step()

        # Train Generators H and Z
        # adversarial loss for both generators
        D_H_fake = disc_H(fake_horse)
        D_Z_fake = disc_Z(fake_zebra)
        loss_G_H = mse(D_H_fake, torch.ones_like(D_H_fake))
        loss_G_Z = mse(D_Z_fake, torch.ones_like(D_Z_fake))

        # cycle loss
        cycle_zebra = gen_Z(fake_horse)
        cycle_horse = gen_H(fake_zebra)
        cycle_zebra_loss = l1(zebra, cycle_zebra)
        cycle_horse_loss = l1(horse, cycle_horse)

        # identity loss (remove these for efficiency if you set lambda_identity=0)
        identity_zebra = gen_Z(zebra)
        identity_horse = gen_H(horse)
        identity_zebra_loss = l1(zebra, identity_zebra)
        identity_horse_loss = l1(horse, identity_horse)

        # add all togethor
        G_loss = (
            loss_G_Z
            + loss_G_H
            + cycle_zebra_loss * LAMBDA_CYCLE
            + cycle_horse_loss * LAMBDA_CYCLE
            + identity_horse_loss * LAMBDA_IDENTITY
            + identity_zebra_loss * LAMBDA_IDENTITY
        )

        opt_gen.zero_grad()
        G_loss.backward()
        opt_gen.step()

        if idx % 200 == 0:
            save_image(fake_horse * 0.5 + 0.5, f"saved_images/horse_{idx}.png")
            save_image(fake_zebra * 0.5 + 0.5, f"saved_images/zebra_{idx}.png")

        loop.set_postfix(H_real=H_reals / (idx + 1), H_fake=H_fakes / (idx + 1))


In [10]:
def main(dirA, dirB, batch_size, lr, device):
    disc_H = Discriminator(3).to(device)
    disc_Z = Discriminator(3).to(device)
    gen_Z = Generator(3, 3).to(device)
    gen_H = Generator(3, 3).to(device)
    opt_disc = optim.Adam(
        list(disc_H.parameters()) + list(disc_Z.parameters()),
        lr=lr,
        betas=(0.5, 0.999),
    )

    opt_gen = optim.Adam(
        list(gen_Z.parameters()) + list(gen_H.parameters()),
        lr=lr,
        betas=(0.5, 0.999),
    )

    L1 = nn.L1Loss()
    mse = nn.MSELoss()


    dataset = HorseZebraDataset(
        root_horse=dirA,
        root_zebra=dirB,
        transform=transforms,
    )
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=True,
    )

    for epoch in range(10):
        train_fn(
            disc_H,
            disc_Z,
            gen_Z,
            gen_H,
            loader,
            opt_disc,
            opt_gen,
            L1,
            mse,
            device
        )


In [12]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
main("/content/dataset/trainA", "/content/dataset/trainB",batch_size=1,lr=1e-5,device=device)

100%|██████████| 1334/1334 [02:10<00:00, 10.20it/s, H_fake=0.39, H_real=0.599]
100%|██████████| 1334/1334 [02:13<00:00,  9.97it/s, H_fake=0.373, H_real=0.61]
100%|██████████| 1334/1334 [02:13<00:00,  9.97it/s, H_fake=0.346, H_real=0.631]
100%|██████████| 1334/1334 [02:13<00:00,  9.96it/s, H_fake=0.33, H_real=0.643]
100%|██████████| 1334/1334 [02:13<00:00,  9.97it/s, H_fake=0.317, H_real=0.652]
100%|██████████| 1334/1334 [02:13<00:00,  9.96it/s, H_fake=0.313, H_real=0.659]
100%|██████████| 1334/1334 [02:13<00:00,  9.96it/s, H_fake=0.304, H_real=0.67]
100%|██████████| 1334/1334 [02:13<00:00,  9.97it/s, H_fake=0.301, H_real=0.677]
100%|██████████| 1334/1334 [02:13<00:00,  9.96it/s, H_fake=0.301, H_real=0.678]
100%|██████████| 1334/1334 [02:13<00:00,  9.98it/s, H_fake=0.295, H_real=0.689]
