# Pix2Pix GAN implementation

[Original video](https://youtu.be/SuddDSqGRzg)

[Paper walkthrough video](https://youtu.be/9SGs4Nm0VR4)

[Pix2Pix paper](https://arxiv.org/abs/1611.07004)

[Datasets](http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/) to train neural network

Training of GPU should take around **1.5 hours**.

## Install `albumentations` image augmentation library

**NOTE:** After installation restart the runtime.

In [None]:
%%script echo "Skip this cell. Use installed albumentations library"

# Google CoLab has old version of albumentations library. Update it.
# After installation restart the runtime.
!pip install -U git+https://github.com/albu/albumentations --no-cache-dir

Skip this cell. Use installed albumentations library


## Get dataset. Import libraries

In [None]:
# Mount your Google Drive to this Colab
from google.colab import drive

gdrive = '/content/gdrive'
drive.mount(gdrive)

# Check connection
data_dir = f'{gdrive}/My Drive/Colab Notebooks/2025.07.25_execises/models'

# Show files in a data directory
!ls -hal "{data_dir}"

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
total 1.2G
-rw------- 1 root root  16M Jul 28 11:26 3d_image_classification.keras
-rw------- 1 root root  32M Jul 30 16:32 dcgan_discriminator_celeb.pth.tar
-rw------- 1 root root  32M Jul 30 16:32 dcgan_discriminator_checkpoint.pth.tar
-rw------- 1 root root  41M Jul 30 16:32 dcgan_generator_celeb.pth.tar
-rw------- 1 root root 145M Jul 30 16:32 dcgan_generator_checkpoint.pth.tar
-rw------- 1 root root 239M Sep  2  2018 maps.tar.gz
-rw------- 1 root root  32M Aug  7 15:41 pix2pix_discriminator.pth.tar
-rw------- 1 root root 623M Aug  7 15:41 pix2pix_generator.pth.tar
-rw------- 1 root root 1.2M Jul 29 08:44 simple_gan_discriminator_checkpoint.pth.tar
-rw------- 1 root root 2.6M Jul 29 08:44 simple_gan_generator_checkpoint.pth.tar


In [None]:
import os

# Get pretrained models
CHECKPOINT_DISC = 'pix2pix_discriminator.pth.tar'
CHECKPOINT_GEN = 'pix2pix_generator.pth.tar'

# Copy files from Google Drive
if os.path.exists(f"{data_dir}/{CHECKPOINT_DISC}") and \
   os.path.exists(f"{data_dir}/{CHECKPOINT_GEN}"):

    !cp -rf "{data_dir}/{CHECKPOINT_DISC}" "."
    !cp -rf "{data_dir}/{CHECKPOINT_GEN}"  "."
    !cp -rf "{data_dir}/maps.tar.gz"       "."

    !ls -hal "{CHECKPOINT_DISC}"
    !ls -hal "{CHECKPOINT_GEN}"
    !ls -hal "maps.tar.gz"

-rw------- 1 root root 32M Aug  7 17:23 pix2pix_discriminator.pth.tar
-rw------- 1 root root 623M Aug  7 17:23 pix2pix_generator.pth.tar
-rw------- 1 root root 239M Aug  7 17:23 maps.tar.gz
-rw------- 1 root root 32M Aug  7 18:23 pix2pix_discriminator.pth.tar
-rw------- 1 root root 623M Aug  7 18:23 pix2pix_generator.pth.tar
-rw------- 1 root root 239M Aug  7 18:23 maps.tar.gz


In [None]:
import os
import cv2  # OpenCV library
import numpy as np
import multiprocessing
import matplotlib.pyplot as plt
import albumentations as A  # image augmentation library

import torch
import torch.nn as nn
import torch.optim as optim

from PIL import Image
from torchvision.utils import save_image
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm  # library to show the progressbar

In [None]:
import os

# Get maps dataset
if not os.path.exists('maps.tar.gz'):
    !wget 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/maps.tar.gz'

--2025-08-08 08:10:23--  http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/maps.tar.gz
Resolving efrosgans.eecs.berkeley.edu (efrosgans.eecs.berkeley.edu)... 128.32.244.190
Connecting to efrosgans.eecs.berkeley.edu (efrosgans.eecs.berkeley.edu)|128.32.244.190|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 250242400 (239M) [application/x-gzip]
Saving to: ‘maps.tar.gz’


2025-08-08 08:11:17 (4.42 MB/s) - ‘maps.tar.gz’ saved [250242400/250242400]



In [None]:
# Extract data
import zipfile
import tarfile

def extract(fname):
    """ Extract files from archive. """
    if fname.endswith(".tar.gz") or fname.endswith('.tgz'):
        ref = tarfile.open(fname, mode='r:gz')
    elif fname.endswith('.tar'):
        ref = tarfile.open(fname, mode='r:')
    elif fname.endswith('.tar.bz2') or fname.endswith('.tbz'):
        ref = tarfile.open(fname, mode='r:bz2')
    elif fname.endswith('.zip'):
        ref = zipfile.ZipFile(fname, mode='r')

    ref.extractall()
    ref.close()

extract('maps.tar.gz')

## Discriminator model

In [None]:
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=2):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=stride,
                      padding_mode='reflect', bias=False),

            # Do not normalize across the batches. Normalize only across the layer (instance).
            nn.InstanceNorm2d(out_channels, affine=True),  # LayerNorm <--> InstanceNorm

            # nn.InstanceNorm2d has better results. No artifacts
            # nn.BatchNorm2d(out_channels),

            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.conv(x)


class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        # Input 286x286. Output: 30x30
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels*2, features[0], kernel_size=4, stride=2,
                      padding=1, padding_mode='reflect'),
            nn.LeakyReLU(0.2),
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            stride = 1 if feature == features[-1] else 2
            layers.append(CNNBlock(in_channels, feature, stride=stride))
            in_channels = feature

        layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1,
                                padding=1, padding_mode='reflect'))

        self.model = nn.Sequential(*layers)

    def forward(self, x, y):
        # x - satellite image, y - transformed real or generated fake image
        # x and y are concatenated along the channels
        x = torch.cat([x, y], dim=1)  # concatenate along channels
        x = self.initial(x)
        return self.model(x)


