In [2]:
!clearml-init

from clearml import Task

task = Task.init(project_name="CycleGAN Training", task_name="Aivazovsky Dataset")

In [None]:
import glob

import os

import deepspeed as ds
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as L
import torch
import torch.nn.functional as F
import torchvision.transforms as T
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities import CombinedLoader
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision.io import read_image
from torchvision.utils import make_grid, save_image

logger = task.get_logger()


# Data Preprocessing

In [None]:
DATA_CONFIG = {
    "style_dir": os.path.join("/kaggle/input/cyclegan-dataset/aivazovsky", "*.jpg"),
    "photo_dir": os.path.join("/kaggle/input/cyclegan-dataset/real", "*.jpg"),
    "batch_size": 1,
    "sample_size": 5,
    "config": {
        "num_workers": os.cpu_count(),
        "pin_memory": True,
    },
}


In [None]:
class ImageTransform(object):
    def __init__(self, dim=256):

        self.resize = T.Resize((dim, dim), antialias=True)
        self.train_transform = T.Compose(
            [
                T.Resize((dim, dim), antialias=True),
                T.RandomCrop((dim, dim)),
                T.RandomHorizontalFlip(p=0.5),
                T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            ]
        )

    def __call__(self, image, stage):
        if stage == "fit":
            img = self.train_transform(image)
        else:
            img = self.resize(image)

        return img * 2 - 1  # normalization


In [None]:
class DatasetBlock(Dataset):
    def __init__(self, filenames, transform, stage):
        self.filenames = filenames
        self.transform = transform
        self.stage = stage

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img = read_image(self.filenames[idx]) / 255.0
        return self.transform(img, stage=self.stage)


In [None]:
class ToStyleModule(L.LightningDataModule):
    def __init__(
        self,
        style_dir,
        photo_dir,
        config,
        sample_size,
        batch_size,
    ):
        super().__init__()
        self.config = config
        self.sample_size = sample_size
        self.batch_size = batch_size

        self.style_filenames = glob.glob(style_dir)
        self.photo_filenames = glob.glob(photo_dir)

        self.transform = ImageTransform()

    def setup(self, stage):
        if stage == "fit":
            self.style_training = DatasetBlock(
                self.style_filenames, self.transform, stage
            )
            self.photo_training = DatasetBlock(
                self.photo_filenames, self.transform, stage
            )

        if stage in ["fit", "test", "predict"]:
            self.photo_validation = DatasetBlock(
                self.photo_filenames, self.transform, None
            )

    def train_dataloader(self):
        config = {
            "shuffle": True,
            "drop_last": True,
            "batch_size": self.batch_size,
            **self.config,
        }

        style_loader = DataLoader(self.style_training, **config)
        photo_loader = DataLoader(self.photo_training, **config)
        loaders = {"style": style_loader, "photo": photo_loader}

        return CombinedLoader(loaders, mode="max_size_cycle")

    def val_dataloader(self):
        return DataLoader(
            self.photo_validation,
            batch_size=self.sample_size,
            **self.config,
        )

    def test_dataloader(self):
        return self.val_dataloader()

    def predict_dataloader(self):
        return DataLoader(
            self.photo_validation,
            batch_size=self.batch_size,
            **self.config,
        )


In [8]:
dm_sample = ToStyleModule(**DATA_CONFIG)
dm_sample.setup("fit")

train_loader = dm_sample.train_dataloader()
imgs = next(iter(train_loader))

# CycleGAN components
## Downsampling, Upsampling and Resudual blocks

In [None]:
class Downsampling(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel=4,
        stride=2,
        padding=1,
        lrelu=True,
        norm=True,
    ):
        super().__init__()

        self.conv_block = nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=kernel,
                stride=stride,
                padding=padding,
                bias=not norm,
            )
        )
        if norm:
            self.conv_block.append(nn.InstanceNorm2d(out_channels, affine=True))

        if lrelu is not None:
            self.conv_block.append(nn.LeakyReLU(0.2, True) if lrelu else nn.ReLU(True))

    def forward(self, x):
        return self.conv_block(x)


In [None]:
class Upsampling(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel=4,
        stride=2,
        padding=1,
        output_padding=0,
        dropout=False,
    ):
        super().__init__()

        self.block = nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size=kernel,
                stride=stride,
                padding=padding,
                output_padding=output_padding,
                bias=False,
            ),
            nn.InstanceNorm2d(out_channels, affine=True),
        )

        if dropout:
            self.block.append(nn.Dropout(0.5))
        self.block.append(nn.ReLU(True))

    def forward(self, x):
        return self.block(x)


