# Project 4 - Group 17: Creates artwork using a cycle GAN similar to Monet's style


> Charles Morris ( cmorris95 )
> Ryan Smith ( rsmit300 )
> Jeffrey Fortune ( jfortun3 )
> Kyle Shannon ( kshannon5 )
> Simon Boka ( sboka )


**Requirements**

In [None]:
!pip install opencv-contrib-python Pillow matplotlib torch torchvision pytorch-fid

In [None]:
import os
import cv2
import glob
import json
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
from time import time
import shutil
import zipfile
import itertools
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torchvision.utils import make_grid, save_image
from pytorch_fid import fid_score

**Global Constants**

In [None]:
BATCH_SIZE = 5

MONET_IMAGES = glob.glob('/kaggle/input/gan-getting-started/monet_jpg/*.jpg')
TEST_IMAGES = glob.glob('/kaggle/input/gan-getting-started/photo_jpg/*.jpg')
print("Total Monet Images:", len(MONET_IMAGES), "Total Test Images:", len(TEST_IMAGES))

SAVED_MODEL_PATH = '/kaggle/working/monet_cyclegan_model_final.pth'
MODEL_CHECKPOINT = '/kaggle/working/monet_cyclegan_model.pth'
GENERATED_MONET_IMAGES = '/kaggle/working/submission'
SUBMISSION_FILE = '/kaggle/working/images.zip'
METRICS_FILE = '/kaggle/working/monet_cyclegan_metrics.json'
FID_SCORE_DIRS = ['/kaggle/working/fid_score_real_images', '/kaggle/working/fid_score_generated_images']

**Hyperparameters**

In [None]:
lr = 0.00014
beta1 = 0.5
beta2 = 0.999
n_epoches = 150
decay_epoch = 50
display_epoch = 10

In [None]:
# Check for CUDA availability and define the Tensor type accordingly
Tensor = torch.Tensor
device = torch.device("cpu")
cuda_available = torch.cuda.is_available()
if cuda_available:
    Tensor = torch.cuda.FloatTensor
    device = torch.device("cuda")
print(f'CUDA Available: {cuda_available}')

In [None]:
class ImageDataset(Dataset):
    def __init__(self, test=False, transforms=None):
        self.transforms = transforms

        if test:
            self.monet_dataset = MONET_IMAGES[250:]
            self.photo_dataset = TEST_IMAGES[250:301]
        else:
            self.monet_dataset = MONET_IMAGES[:250]
            self.photo_dataset = TEST_IMAGES[:250]

    def __len__(self):
        return min(len(self.monet_dataset), len(self.photo_dataset))

    def __getitem__(self, index):
        monet_item =  Image.open(self.monet_dataset[index])
        photo_item =  Image.open(self.photo_dataset[index])

        if self.transforms is not None:
            monet_item = self.transforms(monet_item)
            photo_item = self.transforms(photo_item)
        return photo_item, monet_item 

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels):
        super(Discriminator, self).__init__()
        self.scale_factor = 16

        self.model = nn.Sequential(
            nn.Conv2d(in_channels, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, stride=2, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ZeroPad2d((1,0,1,0)),
            nn.Conv2d(512, 1, 4, padding=1),
        )

    def forward(self, x):
        return self.model(x)

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_channels, in_channels, 3),
            nn.InstanceNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_channels, in_channels, 3),
            nn.InstanceNorm2d(in_channels)
        )

    def forward(self, x):
        return x + self.block(x)

