Data and Project source: https://www.kaggle.com/competitions/gan-getting-started/overview

In [None]:
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import os
import torchvision.transforms as transforms
import torch.optim as optim
from tqdm import tqdm

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
photo_path = 'picashu/data/photo_jpg'
monet_path = 'picashu/data/monet_jpg'

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

In [None]:
# Dataset class:

class Images(Dataset):
    def __init__(self, photo_path, monet_path, transform):
        self.photo_path = photo_path
        self.monet_path = monet_path
        self.transform = transform
        self.photos = os.listdir(photo_path)
        self.monets = os.listdir(monet_path)
        self.l_photo = len(self.photos)
        self.l_monet = len(self.monets)
    
    def __len__(self):
        return max(len(self.photos), len(self.monets))
    
    def __getitem__(self, idx):
        photo = Image.open(self.photo_path + self.photos[idx % self.l_photo]).convert("RGB")
        monet = Image.open(self.monet_path + self.monets[idx % self.l_monet]).convert("RGB")
        
        photo = self.transform(photo)
        monet = self.transform(monet)
        
        return photo, monet

In [None]:
# Define dataset:

dataset = Images(photo_path, monet_path, transform)

In [None]:
# Define torch dataloader:

dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

In [None]:
# Check examples of pics:
example = next(iter(dataloader))

plt.subplot(1, 2, 1)
plt.title('Photo example')
plt.imshow(example[0][0].permute(1, 2, 0) * 0.5 + 0.5)

plt.subplot(1, 2, 2)
plt.title('Monet example')
plt.imshow(example[1][0].permute(1, 2, 0) * 0.5 + 0.5)

Discriminator model

In [None]:
class conv_block(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=in_channels,
                      out_channels=out_channels, 
                      kernel_size=4, 
                      stride=stride,
                      padding=1,
                      bias=True,
                      padding_mode="reflect"),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )
    
    def forward(self, x):
        return self.conv(x)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=64,
                kernel_size=4, 
                stride=2,
                padding=1,
                padding_mode="reflect"),
            nn.LeakyReLU(0.2)
        )
        self.process = nn.Sequential(
            conv_block(64, 128, 2),
            conv_block(128, 256, 2),
            conv_block(256, 512, 1),
            nn.Conv2d(
                in_channels=512,
                out_channels=1,
                kernel_size=4,
                stride=1,
                padding=1,
                padding_mode='reflect'),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        """
        OUT = floor((IN + 2 * padding - kernel_size + 1) / stride + 1)
        [batch_size, 3, 256, 256] ->
        [batch_size, 64, 128, 128] ->
        [batch_size, 128, 64, 64] ->
        [batch_size, 256, 32, 32] ->
        [batch_size, 512, 30, 30] ->
        [batch_size, 1, 30, 30]
        """
        x = self.initial(x)
        x = self.process(x)
        return x

In [None]:
x = torch.randn((1, 3, 256, 256))
dis = Discriminator()
assert(dis(x).shape == (1, 1, 30, 30))

Generator Model

In [None]:
class gen_conv_block(nn.Module):
    def __init__(self, in_channels, out_channels, TYPE='down', activation=False, **kwargs):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=in_channels,
                      out_channels=out_channels,
                      padding_mode="reflect",
                      **kwargs) if TYPE == 'down'
            else nn.ConvTranspose2d(in_channels=in_channels,
                                   out_channels=out_channels,
                                   **kwargs),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True) if activation else nn.Identity()
        )
        
    def forward(self, x):
        return self.conv(x)

