# I'm Something of a Painter Myself

*by Len Fu 2025/3/20*

---

A practice of GAN networks.

# Import the Library

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import shutil
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import Sampler
from tqdm import tqdm
from PIL import Image
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

# Load in the Dataset And Preprocess the Dataset

In [None]:
# Get the path to the Dataset
ROOT_PATH = '/kaggle/input/gan-getting-started/'

# Define the preprocessing procedure
data_transforms = transforms.Compose(
    [
        transforms.Resize((256, 256)), # Adjust the figsize
        transforms.ToTensor(), # Transform to Torch Tensor
        transforms.Normalize(
            mean = [0.5, 0.5, 0.5],
            std = [0.5, 0.5, 0.5]
        ) # Normalise to standard form
    ]
)


# Define the CustomDataset Class
class CustomDataset(Dataset):
    def __init__(self, root_dir, target_folders, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.target_folders = target_folders
        self.image_files = self._get_image_files()

    def _get_image_files(self):
        image_files = []
        # Walk through the directories under the root directories
        for folder_name in os.listdir(self.root_dir):
            folder_path = os.path.join(self.root_dir, folder_name)
            if folder_name in self.target_folders and os.path.isdir(folder_path):
                for file_name in os.listdir(folder_path):
                    if file_name.endswith('.jpg'):
                        image_files.append(os.path.join(folder_path, file_name))
        return image_files

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image

class RepeatSampler(Sampler):
    def __init__(self, data_source, target_length):
        self.data_source = data_source
        self.target_length = target_length
        
    def __iter__(self):

        n_repeats = self.target_length // len(self.data_source) + 1

        indices = []
        for _ in range(n_repeats):
            indices.extend(np.random.permutation(len(self.data_source)).tolist())

        return iter(indices[:self.target_length])
    
    def __len__(self):
        return self.target_length

# Load in the dataset
monet_dataset = CustomDataset(
    root_dir=ROOT_PATH,
    target_folders=['monet_jpg'],
    transform=data_transforms
)

photo_dataset = CustomDataset(
    root_dir=ROOT_PATH,
    target_folders=['photo_jpg'],
    transform=data_transforms
)

# Create the batch loader
batch_size = 32
shuffle = True
num_workers = 4

monet_sampler = RepeatSampler(monet_dataset, len(photo_dataset))

monet_dataloader = DataLoader(
    dataset=monet_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    sampler=monet_sampler
)

photo_dataloader = DataLoader(
    dataset=photo_dataset,
    batch_size=batch_size,
    shuffle=shuffle,
    num_workers=num_workers
)

In [None]:
# Verify whether the dataset is loaded successfully
def imshow(img):
    img = img / 2 + 0.5  # Ir-normalise
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# Get some data
monet_dataiter = iter(monet_dataloader)
monet_images = next(monet_dataiter)

photo_dataiter = iter(photo_dataloader)
photo_images = next(photo_dataiter)

# Show some Photo
imshow(torchvision.utils.make_grid(monet_images[:4])) 
imshow(torchvision.utils.make_grid(photo_images[:4]))

test_dataiter = iter(photo_dataloader)
test_imgs = next(test_dataiter)

print("DataLoader Output Form:")
print("Data Shape:", test_imgs.shape)

# Define the Structure of the Generator and the Discriminator

In [None]:
class CycleGenerator(nn.Module):
    """
    Define the architecture of the generator network.
    Note: Both generators G_XtoY and G_YtoX have the same architecture in this assignment.
    """

    def __init__(self, conv_dim=256, input_channels=3, init_zero_weight=False):
        super(CycleGenerator, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(input_channels, 64, 4, stride=2, padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True)
        )

        self.res_blocks = nn.Sequential(*[
            nn.Sequential(
                nn.ReflectionPad2d(1),
                nn.Conv2d(128, 128, 3, stride=1, padding=0),
                nn.InstanceNorm2d(128),
                nn.ReLU(),
                nn.ReflectionPad2d(1),
                nn.Conv2d(128, 128, 3, stride=1, padding=0),
                nn.InstanceNorm2d(128)
            ) for _ in range(6)
        ])

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
            nn.InstanceNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, input_channels, 7, padding=3),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.res_blocks(x) + x
        return self.decoder(x)


