In [6]:
import torch
import torch.nn as nn
import numpy as np
import functools

In [10]:
class GeneratorNetwork(nn.Module):
    def __init__(self, input_nc, output_nc, ngf):
        self.input_nc = input_nc
        self.output_nc = output_nc
        self.ngf = 64
        self.norm_type = 'batch'
        self.n_downsample = 2
        self.n_blocks_global = 9
        self.n_local_enhancers = 1
        self.n_blocks_local = 3
        self.embed_nc = 256*5
        self.padding_type='reflect'

        super(EmbedGlobalBGGenerator, self).__init__()
        norm_layer = get_norm_layer(norm_type=self.norm_type)
        activation = nn.ReLU(True)
        
        downsample_model = [nn.ReflectionPad2d(3), nn.Conv2d(self.input_nc, self.ngf, kernel_size=7, padding=0), norm_layer(self.ngf), activation]
        
        for i in range(self.n_downsample):
            mult = 2**i
            if i != self.n_downsample-1:
                downsample_model += [nn.Conv2d(self.ngf * mult, self.ngf * mult * 2, kernel_size=3, stride=2, padding=1),
                      norm_layer(self.ngf * mult * 2), activation]
            else:
                downsample_model += [nn.Conv2d(self.ngf * mult, self.ngf * mult * 2, kernel_size=3, stride=2, padding=1),
                      norm_layer(self.ngf * mult * 2), activation]
        self.downsample_model = nn.Sequential(*downsample_model)
        
        model=[]
        model += [nn.Conv2d(in_channels=self.ngf*(2**self.n_downsample)+self.embed_nc, out_channels=self.ngf*(2**self.n_downsample), kernel_size=1, padding=0, stride=1, bias=True)]

        mult = 2**self.n_downsample
        for i in range(self.n_blocks_global):
            self.padding_type='reflect'
            model += [ResnetBlock(self.ngf * mult, padding_type=self.padding_type, activation=activation, norm_layer=norm_layer)]
        
        ### upsample         
        for i in range(self.n_downsample):
            mult = 2**(self.n_downsample - i)
            model += [nn.ConvTranspose2d(self.ngf * mult, int(self.ngf * mult / 2), kernel_size=4, stride=2, padding=1, output_padding=0),
                       norm_layer(int(self.ngf * mult / 2)), activation]
        
        self.model = nn.Sequential(*model)

        bg_encoder = [nn.ReflectionPad2d(3), nn.Conv2d(3, self.ngf, kernel_size=7, padding=0), norm_layer(self.ngf), activation]
        self.bg_encoder = nn.Sequential(*bg_encoder)

        bg_decoder = [nn.Conv2d(in_channels=ngf*2, out_channels=self.ngf, kernel_size=1, padding=0, stride=1, bias=True)]
        bg_decoder += [nn.ReflectionPad2d(3), nn.Conv2d(self.ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
        self.bg_decoder = nn.Sequential(*bg_decoder)
        
    def forward(self, input):
        return self.model(input)

    
class DiscriminatorNetwork(nn.Module):
    def __init__(self, input_nc, dis_n_layers, use_sigmoid):
        self.input_nc = input_nc
        self.dis_n_layers = dis_n_layers
        self.use_sigmoid = use_sigmoid
        self.ndf = 64
        self.norm_type = 'batch'
        self.num_D = 2
        self.getIntermFeat = True
        
        super(DiscriminatorNetwork, self).__init__()
        
        norm_layer = get_norm_layer(norm_type=self.norm_type)
 
        for i in range(self.num_D):
            netD = NLayerDiscriminator(self.input_nc, self.ndf, self.dis_n_layers, norm_layer, use_sigmoid, getIntermFeat)
            if getIntermFeat:                                
                for j in range(n_layers+2):
                    setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j)))                                   
            else:
                setattr(self, 'layer'+str(i), netD.model)

        self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
        

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

def get_norm_layer(norm_type='instance'):
    if norm_type == 'batch':
        norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
    elif norm_type == 'instance':
        norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
    else:
        raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
    return norm_layer

