In [1]:
import torch
from torch import nn, optim
from torchvision import datasets, transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from math import log2
import numpy as np
import os
from tqdm import tqdm
import matplotlib.pyplot as plt

In [2]:
!unzip -qq /content/drive/MyDrive/celeba.zip

In [71]:
DATASET                 = "/content/img_align_celeba"
START_TRAIN_AT_IMG_SIZE = 8 #The authors start from 8x8 images instead of 4x4
DEVICE                  = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE           = 1e-3
BATCH_SIZES             = [256, 128, 64, 32, 16, 8]
CHANNELS_IMG            = 3
Z_DIM                   = 256
W_DIM                   = 256
IN_CHANNELS             = 256
LAMBDA_GP               = 10
PROGRESSIVE_EPOCHS      = [30] * len(BATCH_SIZES)

In [72]:
def get_loader(image_size):
  transform = transforms.Compose(
        [
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.Normalize(
                [0.5 for _ in range(CHANNELS_IMG)],
                [0.5 for _ in range(CHANNELS_IMG)],
            ),
        ]
    )
  batch_size = BATCH_SIZES[int(log2(image_size / 4))]
  dataset = datasets.ImageFolder(root=DATASET, transform=transform)
  loader = DataLoader(
      dataset,
      batch_size=batch_size,
      shuffle=True
  )
  return loader, dataset

In [73]:
factors = [1, 1, 1, 1, 1 / 2, 1 / 4, 1 / 8, 1 / 16, 1 / 32]

In [74]:
class WSLinear(nn.Module):
  def __init__(self, in_features, out_features):
    super(WSLinear, self).__init__()
    self.linear = nn.Linear(in_features, out_features)
    self.scale = (2 / in_features) ** 0.5
    self.bias = self.linear.bias
    self.linear.bias = None

    nn.init.normal_(self.linear.weight)
    nn.init.zeros_(self.bias)

  def forward(self, x):
    return self.linear(x * self.scale) + self.bias

In [75]:
class MappingNetwork(nn.Module):
  def __init__(self, z_dim, w_dim):
    super(MappingNetwork, self).__init__()
    self.mapping = nn.Sequential(
        PixelNorm(),
        WSLinear(z_dim, w_dim),
        nn.ReLU(),
        WSLinear(w_dim, w_dim),
        nn.ReLU(),
        WSLinear(w_dim, w_dim),
        nn.ReLU(),
        WSLinear(w_dim, w_dim),
        nn.ReLU(),
        WSLinear(w_dim, w_dim),
        nn.ReLU(),
        WSLinear(w_dim, w_dim),
        nn.ReLU(),
        WSLinear(w_dim, w_dim),
        nn.ReLU(),
        WSLinear(w_dim, w_dim),
    )

  def forward(self, x):
    return self.mapping(x)

In [76]:
class AdaIN(nn.Module):
  def __init__(self, channels, w_dim):
    super(AdaIN, self).__init__()
    self.instance_norm = nn.InstanceNorm2d(channels)
    self.style_scale = WSLinear(w_dim, channels)
    self.style_bias = WSLinear(w_dim, channels)

  def forward(self, x, w):
    x = self.instance_norm(x)
    style_scale = self.style_scale(w).unsqueeze(2).unsqueeze(3)
    style_bias = self.style_bias(w).unsqueeze(2).unsqueeze(3)
    return style_scale * x + style_bias

In [77]:
class InjectNoise(nn.Module):
  def __init__(self, channels):
    super(InjectNoise, self).__init__()
    self.weight = nn.Parameter(torch.zeros(1, channels, 1, 1))

  def forward(self, x):
    noise = torch.randn((x.shape[0], 1, x.shape[2], x.shape[3]), device=x.device)
    return x + self.weight * noise