def test():  # test function
    x = torch.randn((8, 3, 286, 286))  # 286x286 pixels image size
    y = torch.randn((8, 3, 286, 286))
    model = Discriminator()
    predictions = model(x, y)
    print(predictions.shape)
    assert predictions.shape == (8, 1, 30, 30)

    x = torch.randn((8, 3, 256, 256))  # 256x256 pixels image size
    y = torch.randn((8, 3, 256, 256))
    model = Discriminator()
    predictions = model(x, y)
    print(predictions.shape)

    print('Test - OK')


test()

torch.Size([8, 1, 30, 30])
torch.Size([8, 1, 26, 26])
Test - OK
torch.Size([8, 1, 30, 30])
torch.Size([8, 1, 26, 26])
Test - OK


## Generator model

In [None]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, act='relu', use_dropout=False):
        super().__init__()

        if down:
            layer = nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2,
                              padding=1, padding_mode='reflect', bias=False)
        else:  # cannot use padding_mode='reflect' on the ConvTranspose2d layer
            layer = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4,
                                       stride=2, padding=1, bias=False)

        self.conv = nn.Sequential(
            layer,

            # Do not normalize across the batches. Normalize only across the layer (instance).
            nn.InstanceNorm2d(out_channels, affine=True),  # LayerNorm <--> InstanceNorm

            # nn.InstanceNorm2d has better results. No artifacts
            # nn.BatchNorm2d(out_channels),

            nn.ReLU() if act == 'relu' else nn.LeakyReLU(0.2),
        )
        self.use_dropout = use_dropout
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.conv(x)
        return self.dropout(x) if self.use_dropout else x


