In [2]:
from PIL import Image
import numpy as np
from tqdm import tqdm
import os
from glob import glob

import torch
from torch import nn
from torch.utils import data
import torchvision
import torch.nn.functional as F
from  torch.nn.utils import spectral_norm
from torchvision import models
import torchvision.transforms as transforms
from config import Config as cfg


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class SPADE(nn.Module):
    def __init__(self,num_channels):
        super(SPADE, self).__init__()
        self.bn = nn.BatchNorm2d(num_channels,affine=False)
        self.conv_1 = nn.Sequential(spectral_norm(nn.Conv2d(cfg.NUM_CLASSES,128,kernel_size=3,padding=1)),\
                                   nn.ReLU())
        self.conv_1_1  = spectral_norm(nn.Conv2d(128, num_channels, kernel_size=3, padding=1))
        self.conv_2 = spectral_norm(nn.Conv2d(128,  num_channels, kernel_size=3, padding=1))
        
    def forward(self,x,segmentation_map):
        # print(x.shape)
        # BN
        x = self.bn(x)
        # print(x.shape)
        # Resize Map
        segmentation_map = F.interpolate(segmentation_map, size=x.size()[2:], mode='nearest')
        # Calc gamma and beta 
        # print(segmentation_map.shape)
        output_shared = self.conv_1(segmentation_map)
        gamma = self.conv_1_1(output_shared)
        beta = self.conv_2(output_shared)
        # rescale
        # print(x.shape,gamma.shape,beta.shape)
        out = (x*gamma) + beta
        return out
class SPADEResBlk(nn.Module):
    def __init__(self,num_features_in,num_features_out):
        super(SPADEResBlk,self,).__init__()
        self.spade1 = SPADE(num_channels=num_features_in)
        self.conv1 = spectral_norm(nn.Conv2d(in_channels=num_features_in,\
            out_channels=num_features_out,kernel_size=3,padding=1))
        self.spade2 = SPADE(num_channels=num_features_out)
        self.conv2 = spectral_norm(nn.Conv2d(in_channels=num_features_out,\
            out_channels=num_features_out,kernel_size=3,padding=1))
        self.skip_connection_spade = SPADE(num_channels=num_features_in)
        self.skip_connection_conv = spectral_norm(nn.Conv2d(in_channels=num_features_in,\
                                                out_channels=num_features_out,\
                                                    kernel_size=1,\
                                                        bias=False))
    
    def forward(self,x,segmentation_map):
        skip_features = self.skip_connection_spade(x,segmentation_map)
        skip_features = F.leaky_relu(skip_features,0.2)
        skip_features = self.skip_connection_conv(skip_features)

        x = self.conv1(F.leaky_relu(self.spade1(x,segmentation_map),0.2))
        x = self.conv2(F.leaky_relu(self.spade2(x,segmentation_map),0.2))
        return skip_features + x   
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        self.linear1 = nn.Linear(cfg.LATENT_DIM,16384)
        self.upsample = nn.Upsample(scale_factor=2)
        self.block1 = SPADEResBlk(num_features_in=1024,num_features_out=1024)
        self.block2 = SPADEResBlk(num_features_in=1024,num_features_out=1024)
        self.block3 = SPADEResBlk(num_features_in=1024,num_features_out=512)
        self.block4 = SPADEResBlk(num_features_in=512,num_features_out=256)
        self.block5 = SPADEResBlk(num_features_in=256,num_features_out=128)
        self.block6 = SPADEResBlk(num_features_in=128,num_features_out=64)
        # self.block7 = SPADEResBlk(num_features_in=64,num_features_out=32)

        self.conv = nn.Conv2d(in_channels=64,out_channels=3,kernel_size=3,padding=1)

    def forward(self,latent_vec,segmentation_map):
        x = self.linear1(latent_vec)
        x = x.reshape(-1,1024,4,4)
        x = self.block1(x,segmentation_map)
        x = self.upsample(x)
        x = self.block2(x,segmentation_map)
        x = self.upsample(x)
        x = self.block3(x,segmentation_map)
        x = self.upsample(x)
        x = self.block4(x,segmentation_map)
        x = self.upsample(x)
        x = self.block5(x,segmentation_map)
        x = self.upsample(x)
        x = self.block6(x,segmentation_map)
        x = self.upsample(x)
        # print(x.shape)        
        # x = self.block7(x,segmentation_map)
        # x = self.upsample(x)
        # print(x.shape)        
        x = F.leaky_relu(x,0.2)
        x = self.conv(x)
        x = torch.tanh(x)
        return x
latent_vec = torch.randn(1,256)        
segmentation_map = torch.randn(1,12,256,256)
gen = Generator()
gen(latent_vec,segmentation_map).shape

torch.Size([1, 3, 256, 256])

In [12]:
for param in gen.parameters():
    print(dir(param))
    # print(param.name)
    break

['H', 'T', '__abs__', '__add__', '__and__', '__array__', '__array_priority__', '__array_wrap__', '__bool__', '__class__', '__complex__', '__contains__', '__deepcopy__', '__delattr__', '__delitem__', '__dict__', '__dir__', '__div__', '__dlpack__', '__dlpack_device__', '__doc__', '__eq__', '__float__', '__floordiv__', '__format__', '__ge__', '__getattribute__', '__getitem__', '__gt__', '__hash__', '__iadd__', '__iand__', '__idiv__', '__ifloordiv__', '__ilshift__', '__imod__', '__imul__', '__index__', '__init__', '__init_subclass__', '__int__', '__invert__', '__ior__', '__ipow__', '__irshift__', '__isub__', '__iter__', '__itruediv__', '__ixor__', '__le__', '__len__', '__long__', '__lshift__', '__lt__', '__matmul__', '__mod__', '__module__', '__mul__', '__ne__', '__neg__', '__new__', '__nonzero__', '__or__', '__pos__', '__pow__', '__radd__', '__rand__', '__rdiv__', '__reduce__', '__reduce_ex__', '__repr__', '__reversed__', '__rfloordiv__', '__rlshift__', '__rmatmul__', '__rmod__', '__rmul_