In [None]:
class Residual(nn.Module):
    def __init__(self, in_channels, kernel=3, padding=1):
        super().__init__()

        self.block = nn.Sequential(
            nn.ReflectionPad2d(padding),
            Downsampling(
                in_channels,
                in_channels,
                kernel=kernel,
                stride=1,
                padding=0,
                lrelu=False,
            ),
            nn.ReflectionPad2d(padding),
            Downsampling(
                in_channels, in_channels, kernel=kernel, stride=1, padding=0, lrelu=None
            ),
        )

    def forward(self, x):
        return x + self.block(x)


## U-Net 

In [None]:
class UNet(nn.Module):
    def __init__(self, hidden_channels, in_channels, out_channels):
        super().__init__()

        self.downsampling_block = nn.Sequential(
            Downsampling(
                in_channels, hidden_channels, norm=False
            ),  # 64x128x128 out_channels-height-width
            Downsampling(hidden_channels, hidden_channels * 2),  # 128x64x64
            Downsampling(hidden_channels * 2, hidden_channels * 4),  # 256x32x32
            Downsampling(hidden_channels * 4, hidden_channels * 8),  # 512x16x16
            Downsampling(hidden_channels * 8, hidden_channels * 8),  # 512x8x8
            Downsampling(hidden_channels * 8, hidden_channels * 8),  # 512x4x4
            Downsampling(hidden_channels * 8, hidden_channels * 8),  # 512x2x2
            Downsampling(hidden_channels * 8, hidden_channels * 8, norm=False),
            # 512x1x1, instance norm does not work on 1x1
        )

        self.upsampling_block = nn.Sequential(
            Upsampling(
                hidden_channels * 8, hidden_channels * 8, dropout=True
            ),  # (512+512)x2x2
            Upsampling(
                hidden_channels * 16, hidden_channels * 8, dropout=True
            ),  # (512+512)x4x4
            Upsampling(
                hidden_channels * 16, hidden_channels * 8, dropout=True
            ),  # (512+512)x8x8
            Upsampling(hidden_channels * 16, hidden_channels * 8),  # (512+512)x16x16
            Upsampling(hidden_channels * 16, hidden_channels * 4),  # (256+256)x32x32
            Upsampling(hidden_channels * 8, hidden_channels * 2),  # (128+128)x64x64
            Upsampling(hidden_channels * 4, hidden_channels),  # (64+64)x128x128
        )

        self.feature_block = nn.Sequential(
            nn.ConvTranspose2d(
                hidden_channels * 2, out_channels, kernel_size=4, stride=2, padding=1
            ),  # 3x256x256
            nn.Tanh(),
        )

    def forward(self, x):
        skips = []

        for down in self.downsampling_block:
            x = down(x)
            skips.append(x)
        skips = reversed(skips[:-1])

        for up, skip in zip(self.upsampling_block, skips):
            x = up(x)
            x = torch.cat([x, skip], dim=1)

        return self.feature_block(x)


## ResNet

In [None]:
class ResNet(nn.Module):
    def __init__(self, hidden_channels, in_channels, out_channels, num_resblocks):
        super().__init__()

        self.model = nn.Sequential(
            nn.ReflectionPad2d(3),
            Downsampling(
                in_channels, hidden_channels, kernel=7, stride=1, padding=0, lrelu=False
            ),  # 64x256x256
            Downsampling(
                hidden_channels, hidden_channels * 2, kernel=3, lrelu=False
            ),  # 128x128x128
            Downsampling(
                hidden_channels * 2, hidden_channels * 4, kernel=3, lrelu=False
            ),  # 256x64x64
            # residual blocks
            *[Residual(hidden_channels * 4) for _ in range(num_resblocks)],  # 256x64x64
            # upsampling path
            Upsampling(
                hidden_channels * 4, hidden_channels * 2, kernel=3, output_padding=1
            ),  # 128x128x128
            Upsampling(
                hidden_channels * 2, hidden_channels, kernel=3, output_padding=1
            ),  # 64x256x256
            nn.ReflectionPad2d(3),  # to handle border pixels
            nn.Conv2d(
                hidden_channels, out_channels, kernel_size=7, stride=1, padding=0
            ),  # 3x256x256
            nn.Tanh(),  # pixels in the range [-1,1]
        )

    def forward(self, x):
        return self.model(x)


