In [1]:
# Third-party imports
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import glob
import os
from PIL import Image
from typing import Callable, Optional
import IPython
import torchvision.transforms as transforms
import torchvision
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import save_image

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create outputs folders
os.makedirs("outputs", exist_ok=True)

  warn(f"Failed to load image Python extension: {e}")


In [2]:
# Training and Model Hyperparameters
lr = 3e-4
batch_size = 32
img_dim = 3 * 256 * 256
num_epochs = 50

In [3]:
# Model
class Generator(nn.Module):
    """Generator network.
    
    Attempts to augoment an input image to look like a Monet painting.
    """
    def __init__(self, img_dim: int):
        """Initializes the Generator network.
        
        Args:
            img_dim (int): Dimension of the image space.
        """
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(img_dim, 1024),
            nn.LeakyReLU(0.1),
            nn.Linear(1024, img_dim), 
            nn.Tanh() # For images, we want the values to be between -1 and 1.
        )

    def forward(self, x):
        return self.gen(x)

class Discriminator(nn.Module):
    """Discriminator network.
    
    Attempts to classify real and fake images from the dataset and the Generator network.
    """
    def __init__(self, img_dim: int):
        """Initializes the Discriminator network.
        
        Args:
            img_dim (int): Dimension of the input space.
        """
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 512),
            nn.LeakyReLU(0.1),
            nn.Linear(512, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.disc(x)

# Init the models
gen_monet = Generator(img_dim).to(device)
gen_raw = Generator(img_dim).to(device)
disc_monet = Discriminator(img_dim).to(device)
disc_raw = Discriminator(img_dim).to(device)

In [4]:
# Datasets
class ImageDataset(Dataset):
    """Image dataset.
    
    Loads images from a directory.
    """
    def __init__(self, train_dir: str, test_dir: str, transforms: Optional[Callable] = None):
        """Initializes the ImageDataset.
        
        Args:
            train_dir (str): Directory containing the raw training images.
            test_dir (str): Directory containing the raw training images.
            transforms (Optional[Callable]): Optional transform to be applied on a sample.
        """
        self.train_dir = sorted(glob.glob(os.path.join(train_dir, "*.jpg")))
        self.test_dir = sorted(glob.glob(os.path.join(test_dir, "*.jpg")))
        self.transforms = transforms

    def __getitem__(self, index):
        train_image = Image.open(self.train_dir[index % len(self.train_dir)])
        test_image = Image.open(self.test_dir[index % len(self.test_dir)])
        if self.transforms:
            train_image = self.transforms(train_image)
            test_image = self.transforms(test_image)
        return train_image, test_image

    def __len__(self):
        return len(self.train_dir)
# Image Transforms
transforms = transforms.Compose([
    transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))
])
# Create the dataset
ds = ImageDataset(train_dir='data/photo_jpg/', test_dir='data/monet_jpg/', transforms=transforms)

# Create the data loader
dataloader = DataLoader(ds, batch_size=32, shuffle=True, num_workers=4)        

In [6]:
# Tensorboard.
writer_fake = SummaryWriter(f"runs/GAN_MONET/fake")
writer_real = SummaryWriter(f"runs/GAN_MONET/real")

# Optimizers
opt_gen = torch.optim.Adam(list(gen_monet.parameters()) + list(gen_raw.parameters()), lr=lr)
opt_disc = torch.optim.Adam(list(disc_monet.parameters()) + list(disc_raw.parameters()), lr=lr)

# Loss
L1 = nn.L1Loss() # Cycle consistency loss and identity loss
mse = nn.MSELoss() # Adversarial loss

