In [1]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms


## 1. Data Preparation

In [2]:
class MonetPhotoDataset(Dataset):
    def __init__(self, monet_dir, photo_dir, transform=None):
        self.monet_images = [os.path.join(monet_dir, img) for img in os.listdir(monet_dir) if img.endswith('.jpg')]
        self.photo_images = [os.path.join(photo_dir, img) for img in os.listdir(photo_dir) if img.endswith('.jpg')]
        self.transform = transform

    def __len__(self):
        return min(len(self.monet_images), len(self.photo_images))

    def __getitem__(self, idx):
        monet_img = Image.open(self.monet_images[idx]).convert("RGB")
        photo_img = Image.open(self.photo_images[idx]).convert("RGB")

        if self.transform:
            monet_img = self.transform(monet_img)
            photo_img = self.transform(photo_img)

        return monet_img, photo_img

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

monet_dir = './data/monet_jpg/'
photo_dir = './data/photo_jpg/'

dataset = MonetPhotoDataset(monet_dir, photo_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)


## 3. Model Architecture

#### Comparison
##### 1. standard GAN
In a standard GAN, there are two main components:
- Generator (G): This model generates new data instances (e.g., images) that resemble the training data. Its goal is to create data that the discriminator cannot distinguish from real data.
- Discriminator (D): This model evaluates the data produced by the generator against real data. Its goal is to correctly classify data as either "real" (from the training set) or "fake" (generated by G).

##### 2. CycleGAN
- Two Generators:
  - Generator G: Transforms images from domain X (e.g., photos) to domain Y (e.g., Monet-style paintings).
  - Generator F: Transforms images from domain Y (Monet-style paintings) back to domain X (photos).
- Two Discriminators:
  - Discriminator D_X: Tries to distinguish between real images from domain X and fake images generated by F.
  - Discriminator D_Y: Tries to distinguish between real images from domain Y and fake images generated by G.
- Cycle Consistency Loss:
  - To ensure that the transformation is meaningful, CycleGAN introduces the concept of cycle consistency. This means if you transform an image to the other domain and back again, you should get the original image. 
- Identity Loss (Optional but often used):This loss ensures that if you input an image from one domain into the generator corresponding to that domain, it should produce the same image. 


#### Summary
The code you've been working on implements a CycleGAN, which is a specific type of GAN designed for unpaired image-to-image translation. Unlike a standard GAN, which might generate new images from random noise, a CycleGAN is focused on transforming images from one domain to another (e.g., turning photos into Monet-style paintings).

In [4]:
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_features, in_features, 3, stride=1, padding=1),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_features, in_features, 3, stride=1, padding=1),
            nn.InstanceNorm2d(in_features)
        )

    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # Initial convolution block
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, 7, stride=1, padding=3),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )
        # Downsampling
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True)
        )
        # Residual blocks
        self.res_blocks = nn.Sequential(*[ResidualBlock(256) for _ in range(9)])
        # Upsampling
        self.deconv1 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )
        # Output layer
        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 3, 7, stride=1, padding=3),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.res_blocks(x)
        x = self.deconv1(x)
        x = self.conv3(x)
        return x

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        def discriminator_block(in_filters, out_filters, stride=2, normalize=True):
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=stride, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(3, 64, normalize=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.Conv2d(512, 1, 4, padding=1)
        )

    def forward(self, x):
        return self.model(x)


In [8]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

generator_g = Generator().to(device)
generator_f = Generator().to(device)
discriminator_x = Discriminator().to(device)
discriminator_y = Discriminator().to(device)

# Optimizers
g_optimizer = torch.optim.Adam(generator_g.parameters(), lr=2e-4, betas=(0.5, 0.999))
f_optimizer = torch.optim.Adam(generator_f.parameters(), lr=2e-4, betas=(0.5, 0.999))
dx_optimizer = torch.optim.Adam(discriminator_x.parameters(), lr=2e-4, betas=(0.5, 0.999))
dy_optimizer = torch.optim.Adam(discriminator_y.parameters(), lr=2e-4, betas=(0.5, 0.999))

# Loss functions
criterion_gan = nn.MSELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()

# Training
epochs = 100
lambda_cycle = 10.0
lambda_identity = 5.0

