In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader

The **generator** architecture in *Pix2Pix* uses a *U-Net* structure. *U-Net* has in a *encoder-decoder* network famous by its *skip connections* that help preserve fine details.

In [None]:
class UNetEncoder(nn.Module):
    def __init__(self, in_channels=3, num_features=64):
        super(UNetEncoder, self).__init__()
        # This module lets PyTorch recognize and manage the layers correctly during training.
        self.layers = nn.ModuleList()
        
        # Each step in the encoder halves the spatial dimensions of the input. 
        # Due to this downsampling, we can increase the number of filters learned by the model at each step. 
        # This is why the number of features is doubled at each step.

        # Batch normalization is not used in the first layer of the encoder. Applying BatchNorm immediately after
        # the first convolution could distort the input distribution, making training unstable.
        self.layers.append(self._create_down_block(in_channels, num_features, batch_norm=False)) # 256x256x3 -> 128x128x64

        self.layers.append(self._create_down_block(num_features, num_features*2)) # 128x128x64 -> 64x64x128
        self.layers.append(self._create_down_block(num_features*2, num_features*4)) # 64x64x128 -> 32x32x256
        self.layers.append(self._create_down_block(num_features*4, num_features*8)) # 32x32x256 -> 16x16x512

        for _ in range(3):
            self.layers.append(self._create_down_block(num_features*8, num_features*8)) 
        # The output shape after these 3 blocks is 2x2x512.

        # In deeper layers, feature maps become low-resolution. When this happens, we have a limited statistical diversity, 
        # leading to poor mean/variance estimates. Because of that, we avoid the use o batch normalization.
        # Dropout randomly drops a fraction of neurons during training to prevent over-reliance on certain features.
        self.layers.append(self._create_down_block(num_features*8, num_features*8, dropout=0.5, batch_norm=False)) # 2x2x512 -> 1x1x512
    
    # The down block consists of a convolutional layer, followed by batch normalization, Leaky ReLU activation, and dropout.
    def _create_down_block(self, input_channels, out_channels, batch_norm=True, dropout=0.0):
        return nn.Sequential(
            # The stride of 2 halves the spatial dimensions of the input. Also, padding is used to maintain the spatial dimensions.
            nn.Conv2d(input_channels, out_channels, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(out_channels) if batch_norm else nn.Identity(),
            # We use Leaky ReLU activation function to avoid the dying ReLU problem.
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(dropout) if dropout > 0 else nn.Identity(),
        )
    
    # The forward method iterates over the layers and applies them sequentially. 
    def forward(self, x):
        # During the forward pass, we store the feature maps of each layer in the skips list.
        # This list is used in the decoder to concatenate the feature maps of the encoder with the feature maps of the decoder.
        skips = []
        for layer in self.layers[:-1]:
            x = layer(x)
            skips.append(x)

        # The bottleneck of the encoder does not have a skip connection.
        x = self.layers[-1](x)
        return x, skips
    

# The decoder is the second half of the U-Net architecture. It consists of up-sampling blocks to reconstruct the image
# from the extracted features.
class UNetDecoder(nn.Module):
    def __init__(self, out_channels=3, num_features=64):
        super(UNetDecoder, self).__init__()
        self.layers = nn.ModuleList()
        
        # The first block receives only the output of the encoder. There is no skip connection here yet.
        # This layer also uses dropout to prevent overfitting.
        self.layers.append(self._create_up_block(num_features*8, num_features*8, dropout=0.5)) # 1x1x512 -> 2x2x512

        # Here we start concatenating the feature maps of the encoder with the output of the previous layer. Because
        # of this, the number of input channels is doubled.
        for _ in range(2):
            self.layers.append(self._create_up_block(num_features*16, num_features*8, dropout=0.5)) # 2x2x1024 -> 4x4x512
                                                                                                    # 4x4x1024 -> 8x8x512
        
        # The following blocks are similar to the previous ones, but the number of input channels is halved at each step. 
        # Also, we don't use dropout since the resolution is higher, thus less propense to overfitting.
        self.layers.append(self._create_up_block(num_features*16, num_features*8)) # 8x8x1024 -> 16x16x512
        self.layers.append(self._create_up_block(num_features*16, num_features*4)) # 16x16x1024 -> 32x32x256
        self.layers.append(self._create_up_block(num_features*8, num_features*2)) # 32x32x512 -> 64x64x128
        self.layers.append(self._create_up_block(num_features*4, num_features)) # 64x64x256 -> 128x128x64

        # The last layer of the decoder is a convolutional layer that reduces the number of channels to the desired output.
        # Batch normalization is disabled to avoid distorting the output distribution.
        self.layers.append(self._create_up_block(num_features*2, out_channels, batch_norm=False)) # 128x128x128 -> 256x256x3

    # The up block consists of a transposed convolutional layer, followed by batch normalization, ReLU activation, and dropout.
    # This block recover the spatial dimensions of the input.
    def _create_up_block(self, input_channels, out_channels, batch_norm=True, dropout=0.0):
        return nn.Sequential(
            # We use the same kernel size and stride as the encoder.
            nn.ConvTranspose2d(input_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=not batch_norm),
            nn.BatchNorm2d(out_channels) if batch_norm else nn.Identity(),
            nn.ReLU(True),
            nn.Dropout(dropout) if dropout else nn.Identity(),
        )
    
    def forward(self, x, skips):
        skips = list(reversed(skips))
        x = self.layers[0](x) # The first layer processes only the bottleneck output.

        # The following layers concatenate the feature maps of the encoder with the feature maps
        # of the decoder through the skip connections.
        for i, layer in enumerate(self.layers[1:]):
            x = torch.cat((x, skips[i]), dim=1)
            x = layer(x)
        return x


class Generator(nn.Module):
    def __init__(self, in_channels, out_channels, num_features=64):
        super(Generator, self).__init__()

        # THe generator is simply the encoder and decoder combined.
        # In total, we have 8 down blocks and 8 up blocks.
        self.encoder = UNetEncoder(in_channels, num_features)
        self.decoder = UNetDecoder(out_channels, num_features)
    
    # The forward method of the generator first passes the input through the encoder.
    # The output of the encoder and the skip connections are then passed to the decoder.
    def forward(self, x):
        x, skips = self.encoder(x)
        x = self.decoder(x, skips)
        x = torch.tanh(x) # Normalize to [-1, 1]
        return x

The **discriminator** uses a *PatchGAN* architecture to classify image patches.

In [None]:
# In Pix2Pix, the discriminator is a PatchGAN, which means it classifies overlapping patches of the image as real or fake. 
class Discriminator(nn.Module):
    def __init__(self, in_channels, num_features=64):
        super(Discriminator, self).__init__()

        # The discriminator uses a series of convolutional layers with increasing filters, reducing spatial dimensions each time.
        # The final layer outputs a matrix where each element corresponds to a patch of the input image.
        self.model = nn.Sequential(
            # Considering a 256x256 image, the first layer outputs a 128x128 matrix.
            self._create_block(in_channels*2, num_features, stride=2, batch_norm=False), # 256x256x6 -> 128x128x64
            self._create_block(num_features, num_features*2, stride=2), # 128x128x64 -> 64x64x128
            self._create_block(num_features*2, num_features*4, stride=2), # 64x64x128 -> 32x32x256
            self._create_block(num_features*4, num_features*8, stride=1), # 32x32x256 -> 31x31x512
            nn.Conv2d(num_features*8, 1, kernel_size=4, stride=1, padding=1), # 31x31x512 -> 30x30x1
        )
        
    def _create_block(self, in_channels, out_channels, stride, batch_norm=True):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels) if batch_norm else nn.Identity(),
            nn.LeakyReLU(0.2, True),
        )

    # in the Pix2Pix GAN, the discriminator receives the input image and the target image as input.
    # Because of that, it is called a conditional GAN, because the discriminator is conditioned on the target image.
    def forward(self, x, y):
        # The input image and the target image are concatenated along the channel dimension.
        return self.model(torch.cat([x, y], dim=1))
    

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator = Generator(3, 3).to(device)
discriminator = Discriminator(3).to(device)

