<a href="https://colab.research.google.com/github/boppana-tejkiran/Pix-to-Pix-GAN/blob/main/Pix_to_Pix_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from tqdm import tqdm

# Discriminator Network

In [3]:
class CNNBlock(nn.Module):
  def __init__(self, in_channels, out_channels, stride = 2):
    super().__init__()
    self.conv = nn.Sequential(
      nn.Conv2d(in_channels, out_channels, 4, stride, bias = False, padding_mode = 'reflect'),
      nn.BatchNorm2d(out_channels),
      nn.LeakyReLU(0.2)
    )

  def forward(self, x):
    return self.conv(x)


class Discriminator(nn.Module):
  def __init__(self, in_channels = 3, features = [64, 128, 256, 512]):
    super().__init__()
    self.initial = nn.Sequential(
        nn.Conv2d(in_channels*2, features[0], kernel_size = 4, stride = 2, padding = 1, padding_mode = 'reflect'),
        nn.LeakyReLU(0.2)
    )

    layers = []
    in_channels = features[0]
    for feature in features[1:]:
      layers.append(
          CNNBlock(in_channels, feature, stride = 1 if feature == features[-1] else 2)
      )
      in_channels = feature
    
    layers.append(
        nn.Conv2d(in_channels, 1, kernel_size = 4, stride =1, padding=1, padding_mode = 'reflect')
    )
    self.model = nn.Sequential(*layers)

  def forward(self, x, y):
    x = torch.cat([x,y], dim =1)
    x = self.initial(x)
    return self.model(x)

In [4]:
def test():
  x = torch.randn((1, 3, 256, 256))
  y = torch.randn((1, 3, 256, 256))
  model = Discriminator()
  preds = model(x, y)
  print(preds.shape)

In [5]:
test()

torch.Size([1, 1, 26, 26])


# Generator Network

In [6]:
class Block(nn.Module):
  def __init__(self, in_channels, out_channels, down = True, act = 'relu', use_dropout = False):
    super().__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias = False, padding_mode = 'reflect')
        if down
        else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias = False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU() if act=='relu' else nn.LeakyReLU(0.2)
    )
    self.use_dropout = use_dropout
    self.dropout = nn.Dropout(0.5)

  def forward(self, x):
    x = self.conv(x)
    return self.dropout(x) if self.use_dropout else x


class Generator(nn.Module):
  def __init__(self, in_channels = 3, features = 64):
    super().__init__()
    self.initial_down = nn.Sequential(
        nn.Conv2d(in_channels, features, 4, 2, 1, padding_mode = 'reflect'),
        nn.LeakyReLU(0.2)
    )

    self.down1 = Block(features, features*2, down = True, act = 'leaky', use_dropout=False) #128
    self.down2 = Block(features*2, features*4, down = True, act = 'leaky', use_dropout=False) #64
    self.down3 = Block(features*4, features*8, down = True, act = 'leaky', use_dropout=False) #32
    self.down4 = Block(features*8, features*8, down = True, act = 'leaky', use_dropout=False) #16
    self.down5 = Block(features*8, features*8, down = True, act = 'leaky', use_dropout=False) # 4
    self.down6 = Block(features*8, features*8, down = True, act = 'leaky', use_dropout=False) # 2
    self.bottleneck = nn.Sequential(
        nn.Conv2d(features*8, features*8, 4, 2, 1, padding_mode = "reflect"), nn.ReLU() # 1 x 1
    )
    self.up1  = Block(features*8, features*8, down = False, act = 'relu', use_dropout = True)
    self.up2  = Block(features*8*2, features*8, down = False, act = 'relu', use_dropout = True)
    self.up3  = Block(features*8*2, features*8, down = False, act = 'relu', use_dropout = True)
    self.up4  = Block(features*8*2, features*8, down = False, act = 'relu', use_dropout = True)
    self.up5  = Block(features*8*2, features*4, down = False, act = 'relu', use_dropout = True)
    self.up6  = Block(features*4*2, features*2, down = False, act = 'relu', use_dropout = True)
    self.up7  = Block(features*2*2, features, down = False, act = 'relu', use_dropout = True)

    self.final_up = nn.Sequential(
        nn.ConvTranspose2d(features*2, in_channels, kernel_size=4, stride = 2, padding = 1),
        nn.Tanh(),
    )

  def forward(self, x):
    d1 = self.initial_down(x)
    d2 = self.down1(d1)
    d3 = self.down2(d2)
    d4 = self.down3(d3)
    d5 = self.down4(d4)
    d6 = self.down5(d5)
    d7 = self.down6(d6)

    bottleneck = self.bottleneck(d7)

    up1 = self.up1(bottleneck)
    up2 = self.up2(torch.cat([up1, d7], 1))
    up3 = self.up3(torch.cat([up2, d6], 1))
    up4 = self.up4(torch.cat([up3, d5], 1))
    up5 = self.up5(torch.cat([up4, d4], 1))
    up6 = self.up6(torch.cat([up5, d3], 1))
    up7 = self.up7(torch.cat([up6, d2], 1))

    return self.final_up(torch.cat([up7, d1], 1))      

In [7]:
def test():
  x = torch.randn((1, 3, 256, 256))
  model = Generator()
  preds = model(x)
  print(preds.shape)

In [8]:
test()

torch.Size([1, 3, 256, 256])


In [9]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
lr = 2e-4
batch_size = 16
num_workers = 2
img_size = 256
channels_ing = 100
l1_lambda = 100
num_epochs = 500
load_model = False
save_model = True
checkpoint_disc = "disc.pth.tar"
checkpoint_gen = "gen_pth.tar"

both_transform = A.Compose(
    [A.Resize(width=256, height=256), A.HorizontalFlip(p=0.5)], additional_targets = {"image0": "image"}
)