class DecoderGenerator_mask_skin_image(nn.Module):
    def __init__(self, norm_layer):  
        super(DecoderGenerator_mask_skin_image, self).__init__()
        self.fc = nn.Sequential(nn.Linear(in_features=512, out_features=512*2*2))
        # input is 512*2*2
        layers_list = []
        layers_list.append(DecoderBlock(channel_in=512, channel_out=512, kernel_size=4, padding=1, stride=2, output_padding=0))  #128*4
        layers_list.append(DecoderBlock(channel_in=512, channel_out=512, kernel_size=4, padding=1, stride=2, output_padding=0))  #128*8*8
        layers_list.append(DecoderBlock(channel_in=512, channel_out=512, kernel_size=4, padding=1, stride=2, output_padding=0))  #128*16*16
        layers_list.append(DecoderBlock(channel_in=512, channel_out=512, kernel_size=4, padding=1, stride=2, output_padding=0))  #128*32*32
        layers_list.append(DecoderBlock(channel_in=512, channel_out=256, kernel_size=4, padding=1, stride=2, output_padding=0))  #128*64*64
        layers_list.append(DecoderBlock(channel_in=256, channel_out=128, kernel_size=4, padding=1, stride=2, output_padding=0))  #64*128*128
        layers_list.append(DecoderBlock(channel_in=128, channel_out=64, kernel_size=4, padding=1, stride=2, output_padding=0))  #64*256*256
        layers_list.append(nn.ReflectionPad2d(2))
        layers_list.append(nn.Conv2d(64,3,kernel_size=5,padding=0))
        layers_list.append(nn.Tanh())
        
        self.conv = nn.Sequential(*layers_list)

    def forward(self, ten):
        # print("in DecoderGenerator_mask_skin, print some shape ")
        ten = self.fc(ten)
        ten = ten.view(ten.size()[0],512, 2, 2)
        ten = self.conv(ten)
        assert ten.size()[1] == 3
        assert ten.size()[2] == 256
        assert ten.size()[3] == 256
        return ten

    def __call__(self, *args, **kwargs):
        return super(DecoderGenerator_mask_skin_image, self).__call__(*args, **kwargs)
class DecoderGenerator_mask_mouth_image(nn.Module):
    def __init__(self, norm_layer):  
        super(DecoderGenerator_mask_mouth_image, self).__init__()
        # start from B*1024
        # self.fc = nn.Sequential(nn.Linear(in_features=1024, out_features=512*4*4),
        #                         nn.BatchNorm1d(num_features=512*4*4, momentum=0.9),
        #                         nn.ReLU(True))
        self.fc = nn.Sequential(nn.Linear(in_features=512, out_features=512*5*9))
        layers_list = []
        # layers_list.append(nn.BatchNorm2d(256, momentum=0.9))
        # layers_list.append(nn.ReLU(True))

        layers_list.append(DecoderBlock(channel_in=512, channel_out=256, kernel_size=4, padding=1, stride=2, output_padding=0)) #10*18
        layers_list.append(DecoderBlock(channel_in=256, channel_out=128, kernel_size=4, padding=1, stride=2, output_padding=0)) #20*36
        layers_list.append(DecoderBlock(channel_in=128, channel_out=64, kernel_size=4, padding=1, stride=2, output_padding=0)) #40*72
        layers_list.append(DecoderBlock(channel_in=64, channel_out=64, kernel_size=4, padding=1, stride=2, output_padding=0)) #80*144
        # layers_list.append(DecoderBlock(channel_in=64, channel_out=64, kernel_size=4, padding=1, stride=2, output_padding=0)) #96*160
        layers_list.append(nn.ReflectionPad2d(2))
        layers_list.append(nn.Conv2d(64,3,kernel_size=5,padding=0))
        layers_list.append(nn.Tanh())

        # layers_list.append(DecoderBlock(channel_in=256, channel_out=256, kernel_size=3, padding=1, stride=1, output_padding=0)) #256*12*14

        self.conv = nn.Sequential(*layers_list)

    def forward(self, ten):
        # print("in DecoderGenerator, print some shape ")
        ten = self.fc(ten)
        ten = ten.view(ten.size()[0],512, 5, 9)
        ten = self.conv(ten)
        assert ten.size()[1] == 3
        assert ten.size()[2] == 80
        assert ten.size()[3] == 144
        return ten

    def __call__(self, *args, **kwargs):
        return super(DecoderGenerator_mask_mouth_image, self).__call__(*args, **kwargs)


