In [11]:
from argparse import Namespace
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F

from src.stylegan2.op import fused_leaky_relu
from src.stylegan2.Blocks import ModConvLayer

In [13]:
# Generator
# It takes a low resolution model and style
# and produces a high resolution model
class ConvBlock(nn.Sequential):
    def __init__(self, in_channel, out_channel, ker_size, padd, stride):
        super(ConvBlock,self).__init__()
        self.add_module('conv',nn.Conv2d(in_channel ,out_channel,kernel_size=ker_size,stride=stride,padding=padd)),
        self.add_module('norm',nn.BatchNorm2d(out_channel)),
        self.add_module('LeakyRelu',nn.LeakyReLU(0.2, inplace=True))

    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv2d') != -1:
            m.weight.data.normal_(0.0, 0.02)
        elif classname.find('Norm') != -1:
            m.weight.data.normal_(1.0, 0.02)
            m.bias.data.fill_(0)
        
class Generator(nn.Module):
    def __init__(self, opt):
        super(Generator, self).__init__()
        self.head = ModConvLayer(
            opt.dlatent_size, 
            opt.in_channel,
            opt.out_channel, 
            opt.kernel, 
        )
        self.body = nn.Sequential()
        for i in range(opt.num_layer-2):
            N = int(opt.nfc/pow(2,(i+1)))
            block = ConvBlock(max(2*N,opt.min_nfc),max(N, opt.min_nfc),opt.ker_size,opt.padd_size,1)
            self.body.add_module('block%d'%(i+1),block)
        
#         self.body = nn.Sequential(OrderedDict({
#             ''
            
#         }))
        self.tail = nn.Sequential(
            nn.Conv2d(max(N,opt.min_nfc),opt.nc_im,kernel_size=opt.ker_size,stride =1,padding=opt.padd_size),
            nn.Tanh()
        )
    
    def forward(self, x, style):
        print(x.shape)
        x = self.head(x, style)
        print(x.shape)
        x = self.body(x)
        print(x.shape)
        x = self.tail(x)
        print(x.shape)
        return x

opt = Namespace(
    dlatent_size=128,
    in_channel=3,
    out_channel=32,
    kernel=3,
    ker_size=3,
    # Body
    num_layer=5,
    nfc=32,
    min_nfc=32,
    padd_size=1, #math.floor(opt.ker_size/2)
    nc_im=3,  # image # channels
    
)
G = Generator(opt)
print(G)
img = torch.rand(1, 3, 64, 64)
style = torch.rand(1, opt.dlatent_size)
G(img, style).shape

Generator(
  (head): ModConvLayer(
    (conv): EqualizedModConv2d(3, 32, 3, upsample=False, downsample=False)
  )
  (body): Sequential(
    (block1): ConvBlock(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (LeakyRelu): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (block2): ConvBlock(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (LeakyRelu): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (block3): ConvBlock(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (LeakyRelu): LeakyReLU(negative_slope=0.2, inplace=True)
    )
  )
  (tail): Sequential(
    (0): Conv2d(32, 3, kernel_size

torch.Size([1, 3, 64, 64])