class Generator(nn.Module):
    def __init__(self, in_channels=3, features=64):
        super().__init__()
        # Input: 256
        self.initial_down = nn.Sequential(
            nn.Conv2d(in_channels, features, kernel_size=4, stride=2, padding=1,
                      padding_mode='reflect'),
            nn.LeakyReLU(0.2),
        )  # 128

        self.down1 = Block(features,   features*2, down=True, act='leaky', use_dropout=False)  # 64x64
        self.down2 = Block(features*2, features*4, down=True, act='leaky', use_dropout=False)  # 32x32
        self.down3 = Block(features*4, features*8, down=True, act='leaky', use_dropout=False)  # 16x16
        self.down4 = Block(features*8, features*8, down=True, act='leaky', use_dropout=False)  # 8x8
        self.down5 = Block(features*8, features*8, down=True, act='leaky', use_dropout=False)  # 4x4
        self.down6 = Block(features*8, features*8, down=True, act='leaky', use_dropout=False)  # 2x2
        self.bottleneck = nn.Sequential(
            nn.Conv2d(features*8, features*8, 4, 1, 1, padding_mode='reflect'),
            nn.ReLU(),
        )  # 1x1
        self.up1 = Block(features*8,   features*8, down=False, act='relu', use_dropout=True)   # 2x2
        self.up2 = Block(features*8*2, features*8, down=False, act='relu', use_dropout=True)   # 4x4
        self.up3 = Block(features*8*2, features*8, down=False, act='relu', use_dropout=True)   # 8x8
        self.up4 = Block(features*8*2, features*8, down=False, act='relu', use_dropout=False)  # 16x16
        self.up5 = Block(features*8*2, features*4, down=False, act='relu', use_dropout=False)  # 32x32
        self.up6 = Block(features*4*2, features*2, down=False, act='relu', use_dropout=False)  # 64x64
        self.up7 = Block(features*2*2, features,   down=False, act='relu', use_dropout=False)  # 128x128
        self.final_up = nn.Sequential(
            nn.ConvTranspose2d(features*2, in_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),  # use hyperbolic tangent
        )  # 256x256

    def forward(self, x):  # U-Net shape-like structure
        d1 = self.initial_down(x)
        d2 = self.down1(d1)
        d3 = self.down2(d2)
        d4 = self.down3(d3)
        d5 = self.down4(d4)
        d6 = self.down5(d5)
        d7 = self.down6(d6)
        bottleneck = self.bottleneck(d7)
        up1 = self.up1(bottleneck)
        up2 = self.up2(torch.cat([up1, d7], dim=1))
        up3 = self.up3(torch.cat([up2, d6], dim=1))
        up4 = self.up4(torch.cat([up3, d5], dim=1))
        up5 = self.up5(torch.cat([up4, d4], dim=1))
        up6 = self.up6(torch.cat([up5, d3], dim=1))
        up7 = self.up7(torch.cat([up6, d2], dim=1))
        return self.final_up(torch.cat([up7, d1], dim=1))


def test():
    x = torch.randn((8, 3, 256, 256))
    model = Generator()
    predictions = model(x)
    print(predictions.shape)
    assert predictions.shape == (8, 3, 256, 256)
    print('Test - OK')

test()

torch.Size([8, 3, 256, 256])
Test - OK
torch.Size([8, 3, 256, 256])
Test - OK


## Configuration

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
USE_CUDA = torch.cuda.is_available()  # check if CUDA is available
LEARNING_RATE = 2e-4
BATCH_SIZE = 160
NUM_WORKERS = multiprocessing.cpu_count()  # get number of CPU cores
IMAGE_SIZE = 256
CHANNELS_IMG = 3
L1_LAMBDA = 100
NUM_EPOCHS = 200  # 200 epochs
LOAD_MODEL = True  # True
SAVE_MODEL = True
CHECKPOINT_DISC = 'pix2pix_discriminator.pth.tar'
CHECKPOINT_GEN = 'pix2pix_generator.pth.tar'
EVALUATION = 'pix2pix_evaluation_images'


transform1 = A.Compose([
    A.Resize(width=IMAGE_SIZE, height=IMAGE_SIZE),  # resize image
])

augment_both_images = A.Compose(
    [
        # Applies one of the eight possible D4 dihedral group transformations to a square-shaped input,
        # maintaining the square shape. These transformations correspond to the symmetries of a square,
        # including rotations and reflections.
        A.D4(p=1.0),
    ],
    additional_targets={'image0': 'image'})