In [78]:
class WSConv2d(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
    super(WSConv2d, self).__init__()
    self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
    self.scale = (2 / (in_channels * (kernel_size ** 2))) ** 0.5
    self.bias = self.conv.bias
    self.conv.bias = None

    nn.init.normal_(self.conv.weight)
    nn.init.zeros_(self.bias)

  def forward(self, x):
    return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)

In [79]:
class PixelNorm(nn.Module):
  def __init__(self):
    super(PixelNorm, self).__init__()
    self.epsilon = 1e-8

  def forward(self, x):
    return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)

In [80]:
class ConvBlock(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(ConvBlock, self).__init__()
    self.conv1 = WSConv2d(in_channels, out_channels)
    self.conv2 = WSConv2d(in_channels, out_channels)
    self.lrelu = nn.LeakyReLU(0.1)

  def forward(self, x):
    x = self.lrelu(self.conv1(x))
    x = self.lrelu(self.conv2(x))
    return x

In [81]:
class Discriminator(nn.Module):
  def __init__(self, in_channels, img_channels=3):
    super(Discriminator, self).__init__()
    self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
    self.lrelu = nn.LeakyReLU(0.2)

    for i in range(len(factors) - 1, 0, -1):
      conv_in = int(in_channels * factors[i])
      conv_out = int(in_channels * factors[i - 1])
      self.prog_blocks.append(ConvBlock(conv_in, conv_out))
      self.rgb_layers.append(WSConv2d(img_channels, conv_in, 1, 1, 0))

    self.initial_rgb = WSConv2d(img_channels, in_channels, 1, 1, 0)
    self.rgb_layers.append(self.initial_rgb)
    self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)

    self.final_block = nn.Sequential(
        WSConv2d(in_channels+1, in_channels, kernel_size=3, padding=1),
        nn.LeakyReLU(0.2),
        WSConv2d(in_channels, in_channels, kernel_size=4, stride=1, padding=0),
        nn.LeakyReLU(0.2),
        WSConv2d(in_channels, 1, 1, 1, 0)
    )

  def fade_in(self, alpha, downscaled, out):
    return alpha * out + (1 - alpha) * downscaled

  def minibatch_std(self, x):
    batch_statistics = torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
    return torch.cat([x, batch_statistics], dim=1)

  def forward(self, x, alpha, steps):
    cur_step = len(self.prog_blocks) - steps
    out = self.lrelu(self.rgb_layers[cur_step](x))
    if steps == 0:
      out = self.minibatch_std(out)
      return self.final_block(out).view(out.shape[0], -1)
    downscaled = self.lrelu(self.rgb_layers[cur_step + 1](self.avg_pool(x)))
    out = self.avg_pool(self.prog_blocks[cur_step](out))

    out = self.fade_in(alpha, downscaled, out)

    for step in range(cur_step + 1, len(self.prog_blocks)):
      out = self.prog_blocks[step](out)
      out = self.avg_pool(out)

    out = self.minibatch_std(out)
    return self.final_block(out).view(out.shape[0], -1)

In [82]:
class GenBlock(nn.Module):
  def __init__(self, in_channels, out_channels, w_dim):
    super(GenBlock, self).__init__()
    self.conv1 = WSConv2d(in_channels, out_channels)
    self.conv2 = WSConv2d(in_channels, out_channels)
    self.lrelu = nn.LeakyReLU(0.2, inplace=True)
    self.inject_noise1 = InjectNoise(out_channels)
    self.inject_noise2 = InjectNoise(out_channels)
    self.adain1 = AdaIN(out_channels, w_dim)
    self.adain2 = AdaIN(out_channels, w_dim)

  def forward(self, x, w):
    x = self.adain1(self.lrelu(self.inject_noise1(self.conv1(x))), w)
    x = self.adain2(self.lrelu(self.inject_noise2(self.conv2(x))), w)
    return x