In [None]:
class Generator:
    def __init__(
        self, name, hidden_channels, num_resblocks, in_channels=3, out_channels=3
    ):
        self.name = name
        self.hidden_channels = hidden_channels
        self.num_resblocks = num_resblocks
        self.in_channels = in_channels
        self.out_channels = out_channels

    def create(self):
        if self.name == "unet":
            return UNet(self.hidden_channels, self.in_channels, self.out_channels)

        elif self.name == "resnet":
            return ResNet(
                self.hidden_channels,
                self.in_channels,
                self.out_channels,
                self.num_resblocks,
            )

        return "Did not find generator"


## Generator

## Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self, hidden_channels, in_channels=3):
        super().__init__()

        self.model = nn.Sequential(
            Downsampling(in_channels, hidden_channels, norm=False),  # 64x128x128
            Downsampling(hidden_channels, hidden_channels * 2),  # 128x64x64
            Downsampling(hidden_channels * 2, hidden_channels * 4),  # 256x32x32
            Downsampling(
                hidden_channels * 4, hidden_channels * 8, stride=1
            ),  # 512x31x31
            nn.Conv2d(
                hidden_channels * 8, 1, kernel_size=4, padding=1
            ),  # 1x30x30 (num_channels-h-w)
        )  # 1 channel for binary classification task, 30-30 spatial dimensions of the feature map

    def forward(self, x):
        return self.model(x)


In [None]:
class ReplayBuffer(object):
    def __init__(self, max_size):
        self.max_size = max_size
        self.buffer = []
        self.current_capacity = 0

    def __call__(self, images):
        if self.max_size == 0:
            return images

        imgs = []
        for img in images:
            img = img.unsqueeze(dim=0)

            if self.current_capacity < self.max_size:
                self.current_capacity += 1
                self.buffer.append(img)
                imgs.append(img)
            else:
                p = np.random.uniform(low=0.0, high=1.0)

                if p > 0.5:
                    idx = np.random.randint(low=0, high=self.max_size)
                    tmp = self.buffer[idx].clone()
                    self.buffer[idx] = img
                    imgs.append(tmp)
                else:
                    imgs.append(img)

        return torch.cat(imgs, dim=0) if len(imgs) > 0 else images


# Build the model

In [17]:
def show_img(img_tensor, nrow, title=""):
    img_tensor = img_tensor.detach().cpu() * 0.5 + 0.5
    img_grid = make_grid(img_tensor, nrow=nrow).permute(1, 2, 0)
    plt.figure(figsize=(18, 8))
    plt.imshow(img_grid)
    plt.axis("off")
    plt.title(title)
    plt.show()

