<a href="https://colab.research.google.com/github/martinpius/PYTORCH/blob/main/CYCLE_GAN_Pytorch_Implementantion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount("/content/drive", force_remount = True)
try:
  COLAB = True
  import torch
  print(f">>>> You are on CoLaB with torch version {torch.__version__}")
except Exception as e:
  print(f">>>> {type(e)}: {e}\n>>>> please correct {type(e)} and reload")
  COLAB = False
if torch.cuda.is_available():
  device = torch.device("cuda")
else:
  device = torch.device("cpu")
def time_fmt(t: float = 123.879)->float:
  h = int(t / (60 * 60))
  m = int(t % (60 * 60) / 60)
  s = int(t % 60)
  return f"{h} hrs: {m:>02} min: {s:>05.2f} sec"
print(f">>>> time formating\tplease wait........\n>>>> time elapsed\t{time_fmt()}")

Mounted at /content/drive
>>>> You are on CoLaB with torch version 1.8.1+cu101
>>>> time formating	please wait........
>>>> time elapsed	0 hrs: 02 min: 03.00 sec


In [2]:
#In this notebook we are going to implement a more advanced GAN architecture that convert one picture 
#to another and vice-versa. We bassically train 2 discriminators and 2 generators which works side by side.
#In this case the loss function will be splitted into 6 quantities related to generators, discriminators, identity and cycle.
 

In [3]:
#Importing modules and setup the seed-value and the device to deterministic:
import torch, os, copy
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
import PIL
import math, random, sys, time
import numpy as np
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import albumentations as A
from albumentations.pytorch import ToTensor
from torchvision.utils import save_image

In [4]:
#Seting the seed-values for reporoducability and the device to avoid errors during training 
seed = 1234
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [5]:
#Model building involving the ussual structure of discriminator and generator classes:
#In CYCLE-GAN-the discriminator has a CNN with a scalar outputs (to discriminate fake-real)
#The generator will have an architecture similar to that of a U-NET (i.e; both down-sampling and upsampling)
#conv-layers to construct the images:

In [6]:
###########################Generator class ##############################Generator class ################################

In [7]:
#We first prepare the convolution block to re-use it later when defining the skip-block
class ConBlock(nn.Module):
  def __init__(self, in_channels, out_channels,down = True, act_use = True, **kwargs):
    super(ConBlock, self).__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, padding_mode = 'reflect', **kwargs)
        if down
        else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
        nn.InstanceNorm2d(out_channels),
        nn.ReLU(inplace = True) if act_use else nn.Identity()
    )
  
  def forward(self, input_tensor):
    return self.conv(input_tensor)

#We create the residual block with this class.
class SkipBlock(nn.Module):
  def __init__(self, channels):
    super(SkipBlock, self).__init__()
    self.block = nn.Sequential(
        ConBlock(channels, channels, kernel_size = 3, padding = 1),
        ConBlock(channels, channels, act_use = False, kernel_size = 3, padding = 1))
  
  def forward(self, input_tensor):
    return input_tensor + self.block(input_tensor)

#The actual generator class which reuse the above convblock and resblock classes:
class Generator(nn.Module):
  def __init__(self, img_channels, num_features = 64, num_skip = 9):
    super(Generator, self).__init__()

    self.initial_block = nn.Sequential(
        nn.Conv2d(img_channels, num_features, kernel_size = 7, stride = 1, padding_mode = 'reflect', padding = 3 ),
        nn.InstanceNorm2d(num_features),
        nn.ReLU(inplace = True))
    
    self.down_blocks = nn.ModuleList(
      [
       ConBlock(num_features, 2*num_features, kernel_size = 3, stride = 2, padding = 1),
       ConBlock(2*num_features, 4*num_features, kernel_size = 3, stride = 2, padding = 1)
      ])
    self.resblocks = nn.Sequential(
      *[SkipBlock(4*num_features) for _ in range(num_skip)])
    
    self.up_blocks = nn.ModuleList(
      [ConBlock(4*num_features, 2*num_features,down = False, kernel_size = 3, stride = 2, padding = 1, output_padding = 1),
      ConBlock(2*num_features, 1*num_features, down = False, kernel_size = 3, stride = 2, padding = 1, output_padding = 1)])
    
    self.final = nn.Conv2d(num_features, img_channels, kernel_size = 7,stride = 1, padding = 3, padding_mode = 'reflect')
  
  def forward(self, input_tensor):
    x = self.initial_block(input_tensor)
    for layer in self.down_blocks:
      x = layer(x)
    x = self.resblocks(x)
    for layer in self.up_blocks:
      x = layer(x)
    return torch.tanh(self.final(x))

