# Pix2pix: Image to Image Translation

Dataset preparation & training



Notebook adapted originally from [this tutorial](https://www.tensorflow.org/tutorials/generative/pix2pix), ported to PyTorch integrating changes from [DMLAP](https://github.com/IriniKlz/DMLAP-2024/tree/main/python/06-GANs-pix2pix).

See also [the original repo](https://github.com/phillipi/pix2pix) – in Lua – and the 'official' [PyTorch port](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/tree/master) as well as the [paper](https://arxiv.org/abs/1611.07004).

In [None]:
import os
import sys
import glob
import time
import pathlib

import numpy as np
import matplotlib.pyplot as plt

import cv2
from skimage import filters
from skimage import feature
from skimage import transform

import torch
from torch import nn
import torch.nn.functional as F

import torchvision as tv
from torchvision.transforms import v2
import torchvision.transforms.functional as TF

# Get cpu, gpu or mps device for training
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

# ------------------------------------------------------------------------------
# for dataset processing below
from multiprocessing import Pool
from multiprocessing import Lock
from multiprocessing import cpu_count
from multiprocessing import set_start_method

# make sure we are forking (hacky, for M1/2/3)
if device == "mps":
    set_start_method("fork")

print(f"Multiprocessing: found {cpu_count()} CPUs")
# ------------------------------------------------------------------------------

In [None]:
BATCH_SIZE = 1
IMG_SIZE = 256
IMG_CHANNELS = 3
IMG_EXTENSION = ".png" # check directory and change to e.g. jpg instead!

# fixed directory structure ----------------------------------------------------
DATASETS_DIR = pathlib.Path("datasets")
DATASETS_DIR.mkdir(exist_ok=True)

MODELS_DIR = pathlib.Path("models")
MODELS_DIR.mkdir(exist_ok=True)

GENERATED_DIR = pathlib.Path("generated")
GENERATED_DIR.mkdir(exist_ok=True)
# ------------------------------------------------------------------------------

## Datasets: Download and Preparation

### Choose a Pix2pix Dataset

**Original datasets**

- cityscapes.tar.gz
- edges2handbags.tar.gz
- edges2shoes.tar.gz
- facades.tar.gz
- maps.tar.gz
- night2day.tar.gz

Available [here](http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/). Locally, right click on a link, 'copy link'. Beware the sizes! Some of them are pretty big. For Colab, see **Web download** below.

**Kaggle**

- [Comic faces](https://www.kaggle.com/datasets/defileroff/comic-faces-paired-synthetic): `defileroff/comic-faces-paired-synthetic`
- [Rembrandt](https://www.kaggle.com/datasets/grafstor/rembrandt-pix2pix-dataset): `grafstor/rembrandt-pix2pix-dataset`
- [Depth](https://www.kaggle.com/datasets/greg115/pix2pix-depth): `greg115/pix2pix-depth`
-  [Maps](https://www.kaggle.com/datasets/alincijov/pix2pix-maps) (`alincijov/pix2pix-maps`) (looks identical to `maps` above)

There might be more! Also, if you want to use the transformation pipeline below, any image dataset would work!


### Directory structure

Our goal is to create either one of the two directory structures.

Either images inside our pix2pix directly:

```bash
datasets
└── pix2pix_dataset_name
    ├── 1.jpg
    ...
    └── 1198.jpg
```

Or the same, but with `train`, `test` or `val` sub-directories:

```bash
datasets
└── pix2pix_dataset_name
    ├── train
    │   ├── 1.jpg
    │   ├── 2.jpg
    ...
    ├── test
    ...
```

**Working with transformations**


However, if we apply transformations to the dataset we download (like turning source images black and white), then we might want to keep the downloaded dataset under another name, so that our final, ready dataset is still `pix2pix_dataset_name`, for instance:

```bash
datasets
├── pix2pix_dataset_name
└── pix2pix_dataset_name_orig
```

**Datasets with source and target directories**

For the comics faces dataset, for instance, the sources (pictures) and targets (comics) are in two separate subfolders, we will have to take that into account.

```bash
datasets
├── pix2pix_dataset_name
└── pix2pix_dataset_name_orig
    ├── source
    │   ├── 100.jpg
    ...
    ├── target
    ...
```


### Note on Colab workflow



I recommend working in the same way as with the DCGAN:

1. Either find a reliable url you can download a zip from, and use `wget` directly in the notebook, followed by:  

    ```bash
    unzip -q downloaded-file.zip -d <target-dir>
    ```
    or:
    ```
    tar xzf downloaded-file.tar.gz -C <target-dir>
    ```
     
    
2. Or download the dataset locally first, then upload it to your drive, change the accessibility settings for the zip file in the drive, copy the `ID` and use `gdown` with the `ID`.

To save generated images and model checkpoints to your drive, it's best to connect to your drive *after* downloading the dataset.

```python
# reminder: Colab code to mount your drive
if 'google.colab' in sys.modules:
    from google.colab import drive
    drive.mount('/content/drive/')
    # Then you can access its contents, with '/drive/MyDrive' as the default for Google Drives
```

After that, you can `cp` and `mv` files to and from your drive as if they were on the local machine.

### Web download


Perhaps the easiest way to get the dataset is to use Unix tools:

```python
# download to datasets dir
!wget http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/facades.tar.gz -P {DATASETS_DIR}

# uncompress (still in datasets dir)
!tar xzf {DATASETS_DIR}/facades.tar.gz -C {DATASETS_DIR}

# rename to the uncompressed dir to right format: pix2pix_{name}
# remember: our goal is to have either just the pictures in
# `pix2pix_{MODEL_NAME}, or folders like `train`/`test`/`val`
!mv {DATASETS_DIR}/facades {DATASETS_DIR}/pix2pix_facades

# clean-up, !rm -r for directories
# !rm {DATASETS_DIR}/facades.tar.gz
```


In [None]:
# web download code here


Below, our **global variables** will be:

```python
MODEL_NAME = "facades" # change accordingly (e.g. facades, bw2comics, etc.)

# source and processed datasets
PIX2PIX_DS_DIR_ORIG = DATASETS_DIR / f"pix2pix_{MODEL_NAME}_orig"
# no need to create the dir now, it will be later

PIX2PIX_DS_DIR = DATASETS_DIR / f"pix2pix_{MODEL_NAME}"
# no need to create it now, it will be later

PIX2PIX_DIR = MODELS_DIR / f"pix2pix_{MODEL_NAME}"
PIX2PIX_DIR.mkdir(exist_ok=True)

PIX2PIX_GEN_DIR = GENERATED_DIR / f"pix2pix_{MODEL_NAME}_images"
PIX2PIX_GEN_DIR.mkdir(exist_ok=True)

# Check when plotting or open an image manually!
# 0: [target, source]
# 1: [source, target]
TARGET_INDEX = 0
```

You can use the `global variables code` code cell below for this.

### Kaggle download

To download a dataset from Kaggle, follow these steps:

1. Create a Kaggle account (free)
2. Go to [settings](https://www.kaggle.com/settings), create a token (`.json` file), download it locally and upload it to the main Colab folder (locally, you can just download the dataset manually). You can also save this file to your drive, and copy it from there like so:

```python
from google.colab import drive
drive.mount('/content/drive')
# copy file from drive
!cp "drive/MyDrive/IS53024B-Artificial-Intelligence/kaggle.json" .
```
Then  copy to the root of the Colab:
```python
if os.path.isfile("kaggle.json"):
    print("Found token file `kaggle.json`")
    !mkdir -p /root/.config/kaggle
    !cp kaggle.json /root/.config/kaggle
    !chmod 600 /root/.config/kaggle/kaggle.json
else:
    print("Could not token file `kaggle.json`, please see the steps above!")
```

In [None]:
# kaggle.json copy

Then download:
```python
# also: "grafstor/rembrandt-pix2pix-dataset"
ID = pathlib.Path("defileroff/comic-faces-paired-synthetic") # on Kaggle: username/dataset
!kaggle datasets download {ID} -p {DATASETS_DIR}
```

In [None]:
# kaggle download

Unzip:
```bash
!unzip -q {DATASETS_DIR}/comic-faces-paired-synthetic.zip -d {DATASETS_DIR}
```

In [None]:
# unzip

Clean-up:
```bash
# Then retrieve the name & structure, in this case:
# face2comics_v1.0.0_by_Sxela/face2comics_v1.0.0_by_Sxela/
# I want to simplify this, so:
!mv {DATASETS_DIR}/face2comics_v1.0.0_by_Sxela/face2comics_v1.0.0_by_Sxela {DATASETS_DIR}/pix2pix_face2comics_orig

# # clean-up
# !rm -r "{DATASETS_DIR}/face2comics_v1.0.0_by_Sxela"
# !rm "{DATASETS_DIR}/comic-faces-paired-synthetic.zip"
```

In [None]:
# directory structure and clean-up

Finally, change your **global variables** accordingly

```python
MODEL_NAME = "face2comics" # change accordingly (e.g. edge2comics, etc.)

# source and processed datasets
PIX2PIX_DS_DIR_ORIG = DATASETS_DIR / f"pix2pix_{MODEL_NAME}_orig"
# PIX2PIX_DS_DIR_ORIG.mkdir(exist_ok=True) # already done, uncomment if you need

PIX2PIX_DS_DIR = DATASETS_DIR / f"pix2pix_{MODEL_NAME}"
# PIX2PIX_DS_DIR.mkdir(exist_ok=True) # already done, uncomment if you need

PIX2PIX_DIR = MODELS_DIR / f"pix2pix_{MODEL_NAME}"
PIX2PIX_DIR.mkdir(exist_ok=True)

PIX2PIX_GEN_DIR = GENERATED_DIR / f"pix2pix_{MODEL_NAME}_images"
PIX2PIX_GEN_DIR.mkdir(exist_ok=True)

# 0: [target, source]
# 1: [source, target]
TARGET_INDEX = 0
```

You can use the `global variables code` code cell below for this.

### Setting up our global variables

Each training image in a standard pix2pix dataset consists of one imgage divided into two adjacent **source** and **target** images.
The layout of the source and target may vary from training set to trainig set, so we provide a `TARGET_INDEX` flag the determines on which side the target is (`0` if on the left and `1` if on the right). Set this so the examples from the dataset appear with the source image to the left.

In [None]:
# global variables code here

## Working with transformations

**This entire section can be ignored if our dataset is already `pix2pix` formatted (with AB or BA images).**

### More Variables


This section is meant for two scenarios:

1. We have two separate directories with A and B images, they will be:
    - `SOURCE_DIR`
    - `TARGET_DIR`

2. We only want to transform one set of images, and that will be:
    - `TARGET_DIR`
    - (`SOURCE_DIR = None`)

In [None]:
# True For a standard pix2pix dataset (source/target in one image), False otherwise
IS_INPUT_PIX_TO_PIX = # True or False

# Only used if we have separate source and target directories,
# e.g. 'datasets/pix2pix_face2comics/face/'
SOURCE_DIR = PIX2PIX_DS_DIR_ORIG / "face"
TARGET_DIR = PIX2PIX_DS_DIR_ORIG / "comics"

### Load the images to process

In [None]:
def load_image(path):
    w, h = (256, 256)
    if IS_INPUT_PIX_TO_PIX: # In case we are already loading a pix2pix image
        w, h = (512, 256)
    img = cv2.imread(path)
    img = cv2.resize(img, (w,h), interpolation=cv2.INTER_NEAREST)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # OpenCV is BGR by default
    # If we are loading a pix2pix dataset just extract the target
    if IS_INPUT_PIX_TO_PIX:
        if TARGET_INDEX == 0:
            img = img[:,:h,:]
        else:
            img = img[:,h:,:]

    return img

def load_images_in_path(path, shuffle=False, limit=0):
    fnames = glob.glob(os.path.join(path, "*"))
    print(f"Found {len(fnames)} files in '{path}'")
    if limit > 0:
        fnames = fnames[:limit]
        print(f"Limiting number of files to {limit}")
    for f in fnames:
        yield load_image(f) # See this: https://realpython.com/introduction-to-python-generators/

In [None]:
if SOURCE_DIR:
    source_loader = iter(load_images_in_path(SOURCE_DIR))
    plt.axis("off")
    plt.imshow(next(source_loader))
    plt.show()

target_loader = iter(load_images_in_path(TARGET_DIR)) # create an iterator
plt.axis("off")
plt.imshow(next(target_loader))
plt.show()

### The transformation pipeline



#### Transformation utils

Note: if you plan to use `apply_face_landmarks`, you will need to install mediapipe first!

```bash
!pip install mediapipe
```

In [None]:
def apply_bw_cv2(img):
    """Turn an image black and white using OpenCV"""
    grey_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    return cv2.merge([grey_img, grey_img, grey_img]) # Force three channels for shape compat, thanks ChatGPT!

def apply_canny_cv2(img, thresh1=160, thresh2=250, invert=False):
    """Apply the OpenCV Canny edge detector to an image"""
    grey_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    edges = cv2.Canny(grey_img, thresh1, thresh2)
    if invert:
        edges = cv2.bitwise_not(edges)
    return cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)

def apply_canny_skimage(img, sigma=1.5, invert=False):
    """Apply the Scikit-Image Canny edge detector to an image"""
    grey_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    edges = (feature.canny(grey_img, sigma=sigma)*255).astype(np.uint8)
    if invert:
        edges = cv2.bitwise_not(edges)
    return cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)

# IDEA: It might be possible to use other Mediapipe functionalities, like:
#       - segmentation: https://developers.google.com/mediapipe/solutions/vision/image_segmenter
#       - pose landmarks: https://developers.google.com/mediapipe/solutions/vision/pose_landmarker
#       to write other transformation functions... (For both of those, you then need to find datasets!)

def apply_face_landmarks(
        img, stroke_weight=2, overlay=False, overlay_color='black'
    ):
    """Apply the MediaPipe face landmarker to an image"""
    import urllib
    import mediapipe as mp # requires pip install mediapipe
    from mediapipe import solutions
    from mediapipe.framework.formats import landmark_pb2
    from mediapipe.tasks.python import vision
    from mediapipe.tasks.python.core import base_options as base_options_module

    # Path to the model file
    model_path = PIX2PIX_DIR / "face_landmarker.task"

    # Check if the model file exists, if not, download it
    if not model_path.exists():
        url = "https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/1/face_landmarker.task"
        print(f"Downloading model from {url}...")
        urllib.request.urlretrieve(url, model_path)
        print(f"Model downloaded and saved as {model_path}")

    # Initialize MediaPipe FaceLandmarker
    base_options = base_options_module.BaseOptions(model_asset_path=model_path)
    options = vision.FaceLandmarkerOptions(
        base_options=base_options,
        output_face_blendshapes=True,
        output_facial_transformation_matrixes=True,
        num_faces=1
    )
    detector = vision.FaceLandmarker.create_from_options(options)

    # Function to draw landmarks on the image
    def draw_landmarks_on_image(rgb_image, detection_result,
                                overlay=False, overlay_color='black'):
        face_landmarks_list = detection_result.face_landmarks

        if overlay:
            annotated_image = np.copy(rgb_image)
        else:
            if overlay_color == "white":
                annotated_image = np.ones_like(rgb_image) * 255
            else: # default to black
                annotated_image = np.zeros_like(rgb_image)

        # Loop through the detected faces to visualize.
        for idx in range(len(face_landmarks_list)):
            face_landmarks = face_landmarks_list[idx]

            # Draw the face landmarks.
            face_landmarks_proto = landmark_pb2.NormalizedLandmarkList()
            face_landmarks_proto.landmark.extend([
                landmark_pb2.NormalizedLandmark(
                    x=landmark.x, y=landmark.y, z=landmark.z
                ) for landmark in face_landmarks
            ])

            solutions.drawing_utils.draw_landmarks(
                image=annotated_image,
                landmark_list=face_landmarks_proto,
                connections=solutions.face_mesh.FACEMESH_TESSELATION,
                landmark_drawing_spec=None,
                connection_drawing_spec=solutions.drawing_styles.get_default_face_mesh_tesselation_style())
            solutions.drawing_utils.draw_landmarks(
                image=annotated_image,
                landmark_list=face_landmarks_proto,
                connections=solutions.face_mesh.FACEMESH_CONTOURS,
                landmark_drawing_spec=None,
                connection_drawing_spec=solutions.drawing_styles.get_default_face_mesh_contours_style())
            solutions.drawing_utils.draw_landmarks(
                image=annotated_image,
                landmark_list=face_landmarks_proto,
                connections=solutions.face_mesh.FACEMESH_IRISES,
                landmark_drawing_spec=None,
                connection_drawing_spec=solutions.drawing_styles.get_default_face_mesh_iris_connections_style())

        return annotated_image

    # Convert the frame to RGB and create MediaPipe Image
    rgb_frame = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb_frame)

    # Detect face landmarks in the frame
    detection_result = detector.detect(mp_image)

    # Annotate frame with detected landmarks
    if detection_result:
        annotated_img = draw_landmarks_on_image(
            img, detection_result, overlay=overlay, overlay_color=overlay_color
    )
    else: # you could also imagine returning a purely black/white image
        annotated_img = img

    return annotated_img

# IDEA: Use Canvas (or openCV) to remove parts of the image (draw a rectangle/circle somewhere)
#       so that the net learns to complete an image with a hole in it (inpainting)

#### Select our transformation

In [None]:
# As it is, this version loads an image from the source_image directory and
# applies the Canny edge detection algorithm to it.
def load_source(img, img_source_iterator):
    return next(img_source_iterator)

# set to True to apply one of the transforms above
# False will load the next image from SOURCE_DIR as is
USE_TRANSFORM=False

if USE_TRANSFORM:
    # Set this to the tranformation you want to apply. If you are only working with
    # a single folder of images that you want to process, set image_transformation
    # to one of the filtering operations above,
    # e.g.
    image_transformation = apply_bw_cv2
    # image_transformation = apply_canny_cv2
    # image_transformation = apply_canny_skimage
    # image_transformation = apply_face_landmarks
else:
    # If you are working with existing sources, you can use
    # a Python partial to assign a fixed argument to load_source
    # and use it exactly like the other image transformations
    # (See: https://docs.python.org/3/library/functools.html#functools.partial)

    from functools import partial
    image_transformation = partial(
        load_source, img_source_iterator=iter(load_images_in_path(SOURCE_DIR))
    )

# Note that the partial can be used to specify arguments for `apply_face_landmarks`
# from functools import partial
# image_transformation = partial(apply_face_landmarks, overlay_color="white")

img = next(load_images_in_path(TARGET_DIR))

plt.figure()
plt.subplot(1, 2, 1)
plt.axis("off")
plt.imshow(img)

plt.subplot(1, 2, 2)
plt.axis("off")
plt.imshow(image_transformation(img))
plt.show()

#### Create the dataset!

In [None]:
MAX_IMAGES = 0     # max images to process, 0: disabled
PRINT_EVERY = 1000 # printing progress

img_loader = iter(load_images_in_path(TARGET_DIR, limit=MAX_IMAGES))

def combine_images(source, target):
    if TARGET_INDEX == 1:
        combined = np.hstack([source, target])
    else:
        combined = np.hstack([target, source])
    return combined

def process_source_target(i, source, target, out_dir):
    """Combine source and target and save them"""
    # IDEA: you could also apply additional processing to either
    # your source or target here
    # target = image_transformation(target)
    # target = image_transformation(source)
    combined = combine_images(source, target)
    combined = cv2.cvtColor(combined, cv2.COLOR_RGB2BGR)
    # multiprocessing: only one process at a time
    with lock:
        cv2.imwrite(out_dir / f"{i+1}.png", combined)

def process_target(i, target, out_dir):
    """Process target and save the combine images"""
    source = image_transformation(target)
    combined = combine_images(source, target)
    combined = cv2.cvtColor(combined, cv2.COLOR_RGB2BGR)
    # multiprocessing: only one process at a time
    with lock:
        cv2.imwrite(out_dir / f"{i+1}.png", combined)

# ------------------------------------------------------------------------------
# multiprocessing!

t = time.time()

l = Lock()
def init(l):
    """Initialize a lock to avoid race conditions"""
    global lock
    lock = l

if SOURCE_DIR:
    # CASE 1: we have source images, we want to combine them
    PIX2PIX_DS_DIR.mkdir(exist_ok=True)

    # create a source image iterator
    source_loader = iter(load_images_in_path(SOURCE_DIR, limit=MAX_IMAGES))

    with Pool(processes=cpu_count(), initializer=init, initargs=(l,)) as p:
        for i, (source, target) in enumerate(zip(source_loader, img_loader)):
            if (i+1) % PRINT_EVERY == 0:
                print(f"Processing source #{i+1}")
            p.apply_async(
                process_source_target, args=(i, source, target, PIX2PIX_DS_DIR)
            )
        p.close()
        p.join()

else:
    # CASE 2: we only have targets, we create the sources and combine them
    PIX2PIX_DS_DIR.mkdir(exist_ok=True)

    with Pool(processes=cpu_count(), initializer=init, initargs=(l,)) as p:
        for i, target in enumerate(img_loader):
            if (i+1) % PRINT_EVERY == 0:
                print(f"Processing target #{i+1}")
            p.apply_async(process_target, args=(i, target, PIX2PIX_DS_DIR))
        p.close()
        p.join()

print("Total time:", time.time() - t)
# ------------------------------------------------------------------------------

# verify the number of files in our directory (wc counts words, lines or bytes)
!ls {PIX2PIX_DS_DIR} | wc -l

## Dataset ready: load and preprocess it

The following code also **augments** the dataset by applying random uniform scaling (by upscaling and cropping) and random mirroring to the input output pairs. This should lead to a more stable model according to the original pix2pix paper. Finally the images ar normalized to the [-1,1] range as required by our GAN-based model.

We will organize the dataset in batches of size `1`, as that is generally suggested for pix2pix models. That means that we will update the weights of the model for each image pair separately.



In [None]:
def random_jitter(input_image, target_image):
    # Resizing to 286x286
    resize_transform = v2.Resize(size=(286, 286), interpolation=v2.InterpolationMode.NEAREST)
    input_image = resize_transform(input_image)
    target_image = resize_transform(target_image)

    # Random cropping back to 256x256
    i, j, h, w = v2.RandomCrop.get_params(input_image, output_size=(256, 256))
    input_image = TF.crop(input_image, i, j, h, w)
    target_image = TF.crop(target_image, i, j, h, w)

    # Random mirroring
    if np.random.uniform() < 0.5:
        input_image = TF.hflip(input_image)
        target_image = TF.hflip(target_image)

    return input_image, target_image

class Pix2PixImageDataset(torch.utils.data.Dataset):
    def __init__(self, path, target_index):
        super(Pix2PixImageDataset, self).__init__()
        self.files = [os.path.join(path, f) for f in os.listdir(path) if '.jpg' in f or '.png' in f]
        self.target_index = target_index

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        path = self.files[idx]
        image = cv2.imread(path)
        if image is None:
            shape = (BATCH_SIZE, IMG_CHANNELS, IMG_SIZE, IMG_SIZE)
            return torch.zeros(shape).to(device), torch.zeros(shape).to(device)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # OpenCV is BGR by default
        image = v2.ToImage()(image)
        image = v2.ToDtype(torch.float32, scale=True)(image) # from [0,255] to [0,1]
        w = image.shape[-1]
        w = w // 2

        if self.target_index == 0:          # 0: [target, source]
            input_image = image[:, :, w:]
            target_image = image[:, :, :w]
        else:                               # 1: [source, target]
            input_image = image[:, :, :w]
            target_image = image[:, :, w:]
        # Jitter
        input_image, target_image = random_jitter(input_image, target_image)
        # Normalize
        input_image = input_image * 2 - 1
        target_image = target_image * 2 - 1
        return input_image.to(device), target_image.to(device)

In [None]:
split_train_test = True # change accordingly!
# only used if `split_train_test` is False (a split already present)
TRAIN_DIR = "train"
TEST_DIR = "val"

if split_train_test:
    train_data_orig = Pix2PixImageDataset(PIX2PIX_DS_DIR, TARGET_INDEX)

    train_data, test_data = torch.utils.data.random_split(train_data_orig, [.9,.1])
    print(f"Train samples: {len(train_data)} | Test samples: {len(test_data)}")
else:
    train_data = Pix2PixImageDataset(PIX2PIX_DS_DIR / TRAIN_DIR, TARGET_INDEX)
    test_data = Pix2PixImageDataset(PIX2PIX_DS_DIR / TEST_DIR, TARGET_INDEX)

# create data loaders
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
# get a batch of training images
for i, t in enumerate(train_dataloader):

    x, y = t

    plt.figure(figsize=(13,16))
    plt.axis("off")
    plt.title("Training Images")
    lim = 16
    plt.imshow(TF.to_pil_image(
        tv.utils.make_grid(
            torch.cat([x[:lim], y[:lim]], dim = -1),
            # x[:64],
            padding=2, normalize=True, nrow=4
        ).detach().cpu())
    )
    plt.show()

    break

Note: sometimes you run into errors because of faulty files, etc. One way to test that everything is right in your dataset is to loop over it once, doing nothing:

In [None]:
for i, (input_image, target_image) in enumerate(train_dataloader):
    continue

## Build  the model



The pix2pix model is a conditional generative adversarial network (cGAN). A cGAN
is a type of GAN model used for generating new data samples with specific
attributes or characteristics. In a cGAN, both the generator and discriminator
are *conditioned* on additional information, such as class labels, tags, or
other types of metadata. The generator network takes in random noise as well as
the conditional information as input and produces a new data sample that matches
the desired attributes. The discriminator network, on the other hand, tries to
distinguish between the generated samples and real samples based on both their
visual appearance and the conditional information. For the case of a pix2pix
model the network is conditioned on an image, which should be transformed into
an output image.



### Generator



Differently from a DC-GAN, the generator of the pix2pix model is based on the
[U-net](https://arxiv.org/abs/1505.04597) architecture. A U-net model is a CNN architecture that is typically used
for image segmentation tasks. The name U-net derives from the architecture,
which resembles the letter &ldquo;U&rdquo;. It consists of two main parts: an *encoder* and
a *decoder*. The encoder part consists of a series of convolutional layers,
which reduce the spatial dimensionality of the input image while increasing its
depth (`Conv2d`). This is followed by a bottleneck layer that extracts the most important
features from the input image. The decoder part is a &ldquo;mirror image&rdquo; of the
encoder. It consists of a series of layers that gradually increase the spatial
dimensionality of the output, while decreasing its depth. This is similar to
what we have seen in the DC-GAN example, also using a &ldquo;transposed
convolution&rdquo; layer (`ConvTranspose2d`). The output of each layer in the encoder is
also concatenated with the output of another layer in the decoder. This creates
&ldquo;skip connections&rdquo; that help preserve spatial information and avoid information
loss during the encoding and decoding process.



In [None]:
class Downsample(nn.Module):
    def __init__(self, in_channels, out_channels, size=4, stride=2, apply_batchnorm=True):
        # Convolution-BatchNorm-ReLU
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, size, stride=stride, padding=1, bias=not apply_batchnorm)
        self.batchnorm = nn.BatchNorm2d(out_channels) if apply_batchnorm else None
        self.leakyrelu = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):
        x = self.conv(x)
        if self.batchnorm is not None:
            x = self.batchnorm(x)
        x = self.leakyrelu(x)
        return x

class Upsample(nn.Module):
    def __init__(self, in_channels, out_channels, size=4, stride=2, apply_dropout=False):
        # Convolution-BatchNorm-Dropout-ReLU
        super().__init__()
        self.conv_transpose = nn.ConvTranspose2d(in_channels, out_channels, size, stride=stride, padding=1, bias=True)
        self.batchnorm = nn.BatchNorm2d(out_channels)
        self.dropout = nn.Dropout(0.5) if apply_dropout else None
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv_transpose(x)
        x = self.batchnorm(x)
        if self.dropout is not None:
            x = self.dropout(x)
        x = self.relu(x)
        return x

class Generator(nn.Module):
    def __init__(self, img_channels=3):
        super(Generator, self).__init__()
        # encoder:
        # C64-C128-C256-C512-C512-C512-C512-C512
        # decoder with skip (in/out):
        # CD512-CD1024-CD1024-C1024-C1024-C512-C256-C128
        # CD512-CD512 -CD512 -C512 -C256 -C128-C64

        self.encoders = nn.ModuleList([
            Downsample(3, 64, apply_batchnorm=False),
            Downsample(64, 128),
            Downsample(128, 256),
            Downsample(256, 512),
            Downsample(512, 512),
            Downsample(512, 512),
            Downsample(512, 512),
            Downsample(512, 512, apply_batchnorm=False)
        ])

        self.decoders = nn.ModuleList([
            Upsample(512, 512, apply_dropout=True),
            Upsample(1024, 512, apply_dropout=True),
            Upsample(1024, 512, apply_dropout=True),
            Upsample(1024, 512),
            Upsample(1024, 256),
            Upsample(512, 128),
        ])
        self.last_decoder = Upsample(256, 64)
        self.last = nn.ConvTranspose2d(64, img_channels, kernel_size=4, stride=2, padding=1)

    def forward(self, x):
        skips = []
        for i, down in enumerate(self.encoders):
            x = down(x)
            skips.append(x)
        skips = skips[:-1][::-1]
        for i, up in enumerate(self.decoders):
            x = up(x)
            x = torch.cat([x, skips[i]], dim=1) # residual layers
        x = self.last_decoder(x)
        x = self.last(x)
        return torch.tanh(x)

G = Generator().to(device)
print(G)

### Discriminator

The discriminator in the pix2pix cGAN is a convolutional PatchGAN classifier—it tries to classify if each image _patch_ is real or not real, as described in the [pix2pix paper](https://arxiv.org/abs/1611.07004).

- Each block in the discriminator is: Convolution -> Batch normalization -> Leaky ReLU.
- The shape of the output after the last layer is `(batch_size, 30, 30, 1)`.
- Each `30 x 30` image patch of the output classifies a `70 x 70` portion of the input image.
- The discriminator receives 2 inputs:
    - The input image and the target image, which it should classify as real.
    - The input image and the generated image (the output of the generator), which it should classify as fake.
    - Use `torch.cat([inp, tar], dim=1)` to concatenate these 2 inputs together.



In [None]:
class Discriminator(nn.Module):
    def __init__(self, image_channels=3):
        # C64-C128-C256-C512
        super(Discriminator, self).__init__()
        self.down1 = Downsample(image_channels*2, 64, 4, apply_batchnorm=False)
        self.down2 = Downsample(64, 128, 4)
        self.down3 = Downsample(128, 256, 4)
        self.down4 = Downsample(256, 512, 4, stride=1)
        self.last = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=1)

    def forward(self, inp, tar):
        x = torch.cat([inp, tar], dim=1)
        x = self.down1(x)
        x = self.down2(x)
        x = self.down3(x)
        x = self.down4(x)
        x = self.last(x)
        return torch.sigmoid(x)

D = Discriminator().to(device)
print(D)

### Generate some images before training



Let&rsquo;s generate some images before training to see what the network will output



In [None]:
def numpy_image(x):
    return np.transpose(x.detach().cpu().numpy(), (1, 2, 0)) * 0.5 + 0.5

def generate_images(
        model, inputs, target=None,
        max_images=1, show=True, fname_no_ext=""
    ):
    inputs = inputs.to(device)
    with torch.no_grad():
        outputs = model(inputs)

    # print(prediction.size())

    for i, (input, predicted_image) in enumerate(zip(inputs, outputs)):

        # print(i)
        if i >= max_images:
            break

        if target is not None:
            plt.figure(figsize=(10, 10))
            display_list = [input, target[i], predicted_image]
            title = ['Input Image', 'Target', 'Predicted Image']
        else:
            plt.figure(figsize=(10, 10))
            display_list = [input, predicted_image]
            title = ['Input Image', 'Predicted Image']

        for j in range(len(title)):
            plt.subplot(1, len(title), j+1)
            plt.title(title[j])
            # Getting the pixel values in the [0, 1] range to plot.
            plt.imshow(numpy_image(display_list[j]))
            plt.axis('off')

        if fname_no_ext:
            plt.savefig(f"{fname_no_ext}_{i:04d}.png")

        if show:
            plt.show()

        plt.close()


for i, (example_input, example_target) in enumerate(train_dataloader):
    generate_images(G, example_input, example_target)
    break

## Train the model


This follows the procedure described in the [pix2pix paper](https://arxiv.org/abs/1611.07004). Similarly to unconditional GANs, this conditional GAN (cGAN) is learning to map edges to photo.

The discriminator D learns to classify between fake (synthesised by generator) and real {edges, photo} tuples. The generator G learns to fool the discriminator.

Unlike an unconditional GAN, here, both G and D observe the input edge map.

### Training loop



The training loop procedes by separately optimizing the discriminator and generator at each iteration. The procedure can be summarized as follows:
- For each example input we use the Generator to generate an output.
- Update the discriminator by:
    -  (1) Feeding it the input image and the example target image to classify the ground truth (example) pair.
    -  (2) Feeding it the input image together with the generated output to classify the generated pair.
    -  Using these two outputs (1 and 2) to compute the discriminator loss and to update the discriminator parameters to minimize this loss. In order to update only the discriminator, when computing step (2) the generated image is \"detached\" (using the `.detach()` function) from the Torch computation graph, so that the gradients will not be \"frozen\" and not propagated back to the generator. 
- Update the generator by:
    -  Computing (2) again with the updated discriminator but this time without detaching the generated image
    -  Computing the generator loss by combining the classification loss computed for the discriminator and the [L1 distance](https://montjoile.medium.com/l0-norm-l1-norm-l2-norm-l-infinity-norm-7a7d18a4f40c) between the generated image and the target one and finally updating the parameters of the generator to minimize this loss.
 
The full structure:

![Generator Update Image](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/images/gen.png?raw=1)



#### Generator loss


While GANs learn a loss that adapts to the data, cGANs (as Pix2Pix) learn a structured loss that penalizes a possible structure that differs from the network output and the target image. The generator loss consists of two terms:

-   Similarly to the discriminator case, the first term `fake_gan_loss` is a sigmoid cross-entropy loss of the (discriminated) generated images and an array of ones, i.e. considering the generated output as a real sample.
-   The second term `dist_loss` quantifies the L1 distance, i.e. the mean absolute error (absolute value of differences), between the generated image and the target image. This allows the generated image to become structurally similar to the target image.
-   These two terms are combined as `fake_gan_loss + LAMBDA * dist_loss`, where `LAMBDA = 100`. This value was decided by the authors of the paper.
Feel free to experiment with modifying the value of `LAMBDA` (if you have time to spare:))

#### Discriminator loss



The discriminator loss (`disc_loss`) consists of the average of two terms, a `real_loss` and a `fake_loss`:
- The `real_loss` is the is a [binary cross-entropy loss](https://gombru.github.io/2018/05/23/cross_entropy_loss/) of the (discriminated) real images and an array of ones (since these are the real images). 
- The `fake_loss` is the is a binary cross-entropy loss of the (discriminated) fake images and an array of zeros (since these are the fake images). 

In [None]:
G_optimizer = torch.optim.Adam(G.parameters(), lr=2e-4, betas=(0.5, 0.999))
D_optimizer = torch.optim.Adam(D.parameters(), lr=2e-4, betas=(0.5, 0.999))

BCE_loss = nn.BCELoss()
L1_loss = nn.L1Loss()

LAMBDA = 100 # Weight of L1 loss in optimization

def train_step(inputs, targets):
    # Generate output
    G_outputs = G(inputs)

    # ---- Update discriminator ----
    D_optimizer.zero_grad() # Clear gradients

    # Classify real and fake patches
    # Here we "freeze" generator gradients since we only optimize the discriminator
    real_patch = D(inputs, targets)
    fake_patch = D(inputs, G_outputs.detach())

    # Compute loss for real/fake patches
    # log(D(x,y)) + log(1 - D(x,G(x)))
    real_class = torch.ones_like(real_patch).to(device)
    fake_class = torch.zeros_like(fake_patch).to(device)

    real_loss = BCE_loss(real_patch, real_class)
    fake_loss = BCE_loss(fake_patch, fake_class)
    D_loss = (real_loss + fake_loss)/2

    # Propagate gradients and perform gradient descent step
    D_loss.backward()
    D_optimizer.step()

    # ---- Update generator ----
    G_optimizer.zero_grad() # Clear gradients
    # Classify fake samples, now considering generator gradients
    fake_patch = D(inputs, G_outputs)
    # Compute loss according to paper
    # log(D(x,G(x))) + L1(y,G(x))
    fake_gan_loss = BCE_loss(fake_patch, real_class)
    l1_loss = L1_loss(targets, G_outputs)
    G_loss = fake_gan_loss + LAMBDA * l1_loss

    # Propagate gradients and perform gradient descent step
    G_loss.backward()
    G_optimizer.step()

    return G_loss, D_loss

In [None]:
def plot(d_losses, g_losses, show=True, save=False):
    """
    Book-keeping: visualize losses and one example image for the epoch
    """
    plt.figure(figsize=(12, 3))
    plt.title('Losses')
    plt.plot(g_losses, label='Generator')
    plt.legend()
    plt.plot(np.array(d_losses), label='Discriminator')
    plt.legend()
    if show:
        plt.show()
    if save:
        plt.savefig(PIX2PIX_DIR / "losses.pdf")
        plt.close()

In [None]:
tot = len(train_dataloader) # for print formatting
print(f"The train dataset contains {tot} batches (# of iterations/epoch)!")
# Using this information, we can modify the variables below...

In [None]:
# Those will contain our losses (keeping them in a different cell
# allows them to persist as we run the next cell several times)
g_losses = []
d_losses = []
iters = 0

In [None]:
EPOCHS = 20 # change accordingly (if the dataset is big, you don't need as many)

# All stats according to iters, rather than batch or epoch
AVG_EVERY = 10      # collect batch losses for plotting
PRINT_EVERY = 100   # print stats
GEN_EVERY = 2000    # generate imgs
PLOT_EVERY = 2000   # plot losses (skips iter 0 unless it is 1: 'every iter')
                    # ↓ save model (skips iter 0 unless it is 1: 'every iter')
SAVE_EVERY = iters + EPOCHS * tot * BATCH_SIZE # save at the very end!

for epoch in range(EPOCHS):

    print()
    print(f"Epoch {epoch+1:>{len(str(EPOCHS))}}/{EPOCHS} | Iter: {iters}")

    batch_d_losses = []
    batch_g_losses = []

    for i, (input_image, target_image) in enumerate(train_dataloader):

        G_loss, D_loss = train_step(input_image, target_image)
        g_l, d_l = G_loss.item(), D_loss.item()
        batch_g_losses.append(g_l)
        batch_d_losses.append(d_l)

        if (iters + 1) % PRINT_EVERY == 0:
            print(
                f"  Iter {iters+1} (i {i+1:>{len(str(tot))}}/{tot}) "
                + f"[G loss: {g_l:.4f} | D loss: {d_l:.4f}]"
            )

        if iters % AVG_EVERY == 0:
            # print("Averaging losses")
            # average values collected until now, reset temporary lists
            g_losses.append(np.mean(batch_g_losses))
            d_losses.append(np.mean(batch_d_losses))
            batch_d_losses = []
            batch_g_losses = []

        if (iters + 1) % GEN_EVERY == 0:
            generate_images(
                G, input_image, target_image,
                fname_no_ext=PIX2PIX_GEN_DIR / f"generated_image.iter_{iters+1:04d}"
            )

        if (iters > 0 or PLOT_EVERY == 1) and (iters + 1) % PLOT_EVERY == 0:
            # print()
            # print(f"  Iter {iters + 1} | Plotting")
            # print()
            plot(d_losses, g_losses)

        if (iters > 0 or SAVE_EVERY == 1) and (iters + 1) % SAVE_EVERY == 0:
            print()
            print(f"  Iter {iters+1} | Saving model to {PIX2PIX_DIR}")
            print()
            G_scripted = torch.jit.script(G)
            G_scripted.save(PIX2PIX_DIR / f"pix2pix_{MODEL_NAME}.iter_{iters+1:04d}_scripted.pt")
            # The following saves only model parameters
            torch.save(G.state_dict(), PIX2PIX_DIR / f"pix2pix_{MODEL_NAME}.iter_{iters+1:04d}.pt")

        iters += 1 # technically should be BATCH_SIZE, but oh well

### Download your data

```bash
# -r for 'recursive', required for directories
!zip -r generated.zip generated
!zip -r models.zip models
# then use the left-hand side bar to download manually.
```

(The process is similar for `datasets`, if you want to save that.)

If you want to save things to your drive, you, can then move those to the destination of your choice (provided you mounted your drive in the first place):

```bash
!cp -r models/* !mv drive/MyDrive/IS53055B-DMLCP/DMLCP/python/models
```