In [95]:
class Generator(nn.Module):
  def __init__(self, z_dim, w_dim, in_channels, img_channels=3):
    super(Generator, self).__init__()
    self.starting_constant = nn.Parameter(torch.ones((1, in_channels, 4, 4)))
    self.map = MappingNetwork(z_dim, w_dim)
    self.initial_adain1 = AdaIN(in_channels, w_dim)
    self.initial_adain2 = AdaIN(in_channels, w_dim)
    self.initial_noise1 = InjectNoise(in_channels)
    self.initial_noise2 = InjectNoise(in_channels)
    self.initial_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
    self.lrelu = nn.LeakyReLU(0.2, inplace=True)

    self.initial_rgb = WSConv2d(in_channels, img_channels, 1, 1, 0)
    self.prog_blocks, self.rgb_layers = (
        nn.ModuleList([]),
        nn.ModuleList([self.initial_rgb])
    )

    for i in range(len(factors) - 1):
      conv_in_c = int(in_channels * factors[i])
      conv_out_c = int(in_channels * factors[i + 1])
      self.prog_blocks.append(GenBlock(conv_in_c, conv_out_c, w_dim))
      self.rgb_layers.append(
          WSConv2d(conv_out_c, img_channels, 1, 1, 0)
      )

  def fade_in(self, alpha, upscaled, generated):
    return torch.tanh(alpha * generated + (1 - alpha) * upscaled)

  def forward(self, noise, alpha, steps):
    w = self.map(noise)
    x = self.initial_adain1(self.initial_noise1(self.starting_constant), w)
    x = self.initial_conv(x)
    out = self.initial_adain2(self.lrelu(self.initial_noise2(x)), w)

    if steps == 0:
      return self.initial_rgb(x)

    for step in range(steps):
      upscaled = F.interpolate(out, scale_factor=2, mode="bilinear")
      out = self.prog_blocks[step](upscaled, w)

    final_upscaled = self.rgb_layers[steps - 1](upscaled)
    final_out = self.rgb_layers[steps](out)
    return self.fade_in(alpha, final_upscaled, final_out)


In [96]:
def generate_examples(gen, steps, n=100):
  gen.eval()
  alpha = 1.0
  for i in range(n):
    with torch.no_grad():
      noise = torch.randn(1, Z_DIM).to(DEVICE)
      img = gen(noise, alpha, steps)
      if not os.path.exists(f'saved_examples/step{steps}'):
          os.makedirs(f'saved_examples/step{steps}')
      save_image(img*0.5+0.5, f"saved_examples/step{steps}/img_{i}.png")
  gen.train()

In [97]:
def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
  BATCH_SIZE, C, H, W = real.shape
  beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
  interpolated_imgs = real * beta + fake.detach() * (1 - beta)
  interpolated_imgs.requires_grad_(True)

  mixed_scores = critic(interpolated_imgs, alpha, train_step)

  gradient = torch.autograd.grad(
      inputs=interpolated_imgs,
      outputs=mixed_scores,
      grad_outputs=torch.ones_like(mixed_scores),
      create_graph=True,
      retain_graph=True
  )[0]
  gradient.view(gradient.shape[0], -1)
  gradient_norm = gradient.norm(2, dim=1)
  grad_penalty = torch.mean((gradient_norm - 1) ** 2)
  return grad_penalty

In [126]:
def train_fn(critic, gen, loader, dataset, step, alpha, opt_critic, opt_gen):
  loop = tqdm(loader, leave=True)
  for batch_idx, (real, _) in enumerate(loop):
    real = real.to(DEVICE)
    cur_batch_size = real.shape[0]

    noise = torch.randn(cur_batch_size, Z_DIM).to(DEVICE)
    fake = gen(noise, alpha, step)
    fake_ = fake.detach()
    critic_real = critic(real, alpha, step)
    critic_fake = critic(fake_, alpha, step)
    gp = gradient_penalty(critic, real, fake_, alpha, step, DEVICE)
    loss_critic = (
        -(torch.mean(critic_real) - torch.mean(critic_fake))
        + LAMBDA_GP * gp
        + (0.001 * torch.mean(critic_real ** 2))
    )

    critic.zero_grad()
    loss_critic.backward()
    opt_critic.step()

    gen_fake = critic(fake, alpha, step)
    loss_gen = -torch.mean(gen_fake)

    gen.zero_grad()
    loss_gen.backward()
    opt_gen.step()

    alpha += cur_batch_size / ((PROGRESSIVE_EPOCHS[step] * 0.5) * len(dataset))
    alpha = min(alpha, 1)

    loop.set_postfix(gp=gp.item(), loss_critic=loss_critic.item())

  return alpha