# Loss functions.

# Adversarial loss (Binary Cross-Entropy loss).
adv_criterion = nn.BCEWithLogitsLoss()
# L1 loss (for generator to preserve structure).
l1_criterion = nn.L1Loss()

# Optimizers.

lr = 0.0002
beta1 = 0.5
beta2 = 0.999

optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(beta1, beta2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, beta2))

lambda_l1 = 100

In [None]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Custom dataset should return pairs (input, target)
train_dataset = YourPairedDataset(transform=transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

In [None]:
num_epochs = 200
for epoch in range(num_epochs):
    for real_A, real_B in dataloader:
        real_A, real_B = real_A.to(device), real_B.to(device)
        
        # Train Discriminator
        optimizer_D.zero_grad()
        
        fake_B = generator(real_A)
        D_real = discriminator(real_A, real_B)
        D_fake = discriminator(real_A, fake_B.detach())
        
        real_loss = adv_criterion(D_real, torch.ones_like(D_real))
        fake_loss = adv_criterion(D_fake, torch.zeros_like(D_fake))
        D_loss = (real_loss + fake_loss) * 0.5
        D_loss.backward()
        optimizer_D.step()
        
        # Train Generator
        optimizer_G.zero_grad()
        fake_B = generator(real_A)
        D_fake = discriminator(real_A, fake_B)
        
        G_adv_loss = adv_criterion(D_fake, torch.ones_like(D_fake))
        G_l1_loss = l1_criterion(fake_B, real_B) * lambda_l1
        G_loss = G_adv_loss + G_l1_loss
        G_loss.backward()
        optimizer_G.step()