<a href="https://colab.research.google.com/github/kairavkkp/ML-Tutorials/blob/pix2pix/pix2pix/pix2pix_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q kaggle

In [2]:
from google.colab import files
files.upload()

Saving kaggle.json to kaggle.json


{'kaggle.json': b'{"username":"sagarjiyani","key":"03e058eeaa881793908f116aafb89e22"}'}

In [3]:
! mkdir ~/.kaggle
! cp kaggle.json ~/.kaggle/
! chmod 600 ~/.kaggle/kaggle.json

In [4]:
! kaggle datasets download -d vikramtiwari/pix2pix-dataset

Downloading pix2pix-dataset.zip to /content
100% 2.39G/2.40G [00:33<00:00, 32.5MB/s]
100% 2.40G/2.40G [00:33<00:00, 76.6MB/s]


In [None]:
! unzip pix2pix-dataset.zip -d maps

In [6]:
import torch
import torch.nn as nn

In [7]:
PATH = '/content/maps/maps/maps'

# Descriminator

In [8]:
class CNNBlock(nn.Module):
  def __init__(self, in_channels, out_channels, stride=2):
    super().__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 4, stride, bias=False, padding_mode='reflect'),
        nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(0.2),
    )

  def forward(self, x):
    return self.conv(x)

  
class Descriminator(nn.Module):
  def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
    super().__init__()
    self.initial = nn.Sequential(
        nn.Conv2d(in_channels*2, features[0], kernel_size=4, stride=2, padding=1, padding_mode='reflect'),
        nn.LeakyReLU(0.2),
    )

    layers=[]
    in_channels = features[0]
    for out_channels in features[1:]:
      layers.append(
          CNNBlock(in_channels, out_channels, stride=1 if out_channels == features[-1] else 2)
      )
      in_channels = out_channels

    layers.append(
        nn.Conv2d(in_channels, 1, kernel_size=4, padding=1, padding_mode='reflect')
    )

    self.model = nn.Sequential(*layers)
  
  def forward(self, x, y):
     x = torch.cat([x, y], dim=1)
     x = self.initial(x)
     x = self.model(x)
     return x

# Generator

In [9]:
class Block(nn.Module):
  def __init__(self, in_channels, features, down=True, activation='relu', use_dropout=False):
    super().__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, features, 4, 2, 1, bias=False, padding_mode='reflect')
        if down
        else nn.ConvTranspose2d(in_channels, features, 4, 2, 1, bias=False),
        nn.BatchNorm2d(features),
        nn.ReLU() if activation=='relu' else nn.LeakyReLU(0.2),
    )
    self.use_dropout = use_dropout
    self.dropout = nn.Dropout(0.5)

  def forward(self, x):
    x = self.conv(x)
    x = self.dropout(x) if self.use_dropout else x
    return x

class Generator(nn.Module):
  def __init__(self, in_channels=3, features=64):
    super().__init__()
    self.initial_down = nn.Sequential(
        nn.Conv2d(in_channels, features, 4, 2, 1, padding_mode='reflect'),
        nn.LeakyReLU(0.2),
    )

    self.down1 = Block(features, features*2, down=True, activation='leaky', use_dropout=False)
    self.down2 = Block(features*2, features*4, down=True, activation='leaky', use_dropout=False)
    self.down3 = Block(features*4, features*8, down=True, activation='leaky', use_dropout=False)
    self.down4 = Block(features*8, features*8, down=True, activation='leaky', use_dropout=False)
    self.down5 = Block(features*8, features*8, down=True, activation='leaky', use_dropout=False)
    self.down6 = Block(features*8, features*8, down=True, activation='leaky', use_dropout=False)
    self.bottleneck = nn.Sequential(
        nn.Conv2d(features*8, features*8, 4, 2, 1, padding_mode='reflect'),
        nn.ReLU(),
    )

    self.up1 = Block(features*8, features*8, down=False, activation='relu', use_dropout=True)
    self.up2 = Block(features*8*2, features*8, down=False, activation='relu', use_dropout=True)
    self.up3 = Block(features*8*2, features*8, down=False, activation='relu', use_dropout=True)
    self.up4 = Block(features*8*2, features*8, down=False, activation='relu', use_dropout=False)
    self.up5 = Block(features*8*2, features*4, down=False, activation='relu', use_dropout=False)
    self.up6 = Block(features*4*2, features*2, down=False, activation='relu', use_dropout=False)
    self.up7 = Block(features*2*2, features, down=False, activation='relu', use_dropout=False)
    self.final_up = nn.Sequential(
        nn.ConvTranspose2d(features*2, in_channels, kernel_size=4, stride=2, padding=1),
        nn.Tanh(),
    )

  def farward(self, x):
    d1 = self.initial_down(x)
    d2 = self.down1(d1)
    d3 = self.down2(d2)
    d4 = self.down3(d3)
    d5 = self.down4(d4)
    d6 = self.down5(d5)
    d7 = self.down6(d6)
    bottleneck = self.bottleneck(d7)
    up1 = self.up1(bottelneck)
    up2 = self.up2(torch.cat([up1, d7], 1))
    up3 = self.up3(torch.cat([up2, d6], 1))
    up4 = self.up4(torch.cat([up3, d5], 1))
    up5 = self.up5(torch.cat([up4, d4], 1))
    up6 = self.up6(torch.cat([up5, d3], 1))
    up7 = self.up7(torch.cat([up6, d2], 1))
    final_up = self.final_up(torch.cat([up7, d1], 1))
    return final_up

# Dataset 

In [None]:
!pip install --upgrade albumentations

