# **Import Libraries**

In [1]:
# Standard Library
import os
import warnings

# Third-Party Libraries
import numpy as np
from PIL import Image
from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2

# PyTorch Core
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# Hugging Face
from accelerate import Accelerator

# TorchVision
import torchvision
from torchvision.utils import save_image

# **Ignore Warnings**

In [2]:
warnings.filterwarnings('ignore')

# **Generator**

In [3]:
class GENBlock(nn.Module):
  def __init__(self, in_channels, out_channels, down = True, act = 'relu', use_dropout = False):
    super(GENBlock, self).__init__()

    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias = False, padding_mode = 'reflect') if down else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias = False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU() if act == 'relu' else nn.LeakyReLU(.2),
        nn.Dropout2d(.5) if use_dropout else nn.Identity()
    )

  def forward(self, x):
    return self.conv(x)

class Generator(nn.Module):
  def __init__(self, in_channels, features):
    super(Generator, self).__init__()

    self.initial_down = nn.Sequential(nn.Conv2d(in_channels, features, 4, 2, 1, padding_mode = 'reflect'), nn.LeakyReLU(.2))
    self.down1 = GENBlock(features,     features* 2, act = "leaky", use_dropout = False)
    self.down2 = GENBlock(features*2, features*4, act = "leaky", use_dropout = False)
    self.down3 = GENBlock(features*4, features*8, act = "leaky", use_dropout = False)
    self.down4 = GENBlock(features*8, features*8, act = "leaky", use_dropout = False)
    self.down5 = GENBlock(features*8, features*8, act = "leaky", use_dropout = False)
    self.down6 = GENBlock(features*8, features*8, act = "leaky", use_dropout = False)

    self.bottle_neck = nn.Sequential(nn.Conv2d(features*8, features*8, 4, 2, 1), nn.ReLU())

    self.up1 = GENBlock(features*8,     features*8 , down = False, use_dropout = True)
    self.up2 = GENBlock(features*8*2, features*8,  down = False, use_dropout = True)
    self.up3 = GENBlock(features*8*2, features*8,  down = False, use_dropout = True)
    self.up4 = GENBlock(features*8*2, features*8,  down = False, use_dropout = False)
    self.up5 = GENBlock(features*8*2, features*4,  down = False, use_dropout = False)
    self.up6 = GENBlock(features*4*2, features*2,  down = False, use_dropout = False)
    self.up7 = GENBlock(features*2*2, features,      down = False, use_dropout = False)
    self.final_up = nn.Sequential(nn.ConvTranspose2d(features*2, in_channels, 4, 2, 1), nn.Tanh())

  def forward(self, x):
    d1 = self.initial_down(x)
    d2 = self.down1(d1)
    d3 = self.down2(d2)
    d4 = self.down3(d3)
    d5 = self.down4(d4)
    d6 = self.down5(d5)
    d7 = self.down6(d6)

    bottle_neck = self.bottle_neck(d7)

    u1 = self.up1(bottle_neck)
    u2 = self.up2(torch.cat([u1, d7], 1))
    u3 = self.up3(torch.cat([u2, d6], 1))
    u4 = self.up4(torch.cat([u3, d5], 1))
    u5 = self.up5(torch.cat([u4, d4], 1))
    u6 = self.up6(torch.cat([u5, d3], 1))
    u7 = self.up7(torch.cat([u6, d2], 1))

    return self.final_up(torch.cat([u7, d1], 1))

# **Discriminator**

In [4]:
class DISBlock(nn.Module):
  def __init__(self, in_channels, out_channels, stride):
    super(DISBlock, self).__init__()

    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias = False, padding_mode = 'reflect'),
        nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(.2)
    )

  def forward(self, x):
    return self.conv(x)

class Discriminator(nn.Module):
  def __init__(self, in_channels, features = [64, 128, 256, 512]):
    super(Discriminator, self).__init__()

    self.initial = nn.Sequential(
        nn.Conv2d(in_channels*2, features[0], 4, 2, 1, padding_mode = 'reflect'),
        nn.LeakyReLU(.2)
    )

    in_channels = features[0]
    layers = []

    for feature in features[1:]:
      layers.append(DISBlock(in_channels, feature, stride = 1 if feature == features[-1] else 2))
      in_channels = feature

    layers.append(nn.Conv2d(in_channels, 1, 4, 1, 1, padding_mode = 'reflect'))
    self.model = nn.Sequential(*layers)

  def forward(self, x, y):
    x = self.initial(torch.cat([x, y], 1))
    return self.model(x)

