In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from math import log2
from torchsummary import summary

In [3]:
factors = [1, 1, 1, 1, 1/2, 1/4, 1/16, 1/32]

In [4]:
class PixelNorm(nn.Module):
  def __init__(self):
    super(PixelNorm, self).__init__()
    self.epsilon = 1e-8


  def forward(self, x):
    x = x.unsqueeze(1)
    return x / torch.sqrt((torch.mean(x**2, dim = 1, keepdim = True)) + self.epsilon)


In [5]:
class WSLinear(nn.Module):
  def __init__(self, in_features, out_features, gains = 2):
    super(WSLinear, self).__init__()
    self.linear = nn.Linear(in_features, out_features)
    self.scale = (gains/ in_features)**(0.5)
    self.bias = self.linear.bias
    self.linear.bias = None
    nn.init.normal_(self.linear.weight)
    nn.init.zeros_(self.bias)
  def forward(self, x):
    return self.linear(x * self.scale) + self.bias

In [6]:
class WSConv2d(nn.Module):
  def __init__(self, in_features, out_features, kernel_size = 3, stride = 1, padding = 1, gains = 2):
    super(WSConv2d, self).__init__()
    self.conv = nn.Conv2d(in_features, out_features, kernel_size, stride, padding)
    self.scale = (gains / (kernel_size)**2 * in_features)**(0.5)
    self.bias = self.conv.bias
    self.conv.bias = None
    nn.init.normal_(self.conv.weight)
    nn.init.zeros_(self.bias)

  def forward(self, x):
    return self.conv(x * self.scale) + self.bias

In [17]:
class StyleScale(torch.nn.Module):
    def __init__(self, num_channels, num_styles):
        super().__init__()
        self.scale = torch.zeros((num_channels, num_styles, 1, 1))


    def forward(self, x):
        return self.scale

In [18]:
class AdaIN(nn.Module):
    def __init__(self, in_channels, style_dim):
        super().__init__()
        self.InstanceNorm = nn.InstanceNorm2d(in_channels)
        self.scale = nn.Parameter(torch.ones(1, in_channels, 1, 1))
        self.bias = nn.Parameter(torch.zeros(1, in_channels, 1, 1))
        self.style_scale = StyleScale(style_dim, in_channels)
        self.style_bias = StyleScale(style_dim, in_channels)

    def forward(self, x, w):
        self.scale = self.scale.to(x.device)
        self.bias = self.bias.to(x.device)

        x = self.InstanceNorm(x)
        style_scale = self.style_scale(w).unsqueeze(3)
        style_bias = self.style_bias(w).unsqueeze(3)
        return style_scale * x + style_bias

In [19]:
class MappingNetwork(nn.Module):
  def __init__(self, z_dim, w_dim):
    super().__init__()
    self.mapping = nn.Sequential(
        PixelNorm(),
        WSLinear(z_dim, w_dim),
        nn.ReLU(),
        WSLinear(w_dim, w_dim),
        nn.ReLU(),
        WSLinear(w_dim, w_dim),
        nn.ReLU(),
        WSLinear(w_dim, w_dim),
        nn.ReLU(),
        WSLinear(w_dim, w_dim),
        nn.ReLU(),
        WSLinear(w_dim, w_dim),
        nn.ReLU(),
        WSLinear(w_dim, w_dim),
        nn.ReLU(),
        WSLinear(w_dim, w_dim)
    )
  def forward(self, x):
    return self.mapping(x)

In [20]:
class InjectNoise(nn.Module):
  def __init__(self, channels):
    super().__init__()
    self.weight = nn.Parameter(torch.zeros(1, channels, 1, 1))

  def forward(self, x):
    noise = torch.randn((x.shape[0], 1, x.shape[2], x.shape[3]), device=x.device)
    return x + self.weight * noise

In [21]:
class ConvBlock(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(ConvBlock, self).__init__()
    self.conv1 = WSConv2d(in_channels, out_channels)
    self.conv2 = WSConv2d(out_channels, out_channels)
    self.leaky = nn.LeakyReLU(0.2)

  def forward(self, x):
    x = self.leaky(self.conv1(x))
    x = self.leaky(self.conv2(x))
    return x

In [22]:
class GenBlock(nn.Module):
    def __init__(self, in_channels, out_channels, w_dim):
        super(GenBlock, self).__init__()
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2, inplace=True)
        self.inject_noise1 = InjectNoise(out_channels)
        self.inject_noise2 = InjectNoise(out_channels)
        self.adain1 = AdaIN(out_channels, w_dim)
        self.adain2 = AdaIN(out_channels, w_dim)

    def forward(self, x, w):
        x = self.adain1(self.leaky(self.inject_noise1(self.conv1(x))), w)
        x = self.adain2(self.leaky(self.inject_noise2(self.conv2(x))), w)
        return x