In [None]:
class CycleGAN(L.LightningModule):
    def __init__(
        self,
        name,
        num_resblocks,
        hidden_channels,
        optimizer,
        lr,
        betas,
        lambda_idt,
        lambda_cycle,
        buffer_max_size,
        num_epochs,
        decay_epochs,
    ):
        super().__init__()

        self.optimizer = optimizer
        self.save_hyperparameters(ignore=["optimizer"])
        self.automatic_optimization = False

        self.G_PS = Generator(name, hidden_channels, num_resblocks).create()
        self.G_SP = Generator(name, hidden_channels, num_resblocks).create()

        self.D_P = Discriminator(hidden_channels)
        self.D_S = Discriminator(hidden_channels)

        self.fake_P_buffer = ReplayBuffer(buffer_max_size)
        self.fake_S_buffer = ReplayBuffer(buffer_max_size)

    def forward(self, img):
        return self.G_PS(img)

    def init_weights(self):
        def init_fn(m):
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.InstanceNorm2d)):
                nn.init.normal_(m.weight, 0.0, 0.02)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0)

        for net in [self.G_PS, self.G_SP, self.D_S, self.D_P]:
            net.apply(init_fn)

    def setup(self, stage):
        if stage == "fit":
            self.init_weights()

    def get_lr_scheduler(self, optimizer):

        def lr_lambda(epoch):
            len_decay_phase = self.hparams.num_epochs - self.hparams.decay_epochs + 1.0
            curr_decay_step = max(0, epoch - self.hparams.decay_epochs + 1.0)
            val = 1.0 - curr_decay_step / len_decay_phase

            return max(0.0, val)

        return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

    def configure_optimizers(self):
        optimizer_config = {"lr": self.hparams.lr, "betas": self.hparams.betas}

        optimizer_G = self.optimizer(
            list(self.G_PS.parameters()) + list(self.G_SP.parameters()),
            **optimizer_config,
        )
        optimizer_D = self.optimizer(
            list(self.D_S.parameters()) + list(self.D_P.parameters()),
            **optimizer_config,
        )

        optimizers = [optimizer_G, optimizer_D]
        schedulers = [self.get_lr_scheduler(opt) for opt in optimizers]

        return optimizers, schedulers

    def training_step(self, batch, batch_idx):
        self.real_S = batch["style"]
        self.real_P = batch["photo"]
        opt_gen, opt_disc = self.optimizers()

        self.fake_S = self.G_PS(self.real_P)
        self.fake_P = self.G_SP(self.real_S)

        self.idt_S = self.G_PS(self.real_S)
        self.idt_P = self.G_SP(self.real_P)

        self.recon_S = self.G_PS(self.fake_P)
        self.recon_P = self.G_SP(self.fake_S)

        # train generators
        self.toggle_optimizer(opt_gen)
        gen_loss = self.get_gen_loss()
        opt_gen.zero_grad()
        self.manual_backward(gen_loss)
        opt_gen.step()
        self.untoggle_optimizer(opt_gen)

        self.toggle_optimizer(opt_disc)
        disc_loss_S = self.get_disc_loss_S()
        disc_loss_P = self.get_disc_loss_P()
        opt_disc.zero_grad()
        self.manual_backward(disc_loss_S)
        self.manual_backward(disc_loss_P)
        opt_disc.step()
        self.untoggle_optimizer(opt_disc)

        # record training losses

        metrics = {
            "gen_loss": gen_loss,
            "disc_loss_S": disc_loss_S,
            "disc_loss_P": disc_loss_P,
        }
        self.log_dict(metrics, on_step=False, on_epoch=True, prog_bar=True)

    def get_cycle_loss(self, real, recon, lambda_cycle):
        cycle_loss = F.l1_loss(recon, real)
        return lambda_cycle * cycle_loss

    def get_adv_loss(self, fake, disc):
        fake_hat = disc(fake)
        real_labels = torch.ones_like(fake_hat)
        adv_loss = F.mse_loss(fake_hat, real_labels)
        return adv_loss

    def get_idt_loss(self, real, idt, lambda_cycle):
        idt_loss = F.l1_loss(idt, real)
        return self.hparams.lambda_idt * lambda_cycle * idt_loss

    def get_gen_loss(self):
        # calculate adversarial loss
        adv_loss_PS = self.get_adv_loss(self.fake_S, self.D_S)
        adv_loss_SP = self.get_adv_loss(self.fake_P, self.D_P)
        total_adv_loss = adv_loss_PS + adv_loss_SP

        # calculate identity loss
        lambda_cycle = self.hparams.lambda_cycle
        idt_loss_SS = self.get_idt_loss(self.real_S, self.idt_S, lambda_cycle[0])
        idt_loss_PP = self.get_idt_loss(self.real_P, self.idt_P, lambda_cycle[1])
        total_idt_loss = idt_loss_SS + idt_loss_PP

        # calculate cycle loss
        cycle_loss_SPS = self.get_cycle_loss(self.real_S, self.recon_S, lambda_cycle[0])
        cycle_loss_PSP = self.get_cycle_loss(self.real_P, self.recon_P, lambda_cycle[1])
        total_cycle_loss = cycle_loss_SPS + cycle_loss_PSP

        # combine losses
        gen_loss = total_adv_loss + total_idt_loss + total_cycle_loss
        return gen_loss

    def get_disc_loss(self, real, fake, disc):
        real_hat = disc(real)
        real_labels = torch.ones_like(real_hat)
        real_loss = F.mse_loss(real_hat, real_labels)

        fake_hat = disc(fake.detach())
        fake_labels = torch.zeros_like(fake_hat)
        fake_loss = F.mse_loss(fake_hat, fake_labels)

        disc_loss = (fake_loss + real_loss) * 0.5
        return disc_loss

    def get_disc_loss_S(self):
        fake_S = self.fake_S_buffer(self.fake_S)
        return self.get_disc_loss(self.real_S, fake_S, self.D_S)

    def get_disc_loss_P(self):
        fake_P = self.fake_P_buffer(self.fake_P)
        return self.get_disc_loss(self.real_P, fake_P, self.D_P)

    def validation_step(self, batch, batch_idx):
        self.display_results(batch, batch_idx, "validate")

    def display_results(self, batch, batch_idx, stage):
        real_P = batch
        fake_S = self(real_P)

        if stage == "validate":
            title = f"Epoch {self.current_epoch + 1}: Photo-to-Style Translation"
        else:
            title = f"Sample {batch_idx + 1}: Photo-to-Style Translation"

        show_img(
            torch.cat([real_P, fake_S], dim=0),
            nrow=len(real_P),
            title=title,
        )

    def on_train_epoch_start(self):
        curr_lr = self.lr_schedulers()[0].get_last_lr()[0]
        self.log("lr", curr_lr, on_step=False, on_epoch=True, prog_bar=True)
        logger.report_scalar(
            "Learning Rate", "train", value=curr_lr, iteration=self.current_epoch
        )

    def on_train_epoch_end(self):
        avg_gen_loss = self.trainer.callback_metrics["gen_loss"].item()
        avg_disc_loss_S = self.trainer.callback_metrics["disc_loss_S"].item()
        avg_disc_loss_P = self.trainer.callback_metrics["disc_loss_P"].item()

        logger.report_scalar(
            "Generator Loss",
            "epoch_avg",
            value=avg_gen_loss,
            iteration=self.current_epoch,
        )
        logger.report_scalar(
            "Discriminator Loss (S)",
            "epoch_avg",
            value=avg_disc_loss_S,
            iteration=self.current_epoch,
        )
        logger.report_scalar(
            "Discriminator Loss (P)",
            "epoch_avg",
            value=avg_disc_loss_P,
            iteration=self.current_epoch,
        )

        for sch in self.lr_schedulers():
            sch.step()

        print(
            f"Epoch {self.current_epoch + 1}",
            f"gen_loss: {avg_gen_loss:.5f}",
            f"disc_loss_S: {avg_disc_loss_S:.5f}",
            f"disc_loss_P: {avg_disc_loss_P:.5f}",
            sep=" - ",
        )

    def on_train_end(self):
        print("Training ended.")

    def on_predict_epoch_end(self):
        predictions = self.trainer.predict_loop.predictions
        num_batches = len(predictions)
        batch_size = predictions[0].shape[0]
        last_batch_diff = batch_size - predictions[-1].shape[0]
        print(
            f"Number of images generated: {num_batches * batch_size - last_batch_diff}."
        )