In [None]:
gen = Generator(Z_DIM, W_DIM, IN_CHANNELS, CHANNELS_IMG).to(DEVICE)
critic = Discriminator(IN_CHANNELS, CHANNELS_IMG).to(DEVICE)
opt_gen = optim.Adam([{"params": [param for name, param in gen.named_parameters() if "map" not in name]},
                        {"params": gen.map.parameters(), "lr": 1e-5}], lr=LEARNING_RATE, betas=(0.0, 0.99))
opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99))

gen.train()
critic.train()

step = int(log2(START_TRAIN_AT_IMG_SIZE / 4))
for num_epochs in PROGRESSIVE_EPOCHS[step:]:
    alpha = 1e-5   # start with very low alpha
    loader, dataset = get_loader(4 * 2 ** step)
    print(f"Current image size: {4 * 2 ** step}")

    for epoch in range(num_epochs):
        print(f"Epoch [{epoch+1}/{num_epochs}]")
        alpha = train_fn(
            critic,
            gen,
            loader,
            dataset,
            step,
            alpha,
            opt_critic,
            opt_gen
        )

    generate_examples(gen, step)
    step += 1  # progress to the next img size

Current image size: 8
Epoch [1/30]


100%|██████████| 1583/1583 [05:12<00:00,  5.07it/s, gp=0.276, loss_critic=-9.13]


Epoch [2/30]


100%|██████████| 1583/1583 [05:12<00:00,  5.06it/s, gp=0.0409, loss_critic=-2.57]


Epoch [3/30]


100%|██████████| 1583/1583 [05:13<00:00,  5.05it/s, gp=0.0388, loss_critic=-.872]


Epoch [4/30]


100%|██████████| 1583/1583 [05:13<00:00,  5.05it/s, gp=0.0513, loss_critic=0.015]


Epoch [5/30]


100%|██████████| 1583/1583 [05:13<00:00,  5.05it/s, gp=0.11, loss_critic=0.601]


Epoch [6/30]


100%|██████████| 1583/1583 [05:13<00:00,  5.05it/s, gp=0.129, loss_critic=1.15]


Epoch [7/30]


100%|██████████| 1583/1583 [05:13<00:00,  5.05it/s, gp=0.142, loss_critic=2.2]


Epoch [8/30]


100%|██████████| 1583/1583 [05:13<00:00,  5.05it/s, gp=0.121, loss_critic=0.93]


Epoch [9/30]


100%|██████████| 1583/1583 [05:13<00:00,  5.04it/s, gp=0.115, loss_critic=1.68]


Epoch [10/30]


100%|██████████| 1583/1583 [05:13<00:00,  5.05it/s, gp=0.0936, loss_critic=1.17]


Epoch [11/30]


100%|██████████| 1583/1583 [05:12<00:00,  5.06it/s, gp=0.0889, loss_critic=0.766]


Epoch [12/30]


100%|██████████| 1583/1583 [05:13<00:00,  5.05it/s, gp=0.104, loss_critic=1.15]


Epoch [13/30]


100%|██████████| 1583/1583 [05:14<00:00,  5.03it/s, gp=0.096, loss_critic=0.586]


Epoch [14/30]


