In [1]:
import torch
from   torch import optim
import torchvision.datasets   as datasets
import torchvision.transforms as transforms
import torchvision.utils      as vutils
from   torch.utils.data import DataLoader

import matplotlib.pyplot      as plt
import numpy                  as np

from   model import Generator
from   model import Critic
from   tqdm  import tqdm
from   math  import log2

from   IPython.display import clear_output

In [2]:
DATA_DIR        = "../data/flikr_hq/"
GEN_CHECKPOINT  = "gen.pth"
CRIT_CHECKPOINT = "crit.pth"
DEVICE          = "cuda"
SAVE_MODEL      = True
LOAD_MODEL      = False
START_IMG_SIZE  = 4
LEARNING_RATE   = 1e-3
BATCH_SIZES     = [512, 512, 256, 128, 64, 32]
IMG_CHANNELS    = 3
Z_DIM           = 512
IN_CHANNELS     = 512
LAMBDA_GP       = 10
PROG_EPOCHS     = [24] * len(BATCH_SIZES)
NUM_WORKERS     = 24

torch.backends.cudnn.benchmarks = True

In [3]:
def gradient_penalty(critic, real, fake, alpha, train_step, device):
    batch_size, img_channels, h, w = real.shape
    beta = torch.rand((batch_size, 1, 1, 1)).repeat(1, img_channels, h, w).to(device)
    mixed_imgs = real * beta + fake.detach() * (1 - beta)
    mixed_imgs.requires_grad_(True)
    mixed_scores = critic(mixed_imgs, alpha, train_step)

    gradient = torch.autograd.grad(
        inputs=mixed_imgs,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True
    )[0]

    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    penalty = torch.mean((gradient_norm - 1) ** 2)
    
    return penalty


In [4]:
def save_checkpoint(model, optim, filename="progan_ckpnt.pth"):
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optim.state_dict()
    }
    torch.save(checkpoint, filename)


def load_checkpoint(filename, model, optim, lr):
    checkpoint = torch.load(filename, map_location="cuda")
    model.load_state_dict(checkpoint["state_dict"])
    optim.load_state_dict(checkpoint["optimizer"])

    for group in optim.param_groups:
        group["lr"] = lr

In [5]:
def get_loader(img_size):
    transform = transforms.Compose([
        transforms.Resize(img_size),
        transforms.ToTensor(),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.Normalize(
            [0.5 for _ in range(IMG_CHANNELS)],
            [0.5 for _ in range(IMG_CHANNELS)]
        )
    ])

    dataset = datasets.ImageFolder(root=DATA_DIR, transform=transform)
    loader = DataLoader(dataset, batch_size=BATCH_SIZES[int(log2(img_size/4))], shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)

    return loader, dataset

In [6]:
gen = Generator(Z_DIM, IN_CHANNELS, IMG_CHANNELS).to(DEVICE)
crit = Critic(IN_CHANNELS, IMG_CHANNELS).to(DEVICE)

gen_optim = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99))
crit_optim = optim.Adam(crit.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99))

gen_scalar = torch.cuda.amp.GradScaler()
crit_scalar = torch.cuda.amp.GradScaler()

if LOAD_MODEL:
    print("### Loading checkpoints ###")
    load_checkpoint(GEN_CHECKPOINT, gen, gen_optim, LEARNING_RATE)
    load_checkpoint(CRIT_CHECKPOINT, crit, crit_optim, LEARNING_RATE)

gen.train()
crit.train()

step = int(log2(START_IMG_SIZE / 4))

for epochs in PROG_EPOCHS[step:]:
    alpha = 0
    loader, dataset = get_loader(4 * 2 ** step)
    clear_output()
    print(f"\nImage size: {4 * 2 ** step}")

    for epoch in range(epochs):
        print(f"Epoch [{epoch + 1}/{epochs}]")

        loop = tqdm(loader, leave=True)
        for batch_idx, (real, _) in enumerate(loop):
            real = real.to(DEVICE)
            cur_batch_size = real.shape[0]

            # Train critic
            noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(DEVICE)

            with torch.cuda.amp.autocast():
                fake = gen(noise, alpha, step)
                real_pred = crit(real, alpha, step)
                fake_pred = crit(fake.detach(), alpha, step)
                gp = gradient_penalty(crit, real, fake, alpha, step, DEVICE)
                crit_loss = -(torch.mean(real_pred) - torch.mean(fake_pred)) + LAMBDA_GP * gp + (0.001 * torch.mean(real_pred ** 2))

            crit_optim.zero_grad()
            crit_scalar.scale(crit_loss).backward()
            crit_scalar.step(crit_optim)
            crit_scalar.update()

            # Train generator
            with torch.cuda.amp.autocast():
                gen_fake = crit(fake, alpha, step)
                gen_loss = -torch.mean(gen_fake)

            gen_optim.zero_grad()
            gen_scalar.scale(gen_loss).backward()
            gen_scalar.step(gen_optim)
            gen_scalar.update()

            if alpha < 1:
                alpha += cur_batch_size / ((PROG_EPOCHS[step] * 0.5) * len(dataset))
                alpha = min(alpha, 1)

            loop.set_postfix(
                gp=gp.item(), 
                crit_loss=crit_loss.item(), 
                gen_loss=gen_loss.item(),
                alpha=alpha
            )

            
        if SAVE_MODEL:
            print("### Saving checkpoints ###")
            save_checkpoint(gen, gen_optim, filename=GEN_CHECKPOINT)
            save_checkpoint(crit, crit_optim, filename=CRIT_CHECKPOINT)

        plt.figure(figsize=(8, 8))
        plt.axis("off")
        plt.title("Sample Generation")
        plt.imshow(np.transpose(vutils.make_grid(fake.to(DEVICE)[:16], padding=2, normalize=True).cpu(), (1, 2, 0)))
        plt.show()

    step += 1


Image size: 4
Epoch [1/24]


 88%|████████▊ | 120/137 [01:17<00:10,  1.55it/s, alpha=0.0731, crit_loss=-1.67, gen_loss=1.75, gp=0.0254]


KeyboardInterrupt: 