# Configuration

In [None]:
MODEL_CONFIG = {
    "name": "unet",
    "num_resblocks": 9,
    "hidden_channels": 64,
    "optimizer": (
        ds.ops.adam.FusedAdam if torch.cuda.is_available() else torch.optim.Adam
    ),
    "lr": 2e-4,
    "betas": (0.5, 0.999),
    "lambda_idt": 0.5,
    "lambda_cycle": (10, 10),  # (MPM direction, PMP direction)
    "buffer_max_size": 100,
    "num_epochs": 20,
    "decay_epochs": 20,
}


In [None]:
checkpoint_callback = ModelCheckpoint(
    dirpath="/kaggle/working/",
    filename="epoch-{epoch:02d}",
    save_top_k=-1,
    save_last=True,
    every_n_epochs=5,
)

TRAIN_CONFIG = {
    "accelerator": "gpu" if torch.cuda.is_available() else "cpu",
    "precision": "16-mixed" if torch.cuda.is_available() else 32,
    "devices": 1,
    "enable_checkpointing": True,
    "max_epochs": MODEL_CONFIG["num_epochs"],
    "limit_train_batches": 1.0,
    "limit_predict_batches": 1.0,
    "max_time": {"hours": 4, "minutes": 55},
    "limit_val_batches": 1,
    "limit_test_batches": 5,
    "num_sanity_val_steps": 0,
    "check_val_every_n_epoch": 1,
    "callbacks": [checkpoint_callback],
}


In [21]:
dm = ToStyleModule(**DATA_CONFIG)
model = CycleGAN(**MODEL_CONFIG)
trainer = L.Trainer(**TRAIN_CONFIG)
trainer.fit(model, datamodule=dm)