In [None]:
class res_block(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block_ = nn.Sequential(
            gen_conv_block(channels, channels, kernel_size=3, padding=1),
            gen_conv_block(channels, channels, activation=False, kernel_size=3, padding=1)
        )
    
    def forward(self, x):
        return x + self.block_(x)

In [None]:
class Generator(nn.Module):
    def __init__(self, in_channels=3, num_residuals_blocks=9):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
            nn.ReLU(inplace=True)
        )
        self.down = nn.Sequential(
            gen_conv_block(64, 64*2, TYPE='down', kernel_size=3, stride=2, padding=1),
            gen_conv_block(64*2, 64*4, TYPE='down', kernel_size=3, stride=2, padding=1)
        )
        self.residual = nn.Sequential(
            *[res_block(64*4) for _ in range(num_residuals_blocks)]
        )
        self.up = nn.Sequential(
            gen_conv_block(64*4, 64*2, TYPE='up', kernel_size=3, stride=2, padding=1, output_padding=1),
            gen_conv_block(64*2, 64, TYPE='up', kernel_size=3, stride=2, padding=1, output_padding=1)
        )
        self.get_img = nn.Conv2d(64, in_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")
        
    def forward(self, x):
        x = self.initial(x)
        x = self.down(x)
        x = self.residual(x)
        x = self.up(x)
        return self.get_img(x)

In [None]:
x = torch.randn((1, 3, 256, 256))
dis = Generator()
assert(dis(x).shape == (1, 3, 256, 256))

In [None]:
lr = 2e-4
lambda_cycle = 10

In [None]:
disc_photo = Discriminator().to(device)
disc_monet = Discriminator().to(device)

gen_photo = Generator().to(device)
gen_monet = Generator().to(device)

In [None]:
disc_optimizer = optim.Adam(
    list(disc_photo.parameters()) + list(disc_monet.parameters()),
    lr=lr,
    betas=(0.5, 0.999)
)

gen_optimizer = optim.Adam(
    list(gen_photo.parameters()) + list(gen_monet.parameters()),
    lr=lr,
    betas=(0.5, 0.999)
)

In [None]:
dis_scaler = torch.amp.GradScaler('cuda')
gen_scaler = torch.amp.GradScaler('cuda')

In [None]:
MSE = nn.MSELoss()
L1 = nn.L1Loss()

Training Models

In [None]:
epoches = 1

for epoch in range(epoches):
    running_dis_loss = 0.0
    running_gen_loss = 0.0
    for photo, monet in tqdm(dataloader, leave=True):
        photo = photo.to(device)
        monet = monet.to(device)
        
        # Train discriminator:
        fake_photo = gen_photo(monet)
        Dis_photo_real = disc_photo(photo)
        Dis_photo_fake = disc_photo(fake_photo.detach())
        
        Dis_photo_loss = MSE(Dis_photo_real, torch.ones_like(Dis_photo_real)) + \
                         MSE(Dis_photo_fake, torch.zeros_like(Dis_photo_fake))
        
        fake_monet = gen_monet(photo)
        Dis_monet_real = disc_monet(monet)
        Dis_monet_fake = disc_monet(fake_monet.detach())
        
        Dis_monet_loss = MSE(Dis_monet_real, torch.ones_like(Dis_monet_real)) + \
                         MSE(Dis_monet_fake, torch.zeros_like(Dis_monet_fake))
        
        Dis_loss = (Dis_photo_loss + Dis_monet_loss) / 2.0
        running_dis_loss += Dis_loss / len(dataloader)
        
        disc_optimizer.zero_grad()
        dis_scaler.scale(Dis_loss).backward()
        dis_scaler.step(disc_optimizer)
        dis_scaler.update()
        
        # Train Generator:
        Dis_photo_fake = disc_photo(fake_photo)
        Dis_monet_fake = disc_monet(fake_monet)
        
        Gen_photo_loss = MSE(Dis_photo_fake, torch.ones_like(Dis_photo_fake))
        Gen_monet_loss = MSE(Dis_monet_fake, torch.ones_like(Dis_monet_fake))
        
        Cycled_monet = gen_monet(fake_photo) 
        Cycled_photo = gen_photo(fake_monet)
        
        Cycled_loss = L1(monet, Cycled_monet) + L1(photo, Cycled_photo)
        
        Gen_loss = Gen_photo_loss + Gen_monet_loss + Cycled_loss * lambda_cycle
        running_gen_loss += Gen_loss / len(dataloader)
        
        gen_optimizer.zero_grad()
        gen_scaler.scale(Gen_loss).backward()
        gen_scaler.step(gen_optimizer)
        gen_scaler.update()
    print(f"Epoch {epoch + 1}. Generator loss by epoch: {running_gen_loss}, discriminator loss by epoch: {running_dis_loss}")

In [None]:
torch.save(disc_photo.state_dict(), '/kaggle/working/disc_photo.pth')
torch.save(disc_monet.state_dict(), '/kaggle/working/disc_monet.pth')
torch.save(gen_photo.state_dict(), '/kaggle/working/gen_photo.pth')
torch.save(gen_monet.state_dict(), '/kaggle/working/gen_monet.pth')

In [None]:
batch = next(iter(dataloader))[0]

_, ax = plt.subplots(5, 2, figsize=(12, 12))

for i in range(5):
    original_img = batch[i]
    predicted_img = None
    with torch.no_grad():
        predicted_img = gen_monet(original_img.unsqueeze(0).to(device))
    
    ax[i, 0].imshow(original_img.permute(1, 2, 0) * 0.5 + 0.5)
    ax[i, 1].imshow(predicted_img.squeeze(0).permute(1, 2, 0).cpu() * 0.5 + 0.5)
    
    ax[i, 0].set_title("Original photo")
    ax[i, 1].set_title("Monet like")
    
    ax[i, 0].axis("off")
    ax[i, 1].axis("off")
plt.show()

This project is inspired from https://github.com/junyanz/CycleGAN