transform2 = A.Compose([
    A.ColorJitter(p=0.1),  # NOTE: this color ColorJitter seems important
    A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0),
    ToTensorV2(),  # convert array to tensor
])

## Dataset

In [None]:
class MapDataset(Dataset):
    def __init__(self, root, augment):
        super().__init__()
        self.root = root
        self.augment = augment
        self.list_files = os.listdir(self.root)

    def __len__(self):
        return len(self.list_files)

    def __getitem__(self, index):
        filepath = os.path.join(self.root, self.list_files[index])
        image = np.array(Image.open(filepath))

        input_image = image[:, :600, :]
        target_image = image[:, 600:, :]

        input_image = transform1(image=input_image)['image']
        target_image = transform1(image=target_image)['image']

        if self.augment:
            augmentations = augment_both_images(image=input_image, image0=target_image)
            input_image, target_image = augmentations['image'], augmentations['image0']

        input_image = transform2(image=input_image)['image']
        target_image = transform2(image=target_image)['image']

        return input_image, target_image

## Helper functions

In [None]:
def save_checkpoint(model, optimizer, filename):
    print('=> Saving checkpoint')
    checkpoint = {
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print('=> Loading checkpoint')
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])

    # Replace old learning rate from the saved model
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def prepare(image):
    """ Prepare image for display. """
    image = image * 0.5 + 0.5  # convert tensor from [-1,1] to [0,1]
    image = image.cpu().numpy()  # copy PyTorch tensor from GPU to NumPy CPU
    image = image.squeeze()  # convert (1, 3, 256, 256) to (3, 256, 256)
    # Transpose tensor channels from (channel, height, width) to (height, width, channel).
    image = np.transpose(image, (1, 2, 0))  # convert from (3, 256, 256) to (256, 256, 3)
    return image


def add_image(image, title, cmap="viridis"):
    """ Add image to the grid. Default color map is 'viridis'. """
    plt.imshow(image, cmap=cmap)  # add image to the grid
    plt.title(title)  # show image title
    plt.axis("off")  # turn off axis numbers


def show_one_example(gen, image_with_map):
    """ Show one example on the screen. """
    x, y = image_with_map  # x - image, y - ground truth map
    x, y = x.to(DEVICE), y.to(DEVICE)  # cast tensors to CPU or to GPU device

    with torch.no_grad():  # do not calculate gradiences
        y_fake = gen(x)  # generate fake map
        x, y, y_fake = prepare(x), prepare(y), prepare(y_fake)

        fig = plt.figure(figsize=(12, 4))  # create figure with size 12×4 inches

        fig.add_subplot(1, 3, 1)  # add a new cell to the grid
        add_image(x, "Original image")

        fig.add_subplot(1, 3, 2)  # add a new cell to the grid
        add_image(y, "Ground truth map")

        fig.add_subplot(1, 3, 3)  # add a new cell to the grid
        add_image(y_fake, "Generated map")

        plt.show()  # show the grid



def show_examples(root, augment, num=25):
    """ Show image example on the screen. """
    gen = Generator().to(DEVICE)
    gen.eval()  # switch generator from training to evaluation mode
    opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
    load_checkpoint(CHECKPOINT_GEN, gen, opt_gen, LEARNING_RATE)
    dataset = MapDataset(root=root, augment=augment)
    loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=NUM_WORKERS)
    # Iterate through the DataLoader
    for i, image_with_map in enumerate(loader):
        if i >= num:
            break
        show_one_example(gen, image_with_map)


def save_example(gen, test_loader, epoch, folder):
    """ Save image example in the folder. """
    x, y = next(iter(test_loader))
    x, y = x.to(DEVICE), y.to(DEVICE)
    os.makedirs(folder, exist_ok=True)

    gen.eval()
    with torch.no_grad():
        y_fake = gen(x)
        y_fake = y_fake * 0.5 + 0.5  # convert tensor from [-1,1] to [0,1]
        save_image(y_fake, folder + f'/y_gen_{epoch}.jpg')
        if epoch == 0:
            save_image(x*0.5+0.5, folder + f'/_input_.jpg')
            save_image(y*0.5+0.5, folder + f'/_label_.jpg')
    gen.train()