step = 0
# Training Loop
for epoch in range(num_epochs):
    for idx, (raw_img, monet_img) in enumerate(tqdm(dataloader)):
        raw_img = raw_img.to(device)
        raw_img = raw_img.reshape(raw_img.shape[0], -1)
        monet_img = monet_img.to(device)
        monet_img = monet_img.reshape(monet_img.shape[0], -1)          


        # Train discriminators Monet and Raw
        fake_monet = gen_monet(raw_img)
        D_monet_real = disc_monet(monet_img)
        D_monet_fake = disc_monet(fake_monet.detach())
        D_monet_real_loss = mse(D_monet_real, torch.ones_like(D_monet_real))
        D_monet_fake_loss = mse(D_monet_fake, torch.zeros_like(D_monet_fake))
        D_monet_loss = D_monet_real_loss + D_monet_fake_loss

        fake_raw = gen_raw(monet_img)
        D_raw_real = disc_raw(raw_img)
        D_raw_fake = disc_raw(fake_raw.detach())
        D_raw_real_loss = mse(D_raw_real, torch.ones_like(D_raw_real))
        D_raw_fake_loss = mse(D_raw_fake, torch.zeros_like(D_raw_fake))
        D_raw_loss = D_raw_real_loss + D_raw_fake_loss

        D_loss = (D_monet_loss + D_raw_loss)/2

        opt_disc.zero_grad()
        D_loss.backward(retain_graph=True)
        opt_disc.step()

        # Train generators Monet and Raw
        D_monet_fake = disc_monet(fake_monet)
        D_raw_fake = disc_raw(fake_raw)
        loss_G_monet = mse(D_monet_fake, torch.ones_like(D_monet_fake)) # Adversarial loss
        loss_G_raw = mse(D_raw_fake, torch.ones_like(D_raw_fake)) # Adversarial loss

        # Cycle loss
        cycle_monet = gen_monet(fake_raw)
        cycle_raw = gen_raw(fake_monet)
        cycle_monet_loss = L1(monet_img, cycle_monet)
        cycle_raw_loss = L1(raw_img, cycle_raw)

        G_loss = (loss_G_monet + loss_G_raw) + (cycle_monet_loss + cycle_raw_loss)
        
        opt_gen.zero_grad()
        G_loss.backward()
        opt_gen.step()

        step += 1
        if idx % 100 == 0:
            print(f"Epoch [{epoch}/{num_epochs}] Batch {idx}/{len(dataloader)} \
                  Loss D: {D_loss:.4f}, loss G: {G_loss:.4f}")
            
            with torch.no_grad():
                fake = gen_monet(raw_img)
                # Unnormalise the image
                fake = fake.reshape(-1, 3, 256, 256)
                fake = fake * 0.5 + 0.5
                # Save image
                save_image(fake, f"outputs/monet_{step}.png")
       

    # Can we somehow see what the images look like as we train on tensorboard?
    # Save model
    torch.save(gen_monet.state_dict(), f'outputs/gen_monet_{epoch}.pth')
    torch.save(gen_raw.state_dict(), f'outputs/gen_raw_{epoch}.pth')
    torch.save(disc_monet.state_dict(), f'outputs/disc_monet_{epoch}.pth')
    torch.save(disc_raw.state_dict(), f'outputs/disc_raw_{epoch}.pth')

  0%|          | 0/220 [00:00<?, ?it/s]

Epoch [0/50] Batch 0/220                   Loss D: 0.0307, loss G: 2.9700


 45%|████▌     | 100/220 [00:22<00:25,  4.64it/s]

Epoch [0/50] Batch 100/220                   Loss D: 0.0121, loss G: 3.3856


 91%|█████████ | 200/220 [00:44<00:04,  4.64it/s]

Epoch [0/50] Batch 200/220                   Loss D: 0.0625, loss G: 3.3053


100%|██████████| 220/220 [00:49<00:00,  4.44it/s]
  0%|          | 0/220 [00:00<?, ?it/s]

Epoch [1/50] Batch 0/220                   Loss D: 0.0000, loss G: 3.2804


 45%|████▌     | 100/220 [00:22<00:25,  4.63it/s]

Epoch [1/50] Batch 100/220                   Loss D: 0.0000, loss G: 3.2420


 91%|█████████ | 200/220 [00:44<00:04,  4.63it/s]

Epoch [1/50] Batch 200/220                   Loss D: 0.1014, loss G: 3.2182


100%|██████████| 220/220 [00:49<00:00,  4.44it/s]
  0%|          | 0/220 [00:00<?, ?it/s]

Epoch [2/50] Batch 0/220                   Loss D: 0.0000, loss G: 3.2167


 45%|████▌     | 100/220 [00:22<00:26,  4.59it/s]

Epoch [2/50] Batch 100/220                   Loss D: 0.0000, loss G: 3.1510


 91%|█████████ | 200/220 [00:44<00:04,  4.56it/s]

Epoch [2/50] Batch 200/220                   Loss D: 0.0000, loss G: 3.0874


100%|██████████| 220/220 [00:49<00:00,  4.41it/s]
  0%|          | 0/220 [00:00<?, ?it/s]

Epoch [3/50] Batch 0/220                   Loss D: 0.0000, loss G: 3.0779


 24%|██▎       | 52/220 [00:12<00:39,  4.23it/s]


KeyboardInterrupt: 