In [1]:
# Imports
import torch
import torch.nn as nn 
from torchvision import transforms
import numpy as np
from  matplotlib import pyplot as plt
import torch.nn.functional as F
import time
from utils import get_data, show_img, show_batch

In [2]:
# Class for storing things such as learning rate, image size...
class Args:
    def __init__(self):
        self.lr = 2e-4
        self.epochs = 10
        self.b1 = 0.5
        self.b2 = 0.999
        self.latent_dim = 100
        self.img_size = 128
        self.pixels = int(self.img_size ** 2)
        self.channels = 3
        self.img_tuple = (self.channels, self.img_size, self.img_size)
        self.batch_size = 32
        self.g_fmap_size = 64 
        self.d_fmap_size = 64

In [3]:
# Loading the images.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
args = Args()
transform = transforms.Compose([transforms.Resize(args.img_size),
                                transforms.CenterCrop(args.img_size),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
paintings, photos, painting_loader, photo_loader = get_data(transform=transform, batch_size=32)

In [23]:
# Defining the DCGAN.
# Source: https://arxiv.org/pdf/1511.06434.pdf 
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.ConvTranspose2d(args.latent_dim, args.g_fmap_size * 8, 4, 1, 0, bias=False)
        self.conv2 = nn.ConvTranspose2d(args.g_fmap_size * 8, args.g_fmap_size * 4, 4, 2, 1, bias=False) 
        self.conv3 = nn.ConvTranspose2d(args.g_fmap_size * 4, args.g_fmap_size * 2, 4, 2, 1, bias=False)
        self.conv4 = nn.ConvTranspose2d(args.g_fmap_size * 2, args.g_fmap_size, 4, 2, 1, bias=False)
        self.conv5 = nn.ConvTranspose2d(args.g_fmap_size, args.channels, 4, 2, 1, bias=False)
        self.norm1 = nn.BatchNorm2d(args.g_fmap_size * 8)
        self.norm2 = nn.BatchNorm2d(args.g_fmap_size * 4)
        self.norm3 = nn.BatchNorm2d(args.g_fmap_size * 2)
        self.norm4 = nn.BatchNorm2d(args.g_fmap_size)

    def forward(self, x):
        x = self.conv1(x)
        x = self.norm1(x)
        x = x.relu(True)
        x = self.conv2(x)
        x = self.norm2(x)
        x = x.relu(True)
        x = self.conv3(x)
        x = self.norm3(x)
        x = x.relu(True)
        x = self.conv4(x)
        x = self.norm4(x)
        x = x.relu(True)
        x = self.conv5(x)
        return F.tanh(x)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(args.channels, args.d_fmap_size, 4, 2, 1, bias=False)
        self.conv2 = nn.Conv2d(args.d_fmap_size, args.d_fmap_size * 2, 4, 2, 1, bias=False)
        self.conv3 = nn.Conv2d(args.d_fmap_size * 2, args.d_fmap_size * 4, 4, 2, 1, bias=False)
        self.conv4 = nn.Conv2d(args.d_fmap_size * 4, args.d_fmap_size * 8, 4, 2, 1, bias=False)
        self.conv5 = nn.Conv2d(args.channels * 8, 1, 4, 1, 0, bias=False)
        self.norm1 = nn.BatchNorm2d(args.d_fmap_size * 2)
        self.norm2 = nn.BatchNorm2d(args.d_fmap_size * 4)
        self.norm3 = nn.BatchNorm2d(args.d_fmap_size * 8)
    
    def forward(self, x):
        x = self.conv1(x)
        x = F.leaky_relu(x, 0.2, inplace=True)
        x = self.conv2(x)
        x = self.norm1(x)
        x = F.leaky_relu(x, 0.2, inplace=True)
        x = self.conv3(x)
        x = self.norm2(x)
        x = F.leaky_relu(x, 0.2, inplace=True)
        x = self.conv4(x)
        x = self.norm3(x)
        x = F.leaky_relu(x, 0.2, inplace=True)
        x = self.conv5(x)
        return F.sigmoid(x)

# Weight initialization from the DCGAN paper. 
def init_weights(m):
    class_name = m.__class__.__name__
    if class_name.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif class_name.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [24]:
# Initializing all the relevant components to the DCGAN. 
generator = Generator().to(device)
generator.apply(init_weights)  # Why does this also print the model??
discriminator = Discriminator().to(device)
discriminator.apply(init_weights)
loss_fn = nn.BCELoss()
g_optim = torch.optim.Adam(generator.parameters(), lr=args.lr, betas=(args.b1, args.b2))
d_optim = torch.optim.Adam(discriminator.parameters(), lr=args.lr, betas=(args.b1, args.b2))

Discriminator(
  (conv1): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (conv2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (conv3): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (conv4): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (conv5): Conv2d(24, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
  (norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (norm2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (norm3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In [None]:
# Training loop. 
g_loss_list, d_loss_list = [], []
for epoch in range(1, args.epochs + 1):
    g_epoch_loss = 0
    d_epoch_loss = 0
    t = time.time()
    for i, images in enumerate(painting_loader):
        real_label = torch.ones(images.size(0), 1, dtype=torch.float).to(device)
        fake_label = torch.zeros(images.size(0), 1, dtype=torch.float).to(device)

        ###################
        # DISCRIMINATOR
        ###################
        d_optim.zero_grad()

        # Feeding real images.
        images = images.to(device)
        out = discriminator(images).view(-1)
        d_real_loss = loss_fn(out, real_label)
        d_real_loss.backward()
        
        # Feeding fake images. 
        noise_vector = torch.randn(images.size(0), args.latent_dim, 1, 1).to(device)
        fake = generator(noise_vector)
        out = discriminator(fake.detach()).view(-1)
        d_fake_loss = loss_fn(out, fake_label)
        d_fake_loss.backward()
        d_loss =  d_real_loss + d_fake_loss
        d_optim.step()

        ###################
        # DISCRIMINATOR
        ###################
        g_optim.zero_grad()
        out = discriminator(fake).view(-1)
        g_loss = loss_fn(out, real_label)
        g_loss.backward()
        g_optim.step()

        d_epoch_loss += d_loss.item()
        g_epoch_loss += g_loss.item()
    
    g_epoch_loss /= len(painting_loader)
    d_epoch_loss /= len(painting_loader)
    print(f"EPOCH: {epoch}\nExecution time: {time.time() - t:.1f} sec.\nGenerator loss: {g_epoch_loss:.3f}\nDiscriminator loss: {d_epoch_loss:.3f}\n")
    g_loss_list.append(g_epoch_loss)
    d_loss_list.append(d_epoch_loss)

In [None]:
#  Plotting the loss and some generated images. 
plt.plot(g_loss_list, label="Generator loss")
plt.plot(d_loss_list, label="Discriminator loss")
plt.legend()
plt.show()

with torch.no_grad():
    noise_vector = torch.randn(images.size(0), args.latent_dim, 1, 1).to(device)
    fake = generator(noise_vector).detach().cpu()
    show_batch(fake)