In [23]:
class Generator(nn.Module):
  def __init__(self, z_dim, w_dim, in_channels, img_channels = 3):
    super(Generator, self).__init__()
    self.starting_constant = nn.Parameter(torch.ones((1, in_channels, 4, 4)))
    self.map = MappingNetwork(z_dim, w_dim)
    self.initial_adain1 = AdaIN(in_channels, z_dim)
    self.initial_adain2 = AdaIN(in_channels, w_dim)
    self.initial_noise1 = InjectNoise(in_channels)
    self.initial_noise2 = InjectNoise(in_channels)
    self.initial_conv = nn.Conv2d(in_channels, in_channels, kernel_size = 3, stride = 1, padding = 1)
    self.leaky = nn.LeakyReLU(0.2, inplace = True)
    self.initial_rgb = WSConv2d(in_channels, img_channels, kernel_size = 1, stride = 1, padding = 0)
    self.prog_blocks, self.rgb_layers = (nn.ModuleList([]), nn.ModuleList([self.initial_rgb]))
    for i in range(len(factors) - 1):
      conv_in_c = int(in_channels * factors[i])
      conv_out_c = int(in_channels * factors[i + 1])
      self.prog_blocks.append(GenBlock(conv_in_c, conv_out_c, w_dim))
      self.rgb_layers.append(WSConv2d(conv_out_c, img_channels, kernel_size = 1, stride = 1, padding = 0))

  def fade_in(self, alpha, upscaled, generated):
    return torch.tanh(alpha * generated + (1 - alpha) * upscaled)

  def forward(self, noise, alpha, steps):
    w = self.map(noise)
    x = self.initial_adain1(self.initial_noise1(self.starting_constant), w)
    x = self.initial_conv(x)
    out = self.initial_adain2(self.leaky(self.initial_noise2(x)), w)
    if steps == 0:
      return self.initial_rgb(x)

    for step in range(steps):
      upscaled = F.interpolate(out, scale_factor = 2, mode = "bilinear")
      out = self.prog_blocks[step](upscaled, w)
      final_upscaled = self.rgb_layers[steps - 1](upscaled)
      final_out = self.rgb_layers[steps](out)
      return self.fade_in(alpha, final_upscaled, final_out)

In [24]:
class Discriminator(nn.Module):
  def __init__(self, in_channels, img_channels = 3):
    super(Discriminator, self).__init__()
    self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
    self.leaky = nn.LeakyReLU(0.2)
    for i in range(len(factors) - 1, 0, -1):
      conv_in = int(in_channels * factors[i])
      conv_out = int(in_channels * factors[i - 1])
      self.prog_blocks.append(ConvBlock(conv_in, conv_out))
      self.rgb_layers.append(WSConv2d(img_channels, conv_in, kernel_size = 1, stride = 1, padding = 0))
    self.initial_rgb = WSConv2d(img_channels, in_channels, kernel_size = 1, stride = 1, padding = 0)
    self.avg_pool = nn.AvgPool2d(kernel_size = 2, stride = 2)
    self.final_block = nn.Sequential(WSConv2d(in_channels + 1, in_channels, kernel_size = 3, padding = 1),
                                     nn.LeakyReLU(0.2),
                                     WSConv2d(in_channels, in_channels, kernel_size = 4, padding = 0, stride = 1),
                                     nn.LeakyReLU(0.2),
                                     WSConv2d(in_channels, 1, kernel_size = 1, padding = 0, stride = 1))
  def fade_in(self, alpha, downscaled, out):
    return alpha * out + (1 - alpha) * downscaled
  def minibatch_std(self, x):
    batch_statistics = (torch.std(x, dim = 0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3]))
    return torch.cat([x, batch_statistics], dim = 1)
  def forward(self, x, alpha, steps):
    cur_step = len(self.prog_blocks) - steps
    out = self.leaky(self.rgb_layers[cur_step](x))

    if steps == 0:
      out = self.minibatch_std(out)
      return self.final_block(out).view(out.shape[0], -1)

    downscaled = self.leaky(self.rgb_layers[cur_step + 1](self.avg_pool(x)))
    out = self.avg_pool(self.prog_blocks[cur_step](out))

    for step in range(cur_step + 1, len(self.prog_blocks)):
      out = self.prog_blocks[step](out)
      out = self.avg_pool(out)

    out = self.minibatch_std(out)
    return self.final_block(out).view(out.shape[0], -1)

In [None]:
Z_DIM = 512
W_DIM = 512
IN_CHANNELS = 512
gen = Generator(Z_DIM, W_DIM, IN_CHANNELS, img_channels=3)
disc = Discriminator(IN_CHANNELS, img_channels=3)
tot = 0
for param in gen.parameters():
  tot += param.numel()
print(tot)
for img_size in [4, 8, 16, 32, 64, 128, 256, 512, 1024]:
  num_steps = int(log2(img_size / 4))
  x = torch.randn((2, Z_DIM))
  z = gen(x, alpha = 0.5, steps = num_steps)
  assert z.shape == (2, 3, img_size, img_size)
  out = disc(z, alpha=0.5, steps=num_steps)
  assert out.shape == (2, 1)
  print(f"Success! At img size: {img_size}")

20916136
