In [None]:
import os

images_v = [f'/kaggle/input/cycleganvangogf/{pic}' for pic in os.listdir('/kaggle/input/cycleganvangogf')]
images_r = [f'/kaggle/input/cylceganreal/{pic}' for pic in os.listdir('/kaggle/input/cylceganreal')]
if '/kaggle/input/cycleganvangogf/hi.py' in images_v: images_v.remove('/kaggle/input/cycleganvangogf/hi.py')

    
print("no of vanggogf images ", len(images_v))
print("no of real images ", len(images_r))

In [None]:
# Discriminator
import torch
import torch.nn as nn

class DBlock(nn.Module):
    def __init__(self,in_channels,out_channels, stride):
        super().__init__()
        
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=4,
                stride=stride,
                padding=1,
                bias=True,
                padding_mode='reflect'
            ),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )
        
    def forward(self, x): 
        return self.conv(x)
    
class Discriminator(nn.Module):
    def __init__(self,in_channels, features = [64,128,256,512]):
        super().__init__()
        
        layers = list()
        init_channels = in_channels
        in_channels = features[0]
        
        for feature in features[1:]:
            layers.append(
                DBlock(
                    in_channels=in_channels,
                    out_channels=feature,
                    stride = 1 if feature==features[-1] else 2
                )
            )
            in_channels = feature
            
        self.discriminator = nn.Sequential(
            #initial
            nn.Sequential(
                nn.Conv2d(
                    in_channels=init_channels,
                    out_channels=features[0],
                    kernel_size=4,
                    stride=2,
                    padding=1,
                    padding_mode='reflect'
                ),
                nn.LeakyReLU(0.2)
            ),
            #intermediate
            nn.Sequential(*layers),
            #final
            nn.Conv2d(
                in_channels=in_channels,
                out_channels = 1,
                kernel_size=4,
                stride = 1,
                padding=1,
                padding_mode='reflect'
            )
        )
        
    def forward(self, x):
        return self.discriminator(x)



In [None]:
class GBlock(nn.Module):
    def __init__(self,in_channels, out_channels, down = True, use_act = True, **kwargs):
        super().__init__()
        
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels = in_channels, out_channels = out_channels, padding_mode='reflect', **kwargs)
            if down
            else nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, **kwargs),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True) if use_act else nn.Identity()
        )
        
    def forward(self, x): 
        return self.conv(x)

class ResidualBlock(nn.Module):
    def __init__(self,channels):
        super().__init__()
        
        self.block = nn.Sequential(
            GBlock(in_channels=channels,out_channels=channels,use_act=True,kernel_size = 3,padding=1),
            GBlock(in_channels=channels,out_channels=channels,use_act=False,kernel_size = 3, padding=1),
        )
        
    def forward(self, x): 
        return x + self.block(x)
    
class Generator(nn.Module):
        def __init__(self,in_channels,num_residuals=9):
            super().__init__()
            
            self.generator = nn.Sequential(
                nn.Conv2d(in_channels,64,kernel_size=7,stride=1, padding=3, padding_mode='reflect'),
                nn.ReLU(inplace=True),
                GBlock(64,128,down=True, use_act = True,kernel_size = 3,stride = 2,padding = 1),
                GBlock(128,256,down=True, use_act = True,kernel_size = 3,stride = 2,padding = 1),
                *([ResidualBlock(256)]*num_residuals),
                GBlock(256,128,down=False, kernel_size = 3, stride = 2, padding=1,output_padding=1),
                GBlock(128,64,down=False, kernel_size = 3, stride = 2, padding=1,output_padding=1),
                nn.Conv2d(64,3,7,1,3,padding_mode="reflect")
            )
            
        def forward(self, x): 
            return self.generator(x)


In [None]:
from torchvision import transforms
import os
from typing import Callable
import torch


class Config:
    if not os.path.isdir("VangGoghGAN"):
        os.mkdir("VangGoghGAN")
    if not os.path.isdir("VangGoghGAN/checkpoints"):
        os.mkdir("VangGoghGAN/checkpoints")
    if not os.path.isdir("VangGoghGAN/gen_VangGogh"):
        os.mkdir("VangGoghGAN/gen_VangGogh")
    if not os.path.isdir("VangGoghGAN/gen_photo"):
        os.mkdir("VangGoghGAN/gen_photo")


    LEARNING_RATE:float = 0.0002
    BETA_1:float = 0.5
    BETA_2:float = 0.999
    LAMBDA_CYCLE:int = 10
    LAMBDA_IDENTITY:int = 5
    NUM_EPOCHS:int = 20
    BATCH_SIZE:int = 1

    SAVE_MODEL:bool = True
    LOAD_MODEL:bool = False

    CHECKPOINT_GEN_VANGGOGH:Callable[..., str] = f'VangGoghGAN/checkpoints/gen_VangGogh.pth'
    CHECKPOINT_DIS_VANGGOGH:Callable[..., str] = f'VangGoghGAN/checkpoints/dis_VangGogh.pth'
    CHECKPOINT_GEN_PHOTO:Callable[..., str] = f'VangGoghGAN/checkpoints/gen_photo.pth'
    CHECKPOINT_DIS_PHOTO:Callable[..., str] = f'VangGoghGAN/checkpoints/dis_photo.pth'

    DEVICE:str = "cuda" if torch.cuda.is_available() else "cpu"
    VANGGOGF_SAVED_IMAGES:str = "VangGoghGAN/gen_VangGogh"
    PHOTO_SAVED_IMAGES:str = "VangGoghGAN/gen_photo"

    preprocess:transforms.transforms.Compose = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((256, 256)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])