100%|██████████| 1583/1583 [05:15<00:00,  5.02it/s, gp=0.0873, loss_critic=1.45]


Epoch [15/30]


100%|██████████| 1583/1583 [05:14<00:00,  5.03it/s, gp=0.109, loss_critic=1.38]


Epoch [16/30]


100%|██████████| 1583/1583 [05:26<00:00,  4.85it/s, gp=0.095, loss_critic=1.18]


Epoch [17/30]


100%|██████████| 1583/1583 [05:23<00:00,  4.90it/s, gp=0.0829, loss_critic=1.75]


Epoch [18/30]


100%|██████████| 1583/1583 [05:21<00:00,  4.93it/s, gp=0.0835, loss_critic=0.825]


Epoch [19/30]


100%|██████████| 1583/1583 [05:20<00:00,  4.94it/s, gp=0.0697, loss_critic=0.903]


Epoch [20/30]


100%|██████████| 1583/1583 [05:20<00:00,  4.94it/s, gp=0.0833, loss_critic=1.24]


Epoch [21/30]


100%|██████████| 1583/1583 [05:21<00:00,  4.92it/s, gp=0.0817, loss_critic=1.28]


Epoch [22/30]


100%|██████████| 1583/1583 [05:20<00:00,  4.93it/s, gp=0.0876, loss_critic=0.936]


Epoch [23/30]


100%|██████████| 1583/1583 [05:21<00:00,  4.92it/s, gp=0.0655, loss_critic=0.927]


Epoch [24/30]


100%|██████████| 1583/1583 [05:21<00:00,  4.92it/s, gp=0.0617, loss_critic=0.73]


Epoch [25/30]


100%|██████████| 1583/1583 [05:21<00:00,  4.93it/s, gp=0.0637, loss_critic=0.932]


Epoch [26/30]


100%|██████████| 1583/1583 [05:22<00:00,  4.91it/s, gp=0.0717, loss_critic=0.807]


Epoch [27/30]


100%|██████████| 1583/1583 [05:22<00:00,  4.90it/s, gp=0.0821, loss_critic=1.1]


Epoch [28/30]


100%|██████████| 1583/1583 [05:22<00:00,  4.91it/s, gp=0.0647, loss_critic=1.13]


Epoch [29/30]


100%|██████████| 1583/1583 [05:22<00:00,  4.91it/s, gp=0.0703, loss_critic=0.964]


Epoch [30/30]


100%|██████████| 1583/1583 [05:27<00:00,  4.83it/s, gp=0.0737, loss_critic=0.722]


Current image size: 16
Epoch [1/30]


100%|██████████| 3166/3166 [11:43<00:00,  4.50it/s, gp=0.15, loss_critic=2.2]


Epoch [2/30]


100%|██████████| 3166/3166 [11:48<00:00,  4.47it/s, gp=0.108, loss_critic=-3.17]


Epoch [3/30]


100%|██████████| 3166/3166 [11:49<00:00,  4.46it/s, gp=0.137, loss_critic=1.39]


Epoch [4/30]


100%|██████████| 3166/3166 [11:48<00:00,  4.47it/s, gp=0.169, loss_critic=2.55]


Epoch [5/30]


100%|██████████| 3166/3166 [11:45<00:00,  4.48it/s, gp=0.171, loss_critic=0.286]


Epoch [6/30]


100%|██████████| 3166/3166 [11:52<00:00,  4.44it/s, gp=0.173, loss_critic=-2.32]


Epoch [7/30]


100%|██████████| 3166/3166 [11:53<00:00,  4.44it/s, gp=0.19, loss_critic=1.01]


Epoch [8/30]


100%|██████████| 3166/3166 [11:52<00:00,  4.44it/s, gp=0.192, loss_critic=-1.07]


Epoch [9/30]


100%|██████████| 3166/3166 [11:52<00:00,  4.44it/s, gp=0.192, loss_critic=0.999]


Epoch [10/30]


