In [None]:
# Imports
import torch
import torch.nn as nn 
from torchvision import transforms
import numpy as np
from  matplotlib import pyplot as plt
import torch.nn.functional as F
import time
from utils import get_data, show_batch, accuracy
from PIL import Image

In [None]:
# Class for storing things such as learning rate, image size...
class Args:
    def __init__(self):
        self.lr = 2e-4
        self.epochs = 1000
        self.b1 = 0.5
        self.b2 = 0.999
        self.latent_dim = 100
        self.img_size = 64
        self.pixels = int(self.img_size ** 2)
        self.channels = 3
        self.img_tuple = (self.channels, self.img_size, self.img_size)
        self.batch_size = 64
        self.g_fmap_size = 64 
        self.d_fmap_size = 64
        self.d_loss_threshold = 0.15
args = Args()

In [None]:
# Loading the images.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform = transforms.Compose([
    transforms.Resize(int(args.img_size*1.12), Image.BICUBIC),
    transforms.RandomCrop((args.img_size, args.img_size)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
paintings, photos, painting_loader, photo_loader = get_data(transform=transform, batch_size=args.batch_size)

In [None]:
# Defining the DCGAN.
# Source: https://arxiv.org/pdf/1511.06434.pdf
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.ConvTranspose2d(args.latent_dim, args.g_fmap_size * 8, 4, 1, 0, bias=False)
        self.conv2 = nn.ConvTranspose2d(args.g_fmap_size * 8, args.g_fmap_size * 4, 4, 2, 1, bias=False)
        self.conv3 = nn.ConvTranspose2d(args.g_fmap_size * 4, args.g_fmap_size * 2, 4, 2, 1, bias=False)
        self.conv4 = nn.ConvTranspose2d(args.g_fmap_size * 2, args.g_fmap_size, 4, 2, 1, bias=False)
        self.conv5 = nn.ConvTranspose2d(args.g_fmap_size, args.channels, 4, 2, 1, bias=False)
        self.norm1 = nn.BatchNorm2d(args.g_fmap_size * 8)
        self.norm2 = nn.BatchNorm2d(args.g_fmap_size * 4)
        self.norm3 = nn.BatchNorm2d(args.g_fmap_size * 2)
        self.norm4 = nn.BatchNorm2d(args.g_fmap_size)

    def forward(self, x):
        x = self.conv1(x)
        x = self.norm1(x)
        x = x.relu()
        x = self.conv2(x)
        x = self.norm2(x)
        x = x.relu()
        x = self.conv3(x)
        x = self.norm3(x)
        x = x.relu()
        x = self.conv4(x)
        x = self.norm4(x)
        x = x.relu()
        x = self.conv5(x)
        return F.tanh(x)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(args.channels, args.d_fmap_size, 4, 2, 1, bias=False)
        self.conv2 = nn.Conv2d(args.d_fmap_size, args.d_fmap_size * 2, 4, 2, 1, bias=False)
        self.conv3 = nn.Conv2d(args.d_fmap_size * 2, args.d_fmap_size * 4, 4, 2, 1, bias=False)
        self.conv4 = nn.Conv2d(args.d_fmap_size * 4, args.d_fmap_size * 8, 4, 2, 1, bias=False)
        self.conv5 = nn.Conv2d(args.d_fmap_size * 8, 1, 4, 1, 0, bias=False)
        self.norm1 = nn.BatchNorm2d(args.d_fmap_size * 2)
        self.norm2 = nn.BatchNorm2d(args.d_fmap_size * 4)
        self.norm3 = nn.BatchNorm2d(args.d_fmap_size * 8)

    def forward(self, x):
        x = self.conv1(x)
        x = F.leaky_relu(x, 0.2, inplace=True)
        x = self.conv2(x)
        x = self.norm1(x)
        x = F.leaky_relu(x, 0.2, inplace=True)
        x = self.conv3(x)
        x = self.norm2(x)
        x = F.leaky_relu(x, 0.2, inplace=True)
        x = self.conv4(x)
        x = self.norm3(x)
        x = F.leaky_relu(x, 0.2, inplace=True)
        x = self.conv5(x)
        return F.sigmoid(x)

# Weight initialization from the DCGAN paper.
def init_weights(m):
    class_name = m.__class__.__name__
    if class_name.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif class_name.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
# Initializing all the relevant components to the DCGAN. 
generator = Generator().to(device)
generator.apply(init_weights)
discriminator = Discriminator().to(device)
discriminator.apply(init_weights)
loss_fn = nn.BCELoss()
g_optim = torch.optim.Adam(generator.parameters(), lr=args.lr, betas=(args.b1, args.b2))
d_optim = torch.optim.Adam(discriminator.parameters(), lr=args.lr, betas=(args.b1, args.b2))

In [None]:
# Training loop.
g_loss_list, d_loss_list, accuracy_list = [], [], []
fixed_noise = torch.randn(16, args.latent_dim, 1, 1, device=device)

for epoch in range(1, args.epochs + 1):
    g_epoch_loss = 0
    d_epoch_loss = 0
    t = time.time()
    d_prev_loss = 1
    acc = 0
    for i, images in enumerate(painting_loader):
        real_label = torch.ones(images.size(0), dtype=torch.float).to(device)
        real_disc_label = torch.ones(images.size(0), dtype=torch.float).to(device) - 0.1 # Label smoothing.
        fake_label = torch.zeros(images.size(0), dtype=torch.float).to(device)

        ###################
        # DISCRIMINATOR
        ###################
        discriminator.train()
        d_optim.zero_grad()

        # Feeding real images.
        images = images.to(device)
        out = discriminator(images).view(-1)
        d_real_loss = loss_fn(out, real_disc_label)
        if d_prev_loss > args.d_loss_threshold:
            d_real_loss.backward()

        # Feeding fake images.
        noise_vector = torch.randn(images.size(0), args.latent_dim, 1, 1, device=device)
        fake = generator(noise_vector)
        out = discriminator(fake.detach()).view(-1)
        d_fake_loss = loss_fn(out, fake_label)
        if d_prev_loss > args.d_loss_threshold:
            d_fake_loss.backward()
        
        acc += accuracy(fake_label, out)
        d_loss =  d_real_loss + d_fake_loss
        d_optim.step()

        ###################
        # GENERATOR
        ###################
        generator.train()
        g_optim.zero_grad()
        out = discriminator(fake).view(-1)
        g_loss = loss_fn(out, real_label)
        g_loss.backward()
        g_optim.step()

        d_epoch_loss += d_loss.item()
        g_epoch_loss += g_loss.item()
        d_prev_loss = d_loss.item() 

    g_epoch_loss /= len(painting_loader)
    d_epoch_loss /= len(painting_loader)
    acc /= len(painting_loader)
    g_loss_list.append(g_epoch_loss)
    d_loss_list.append(d_epoch_loss)
    accuracy_list.append(acc)

    # Generating some images to show later. 
    if epoch % 25 == 0 or epoch == 1:
        print(f"EPOCH: {epoch}\nExecution time: {time.time() - t:.1f} sec.\nGenerator loss: {g_epoch_loss:.3f}\nDiscriminator loss: {d_epoch_loss:.3f}\nDiscriminator accuracy: {acc:.3f}")
        generator.eval()
        with torch.no_grad():
            out = generator(fixed_noise).detach().cpu()
            show_batch(out)

In [None]:
#  Plotting the loss and some generated images.
plt.plot(g_loss_list, label="Generator loss")
plt.plot(d_loss_list, label="Discriminator loss")
plt.legend()
plt.show()
plt.cla()
plt.plot(accuracy_list, label="Discriminator accuracy")
plt.show()

with torch.no_grad():
    noise_vector = torch.randn(16, args.latent_dim, 1, 1).to(device)
    fake = generator(noise_vector).detach().cpu()
    show_batch(fake)