<a href="https://colab.research.google.com/github/isurushanaka/iPURSE2023/blob/main/iPURSE_Workshop_GAN_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Image-to-image translation with a conditional GAN**

### [Dataset](https://github.com/isurushanaka/N2D250K)


*   Night-to-day image translation
*   Paired images

<img src="https://github.com/isurushanaka/N2D250K/blob/main/Sample%20Images/paired_dataset-v.png?raw=true"  width="22%" height="70%">



### GAN Implementation


1.   Generator
2.   Discriminator
3.   Training Function
4.   Data Loader
5.   Helper Functions
6.   Hyperparameters
7.   Training



# 1. Generator

![](https://www.researchgate.net/profile/Chenxing-Wang/publication/349487608/figure/fig4/AS:1017165721378817@1619522614571/The-network-structure-of-pix2pix-including-the-structure-of-U-Net.png)

In [None]:
import torch
import torch.nn as nn


class genBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, act="relu", use_dropout=False):
        super(genBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False, padding_mode="reflect") if down else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU() if act == "relu" else nn.LeakyReLU(0.2),
        )

        self.use_dropout = use_dropout
        self.dropout = nn.Dropout(0.5)
        self.down = down

    def forward(self, x):
        x = self.conv(x)
        return self.dropout(x) if self.use_dropout else x