In [11]:
import os
from PIL import Image
import numpy as np
from torch.utils.data import Dataset

In [12]:
# Config
import albumentations as A
from albumentations.pytorch import ToTensorV2

DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
LEARNING_RATE = 2e-4
BATCH_SIZE = 16
NUM_WORKERS = 2
IMG_SIZE = 256
CHANNELS = 3
L1_LAMBDA = 100
NUM_EPOCHS = 500
LOAD_MODEL =True
SAVE_MODEL =True
CHECKPOINT_DISC = 'disc.pth.tar'
CHECKPOINT_GEN = 'gen.pth.tar'

both_transform = A.Compose(
    [
     A.Resize(width=256, height=256), 
     A.HorizontalFlip(p=0.5)
    ],
    additional_targets = {'image0' : 'image'},
)

transform_only_input = A.Compose(
    [
     A.ColorJitter(p=0.1),
     A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0),
     ToTensorV2(),
    ]
)

transform_only_mask = A.Compose(
    [
     A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5 ,0.5], max_pixel_value=255.0),
     ToTensorV2(),
    ]
)

In [13]:
# Utils
from torchvision.utils import save_image

def save_some_example(gen, val_loader, epoch, folder):
  x, y = next(iter(val_loader))
  x, y = x.to(DEVICE), y.to(DEVICE)
  gen.eval()
  with torch.no_grad():
    y_fake = gen(x)
    y_fake = y_fake * 0.5 + 0.5
    save_image(y_fake, folder + f"/y_gen_{epoch}.png")
    save_image(x * 0.5 + 0.5, folder + f"/input_{epoch}.png")
    if epoch == 1:
      save_image(y * 0.5 + 0.5, folder + f"/label_{epoch}.png")
  gen.train()

def save_checkpoint(model, optimizer, filename='my_checkpoint.pth.tar'):
  print('=> Saving Checkpoint')
  checkpoint = {
      'state_dict' : model.state_dict(),
      'optimizer' : optimizer.state_dict(),
  }
  torch.save(checkpoint, filename)

def load_checkpoint(checkpoint_file, model, optimizer, lr):
  print('=> Loading Checkpoint')
  checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
  model.load_state_dict(checpoint['state_dict'])
  optimizer.load_state_dict(checpoint['optimizer'])

  for param_group in optimizer.param_groups:
    param_group['lr'] = lr

In [14]:
class MapDataset(Dataset):
  def __init__(self, root_dir):
    self.root_dir = root_dir
    self.list_files = os.listdir(root_dir)
    print(self.list_files)

  def __len__(self):
    return len(self.list_files)

  def __get_item__(self, index):
    img_file = self.list_files[index]
    img_path = os.path.join(self.root_dir, img_file)
    image = np.array(Image.open(img_path))
    inp_image = image[:, :600, :]
    target_image = image[:, 600:, :]

    augmentations = both_transform(image=input_image, image0=target_image)
    input_image, target_image = augmentations['image'], augmentations['image0']

    input_image = transform_only_input(image=input_image)['image']
    target_image = transform_only_input(image=target_image)['image']

    return input_image, target_image

# Training

In [15]:
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm

In [16]:
def train_fn(disc, gen, loader, opt_disc, opt_gen, l1, bce, g_scalar, d_scalar):
  loop = tqdm(loader, leave=True)

  for idx, (x, y) in enumerate(loop):
    x, y = x.to(DEVICE), y.to(DEVICE)

    # Train Discriminator
    with torch.cuda.amp.autocast():
      y_fake = gen(x)
      D_real = disc(x, y)
      D_fake = disc(x, y_fake.detach())
      D_real_loss = bce(D_real, torch.ones_like(D_real))
      D_fake_loss = bce(D_fake, torch.zeros_like(D_fake))
      D_loss = (D_real_loss + D_fake_loss) / 2

    disc.zero_gead()
    d_scalar.scale(D_loss).backward()
    d_scalar.step(opt_disc)
    d_scalar.update()

    # Train Generator
    with torch.cuda.amp.autocast():
      D_fake = disc(x, y_fake)
      G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
      L1 = l1(y_fake, y) * L1_LAMBDA
      G_loss = G_fake_loss + l1

    opt_gen.zero_gead()
    g_scalar.scale(G_loss).backward()
    g_scalar.step(opt_gen)
    g_scalar.update()

def main():
  disc = Descriminator(in_channels=3).to(DEVICE)
  gen = Generator(in_channels=3).to(DEVICE)
  opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
  opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
  BCE = nn.BCEWithLogitsLoss()
  L1_LOSS = nn.L1Loss()

  if LOAD_MODEL:
    load_checkpoint(CHECKPOINT_GEN, gen, opt_gen, LEARNING_RATE)
    load_checkpoint(CHECKPOINT_DISC, disc, opt_disc, LEARNING_RATE)
    
  train_dataset = MapDataset(root_dir=os.path.join(PATH, 'train'))
  train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
  g_scalar = torch.cuda.amp.GradScaler()
  d_scalar = torch.cuda.amp.GradScaler()

  val_dataset = MapDataset(root_dir=os.path.join(PATH, 'val'))
  val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

  for epoch in NUM_EPOCHS:
    train_fn(disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, BCE, g_scalar, d_scalar)

    if SAVE_MODEL and epoch % 5 == 0:
      save_checkpoint(gen, opt_gen, CHECKPOINT_GEN)
      save_checkpoint(disc, opt_disc, CHECKPOINT_DISC)

    save_some_example(gen, val_loader, epoch, folder='evaluation')

if __name__ == '__main__':
  main()

=> Loading Checkpoint


FileNotFoundError: ignored