## Pix2pix (SPADE)

In [1]:
from pytorch_lightning.trainer import Trainer
from torchsummary import summary
import torch
from pytorch_lightning.loggers import TensorBoardLogger
from pathlib import Path
from time import time

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0

In [3]:
class BaseOptions:
    def __init__(self, **entries):
        self.__dict__.update(entries)

In [4]:
opt = {
    "gpu_ids": [0],
    "netG": 'global',
    "ngf": 64,
    "num_upsampling_layers": "normal",
    "crop_size": 256,
    "aspect_ratio": 1.0,
    "use_vae": False,
    "z_dim": 256,
    "norm_G": "spectralspadesyncbatch3x3",
    "norm_D": "spectralinstance",
    "norm_E": "spectralinstance",
    "label_nc": 182,
    "contain_dontcare_label": True,
    "output_nc": 3,
    "no_instance": False,
    "init_type": "xavier",
    "init_variance": 0.02,
    "isTrain": True,
    "which_epoch": "latest",
    "checkpoints_dir": './checkpoints',
    "name": 'cityscapes_pretrained',
    "netD": 'multiscale',
    "num_D": 2,
    "netD_subarch": 'n_layer',
    "ndf": 64,
    "n_layers_D": 4,
    "continue_train": False,
    "gan_mode": 'hinge',
    "no_vgg_loss": False,
    "norm": "instance",
    "n_downsample_global": 4,
    "n_blocks_global": 9,
    "n_blocks_local": 3,
    "n_local_enhancers": 1,
    "no_lsgan": True,
    "no_ganFeat_loss": True,
    "feat_num": 3,
    "nef": 16,
    "n_downsample_E": 4,
    "semantic_nc": 182,
    "resnet_initial_kernel_size": 7
}

In [5]:
opt = BaseOptions(**opt)

In [6]:
import models.networks.spade as spade

In [7]:
from models.networks.spade.generator import SPADEGenerator, Pix2PixHDGenerator

In [8]:
netG = SPADEGenerator(opt)

In [9]:
netG = netG.to(device)

In [10]:
netG.print_network()

Network [SPADEGenerator] was created. Total number of parameters: 97.4 million. To see the architecture, do print(network).


In [11]:
print(netG)

SPADEGenerator(
  (fc): Conv2d(182, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (head_0): SPADEResnetBlock(
    (conv_0): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv_1): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (norm_0): SPADE(
      (param_free_norm): SynchronizedBatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
      (mlp_shared): Sequential(
        (0): Conv2d(182, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (mlp_gamma): Conv2d(128, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (mlp_beta): Conv2d(128, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (norm_1): SPADE(
      (param_free_norm): SynchronizedBatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
      (mlp_shared): Sequential(
        (0): Conv2d(182, 128, kernel_size=(3, 3), stride=(1, 1), padding=

In [12]:
params = list(netG.parameters())
print(len(params))

144


In [15]:
params[1].size()

torch.Size([1024])

In [14]:
summary(netG, input_size=(182, 256, 256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 1024, 8, 8]       1,678,336
SynchronizedBatchNorm2d-2           [-1, 1024, 8, 8]               0
            Conv2d-3            [-1, 128, 8, 8]         209,792
              ReLU-4            [-1, 128, 8, 8]               0
            Conv2d-5           [-1, 1024, 8, 8]       1,180,672
            Conv2d-6           [-1, 1024, 8, 8]       1,180,672
             SPADE-7           [-1, 1024, 8, 8]               0
            Conv2d-8           [-1, 1024, 8, 8]       9,438,208
SynchronizedBatchNorm2d-9           [-1, 1024, 8, 8]               0
           Conv2d-10            [-1, 128, 8, 8]         209,792
             ReLU-11            [-1, 128, 8, 8]               0
           Conv2d-12           [-1, 1024, 8, 8]       1,180,672
           Conv2d-13           [-1, 1024, 8, 8]       1,180,672
            SPADE-14         



In [15]:
netG1 = Pix2PixHDGenerator(opt)

ValueError: normalization layer spadesyncbatch3x3 is not recognized