class Generator(nn.Module):
    def __init__(self, in_channels=3):
        super().__init__()
        self.initial_down = nn.Sequential(nn.Conv2d(in_channels, 64, 4, 2, 1, padding_mode="reflect"), nn.LeakyReLU(0.2),)
        self.down1 = genBlock(64, 128, down=True, act="leaky", use_dropout=False)
        self.down2 = genBlock(128, 256, down=True, act="leaky", use_dropout=False)
        self.down3 = genBlock(256, 512, down=True, act="leaky", use_dropout=False)
        self.down4 = genBlock(512, 512, down=True, act="leaky", use_dropout=False)
        self.down5 = genBlock(512, 512, down=True, act="leaky", use_dropout=False)
        self.down6 = genBlock(512, 512, down=True, act="leaky", use_dropout=False)

        self.bottleneck = nn.Sequential(nn.Conv2d(512, 512, 4, 2, 1), nn.ReLU())

        self.up1 = genBlock(512, 512, down=False, act="relu", use_dropout=True)
        self.up2 = genBlock(512 * 2, 512, down=False, act="relu", use_dropout=True)
        self.up3 = genBlock(512 * 2, 512, down=False, act="relu", use_dropout=True)
        self.up4 = genBlock(512 * 2, 512, down=False, act="relu", use_dropout=False)
        self.up5 = genBlock(512 * 2, 256, down=False, act="relu", use_dropout=False)
        self.up6 = genBlock(256 * 2, 128, down=False, act="relu", use_dropout=False)
        self.up7 = genBlock(128 * 2, 64, down=False, act="relu", use_dropout=False)

        self.final_up = nn.Sequential(
            nn.ConvTranspose2d(128, in_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )

    def forward(self, x):
        d1 = self.initial_down(x)
        d2 = self.down1(d1)
        d3 = self.down2(d2)
        d4 = self.down3(d3)
        d5 = self.down4(d4)
        d6 = self.down5(d5)
        d7 = self.down6(d6)
        bottleneck = self.bottleneck(d7)
        up1 = self.up1(bottleneck)
        up2 = self.up2(torch.cat([up1, d7], 1))
        up3 = self.up3(torch.cat([up2, d6], 1))
        up4 = self.up4(torch.cat([up3, d5], 1))
        up5 = self.up5(torch.cat([up4, d4], 1))
        up6 = self.up6(torch.cat([up5, d3], 1))
        up7 = self.up7(torch.cat([up6, d2], 1))
        return self.final_up(torch.cat([up7, d1], 1))


# 2. Discriminator

![](https://www.researchgate.net/profile/Chenxing-Wang/publication/349487608/figure/fig4/AS:1017165721378817@1619522614571/The-network-structure-of-pix2pix-including-the-structure-of-U-Net.png)

In [None]:
import torch
import torch.nn as nn


class discBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(discBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias=False, padding_mode="reflect"),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.conv(x)


class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels * 2,features[0],kernel_size=4,stride=2,padding=1,padding_mode="reflect",),
            nn.LeakyReLU(0.2),
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(discBlock(in_channels, feature, stride=1 if feature == features[-1] else 2),)
            in_channels = feature

        layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"),)

        self.model = nn.Sequential(*layers)

    def forward(self, x, y):
        x = torch.cat([x, y], dim=1)
        x = self.initial(x)
        x = self.model(x)
        return x

# 3. Training Function

In [None]:
from tqdm import tqdm
import torch
from torchvision.utils import save_image,make_grid

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
L1_LAMBDA = 10

def train_fn(disc, gen, loader, opt_disc, opt_gen, l1_loss, bce, g_scaler, d_scaler, writer, writer_step):
    loop = tqdm(loader, leave=True)

    for idx, (night, day) in enumerate(loop):
        night = night.to(DEVICE)
        day = day.to(DEVICE)

        # Train Discriminator
        with torch.cuda.amp.autocast():
            day_fake = gen(night)
            D_real = disc(night, day)
            D_real_loss = bce(D_real, torch.ones_like(D_real))
            D_fake = disc(night, day_fake.detach())
            D_fake_loss = bce(D_fake, torch.zeros_like(D_fake))
            D_loss = (D_real_loss + D_fake_loss) / 2

        disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train generator
        with torch.cuda.amp.autocast():
            D_fake = disc(night, day_fake)
            G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
            L1 = l1_loss(day_fake, day) * L1_LAMBDA
            G_loss = G_fake_loss + L1

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        #training log
        log_dict={
            "Disc_loss":f"{D_loss.mean().item()}",
            "Gen_loss":f"{G_loss.mean().item()}",

            "d_real":f"{torch.sigmoid(D_real).mean().item()}",
            "d_fake":f"{torch.sigmoid(D_fake).mean().item()}",

        }
        writer.add_scalar("Disc_loss", float(log_dict["Disc_loss"]), global_step=writer_step)
        writer.add_scalar("Gen_loss", float(log_dict["Gen_loss"]), global_step=writer_step)
        writer.add_scalar("d_real", float(log_dict["d_real"]), global_step=writer_step)
        writer.add_scalar("d_fake", float(log_dict["d_fake"]), global_step=writer_step)

        day_fake = day_fake * 0.5 + 0.5  # remove normalization#
        img_grid = make_grid(day_fake)
        writer.add_image('fake_day', img_grid, global_step=writer_step)

        writer_step+=1


        if idx % 10 == 0:
            loop.set_postfix(D_real=torch.sigmoid(D_real).mean().item(),D_fake=torch.sigmoid(D_fake).mean().item(),)

    return log_dict

# 4. Data Loader

In [None]:
import numpy as np
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2


transform = A.Compose(
    [
        A.Resize(width=256, height=256),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
        ToTensorV2(),
    ],
    additional_targets={"image0": "image"},)


class LoadDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.list_files = os.listdir(self.root_dir+"/night_images")

    def __len__(self):
        return len(self.list_files)

    def get_images(self,index):
        img_file = self.list_files[index]
        try:
            day_img_path = os.path.join(self.root_dir+"/day_images", img_file)
            ngt_img_path = os.path.join(self.root_dir+"/night_images", img_file)
            input_image = np.array(Image.open(ngt_img_path))
            target_image = np.array(Image.open(day_img_path))

            augmentations = transform(image=input_image, image0=target_image)
            input_image = augmentations["image"]
            target_image = augmentations["image0"]

            return input_image, target_image
        except Exception as e:
            return self.get_images(index+1)

    def __getitem__(self, index):
        input_image, target_image = self.get_images(index)
        return input_image, target_image

# 5. Helper Functions

In [None]:
import torch
import torchvision
from torchvision.utils import save_image,make_grid

def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

def val_fn(gen, validation_loader, epoch, saved_imgs_folder, writer):
    x, y = next(iter(validation_loader))
    x, y = x.to(DEVICE), y.to(DEVICE)
    gen.eval()
    with torch.no_grad():
        y_fake = gen(x)
        y_fake = y_fake * 0.5 + 0.5  # remove normalization#
        save_image(y_fake, saved_imgs_folder + f"/y_gen_{epoch}.png")
        save_image(x * 0.5 + 0.5, saved_imgs_folder + f"/input_{epoch}.png")

        # create grid of images
        img_grid = make_grid(y_fake)
        writer.add_image('fake_day', img_grid, global_step=epoch)
        if epoch == 1:
            save_image(y * 0.5 + 0.5, saved_imgs_folder + f"/label_{epoch}.png")
    gen.train()

# 6. Hyperparameters

In [None]:
training_data_folder_path = "/content/drive/MyDrive/iPURSE/Dataset/train"
validation_data_folder_path = "/content/drive/MyDrive/iPURSE/Dataset/validation"
CHECKPOINT_DISC_PATH = "/content/drive/MyDrive/iPURSE/Checkpoint/disc.pth.tar"
CHECKPOINT_GEN_PATH = "/content/drive/MyDrive/iPURSE/Checkpoint/gen.pth.tar"

saved_imgs_folder_path = "/content/drive/MyDrive/iPURSE/Saved_Images"
runs_path = "/content/drive/MyDrive/iPURSE/runs"

LEARNING_RATE = 2e-4
BATCH_SIZE = 16
NUM_WORKERS = 16
NUM_EPOCHS = 300


LOAD_MODEL = True
SAVE_MODEL = True

# 7. Training

In [None]:
%load_ext tensorboard

In [None]:
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# from tensorflow import summary

disc = Discriminator(in_channels=3).to(DEVICE)
gen = Generator(in_channels=3).to(DEVICE)
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999),)
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
BCE = nn.BCEWithLogitsLoss()
L1_LOSS = nn.L1Loss()

