# Monet Painting Generator using CycleGAN and PyTorch

This project aims to implement the CycleGAN archicture for the task of translating regular photos to Monet-style paintings. The architectures implemented from the paper Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks (https://arxiv.org/abs/1703.10593).

## Setup and Helper Functions

In [None]:
import os
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.nn as nn
import itertools

import torch
import glob
import random
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
def convert_to_RGB(image):
    rgb_image = Image.new("RGB", image.size)
    rgb_image.paste(image)
    return rgb_image

The following code implements the replay buffer, which is used by the paper to improve the training stability of the discriminator

In [None]:
class ReplayBuffer:
    # Create image buffer to store previous 50 images
    def __init__(self, max_size=50):
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0, 1) > 0.5:
                    i = random.randint(0, self.max_size - 1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return Variable(torch.cat(to_return))

This initializes the weights of the model according to the paper

In [None]:
def initialize_conv_weights_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
        if hasattr(m, "bias") and m.bias is not None:
            torch.nn.init.constant_(m.bias.data, 0.0)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

Class for our custom dataset

In [None]:
class ImageDataset(Dataset):
    def __init__(self, root, transforms_=None, mode="train"):
        self.transform = transforms.Compose(transforms_)

        if mode == 'train':
            self.monet_files = sorted(glob.glob(os.path.join(root, "monet_jpg") + "/*.*")[:600])
            self.photo_files = sorted(glob.glob(os.path.join(root, "photo_jpg") + "/*.*")[:250])
        elif mode == 'test':
            self.monet_files = sorted(glob.glob(os.path.join(root, "monet_jpg") + "/*.*")[250:])
            self.photo_files = sorted(glob.glob(os.path.join(root, "photo_jpg") + "/*.*")[250:300])
        elif mode == 'all':
            self.monet_files = sorted(glob.glob(os.path.join(root, "monet_jpg") + "/*.*"))
            self.photo_files = sorted(glob.glob(os.path.join(root, "photo_jpg") + "/*.*"))

    def __getitem__(self, index):
        monet = Image.open(self.monet_files[index % len(self.monet_files)])
        photo = Image.open(self.photo_files[random.randint(0, len(self.photo_files) - 1)])

        if monet.mode != "RGB":
            monet = convert_to_RGB(monet)
        if photo.mode != "RGB":
            photo = convert_to_RGB(photo)

        monet = self.transform(monet)
        photo = self.transform(photo)

        return (monet.float(), photo.float())

    def __len__(self):
        return max(len(self.monet_files), len(self.photo_files))

## Building the Network

In [None]:
# Residual block with two convolution layers
class ResidualBlock(nn.Module):
    def __init__(self, in_channel):
        super(ResidualBlock, self).__init__()

        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_channel, in_channel, 3),
            nn.InstanceNorm2d(in_channel),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_channel, in_channel, 3),
            nn.InstanceNorm2d(in_channel),
        )

    def forward(self, x):
        return x + self.block(x)