## Train

In [None]:
def train(disc, gen, loader, opt_disc, opt_gen, l1_loss, bce_loss, d_scaler, g_scaler):
    loop = tqdm(loader, leave=False)

    for idx, (x, y) in enumerate(loop):
        x, y = x.to(DEVICE), y.to(DEVICE)

        # Train Discriminator
        # Forward pass with autocast enabled for automatic mixed precision
        with torch.autocast(DEVICE, dtype=torch.float16):
            y_fake = gen(x)
            d_real = disc(x, y)
            d_fake = disc(x, y_fake.detach())
            d_real_loss = bce_loss(d_real, torch.ones_like(d_real))
            d_fake_loss = bce_loss(d_fake, torch.zeros_like(d_fake))
            d_loss = (d_real_loss + d_fake_loss) / 2

        opt_disc.zero_grad()
        d_scaler.scale(d_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train Generator
        # Forward pass with autocast enabled for automatic mixed precision
        with torch.autocast(DEVICE, dtype=torch.float16):
            d_fake = disc(x, y_fake)
            g_fake_loss = bce_loss(d_fake, torch.ones_like(d_fake))
            l1 = l1_loss(y_fake, y) * L1_LAMBDA
            g_loss = g_fake_loss + l1

        opt_gen.zero_grad()
        d_scaler.scale(g_loss).backward()
        d_scaler.step(opt_gen)
        d_scaler.update()


def main():
    disc = Discriminator().to(DEVICE)
    gen = Generator().to(DEVICE)
    opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
    opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
    bce_loss = nn.BCEWithLogitsLoss()
    l1_loss = nn.L1Loss()

    if LOAD_MODEL and os.path.exists(CHECKPOINT_DISC) and os.path.exists(CHECKPOINT_GEN):
        load_checkpoint(CHECKPOINT_DISC, disc, opt_disc, LEARNING_RATE)
        load_checkpoint(CHECKPOINT_GEN, gen, opt_gen, LEARNING_RATE)

    train_dataset = MapDataset(root='./maps/train', augment=True)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
    test_dataset = MapDataset(root='./maps/val', augment=False)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=NUM_WORKERS)

    # Initialize GradScaler for gradient scaling (essential for AMP training)
    d_scaler = torch.amp.GradScaler(enabled=USE_CUDA)
    g_scaler = torch.amp.GradScaler(enabled=USE_CUDA)

    disc.train()
    gen.train()

    for epoch in range(NUM_EPOCHS+1):
        train(disc, gen, train_loader, opt_disc, opt_gen, l1_loss, bce_loss, d_scaler, g_scaler)

        if SAVE_MODEL and epoch % 10 == 0:  # save model every 10 epochs
            print(f'epoch: {epoch}')
            save_checkpoint(disc, opt_disc, CHECKPOINT_DISC)
            save_checkpoint(gen, opt_gen, CHECKPOINT_GEN)

        if epoch % 5 == 0:  # show example every 5 epochs
            show_one_example(gen, next(iter(test_loader)))

        save_example(gen, test_loader, epoch, folder=EVALUATION)


if __name__ == '__main__':
    main()

## Show results

### Show validation images without augmentations

In [None]:
show_examples('./maps/val', augment=False)

Output hidden; open in https://colab.research.google.com to view.

### Show train images with augmentations

In [None]:
show_examples('./maps/train', augment=True)

Output hidden; open in https://colab.research.google.com to view.

## Save results to Google Drive

In [None]:
pictures_dir = os.path.dirname(data_dir) + '/pictures'
print(pictures_dir)

!zip -qr "{EVALUATION}.zip"  "{EVALUATION}/"

!cp -rf "{EVALUATION}.zip"   "{pictures_dir}"
!cp -rf "{CHECKPOINT_DISC}"  "{data_dir}"
!cp -rf "{CHECKPOINT_GEN}"   "{data_dir}"

/content/gdrive/My Drive/Colab Notebooks/2025.07.25_execises/pictures