disc.train()
gen.train()

if LOAD_MODEL:
  load_checkpoint(CHECKPOINT_GEN_PATH, gen, opt_gen, LEARNING_RATE,)
  load_checkpoint(CHECKPOINT_DISC_PATH, disc, opt_disc, LEARNING_RATE,)

train_dataset = LoadDataset(root_dir=training_data_folder_path)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,shuffle=True, num_workers=NUM_WORKERS,)

validation_dataset = LoadDataset(root_dir=validation_data_folder_path)
validation_loader = DataLoader(validation_dataset, batch_size=20, shuffle=False,)

g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()

curr_epoch = 70
residue = 0
batchwriter_step = 632*curr_epoch+residue

batchwriter = SummaryWriter(runs_path + "/N2D/batch", purge_step=batchwriter_step)
epochwriter = SummaryWriter(runs_path + "/N2D/epoch", purge_step=curr_epoch)


for epoch in range(curr_epoch, NUM_EPOCHS):
  tr_log = train_fn(disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, BCE, g_scaler, d_scaler, batchwriter, batchwriter_step)
  #update the tensor board
  epochwriter.add_scalar("d_real", float(tr_log["d_real"]), global_step=epoch)
  epochwriter.add_scalar("d_fake", float(tr_log["d_fake"]), global_step=epoch)
  epochwriter.add_scalar("Disc_loss", float(tr_log["Disc_loss"]), global_step=epoch)
  epochwriter.add_scalar("Gen_loss", float(tr_log["Gen_loss"]), global_step=epoch)

  if SAVE_MODEL:
    save_checkpoint(gen, opt_gen, filename=CHECKPOINT_GEN_PATH)
    save_checkpoint(disc, opt_disc, filename=CHECKPOINT_DISC_PATH)
  val_fn(gen, validation_loader, epoch, saved_imgs_folder=saved_imgs_folder_path, writer=epochwriter)
  print(f"epoch {epoch} is completed")