100%|██████████| 3166/3166 [11:53<00:00,  4.43it/s, gp=0.196, loss_critic=-.339]


Epoch [11/30]


100%|██████████| 3166/3166 [11:54<00:00,  4.43it/s, gp=0.189, loss_critic=2.65]


Epoch [12/30]


100%|██████████| 3166/3166 [11:54<00:00,  4.43it/s, gp=0.185, loss_critic=-3.39]


Epoch [13/30]


100%|██████████| 3166/3166 [11:53<00:00,  4.44it/s, gp=0.179, loss_critic=-3.18]


Epoch [14/30]


100%|██████████| 3166/3166 [11:53<00:00,  4.43it/s, gp=0.187, loss_critic=-1.43]


Epoch [15/30]


100%|██████████| 3166/3166 [11:53<00:00,  4.44it/s, gp=0.178, loss_critic=0.281]


Epoch [16/30]


100%|██████████| 3166/3166 [11:52<00:00,  4.44it/s, gp=0.182, loss_critic=1.69]


Epoch [17/30]


100%|██████████| 3166/3166 [11:52<00:00,  4.44it/s, gp=0.175, loss_critic=-1.1]


Epoch [18/30]


100%|██████████| 3166/3166 [11:47<00:00,  4.47it/s, gp=0.178, loss_critic=-.697]


Epoch [19/30]


100%|██████████| 3166/3166 [11:53<00:00,  4.44it/s, gp=0.179, loss_critic=1.81]


Epoch [20/30]


100%|██████████| 3166/3166 [11:53<00:00,  4.44it/s, gp=0.174, loss_critic=0.142]


Epoch [21/30]


100%|██████████| 3166/3166 [11:55<00:00,  4.43it/s, gp=0.176, loss_critic=1.16]


Epoch [22/30]


100%|██████████| 3166/3166 [11:52<00:00,  4.44it/s, gp=0.18, loss_critic=2.19]


Epoch [23/30]


100%|██████████| 3166/3166 [11:56<00:00,  4.42it/s, gp=0.176, loss_critic=2.75]


Epoch [24/30]


100%|██████████| 3166/3166 [11:50<00:00,  4.46it/s, gp=0.179, loss_critic=0.805]


Epoch [25/30]


100%|██████████| 3166/3166 [11:36<00:00,  4.55it/s, gp=0.166, loss_critic=0.521]


Epoch [26/30]


100%|██████████| 3166/3166 [11:34<00:00,  4.56it/s, gp=0.177, loss_critic=0.517]


Epoch [27/30]


100%|██████████| 3166/3166 [11:35<00:00,  4.55it/s, gp=0.17, loss_critic=-2.57]


Epoch [28/30]


100%|██████████| 3166/3166 [11:37<00:00,  4.54it/s, gp=0.174, loss_critic=-2.34]


Epoch [29/30]


100%|██████████| 3166/3166 [11:36<00:00,  4.54it/s, gp=0.167, loss_critic=0.681]


Epoch [30/30]


100%|██████████| 3166/3166 [11:37<00:00,  4.54it/s, gp=0.178, loss_critic=0.0277]


Current image size: 32
Epoch [1/30]


100%|██████████| 6332/6332 [39:58<00:00,  2.64it/s, gp=0.369, loss_critic=-38.1]


Epoch [2/30]


100%|██████████| 6332/6332 [40:02<00:00,  2.64it/s, gp=0.355, loss_critic=17.7]


Epoch [3/30]


100%|██████████| 6332/6332 [40:04<00:00,  2.63it/s, gp=0.341, loss_critic=0.519]


Epoch [4/30]


100%|██████████| 6332/6332 [40:02<00:00,  2.64it/s, gp=0.33, loss_critic=-1.9]


Epoch [5/30]


 36%|███▋      | 2297/6332 [14:30<25:26,  2.64it/s, gp=0.305, loss_critic=-18.7]