In [1]:
import sys

import numpy as np
import math

import torch
import torch.nn.functional as F

from torch import nn, Tensor

In [2]:
sys.path.insert(0, '../src')

In [3]:
from models.stylegan2.layers import EqualConv2d, EqualLinear, EqualLeakyReLU, \
    ModulatedConv2d, RandomGaussianNoise

In [4]:
class StyledLayer(nn.Module):
    def __init__(self, in_channels, out_channels, style_dim, up=False):
        super(StyledLayer, self).__init__()
        self.style = EqualLinear(style_dim, in_channels, bias=True)
        nn.init.ones_(self.style.bias)
        
        if up:
            self.upscale = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        else:
            self.upscale = None
            
        self.conv = ModulatedConv2d(in_channels, out_channels, kernel_size=3, 
                                    stride=1, padding=1)
        self.add_noise = RandomGaussianNoise()
        self.act_fn = EqualLeakyReLU(inplace=True)
        
    def forward(self, x, w):
        if self.upscale is not None:
            x = self.upscale(x)
        y = self.style(w)
        x = self.conv(x, y)
        x = self.act_fn(self.add_noise(x))
        return x

In [5]:
sl = StyledLayer(
    in_channels=4, 
    out_channels=8, 
    style_dim=24,
    up=False
); sl

StyledLayer(
  (style): EqualLinear(in_features=24, out_features=4, bias=True)
  (conv): ModulatedConv2d(4, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (add_noise): RandomGaussianNoise()
  (act_fn): EqualLeakyReLU(negative_slope=0.2, inplace=True)
)

In [6]:
x = torch.rand(1, 4, 32, 32)
w = torch.rand(1, 24)

sl(x, w).shape

torch.Size([1, 8, 32, 32])

In [7]:
class StyledBlock(nn.Module):
    def __init__(self, in_channels, out_channels, style_dim):
        super(StyledBlock, self).__init__()
        
    def forward(self, x, w):
        return x

In [8]:
class SynthesisNet(nn.Module):
    def __init__(self, style_dim=512, out_channels=3, out_res=1024, 
                 fmap_base=16 << 10, fmap_decay=1.0, fmap_min=1, fmap_max=512):
        super(SynthesisNet, self).__init__()
        
        if out_res <= 4:
            raise AttributeError("The output resolution must be greater than 4")
            
        res_log2 = int(math.log2(out_res))
        if out_res != 2**res_log2:
            raise AttributeError("The output resolution must be a power of 2")
        
        def nf(stage): 
            fmaps = int(fmap_base / (2.0 ** (stage * fmap_decay)))
            return np.clip(fmaps, fmap_min, fmap_max)
        
        num_layers = res_log2 * 2 - 2
        
    def forward():
        pass

In [9]:
class ToRGB(nn.Module):
    def __init__(self, in_channels, out_channels, style_dim):
        super(ToRGB, self).__init__()
        self.style = EqualLinear(style_dim, in_channels, bias=True)
        nn.init.ones_(self.style.bias)
        
        self.conv = ModulatedConv2d(in_channels, out_channels, kernel_size=1, 
                                    stride=1, padding=0, demodulate=False)
        
    def forward(self, x, w, x0=None):
        x = self.conv(x, self.style(w))
        if x0 is not None:
            x = x + x0
        return x

In [10]:
trgb = ToRGB(4, 8, 24)

In [11]:
trgb(
    torch.rand(1, 4, 12, 12),
    torch.rand(1, 24)
).shape

torch.Size([1, 8, 12, 12])

In [12]:
class InputNoise(nn.Module):
    def __init__(self, channels, size=4):
        super(InputNoise, self).__init__()
        self.weight = nn.Parameter(torch.empty(1, channels, size, size),
                                   requires_grad=True)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.normal_(self.weight)

    def forward(self, n):
        x = self.weight.expand(n, -1, -1, -1)
        return x

In [13]:
inp = InputNoise(512, size=4)

In [14]:
inp(11).shape

torch.Size([11, 512, 4, 4])