In [49]:
# Third-party imports
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import glob
import os
from PIL import Image
from typing import Callable, Optional
import IPython
import torchvision.transforms as transforms
from tqdm import tqdm
import matplotlib.pyplot as plt

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create outputs folders
os.makedirs("outputs", exist_ok=True)

In [50]:
# Training and Model Hyperparameters
lr = 3e-4
batch_size = 32
img_dim = 3 * 256 * 256
num_epochs = 50

In [51]:
# Model
class Generator(nn.Module):
    """Generator network.
    
    Attempts to augoment an input image to look like a Monet painting.
    """
    def __init__(self, img_dim: int):
        """Initializes the Generator network.
        
        Args:
            img_dim (int): Dimension of the image space.
        """
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(img_dim, 1024),
            nn.LeakyReLU(0.1),
            nn.Linear(1024, img_dim), 
            nn.Tanh() # For images, we want the values to be between -1 and 1.
        )

    def forward(self, x):
        return self.gen(x)

class Discriminator(nn.Module):
    """Discriminator network.
    
    Attempts to classify real and fake images from the dataset and the Generator network.
    """
    def __init__(self, img_dim: int):
        """Initializes the Discriminator network.
        
        Args:
            img_dim (int): Dimension of the input space.
        """
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 512),
            nn.LeakyReLU(0.1),
            nn.Linear(512, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.disc(x)

# Init the models
gen_g = Generator(img_dim).to(device)
gen_f = Generator(img_dim).to(device)
disc = Discriminator(img_dim).to(device)

In [52]:
# Datasets
class ImageDataset(Dataset):
    """Image dataset.
    
    Loads images from a directory.
    """
    def __init__(self, train_dir: str, test_dir: str, transforms: Optional[Callable] = None):
        """Initializes the ImageDataset.
        
        Args:
            train_dir (str): Directory containing the raw training images.
            test_dir (str): Directory containing the raw training images.
            transforms (Optional[Callable]): Optional transform to be applied on a sample.
        """
        self.train_dir = sorted(glob.glob(os.path.join(train_dir, "*.jpg")))
        self.test_dir = sorted(glob.glob(os.path.join(test_dir, "*.jpg")))
        self.transforms = transforms

    def __getitem__(self, index):
        train_image = Image.open(self.train_dir[index % len(self.train_dir)])
        test_image = Image.open(self.test_dir[index % len(self.test_dir)])
        if self.transforms:
            train_image = self.transforms(train_image)
            test_image = self.transforms(test_image)
        return train_image, test_image

    def __len__(self):
        return len(self.train_dir)
# Image Transforms
transforms = transforms.Compose([
    transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))
])
# Create the dataset
ds = ImageDataset(train_dir='data/photo_jpg/', test_dir='data/monet_jpg/', transforms=transforms)

# Create the data loader
dataloader = DataLoader(ds, batch_size=32, shuffle=True, num_workers=4)        

In [None]:
# Optimizers
opt_gen_g = torch.optim.Adam(gen_g.parameters(), lr=lr)
opt_gen_f = torch.optim.Adam(gen_f.parameters(), lr=lr)
opt_disc = torch.optim.Adam(disc.parameters(), lr=lr)

# Loss
criterion = nn.BCELoss()
cycle_consistency_criterion = nn.MSELoss()

# Training Loop
for epoch in range(num_epochs):
    for raw_img, monet_img in tqdm(dataloader):
        # Get the data
        raw_img = raw_img.to(device)
        raw_img = raw_img.reshape(raw_img.shape[0], -1)
        monet_img = monet_img.to(device)
        monet_img = monet_img.reshape(monet_img.shape[0], -1)        
        
        # Train the discriminator
        disc_real = disc(monet_img)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))

        fake_img = gen_g(raw_img)
        disc_fake = disc(fake_img)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        
        lossD = (lossD_real + lossD_fake) / 2

        # Train the generator g
        lossG_1 = criterion(disc_fake, torch.ones_like(disc_fake))

        # Train the generator f
        og_img = gen_f(fake_img)
        lossG_2 = cycle_consistency_criterion(og_img, raw_img)

        lossG_g = (lossG_1 + lossG_2) /2
        lossG_f = lossG_2

        # Backward pass.
        lossD.backward(retain_graph=True)
        lossG_g.backward(retain_graph=True)
        lossG_f.backward()

        opt_gen_g.zero_grad()
        opt_gen_f.zero_grad()
        opt_disc.zero_grad()

        opt_gen_g.step()
        opt_gen_f.step()
        opt_disc.step()

    print(f"Epoch [{epoch+1}/{num_epochs}] Loss D: {lossD:.4f}, Loss G_g: {lossG_g:.4f}, Loss G_f: {lossG_f:.4f}")
    # Save model
    torch.save(gen_g.state_dict(), f'outputs/gen_g.pth')
    torch.save(gen_f.state_dict(), f'outputs/gen_f.pth')
    torch.save(disc.state_dict(), f'outputs/disc.pth')