# Training conditional GANs (cGANs): Pix2Pix

This notebook implements the Pix2Pix conditional GAN as described in the paper:
- Isola et al. (2018) Image-to-Image Translation with Conditional Adversarial Networks: https://arxiv.org/pdf/1611.07004.pdf

Use the "creating_a_pix2pix_dataset" notebook to create a dataset, or alternatively use an existing dataset by downloading it from one of these links:
-   Standard pix2pix datasets: [http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/](http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/)
-   Comic faces: [https://www.kaggle.com/datasets/defileroff/comic-faces-paired-synthetic](https://www.kaggle.com/datasets/defileroff/comic-faces-paired-synthetic)
-   Maps: [https://www.kaggle.com/datasets/alincijov/pix2pix-maps](https://www.kaggle.com/datasets/alincijov/pix2pix-maps)
-   Edges to Rembrandt: [https://www.kaggle.com/datasets/grafstor/rembrandt-pix2pix-dataset](https://www.kaggle.com/datasets/grafstor/rembrandt-pix2pix-dataset)
-   Depth [https://www.kaggle.com/datasets/greg115/pix2pix-depth](https://www.kaggle.com/datasets/greg115/pix2pix-depth)


Now let's do our usual list of imports

In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.utils as vutils
from torch.utils.data import Dataset, DataLoader, Subset
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os

And then proceed with setting up the notebook so it can find and parse our dataset and train the network:

In [None]:
dataset_path = "./datasets/edge2comics/" # Change this for your custom dataset
target_index = 1 # This 

img_channels = 3 # Do not change
img_size = 256   # Do not change 
batch_size = 1 
val_size = 0.05  # Validation set size
# Try chaning 'cpu' to 'mps' if using mac M1/M2, may speed up things
device = 'cuda' if torch.cuda.is_available() else 'cpu' 
if device == 'mps':
    torch.set_default_tensor_type(torch.FloatTensor)

Each training image in a standard pix2pix dataset consists of one imgage divided into two adjacent **source** and **target** images.
The layout of the source and target may vary from training set to trainig set, so we provide a `target_index` flag the determines on which side the target is (`0` if on the left and `1` if on the right). Set this so the examples from the dataset appear with the source image to the left.

The following code also **augments** the dataset by applying random uniform scaling (by upscaling and cropping) and random mirroring to the input output pairs. This should lead to a more stable model according to the original pix2pix paper. Finally the images ar normalized to the [-1,1] range as required by our GAN-based model.

We will organize the dataset in batches of size `1`, as that is generally suggested for pix2pix models. That means that we will update the weights of the model for each image pair separately.

Run the code below and examine the resulting example images. Then set the `target_index` variable to reflect the position of the target image. That is `target_index=0` if the target image is on the left and `target_index=1` if it is on the right.

NOTE: visualization will break here on M1/M2 macs if you used `mps` for the device (https://github.com/pytorch/pytorch/issues/84523)



In [None]:
def random_jitter(input_image, target_image):
    # Resizing to 286x286
    resize_transform = transforms.Resize(size=(286, 286), interpolation=transforms.InterpolationMode.NEAREST)
    input_image = resize_transform(input_image)
    target_image = resize_transform(target_image)

    # Random cropping back to 256x256
    i, j, h, w = transforms.RandomCrop.get_params(input_image, output_size=(256, 256))
    input_image = transforms.functional.crop(input_image, i, j, h, w)
    target_image = transforms.functional.crop(target_image, i, j, h, w)

    # Random mirroring
    if np.random.uniform() < 0.5:
        input_image = transforms.functional.hflip(input_image)
        target_image = transforms.functional.hflip(target_image)

    return input_image, target_image

class Pix2PixImageDataset(Dataset):
    def __init__(self, path, target_index):
        super(Pix2PixImageDataset, self).__init__()
        self.files = [os.path.join(path, f) for f in os.listdir(path) if '.jpg' in f or '.png' in f]
        self.target_index = target_index
        
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        path = self.files[idx]
        image = Image.open(path)
        image = torchvision.transforms.ToTensor()(image)
        w = image.shape[-1]
        w = w // 2

        if target_index == 0:
            input_image = image[:, :, w:]
            target_image = image[:, :, :w]
        else:
            target_image = image[:, :, w:]
            input_image = image[:, :, :w]
        # Jitter
        input_image, target_image = random_jitter(input_image, target_image)
        # Normalize
        input_image = input_image*2 - 1
        target_image = target_image*2 - 1
        return input_image.to(device), target_image.to(device)


train_dataset = Pix2PixImageDataset(dataset_path, target_index)
val_dataset = Pix2PixImageDataset(dataset_path, target_index)

# get length of the full dataset before split, and save it in idx
num_train = len(train_dataset)

# create an array of idx numbers for each element of the full dataset
idx = list(range(num_train))
#print(num_train, idx)

# perform train / val split for data points
train_indices, val_indices = train_test_split(idx, test_size=val_size, random_state=42)

# override datasets to only be samples for each split
train_dataset = Subset(train_dataset, train_indices)
val_dataset = Subset(val_dataset, val_indices)

# create data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# get a batch of training images
sample_batch = [torch.cat(image_pair, axis=-1)[0] for image_pair in list(train_loader)[:64]]

# create a grid of images
img_grid = vutils.make_grid(sample_batch, padding=2, normalize=True)

# convert to NumPy and transpose dimensions for matplotlib
img_grid_np = np.transpose(img_grid.detach().cpu().numpy(), (1, 2, 0))

# plot the grid of images
plt.figure(figsize=(15, 8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(img_grid_np)
plt.show()


## Build  the model



The pix2pix model is a conditional generative adversarial network (cGAN). A CGAN
is a type of GAN model used for generating new data samples with specific
attributes or characteristics. In a CGAN, both the generator and discriminator
are *conditioned* on additional information, such as class labels, tags, or
other types of metadata. The generator network takes in random noise as well as
the conditional information as input and produces a new data sample that matches
the desired attributes. The discriminator network, on the other hand, tries to
distinguish between the generated samples and real samples based on both their
visual appearance and the conditional information. For the case of a pix2pix
model the network is conditioned on an image, which should be transformed into
an output image.



### Generator



Differently from a DC-GAN, the generator of the pix2pix model is based on the
[U-net](https://arxiv.org/abs/1505.04597) architecture. A U-net model is a CNN architecture that is typically used
for image segmentation tasks. The name U-net derives from the architecture,
which resembles the letter &ldquo;U&rdquo;. It consists of two main parts: an *encoder* and
a *decoder*. The encoder part consists of a series of convolutional layers,
which *decrease* the spatial dimensionality of the input image while increasing its
depth (using `Conv2d` layers). This is followed by a bottleneck layer that extracts the most important
features from the input image. The decoder part is a &ldquo;mirror image&rdquo; of the
encoder. It consists of a series of layers that gradually *increase* the spatial
dimensionality of the output, while decreasing its depth (using `ConvTranspose2d` layers). The output of each consecutive layer in the encoder is
concatenated with the output of a corresponding layer in the decoder, which creates a "U" shape. This creates
&ldquo;skip connections&rdquo; that help preserve spatial information and correclations and avoid information
loss during the encoding and decoding process. 


In [None]:
import torch
import torch.nn as nn

import torch.nn as nn
import torch.nn.functional as F

class Downsample(nn.Module):
    def __init__(self, in_channels, out_channels, size=4, stride=2, apply_batchnorm=True):
        # Convolution-BatchNorm-ReLU
        super().__init__() 
        self.conv = nn.Conv2d(in_channels, out_channels, size, stride=stride, padding=1, bias=not apply_batchnorm)
        self.batchnorm = nn.BatchNorm2d(out_channels) if apply_batchnorm else None
        self.leakyrelu = nn.LeakyReLU(0.2, True)

    def forward(self, x):
        x = self.conv(x)
        if self.batchnorm is not None:
            x = self.batchnorm(x)
        x = self.leakyrelu(x)
        return x

class Upsample(nn.Module):
    def __init__(self, in_channels, out_channels, size=4, stride=2, apply_dropout=False):
        # Convolution-BatchNorm-Dropout-ReLU
        super().__init__() 
        self.conv_transpose = nn.ConvTranspose2d(in_channels, out_channels, size, stride=stride, padding=1, bias=True)
        self.batchnorm = nn.BatchNorm2d(out_channels)
        self.dropout = nn.Dropout(0.5) if apply_dropout else None
        self.relu = nn.ReLU(True)

    def forward(self, x):
        x = self.conv_transpose(x)
        x = self.batchnorm(x)
        if self.dropout is not None:
            x = self.dropout(x)
        x = self.relu(x)
        return x

class Generator(nn.Module):
    def __init__(self, img_channels=3):
        super().__init__() 
        # encoder:
        # C64-C128-C256-C512-C512-C512-C512-C512
        # decoder with skip (in/out):
        # CD512-CD1024-CD1024-C1024-C1024-C512-C256-C128
        # CD512-CD512 -CD512 -C512 -C256 -C128-C64
        
        self.encoders = nn.ModuleList([
            Downsample(3, 64, apply_batchnorm=False),
            Downsample(64, 128),
            Downsample(128, 256),
            Downsample(256, 512),
            Downsample(512, 512),
            Downsample(512, 512),
            Downsample(512, 512),
            Downsample(512, 512, apply_batchnorm=False)
        ])

        self.decoders = nn.ModuleList([
            Upsample(512, 512, apply_dropout=True),
            Upsample(1024, 512, apply_dropout=True),
            Upsample(1024, 512, apply_dropout=True),
            Upsample(1024, 512),
            Upsample(1024, 256),
            Upsample(512, 128),
        ])
        self.last_decoder = Upsample(256, 64)
        self.last = nn.ConvTranspose2d(64, img_channels, kernel_size=4, stride=2, padding=1)

    def forward(self, x):
        skips = []
        for i, down in enumerate(self.encoders):
            x = down(x)
            skips.append(x)
        skips = skips[:-1][::-1]
        for i, up in enumerate(self.decoders):
            x = up(x) 
            x = torch.cat([x, skips[i]], dim=1)
        x = self.last_decoder(x)
        x = self.last(x)
        return torch.tanh(x)

generator = Generator().to(device)
print(generator)

### Discriminator



The discriminator in the pix2pix model is a convolutional &ldquo;PatchGAN classifier&rdquo;.
It tries to classify if each **image patch** if it is real or not real. In the
following decoder, each 30 x 30 image patch of the output classifies a 70 x 70
portion of the input image.



In [None]:
class Discriminator(nn.Module):
    def __init__(self, image_channels=3):
        # C64-C128-C256-C512
        super(Discriminator, self).__init__()
        self.down1 = Downsample(image_channels*2, 64, 4, apply_batchnorm=False)
        self.down2 = Downsample(64, 128, 4)
        self.down3 = Downsample(128, 256, 4)
        self.down4 = Downsample(256, 512, 4, stride=1)
        self.last = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=1)

    def forward(self, inp, tar):
        x = torch.cat([inp, tar], dim=1)
        x = self.down1(x)
        x = self.down2(x)
        x = self.down3(x)
        x = self.down4(x)
        x = self.last(x)
        return torch.sigmoid(x)

discriminator = Discriminator().to(device)
print(discriminator)

### Generate some images before training



Let&rsquo;s generate some images before training to see what the network will output



In [None]:
def numpy_image(x):
    return np.transpose(x.detach().cpu().numpy(), (1, 2, 0))*0.5 + 0.5

def generate_images(model, test_input, tar, fname=''):
    prediction = model(test_input) #, training=True)
    plt.figure(figsize=(10, 10))

    if tar is not None:
        display_list = [test_input[0], tar[0], prediction[0]]
        title = ['Input Image', 'Ground Truth', 'Predicted Image']
    else:
        display_list = [test_input[0], prediction[0]]
        title = ['Input Image', 'Predicted Image']

    for i in range(len(title)):
        plt.subplot(1, len(title), i+1)
        plt.title(title[i])
        # Getting the pixel values in the [0, 1] range to plot.
        plt.imshow(numpy_image(display_list[i]))
        plt.axis('off')

    if fname:
        plt.savefig(fname)
        plt.close()
    else:
        plt.show()

for example_input, example_target in list(train_loader)[:3]:
    print(example_input.device, example_target.device)
    generate_images(generator, example_input, example_target)
    break


## Training the model



### Generator loss



GANs learn a loss that adapts to the data, while cGANs learn a structured loss that penalizes a possible structure that differs from the network output and the target image, as described in the [pix2pix paper](https://arxiv.org/abs/1611.07004).

-   The generator loss is a sigmoid cross-entropy loss of the generated images and an array of ones.
-   The pix2pix paper also mentions the L1 loss, which is a MAE (mean absolute error) between the generated image and the target image.
-   This allows the generated image to become structurally similar to the target image.
-   The formula to calculate the total generator loss is `gan_loss + LAMBDA * l1_loss`, where `LAMBDA = 100`. This value was decided by the authors of the paper.

Feel free to experiment with modifying the value of `LAMBDA` (if you have time to spare:))



## Training loop



The training loop procedes by separately optimizing the discriminator and generator at each iteration. The procedure can be summarized as follows:
- For each example input we use the Generator to generate an output.
- Update the discriminator by:
    -  (1) Feeding it the input image and the example target image to classify the ground truth (example) pair.
    -  (2) Feeding it the input image together with the generated output to classify the generated pair.
    -  Using these two outputs (1 and 2) to compute the discriminator loss and to update the discriminator parameters to minimize this loss. In order to update only the discriminator, when computing step (2) the generated image is "detached" (using the `.detach()` function) from the Torch computation graph, so that the gradients will not be "frozen" and not propagated back to the generator. 
- Update the generator by:
    -  Computing (2) again with the updated discriminator but this time without detaching the generated image
    -  Computing the generator loss by combining the classification loss computed for the discriminator and the [L1 distance](https://montjoile.medium.com/l0-norm-l1-norm-l2-norm-l-infinity-norm-7a7d18a4f40c) between the generated image and the target one and finally updating the parameters of the generator to minimize this loss.

#### Discriminator loss
The discriminator loss (`disc_loss`) consists of the average of two terms, a `real_loss` and a `fake_loss`:
- The `real_loss` is the is a [binary cross-entropy loss](https://gombru.github.io/2018/05/23/cross_entropy_loss/) of the (discriminated) real images and an array of ones (since these are the real images). 
- The `fake_loss` is the is a binary cross-entropy loss of the (discriminated) fake images and an array of zeros (since these are the fake images). 

#### Generator loss
While GANs learn a loss that adapts to the data, cGANs (as Pix2Pix) learn a structured loss that penalizes a possible structure that differs from the network output and the target image. As described in the [pix2pix paper](https://arxiv.org/abs/1611.07004) the generator loss consists of two terms:

-   Similarly to the discriminator case, the first term `fake_gan_loss` is a sigmoid cross-entropy loss of the (discriminated) generated images and an array of ones, i.e. considering the generated output as a real sample.
-   The second term `dist_loss` quantifies the L1 distance, i.e. the mean absolute error (absolute value of differences), between the generated image and the target image. This allows the generated image to become structurally similar to the target image.
-   These two terms are combined as `fake_gan_loss + LAMBDA * dist_loss`, where `LAMBDA = 100`. This value was decided by the authors of the paper.



In [None]:
import os
import sys
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import matplotlib.pyplot as plt

epochs = 200
LAMBDA = 100 # Weight of L1 loss in optimization 
save_interval = 1

# Automatically create model path from dataset path, change this in case you want to customize name
model_path = os.path.join("./models/", os.path.basename(os.path.dirname(dataset_path))) 

os.makedirs(model_path, exist_ok=True)

gen_optimizer = torch.optim.Adam(generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
disc_optimizer = torch.optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))

BCE_loss = nn.BCELoss()
L1_loss = nn.L1Loss()

def train_step(input_image, target_image):
    # Generate output
    gen_output = generator(input_image) 

    # ---- Update discriminator ----
    disc_optimizer.zero_grad() # Clear gradients

    # Classify real and fake patches 
    # Here we "freeze" generator gradients since we only optimize the discriminator
    real_patch = discriminator(input_image, target_image)
    fake_patch = discriminator(input_image, gen_output.detach())

    # Compute loss for real/fake patches
    # log(D(x,y)) + log(1 - D(x,G(x)))
    real_class = torch.ones_like(real_patch).to(device)
    fake_class = torch.zeros_like(fake_patch).to(device)

    real_loss = BCE_loss(real_patch, real_class)
    fake_loss = BCE_loss(fake_patch, fake_class)
    disc_loss = (real_loss + fake_loss)/2

    # Propagate gradients and perform gradient descent step
    disc_loss.backward()
    disc_optimizer.step()

    # ---- Update generator ---- 
    gen_optimizer.zero_grad() # Clear gradients
    # Classify fake samples, now considering generator gradients
    fake_patch = discriminator(input_image, gen_output)
    # Compute loss according to paper 
    # log(D(x,G(x))) + L1(y,G(x))
    fake_gan_loss = BCE_loss(fake_patch, real_class)
    dist_loss = L1_loss(gen_output, target_image)
    gen_loss = fake_gan_loss + LAMBDA * dist_loss

    # Propagate gradients and perform gradient descent step
    gen_loss.backward()
    gen_optimizer.step()

    return gen_loss, disc_loss

g_losses = []
d_losses = []

torch.autograd.set_detect_anomaly(False)

n = len(train_loader)
for epoch in range(epochs):
    batch_d_losses = []
    batch_g_losses = []
    for i, (input_image, target_image) in enumerate(train_loader):
        gen_loss, disc_loss = train_step(input_image, target_image)
        batch_d_losses.append(disc_loss.item())
        batch_g_losses.append(gen_loss.item())
        sys.stdout.write("\r" + "Epoch %d - image %d of %d " % (epoch+1, i+1, n) + "[gen loss: %.4f | disc loss: %.4f]" % (gen_loss.item(), disc_loss.item()))

    g_losses.append(np.mean(batch_g_losses))
    d_losses.append(np.mean(batch_d_losses))

    plt.figure(figsize=(6,5))
    plt.title('Losses')
    plt.plot(np.array(d_losses) * 40, label='Discriminator')
    plt.plot(g_losses, label='Generator')
    plt.legend()
    plt.savefig(os.path.join(model_path, "losses.pdf"))
    plt.close()

    if epoch % save_interval == 0:
        print('\nSaving epoch %d to %s' % (epoch+1, model_path))    
        for j, (example_input, example_target) in enumerate(list(train_loader)[:3]):
            generate_images(generator, example_input, example_target, fname=os.path.join(model_path, "e%03d_generated_image_%d.png" % (epoch+1, j+1)))
            generator_scripted = torch.jit.script(generator) 
            generator_scripted.save(os.path.join(model_path, "e%0d_generator.pt" % (epoch+1)))
            # The following saves only model parameters
            #torch.save(generator.state_dict(), os.path.join(model_path, "e%0d_generator.pth" % (epoch+1)))
            

Finally let's try generating some images using the validation set

In [None]:

for example_input, example_target in list(val_loader)[:3]:
    print(example_input.device, example_target.device)
    generate_images(generator, example_input, example_target)
    
