# Photo to Monet — CycleGAN⚡PyTorch Lightning

This notebook aims to implement CycleGAN. The model architecture is adapted from the [tutorial](https://www.kaggle.com/code/amyjang/monet-cyclegan-tutorial/notebook) available. We also attempt to convert the code to PyTorch Lightning here. The list of references is as follows:
* Original [paper](https://arxiv.org/abs/1703.10593) and [code](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) for CycleGAN.
* Original [paper](https://arxiv.org/abs/1611.04076) for LSGAN, which has shown to outperform BCE loss and is used in the original CycleGAN implementation. BCE loss may be susceptible to vanishing gradient problems and cause ineffective learning. For better training stability, we use LSGAN, which adopts the mean squared error for the adversarial criterion.
* [Documentation](https://pytorch-lightning.readthedocs.io/en/stable/notebooks/lightning_examples/basic-gan.html) for basic GAN in PyTorch Lightning. Kaggle seems to still be using [Python 3.7](https://www.kaggle.com/discussions/product-feedback/388376) at the time of writing, which does not support Lightning 2.0. Manual optimization is required for training with multiple optimizers in Lightning 2.0, but for simplicity we stick to automatic optimization here in version 1.9. To upgrade the code for compatibility with Lightning 2.0, a useful documentation can be found [here](https://lightning.ai/docs/pytorch/stable/upgrade/from_1_9.html).

More work can be done to include evaluation metrics like the inception score (IS) or Fréchet inception distance (FID).

In [None]:
# install library
print('installing...')

# !pip install pytorch_lightning

# !pip install torchvision

# !pip install --upgrade pytorch_lightning
# !pip install pytorch_lightning==1.4.9

# !pip install --upgrade torchtext

# !pip install pytorch_lightning


# !pip uninstall torchmetrics
# !pip uninstall pytorch_lightning
# !pip install torchmetrics
# !pip install pytorch_lightning



print('installed')

installing...


In [5]:
# !pip show torchmetrics
# !pip show pytorch_lightning

!pip show pytorch_lightning

Name: pytorch-lightning
Version: 2.0.1.post0
Summary: PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers. Scale your models. Write less boilerplate.
Home-page: https://github.com/Lightning-AI/lightning
Author: Lightning AI et al.
Author-email: pytorch@lightning.ai
License: Apache-2.0
Location: f:\software\installed\anaconda\src\envs\deep_learning\lib\site-packages
Requires: fsspec, lightning-utilities, numpy, packaging, PyYAML, torch, torchmetrics, tqdm, typing-extensions
Required-by: 


In [6]:
import glob
import os
import shutil

import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as L
import torch
import torch.nn.functional as F
import torchvision.transforms as T
# from pytorch_lightning.trainer.supporters import CombinedLoader
# from pytorch_lightning.utilities.data import DataListLoader
from pytorch_lightning import DataListLoader

from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision.io import read_image
from torchvision.utils import make_grid, save_image

L.seed_everything(0, workers=True)
print(L.__name__, L.__version__)

ImportError: cannot import name 'DataListLoader' from 'pytorch_lightning' (F:\Software\installed\anaconda\src\envs\deep_learning\lib\site-packages\pytorch_lightning\__init__.py)

---

# 1. Data Preprocessing

In [None]:
def show_img(img_tensor, nrow, title=""):
    img_tensor = img_tensor.detach().cpu()*0.5 + 0.5
    img_grid = make_grid(img_tensor, nrow=nrow).permute(1, 2, 0)
    plt.figure(figsize=(18, 8))
    plt.imshow(img_grid)
    plt.axis("off")
    plt.title(title)
    plt.show()

### Augmenting the images.

Before loading the datasets, we define `CustomTransform` for image augmentation. This improves learning by introducing more variety in the images during training instead of learning from the same set of images, especially when we only have 300 Monet paintings. We look at some basic image transformations:
* Scaling the images larger using `Resize` and then randomly cropping to the original size of 256 with `RandomCrop` to obtain slightly different images.
* Randomly flipping the images horizontally using `RandomHorizontalFlip`. The photos and Monet paintings do not greatly depend on the horizontal orientation.
* Randomly changing the colors of the images using `ColorJitter`. This could mimic different lighting conditions for the photos and introduce variability in the colors of the Monet paintings.

Other possible transformations can be found [here](https://pytorch.org/vision/stable/transforms.html). These transformations are only needed during model training/fitting, and we specify this using the `stage` argument. Finally, the images are scaled down for better convergence.

In [None]:
class CustomTransform(object):
    def __init__(self, load_dim=286, target_dim=256):
        self.transform_train = T.Compose([
            T.Resize((load_dim, load_dim)),
            T.RandomCrop((target_dim, target_dim)),
            T.RandomHorizontalFlip(p=0.5),
            T.ColorJitter(brightness=0.2, contrast=0.2,
                          saturation=0.2, hue=0.1),
        ])
        
        # ensure images outside of training dataset are also of the same size
        self.transform = T.Resize((target_dim, target_dim))
        
    def __call__(self, img, stage="fit"):
        if stage == "fit":
            img = self.transform_train(img)
        else:
            img = self.transform(img)
        return img*2 - 1

### Storing the datasets.

To load and store the datasets, we define a custom `Dataset` involving three main methods: 
* `__init__` to initialize the dataset.
* `__len__` to retrieve the size of the dataset.
* `__getitem__` to get the i-th sample of images after performing the transformations described above.

Similarly, we define the `stage` argument to differentiate between the training dataset and prediction dataset when performing the transformations. Different instances of `CustomDataset` will be used to retrieve the photos and Monet paintings separately. We look at combining them while iterating through the datasets later.

In [None]:
class CustomDataset(Dataset):
    def __init__(self, filenames, transform, stage):
        self.filenames = filenames
        self.transform = transform
        self.stage = stage
        
    def __len__(self):
        return len(self.filenames)
    
    def __getitem__(self, idx):
        img_name = self.filenames[idx]
        img = read_image(img_name) / 255.0
        return self.transform(img, stage=self.stage)

### Iterating through the datasets.

To prepare the datasets, we load them into `DataLoader` separately, which can then iterate through the datasets as needed. Because the training dataset contains both the Monet paintings and photos, we pass both dataloaders into `CombinedLoader` for training. We specify the sampling mode using `mode="max_size_cycle"` to stop after one complete pass of the larger dataset of photos while cycling through the smaller dataset of Monet paintings. Other modes can be found [here](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.utilities.combined_loader.html). In contrast, we fix the prediction dataset to only contain photos for simplicity since we are trying to generate Monet-style images in this notebook.

To organize all the steps described above for processing data, we define a custom `LightningDataModule`. A datamodule involves many methods, but we are mainly concerned with:
* `setup` to create the datasets and apply the corresponding transformations defined above.
* `train_dataloader` to generate the dataloader for the training dataset.
* `predict_dataloader` to generate the dataloader for the prediction dataset.

Other possible methods can be found [here](https://pytorch-lightning.readthedocs.io/en/stable/data/datamodule.html). We define the following parameters:
* `DEBUG` — to enable debugging mode in CPU.
* `MONET_DIR` and `PHOTO_DIR` — the directories where the Monet paintings and photos are loaded from.
* `BATCH_SIZE` — the number of samples in each training/prediction batch.
* `num_workers` — the number of subprocesses (excluding the main process) used for data loading.
* `pin_memory` — to enable faster data transfer to GPU during training.

In [None]:
MONET_DIR = "/kaggle/input/gan-getting-started/monet_jpg/*.jpg"
PHOTO_DIR = "/kaggle/input/gan-getting-started/photo_jpg/*.jpg"
DEBUG = not torch.cuda.is_available()
BATCH_SIZE = [1, 4] # [batch size for Monet paintings, batch size for photos]
LOADER_CONFIG = {
    "num_workers": os.cpu_count(),
    "pin_memory": torch.cuda.is_available(),
}

In [None]:
class CustomDataModule(L.LightningDataModule):
    def __init__(
        self,
        debug=DEBUG,
        monet_dir=MONET_DIR,
        photo_dir=PHOTO_DIR,
        batch_size=BATCH_SIZE,
        loader_config=LOADER_CONFIG,
        transform=CustomTransform(),
        mode="max_size_cycle",
    ):
        super().__init__()
        if isinstance(batch_size, list):
            self.batch_size = batch_size  
        else:
            self.batch_size = [batch_size] * 2
        if debug:
            idx = max(self.batch_size) * 2
        else:
            idx = None
        self.monet_filenames = sorted(glob.glob(monet_dir))[:idx]
        self.photo_filenames = sorted(glob.glob(photo_dir))[:idx]
        self.loader_config = loader_config
        self.transform = transform
        self.mode = mode
        
    def setup(self, stage):
        if stage == "fit":
            self.train_monet = CustomDataset(self.monet_filenames, self.transform, stage)
            self.train_photo = CustomDataset(self.photo_filenames, self.transform, stage)
        
        elif stage == "predict":
            self.predict = CustomDataset(self.photo_filenames, self.transform, stage)
            
    def train_dataloader(self):
        loader_monet = DataLoader(
            self.train_monet,
            shuffle=True,
            drop_last=True,
            batch_size=self.batch_size[0],
            **self.loader_config,
        )
        loader_photo = DataLoader(
            self.train_photo,
            shuffle=True, 
            drop_last=True,
            batch_size=self.batch_size[1],
            **self.loader_config,
        )
        loaders = {"monet": loader_monet, "photo": loader_photo}
        return CombinedLoader(loaders, mode=self.mode)
    
    def predict_dataloader(self):
        return DataLoader(
            self.predict,
            shuffle=False,
            drop_last=False,
            batch_size=self.batch_size[1],
            **self.loader_config,
        )

We check that the datamodule defined is working as intended by visualizing samples of the images below.

In [None]:
SAMPLE_SIZE = 5
dm_sample = CustomDataModule(batch_size=SAMPLE_SIZE)

dm_sample.setup("fit")
train_loader = dm_sample.train_dataloader()
monet_samples, photo_samples = next(iter(train_loader)).values()

dm_sample.setup("predict")
predict_loader = dm_sample.predict_dataloader()
photo__samples = next(iter(predict_loader)) # used to track performance of model during training later

show_img(monet_samples, nrow=SAMPLE_SIZE, title="Augmented Monet Paintings")
show_img(photo_samples, nrow=SAMPLE_SIZE, title="Augmented Photos")

---

# 2. Building CycleGAN Architecture

### Generator.

<img src="https://lh5.googleusercontent.com/9kNO6hxYJmpcfG5bOjnDazieeLC7Q8jZJi3gTtnJelbkOUL7Xz9e-3F_SNuxPpo4fZ4=w2400" width="600"/>

_Example of the U-Net architecture [[source](https://paperswithcode.com/method/u-net)]._

We use a U-Net architecture for the CycleGAN generator. U-Net is a network which consists of downsampling blocks and upsampling blocks with long skip connections, giving it the U-shaped architecture.

### Downsampling blocks.

The downsampling blocks use convolution layers to increase the number of feature maps while reducing the dimensions of the 2D image.

In [None]:
class Downsampling(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size=4,
        stride=2,
        padding=1,
        norm=True,
    ):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
                      stride=stride, padding=padding, bias=not norm),
        )
        if norm:
            self.block.append(nn.InstanceNorm2d(out_channels, affine=True))
        self.block.append(nn.LeakyReLU(0.3))
        
    def forward(self, x):
        return self.block(x)

### Upsampling blocks.

On the other hand, the upsampling blocks contain transposed convolution layers, which combine the learned features to output an image with the original size 256.

In [None]:
class Upsampling(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size=4,
        stride=2,
        padding=1,
        output_padding=0,
        dropout=False,
    ):
        super().__init__()
        self.block = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, 
                               padding=padding, output_padding=output_padding, bias=False),
            nn.InstanceNorm2d(out_channels, affine=True),
        )
        if dropout:
            self.block.append(nn.Dropout(0.5))
        self.block.append(nn.ReLU())
        
    def forward(self, x):
        return self.block(x)

### Building the generator.

With the building blocks defined, we can now build our CycleGAN generator. In the upsampling path, we concatenate the outputs of the upsampling blocks and the outputs of the downsampling blocks symmetrically. This can be seen as a kind of skip connection, facilitating information flow in deep networks and reducing the impact of vanishing gradients. For reference, the output size of each block is commented below.

In [None]:
class Generator(nn.Module):
    def __init__(self, in_channels, out_channels, hid_channels):
        super().__init__()
        self.downsampling_path = nn.Sequential(
            Downsampling(in_channels, hid_channels, norm=False), # 64x128x128
            Downsampling(hid_channels, hid_channels*2), # 128x64x64
            Downsampling(hid_channels*2, hid_channels*4), # 256x32x32
            Downsampling(hid_channels*4, hid_channels*8), # 512x16x16
            Downsampling(hid_channels*8, hid_channels*8), # 512x8x8
            Downsampling(hid_channels*8, hid_channels*8), # 512x4x4
            Downsampling(hid_channels*8, hid_channels*8), # 512x2x2
            Downsampling(hid_channels*8, hid_channels*8, norm=False), # 512x1x1, instance norm does not work on 1x1
        )
        self.upsampling_path = nn.Sequential(
            Upsampling(hid_channels*8, hid_channels*8, dropout=True), # (512+512)x2x2
            Upsampling(hid_channels*16, hid_channels*8, dropout=True), # (512+512)x4x4
            Upsampling(hid_channels*16, hid_channels*8, dropout=True), # (512+512)x8x8
            Upsampling(hid_channels*16, hid_channels*8), # (512+512)x16x16
            Upsampling(hid_channels*16, hid_channels*4), # (256+256)x32x32
            Upsampling(hid_channels*8, hid_channels*2), # (128+128)x64x64
            Upsampling(hid_channels*4, hid_channels), # (64+64)x128x128
        )
        self.feature_block = nn.Sequential(
            nn.ConvTranspose2d(hid_channels*2, out_channels,
                               kernel_size=4, stride=2, padding=1), # 3x256x256
            nn.Tanh(),
        )
        
    def forward(self, x):
        skips = []
        for down in self.downsampling_path:
            x = down(x)
            skips.append(x)
        skips = reversed(skips[:-1])

        for up, skip in zip(self.upsampling_path, skips):
            x = up(x)
            x = torch.cat([x, skip], dim=1)
        return self.feature_block(x)

### Discriminator.

<img src="https://lh6.googleusercontent.com/UhJiaTOQWgfHQlWq50IMGBvdkJ3NDggC449cxud8XVlSxUrule8f5LyoLUV8aaYemGw=w2400" width="300"/>

_Diagram of how the PatchGAN discriminator works [[source](https://www.researchgate.net/figure/PatchGAN-discriminator-Each-value-of-the-output-matrix-represents-the-probability-of_fig1_323904616)]._

Unlike conventional networks that output a single probability of the input image being real or fake, CycleGAN uses the PatchGAN discriminator that outputs a matrix of values. Intuitively, each value of the output matrix checks the corresponding portion of the input image. Values closer to 1 indicate real classification and values closer to 0 indicate fake classification.

### Building the discriminator.

In general, the PatchGAN discriminator consists of a sequence of convolution layers, which can be built using the downsampling blocks defined earlier. For reference, the output size of each block is commented below.

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels, hid_channels):
        super().__init__()
        self.block = nn.Sequential(
            Downsampling(in_channels, hid_channels, norm=False), # 64x128x128
            Downsampling(hid_channels, hid_channels*2), # 128x64x64
            Downsampling(hid_channels*2, hid_channels*4), # 256x32x32
            Downsampling(hid_channels*4, hid_channels*8, stride=1), # 512x31x31
            nn.Conv2d(hid_channels*8, 1, kernel_size=4, padding=1), # 1x30x30
        )
        
    def forward(self, x):
        return self.block(x)

### CycleGAN.

With the generator and discriminator defined, we can now build CycleGAN, which consists of two generators and two discriminators:
* Generator for photo-to-Monet translation (`gen_PM`).
* Generator for Monet-to-photo translation (`gen_MP`).
* Discriminator for Monet paintings (`disc_M`).
* Discriminator for photos (`disc_P`).

Using `init_weights` function, the weights of the layers in the generators and discriminators are initialized using the normal distribution and the biases are initialized to 0s. The Adam optimizer is used for model training. To optimize the parameters, we need to define the loss functions:
* **Discriminator loss** (`disc_loss`). For real images fed into the discriminator, the output matrix is compared against a matrix of 1s using the mean squared error. For fake images, the output matrix is compared against a matrix of 0s. This suggests that to minimize loss, the perfect discriminator outputs a matrix of 1s for real images and a matrix of 0s for fake images.
* **Generator loss** (`gen_loss`). This is composed of three different loss functions below.
  * *Adversarial loss* (`adv_loss`). Fake images are fed into the discriminator and the output matrix is compared against a matrix of 1s using the mean squared error. To minimize loss, the generator needs to 'fool' the discriminator into thinking that the fake images are real and output a matrix of 1s.
  * *Identity loss* (`id_loss`). When a Monet painting is fed into the photo-to-Monet generator, we should get back the same Monet painting because nothing needs to be transformed. The same applies for photos fed into the Monet-to-photo generator. To encourage identity mapping, the difference in pixel values between the input image and generated image is measured using the l1 loss.
  * *Cycle loss* (`cycle_loss`). When a Monet painting is fed into the Monet-to-photo generator, and the generated image is fed back into the photo-to-Monet generator, it should transform back into the original Monet painting. The same applies for photos passed to the two generators to get back the original photos. To preserve information throughout this cycle, the l1 loss is used to measure the difference between the original image and the cycled image.

From the above, the mean squared error and the l1 loss are defined as the adversarial criterion (`adv_criterion`) and the reconstruction criterion (`recon_criterion`) respectively.

### Building the CycleGAN model.

To organize the code for modeling, we define the above functions within the `LightningModule` class together with the following methods:
* `__init__` to initialize the two generators, the two discriminators, and other parameters.
* `forward` to generate Monet-style images given input photos.
* `configure_optimizers` to initialize the Adam optimizers and learning rate schedules. We use a constant learning rate and then linearly decay towards the end of training.
* `training_step` to compute the loss functions for the generators and discriminators.
* `training_epoch_end` to print the average values of the loss functions over the batches per epoch, and visualize the performance of `gen_PM`. We also record the values of the learning rates and losses for plotting later.
* `predict_step` to run the `forward` method during prediction.

Other useful methods can be found [here](https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html). Besides the above, we define additional methods like `get_lr_scheduler` to set up the learning rate schedules, `loss_plot` to plot the loss curves, and `lr_plot` to plot the learning rate schedules. The following parameters are set:
* `IN_CHANNELS` — the number of input channels for the generator and discriminator, which equals 3 since we are working with RGB images.
* `OUT_CHANNELS` — the number of output channels for the generator, which equals 3 as we are trying to output RGB images.
* `HID_CHANNELS` — the number of output channels in the first layer for the generator and discriminator.
* `LR` and `BETAS` — the learning rate and beta parameters for the Adam optimizer.
* `LAMBDA_ID` and `LAMBDA_CYCLE` — the weights used in the identity loss and cycle loss.
* `NUM_EPOCHS` — the number of epochs for training.
* `DECAY_EPOCHS` — the number of epochs before starting the learning rate decay.
* `DISPLAY_EPOCHS` — the frequency to visualize the performance of `gen_PM`.

In [None]:
IN_CHANNELS = 3
OUT_CHANNELS = 3
HID_CHANNELS = 64
LR = 2e-4
BETAS = (0.5, 0.999)
LAMBDA_ID = 2
LAMBDA_CYCLE = 5
NUM_EPOCHS = 36 if not DEBUG else 2
DECAY_EPOCHS = 27 if not DEBUG else 1
DISPLAY_EPOCHS = 12

In [None]:
class CycleGAN(L.LightningModule):
    def __init__(
        self, 
        in_channels=IN_CHANNELS,
        out_channels=OUT_CHANNELS, 
        hid_channels=HID_CHANNELS,
        lr=LR,
        betas=BETAS,
        lambda_id=LAMBDA_ID,
        lambda_cycle=LAMBDA_CYCLE,
        num_epochs=NUM_EPOCHS,
        decay_epochs=DECAY_EPOCHS,
        display_epochs=DISPLAY_EPOCHS,
        photo_samples=photo__samples,
    ):
        super().__init__()
        self.lr = lr
        self.betas = betas
        self.lambda_id = lambda_id
        self.lambda_cycle = lambda_cycle
        self.num_epochs = num_epochs
        self.decay_epochs = decay_epochs
        self.display_epochs = display_epochs
        self.photo_samples = photo_samples.to("cuda" if torch.cuda.is_available() else "cpu")
        
        # record learning rates and losses
        self.lr_history = [self.lr]
        self.loss_names = ["gen_loss_PM", "gen_loss_MP", "disc_loss_M", "disc_loss_P"]
        self.loss_history = {loss: [] for loss in self.loss_names}
        
        # initialize generators and discriminators
        self.gen_PM = Generator(in_channels, out_channels, hid_channels)
        self.gen_MP = Generator(in_channels, out_channels, hid_channels)
        self.disc_M = Discriminator(in_channels, hid_channels)
        self.disc_P = Discriminator(in_channels, hid_channels)
        self.init_weights()
        
    def forward(self, z):
        return self.gen_PM(z)
                
    def init_weights(self):
        def init_fn(m):
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
                nn.init.normal_(m.weight, 0.0, 0.02)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0)
            elif isinstance(m, nn.InstanceNorm2d):
                nn.init.normal_(m.weight, 0.0, 0.02)
                nn.init.constant_(m.bias, 0.0)
                
        self.gen_PM = self.gen_PM.apply(init_fn)
        self.gen_MP = self.gen_MP.apply(init_fn)
        self.disc_M = self.disc_M.apply(init_fn)
        self.disc_P = self.disc_P.apply(init_fn)
        
    def adv_criterion(self, y_hat, y):
        return F.mse_loss(y_hat, y)
    
    def recon_criterion(self, y_hat, y):
        return F.l1_loss(y_hat, y)
    
    def adv_loss(self, fake_Y, disc_Y):
        fake_Y_hat = disc_Y(fake_Y)
        valid = torch.ones_like(fake_Y_hat)
        adv_loss_XY = self.adv_criterion(fake_Y_hat, valid)
        return adv_loss_XY
    
    def id_loss(self, real_Y, gen_XY):
        id_Y = gen_XY(real_Y)
        id_loss_Y = self.recon_criterion(id_Y, real_Y)
        return self.lambda_id * id_loss_Y
    
    def cycle_loss(self, real_Y, fake_X, gen_XY):
        cycle_Y = gen_XY(fake_X)
        cycle_loss_Y = self.recon_criterion(cycle_Y, real_Y)
        return self.lambda_cycle * cycle_loss_Y
    
    def gen_loss(self, real_X, real_Y, gen_XY, gen_YX, disc_Y):
        fake_Y = gen_XY(real_X)
        fake_X = gen_YX(real_Y)
        
        adv_loss_XY = self.adv_loss(fake_Y, disc_Y)
        id_loss_Y = self.id_loss(real_Y, gen_XY)
        cycle_loss_Y = self.cycle_loss(real_Y, fake_X, gen_XY)
        cycle_loss_X = self.cycle_loss(real_X, fake_Y, gen_YX)
        total_cycle_loss = cycle_loss_X + cycle_loss_Y
        
        gen_loss_XY = adv_loss_XY + id_loss_Y + total_cycle_loss
        return gen_loss_XY
    
    def disc_loss(self, real_Y, fake_Y, disc_Y):
        real_Y_hat = disc_Y(real_Y)
        valid = torch.ones_like(real_Y_hat)
        real_loss_Y = self.adv_criterion(real_Y_hat, valid)
        
        fake_Y_hat = disc_Y(fake_Y.detach())
        fake = torch.zeros_like(fake_Y_hat)
        fake_loss_Y = self.adv_criterion(fake_Y_hat, fake)
        
        disc_loss_Y = (fake_loss_Y+real_loss_Y) * 0.5
        return disc_loss_Y
    
    def get_lr_scheduler(self, optimizer):
        def lr_lambda(epoch):
            val = 1.0 - max(0, epoch-self.decay_epochs+1.0)/(self.num_epochs-self.decay_epochs+1.0)
            return max(0.0, val)
        return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
    
    def configure_optimizers(self):
        params = {
            "lr": self.lr,
            "betas": self.betas,
        }
        opt_gen_PM = torch.optim.Adam(self.gen_PM.parameters(), **params)
        opt_gen_MP = torch.optim.Adam(self.gen_MP.parameters(), **params)
        opt_disc_M = torch.optim.Adam(self.disc_M.parameters(), **params)
        opt_disc_P = torch.optim.Adam(self.disc_P.parameters(), **params)
        optimizers = [opt_gen_PM, opt_gen_MP, opt_disc_M, opt_disc_P]
        schedulers = [self.get_lr_scheduler(opt) for opt in optimizers]
        return optimizers, schedulers
    
    def training_step(self, batch, batch_idx, optimizer_idx):
        real_M = batch["monet"]
        real_P = batch["photo"]
        if optimizer_idx == 0:
            gen_loss_PM = self.gen_loss(real_P, real_M, 
                                        self.gen_PM, self.gen_MP, self.disc_M)
            return gen_loss_PM
        
        if optimizer_idx == 1:
            gen_loss_MP = self.gen_loss(real_M, real_P,
                                        self.gen_MP, self.gen_PM, self.disc_P)
            return gen_loss_MP
        
        if optimizer_idx == 2:
            with torch.no_grad():
                fake_M = self.gen_PM(real_P)
            disc_loss_M = self.disc_loss(real_M, fake_M,
                                         self.disc_M)
            return disc_loss_M
        
        if optimizer_idx == 3:
            with torch.no_grad():
                fake_P = self.gen_MP(real_M)
            disc_loss_P = self.disc_loss(real_P, fake_P,
                                         self.disc_P)
            return disc_loss_P
    
    def training_epoch_end(self, outputs):
        current_epoch = self.current_epoch + 1
        
        lr = self.lr_schedulers()[0].get_last_lr()[0]
        self.lr_history.append(lr)
        losses = {}
        for j in range(4):
            loss = [out[j]["loss"].item() for out in outputs]
            self.loss_history[self.loss_names[j]].extend(loss)
            losses[self.loss_names[j]] = np.mean(loss)
        print(
            " - ".join([
                f"Epoch {current_epoch}",
                f"lr: {self.lr_history[-2]:.5f}",
                *[f"{loss}: {val:.5f}" for loss, val in losses.items()],
            ])
        )
        
        if current_epoch%self.display_epochs==0 or current_epoch in [1, self.num_epochs]:
            torch.set_grad_enabled(False)
            self.eval()
            gen_monets = self.forward(self.photo_samples)
            show_img(
                torch.cat([self.photo_samples, gen_monets]),
                nrow=len(self.photo_samples),
                title=f"Epoch {current_epoch}: Photo-to-Monet Translation",
            )
            torch.set_grad_enabled(True)
            self.train()
            
    def predict_step(self, batch, batch_idx):
        return self.forward(batch)
    
    def lr_plot(self):
        num_epochs = len(self.lr_history[:-1])
        plt.figure(figsize=(18, 4.5))
        plt.title("Learning Rate Schedule")
        plt.ylabel("Learning Rate")
        plt.xlabel("Epoch")
        plt.plot(
            np.arange(1, num_epochs+1),
            self.lr_history[:-1],
        )
            
    def loss_plot(self):
        titles = ["Generator Loss Curves", "Discriminator Loss Curves"]
        num_steps = len(list(self.loss_history.values())[0])
        plt.figure(figsize=(18, 4.5))
        for j in range(4):
            if j%2 == 0:
                plt.subplot(1, 2, (j//2)+1)
                plt.title(titles[j//2])
                plt.ylabel("Loss")
                plt.xlabel("Step")
            plt.plot(
                np.arange(1, num_steps+1),
                self.loss_history[self.loss_names[j]],
                label=self.loss_names[j],
            )
            plt.legend(loc="upper right")

---

# 3. Model Training

To start training the model, we use `Trainer` to automatically handle the training loop by running the `fit` method. We set the below parameters in `TRAIN_CONFIG` for training.

In [None]:
TRAIN_CONFIG = {
    "accelerator": "gpu" if not DEBUG else "cpu",
    "devices": 1,
    "logger": False,
    "enable_checkpointing": True,
    "max_epochs": NUM_EPOCHS,
    "precision": 16 if not DEBUG else 32,
}

In [None]:
dm = CustomDataModule()
model = CycleGAN()
trainer = L.Trainer(**TRAIN_CONFIG)

trainer.fit(model, datamodule=dm)

### Plotting the learning rate schedule.

In [None]:
model.lr_plot()

### Plotting the loss curves.

In [None]:
model.loss_plot()

---

# 4. Submission

Computing the predictions can be done by running the `predict` method to generate the Monet-style images given the input photos.

In [None]:
predictions = trainer.predict(model, datamodule=dm)

### Saving the generated images.

In [None]:
os.makedirs("../images", exist_ok=True)

idx = 0
for tensor in predictions:
    for monet in tensor:
        save_image((monet.float().squeeze()*0.5+0.5), fp=f"../images/{idx}.jpg")
        idx += 1

shutil.make_archive("/kaggle/working/images", "zip", "/kaggle/images")

### Examining the results on other photo samples.

In [None]:
torch.set_grad_enabled(False)
model.eval()

for j, photos in enumerate(iter(predict_loader)):
    if j == 5:
        break
    gen_monets = model(photos)
    show_img(
        torch.cat([photos, gen_monets]),
        nrow=len(photos),
        title=f"Sample {j+1}: Photo-to-Monet Translation",
    )

 ---