<a href="https://colab.research.google.com/github/mohripan/Machine-Learning/blob/main/CycleGansHorseZebra.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install --upgrade --force-reinstall --no-deps albumentations
!pip install qudida

Collecting albumentations
  Using cached albumentations-1.1.0-py3-none-any.whl (102 kB)
Installing collected packages: albumentations
  Attempting uninstall: albumentations
    Found existing installation: albumentations 1.1.0
    Uninstalling albumentations-1.1.0:
      Successfully uninstalled albumentations-1.1.0
Successfully installed albumentations-1.1.0


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import cv2
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from tqdm.auto import tqdm
import numpy as np
import matplotlib.pyplot as plt
import zipfile, requests, io
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [3]:
zip_url = 'https://storage.googleapis.com/kaggle-data-sets/210524/459209/bundle/archive.zip?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gcp-kaggle-com%40kaggle-161607.iam.gserviceaccount.com%2F20211031%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20211031T095629Z&X-Goog-Expires=259199&X-Goog-SignedHeaders=host&X-Goog-Signature=47bde570c2a95274e1b567277b186d50bd079afce8ab6d222b646bcf296a0281647af24e4216476787171bfd03c18f1c3b31fe42ec9c064f36e862e526774fd8134827e31b555763a84f89df025d5c74f022c2cfd635019320c7e00655326c610c0be98a214f65d859c0d36c80772d705d9857f1b9c8d3f53750be3c8500711186624596157b6a73360d8f38454a7054db639d0c289458967b2f061f48639266afa746df8c48679fe76dedc579c835d3a1ee9f6a8f2d6b1fffd84c8d7b6770e19b132767497ba97f385c0c67c7c5c21fa76a19be7a5feadb4bef4e8a6e10aae7b997a4741530f536fa035533d553a42cf3ea178cdd37d6c7c970a005b37a867e'
r = requests.get(zip_url)
with zipfile.ZipFile(io.BytesIO(r.content)) as my_zip:
  my_zip.extractall()

In [4]:
class Block(nn.Module):
  def __init__(self, in_channels, out_channels, stride):
    super().__init__()
    self.conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias=True, padding_mode='reflect'),
                              nn.InstanceNorm2d(out_channels),
                              nn.LeakyReLU(0.2))
      
  def forward(self, x):
    return self.conv(x)

In [5]:
class Discriminator(nn.Module):
  def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
    super().__init__()
    self.initial = nn.Sequential(nn.Conv2d(in_channels,
                                           features[0],
                                           4, 2, 1,
                                           padding_mode='reflect'),
                                 nn.LeakyReLU(0.2),)
    
    layers = []
    in_channels = features[0]

    for feature in features[1:]:
      layers.append(Block(in_channels, feature, stride=1 if feature==features[-1] else 2))
      in_channels = feature

    layers.append(nn.Conv2d(in_channels, 1, 4, 1, 1, padding_mode='reflect'))
    self.mode = nn.Sequential(*layers)

  def forward(self, x):
    x = self.initial(x)
    return torch.sigmoid(self.mode(x))

In [6]:
class ConvBlock(nn.Module):
  def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
    super().__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, padding_mode='reflect', **kwargs)
        if down
        else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
        nn.InstanceNorm2d(out_channels),
        nn.ReLU(inplace=True) if use_act else nn.Identity()
    )

  def forward(self, x):
    return self.conv(x)

In [7]:
class ResidualBlock(nn.Module):
  def __init__(self, channels):
    super().__init__()
    self.block = nn.Sequential(ConvBlock(channels, channels, kernel_size=3, padding=1),
                               ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1),
                               )

  def forward(self, x):
    return x+self.block(x)

