In [1]:
import sys

import numpy as np
import math

import torch
import torch.nn.functional as F

from torch import nn, Tensor

In [2]:
sys.path.insert(0, '../src')

In [3]:
from models.stylegan2.layers import EqualConv2d, EqualLinear, EqualLeakyReLU, \
    ModulatedConv2d, RandomGaussianNoise, InputNoise

from models.stylegan2.net import Layer, ToRGB

In [4]:
def nf(stage): 
    fmap_base=16 << 10
    fmap_decay=1.0
    fmap_min=1
    fmap_max=512
    
    fmaps = int(fmap_base / (2.0 ** (stage * fmap_decay)))
    return np.clip(fmaps, fmap_min, fmap_max)

res_log2 = int(math.log2(1024))
num_layers = res_log2 * 2 - 2
print(f'n: {num_layers}\n')

for i in range(res_log2 + 1):
    print(f'{i}: {nf(i)}')

n: 18

0: 512
1: 512
2: 512
3: 512
4: 512
5: 512
6: 256
7: 128
8: 64
9: 32
10: 16


In [5]:
for res in range(3, res_log2 + 1):
    print(res*2-5, res*2-4, res-2, res-1)
    
# for res in range(3, res_log2 + 1):
#     stage = (res - 2) * 2
#     print(stage-1, stage)
    
    
# for res in range(res_log2 - 2):
#     stage = res * 2 + 1
#     print(stage, stage+1, res+1, res+2)

1 2 1 2
3 4 2 3
5 6 3 4
7 8 4 5
9 10 5 6
11 12 6 7
13 14 7 8
15 16 8 9


In [6]:
for res in range(3, res_log2 + 1):
    print(res*2-3)

3
5
7
9
11
13
15
17


In [7]:
def upscale(x, factor):
    return F.interpolate(x, scale_factor=factor, mode='bilinear', align_corners=False)


class SynthesisNet(nn.Module):
    def __init__(self, img_res=1024, img_channels=3, style_dim=512,
                 fmap_base=16 << 10, fmap_decay=1.0, fmap_min=1, fmap_max=512):
        super(SynthesisNet, self).__init__()
        
        if img_res <= 4:
            raise AttributeError("Image resolution must be greater than 4")
            
        res_log2 = int(math.log2(img_res))
        if img_res != 2**res_log2:
            raise AttributeError("Image resolution must be a power of 2")
            
        self.res_log2 = res_log2
        
        def nf(stage): 
            fmaps = int(fmap_base / (2.0 ** (stage * fmap_decay)))
            return np.clip(fmaps, fmap_min, fmap_max)

        main = [Layer(nf(1), nf(1), style_dim)]
        outs = [ToRGB(nf(1), img_channels, style_dim)]
        
        for res in range(1, res_log2 - 1):
            inp_ch, out_ch = nf(res), nf(res+1)
            main += [Layer(inp_ch, out_ch, style_dim, up=True),
                     Layer(out_ch, out_ch, style_dim)]
            outs += [ToRGB(out_ch, img_channels, style_dim)]
        
        self.input = InputNoise(nf(1), size=4)
        self.main = nn.ModuleList(main)
        self.outs = nn.ModuleList(outs)
        
    def forward(self, n):
        w = torch.rand(len(self.main)+1, n, 512)
        x = self.input(n).clone()
        y = None
        
        for i, layer in enumerate(self.main):
            x = layer(x, w[i])
            
            if not i % 2:
                out = self.outs[i // 2]
                
                if not i:
                    y = out(x, w[i+1])
                else:
                    y = upscale(y, 2) + out(x, w[i+1])
                
        return y

In [8]:
sn = SynthesisNet(
    img_res=32,
    fmap_base=2<<6
)

In [9]:
sn(3).shape

torch.Size([3, 3, 32, 32])

In [10]:
math.log2(32) * 2 - 2

8.0