################End of Generator class ########### End of Generator class############ End of Generator Class###############################

In [8]:
#Testing the generator if its generate the intended image:
def __test__():
  img_channels = 3
  H = 256
  W = 256
  batch_size = 32
  rand_img = torch.randn(batch_size, img_channels, W, H).to(device = device)
  generator = Generator(img_channels,9).to(device = device)
  gen_output = generator(rand_img)
  return f"gen_output_shape: {gen_output.shape}"

In [9]:
__test__()

'gen_output_shape: torch.Size([32, 3, 256, 256])'

In [10]:

#########Discriminator class ########Discriminator class #############Discriminator class########################## DISCRIMINATOR#############


In [11]:
#We first create a convolution block to re-use later when defining the discriminatoe
class DConBlock(nn.Module):
  def __init__(self, in_channels, out_channels, stride):
    super(DConBlock, self).__init__()
    self.convd = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 4, 1, bias = True, padding_mode = 'reflect'),
        nn.InstanceNorm2d(out_channels),
        nn.LeakyReLU(0.2, inplace = True))
    
  def forward(self, input_tensor):
    return self.convd(input_tensor)
    

In [12]:
#We then use the above conv-block to create Discriminator network:
class Discriminator(nn.Module):
  def __init__(self, in_channels = 3, num_features = [64, 128, 256, 512]):
    super(Discriminator, self).__init__()
    self.first_block = nn.Sequential(
        nn.Conv2d(in_channels, num_features[0], kernel_size = 4, stride = 2, padding = 1, padding_mode = 'reflect'),
        nn.LeakyReLU(0.2, inplace = True))
    layers = []
    in_channels = num_features[0] #updating for the next conv-layer
    #We start iteration from the 2nd feature
    for feature in num_features[1:]:
      layers.append(DConBlock(in_channels, feature, stride = 1 if feature == num_features[-1] else 2))
      in_channels = feature #update again for the secong layer
    layers.append(nn.Conv2d(in_channels, 1, kernel_size = 4, stride = 1, padding = 1, padding_mode = 'reflect'))
      #Unpacking the layers and create a model object
    self.model = nn.Sequential(*layers)
  
  def forward(self, input_tensor):
    x = self.first_block(input_tensor)
    return torch.sigmoid(self.model(x)) #make sure the output is a binary


In [13]:
#Testing the Discriminator with the random generated noise
def __dtest__():
  H,W = 256, 256
  img_channels = 3
  batch_size = 64
  noise = torch.randn(batch_size, img_channels, W,H).to(device = device)
  discriminator = Discriminator().to(device = device)
  dis_out = discriminator(noise)
  return f"discriminator-output_shape: {dis_out.shape}"

In [14]:
__dtest__()

'discriminator-output_shape: torch.Size([64, 1, 118, 118])'

In [15]:
###################### End Discriminator ################# End Discriminator ########################## End Discriminator ####################

In [16]:
###Defining the checkpoint to save and load the weights ######
def save_chkpt(model, optimizer, file_name = "cycle_gan.path.tar"):
  print(f">>>> saving the checkpoint.........")
  checkpoint =  {'state_dict': model.state_dict(),
                 'optimizers': optimizer.state_dict()}
  torch.save(checkpoint, file_name)

def load_checkpoint(model, optimizer, lr, check_point_file):
  print(f">>>> loading checkpoint..........")
  checkpoint = torch.load(check_point_file, map_location = device)
  model.load_state_dict(checkpoint['state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer'])
  for par_group in optimizer.param_groups:
    par_group['lr'] = lr

In [17]:
############ Loading and preprocess Data ############### Loading and preprocess Data ################# Loading Preprocess #######################
##******************We are going to use the dataset of horses and zebras from Kaggle*************

In [18]:
#HYPER PARAMETERS
EPOCHS = 5
batch_size = 1
learning_rate = 1e-5
lambda_ID = 0.0 # for identity loss component
lambda_cycle = 10 # for cycle loss component
num_workers = 4 # for the effectiveness of loader
train_dir = "/content/drive/MyDrive/GANS/archive (4)"
val_dir = "/content/drive/MyDrive/GANS/archive (4)"
gen_h_ckpt = "genh.path.tar" #checkpoint for generator(horse-training)
gen_z_ckpt = "genz.path.tar" #checkpoint for generator(zebra-training)
dis_h_ckpt = "dish.path.tar" #checkpoint for discriminator(horse-training)
dis_z_ckpt = "disz.path.tar" #checkpoint for discriminator(zebra-training)
load_model = True
save_model = True