In [None]:
import torch
config = Config()


class Utils:
    def save(self, model, epoch, file_name)->None:
        checkpoint = {
            "model":model.state_dict(),
            "epoch":epoch,
        }
        torch.save(checkpoint,file_name) 
        print('__finished saving checkpoint__')

        
        
    def load(self, model, file_name)->None: 
        checkpoint = torch.load(file_name, map_location = config.DEVICE)
        model.load_state_dict(checkpoint['model'])
        print('__finished loading checkpoint__')



In [None]:
from torch.utils.data import Dataset
from PIL import Image
config = Config()


class PhotoToVanggogfDataset(Dataset):
    def __init__(self, vanggogf_photos, real_photos, transform:bool = True)->None:
        self.vanggohf_photos = vanggogf_photos
        self.real_photos = real_photos
        self.transform = transform

    def __len__(self)->int:
        return max(len(self.vanggohf_photos), len(self.real_photos))

    def __getitem__(self, idx)->tuple:
        vanggogf_img = Image.open(self.vanggohf_photos[idx%len(self.vanggohf_photos)])
        real_img = Image.open(self.real_photos[idx%len(self.real_photos)])
        if self.transform:
            real_img = config.preprocess(real_img)
            vanggogf_img = config.preprocess(vanggogf_img)

        return (real_img, vanggogf_img)

In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import os
from torchvision.utils import save_image

from tqdm import tqdm

config = Config()
utils = Utils()