class DecoderGenerator_mask_eye_image(nn.Module):
    def __init__(self, norm_layer):  
        super(DecoderGenerator_mask_eye_image, self).__init__()
        # start from B*1024
        # self.fc = nn.Sequential(nn.Linear(in_features=1024, out_features=512*4*4),
        #                         nn.BatchNorm1d(num_features=512*4*4, momentum=0.9),
        #                         nn.ReLU(True))
        self.fc = nn.Sequential(nn.Linear(in_features=512, out_features=512*2*3, bias=False))
        layers_list = []
        # layers_list.append(nn.BatchNorm2d(256, momentum=0.9))
        # layers_list.append(nn.ReLU(True))

        layers_list.append(DecoderBlock(channel_in=512, channel_out=256, kernel_size=4, padding=1, stride=2, output_padding=0)) #256*4
        layers_list.append(DecoderBlock(channel_in=256, channel_out=128, kernel_size=4, padding=1, stride=2, output_padding=0)) #128*8
        layers_list.append(DecoderBlock(channel_in=128, channel_out=64, kernel_size=4, padding=1, stride=2, output_padding=0)) #64*16
        layers_list.append(DecoderBlock(channel_in=64, channel_out=64, kernel_size=4, padding=1, stride=2, output_padding=0)) #64*32
        # layers_list.append(DecoderBlock(channel_in=64, channel_out=64, kernel_size=4, padding=1, stride=2, output_padding=0)) #64*64
        layers_list.append(nn.ReflectionPad2d(2))
        layers_list.append(nn.Conv2d(64,3,kernel_size=5,padding=0))
        layers_list.append(nn.Tanh())

        # layers_list.append(DecoderBlock(channel_in=256, channel_out=256, kernel_size=3, padding=1, stride=1, output_padding=0)) #256*12*14

        self.conv = nn.Sequential(*layers_list)

    def forward(self, ten):
        # print("in DecoderGenerator, print some shape ")
        ten = self.fc(ten)
        ten = ten.view(ten.size()[0],512, 2, 3)
        ten = self.conv(ten)
        assert ten.size()[1] == 3
        assert ten.size()[2] == 32
        assert ten.size()[3] == 48
        return ten

    def __call__(self, *args, **kwargs):
        return super(DecoderGenerator_mask_eye_image, self).__call__(*args, **kwargs)


class DecoderGenerator_mask_mouth(nn.Module):
    def __init__(self, norm_layer):  
        super(DecoderGenerator_mask_mouth, self).__init__()
        

        self.fc = nn.Sequential(nn.Linear(in_features=512, out_features=512*5*9))
        layers_list = []

        layers_list.append(DecoderBlock(channel_in=512, channel_out=256, kernel_size=4, padding=1, stride=2)) #10*18
        layers_list.append(DecoderBlock(channel_in=256, channel_out=256, kernel_size=4, padding=1, stride=2)) #20*36
        # layers_list.append(DecoderBlock(channel_in=256, channel_out=256, kernel_size=4, padding=1, stride=2)) #40*72

        self.conv = nn.Sequential(*layers_list)

    def forward(self, ten):
        # print("in DecoderGenerator, print some shape ")
        # ten = self.fc(ten)
        # ten = ten.view(ten.size()[0],512, 4, 4)
        ten = self.fc(ten)
        ten = ten.view(ten.size()[0],512, 5, 9)
        ten = self.conv(ten)
        assert ten.size()[1] == 256
        assert ten.size()[2] == 20
        assert ten.size()[3] == 36
        return ten

    def __call__(self, *args, **kwargs):
        return super(DecoderGenerator_mask_mouth, self).__call__(*args, **kwargs)


class DecoderGenerator_mask_eye(nn.Module):
    def __init__(self, norm_layer):  
        super(DecoderGenerator_mask_eye, self).__init__()
        # start from B*1024
        # self.fc = nn.Sequential(nn.Linear(in_features=1024, out_features=512*4*4),
        #                         nn.BatchNorm1d(num_features=512*4*4, momentum=0.9),
        #                         nn.ReLU(True))
        # self.fc = nn.Sequential(nn.Linear(in_features=1024, out_features=256*6*7, bias=False))
        self.fc = nn.Sequential(nn.Linear(in_features=512, out_features=512*2*3, bias=False))
        layers_list = []
        # layers_list.append(nn.BatchNorm2d(256, momentum=0.9))
        # layers_list.append(nn.ReLU(True))
        layers_list.append(DecoderBlock(channel_in=512, channel_out=256, kernel_size=4, padding=1, stride=2)) #256*4
        layers_list.append(DecoderBlock(channel_in=256, channel_out=256, kernel_size=4, padding=1, stride=2)) #256*8
        # layers_list.append(DecoderBlock(channel_in=256, channel_out=256, kernel_size=4, padding=1, stride=2)) #256*16
        # layers_list.append(DecoderBlock(channel_in=256, channel_out=256, kernel_size=3, padding=1, stride=1)) #256*16
        # # layers_list.append(DecoderBlock(channel_in=256, channel_out=256, kernel_size=3, padding=1, stride=1, output_padding=0)) #256*12*14

        self.conv = nn.Sequential(*layers_list)

    def forward(self, ten):
        # print("in DecoderGenerator, print some shape ")
        # ten = self.fc(ten)
        # ten = ten.view(ten.size()[0],512, 4, 4)
        ten = self.fc(ten)
        ten = ten.view(ten.size()[0],512, 2, 3)
        ten = self.conv(ten)
        assert ten.size()[1] == 256
        assert ten.size()[2] == 8
        assert ten.size()[3] == 12
        return ten

    def __call__(self, *args, **kwargs):
        return super(DecoderGenerator_mask_eye, self).__call__(*args, **kwargs)


