<a href="https://colab.research.google.com/github/harryypham/MyMLPractice/blob/main/Pix2pix.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [10]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision.utils import save_image
from tqdm import tqdm
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import warnings
warnings.filterwarnings("ignore")


%matplotlib inline

In [4]:
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
TRAIN_DIR = "data/data/train"
VAL_DIR = "data/data/val"
LEARNING_RATE = 2e-4
BATCH_SIZE = 16
NUM_WORKERS = 2
IMAGE_SIZE = 256
CHANNELS_IMG = 3
L1_LAMBDA = 100
LAMBDA_GP = 10
NUM_EPOCHS = 500

both_transform = A.Compose(
    [A.Resize(width=256, height=256), A.HorizontalFlip(p=0.5)], additional_targets={"image0": "image"},
)

transform_only_input = A.Compose(
    [
        A.HorizontalFlip(p=0.5),
        A.ColorJitter(p=0.2),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
        ToTensorV2(),
    ]
)

transform_only_mask = A.Compose(
    [
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
        ToTensorV2(),
    ]
)

def save_some_examples(gen, val_loader, epoch, folder):
    x, y = next(iter(val_loader))
    x, y = x.to(DEVICE), y.to(DEVICE)
    gen.eval()
    with torch.no_grad():
        y_fake = gen(x)
        y_fake = y_fake * 0.5 + 0.5  # remove normalization#
        save_image(y_fake, folder + f"/y_gen_{epoch}.png")
        save_image(x * 0.5 + 0.5, folder + f"/input_{epoch}.png")
        if epoch == 1:
            save_image(y * 0.5 + 0.5, folder + f"/label_{epoch}.png")
    gen.train()

In [3]:
class CustomDataset(Dataset):
  def __init__(self, root_dir):
    self.root_dir = root_dir
    self.list_files = os.listdir(root_dir)

  def __len__(self):
    return len(self.list_files)

  def __getitem__(self, index):
    img_file = self.list_files[index]
    img_path = os.path.join(self.root_dir, img_file)
    image = np.array(Image.open(img_path))
    input_image, target_image = image[:, :512, :], image[:, 512:, :]

    augmentations = both_transform(image=input_image, image0=target_image)
    input_image, target_image = augmentations["image"], augmentations["image0"]

    input_image = transform_only_input(image=input_image)["image"]
    target_image = transform_only_input(image=target_image)["image"]

    return input_image, target_image

In [5]:
class Conv(nn.Module):
  def __init__(self, in_channels, out_channels, stride):
    super().__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 4, stride, bias=False, padding_mode="reflect"),
        nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(0.2)
    )

  def forward(self, x):
    return self.conv(x)

class Discriminator(nn.Module):
  def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
    super().__init__()
    self.initial = nn.Sequential(
        nn.Conv2d(in_channels*2, features[0], 4, 2, 1, padding_mode="reflect"),
        nn.LeakyReLU(0.2),
    )

    layers = []
    in_channels = features[0]
    for feature in features[1:]:
      layers.append(
          Conv(in_channels, feature, stride=1 if feature == features[-1] else 2)
      )
      in_channels = feature
    layers.append(
        nn.Conv2d(in_channels, 1, 4, 1, 1, padding_mode="reflect")
    )

    self.model = nn.Sequential(*layers)

  def forward(self, x, y):
    x = torch.cat([x, y], dim=1)
    x = self.initial(x)
    x = self.model(x)
    return x



In [6]:
class Block(nn.Module):
  def __init__(self, in_channels, out_channels, down=True, act="relu", use_dropout=False):
    super().__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False, padding_mode="reflect")
        if down
        else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU() if act == "relu" else nn.LeakyReLU(0.2),
    )
    self.use_dropout = use_dropout
    self.dropout = nn.Dropout(0.5)

  def forward(self, x):
    x = self.conv(x)
    return self.dropout(x) if self.use_dropout else x