class Trainer:
    def train_epoch(self, 
                    config, 
                    disc_V, 
                    gen_V, 
                    disc_R, 
                    gen_R, 
                    trainloader, 
                    opt_disc, 
                    opt_gen, 
                    l1_loss, 
                    bce_loss, 
                    d_scaler = torch.cuda.amp.GradScaler(),
                    g_scaler = torch.cuda.amp.GradScaler()
                    ):
        
        discriminator_loss_epoch = 0
        generator_loss_epoch = 0

        loader = tqdm(trainloader, colour="blue")
        for idx, (real_img, vanggogf_img) in enumerate(loader):
            real_img = real_img.to(config.DEVICE)
            vanggogf_img = vanggogf_img.to(config.DEVICE)

            # training discriminator
            with torch.cuda.amp.autocast():

                # disc_V loss --> true vanggogf
                fake_v = gen_V(real_img)
                disc_v_fake_score = disc_V(fake_v.detach())
                disc_v_real_score = disc_V(vanggogf_img)
                disc_v_fake_loss = bce_loss(disc_v_fake_score, torch.zeros_like(disc_v_fake_score))
                disc_v_real_loss = bce_loss(disc_v_real_score, torch.ones_like(disc_v_real_score))
                disc_v_loss = disc_v_real_loss + disc_v_fake_loss

                # disc_R loss --> true real
                fake_r = gen_R(vanggogf_img)
                disc_r_fake_score = disc_R(fake_r.detach())
                disc_r_real_score = disc_R(real_img)
                disc_r_fake_loss = bce_loss(disc_r_fake_score, torch.zeros_like(disc_r_fake_score))
                disc_r_real_loss = bce_loss(disc_r_real_score, torch.ones_like(disc_r_real_score))
                disc_r_loss = disc_r_real_loss + disc_r_fake_loss
                
                disc_loss = (disc_v_loss + disc_r_loss) / 2
                discriminator_loss_epoch += disc_loss.item()
            
            opt_disc.zero_grad()
            d_scaler.scale(disc_loss).backward(retain_graph=True)
            d_scaler.step(opt_disc)
            d_scaler.update()

            # training generator
            with torch.cuda.amp.autocast():
                disc_fake_v = disc_V(fake_v)
                disc_fake_r = disc_R(fake_r)

                # normal gan loss
                gen_loss_v = bce_loss(disc_fake_v, torch.ones_like(disc_fake_v))
                gen_loss_r = bce_loss(disc_fake_r, torch.ones_like(disc_fake_r))

                # cycle loss
                cycle_loss_v = l1_loss(real_img, gen_R(fake_v))
                cycle_loss_r = l1_loss(vanggogf_img, gen_V(fake_r))

                # identity loss
                indentity_loss_v = l1_loss(vanggogf_img, gen_V(vanggogf_img))
                indentity_loss_r = l1_loss(real_img, gen_R(real_img))

                gen_loss = gen_loss_v + gen_loss_r + (cycle_loss_v + cycle_loss_r) * config.LAMBDA_CYCLE + (indentity_loss_v + indentity_loss_r) * config.LAMBDA_IDENTITY
                generator_loss_epoch = gen_loss.item()

            opt_gen.zero_grad()
            g_scaler.scale(gen_loss).backward()
            g_scaler.step(opt_gen)
            g_scaler.update()

            if not idx%200:
                save_image(fake_v*0.5+0.5,f"{config.VANGGOGF_SAVED_IMAGES}/{idx}.png")
                save_image(fake_r*0.5+0.5,f"{config.PHOTO_SAVED_IMAGES}/{idx}.png")

            loader.set_postfix(
                disc_loss = f"{disc_loss.item():.4f}",
                gen_loss = f"{gen_loss.item():.4f}"
            )

        return discriminator_loss_epoch/len(trainloader),generator_loss_epoch/len(trainloader)

    def train(self):
        disc_V = Discriminator(in_channels=3).to(config.DEVICE)
        gen_V = Generator(in_channels=3).to(config.DEVICE)

        disc_R = Discriminator(in_channels=3).to(config.DEVICE)
        gen_R = Generator(in_channels=3).to(config.DEVICE)
        
        opt_disc = torch.optim.Adam(
            list(disc_V.parameters()) + list(disc_R.parameters()),
            lr = config.LEARNING_RATE,
            betas = (config.BETA_1, config.BETA_2)
        )
        opt_gen = torch.optim.Adam(
            list(gen_V.parameters()) + list(gen_R.parameters()),
            lr = config.LEARNING_RATE,
            betas = (config.BETA_1, config.BETA_2)
        )

        l1_loss = nn.L1Loss()
        bce_loss = nn.BCEWithLogitsLoss()
        if config.LOAD_MODEL:
            utils.load(
                disc_V,
                config.CHECKPOINT_DIS_VANGGOGH
            )

            utils.load(
                disc_R,
                config.CHECKPOINT_DIS_PHOTO
            )

            utils.load(
                gen_V,
                config.CHECKPOINT_GEN_VANGGOGH
            )

            utils.load(
                gen_R,
               config.CHECKPOINT_GEN_PHOTO
            )
            
        
        try:
            trainset = PhotoToVanggogfDataset(images_v, images_r, transform=True)
            trainloader = torch.utils.data.DataLoader(
                trainset,
                batch_size = config.BATCH_SIZE,
                shuffle = True
            )

       
        except ValueError as e:
            print(e)
            return
        
        generator_loss = list()
        discriminator_loss = list()
        print("__Training started__")

        for epoch in range(config.NUM_EPOCHS):
            print(f"Epoch{epoch + 1}/{config.NUM_EPOCHS}")
            
            gen_loss, disc_loss = self.train_epoch(
                config = config,
                disc_V = disc_V, 
                gen_V = gen_V, 
                disc_R = disc_R,  
                gen_R = gen_R, 
                trainloader = trainloader, 
                opt_disc = opt_disc, 
                opt_gen = opt_gen, 
                l1_loss = l1_loss, 
                bce_loss = bce_loss, 
            )

            generator_loss.append(gen_loss)
            discriminator_loss.append(disc_loss)

            if config.SAVE_MODEL:
                if not os.path.isdir(f'VangGoghGAN/checkpoints'):
                    os.mkdir(f'VangGoghGAN/checkpoints')
                
            utils.save(
                disc_V,
                epoch,
                config.CHECKPOINT_DIS_VANGGOGH
            )
            utils.save(
                disc_R,
                epoch,
                config.CHECKPOINT_DIS_PHOTO
            )
            utils.save(
                gen_V,
                epoch,
                config.CHECKPOINT_GEN_VANGGOGH
            )
            utils.save(
                gen_R,
                epoch,
                config.CHECKPOINT_GEN_PHOTO
            )

        print("__Training Complete__")
       
        print("__plotting loss curves__")
        plt.figure(figsize=(30,30))
        plt.plot(generator_loss,color="red")
        plt.plot(discriminator_loss,color='blue')
        plt.legend(['gen_loss','disc_loss'])
        plt.title('LOSS vs EPOCH',fontdict={'fontsize':10})
        plt.xlabel('EPOCH')
        plt.ylabel('LOSS')
        plt.xticks(range(0, config.NUM_EPOCHS+1 , 1),fontsize=10)
        plt.yticks(fontsize=10)
        plt.show()
        

In [None]:
trainer = Trainer()
trainer.train()