transforms = A.Compose(
    [
     A.Resize(height = 256, width = 256), A.VerticalFlip(), 
     A.Normalize(mean = [0.5,0.5,0.5], 
     std = [0.5,0.5,0.5], max_pixel_value = 255), ToTensor()
    ],
    additional_targets = {'image0':'image'}
)



In [19]:
###### Loading and preprocess the dataset from the directory################ Loading and Preprocess the dataset ###########################
class HZDataset(Dataset):
  def __init__(self, root_zebra, root_horse, transform = None):
    self.root_zebra = root_zebra
    self.root_horse = root_horse
    self.transform = transform
    self.zebra_imgs = os.listdir(root_zebra)
    self.horse_imgs = os.listdir(root_horse)
    self.dfm_len = max(len(self.zebra_imgs), len(self.horse_imgs)) #since the two sets are not of equal length
    self.zebra_len = len(self.zebra_imgs) #grab number of rows for the zebra dataset
    self.horse_len = len(self.horse_imgs) #grab number of rows for the horse dataset
  
  def __len__(self):
    return self.dfm_len
  
  def __getitem__(self, index):
    '''make sure we do not go out of range'''
    zebra_imgs = self.zebra_imgs[index % self.zebra_len]
    horse_imgs = self.horse_imgs[index % self.horse_len]
    zebra_path = os.path.join(self.root_zebra, zebra_imgs)
    horse_path = os.path.join(self.root_horse, horse_imgs)
    zebra_image = np.array(PIL.Image.open(zebra_path).convert('RGB'))
    horse_image = np.array(PIL.Image.open(horse_path).convert('RGB'))
    if self.transform:
      dfm_aug = self.transform(image = zebra_image, image0 = horse_image)
      zebra_image = dfm_aug['image']
      horse_image = dfm_aug['image0']
    return zebra_image, horse_image



In [20]:
#The training Loop: This is an important and longest function of cycle GAN!!!!!
#Lets get started#####

In [21]:
def train_loop(disc_h, disc_z, gen_h, gen_z, loader, opt_gen, opt_disc,l1, mse, d_scaler, g_scaler):
  h_fakes = 0
  h_reals = 0
  mytqdm = tqdm(loader, leave = True)
  for idx, (zebra, horse) in enumerate(mytqdm):
    zebra = zebra.to(device = device)
    horse = horse.to(device = device)

    #Traning the discriminators (both for horse & zebra)
    with torch.cuda.amp.autocast():
      fake_horse = gen_h(zebra)
      d_h_real = disc_h(horse)
      d_h_fake = disc_h(fake_horse.detach()) #detach to allow this component to be free after backward pass for later use
      h_reals+=d_h_real.mean().item() #
      h_fakes+=d_h_fake.mean().item()
      d_h_real_loss = mse(d_h_real, torch.ones_like(d_h_real)) #discriminator loss due to real horse
      d_h_fake_loss = mse(d_h_fake, torch.zeros_like(d_h_fake))#discriminator loss due to fake horse
      d_h_loss = (d_h_real_loss + d_h_fake_loss)/2
      #discriminator for zebra
      fake_zebra = gen_z(horse)
      d_z_real = disc_z(zebra)
      d_z_fake = disc_z(fake_zebra.detach())
      d_z_real_loss = mse(d_z_real, torch.ones_like(d_z_real))
      d_z_fake_loss = mse(d_z_fake, torch.zeros_like(d_z_fake))
      d_z_loss = (d_z_real_loss + d_z_fake_loss)/2
      #total discriminator loss
      d_loss = (d_h_loss + d_z_loss) / 2
    
    opt_disc.zero_grad()#initialize the discriminator slope parameters to zeros
    d_scaler.scale(d_loss).bakward() # backward pass for the discriminator
    d_scaler.step(opt_disc) #gradient descent 
    d_scaler.update()

    #We now train the generators (both for horse and zebras):
    with torch.cuda.amp.autocast():
      d_h_fake = disc_h(fake_horse)
      d_z_fake = disc_z(fake_zebra)
      #We fool both discriminators for horse and zebra
      loss_gen_horse = mse(d_h_fake, torch.ones(d_h_fake))
      loss_gen_zebra = mse(d_z_fake, torch.ones(d_z_fake))

      #cycle loss: penalty between a real horse and a fake horse (l1-norm)
      cycle_zebra = gen_z(fake_horse)
      cycle_horse = gen_h(fake_zebra)
      cycle_z_loss = l1(zebra, cycle_zebra) # distance btn a fake horse and a real horse
      cycle_h_loss = l1(horse, cycle_horse)

      #Identity loss: penalty between a real horse/zebra and a generated one (l1-norm)
      id_zebra = gen_z(zebra)
      id_horse = gen_h(horse)
      id_zebra_loss = l1(zebra, id_zebra)
      id_horse_loss = l1(horse, id_zebra)

      #total generator loss (adding all of them)
      gen_loss = (loss_gen_horse + 
                  loss_gen_zebra + 
                  cycle_z_loss * lambda_cycle + 
                  cycle_h_loss * lambda_cycle + 
                  id_zebra_loss * lambda_ID + 
                  id_horse_loss * lambda_ID)
      
      opt_gen.zero_grad() #initialize the grad parameters to zero
      g_scaler.scale(gen_loss).backward() #bakward pass for the generator
      g_scaler.step(opt_gen) # gradient descent for the generator
      g_scaler.update()
      if idx % 100 == 0:
        save_image(fake_horse * 0.5 + 0.5, f"saved_images/horse_{idx}.png")
        save_image(fake_zebra * 0.5 + 0.5, f"saved_images/zebra_{idx}.png")
        
      mytqdm.set_postfix(h_real = h_real/(idx + 1), h_fake = h_fake/(idx + 1))