transform_only_input = A.Compose(
    [A.ColorJitter(p=0.1),
     A.Normalize(mean = [0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5], max_pixel_value = 255.0),
     ToTensorV2()]
)

transform_only_mask = A.Compose(
     [A.Normalize(mean = [0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5], max_pixel_value = 255.0),
     ToTensorV2()]
)

# Custom Dataset

In [10]:
from PIL import Image
import numpy as np
import os
from torch.utils.data import Dataset

class CustomDataset(Dataset):
  def __init__(self, dir_path):
    self.root_dir = dir_path
    self.list_files = os.listdir(self.root_dir)
    # print(self.list_files)

  def __len__(self):
    return len(self.list_files)

  def __getitem__(self, index):
    img_file = self.list_files[index]
    img_path = os.path.join(self.root_dir, img_file)
    image = np.array(Image.open(img_path))
    input_image= image[:, :600, :]
    target_image = image[:, 600:, :]

    augmentations = both_transform(image = input_image, image0 = target_image)
    input_image, target_image = augmentations['image'], augmentations['image0']

    input_image = transform_only_input(image = input_image)['image']
    target_image = transform_only_input(image = target_image)['image']

    return input_image, target_image

#Save and Load Check Points

In [11]:
from torchvision.utils import save_image

def save_some_examples(gen, val_loader, epoch, folder):
  x, y = next(iter(val_loader))
  x, y = x.to(DEVICE), y.to(DEVICE)
  gen.eval()
  with torch.no_grad():
    y_fake = gen(x)
    y_fake = y_fake * 0.5 + 0.5 # remove normalization
    save_image(y_fake, folder + f"/y_gen_{epoch}.png")
    save_image(x * 0.5 + 0.5, folder + f"/label_{epoch}.png")
    if epoch == 1:
      save_image(y * 0.5 + 0.5, folder + f"/label_{epoch}.png")
  gen.train()

def save_checkpoint(model, optimizer, filename = "mychcekpoint.pth.tar"):
  print("Saving checkpoint..")
  checkpoint = {
      "state_dict" : model.state_dict(),
      "optimizer" : optimizer.state_dict()
  }
  torch.save(checkpoint, filename)

def load_checkpoint(checkpoint_file, model, optimizer, lr):
  print("Loading checkpoint..")
  checkpoint = torch.load(checkpoint_file, map_location= DEVICE)
  model.load_state_dict(checkpoint['state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer'])

  for param_group in optimizer.param_groups:
    param_group['lr'] = lr

# Training Loop

In [12]:
def train(disc, gen, loader, opt_disc, opt_gen, l1, bce, g_scalar, d_scalar):
  loop = tqdm(loader, leave = True)

  for idx, (x, y) in enumerate(loop):
    x, y = x.to(DEVICE), y.to(DEVICE)

    # Train Discriminator
    with torch.cuda.amp.autocast():
      y_fake = gen(x)
      d_real = disc(x, y)
      d_fake = disc(x, y_fake.detach())

      d_real_loss = bce(d_real, torch.ones_like(d_real))
      d_fake_loss = bce(d_fake, torch.zeros_like(d_fake))
      d_loss = (d_real_loss + d_fake_loss)/2
    
    disc.zero_grad()
    d_scalar.scale(d_loss).backward()
    d_scalar.step(opt_disc)
    d_scalar.update()

    # Train Generator
    with torch.cuda.amp.autocast():
      d_fake = disc(x, y_fake)
      g_fake_loss = bce(d_fake, torch.ones_like(d_fake))
      L1 = l1(y_fake, y) * l1_lambda
      g_loss = g_fake_loss + L1
    
    opt_gen.zero_grad()
    g_scalar.scale(g_loss).backward()
    g_scalar.step(opt_gen)
    g_scalar.update()

def main():
  disc = Discriminator(in_channels=3).to(DEVICE)
  gen = Generator(in_channels =3).to(DEVICE)
  opt_disc = optim.Adam(disc.parameters(), lr = lr, betas = (0.5, 0.999))
  opt_gen = optim.Adam(gen.parameters(), lr = lr, betas = (0.5, 0.999))
  BCE = nn.BCEWithLogitsLoss()
  L1_Loss = nn.L1Loss()

  if load_model:
    load_checkpoint(checkpoint_gen, gen, opt_gen, lr)
    load_checkpoint(checkpoint_disc, disc, opt_disc, lr)

  train_dataset = CustomDataset(dir_path="/content/drive/MyDrive/data/maps/train")
  train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers)
  g_scalar = torch.cuda.amp.GradScaler()
  d_scalar = torch.cuda.amp.GradScaler()
  val_dataset = CustomDataset(dir_path = "/content/drive/MyDrive/data/maps/val")
  val_loader = DataLoader(val_dataset, batch_size = 8, shuffle = False)

  for epoch in range(num_epochs):
    train(disc, gen, train_loader, opt_disc, opt_gen, L1_Loss, BCE, g_scalar, d_scalar)

    if save_model and epoch % 5 == 0:
      save_checkpoint(gen, opt_gen, filename = checkpoint_gen)
      save_checkpoint(disc, opt_disc, filename = checkpoint_disc)

    save_some_examples(gen, val_loader, epoch, folder = '/content/drive/MyDrive/evaluation')

In [None]:
main()

100%|██████████| 70/70 [00:47<00:00,  1.49it/s]


Saving checkpoint..
Saving checkpoint..


100%|██████████| 70/70 [00:19<00:00,  3.53it/s]
100%|██████████| 70/70 [00:19<00:00,  3.54it/s]
 44%|████▍     | 31/70 [00:09<00:11,  3.46it/s]