=> Loading checkpoint
=> Loading checkpoint


100%|██████████| 632/632 [15:48<00:00,  1.50s/it, D_fake=0.49, D_real=0.511]


=> Saving checkpoint
=> Saving checkpoint
epoch 56 is completed


100%|██████████| 632/632 [12:19<00:00,  1.17s/it, D_fake=0.489, D_real=0.509]


=> Saving checkpoint
=> Saving checkpoint
epoch 57 is completed


100%|██████████| 632/632 [12:23<00:00,  1.18s/it, D_fake=0.453, D_real=0.483]


=> Saving checkpoint
=> Saving checkpoint
epoch 58 is completed


100%|██████████| 632/632 [12:28<00:00,  1.18s/it, D_fake=0.229, D_real=0.812]


=> Saving checkpoint
=> Saving checkpoint
epoch 59 is completed


100%|██████████| 632/632 [12:32<00:00,  1.19s/it, D_fake=0.326, D_real=0.631]


=> Saving checkpoint
=> Saving checkpoint
epoch 60 is completed


100%|██████████| 632/632 [12:28<00:00,  1.18s/it, D_fake=0.498, D_real=0.495]


=> Saving checkpoint
=> Saving checkpoint
epoch 61 is completed


100%|██████████| 632/632 [12:28<00:00,  1.18s/it, D_fake=0.163, D_real=0.696]


=> Saving checkpoint
=> Saving checkpoint
epoch 62 is completed


100%|██████████| 632/632 [12:36<00:00,  1.20s/it, D_fake=0.275, D_real=0.575]


=> Saving checkpoint
=> Saving checkpoint
epoch 63 is completed


100%|██████████| 632/632 [12:24<00:00,  1.18s/it, D_fake=0.357, D_real=0.668]


=> Saving checkpoint
=> Saving checkpoint
epoch 64 is completed


100%|██████████| 632/632 [12:28<00:00,  1.18s/it, D_fake=0.47, D_real=0.495]


=> Saving checkpoint
=> Saving checkpoint
epoch 65 is completed


100%|██████████| 632/632 [12:39<00:00,  1.20s/it, D_fake=0.273, D_real=0.59]


=> Saving checkpoint
=> Saving checkpoint
epoch 66 is completed


100%|██████████| 632/632 [12:34<00:00,  1.19s/it, D_fake=0.465, D_real=0.549]


=> Saving checkpoint
=> Saving checkpoint
epoch 67 is completed


100%|██████████| 632/632 [12:38<00:00,  1.20s/it, D_fake=0.412, D_real=0.444]


=> Saving checkpoint
=> Saving checkpoint
epoch 68 is completed


100%|██████████| 632/632 [12:32<00:00,  1.19s/it, D_fake=0.501, D_real=0.472]


=> Saving checkpoint
=> Saving checkpoint
epoch 69 is completed


100%|██████████| 632/632 [12:26<00:00,  1.18s/it, D_fake=0.496, D_real=0.516]


=> Saving checkpoint
=> Saving checkpoint
epoch 70 is completed


 41%|████      | 258/632 [05:10<07:52,  1.26s/it, D_fake=0.49, D_real=0.528]

In [None]:
! tensorboard --logdir "/content/drive/MyDrive/iPURSE/runs" --port=6006

In [None]:
from tensorboard import notebook
notebook.display(port=6006, height=1000)

In [None]:
! tensorboard dev upload \
  --logdir "/content/drive/MyDrive/iPURSE/runs" \
  --name "N2D" \
  --description "Training" \
  --one_shot