class GeneratorResNet(nn.Module):
    def __init__(self, in_channels, num_residual_blocks=9):
        super(GeneratorResNet, self).__init__()

        self.initial = nn.Sequential(
            nn.ReflectionPad2d(in_channels),
            nn.Conv2d(in_channels, 64, 2 * in_channels + 1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )

        self.downsample_blocks = nn.Sequential(
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True)
        )

        self.residual_blocks = nn.Sequential(*[ResidualBlock(256) for _ in range(num_residual_blocks)])

        self.upsample_blocks = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 128, 3, stride=1, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )

        self.output = nn.Sequential(
            nn.ReflectionPad2d(in_channels),
            nn.Conv2d(64, in_channels, 2 * in_channels + 1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.initial(x)
        x = self.downsample_blocks(x)
        x = self.residual_blocks(x)
        x = self.upsample_blocks(x)
        return self.output(x)

In [None]:
class CycleGAN(nn.Module):
    def __init__(self, in_channels, num_residual_blocks=9):
        super(CycleGAN, self).__init__()
        # Initialize the generator and discriminator models
        self.monet_generator = GeneratorResNet(in_channels, num_residual_blocks).to(device)
        self.photo_generator = GeneratorResNet(in_channels, num_residual_blocks).to(device)
        self.monet_discriminator = Discriminator(in_channels).to(device)
        self.photo_discriminator = Discriminator(in_channels).to(device)

    def forward(self, photo, monet):
        with torch.no_grad():
            fake_monet = self.monet_generator(photo)
            fake_photo = self.photo_generator(monet)
        return fake_monet, fake_photo

    def save_model(self, filename):
        torch.save(self.state_dict(), filename)

    def load_model(self, filename):
        model_state = torch.load(filename, map_location=device)
        self.load_state_dict(model_state)
        self.to(device)


In [None]:
class MonetLoss:
    def __init__(self, model:CycleGAN) -> None:
        self.gan_loss = nn.MSELoss().to(device)
        self.cycle_loss = nn.L1Loss().to(device)
        self.identity_loss = nn.L1Loss().to(device)
        self.model = model


    def calculate_loss(self, real_photo, real_monet, reconstructed_photo, reconstructed_monet, fake_photo, fake_monet):
        loss_identity = self.calculate_identity_loss(real_photo, real_monet)

        labeled_fake_monet = self.model.photo_discriminator(fake_monet)
        labeled_fake_photo = self.model.monet_discriminator(fake_photo)
        loss_G_XtoY = self.gan_loss(labeled_fake_monet, torch.ones_like(labeled_fake_monet))
        loss_G_YtoX = self.gan_loss(labeled_fake_photo, torch.ones_like(labeled_fake_photo))

        loss_cycle_XYX = self.cycle_loss(reconstructed_photo, real_photo)
        loss_cycle_YXY = self.cycle_loss(reconstructed_monet, real_monet)

        total_loss_G = loss_G_XtoY + loss_G_YtoX + loss_cycle_XYX + loss_cycle_YXY + loss_identity

        total_loss_D_X = self.calculate_discriminator_loss(self.model.monet_discriminator, real_photo, fake_photo)

        total_loss_D_Y = self.calculate_discriminator_loss(self.model.photo_discriminator, real_monet, fake_monet)

        torch.cuda.empty_cache() # Attemting to clear cuda cache

        return total_loss_G, total_loss_D_X, total_loss_D_Y

    
    def calculate_identity_loss(self, real_photo, real_monet):
        identity_monet = self.model.monet_generator(real_monet)
        identity_photo = self.model.photo_generator(real_photo)
        loss_identity_monet = self.identity_loss(identity_monet, real_monet)
        loss_identity_photo = self.identity_loss(identity_photo, real_photo)
        total_loss = loss_identity_monet + loss_identity_photo
        return total_loss

    
    def calculate_discriminator_loss(self, discriminator, real_item, fake_item):
        item_labeled = discriminator(real_item.detach())
        loss_real_item = self.gan_loss(item_labeled, torch.ones_like(item_labeled))
        loss_fake_item = self.gan_loss(discriminator(fake_item.detach()), torch.zeros_like(item_labeled))
        total_loss = (loss_real_item + loss_fake_item) / 2
        return total_loss


In [None]:
class MonetTrainer:
    def __init__(self, model:CycleGAN, epochs, lr, beta1, beta2, decay_epoch) -> None:
        self.metrics = []
        self.model = model
        self.losses = MonetLoss(model)
        self.epochs = epochs
        self.init_optimizers(lr, beta1, beta2)
        self.init_lr_schedulers(epochs, decay_epoch)

  
    def train_step(self, real_photo, real_monet):
        real_photo = real_photo.type(Tensor)
        real_monet = real_monet.type(Tensor)

        # Generate fake images
        fake_monet = self.model.monet_generator(real_photo)
        fake_photo = self.model.photo_generator(real_monet)

        # Cycle Consistency
        reconstructed_X = self.model.photo_generator(fake_monet)
        reconstructed_Y = self.model.monet_generator(fake_photo)

        # Calculate losses
        total_loss_G, total_loss_D_X, total_loss_D_Y = self.losses.calculate_loss(real_photo, real_monet, reconstructed_X, reconstructed_Y, fake_photo, fake_monet)

        # Update generators
        self.update_generators(total_loss_G)

        # Update Discriminators
        self.update_discriminators(total_loss_D_X, total_loss_D_Y)

        torch.cuda.empty_cache() # Attemting to clear cuda cache
        return total_loss_G, total_loss_D_X, total_loss_D_Y, fake_photo, fake_monet

    
    def update_generators(self, total_loss_G):
        self.optimizer_G.zero_grad()
        total_loss_G.backward()
        self.optimizer_G.step()

        
    def update_discriminators(self, total_loss_D_X, total_loss_D_Y):
        # Update Photo Discriminator
        self.optimizer_D_X.zero_grad()
        total_loss_D_X.backward()
        self.optimizer_D_X.step()

        # Update Monet Discriminator
        self.optimizer_D_Y.zero_grad()
        total_loss_D_Y.backward()
        self.optimizer_D_Y.step()

        
    def train(self, dataloader, display_epoch: int=10):
        # Training loop
        for epoch in range(self.epochs):
            self.train_epoch(dataloader, epoch, display_epoch)
            self.update_lr_schedulers()

            
    def update_lr_schedulers(self):
        self.lr_scheduler_G.step()
        self.lr_scheduler_D_X.step()
        self.lr_scheduler_D_Y.step()

        
    def train_epoch(self, dataloader, epoch, display_epoch):
        for batch_idx, (real_photo, real_monet) in enumerate(dataloader):
            real_photo, real_monet = real_photo.type(Tensor).detach(), real_monet.type(Tensor).detach()

            total_loss_G, total_loss_D_X, total_loss_D_Y, fake_photo, fake_monet = self.train_step(real_photo, real_monet)

            self.metrics.append({
                    "Epoch": epoch,
                    "Batch": batch_idx,
                    "Generator LR": self.lr_scheduler_G.get_last_lr()[0],
                    "Discriminator X LR": self.lr_scheduler_D_X.get_last_lr()[0],
                    "Discriminator Y LR": self.lr_scheduler_D_Y.get_last_lr()[0],
                    "Generator Loss": total_loss_G.item(),
                    "Discriminator X Loss": total_loss_D_X.item(),
                    "Discriminator Y Loss": total_loss_D_Y.item()  
            })
                
            if batch_idx % display_epoch == 0:
                self.show_metrics(dataloader)
                
            if epoch + 1 % display_epoch == 0:
                MonetUtils.sample_images(real_photo, real_monet, fake_photo, fake_monet)
                self.model.save_model(MODEL_CHECKPOINT)
            
            torch.cuda.empty_cache() # Attemting to clear cuda cache

    def init_lr_schedulers(self, epochs, decay_epoch):
        lr_lambda = lambda epoch: 1.0 - max(0, epoch - decay_epoch) / (epochs - decay_epoch)
        self.lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(self.optimizer_G, lr_lambda=lr_lambda)
        self.lr_scheduler_D_X = torch.optim.lr_scheduler.LambdaLR(self.optimizer_D_X, lr_lambda=lr_lambda)
        self.lr_scheduler_D_Y = torch.optim.lr_scheduler.LambdaLR(self.optimizer_D_Y, lr_lambda=lr_lambda)

        
    def init_optimizers(self, lr, beta1, beta2):
        self.optimizer_G = torch.optim.Adam(itertools.chain(self.model.monet_generator.parameters(), self.model.photo_generator.parameters()), lr=lr, betas=(beta1, beta2))
        self.optimizer_D_X = torch.optim.Adam(self.model.monet_discriminator.parameters(), lr=lr, betas=(beta1, beta2))
        self.optimizer_D_Y = torch.optim.Adam(self.model.photo_discriminator.parameters(), lr=lr, betas=(beta1, beta2))

        
    def show_metrics(self, dataloader):
        latest_metrics = self.metrics[-1]

        print(f'Epoch: {latest_metrics["Epoch"]}/{self.epochs}, '
                        f'Batch: {latest_metrics["Batch"]}/{len(dataloader)}, '
                        f'Generator LR: {latest_metrics["Generator LR"]:.6f}, '
                        f'Discriminator X LR: {latest_metrics["Discriminator X LR"]:.6f}, '
                        f'Discriminator Y LR: {latest_metrics["Discriminator Y LR"]:.6f}, '
                        f'Generator Loss: {latest_metrics["Generator Loss"]:.4f}, '
                        f'Discriminator X Loss: {latest_metrics["Discriminator X Loss"]:.4f}, '
                        f'Discriminator Y Loss: {latest_metrics["Discriminator Y Loss"]:.4f}')
        

    def evaluate_fid(self, dataloader, batch_size=50, dims=2048):
        os.makedirs(FID_SCORE_DIRS[0], exist_ok=True)
        os.makedirs(FID_SCORE_DIRS[1], exist_ok=True)

        self.model.monet_generator.eval()

        with torch.no_grad():
            for i, (real_images, _) in enumerate(dataloader):
                fake_images = self.model.monet_generator(real_images.type(Tensor)).detach()

                for j in range(real_images.size(0)):
                    save_image(real_images[j], os.path.join(FID_SCORE_DIRS[0], f'real_{i * dataloader.batch_size + j}.png'))
                    save_image(fake_images[j], os.path.join(FID_SCORE_DIRS[1], f'fake_{i * dataloader.batch_size + j}.png'))

                if (i + 1) * dataloader.batch_size >= batch_size:
                    break

        fid = fid_score.calculate_fid_given_paths([FID_SCORE_DIRS[0], FID_SCORE_DIRS[1]], batch_size=batch_size, device=device, dims=dims)
        return fid
                  

In [None]:
class MonetUtils:
    @staticmethod
    def sample_images(real_photo, real_monet, fake_photo, fake_monet):

        real_photo = real_photo.type(Tensor)
        fake_photo = fake_photo.type(Tensor)
        real_monet = real_monet.type(Tensor)
        fake_monet = fake_monet.type(Tensor)

        ncols = real_photo.size(0)
        real_photo_grid = make_grid(real_photo, nrow=ncols, normalize=True)
        fake_monet_grid = make_grid(fake_monet, nrow=ncols, normalize=True)
        real_monet_grid = make_grid(real_monet, nrow=ncols, normalize=True)
        fake_photo_grid = make_grid(fake_photo, nrow=ncols, normalize=True)

        fig, axs = plt.subplots(2, 2, figsize=(3*BATCH_SIZE, BATCH_SIZE))

        axs[0, 0].imshow(real_photo_grid.permute(1, 2, 0).cpu().numpy())
        axs[0, 0].set_title("Real Photos")
        axs[0, 0].axis('off')

        axs[0, 1].imshow(fake_monet_grid.permute(1, 2, 0).cpu().numpy())
        axs[0, 1].set_title("Generated Monet Arts from Photos")
        axs[0, 1].axis('off')

        axs[1, 0].imshow(real_monet_grid.permute(1, 2, 0).cpu().numpy())
        axs[1, 0].set_title("Real Monet Arts")
        axs[1, 0].axis('off')

        axs[1, 1].imshow(fake_photo_grid.permute(1, 2, 0).cpu().numpy())
        axs[1, 1].set_title("Generated Photos from Monet Arts")
        axs[1, 1].axis('off')

        plt.tight_layout(h_pad=2.0, w_pad=1.0)
        plt.show()

    @staticmethod    
    def transform(image_path, generator):
        generator.eval()
        image = Image.open(image_path).convert("RGB")
        image = transforms_dataset(image).unsqueeze(0).to(device)
        with torch.no_grad():
            output = generator(image).squeeze(0)
        image = image.squeeze(0)    
        return output, image
    
    def transform_to_monet(photo_path):
        pass
    
    def transform_to_photo(painting_path):
        pass
    
    
    @staticmethod
    def view_images(real_image, generate_image, title='Original and Generated Images'):
        real_image = (real_image + 1) / 2
        real_image = torch.clamp(real_image, 0, 1)
        generate_image = (generate_image + 1) / 2
        generate_image = torch.clamp(generate_image, 0, 1)
        
        grid = make_grid([real_image, generate_image])
        grid = grid.permute(1, 2, 0).cpu().numpy()
        
        plt.figure(figsize=(5, 4))
        plt.imshow(grid)
        plt.title(title)
        plt.axis('off')
        plt.tight_layout(h_pad=0.0, w_pad=1.0)
        plt.show()
       
    
    @staticmethod
    def make_submission():
        os.makedirs(GENERATED_MONET_IMAGES, exist_ok=True)
        start = time()

        with zipfile.ZipFile(SUBMISSION_FILE, 'w', zipfile.ZIP_DEFLATED) as zipf:
            for i, file_path in enumerate(TEST_IMAGES):
                filename = os.path.basename(file_path)
                output, _ = MonetUtils.transform_to_monet(file_path)
                file_path = os.path.join(GENERATED_MONET_IMAGES, filename)
                save_image(output, file_path)
                zipf.write(file_path, filename)
                
                print(f"Processed {i + 1}/{len(TEST_IMAGES)} - {round((i + 1)/len(TEST_IMAGES), 2)} %: {filename}  :: Time Elapsed: {int(time() - start)} Seconds", end='\r')

        print(f"\nAll images processed and added to the zip file: {SUBMISSION_FILE}.")
    
    
    @staticmethod
    def save_metrics(metrics):
        with open(METRICS_FILE, "w", encoding="utf-8") as jsonf:
            json.dump(metrics, jsonf, indent=4)
        
        print(f"\nAll training metrics are in file: {METRICS_FILE}.")


    @staticmethod
    def cleanup_jobs(target_dir, ext="*"):
        for f in glob.glob(f'{target_dir}/*.{ext}'):
            os.remove(f)
        
        os.rmdir(target_dir)


**Loading the datasets**

In [None]:
transforms_dataset = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [None]:
train_loader = DataLoader(
    ImageDataset(test=False, transforms=transforms_dataset),
    batch_size = BATCH_SIZE,
    shuffle = True
)

test_loader = DataLoader(
    ImageDataset(test=True, transforms=transforms_dataset),
    batch_size = BATCH_SIZE,
    shuffle = False
)

**Visualize a photo example and a Monet painting example**

In [None]:
# Load the last image from each dataset
monet_image_path = MONET_IMAGES[-1]
photo_image_path = TEST_IMAGES[-1]
monet_painting = transforms_dataset(Image.open(monet_image_path))
example_photo = transforms_dataset(Image.open(photo_image_path))

MonetUtils.view_images(monet_painting, example_photo, 'Monet Painting and Photo Example')

**Compile CycleGAN class**

In [None]:
cyclegan = CycleGAN(3)

In [None]:
from functools import partial

MonetUtils.transform_to_monet = partial(MonetUtils.transform, generator=cyclegan.monet_generator)
MonetUtils.transform_to_photo = partial(MonetUtils.transform, generator=cyclegan.photo_generator)

In [None]:
real_X, real_Y = next(iter(test_loader))

In [None]:
#Print Sample Before Training
fake_Y, fake_X = cyclegan.forward(real_X.type(Tensor), real_Y.type(Tensor))
MonetUtils.sample_images(real_X, real_Y, fake_X, fake_Y)
cyclegan.save_model(MODEL_CHECKPOINT)
torch.cuda.empty_cache()

In [None]:
fake_, real_ = MonetUtils.transform_to_monet(photo_image_path)
MonetUtils.view_images(real_, fake_, "Original Photo and It's Generated Art")
fake_, real_ = MonetUtils.transform_to_photo(monet_image_path)
MonetUtils.view_images(real_, fake_, " Monet Art and It's Generated Photo")

**Training on Entire Data Set**

In [None]:
trainer = MonetTrainer(cyclegan, n_epoches, lr, beta1, beta2, decay_epoch)

In [None]:
trainer.train(train_loader)
torch.cuda.empty_cache()

**Visualizing the Solution**

In [None]:
#Print Sample After Model Training
fake_Y, fake_X = cyclegan.forward(real_X.type(Tensor), real_Y.type(Tensor))
MonetUtils.sample_images(real_X, real_Y, fake_X, fake_Y)
cyclegan.save_model(SAVED_MODEL_PATH)
torch.cuda.empty_cache()

In [None]:
fake_, real_ = MonetUtils.transform_to_monet(photo_image_path)
MonetUtils.view_images(real_, fake_, "Original Photo and It's Generated Art")
fake_, real_ = MonetUtils.transform_to_photo(monet_image_path)
MonetUtils.view_images(real_, fake_, " Monet Art and It's Generated Photo")

**TO TESTING PRE-TRAINED MODEL**

**Evaluating the model using FID Score**

In [None]:
score = trainer.evaluate_fid(test_loader)
print(f'FID Score: {score}')

**Create Submission Images**

In [None]:
MonetUtils.make_submission()

**Trigger Download Link**

In [None]:
from IPython.display import FileLink
FileLink(SUBMISSION_FILE)

**Save Training Metrics**

In [None]:
MonetUtils.save_metrics(trainer.metrics)

**Clean Up (Optional)**

In [None]:
[MonetUtils.cleanup_jobs(folder) for folder in FID_SCORE_DIRS + [GENERATED_MONET_IMAGES]]