for epoch in range(epochs):
    for i, (monet, photo) in enumerate(dataloader):
        monet, photo = monet.to(device), photo.to(device)

        # ----------- Train Generators --------------
        # Generator g
        g_optimizer.zero_grad()

        fake_monet = generator_g(photo)
        cycled_photo = generator_f(fake_monet)
        same_monet = generator_g(monet)
        
        # Losses for generator g
        identity_loss_g = criterion_identity(same_monet, monet) * lambda_identity
        gan_loss_g = criterion_gan(discriminator_y(fake_monet), torch.ones_like(discriminator_y(fake_monet)))
        cycle_loss_g = criterion_cycle(cycled_photo, photo) * lambda_cycle

        total_g_loss = identity_loss_g + gan_loss_g + cycle_loss_g
        total_g_loss.backward()
        g_optimizer.step()

        # Generator f
        f_optimizer.zero_grad()

        fake_photo = generator_f(monet)
        cycled_monet = generator_g(fake_photo)
        same_photo = generator_f(photo)
        
        # Losses for generator f
        identity_loss_f = criterion_identity(same_photo, photo) * lambda_identity
        gan_loss_f = criterion_gan(discriminator_x(fake_photo), torch.ones_like(discriminator_x(fake_photo)))
        cycle_loss_f = criterion_cycle(cycled_monet, monet) * lambda_cycle

        total_f_loss = identity_loss_f + gan_loss_f + cycle_loss_f
        total_f_loss.backward()
        f_optimizer.step()

        # ----------- Train Discriminators --------------
        # Discriminator x
        dx_optimizer.zero_grad()

        pred_real_monet = discriminator_x(monet)
        pred_fake_monet = discriminator_x(fake_photo.detach())
        
        dx_loss_real = criterion_gan(pred_real_monet, torch.ones_like(pred_real_monet))
        dx_loss_fake = criterion_gan(pred_fake_monet, torch.zeros_like(pred_fake_monet))
        
        dx_loss = (dx_loss_real + dx_loss_fake) * 0.5
        dx_loss.backward()
        dx_optimizer.step()

        # Discriminator y
        dy_optimizer.zero_grad()

        pred_real_photo = discriminator_y(photo)
        pred_fake_photo = discriminator_y(fake_monet.detach())
        
        dy_loss_real = criterion_gan(pred_real_photo, torch.ones_like(pred_real_photo))
        dy_loss_fake = criterion_gan(pred_fake_photo, torch.zeros_like(pred_fake_photo))
        
        dy_loss = (dy_loss_real + dy_loss_fake) * 0.5
        dy_loss.backward()
        dy_optimizer.step()

        if i % 100 == 0:
            print(f'Epoch [{epoch}/{epochs}] Batch [{i}/{len(dataloader)}] '
                  f'G Loss: {total_g_loss.item():.4f}, F Loss: {total_f_loss.item():.4f}, '
                  f'Dx Loss: {dx_loss.item():.4f}, Dy Loss: {dy_loss.item():.4f}')


Epoch [0/100] Batch [0/300] G Loss: 11.8309, F Loss: 9.9742, Dx Loss: 0.5049, Dy Loss: 0.6386
Epoch [0/100] Batch [100/300] G Loss: 5.6284, F Loss: 4.8153, Dx Loss: 0.2371, Dy Loss: 0.2183
Epoch [0/100] Batch [200/300] G Loss: 6.3874, F Loss: 5.2291, Dx Loss: 0.1397, Dy Loss: 0.3567
Epoch [1/100] Batch [0/300] G Loss: 4.2739, F Loss: 3.6884, Dx Loss: 0.3393, Dy Loss: 0.2366
Epoch [1/100] Batch [100/300] G Loss: 5.3573, F Loss: 4.1936, Dx Loss: 0.1757, Dy Loss: 0.1648
Epoch [1/100] Batch [200/300] G Loss: 3.4330, F Loss: 4.4492, Dx Loss: 0.1828, Dy Loss: 0.1770
Epoch [2/100] Batch [0/300] G Loss: 5.8017, F Loss: 4.9839, Dx Loss: 0.1453, Dy Loss: 0.0721
Epoch [2/100] Batch [100/300] G Loss: 4.0471, F Loss: 3.5777, Dx Loss: 0.2761, Dy Loss: 0.2122


In [None]:
import numpy as np
from torchvision.utils import save_image

output_dir = "generated_images"
os.makedirs(output_dir, exist_ok=True)

for i, (_, photo) in enumerate(dataloader):
    photo = photo.to(device)
    fake_monet = generator_g(photo)
    
    save_image(fake_monet, os.path.join(output_dir, f"monet_{i}.jpg"))
    if i >= 10000:  # Limiting to 10,000 images as per competition rules
        break

# Zipping the images
import shutil
shutil.make_archive("images", 'zip', output_dir)


NameError: name 'generator_g' is not defined