In [1]:
import os
import time
import gc

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision.datasets as dset
from torchvision import transforms
from torchvision.utils import save_image, make_grid
import torchvision.utils as vutils

from torch.utils.data import DataLoader

import matplotlib.pyplot as plt

from gen_model import Generator
from disc_model import Discriminator
from utils import *
from step_assert import AssertStep

In [2]:
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Current device: {torch.cuda.get_device_name(device)}")

Current device: GeForce GTX 1080


In [3]:
# training parameters
batch_size = [256, 256, 128, 64, 64]
betas = (0.0, 0.99)
noise_dim = 256
step = 1
max_steps = 4
res_list = [4, 8, 16, 32, 64]
fade_size = 800_000
lr = 0.005
phase_size = [800_000, 1_600_000, 2_000_000, 2_500_000, 3_500_000]
GRAD_VAL = 10
disc_train_count = 5

In [4]:
# model definitions
generator = Generator(noise_dim).to(device)
generator.apply(weights_init)
g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=betas)

discriminator = Discriminator().to(device)
discriminator.apply(weights_init)
d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=betas)

In [5]:
g_pr = large_num_period(count_parameters(generator))
print(f"Trainable Generator parameters: {g_pr}")
generator

Trainable Generator parameters: 2.229.392


Generator(
  (tanh): Tanh()
  (inp): GenBlock(
    (conv1): ConvTranspose2d(256, 256, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (leaky1): LeakyReLU(negative_slope=0.2)
    (conv2): ConvTranspose2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (leaky2): LeakyReLU(negative_slope=0.2)
  )
  (conv1): GenBlock(
    (upsample): Upsample(scale_factor=2.0, mode=nearest)
    (conv1): ConvTranspose2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (leaky1): LeakyReLU(negative_slope=0.2)
    (conv2): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track

In [6]:
d_pr = large_num_period(count_parameters(discriminator))
print(f"Trainable Discriminator parameters: {d_pr}")
discriminator

Trainable Discriminator parameters: 2.229.968


Discriminator(
  (conv1): DiscBlock(
    (leaky): LeakyReLU(negative_slope=0.2)
    (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (norm1): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (norm2): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
  )
  (conv2): DiscBlock(
    (leaky): LeakyReLU(negative_slope=0.2)
    (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (norm1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (pool): AvgPool2d(kernel_size=2, 

In [7]:
preview_noise = torch.randn(64, noise_dim).cuda()
def generate_and_save_images(iteration):
    fake_folder = r"C:\Users\Johnny\Desktop\PROGAN\intermediate_images"
    fake_img_path = os.path.join(fake_folder, f"iteration{iteration}resolution{res_list[step]}x{res_list[step]}")
    with torch.no_grad():
        images = generator(preview_noise, step=step, alpha=alpha).detach().cpu()
        images = np.transpose(vutils.make_grid(images, padding=2, normalize=True), (1,2,0))
    plt.figure(figsize=(8,8))
    plt.axis("off")
    plt.title("Training Images")
    plt.imshow(images)
    plt.savefig(fake_img_path)
    plt.close()
    
def generate_final_images(num=128):
    image_folder = r"C:\Users\Johnny\Desktop\PROGAN\final_images"
    with torch.no_grad():
        noise = torch.randn(num, noise_dim).cuda()
        images = generator(noise, step=step, alpha=alpha).detach().cpu()
    for i in range(images.shape[0]):
        image = images[i]
        save_image(image,
                   os.path.join(image_folder, f"image{i}.jpg"),
                   normalize=True,
                   range=(-1, 1))
        
val_noise = torch.randn(256, noise_dim).cuda()
def calcMFID(real_images):
    with torch.no_grad():
        real_predict = discriminator(real_images, step=step, alpha=alpha).mean()
        val_images = generator(val_noise, step=step, alpha=alpha)
        val_predict = discriminator(val_images, step=step, alpha=alpha).mean()
        MFID_score = real_predict - val_predict
    return MFID_score.item()

In [8]:
def train(iterations=10_000_000):
    global step, alpha
    early_stopper = AssertStep(tolerance=4, buffer=2)
    g_losses = []
    d_losses = []
    data = new_dataloader(batch_size[step], res_list[step])
    used_samples = 0
    start = time.time()
    print("Starting Training Loop...")
    mfid_batch = next(iter(data))[0].cuda().float()
    step_start = time.time()
    for current_iteration in range(iterations):
            if used_samples > phase_size[step]:
                torch.save(generator, "generator.pt")
                torch.save(discriminator, "discriminator.pt")
                step_time_taken = (time.time() - step_start) // 60
                step_start = time.time()
                print(f"Time taken for resolution {res_list[step]}x{res_list[step]} is {step_time_taken} minutes, Used Samples: {used_samples}, loss_MFID: {calcMFID(mfid_batch)}")
                print()
                generate_and_save_images(current_iteration)
                samples = large_num_period(used_samples)
                used_samples = 0
                step += 1
                if step > max_steps:
                    step = max_steps 
                    break
                adjust_lr(d_optimizer, lr)
                adjust_lr(g_optimizer, lr)
                data = new_dataloader(batch_size[step], res_list[step])
                mfid_batch = next(iter(data))[0].cuda().float()
                loader = iter(data)
                
            alpha = min([1, (used_samples + 1) /  fade_size]) if step > 1 else 1
            try:
                batch = next(loader)
            except (NameError, StopIteration):
                loader = iter(data)
                batch = next(loader)

            real_images = batch[0].cuda().float()
            current_bs = real_images.shape[0]
            real_predict = discriminator(real_images, step=step, alpha=alpha).mean()

            # random noise vector sampling values of gaussian distribution
            gen_noise = torch.randn(current_bs, noise_dim).cuda()
            gen_imgs = generator(gen_noise, step=step, alpha=alpha)

            fake_predict = discriminator(gen_imgs, step=step, alpha=alpha).mean()

            # wgan-gp loss for discriminator - maximize (d(r) - d(f)) -> wasserstein distance
            gp = gradient_penalty(discriminator, real_images, gen_imgs, step, device=device, alpha=alpha)
            disc_loss = -(real_predict - fake_predict) + (GRAD_VAL * gp)

            discriminator.zero_grad()
            if current_iteration % disc_train_count == 0:
                disc_loss.backward(retain_graph=True)
            else:
                disc_loss.backward()
            d_optimizer.step()
            
            if current_iteration % disc_train_count == 0:
                used_samples += current_bs
                # do another forward pass on fake images
                gen_predict = discriminator(gen_imgs, step=step, alpha=alpha).mean()
                # g loss - maximize d(f)
                gen_loss = -gen_predict
                generator.zero_grad()
                gen_loss.backward()
                g_optimizer.step()

            if current_iteration % 5_000 == 0:
                mfid = calcMFID(mfid_batch)
                
                    
                g_losses.append((real_predict - fake_predict).detach().item())
                d_losses.append(disc_loss.detach().item())
                generate_and_save_images(current_iteration)
                plot_losses(g_losses, d_losses)
                samples = large_num_period(used_samples)
                iter_nr = large_num_period(current_iteration)
                print(f"[{iter_nr}] Resolution: {res_list[step]}x{res_list[step]}, loss_MFID: {mfid}, Samples: {samples}, alpha: {alpha}, Time: {(time.time()-start) // 60} minutes")
    print(f"Training took {(time.time()-start) // 60} minutes.")

In [9]:
train()
# save generator model
torch.save(generator, "generator.pt")
torch.save(discriminator, "discriminator.pt")
# generate images
generate_final_images()

Starting Training Loop...
[0] Resolution: 8x8, loss_MFID: -6.335700988769531, Samples: 256, alpha: 1, Time: 0.0 minutes
[5.000] Resolution: 8x8, loss_MFID: -0.6701259613037109, Samples: 256.103, alpha: 1, Time: 12.0 minutes
[10.000] Resolution: 8x8, loss_MFID: -0.18227386474609375, Samples: 511.950, alpha: 1, Time: 26.0 minutes
[15.000] Resolution: 8x8, loss_MFID: 0.06898641586303711, Samples: 767.644, alpha: 1, Time: 39.0 minutes
[20.000] Resolution: 8x8, loss_MFID: 0.05788230895996094, Samples: 1.023.491, alpha: 1, Time: 52.0 minutes
[25.000] Resolution: 8x8, loss_MFID: -0.19026947021484375, Samples: 1.279.338, alpha: 1, Time: 66.0 minutes
[30.000] Resolution: 8x8, loss_MFID: 0.04847431182861328, Samples: 1.535.185, alpha: 1, Time: 79.0 minutes
Time taken for resolution 8x8 is 82.0 minutes, Used Samples: 1600056, loss_MFID: 0.005260467529296875

[35.000] Resolution: 16x16, loss_MFID: 1.9728317260742188, Samples: 95.488, alpha: 0.11920125, Time: 91.0 minutes
[40.000] Resolution: 16x16

[355.000] Resolution: 64x64, loss_MFID: 1.4723968505859375, Samples: 643.189, alpha: 0.8039075, Time: 853.0 minutes
[360.000] Resolution: 64x64, loss_MFID: 0.949920654296875, Samples: 707.189, alpha: 0.8839075, Time: 872.0 minutes
[365.000] Resolution: 64x64, loss_MFID: 1.0304718017578125, Samples: 771.189, alpha: 0.9639075, Time: 891.0 minutes
[370.000] Resolution: 64x64, loss_MFID: 1.9460296630859375, Samples: 835.164, alpha: 1, Time: 909.0 minutes
[375.000] Resolution: 64x64, loss_MFID: 1.6615447998046875, Samples: 899.164, alpha: 1, Time: 927.0 minutes
[380.000] Resolution: 64x64, loss_MFID: 1.8035888671875, Samples: 963.164, alpha: 1, Time: 945.0 minutes
[385.000] Resolution: 64x64, loss_MFID: 2.1310272216796875, Samples: 1.027.139, alpha: 1, Time: 963.0 minutes
[390.000] Resolution: 64x64, loss_MFID: 1.4591064453125, Samples: 1.091.139, alpha: 1, Time: 981.0 minutes
[395.000] Resolution: 64x64, loss_MFID: 1.471527099609375, Samples: 1.155.139, alpha: 1, Time: 999.0 minutes
[400.0

In [10]:
generate_final_images(1000)

In [12]:

def generate_and_save_images2():
    preview_noise = torch.randn(64, noise_dim).cuda()
    fake_folder = r"C:\Users\Johnny\Desktop\PROGAN\""
    fake_img_path = os.path.join(fake_folder, f"resolution{res_list[step]}x{res_list[step]}")
    with torch.no_grad():
        images = generator(preview_noise, step=step, alpha=alpha).detach().cpu()
        images = np.transpose(vutils.make_grid(images, padding=2, normalize=True), (1,2,0))
    plt.figure(figsize=(8,8))
    plt.axis("off")
    plt.title("Training Images")
    plt.imshow(images)
    plt.savefig(fake_img_path)
    plt.close()
generate_and_save_images2()

NameError: name 'iteration' is not defined