class Generator(nn.Module):
  def __init__(self, in_channels=3, features=64):
    super().__init__()
    self.initial_down = nn.Sequential(
        nn.Conv2d(in_channels, features, 4, 2, 1, padding_mode="reflect"),
        nn.LeakyReLU(0.2),
    )

    self.down1 = Block(features, features*2, down=True, act="leaky")
    self.down2 = Block(features*2, features*4, down=True, act="leaky")
    self.down3 = Block(features*4, features*8, down=True, act="leaky")
    self.down4 = Block(features*8, features*8, down=True, act="leaky")
    self.down5 = Block(features*8, features*8, down=True, act="leaky")
    self.down6 = Block(features*8, features*8, down=True, act="leaky")

    self.bottleneck = nn.Sequential(
        nn.Conv2d(features*8, features*8, 4, 2, 1, padding_mode="reflect"),
        nn.ReLU()
    )

    self.up1 = Block(features*8, features*8, down=False, act="relu", use_dropout=True)
    self.up2 = Block(features*8*2, features*8, down=False, act="relu", use_dropout=True)
    self.up3 = Block(features*8*2, features*8, down=False, act="relu", use_dropout=True)
    self.up4 = Block(features*8*2, features*8, down=False, act="relu", use_dropout=False)
    self.up5 = Block(features*8*2, features*4, down=False, act="relu", use_dropout=False)
    self.up6 = Block(features*4*2, features*2, down=False, act="relu", use_dropout=False)
    self.up7 = Block(features*2*2, features, down=False, act="relu", use_dropout=False)
    self.final_up = nn.Sequential(
        nn.ConvTranspose2d(features*2, in_channels, 4, 2, 1),
        nn.Tanh()
    )

  def forward(self, x):
    d1 = self.initial_down(x)
    d2 = self.down1(d1)
    d3 = self.down2(d2)
    d4 = self.down3(d3)
    d5 = self.down4(d4)
    d6 = self.down5(d5)
    d7 = self.down6(d6)

    bottleneck = self.bottleneck(d7)

    u1 = self.up1(bottleneck)
    u2 = self.up2(torch.cat([u1, d7], dim=1))
    u3 = self.up3(torch.cat([u2, d6], dim=1))
    u4 = self.up4(torch.cat([u3, d5], dim=1))
    u5 = self.up5(torch.cat([u4, d4], dim=1))
    u6 = self.up6(torch.cat([u5, d3], dim=1))
    u7 = self.up7(torch.cat([u6, d2], dim=1))
    out = self.final_up(torch.cat([u7, d1], dim=1))
    return out




In [11]:
def train_fn(disc, gen, loader, opt_disc, opt_gen, l1, bce):
  loop = tqdm(loader, leave=True)
  d_loss, g_loss = 0, 0
  for idx, (x, y) in enumerate(loop):
    x, y = x.to(DEVICE), y.to(DEVICE)

    y_fake = gen(x)
    D_real = disc(x, y)
    D_fake = disc(x, y_fake.detach())
    D_real_loss = bce(D_real, torch.ones_like(D_real))
    D_fake_loss = bce(D_fake, torch.zeros_like(D_fake))
    D_loss = (D_real_loss + D_fake_loss) / 2

    disc.zero_grad()
    D_loss.backward()
    opt_disc.step()

    D_fake = disc(x, y_fake)
    G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
    L1 = l1(y_fake, y) * L1_LAMBDA
    G_loss = G_fake_loss + L1

    gen.zero_grad()
    G_loss.backward()
    opt_gen.step()
    d_loss += D_loss.item()
    g_loss += G_loss.item()

    loop.set_description("D_loss: %.6f | G_loss: %.6f" % (d_loss/((idx+1)*BATCH_SIZE), g_loss/((idx+1)*BATCH_SIZE)))


In [12]:
def main():
  disc = Discriminator(in_channels=3).to(DEVICE)
  gen = Generator(in_channels=3).to(DEVICE)
  opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
  opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
  BCE = nn.BCEWithLogitsLoss()
  L1_LOSS = nn.L1Loss()

  train_dataset = CustomDataset(TRAIN_DIR)
  train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
  valset = CustomDataset(VAL_DIR)
  val_loader = DataLoader(valset, batch_size=1, shuffle=False)

  for epoch in range(NUM_EPOCHS):
    train_fn(disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, BCE)

    save_some_examples(gen, val_loader, epoch, folder="eval")

In [None]:
main()

D_loss: 0.026193 | G_loss: 0.976092: 100%|██████████| 889/889 [02:31<00:00,  5.88it/s]
D_loss: 0.028710 | G_loss: 0.879163: 100%|██████████| 889/889 [02:30<00:00,  5.89it/s]
D_loss: 0.028967 | G_loss: 0.882265: 100%|██████████| 889/889 [02:31<00:00,  5.89it/s]
D_loss: 0.030519 | G_loss: 0.868943: 100%|██████████| 889/889 [02:31<00:00,  5.88it/s]
D_loss: 0.029786 | G_loss: 0.859628:  83%|████████▎ | 740/889 [02:05<00:25,  5.91it/s]