class DecoderGenerator_mask_skin(nn.Module):
    def __init__(self, norm_layer):  
        super(DecoderGenerator_mask_skin, self).__init__()
        # input is 128*4*4
        self.fc = nn.Sequential(nn.Linear(in_features=512, out_features=512*2*2))
        layers_list = []
        layers_list.append(DecoderBlock(channel_in=512, channel_out=256, kernel_size=4, padding=1, stride=2))  #256*4
        layers_list.append(DecoderBlock(channel_in=256, channel_out=256, kernel_size=4, padding=1, stride=2))  #256*8
        layers_list.append(DecoderBlock(channel_in=256, channel_out=256, kernel_size=4, padding=1, stride=2))  #256*16
        layers_list.append(DecoderBlock(channel_in=256, channel_out=256, kernel_size=4, padding=1, stride=2))  #256*32
        layers_list.append(DecoderBlock(channel_in=256, channel_out=256, kernel_size=4, padding=1, stride=2))  #256*64
        self.conv = nn.Sequential(*layers_list)

    def forward(self, ten):
        # print("in DecoderGenerator_mask_skin, print some shape ")
        ten = self.fc(ten)
        ten = ten.view(ten.size()[0],512, 2, 2)
        ten = self.conv(ten)
        assert ten.size()[1] == 256
        assert ten.size()[2] == 64
        return ten

    def __call__(self, *args, **kwargs):
        return super(DecoderGenerator_mask_skin, self).__call__(*args, **kwargs)

class FurryGan(nn.Module):
    def __init__(self, isTrain=True):
        self.input_nc = 11             #number of input channels
        self.output_nc = 3             #number of output channels
        self.isTrain = isTrain         #Whether to train
        self.dis_net_input_nc = self.input_nc + self.output_nc
        self.dis_n_layers = 3
        
        self.gen_net = GeneratorNetwork(self.input_nc, self.output_nc)
        self.gen_net.apply(weights_init)
        
        if self.isTrain:
            use_sigmoid = True
            
        
        self.dis_net = DiscriminatorNetwork(self.dis_net_input_nc, self.dis_n_layers, use_sigmoid)
        
        #TODO
#         embed_feature_size



        
        
        
        self.decoder_skin_net = DecoderGenerator_mask_skin(functools.partial(nn.BatchNorm2d, affine=True))
        self.decoder_skin_hair = DecoderGenerator_mask_skin(functools.partial(nn.BatchNorm2d, affine=True))
        self.decoder_skin_left_eye =  DecoderGenerator_mask_eye(functools.partial(nn.BatchNorm2d, affine=True))
        self.decoder_skin_right_eye = DecoderGenerator_mask_eye(functools.partial(nn.BatchNorm2d, affine=True))
        self.decoder_skin_mouth =  DecoderGenerator_mask_mouth(functools.partial(nn.BatchNorm2d, affine=True)) 
        
        
        self.decoder_skin_image_net = DecoderGenerator_mask_skin_image(functools.partial(nn.BatchNorm2d, affine=True))
        self.decoder_skin_image_hair = DecoderGenerator_mask_skin_image(functools.partial(nn.BatchNorm2d, affine=True))
        self.decoder_skin_image_left_eye = DecoderGenerator_mask_eye_image(functools.partial(nn.BatchNorm2d, affine=True))
        self.decoder_skin_image_right_eye = DecoderGenerator_mask_eye_image(functools.partial(nn.BatchNorm2d, affine=True))
        self.decoder_skin_image_mouth = DecoderGenerator_mask_mouth_image(functools.partial(nn.BatchNorm2d, affine=True))