# CycleGAN Implementation

**Introduction**

Following notebook contains implementation of CycleGAN with brief discriptions of arcitecthure and objective functions. More information could be found in the report. 

In [1]:
from torch.nn.modules.instancenorm import InstanceNorm2d
import torch 
import torch.nn as nn 
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image
from PIL import Image
from tqdm import tqdm 
import numpy as np 
import math
import os
import glob


**Network Architecture - Generator**

Following blocks are described in the paper for the generator. Due to problems with tensors sizes and pytorch, some of the parameters such as stride size are implemented differently. This is noted by a "=>"

cs7s1 - k: *conv_block_down*

*   Convolution, kernel_size=7x7 
*   Instance Normalization
* ReLU
* k filters
* stride 1 => 2

dk: *conv_block_down*
* Convolution, kernel_size=3x3
* Instance Normalization
* Relu
* k filters and stride 2

Rk: *residual_block*
* Convolution, kernel_size=7x7
* input filter channels = output filter channels

uk: *conv_block_up*

* Fractional-strided convolution kernel_size=3x3 
* k filters
* stride 1/2 => 1


The section below implents the mentioned blocks

In [2]:
def conv_block_down(in_channels, out_channels, activation='relu',**kwargs): 
  """cs7s1 - k: Down sampling conv block  """
  activations = nn.ModuleDict([
              ['lrelu', nn.LeakyReLU()],
              ['relu', nn.ReLU()], 
              ['none', nn.Identity()]
  ])

  return nn.Sequential(
      nn.Conv2d(in_channels, out_channels,padding_mode="reflect", **kwargs), 
      InstanceNorm2d(out_channels), 
      activations[activation]
  )

def conv_block_up(in_channels, out_channels, activation='relu', **kwargs): 
  """uk: Up sampling Conv block """
  activations = nn.ModuleDict([
              ['lrelu', nn.LeakyReLU()],
              ['relu', nn.ReLU()], 
              ['none', nn.Identity()]
  ])

  return nn.Sequential(
    nn.ConvTranspose2d(in_channels, out_channels, **kwargs), 
    InstanceNorm2d(out_channels), 
    activations[activation]
)

def residual_block(channels=256, **kwargs):
  """Rx: Residual block. Last layer without activation function  """
  return nn.Sequential(
      conv_block_down(channels, channels, **kwargs),
      conv_block_down(channels, channels, activation='none', **kwargs)
  )


**From the paper:** the network consists of:
c7s1-64,

d128,d256,

R256,R256,R256,
R256,R256,R256,R256,R256,R256,

u128, u64,

c7s1-3

In [3]:
class Generator(nn.Module): 
  """ Class for generator  """
  def __init__(self, in_channels, u_net_sizes = [64, 128, 256], u_net_sizes_R = [256, 128,64],  num_residuals=9): 
    super().__init__()
    self.u_net_sizes = u_net_sizes # channels of the conv down sampling 
    self.u_net_sizes_R = u_net_sizes_R # channels of conv upsampling 

    self.initial = nn.Sequential(conv_block_down(in_channels, self.u_net_sizes[0], kernel_size=7, stride=1, padding=3))

    self.down_blocks = nn.Sequential(*[conv_block_down(in_c, out_c, kernel_size=3, padding=1, stride=2) # conv down sampling 
                for in_c, out_c in zip(self.u_net_sizes[0:], self.u_net_sizes[1:])])
  
    self.res_blocks = nn.ModuleList([residual_block(kernel_size=3, stride=1, padding=1 ) for _ in range(num_residuals)]) # residual blocks 

    self.up_blocks = nn.Sequential(*[conv_block_up(in_c, out_c, kernel_size=3, stride= 2,padding =1, output_padding=1)
                for in_c, out_c in zip(self.u_net_sizes_R[0:], self.u_net_sizes_R[1:])])

    self.last = nn.Sequential(conv_block_down(self.u_net_sizes[0], 3, kernel_size=7, padding = 3, stride=1))

  def forward(self, x): 
    x = self.initial(x)
    x = self.down_blocks(x)
    for block in self.res_blocks: 
      shortcut = x
      x = block(x) + shortcut # Adding shortcut in residual block
    x = self.up_blocks(x)
    return self.last(x)
    
    