# Generator from CycleGAN paper
# c7s1-64,d128,d256,R256,R256,R256, R256,R256,R256,R256,R256,R256,u128 u64,c7s1-3
class GeneratorResNet(nn.Module):
    def __init__(self, input_shape, num_residual_blocks):
        super(GeneratorResNet, self).__init__()

        channels = input_shape[0]

        # Initial convolution block
        out_channels = 64
        model = [
            nn.ReflectionPad2d(channels),
            nn.Conv2d(channels, out_channels, kernel_size = 7),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True),
        ]
        in_channels = out_channels

        # Downsampling (Encoder)
        for _ in range(2):
            out_channels *= 2
            model += [
                nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = 2, padding = 1),
                nn.InstanceNorm2d(out_channels),
                nn.ReLU(inplace=True),
            ]
            in_channels = out_channels

        # Residual blocks
        for _ in range(num_residual_blocks):
            model += [ResidualBlock(out_channels)]

        # Upsampling (Decoder)
        for _ in range(2):
            out_channels //= 2
            model += [
                nn.Upsample(scale_factor = 2),
                nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1),
                nn.InstanceNorm2d(out_channels),
                nn.ReLU(inplace=True),
            ]
            in_channels = out_channels

        # Output layer
        model += [
            nn.ReflectionPad2d(channels),
            nn.Conv2d(out_channels, channels, 7),
            nn.Tanh(),
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


# Discriminator from CycleGAN paper
# C64-C128-C256-C512
class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()

        channels, height, width = input_shape

        # Calculate output shape of discriminator
        self.output_shape = (1, height //2 ** 4, width // 2 ** 4)

        def discriminator_block(in_channels, out_channels, normalize=True):
            layers = [
                nn.Conv2d(in_channels, out_channels, kernel_size = 4, stride = 2, padding = 1)
            ]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_channels))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(channels, out_channels=64, normalize=False),
            *discriminator_block(64, out_channels=128),
            *discriminator_block(128, out_channels=256),
            *discriminator_block(256, out_channels=512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, padding=1)
        )

    def forward(self, img):
        return self.model(img)

## Training

Training function

In [None]:
def train(gen_G, gen_F, dX, dY, dataloader, n_epochs, id_loss, cycle_loss, gan_loss, 
          lambda_cycle, lambda_id, optimizer_G, optimizer_dX, optimizer_dY, buffer_X, buffer_Y, device):
    for epoch in range(n_epochs):
        for i, (monet, photo) in enumerate(dataloader):
            monet = monet.to(device)
            photo = photo.to(device)

            valid = torch.from_numpy(np.ones((monet.size(0), *dX.output_shape), dtype="float32")).to(device)
            generated = torch.from_numpy(np.zeros((monet.size(0), *dX.output_shape), dtype="float32")).to(device)

            # TRAIN GENERATORS
            gen_G.train()
            gen_F.train()
            optimizer_G.zero_grad()
            
            # Identity Loss
            id_loss_G = id_loss(gen_G(photo), photo)
            id_loss_F = id_loss(gen_F(monet), monet)
            id_loss_avg = (id_loss_G + id_loss_F)/2
            
            # GAN Loss
            generated_monet = gen_G(photo)
            generated_photo = gen_F(monet)
            gan_loss_G = gan_loss(dY(generated_monet), valid)
            gan_loss_F = gan_loss(dX(generated_photo), valid)
            gan_loss_avg = (gan_loss_G + gan_loss_F)/2
            
            # Cycle Consistency Loss
            cycle_loss_G = cycle_loss(gen_F(generated_monet), photo)
            cycle_loss_F = cycle_loss(gen_G(generated_photo), monet)
            cycle_loss_avg = (cycle_loss_G + cycle_loss_F)/2

            generator_loss = gan_loss_avg + lambda_id * id_loss_avg + lambda_cycle * cycle_loss_avg
            generator_loss.backward()
            optimizer_G.step()

            # TRAIN DISCRIMINATOR X
            optimizer_dX.zero_grad()
            real_loss = gan_loss(dX(photo), valid)
            generated_photo_ = buffer_X.push_and_pop(generated_photo)
            generated_loss = gan_loss(dX(generated_photo_), generated)
            dX_loss = (real_loss + generated_loss)/2
            dX_loss.backward()
            optimizer_dX.step()

            # TRAIN DISCRIMINATOR Y
            optimizer_dY.zero_grad()
            real_loss = gan_loss(dY(monet), valid)
            generated_monet_ = buffer_Y.push_and_pop(generated_monet)
            generated_loss = gan_loss(dY(generated_monet_), generated)
            dY_loss = (real_loss + generated_loss)/2
            dY_loss.backward()
            optimizer_dY.step()

            d_loss = (dX_loss + dY_loss)/2
            
            if (i + 1) % 15 == 0:
                print(f'[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(dataloader)}] [Discriminator Loss {d_loss.item()}] [Generator Loss {generator_loss.item()}]')
            """
            if (epoch + 1) % 5 == 1 and i == 62:
                generated_monet = generated_monet/2 + 0.5
                photo = photo/2 + 0.5
                generated_monet = np.transpose(generated_monet.detach().cpu().numpy()[0, :, :, :])
                photo = np.transpose(photo.detach().cpu().numpy()[0, :, :, :])
                plt.imshow(generated_monet)
                plt.show()
                plt.imshow(photo)
                plt.show()
            """
        """
        torch.save(gen_G.state_dict(), "generator_G")
        torch.save(gen_F.state_dict(), "generator_F")
        """

