In [1]:
import math
import numpy as np
import random
import sys

import torch
import torch.nn.functional as F

from torch import nn, Tensor
from typing import Any, Callable, Optional

In [2]:
sys.path.insert(0, '../src')

In [3]:
from models.stylegan2.layers import EqualConv2d, EqualLinear, EqualLeakyReLU, \
    ModulatedConv2d, AddRandomNoise, ConcatMiniBatchStddev

from models.stylegan2.net import StyledLayer, ToRGB, FromRGB, Input, SynthesisNet, MappingNet

In [4]:
class Generator(nn.Module):
    def __init__(self, mapping: MappingNet, synthesis: SynthesisNet, 
                 p_style_mix=0.9, w_ema_decay=0.995,
                 truncation_psi=0.5, truncation_cutoff=None, 
                 is_training=True):
        super(Generator, self).__init__()  
        if is_training:
            if w_ema_decay >= 1.0 or w_ema_decay <= 0.0:
                w_ema_decay = None
            if p_style_mix <= 0.0:
                p_style_mix = None
            truncation_psi = None
        else:
            w_ema_decay = None
            p_style_mix = None
            if truncation_psi >= 1.0:
                truncation_psi = None
        
        self.mapping = mapping
        self.synthesis = synthesis
        self.p_style_mix = p_style_mix
        
        self.w_ema_decay = w_ema_decay
        self.register_buffer('w_avg', torch.zeros(mapping.style_dim))
        
        self.truncation_psi = truncation_psi
        self.truncation_cutoff = truncation_cutoff

    @property
    def num_layers(self):
        return self.synthesis.num_layers
    
    def w_ema_step(self, w: Tensor):
        with torch.no_grad():
            self.w_avg = torch.lerp(w.mean(0), self.w_avg, self.w_ema_decay)
        return w

    def mix_styles(self, z1: Tensor, label: Tensor, w1: Tensor):
        num_layers = self.num_layers
        
        if random.uniform(0, 1) < self.p_style_mix:
            mix_cutoff = int(random.uniform(1, num_layers))
        else:
            mix_cutoff = num_layers

        z2 = torch.randn_like(z1)
        w2 = self.mapping(z2, label)
        mask = (torch.arange(num_layers) < mix_cutoff)[:, None, None]
        return torch.where(mask, w1, w2)
    
    def truncate(self, w: Tensor):
        assert w.ndim == 3, "w: layer axis is missing"
        layer_psi = torch.ones(self.num_layers, device=w.device)
        if self.truncation_cutoff is None:
            layer_psi *= self.truncation_psi
        else:
            layer_idx = torch.arange(self.num_layers, device=w.device)
            mask = layer_idx < self.truncation_cutoff
            layer_psi = torch.where(mask, layer_psi * self.truncation_psi, layer_psi)
        w = torch.lerp(self.w_avg[None, None, :], w, layer_psi[:, None, None])
        return w

    def forward(self, z, label=None):
        w = self.mapping(z, label)
        if self.w_ema_decay:
            self.w_ema_step(w)
        # N, S -> L, N, S
        w = w.expand(self.num_layers, -1, -1)
        if self.p_style_mix:
            w = self.mix_styles(z, label, w)
        if self.truncation_psi:
            w = self.truncate(w)
        return self.synthesis(w)

In [5]:
latent_dim = 16
style_dim = 16

synthesis = SynthesisNet(
    img_res=32,
    fmap_base=2 << 5,
    style_dim=style_dim
)

mapping = MappingNet(
    latent_dim=latent_dim,
    label_dim=1,
    style_dim=style_dim,
    num_layers=3,
    hidden_dim=16
)

g = Generator(mapping, synthesis, 
    p_style_mix=0.9, 
    truncation_psi=0.5, 
    truncation_cutoff=4)

In [6]:
N = 3
z = torch.rand(N, latent_dim)
y = torch.rand(N, 1)
g(z, y).shape

torch.Size([3, 3, 32, 32])

In [None]:
class FromRGB(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(FromRGB, self).__init__()
        self.conv = EqualConv2d(in_channels, out_channels, kernel_size=1,
                                stride=1, padding=0, bias=True)
        self.act_fn = EqualLeakyReLU(inplace=True)

    def forward(self, x):
        return self.act_fn(self.conv(x))

In [None]:
def conv_lrelu(in_ch: int, out_ch: int):
    return [EqualConv2d(in_ch, out_ch, kernel_size=3, stride=1, 
                        padding=1, bias=True),
            EqualLeakyReLU(inplace=True)]


class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        self.conv = nn.Sequential(
            *conv_lrelu(in_channels, in_channels),
            *conv_lrelu(in_channels, out_channels),
            nn.AvgPool2d(2))
        self.down = nn.Sequential(
            EqualConv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.AvgPool2d(2))
        
    def forward(self, x: Tensor):
        x = self.conv(x) + self.down(x)
        return x * (1 / math.sqrt(2))

In [None]:
class Lambda(nn.Module):
    def __init__(self, fn: Callable[[Any], Tensor]):
        super(Lambda, self).__init__()
        self.fn = fn
        
    def forward(self, x: Tensor):
        return self.fn(x)


class Flatten(nn.Module):
    def forward(self, x: Tensor):
        return x.flatten(1)

In [None]:
torch.rand(3,5,2,2).flatten(1).shape

In [None]:
class Discriminator(nn.Module):
    def __init__(self, img_res=1024, img_channels=3, label_dim=0,
                 fmap_base=16 << 10, fmap_decay=1.0, fmap_min=1, fmap_max=512,
                 mbstd_group_size=4, mbstd_num_features=1):
        super(Discriminator, self).__init__()
        
        if img_res <= 4:
            raise AttributeError("Image resolution must be greater than 4")

        res_log2 = int(math.log2(img_res))
        if img_res != 2 ** res_log2:
            raise AttributeError("Image resolution must be a power of 2")

        def nf(stage):
            fmaps = int(fmap_base / (2.0 ** (stage * fmap_decay)))
            return np.clip(fmaps, fmap_min, fmap_max)
        
        inp = FromRGB(img_channels, nf(res_log2-1))
        main = [ResidualBlock(nf(res-1), nf(res-2)) 
                for res in range(res_log2, 2, -1)]
        
        mbstd_ch = mbstd_num_features * int(mbstd_group_size > 1) 
        out = [*conv_lrelu(nf(1) + mbstd_ch, nf(1)),
               Lambda(lambda x: x.flatten(1)),
               EqualLinear(nf(1) * 4**2, nf(0), bias=True), 
               EqualLeakyReLU(inplace=True), 
               EqualLinear(nf(0), max(label_dim, 1), bias=True)]
        if mbstd_ch:
            mbstd = ConcatMiniBatchStddev(mbstd_group_size, mbstd_num_features)
            out = [mbstd] + out
        
        self.layers = nn.Sequential(inp, *main, *out)
        
    def forward(self, image: Tensor, label: Optional[Tensor] = None):
        x = self.layers(image)
        if label is not None:
            x = torch.sum(x * label, dim=1, keepdim=True)
        return x

In [None]:
d = Discriminator(img_res=64, fmap_base=2 << 6, label_dim=3)

In [None]:
d(torch.rand(2, 3, 64, 64), torch.rand(2, 3)).shape