In [5]:
def test():
  img_channels = 3
  img_size = 256
  x = torch.randn((2, img_channels, img_size, img_size)) # Replicated 2 samples, img_channels in channels and 256x256 image?

  gen = Generator(img_channels)
  print(gen)
  #print(x.size())
  pred = gen(x)
  print(pred.shape) # 70x 70 patch => each value in the grid sees a 70x70 patch in the orginal image 
#test()

**Network Architectures - Discriminator**

For discriminator networks, we use 70 × 70 PatchGAN.

Paper builds the network consisting of following blocks

Ck: 
* Convolution, kernel_size=4x4
* InstanceNorm 
* LeakyReLU, slope 0.2
* k filters, stride 2

After last layer apply conv to produce a 1-dimensional output. No InstanceNorm for first C64 layer. 

The discriminator architecture is:
C64-C128-C256-C512




In [4]:
def conv_block_D(in_channels, out_channels, *args, **kwargs): 
  """ck - k: 
   Conv block """
  return nn.Sequential(
      nn.Conv2d(in_channels, out_channels, *args, **kwargs), 
      InstanceNorm2d(out_channels), 
      nn.ReLU(),
  )

In [5]:
class Discriminator(nn.Module): 
  def __init__(self, in_channels, dec_sizes = [64, 128, 256, 512]): 
    super().__init__()
    self.dec_sizes = [in_channels, *dec_sizes]

    self.initial = nn.Sequential(
        nn.Conv2d(in_channels, self.dec_sizes[1], kernel_size=4, stride=2, padding=1, padding_mode="reflect"), 
        nn.LeakyReLU(0.2), 
    )
    
    conv_blocks = [conv_block_D(in_c, out_c, kernel_size=4,stride=2 if out_c != self.dec_sizes[-1] else 1, padding=1)
                for in_c, out_c in zip(self.dec_sizes[1:], self.dec_sizes[2:])]

    conv_blocks.append(nn.Conv2d(dec_sizes[-1], 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"))
    
    self.model = nn.Sequential(*conv_blocks)

  def forward(self, x): 
    x = self.initial(x)
    return self.model(x)
    
    

In [9]:
def test():
  x = torch.randn((5,3,256,256)) # Replicated 5 samples, 3 in channels and 256x256 image?

  model = Discriminator(in_channels=3)
  print(model)
  pred = model(x)
  print(pred.shape) # 70x 70 patch => each value in the grid sees a 70x70 patch in the orginal image 

#test()

# Dataset and DataLoader

In [10]:
%pip install Pillow

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [6]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


**Dataset** class for loading and appling basic transformation. Dataset consiting of trainA, trainB, testA and testB must be stored within a folder named '/dataset/' on drive. 

In [7]:
class HorseZebraDataset(Dataset): 
  """
  parameter mode: specifices train or test. 
  glob module retrive all matchign pathnames with specified pattern. 
  b = zebra, a = horse

  """

  def __init__(self, mode="train", transform=None, root_dir='/content'):
    self.root_dir = root_dir
    self.mode = mode
    self.transform = transform
  
    self.modeA = glob.glob(os.path.join(self.root_dir,'**/dataset', "%sA" % mode, '*.jpg'), 
                      recursive = True)
    self.modeB = glob.glob(os.path.join(self.root_dir,'**/dataset', "%sB" % mode, '*.jpg'), 
                      recursive = True)
    self.length_dataset = max(len(self.modeA), len(self.modeB))
    self.zebra_len = len(self.modeA)
    self.horse_len = len(self.modeB)
     
    
  def __getitem__(self, index):
    """ Modulus operation ensures index not out of bound""" 
    idxZ = index % self.zebra_len
    idxH = index % self.zebra_len
    zebra_img = self.modeB[idxZ]
    horse_img = self.modeA[idxH]

    zebra_img = np.array(Image.open(zebra_img).convert("RGB"))
    horse_img = np.array(Image.open(horse_img).convert("RGB"))

    if self.transform: 
      zebra_img = self.transform(zebra_img)
      horse_img = self.transform(horse_img)
      

    return zebra_img, horse_img

  def __len__(self): 
    # len(dataset)
    return self.length_dataset 

  

In [8]:
from torchvision import transforms

In [9]:
LEARNING_RATE = 1e-5
BATCH_SIZE = 1
LAMBDA_IDENTITY = 0.0
LAMBDA_CYCLE = 10
NUM_WORKERS = 4
NUM_EPOCHS = 10
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

disc_H = Discriminator(in_channels=3).to(DEVICE)
disc_Z = Discriminator(in_channels=3).to(DEVICE)
gen_Z = Generator(in_channels=3, num_residuals=9).to(DEVICE)
gen_H = Generator(in_channels=3, num_residuals=9).to(DEVICE)

transformsA = transforms.Compose(
        [
            transforms.ToTensor(),
            #transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])    
        ]
    )

opt_disc = torch.optim.Adam(
    list(disc_H.parameters()) + list(disc_Z.parameters()),
    lr= LEARNING_RATE, 
    betas=(0.5,0.999)
)

opt_gen = torch.optim.Adam(
    list(gen_Z.parameters()) + list(gen_H.parameters()),
    lr= LEARNING_RATE, 
    betas=(0.5,0.999)
)

L1 = nn.L1Loss()
mse = nn.MSELoss()

datasetTrain = HorseZebraDataset(mode="train", transform =transformsA)
datasetTest = HorseZebraDataset(mode="test", transform =transformsA)

dataLoaderTrain = DataLoader(datasetTrain, batch_size= BATCH_SIZE, shuffle=True)
dataLoaderTest = DataLoader(datasetTest,batch_size= BATCH_SIZE, shuffle=True)

g_scaler = torch.cuda.amp.GradScaler() # Scale gradients, so they aren't flushed to zero. "Solves" problem of too small gradients take into account
d_scaler = torch.cuda.amp.GradScaler()

In [10]:
def save_checkpoint(state, filename="my_checkpoint.pth.tar"): 
  print("=> saving checkpoint")
  torch.save(state, filename)

def load_checkpoint(checkpoint): 
  print("=> loading checkpoint")
  gen_Z.load_state_dict(checkpoint['state_dict'])
  opt_gen.load_state_dict(checkpoint['optimizer'])





In [11]:
"""Train implementaion inspired from https://github.com/aladdinpersson/Machine-Learning-Collection Credits to: Aladdin Persson """

def train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d_scalar, g_scalar, epoch ): 
  
  # H_reals = 0
  # H_fakes = 0
  loop = tqdm(loader, total=len(loader), position=0, leave=True) # wrap loader for progress bar

  checkpoint = {'state_dict' : gen_Z.state_dict(), 'optimizer': opt_gen.state_dict()}
  save_checkpoint(checkpoint)

  for idx, (zebra, horse) in enumerate(loop):
    zebra = zebra.to(DEVICE)
    horse = horse.to(DEVICE)

    # Train Discriminators H and Z

    with torch.cuda.amp.autocast(): # automatically cast tensors to a smaller memory footprint, float16
      fake_horse = gen_H(zebra) # Generate fake horse from real zebra
      D_H_real = disc_H(horse) #  Prediction on the real horses => want it to be 1
      D_H_fake = disc_H(fake_horse.detach()) # Prediction of fake horse => want it to be 0. (Detach fake_horse output from the auto_grad-graph as we will use it later)¨
      # H_reals += D_H_real.mean().item()
      # H_fakes += D_H_fake.mean().item()
      D_H_real_loss = mse(D_H_real, torch.ones_like(D_H_real)) # D_H_real goal output is 1, therefor comparing against 1
      D_H_fake_loss = mse(D_H_fake, torch.zeros_like(D_H_fake)) # D_H_fake Should output 0
      D_H_loss = D_H_real_loss + D_H_fake_loss

      fake_zebra = gen_Z(horse)
      D_Z_real = disc_Z(zebra) # D_Z_real want it => 1
      D_Z_fake = disc_Z(fake_zebra.detach())
      D_Z_real_loss = mse(D_Z_real, torch.ones_like(D_H_real)) # D_H_real should output 1
      D_Z_fake_loss = mse(D_Z_fake, torch.zeros_like(D_H_fake)) 
      D_Z_loss = D_Z_real_loss + D_Z_fake_loss

      D_loss = D_H_loss + D_Z_loss

    opt_disc.zero_grad()
    d_scaler.scale(D_loss).backward()
    d_scaler.step(opt_disc)
    d_scaler.update()


    # Train Generator H and Z

    with torch.cuda.amp.autocast(): 
      # adversial loss for both generators
      D_H_fake = disc_H(fake_horse) # should output 0
      D_Z_fake = disc_Z(fake_zebra) # should output 0
      loss_G_H = mse(D_H_fake, torch.ones_like(D_H_fake)) # Want to trick Discrimiator to belive it's real 
      loss_G_Z = mse(D_Z_fake, torch.ones_like(D_Z_fake)) # Want to trick Discrimiator to belive it's real 

      # cycle loss 
      cycle_zebra = gen_Z(fake_horse)
      cycle_horse = gen_H(fake_zebra)
      cycle_zebra_loss = l1(zebra, cycle_zebra)
      cycle_horse_loss = l1(zebra, cycle_horse)

      G_loss = (
          loss_G_H
          + loss_G_Z
          + cycle_zebra_loss*LAMBDA_CYCLE
          + cycle_horse_loss*LAMBDA_CYCLE
      )

    opt_gen.zero_grad()
    g_scaler.scale(G_loss).backward()
    g_scaler.step(opt_gen)
    g_scaler.update()

    loop.set_description(f"Epoch [{epoch}/{NUM_EPOCHS}]") 
    #loop.set_postfix(H_real=H_reals / (idx + 1), H_fake=H_fakes / (idx + 1))


    
    if idx % 1200 == 0:
      pathA = os.path.join('/content', 'epoch%i'%epoch)
      if os.path.exists(pathA) == False:
        os.mkdir(pathA)
      save_image(fake_horse, f"{pathA}/horseFake_{idx}.png")
      save_image(horse, f"{pathA}/horse_{idx}.png")
      save_image(fake_zebra, f"{pathA}/zebraFake_{idx}.png")
      save_image(zebra, f"{pathA}/zebra_{idx}.png")



In [12]:
    for epoch in range(NUM_EPOCHS):
        train_fn(
            disc_H,
            disc_Z,
            gen_Z,
            gen_H,
            dataLoaderTrain,
            opt_disc,
            opt_gen,
            L1,
            mse,
            d_scaler,
            g_scaler,
            epoch
        )


  0%|          | 0/1334 [00:00<?, ?it/s]

=> saving checkpoint


Epoch [0/10]: 100%|██████████| 1334/1334 [08:21<00:00,  2.66it/s]
  0%|          | 0/1334 [00:00<?, ?it/s]

=> saving checkpoint


Epoch [1/10]: 100%|██████████| 1334/1334 [03:58<00:00,  5.59it/s]
  0%|          | 0/1334 [00:00<?, ?it/s]

=> saving checkpoint


Epoch [2/10]: 100%|██████████| 1334/1334 [03:56<00:00,  5.63it/s]
  0%|          | 0/1334 [00:00<?, ?it/s]

=> saving checkpoint


Epoch [3/10]: 100%|██████████| 1334/1334 [03:56<00:00,  5.63it/s]
  0%|          | 0/1334 [00:00<?, ?it/s]

=> saving checkpoint


Epoch [4/10]: 100%|██████████| 1334/1334 [03:57<00:00,  5.62it/s]
  0%|          | 0/1334 [00:00<?, ?it/s]

=> saving checkpoint


Epoch [5/10]: 100%|██████████| 1334/1334 [03:55<00:00,  5.67it/s]
  0%|          | 0/1334 [00:00<?, ?it/s]

=> saving checkpoint


Epoch [6/10]: 100%|██████████| 1334/1334 [03:54<00:00,  5.70it/s]
  0%|          | 0/1334 [00:00<?, ?it/s]

=> saving checkpoint


Epoch [7/10]: 100%|██████████| 1334/1334 [03:53<00:00,  5.72it/s]
  0%|          | 0/1334 [00:00<?, ?it/s]

=> saving checkpoint


Epoch [8/10]: 100%|██████████| 1334/1334 [03:53<00:00,  5.71it/s]
  0%|          | 0/1334 [00:00<?, ?it/s]

=> saving checkpoint


Epoch [9/10]: 100%|██████████| 1334/1334 [03:52<00:00,  5.73it/s]