Parameters for training

In [None]:
params = {
    "n_epochs": 150,
    "batch_size": 4,
    "lr": 0.0002,
    "b1": 0.5,
    "b2": 0.999,
    "img_size": 256,
    "channels": 3,
    "num_residual_blocks": 12,
    "lambda_cycle": 10.0,
    "lambda_id": 5.0
}

Implement the dataloader

In [None]:
root = "/kaggle/input/gan-getting-started"

transforms_ = [
    transforms.Resize((params['img_size'], params['img_size']), Image.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]

train_dataloader = DataLoader(
    ImageDataset(root, mode="train", transforms_=transforms_),
    batch_size=params['batch_size'],
    shuffle=True,
    num_workers=1,
)

Implement other components necessary for training, including the loss functions, models, buffers, and optimizers

In [None]:
gan_loss = torch.nn.MSELoss()
cycle_loss = torch.nn.L1Loss()
id_loss = torch.nn.L1Loss()
input_shape = (params['channels'], params['img_size'], params['img_size'])

gen_G = GeneratorResNet(input_shape, params['num_residual_blocks'])
gen_F = GeneratorResNet(input_shape, params['num_residual_blocks'])
dX = Discriminator(input_shape) 
dY = Discriminator(input_shape)

gen_G = gen_G.to(device)
gen_F = gen_F.to(device)
dX = dX.to(device)
dY = dY.to(device)

gen_G.apply(initialize_conv_weights_normal)
gen_F.apply(initialize_conv_weights_normal)
dX.apply(initialize_conv_weights_normal)
dY.apply(initialize_conv_weights_normal)

buffer_X = ReplayBuffer()
buffer_Y = ReplayBuffer()

optimizer_G = torch.optim.Adam(
    itertools.chain(gen_G.parameters(), gen_F.parameters()),
    lr=params['lr'],
    betas=(params['b1'], params['b2']),
)
optimizer_dX = torch.optim.Adam(dX.parameters(), lr=params['lr'], betas=(params['b1'], params['b2']))
optimizer_dY = torch.optim.Adam(dY.parameters(), lr=params['lr'], betas=(params['b1'], params['b2']))

Traing the network!

In [None]:
train(gen_G, gen_F, dX, dY, train_dataloader, params['n_epochs'], id_loss, cycle_loss, gan_loss, params['lambda_cycle'], params['lambda_id'], optimizer_G, optimizer_dX, optimizer_dY, buffer_X, buffer_Y, device)

## Submission

In [None]:
import PIL
! mkdir ../images

In [None]:
submit_dataloader = DataLoader(ImageDataset(root, transforms_, "all"), batch_size=1, shuffle=False)

In [None]:
gen_G.eval()

for i, (monet, photo) in enumerate(submit_dataloader):
    outputs = gen_G(photo.to(device))
    outputs = np.transpose(outputs.cpu().detach().numpy(), [0, 2, 3, 1])
    outputs = outputs / 2 + 0.5
    output = (outputs[0, :, :, :] * 255).astype(np.uint8)
    im = Image.fromarray(output).convert('RGB')
    im.save(f'../images/output_img_{i}.jpg')
    if (i + 1) % 100 == 1:
        print(f"Progress: {i}")

In [None]:
import shutil
shutil.make_archive("/kaggle/working/images", 'zip', "/kaggle/images")