# **Dataset**

In [5]:
!mkdir -p ./edges2shoes
!wget -q --show-progress -O ./edges2shoes/edges2shoes.tar.gz http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/edges2shoes.tar.gz

import tarfile

tar_path = "./edges2shoes/edges2shoes.tar.gz"
extract_path = "./edges2shoes"

with tarfile.open(tar_path) as tar:
    tar.extractall(path = extract_path)



In [6]:
train_both_transform = A.Compose([
    A.Resize(256, 256)
], additional_targets={"image0": "image"})

train_input_transform = A.Compose([
    A.HorizontalFlip(0.5),
    A.ColorJitter(0.2, 0.2, 0.2, 0.2),
    A.Normalize([0.5]*3, [0.5]*3, 255),
    ToTensorV2()
])

train_target_transform = A.Compose([
    A.Normalize([0.5]*3, [0.5]*3, 255),
    ToTensorV2()
])

val_transform = A.Compose([
    A.Resize(256, 256),
    A.Normalize([0.5] * 3, [0.5] * 3, 255),
    ToTensorV2()
], additional_targets={"image0": "image"})

In [7]:
class Edges2ShoesDataset(Dataset):
    def __init__(self, root_dir, mode="train"):
        self.mode = mode
        self.files = sorted(os.listdir(os.path.join(root_dir, mode)))
        self.root = os.path.join(root_dir, mode)

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img = np.array(Image.open(os.path.join(self.root, self.files[idx])))
        w = img.shape[1]
        input_img, target_img = img[:, :w//2, :], img[:, w//2:, :]

        if self.mode == "train":
            processing = train_both_transform(image = input_img, image0 = target_img)
            input_img, target_img = processing["image"], processing["image0"]
            input_img = train_input_transform(image = input_img)["image"]
            target_img = train_target_transform(image = target_img)["image"]

        else:
            processing = val_transform(image = input_img, image0 = target_img)
            input_img, target_img = processing["image"], processing["image0"]

        return {"input": input_img, "target": target_img}

In [8]:
train_loader = DataLoader(
    Edges2ShoesDataset("./edges2shoes/edges2shoes", mode = "train"),
    batch_size = 16, shuffle = True, num_workers = 2
)

val_loader = DataLoader(
    Edges2ShoesDataset("./edges2shoes/edges2shoes", mode = "val"),
    batch_size = 16, shuffle = False, num_workers = 2
)

# **Training**

In [12]:
device = "cuda" if torch.cuda.is_available() else "cpu"
learning_rate = 2e-4
batch_size =  16
num_workers = 2
image_size = 256
channels_img = 3
l1_lambda = 100
num_epochs = 500

disc = Discriminator(in_channels = 3).to(device)
gen = Generator(in_channels = 3, features = 64).to(device)

opt_disc = optim.Adam(disc.parameters(), lr = learning_rate, betas = (0.5, 0.999))
opt_gen = optim.Adam(gen.parameters(), lr = learning_rate, betas = (0.5, 0.999))

l1_loss = nn.L1Loss()
bce = nn.BCEWithLogitsLoss()

g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()

for epoch in range(num_epochs):
    gen.train()
    disc.train()
    for batch in tqdm(train_loader):
        x, y = batch["input"].to(device), batch["target"].to(device)

        with torch.cuda.amp.autocast():
            y_fake = gen(x)

            D_real = disc(x, y)
            D_real_loss = bce(D_real, torch.ones_like(D_real))

            D_fake = disc(x, y_fake.detach())
            D_fake_loss = bce(D_fake, torch.zeros_like(D_fake))

            D_loss = (D_real_loss + D_fake_loss) / 2

        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        with torch.cuda.amp.autocast():
            D_fake = disc(x, y_fake)

            G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
            l1 = l1_loss(y_fake, y)

            G_loss = G_fake_loss + l1_lambda * l1

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

    print(f"Epoch {epoch+1}: Disc Loss {D_loss.item():.4f} | Gen Loss {G_loss.item():.4f}")

    with torch.no_grad():
        gen.eval()
        disc.eval()

        batch = next(iter(val_loader))
        x, y = batch["input"].to(device), batch["target"].to(device)

        fake = gen(x)

        save_image(x[:3], "pix2pix_input.png", nrow=1, normalize=True)
        save_image(y[:3], "pix2pix_ground_truth.png", nrow=1, normalize=True)
        save_image(fake[:3], "pix2pix_output.png", nrow=1, normalize=True)

100%|██████████| 3115/3115 [10:18<00:00,  5.04it/s]

Epoch 1: Disc Loss 0.6706 | Gen Loss 14.7605



100%|██████████| 3115/3115 [10:14<00:00,  5.07it/s]

Epoch 2: Disc Loss 0.6765 | Gen Loss 16.0256



100%|██████████| 3115/3115 [10:10<00:00,  5.10it/s]

Epoch 3: Disc Loss 1.1305 | Gen Loss 27.1613



100%|██████████| 3115/3115 [10:12<00:00,  5.09it/s]

Epoch 4: Disc Loss 0.3771 | Gen Loss 30.8410



100%|██████████| 3115/3115 [10:11<00:00,  5.09it/s]

Epoch 5: Disc Loss 0.3338 | Gen Loss 24.3672



100%|██████████| 3115/3115 [10:14<00:00,  5.07it/s]

Epoch 6: Disc Loss 0.6438 | Gen Loss 19.9066



100%|██████████| 3115/3115 [10:13<00:00,  5.08it/s]

Epoch 7: Disc Loss 0.1790 | Gen Loss 19.9721



100%|██████████| 3115/3115 [10:12<00:00,  5.09it/s]

Epoch 8: Disc Loss 0.5824 | Gen Loss 30.2756



100%|██████████| 3115/3115 [10:15<00:00,  5.06it/s]

Epoch 9: Disc Loss 0.4930 | Gen Loss 21.8673



100%|██████████| 3115/3115 [10:16<00:00,  5.05it/s]

Epoch 10: Disc Loss 0.0678 | Gen Loss 22.1052



100%|██████████| 3115/3115 [10:12<00:00,  5.08it/s]

Epoch 11: Disc Loss 0.2250 | Gen Loss 8.9949



100%|██████████| 3115/3115 [10:11<00:00,  5.10it/s]

Epoch 12: Disc Loss 0.0113 | Gen Loss 23.8974



100%|██████████| 3115/3115 [10:12<00:00,  5.09it/s]

Epoch 13: Disc Loss 0.4757 | Gen Loss 20.7361



100%|██████████| 3115/3115 [10:14<00:00,  5.07it/s]

Epoch 14: Disc Loss 0.6228 | Gen Loss 16.6926



100%|██████████| 3115/3115 [10:15<00:00,  5.06it/s]

Epoch 15: Disc Loss 0.8035 | Gen Loss 27.3573



100%|██████████| 3115/3115 [10:13<00:00,  5.07it/s]

Epoch 16: Disc Loss 0.2270 | Gen Loss 21.6279



100%|██████████| 3115/3115 [10:09<00:00,  5.11it/s]

Epoch 17: Disc Loss 0.2738 | Gen Loss 18.9717



100%|██████████| 3115/3115 [10:11<00:00,  5.09it/s]

Epoch 18: Disc Loss 0.2417 | Gen Loss 14.5289



100%|██████████| 3115/3115 [10:10<00:00,  5.10it/s]

Epoch 19: Disc Loss 0.7007 | Gen Loss 12.7648



100%|██████████| 3115/3115 [10:12<00:00,  5.09it/s]

Epoch 20: Disc Loss 0.0183 | Gen Loss 31.5510



100%|██████████| 3115/3115 [10:12<00:00,  5.09it/s]

Epoch 21: Disc Loss 0.1655 | Gen Loss 23.6577



100%|██████████| 3115/3115 [10:15<00:00,  5.06it/s]

Epoch 22: Disc Loss 0.0949 | Gen Loss 17.4959



100%|██████████| 3115/3115 [10:13<00:00,  5.08it/s]

Epoch 23: Disc Loss 0.1554 | Gen Loss 19.4686



100%|██████████| 3115/3115 [10:13<00:00,  5.08it/s]

Epoch 24: Disc Loss 0.0679 | Gen Loss 29.4201



100%|██████████| 3115/3115 [10:16<00:00,  5.05it/s]

Epoch 25: Disc Loss 0.1640 | Gen Loss 21.0363



100%|██████████| 3115/3115 [10:12<00:00,  5.08it/s]

Epoch 26: Disc Loss 0.6424 | Gen Loss 10.3390



100%|██████████| 3115/3115 [10:13<00:00,  5.08it/s]

Epoch 27: Disc Loss 0.7941 | Gen Loss 19.1311



100%|██████████| 3115/3115 [10:14<00:00,  5.07it/s]

Epoch 28: Disc Loss 2.0458 | Gen Loss 22.3883



100%|██████████| 3115/3115 [10:16<00:00,  5.05it/s]

Epoch 29: Disc Loss 0.0836 | Gen Loss 29.6754



100%|██████████| 3115/3115 [10:19<00:00,  5.02it/s]

Epoch 30: Disc Loss 0.3127 | Gen Loss 28.0532



100%|██████████| 3115/3115 [10:19<00:00,  5.03it/s]

Epoch 31: Disc Loss 0.1082 | Gen Loss 21.0844



100%|██████████| 3115/3115 [10:14<00:00,  5.07it/s]

Epoch 32: Disc Loss 0.0780 | Gen Loss 24.1920



100%|██████████| 3115/3115 [10:15<00:00,  5.06it/s]

Epoch 33: Disc Loss 0.1808 | Gen Loss 15.4235



100%|██████████| 3115/3115 [10:12<00:00,  5.08it/s]

Epoch 34: Disc Loss 0.7099 | Gen Loss 12.0877



100%|██████████| 3115/3115 [10:14<00:00,  5.07it/s]

Epoch 35: Disc Loss 0.9952 | Gen Loss 15.8438



100%|██████████| 3115/3115 [10:11<00:00,  5.09it/s]

Epoch 36: Disc Loss 0.2146 | Gen Loss 25.0097



100%|██████████| 3115/3115 [10:09<00:00,  5.11it/s]

Epoch 37: Disc Loss 0.0107 | Gen Loss 16.7074



100%|██████████| 3115/3115 [10:10<00:00,  5.11it/s]

Epoch 38: Disc Loss 0.7617 | Gen Loss 24.8963



100%|██████████| 3115/3115 [10:09<00:00,  5.11it/s]

Epoch 39: Disc Loss 0.7680 | Gen Loss 27.2121



100%|██████████| 3115/3115 [10:11<00:00,  5.09it/s]

Epoch 40: Disc Loss 0.5777 | Gen Loss 27.5452



100%|██████████| 3115/3115 [10:12<00:00,  5.09it/s]

Epoch 41: Disc Loss 1.4654 | Gen Loss 22.3483



100%|██████████| 3115/3115 [10:10<00:00,  5.10it/s]

Epoch 42: Disc Loss 0.2448 | Gen Loss 21.0437



100%|██████████| 3115/3115 [10:14<00:00,  5.07it/s]

Epoch 43: Disc Loss 0.6762 | Gen Loss 9.8892



100%|██████████| 3115/3115 [10:12<00:00,  5.09it/s]

Epoch 44: Disc Loss 0.0074 | Gen Loss 19.8906



100%|██████████| 3115/3115 [10:12<00:00,  5.09it/s]

Epoch 45: Disc Loss 0.5680 | Gen Loss 39.0513



100%|██████████| 3115/3115 [10:11<00:00,  5.09it/s]

Epoch 46: Disc Loss 0.0957 | Gen Loss 17.1567



100%|██████████| 3115/3115 [10:13<00:00,  5.08it/s]

Epoch 47: Disc Loss 0.9574 | Gen Loss 24.9338



100%|██████████| 3115/3115 [10:15<00:00,  5.06it/s]

Epoch 48: Disc Loss 1.0996 | Gen Loss 18.0050



100%|██████████| 3115/3115 [10:14<00:00,  5.07it/s]

Epoch 49: Disc Loss 0.1094 | Gen Loss 13.4752



100%|██████████| 3115/3115 [10:16<00:00,  5.05it/s]

Epoch 50: Disc Loss 1.3806 | Gen Loss 15.9858



 15%|█▌        | 481/3115 [01:35<08:42,  5.04it/s]


KeyboardInterrupt: 