class Discriminator(nn.Module):
    """
    The Discriminator is based on the structure of PatchGAN,
    """
    
    def __init__(self, conv_dim=256, input_channels=3):
        super().__init__()

        self.features = nn.Sequential(
            nn.Conv2d(input_channels, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 512, 4, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        )

        self.classifier = nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x


In [None]:
# Initialize the generator Monet and Photo
input_channels = 3
generator_M = CycleGenerator()
generator_P = CycleGenerator()

# Initialize the discriminator Monet and Photo
discriminator_M = Discriminator()
discriminator_P = Discriminator()

criterion_GAN = nn.BCEWithLogitsLoss()
criterion_cycle = nn.BCEWithLogitsLoss()
criterion_identity = nn.BCEWithLogitsLoss()


optimizer_MG  = torch.optim.Adam(generator_M.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_PG  = torch.optim.Adam(generator_P.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_MD  = torch.optim.Adam(discriminator_M.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_PD  = torch.optim.Adam(discriminator_P.parameters(), lr=0.0002, betas=(0.5, 0.999))

if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    generator_M = nn.DataParallel(generator_M)
    generator_P = nn.DataParallel(generator_P)
    discriminator_M = nn.DataParallel(discriminator_M)
    discriminator_P = nn.DataParallel(discriminator_P)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

generator_M.to(device), discriminator_M.to(device), generator_P.to(device), discriminator_P.to(device)

In [None]:
epoch_nums = 66

for epoch in range(epoch_nums):   
    for monet_imgs, photo_imgs in tqdm(zip(monet_dataloader, photo_dataloader), desc=f"Epoch:{epoch+1}/{epoch_nums}"):

        if monet_imgs is None or photo_imgs is None:
            continue
            
        monet_imgs = monet_imgs.to(device)
        photo_imgs = photo_imgs.to(device)
        

        if monet_imgs.size(0) < photo_imgs.size(0) :
            photo_imgs = photo_imgs[:monet_imgs.size(0)]
        elif monet_imgs.size(0) > photo_imgs.size(0) :
            monet_imgs = monet_imgs[:photo_imgs.size(0)]
    
        # Labels of real and fake figures
        real_labels = torch.ones_like(discriminator_P(monet_imgs))*0.9
        fake_labels = torch.zeros_like(discriminator_P(monet_imgs))*0.1
    
        # Train the photo discriminator
        optimizer_PD.zero_grad()
    
        # Real imgs
        real_imgs = photo_imgs  # Assuming photo_imgs are real images
        outputs = discriminator_P(real_imgs)
        d_loss_real = criterion_GAN(outputs, real_labels)
        d_loss_real.backward()
    
        # Generate the fake imgs
        fake_imgs = generator_P(monet_imgs).detach()  # Generate fake images from monet_imgs
        outputs = discriminator_P(fake_imgs)
        d_loss_fake = criterion_GAN(outputs, fake_labels)
        d_loss_fake.backward()
    
        d_loss = d_loss_real + d_loss_fake
        optimizer_PD.step()

        # Train the monet discriminator
        optimizer_MD.zero_grad()
    
        # Real imgs
        real_imgs = monet_imgs  # Assuming photo_imgs are real images
        outputs = discriminator_M(real_imgs)
        d_loss_real = criterion_GAN(outputs, real_labels)
        d_loss_real.backward()
    
        # Generate the fake imgs
        fake_imgs = generator_M(photo_imgs).detach()  # Generate fake images from monet_imgs
        outputs = discriminator_M(fake_imgs)
        d_loss_fake = criterion_GAN(outputs, fake_labels)
        d_loss_fake.backward()
    
        d_loss = d_loss_real + d_loss_fake
        optimizer_MD.step()
    
        # Train the photo generator
        optimizer_PG.zero_grad()
    
        fake_imgs = generator_P(monet_imgs)  # Generate fake images from monet_imgs
        outputs = discriminator_P(fake_imgs)
        g_loss_gan = criterion_GAN(outputs, real_labels)

        reconstructed_imgs = generator_M(fake_imgs)
        g_loss_cycle = criterion_cycle(reconstructed_imgs, monet_imgs)

        identity_imgs = generator_P(photo_imgs)
        g_loss_identity = criterion_identity(identity_imgs, photo_imgs)

        g_loss = g_loss_gan + g_loss_cycle * 10 + g_loss_identity * 0.5
        
        g_loss.backward()
        optimizer_PG.step()

        # Train the monet generator
        optimizer_MG.zero_grad()
    
        fake_imgs = generator_M(photo_imgs)  # Generate fake images from monet_imgs
        outputs = discriminator_M(fake_imgs)
        g_loss_gan = criterion_GAN(outputs, real_labels)

        reconstructed_imgs = generator_P(fake_imgs)
        g_loss_cycle = criterion_cycle(reconstructed_imgs, photo_imgs)

        identity_imgs = generator_M(monet_imgs)
        g_loss_identity = criterion_identity(identity_imgs, monet_imgs)
        
        g_loss = g_loss_gan + g_loss_cycle * 10 + g_loss_identity * 0.5
        g_loss.backward()
        optimizer_MG.step()

    if (epoch + 1) % 20 == 0:
        with torch.no_grad():
            monet_to_photo = generator_P(monet_imgs[:5])
            photo_to_monet = generator_M(photo_imgs[:5])
        
        plt.figure(figsize=(12, 8)) 
        
        def denorm(tensor):
            return tensor * 0.5 + 0.5  
        
        for j in range(5):
            row = j
            
            plt.subplot(5, 4, row*4 + 1)
            plt.imshow(denorm(monet_imgs[j].cpu().permute(1,2,0)))
            plt.title("Monet Original" if j==0 else "")
            plt.axis('off')
            
            plt.subplot(5, 4, row*4 + 2)
            plt.imshow(denorm(monet_to_photo[j].cpu().permute(1,2,0)))
            plt.title("Generated Photo" if j==0 else "")
            plt.axis('off')
            
            plt.subplot(5, 4, row*4 + 3)
            plt.imshow(denorm(photo_imgs[j].cpu().permute(1,2,0)))
            plt.title("Photo Original" if j==0 else "")
            plt.axis('off')
            
            plt.subplot(5, 4, row*4 + 4)
            plt.imshow(denorm(photo_to_monet[j].cpu().permute(1,2,0)))
            plt.title("Generated Monet" if j==0 else "")
            plt.axis('off')
        
        plt.tight_layout()
        plt.show()

# Generate the Figures

In [None]:
os.makedirs('../models', exist_ok=True)
# Save the trained encoder and trained decoder
torch.save(generator_M.state_dict(), '../models/generator_M.pth')
torch.save(generator_P.state_dict(), '../models/generator_P.pth')
torch.save(discriminator_M.state_dict(), '../models/discriminator_M.pth')
torch.save(discriminator_P.state_dict(), '../models/discriminator_P.pth')

In [None]:
os.makedirs('../images', exist_ok=True)

generator_M.eval()  

i = 0  
with torch.no_grad():
    for batch_idx, imgs in enumerate(tqdm(photo_dataloader, desc='Generating Images')):
        imgs = imgs.to(device)
        generated_imgs = generator_M(imgs)
        
        for j in range(generated_imgs.size(0)):
            generated_img = generated_imgs[j].cpu()
            generated_img = generated_img * 0.5 + 0.5  # [-1,1] → [0,1]
            generated_img = generated_img.permute(1, 2, 0).numpy()

            generated_img = np.clip(generated_img, 0.0, 1.0)

            plt.imsave(f"../images/{i}.jpg", generated_img)
            i += 1

# Create submission files

In [None]:
# Zip the created figures
shutil.make_archive("/kaggle/working/images", 'zip', "/kaggle/images")

In [6]:
import torch
model_file = torch.load('./models/generator_M.pth', map_location='cpu')

print(model_file.keys())

odict_keys(['encoder.0.weight', 'encoder.0.bias', 'res_blocks.0.1.weight', 'res_blocks.0.1.bias', 'res_blocks.0.5.weight', 'res_blocks.0.5.bias', 'res_blocks.1.1.weight', 'res_blocks.1.1.bias', 'res_blocks.1.5.weight', 'res_blocks.1.5.bias', 'res_blocks.2.1.weight', 'res_blocks.2.1.bias', 'res_blocks.2.5.weight', 'res_blocks.2.5.bias', 'res_blocks.3.1.weight', 'res_blocks.3.1.bias', 'res_blocks.3.5.weight', 'res_blocks.3.5.bias', 'res_blocks.4.1.weight', 'res_blocks.4.1.bias', 'res_blocks.4.5.weight', 'res_blocks.4.5.bias', 'res_blocks.5.1.weight', 'res_blocks.5.1.bias', 'res_blocks.5.5.weight', 'res_blocks.5.5.bias', 'decoder.0.weight', 'decoder.0.bias', 'decoder.3.weight', 'decoder.3.bias'])