In [22]:
##########Finally we train our Model using the following function####################

In [23]:
def __train__():
  global_tic = time.time()
  #instantiate the model classes
  disc_h = Discriminator(in_channels = 3).to(device = device)
  disc_z = Discriminator(in_channels = 3).to(device = device)
  gen_h = Generator(img_channels = 3, num_skip = 9).to(device = device)
  gen_z = Generator(img_channels = 3, num_skip = 9).to(device = device)

  #get the optimizers objects
  opt_gen = optim.Adam(params = list(gen_h.parameters()) + list(gen_z.parameters()),
                       lr = learning_rate, betas = (0.5, 0.999))
  opt_disc = optim.Adam(params = list(disc_h.parameters()) + list(disc_z.parameters()),
                        lr = learning_rate, betas = (0.5, 0.999))
  
  #Get the loss object (norm1 and mse)
  l1 = nn.L1Loss()
  mse = nn.MSELoss()

  #loading the check points if any.
  if not load_model:
    load_checkpoint(gen_h_ckpt, gen_h, opt_gen, learning_rate)
    load_checkpoint(gen_z_ckpt, gen_z, opt_gen, learning_rate)
    load_checkpoint(dis_h_ckpt, disc_h, opt_disc, learning_rate)
    load_checkpoint(dis_z_ckpt, disc_z, opt_disc, learning_rate)
  
  #importing and processing the dataset (using the above defined function)
  train_data = HZDataset(root_zebra = train_dir+"/trainB",root_horse = train_dir+"/trainA")
  val_data = HZDataset(root_zebra = val_dir+"/testB", root_horse = val_dir+"/testA")

  #defining the loaders
  loader = DataLoader(
      dataset = train_data, 
      shuffle = True,
       batch_size = batch_size, 
       num_workers = num_workers, 
       pin_memory = True)
  
  val_loader = DataLoader(
      dataset = val_data,
      shuffle = False,
      batch_size = batch_size,
      pin_memory = True)
  
  #get the gradient scaler object
  g_scaler = torch.cuda.amp.GradScaler() # for generator
  d_scaler = torch.cuda.amp.GradScaler() # for discriminator

  #Actual Training is going down here!!!!!!!!!!!
  for epoch in range(EPOCHS):
    tic = time.time()
    print(f"\n>>>> training starts for the epoch {epoch + 1}\tplease wait..........")
    train_loop(disc_h, disc_z, gen_h, gen_z, loader, opt_gen, opt_disc, l1,mse,d_scaler, g_scaler)
    toc = time.time()
    #saving check-points
    if save_model:
      save_chkpt(gen_h, opt_gen, file_name = gen_h_ckpt)
      save_chkpt(gen_z, opt_gen, file_name = gen_z_ckpt)
      save_chkpt(disc_h, opt_disc, file_name = dis_h_ckpt)
      save_chkpt(disc_z, opt_disc, file_name = dis_z_ckpt)
    print(f"\n>>>> time elapsed at the end of epoch {epoch + 1}:{time_fmt(toc - tic)}")

