# Pix2Pix

#### References
* [Datasets](https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/)
* [Code](https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/pix2pix/pix2pix.py)
* [Another Code](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix)
* [Paper](https://arxiv.org/pdf/1611.07004.pdf)
* [Confluece page](https://machinereinforcedbook.atlassian.net/wiki/spaces/ML/pages/22052990/Pix2Pix)

In [1]:
import sys
sys.path.insert(0,'..')
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch import optim
from PIL import Image
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from models import utils_unet
from dataset_utils.img_folder_dataset import ImageDataset
from torch.optim.lr_scheduler import StepLR
import numpy as np
from tqdm import tqdm
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Device:', device)
num_gpu = torch.cuda.device_count()
#num_gpu = 8
print('Number of GPUs Available:', num_gpu)
print('Pytorch version:', torch.__version__)

# Tensorboard
from torch.utils.tensorboard import SummaryWriter
!rm -rf ./runs
writer = SummaryWriter('./runs/train')

# Metaparameters
learning_rate = 0.0002
b1 = 0.5
b2 = 0.999
img_height = img_width = 256
dataset_name = 'maps'
batch_size = 5
num_epochs = 200
# Loss weight of L1 pixel-wise loss between translated image and real image
lambda_pixel = 100

# Calculate output of image discriminator (PatchGAN)
patch = (1, img_height // 2 ** 4, img_width // 2 ** 4)
print('Patch size:', patch)

invert_targets = False

Device: cuda:0
Number of GPUs Available: 8
Pytorch version: 1.2.0
Patch size: (1, 16, 16)


#### DataLoaders

In [2]:
transforms_ = [
    transforms.Resize((img_height, img_width), Image.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]

dataloader_train = DataLoader(
    ImageDataset("../data/%s" % dataset_name, transforms_=transforms_),
    batch_size=batch_size,
    shuffle=True,
    num_workers=10,
)

val_dataloader = DataLoader(
    ImageDataset("../data/%s" % dataset_name, transforms_=transforms_, mode="val"),
    batch_size=10,
    shuffle=True,
    num_workers=1,
)

#### Declare Generator and Discriminator Models

In [3]:
class GeneratorUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(GeneratorUNet, self).__init__()

        self.down1 = utils_unet.UNetDown(in_channels, 64, normalize=False)
        self.down2 = utils_unet.UNetDown(64, 128)
        self.down3 = utils_unet.UNetDown(128, 256)
        self.down4 = utils_unet.UNetDown(256, 512, dropout=0.5)
        self.down5 = utils_unet.UNetDown(512, 512, dropout=0.5)
        self.down6 = utils_unet.UNetDown(512, 512, dropout=0.5)
        self.down7 = utils_unet.UNetDown(512, 512, dropout=0.5)
        self.down8 = utils_unet.UNetDown(512, 512, normalize=False, dropout=0.5)

        self.up1 = utils_unet.UNetUp(512, 512, dropout=0.5)
        self.up2 = utils_unet.UNetUp(1024, 512, dropout=0.5)
        self.up3 = utils_unet.UNetUp(1024, 512, dropout=0.5)
        self.up4 = utils_unet.UNetUp(1024, 512, dropout=0.5)
        self.up5 = utils_unet.UNetUp(1024, 256)
        self.up6 = utils_unet.UNetUp(512, 128)
        self.up7 = utils_unet.UNetUp(256, 64)

        self.final = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(128, out_channels, 4, padding=1),
            nn.Tanh(),
        )
    def forward(self, x):
        # U-Net generator with skip connections from encoder to decoder
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)
        u1 = self.up1(d8, d7)
        u2 = self.up2(u1, d6)
        u3 = self.up3(u2, d5)
        u4 = self.up4(u3, d4)
        u5 = self.up5(u4, d3)
        u6 = self.up6(u5, d2)
        u7 = self.up7(u6, d1)
        return self.final(u7)
    

class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, normalization=True):
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalization:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(in_channels * 2, 64, normalization=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, 4, padding=1, bias=False)
        )

    def forward(self, img_A, img_B):
        # Concatenate image and condition image by channels to produce input
        img_input = torch.cat((img_A, img_B), 1)
        return self.model(img_input)

G = GeneratorUNet()
writer.add_graph(G, torch.ones((1,3,256,256)))
G = G.to(device)
D = Discriminator().to(device)
print(D)

Discriminator(
  (model): Sequential(
    (0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (9): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): ZeroPad2d(padding=(1, 0, 1, 0), value=0.0)
    (12): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), bias=False)
  )
)


#### Define Loss Functions

In [4]:
criterion_GAN = torch.nn.MSELoss()
criterion_pixelwise = torch.nn.L1Loss()

#### Define Optimizers

In [5]:
optimizer_G = torch.optim.Adam(G.parameters(), lr=learning_rate, betas=(b1, b2))
optimizer_D = torch.optim.Adam(D.parameters(), lr=learning_rate, betas=(b1, b2))

#### Train loop

In [None]:
for epoch in tqdm(range(num_epochs)):
    running_loss_G = 0.0
    running_loss_D = 0.0
    running_pixelwise_loss = 0.0
    # Iterate over the training data
    for idx_sample, batch in enumerate(dataloader_train):
        if invert_targets:
            real_A = batch["B"].to(device)
            real_B = batch["A"].to(device)
        else:
            real_A = batch["A"].to(device)
            real_B = batch["B"].to(device)
        batch_size = real_B.size()[0]
        
        # Adversarial ground truths (you can do soft-label here....)
        # Remember that our discriminator outputs a grid of patches
        valid = torch.ones(batch_size, *patch).to(device)
        fake = torch.zeros(batch_size, *patch).to(device)
    
        # Train Generators
        optimizer_G.zero_grad()
        # GAN loss
        fake_B = G(real_A)
        pred_fake = D(fake_B, real_A)
        
        loss_GAN = criterion_GAN(pred_fake, valid)
        
        # Pixel-wise loss
        loss_pixel = criterion_pixelwise(fake_B, real_B)
        
        # Total loss
        loss_G = loss_GAN + lambda_pixel * loss_pixel
        loss_G.backward()
        optimizer_G.step()
        
        # Train Discriminator
        optimizer_D.zero_grad()
        # Real loss
        pred_real = D(real_B, real_A)
        loss_real = criterion_GAN(pred_real, valid)
        # Fake loss
        pred_fake = D(fake_B.detach(), real_A)
        loss_fake = criterion_GAN(pred_fake, fake)
        # Total loss
        loss_D = 0.5 * (loss_real + loss_fake)
        loss_D.backward()
        optimizer_D.step()
        
        # Update statistics
        running_loss_G += loss_G.item() * batch_size
        # Update statistics
        running_loss_D += loss_D.item() * batch_size
    
    # Epoch ends
    epoch_loss_generator = running_loss_G / len(dataloader_train.dataset)
    epoch_loss_discriminator = running_loss_D / len(dataloader_train.dataset)
    
    # Send results to tensorboard
    writer.add_scalar('train/loss_generator', epoch_loss_generator, epoch)
    writer.add_scalar('train/loss_discriminator', epoch_loss_discriminator, epoch)
    
    # Send images to tensorboard
    writer.add_images('train/real_A', real_A, epoch)
    writer.add_images('train/real_B', real_B, epoch)
    writer.add_images('train/G', fake_B, epoch)

  3%|▎         | 6/200 [02:52<1:33:02, 28.78s/it]