In [None]:
import os

os.system("kaggle competitions download -c gan-getting-started")
os.system("unzip -nq gan-getting-started.zip")

In [None]:
monet_path = "monet_jpg"
photo_path = "photo_jpg"

import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import lightning as L

# Dataset Class
class ImageDataset(Dataset):
    def __init__(self, monet_dir, photo_dir, transform=None):
        self.monet_images = [os.path.join(monet_dir, f) for f in os.listdir(monet_dir)]
        self.photo_images = [os.path.join(photo_dir, f) for f in os.listdir(photo_dir)]
        self.transform = transform

    def __len__(self):
        return max(len(self.monet_images), len(self.photo_images))

    def __getitem__(self, idx):
        monet_img = Image.open(self.monet_images[idx % len(self.monet_images)]).convert("RGB")
        photo_img = Image.open(self.photo_images[idx % len(self.photo_images)]).convert("RGB")

        if self.transform:
            monet_img = self.transform(monet_img)
            photo_img = self.transform(photo_img)

        return monet_img, photo_img

# Transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# Example Usage
dataset = ImageDataset(monet_path, photo_path, transform=transform)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=12)

# Generator Model
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            # Add more layers following CycleGAN architecture...
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)

# Discriminator Model
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            # Add more layers following CycleGAN architecture...
            nn.Conv2d(64, 1, kernel_size=4, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.main(x)

# Lightning Module
class CycleGAN(L.LightningModule):
    def __init__(self):
        super(CycleGAN, self).__init__()
        self.gen_monet = Generator()
        self.gen_photo = Generator()
        self.disc_monet = Discriminator()
        self.disc_photo = Discriminator()

        self.adversarial_loss = nn.MSELoss()
        self.cycle_loss = nn.L1Loss()
        self.identity_loss = nn.L1Loss()

        self.automatic_optimization = False


    def forward(self, x):
        return self.gen_monet(x)

    def training_step(self, batch, batch_idx):
        monet_img, photo_img = batch
        opt_gen, opt_disc = self.optimizers()

        # Generator Monet -> Photo
        fake_photo = self.gen_photo(monet_img)
        recon_monet = self.gen_monet(fake_photo)

        loss_gan = self.adversarial_loss(self.disc_photo(fake_photo), torch.ones_like(self.disc_photo(fake_photo)))
        loss_cycle = self.cycle_loss(recon_monet, monet_img)
        loss_identity = self.identity_loss(self.gen_photo(photo_img), photo_img)

        loss_gen = loss_gan + 10.0 * loss_cycle + 5.0 * loss_identity
        self.log("gen_loss", loss_gen, on_step=True, on_epoch=True, prog_bar=True, logger=True)

        opt_gen.zero_grad()
        self.manual_backward(loss_gen)
        opt_gen.step()

        # Discriminator Photo
        real_loss = self.adversarial_loss(self.disc_photo(photo_img), torch.ones_like(self.disc_photo(photo_img)))
        fake_loss = self.adversarial_loss(self.disc_photo(fake_photo.detach()), torch.zeros_like(self.disc_photo(fake_photo)))
        loss_disc = (real_loss + fake_loss) / 2

        self.log("disc_loss", loss_disc, on_step=True, on_epoch=True, prog_bar=True, logger=True)

        opt_disc.zero_grad()
        self.manual_backward(loss_disc)
        opt_disc.step()

    def configure_optimizers(self):
        opt_gen = torch.optim.Adam(
            list(self.gen_monet.parameters()) + list(self.gen_photo.parameters()), lr=0.0002, betas=(0.5, 0.999)
        )
        opt_disc = torch.optim.Adam(
            list(self.disc_monet.parameters()) + list(self.disc_photo.parameters()), lr=0.0002, betas=(0.5, 0.999)
        )
        return [opt_gen, opt_disc], []

# Trainer Example
trainer = L.Trainer(max_epochs=10)
model = CycleGAN()
trainer.fit(model, dataloader)


In [None]:
import matplotlib.pyplot as plt
import numpy as np

def visualize_results(generator, dataloader, device, num_images=5):
    _, ax = plt.subplots(num_images, 2, figsize=(12, num_images * 3))
    generator.eval()  # Set the generator to evaluation mode

    for i, (monet_img, photo_img) in enumerate(dataloader):
        if i >= num_images:
            break

        # Move images to the device
        photo_img = photo_img.to(device)

        # Generate Monet-style images
        with torch.no_grad():
            fake_monet = generator(photo_img).cpu()

        # Convert tensors to numpy arrays
        photo_img = (photo_img[0].permute(1, 2, 0).numpy() * 0.5 + 0.5) * 255
        fake_monet = (fake_monet[0].permute(1, 2, 0).numpy() * 0.5 + 0.5) * 255

        # Display images
        ax[i, 0].imshow(photo_img.astype(np.uint8))
        ax[i, 1].imshow(fake_monet.astype(np.uint8))
        ax[i, 0].set_title("Input Photo")
        ax[i, 1].set_title("Monet-esque")
        ax[i, 0].axis("off")
        ax[i, 1].axis("off")

    plt.tight_layout()
    plt.show()

# convert the photos to monet style
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
visualize_results(model.gen_photo, dataloader, device)

# save the monet style images
def save_monet_images(generator, dataloader, device, save_dir="monet_output", num_images=5):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    generator.eval()  # Set the generator to evaluation mode

    for i, (monet_img, photo_img) in enumerate(dataloader):
        if i >= num_images:
            break

        # Move images to the device
        photo_img = photo_img.to(device)

        # Generate Monet-style images
        with torch.no_grad():
            fake_monet = generator(photo_img).cpu()

        # Convert tensors to PIL images and save
        for j in range(photo_img.size(0)):
            img = transforms.ToPILImage()(fake_monet[j])
            img.save(os.path.join(save_dir, f"monet_{i * dataloader.batch_size + j}.jpg"))
# Save the monet style images
save_monet_images(model.gen_photo, dataloader, device)
# save the model
torch.save(model.state_dict(), "cyclegan_monet_photo.pth")