In [8]:
class Generator(nn.Module):
  def __init__(self, img_channels, num_features=64, num_residuals=9):
    super().__init__()
    self.initial = nn.Sequential(
        nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode='reflect'),
        nn.ReLU(inplace=True)
    )

    self.down_blocks = nn.ModuleList(
        [
         ConvBlock(num_features, num_features*2, kernel_size=3, stride=2, padding=1),
         ConvBlock(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1),
        ]
    )

    self.residual_blocks = nn.Sequential(
        *[ResidualBlock(num_features*4) for _ in range(num_residuals)]
    )

    self.up_blocks = nn.ModuleList(
        [
         ConvBlock(num_features*4, num_features*2, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
         ConvBlock(num_features*2, num_features, down=False, kernel_size=3, stride=2, padding=1, output_padding=1)
        ]
    )

    self.last = nn.Conv2d(num_features, img_channels, kernel_size=7, stride=1, padding=3, padding_mode='reflect')

  def forward(self, x):
    x = self.initial(x)
    for layer in self.down_blocks:
      x = layer(x)
    x = self.residual_blocks(x)
    for layer in self.up_blocks:
      x = layer(x)
    
    return torch.tanh(self.last(x))

In [9]:
!pip install config



In [10]:
from PIL import Image
import os
import config

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_horse_dir = '/content/horse2zebra/horse2zebra/trainA'
train_zebra_dir = '/content/horse2zebra/horse2zebra/trainB'
val_horse_dir = '/content/horse2zebra/horse2zebra/testA'
val_zebra_dir = '/content/horse2zebra/horse2zebra/testB'

batch_size = 1
lr = 1e-5
lambda_identity = 0.0
lambda_cycle = 10
num_workers = 4
num_epochs = 10
load_model = False
save_model = True
checkpoint_gen_H = 'genh.pth.tar'
checkpoint_gen_Z = 'genz.pth.tar'
checkpoint_critic_H = 'critich.pth.tar'
checkpoint_critic_Z = 'criticz.pth.tar'

transforms = A.Compose([
                        A.Resize(256, 256),
                        A.HorizontalFlip(0.5),
                        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
                        ToTensorV2(),
], additional_targets={'image0':'image'})

In [12]:
def save_checkpoint(model, optimizer, filename='horse_zebra.pth.tar'):
  print('=> Saving checkpoint')
  checkpoint = {
      'state_dict': model.state_dict(),
      'optimizer': optimizer.state_dict(),
  }
  torch.save(checkpoint, filename)

def load_checkpoint(checkpoint_file, model, optimizer, lr):
  print('=> Loading checkpoint')
  checkpoint = torch.load(checkpoint_file, map_location=device)
  model.load_state_dict(checkpoint['state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer'])

  for param_group in optimizer.param_groups:
    param_group['lr'] = lr

In [13]:
import random

In [14]:
def seed_everything(seed=42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [15]:
class HorseZebraDataset(Dataset):
  def __init__(self, root_zebra, root_horse, transform=None):
    self.root_zebra = root_zebra
    self.root_horse = root_horse
    self.transform = transform

    self.zebra_images = os.listdir(root_zebra)
    self.horse_images = os.listdir(root_horse)
    self.length_dataset = max(len(self.zebra_images), len(self.horse_images))
    self.zebra_len = len(self.zebra_images)
    self.horse_len = len(self.horse_images)

  def __len__(self):
    return self.length_dataset

  def __getitem__(self, idx):
    zebra_img = self.zebra_images[idx % self.zebra_len]
    horse_img = self.horse_images[idx % self.horse_len]

    zebra_path = os.path.join(self.root_zebra, zebra_img)
    horse_path = os.path.join(self.root_horse, horse_img)

    zebra_img = np.array(Image.open(zebra_path).convert('RGB'))
    horse_img = np.array(Image.open(horse_path).convert('RGB'))

    if self.transform:
      augmentation = self.transform(image=zebra_img, image0=horse_img)
      zebra_img = augmentation['image']
      horse_img = augmentation['image0']

    return zebra_img, horse_img

In [19]:
from tqdm import tqdm
from torchvision.utils import save_image

In [31]:
def train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, L1, mse, d_scaler, g_scaler):
  loop = tqdm(loader, leave=True)

  for idx, (zebra, horse) in enumerate(loop):
    zebra, horse = zebra.to(device), horse.to(device)

    with torch.cuda.amp.autocast():
      fake_horse = gen_H(zebra)
      D_H_real = disc_H(horse)
      D_H_fake = disc_H(fake_horse.detach())
      D_H_real_loss = mse(D_H_real, torch.ones_like(D_H_real))
      D_H_fake_loss = mse(D_H_fake, torch.zeros_like(D_H_fake))
      D_H_loss = D_H_real_loss+D_H_fake_loss

      fake_zebra = gen_Z(horse)
      D_Z_real = disc_Z(zebra)
      D_Z_fake = disc_Z(fake_zebra.detach())
      D_Z_real_loss = mse(D_Z_real, torch.ones_like(D_Z_real))
      D_Z_fake_loss = mse(D_Z_fake, torch.zeros_like(D_Z_fake))
      D_Z_loss = D_Z_real_loss+D_Z_fake_loss

      D_loss = (D_H_loss+D_Z_loss)/2

    opt_disc.zero_grad()
    d_scaler.scale(D_loss).backward()
    d_scaler.step(opt_disc)
    d_scaler.update()

    with torch.cuda.amp.autocast():
      D_H_fake = disc_H(fake_horse)
      D_Z_fake = disc_Z(fake_zebra)
      loss_G_H = mse(D_H_fake, torch.ones_like(D_H_fake))
      loss_G_Z = mse(D_Z_fake, torch.ones_like(D_Z_fake))

      cycle_zebra = gen_Z(fake_horse)
      cycle_horse = gen_H(fake_zebra)
      cycle_zebra_loss = L1(zebra, cycle_zebra)
      cycle_horse_loss = L1(horse, cycle_horse)

      identity_zebra = gen_Z(zebra)
      identity_horse = gen_H(horse)
      identity_zebra_loss = L1(zebra, identity_zebra)
      identity_horse_loss = L1(horse, identity_horse)

      G_loss = (loss_G_Z+
                loss_G_H+
                cycle_zebra_loss*lambda_cycle+
                cycle_horse_loss*lambda_cycle+
                identity_horse_loss*lambda_identity+
                identity_zebra_loss*lambda_identity)
      
      opt_gen.zero_grad()
      g_scaler.scale(G_loss).backward()
      g_scaler.step(opt_gen)
      g_scaler.update()

      if idx%200 == 0:
        save_image(fake_horse*0.5+0.5, f'saved_images/horse_{idx}.png')
        save_image(fake_zebra*0.5+0.5, f'saved_images/zebra_{idx}.png')
  

def main():
  disc_H = Discriminator(in_channels=3).to(device)
  disc_Z = Discriminator(in_channels=3).to(device)
  gen_Z = Generator(img_channels=3, num_residuals=9).to(device)
  gen_H = Generator(img_channels=3, num_residuals=9).to(device)

  opt_disc = optim.Adam(list(disc_H.parameters()) + list(disc_Z.parameters()), lr=lr, betas=(0.5, 0.999),)

  opt_gen = optim.Adam(list(gen_Z.parameters()) + list(gen_H.parameters()), lr=lr, betas=(0.5, 0.999),)

  L1 = nn.L1Loss()
  mse = nn.MSELoss()

  if load_model:
    load_checkpoint(checkpoint_gen_H, gen_H, opt_gen, lr)
    load_checkpoint(checkpoint_gen_Z, gen_Z, opt_gen, lr)
    load_checkpoint(checkpoint_critic_H, disc_H, opt_disc, lr)
    load_checkpoint(checkpoint_critic_Z, disc_Z, opt_disc, lr)

  dataset = HorseZebraDataset(root_horse=train_horse_dir, root_zebra=train_zebra_dir, transform=transforms)

  loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)

  g_scaler = torch.cuda.amp.GradScaler()
  d_scaler = torch.cuda.amp.GradScaler()

  for i in range(num_epochs):
    train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, L1, mse, d_scaler, g_scaler)

    if save_model:
      save_checkpoint(gen_H, opt_gen, filename=checkpoint_gen_H)
      save_checkpoint(gen_Z, opt_gen, filename=checkpoint_gen_Z)
      save_checkpoint(disc_H, opt_disc, filename=checkpoint_critic_H)
      save_checkpoint(disc_Z, opt_disc, filename=checkpoint_critic_Z)

In [None]:
main()

  cpuset_checked))
100%|██████████| 1334/1334 [54:15<00:00,  2.44s/it]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


 62%|██████▏   | 831/1334 [33:50<20:27,  2.44s/it]