In [1]:
from collections import OrderedDict


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm

from src.models.layers import (
    ConstantInput,    
    StyledConv2d,
)


from src.config import get_parser

config = get_parser().parse_args(args=[]) 

In [2]:
from src.models.util import ConvBlock


class Stylist(nn.Sequential):
    def __init__(self, config):
        super(Stylist,self).__init__()
        channels = config.stylist_channels
        style_dim = config.style_dim
        for i, (in_ch, out_ch) in enumerate(zip(channels, channels[1:])):
            self.add_module(f'conv{i}', ConvBlock(in_ch, out_ch))            
        self.add_module('avgpool', nn.AdaptiveAvgPool2d((1, 1)))
        self.add_module('flatten', nn.Flatten())
        self.add_module('linear', nn.Linear(channels[-1], style_dim))

stylist = Stylist(config)
print(stylist)

stylist(torch.rand(3, 1, 256, 256)).shape

Stylist(
  (conv0): ConvBlock(
    (conv): Conv2d(1, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (lrelu): LeakyReLU(negative_slope=0.2)
  )
  (conv1): ConvBlock(
    (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (lrelu): LeakyReLU(negative_slope=0.2)
  )
  (conv2): ConvBlock(
    (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (lrelu): LeakyReLU(negative_slope=0.2)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear): Linear(in_features=256, out_features=256, bias=True)
)


torch.Size([3, 256])

In [5]:
class Synthesis(nn.Module):
    def __init__(self, config):        
        super(Synthesis,self).__init__()        
        channels = config.synthesis_channels        
        self.input = ConstantInput(config.initial_input_file,
            config.grid_size, config.initial_input_fixed)
        self.head = StyledConv2d(3, channels[0], config.style_dim, 3)              
        self.trunk = nn.ModuleList([
            StyledConv2d(in_ch, out_ch, config.style_dim, 3)
            for i, (in_ch, out_ch) in
            enumerate(zip(channels, channels[1:]))
        ])
        self.tail = nn.Sequential(
            spectral_norm(nn.Conv2d(channels[-1], 3, 3, 1, 1, bias=False)),            
            nn.Tanh(),)

    def forward(self, style):
        x = self.input(style)        
        x = self.head(x, style) 
        for layer in self.trunk:
            x = layer(x, style)        
        x = self.tail(x)        
        return x
    


synthesis = Synthesis(config)    
print(synthesis)

synthesis(torch.rand(3, config.style_dim)).shape

Synthesis(
  (input): ConstantInput()
  (head): StyledConv2d(
    (conv): ModulatedConv2d(
      (modulation): Linear(in_features=256, out_features=3, bias=True)
    )
    (noise): NoiseInjection()
    (act): LeakyReLU(negative_slope=0.2)
  )
  (trunk): ModuleList(
    (0): StyledConv2d(
      (conv): ModulatedConv2d(
        (modulation): Linear(in_features=256, out_features=256, bias=True)
      )
      (noise): NoiseInjection()
      (act): LeakyReLU(negative_slope=0.2)
    )
    (1): StyledConv2d(
      (conv): ModulatedConv2d(
        (modulation): Linear(in_features=256, out_features=256, bias=True)
      )
      (noise): NoiseInjection()
      (act): LeakyReLU(negative_slope=0.2)
    )
    (2): StyledConv2d(
      (conv): ModulatedConv2d(
        (modulation): Linear(in_features=256, out_features=256, bias=True)
      )
      (noise): NoiseInjection()
      (act): LeakyReLU(negative_slope=0.2)
    )
  )
  (tail): Sequential(
    (0): Conv2d(256, 3, kernel_size=(3, 3), stride=(1,

torch.Size([3, 3, 32, 32])

In [8]:
class Generator(nn.Module):
    def __init__(self, config):        
        super(Generator,self).__init__()
        self.stylist = Stylist(config)
        self.synthesis = Synthesis(config)    
        
    def forward(self, image):
        style = self.stylist(image) 
        points = self.synthesis(style)
        return points
    
G = Generator(config)    
G(torch.rand(5, 1, 64, 64)).shape

torch.Size([5, 3, 256, 256])

In [3]:
from src.models.style_generator import StyleGenerator

G = StyleGenerator(config)
G

StyleGenerator(
  (stylist): Stylist(
    (conv0): ConvBlock(
      (conv): Conv2d(1, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (lrelu): LeakyReLU(negative_slope=0.2)
    )
    (conv1): ConvBlock(
      (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (lrelu): LeakyReLU(negative_slope=0.2)
    )
    (conv2): ConvBlock(
      (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (lrelu): LeakyReLU(negative_slope=0.2)
    )
    (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (linear): Linear(in_features=256, out_features=256, bias=True)
  )
  (synthesis): Synthesis(
    (input): ConstantInput()
    (head): StyledConv2d(
      (conv): ModulatedConv2d(
        (modulation): Linear(in_features=256, out_features=3, bias=True)
      )
      (noise): NoiseInjection()
      (act): LeakyReLU(negative_slope=0.2)
    )
    (trunk)

In [4]:
G(torch.rand(5, 1, 32, 32)).shape


torch.Size([5, 3, 256, 256])