In [None]:
import os
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torchvision.transforms as transforms
import numpy as np
from data_utils.dataset import LSUNDataset
from datetime import datetime
from custom_nets.dcgan import Generator, Discriminator

# Set random seed for reproducibility
manualSeed = 42

random.seed(manualSeed)
torch.manual_seed(manualSeed)

In [None]:
class Hyperparams():
  def __init__(self, learning_rate, batch_size, beta1):
    self.learning_rate = learning_rate
    self.batch_size = batch_size
    self.beta1 = beta1

  def get_network_params_name(self, name):
    return f'{name}_lr_{self.learning_rate}_bs_{self.batch_size}_b_{self.beta1}'

In [None]:
# Root directory for dataset
dataroot = "../data/data"

# Number of workers for dataloader
workers = 1

# Spatial size of training images. All images will be resized to this
#   size using a transformer.
image_size = 64

# Number of channels in the training images. For color images this is 3
nc = 3

# Size of z latent vector (i.e. size of generator input)
nz = 100

# Size of feature maps in generator
ngf = 64

# Size of feature maps in discriminator
ndf = 64

# Number of training epochs
num_epochs = 5

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1

In [None]:
hyperparams = Hyperparams(
    learning_rate=0.001,
    batch_size=128,
    beta1=0.5
)

In [None]:
# Create the dataset
dataset = LSUNDataset('../data/images.txt', dataroot,
    transform=transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]))

# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=hyperparams.batch_size,
                                        shuffle=True, num_workers=workers)

# Decide which device we want to run on
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# Custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
# Create the generator
netG = Generator(ngpu, nz, ngf, nc).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.02.
netG.apply(weights_init)

# Print the model
print(netG)

In [None]:
# Create the Discriminator
netD = Discriminator(ngpu, nc, ndf).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netD = nn.DataParallel(netD, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netD.apply(weights_init)

# Print the model
print(netD)

In [None]:
# Initialize BCELoss function
criterion = nn.BCELoss()

# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=hyperparams.learning_rate, betas=(hyperparams.beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=hyperparams.learning_rate, betas=(hyperparams.beta1, 0.999))

In [None]:
# Training Loop

# Lists to keep track of progress
G_losses = []
D_losses = []
D_real_mean_out = []
D_fake_mean_out = []
iters = 0
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Compute error of D as sum over the fake and the real batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                    errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # Save mean scores for plotting later
        D_real_mean_out.append(D_x)
        D_fake_mean_out.append(D_G_z1)

        iters += 1

    if not os.path.exists('../nets'):
        os.mkdir('../nets')
    torch.save(netD.state_dict(), f'../nets/{hyperparams.get_network_params_name(f"dcgan_netD_{format(timestamp)}")}')
    torch.save(netG.state_dict(), f'../nets/{hyperparams.get_network_params_name(f"dcgan_netG_{format(timestamp)}")}')

    if not os.path.exists('../loss'):
        os.mkdir('../loss')
    np.savetxt(f'../loss/{hyperparams.get_network_params_name(f"dcgan_netD_{format(timestamp)}")}.txt', np.array(D_losses))
    np.savetxt(f'../loss/{hyperparams.get_network_params_name(f"dcgan_netG_{format(timestamp)}")}.txt', np.array(G_losses))

    if not os.path.exists('../mean_out'):
        os.mkdir('../mean_out')
    np.savetxt(f'../mean_out/{hyperparams.get_network_params_name(f"dcgan_netD_real_{format(timestamp)}")}.txt', np.array(D_real_mean_out))
    np.savetxt(f'../mean_out/{hyperparams.get_network_params_name(f"dcgan_netD_fake_{format(timestamp)}")}.txt', np.array(D_fake_mean_out))