In [5]:
import torch
import torch.nn as nn
import numpy as np
import functools

# Loss Classes

In [8]:
class GANLoss(nn.Module):
    def __init__(self, use_lsgan, target_real_label=1.0, target_fake_label=0.0):
        super(GANLoss, self).__init__()
        self.use_lsgan = use_lsgan
        self.target_real_label = target_real_label
        self.target_fake_label = target_fake_label
        self.real_label_var = None
        self.fake_label_var = None
        self.Tensor = tensor
        
        if self.use_lsgan:
            self.loss = nn.MSELoss()
        else:
            self.loss = nn.BCEWithLogitsLoss()
        
    def get_target_tensor(self, input, target_is_real):
        target_tensor = None
        if target_is_real:
            create_label = ((self.real_label_var is None) or
                            (self.real_label_var.numel() != input.numel()))
            if create_label:
                real_tensor = self.Tensor(input.size()).fill_(self.real_label)
                self.real_label_var = Variable(real_tensor, requires_grad=False)
            target_tensor = self.real_label_var
        else:
            create_label = ((self.fake_label_var is None) or
                            (self.fake_label_var.numel() != input.numel()))
            if create_label:
                fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
                self.fake_label_var = Variable(fake_tensor, requires_grad=False)
            target_tensor = self.fake_label_var
        return target_tensor
    
    def __call__(self, input, target_is_real):
        if isinstance(input[0], list):
            loss = 0
            for input_i in input:
                pred = input_i[-1]
                target_tensor = self.get_target_tensor(pred, target_is_real)
                loss += self.loss(pred, target_tensor)
            return loss
        else:            
            target_tensor = self.get_target_tensor(input[-1], target_is_real)
            return self.loss(input[-1], target_tensor)
        
class MFMLoss(nn.Module):
    def __init__(self):
        super(MFMLoss, self).__init__()
        self.criterion = nn.MSELoss()

    def forward(self, x_input, y_input):
        loss = 0
        for i in range(len(x_input)):
            x = x_input[i][-2]
            y = y_input[i][-2]
            assert x.dim() == 4 
            assert y.dim() == 4
            x_mean = torch.mean(x,0)
            y_mean = torch.mean(y,0)
            loss += self.criterion(x_mean, y_mean.detach())
        return loss   

class Vgg19(torch.nn.Module):
    def __init__(self, requires_grad=False):
        super(Vgg19, self).__init__()
        vgg_pretrained_features = models.vgg19(pretrained=True).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        for x in range(2):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(2, 7):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(7, 12):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(12, 21):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        for x in range(21, 30):
            self.slice5.add_module(str(x), vgg_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X, layers_num=5):
        h_relu1 = self.slice1(X)
        if layers_num == 1:
            return [h_relu1]   
        h_relu2 = self.slice2(h_relu1)     
        if layers_num == 2:
            return [h_relu1, h_relu2]   
        h_relu3 = self.slice3(h_relu2)   
        if layers_num == 3:
            return [h_relu1, h_relu2, h_relu3]     
        h_relu4 = self.slice4(h_relu3)        
        if layers_num == 4:
            return [h_relu1, h_relu2, h_relu3, h_relu4]     
        h_relu5 = self.slice5(h_relu4)                
        out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
        return out

class VGGLoss(nn.Module):
    def __init__(self, weights = None):
        super(VGGLoss, self).__init__()       
        if weights != None: 
            self.weights = weights
        else:
            self.weights = [1.0/4, 1.0/4, 1.0/4, 1.0/8, 1.0/8]        
        self.vgg = Vgg19()
        self.criterion = nn.L1Loss()

    def forward(self, x, y, face_mask, mask_weights):              
        assert face_mask.size()[1] == len(mask_weights)  # suppose to be 5
        x_vgg, y_vgg = self.vgg(x,layers_num=len(self.weights)), self.vgg(y,layers_num=len(self.weights))
        mask = []
        mask.append(face_mask.detach())
        
        downsample = nn.MaxPool2d(2)
        for i in range(len(x_vgg)):
            mask.append(downsample(mask[i]))
            mask[i] = mask[i].detach()
        loss = 0
        for i in range(len(x_vgg)):
            for mask_index in range(len(mask_weights)):
                a = x_vgg[i]*mask[i][:,mask_index:mask_index+1,:,:]
                loss += self.weights[i] * self.criterion(x_vgg[i]*mask[i][:,mask_index:mask_index+1,:,:], (y_vgg[i]*mask[i][:,mask_index:mask_index+1,:,:]).detach()) * mask_weights[mask_index]
        return loss    

class GramMatrixLoss(nn.Module):
    def __init__(self):
        super(GramMatrixLoss, self).__init__()        
        self.weights = [1.0,1.0,1.0]
        self.vgg = Vgg19()
        # self.criterion = nn.L1Loss()
        self.criterion = nn.MSELoss()
        # self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]        

    def forward(self, x, y, label):
        # we use this label to label face
        face_mask = (label==1).type(torch.FloatTensor)
        mask = []
        mask.append(face_mask)
        x_vgg, y_vgg = self.vgg(x,layers_num=len(self.weights)), self.vgg(y,layers_num=len(self.weights))
        downsample = nn.MaxPool2d(2)
        for i in range(len(x_vgg)):
            mask.append(downsample(mask[i]))
            mask[i] = mask[i].detach()
        loss = 0
        for i in range(len(x_vgg)):
            loss += self.weights[i] * self.criterion(grammatrix(x_vgg[i]*mask[i]), grammatrix(y_vgg[i]*mask[i]).detach())
        return loss

# Network Classes

In [10]:
class ResnetBlock(nn.Module):
    def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False):
        super(ResnetBlock, self).__init__()
        self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout)

    def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout):
        conv_block = []
        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
                       norm_layer(dim),
                       activation]
        if use_dropout:
            conv_block += [nn.Dropout(0.5)]

        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)
        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
                       norm_layer(dim)]

        return nn.Sequential(*conv_block)

    def forward(self, x):
        out = x + self.conv_block(x)
        return out

class GeneratorNetwork(nn.Module):
    def __init__(self, input_nc, output_nc, ngf):
        self.input_nc = input_nc
        self.output_nc = output_nc
        self.ngf = 64
        self.norm_type = 'batch'
        self.n_downsample = 2
        self.n_blocks_global = 9
        self.n_local_enhancers = 1
        self.n_blocks_local = 3
        self.embed_nc = 256*5
        self.padding_type='reflect'

        super(EmbedGlobalBGGenerator, self).__init__()
        norm_layer = get_norm_layer(norm_type=self.norm_type)
        activation = nn.ReLU(True)
        
        downsample_model = [nn.ReflectionPad2d(3), nn.Conv2d(self.input_nc, self.ngf, kernel_size=7, padding=0), norm_layer(self.ngf), activation]
        
        for i in range(self.n_downsample):
            mult = 2**i
            if i != self.n_downsample-1:
                downsample_model += [nn.Conv2d(self.ngf * mult, self.ngf * mult * 2, kernel_size=3, stride=2, padding=1),
                      norm_layer(self.ngf * mult * 2), activation]
            else:
                downsample_model += [nn.Conv2d(self.ngf * mult, self.ngf * mult * 2, kernel_size=3, stride=2, padding=1),
                      norm_layer(self.ngf * mult * 2), activation]
        self.downsample_model = nn.Sequential(*downsample_model)
        
        model=[]
        model += [nn.Conv2d(in_channels=self.ngf*(2**self.n_downsample)+self.embed_nc, out_channels=self.ngf*(2**self.n_downsample), kernel_size=1, padding=0, stride=1, bias=True)]

        mult = 2**self.n_downsample
        for i in range(self.n_blocks_global):
            model += [ResnetBlock(self.ngf * mult, padding_type=self.padding_type, activation=activation, norm_layer=norm_layer)]
        
        ### upsample         
        for i in range(self.n_downsample):
            mult = 2**(self.n_downsample - i)
            model += [nn.ConvTranspose2d(self.ngf * mult, int(self.ngf * mult / 2), kernel_size=4, stride=2, padding=1, output_padding=0),
                       norm_layer(int(self.ngf * mult / 2)), activation]
        
        self.model = nn.Sequential(*model)

        bg_encoder = [nn.ReflectionPad2d(3), nn.Conv2d(3, self.ngf, kernel_size=7, padding=0), norm_layer(self.ngf), activation]
        self.bg_encoder = nn.Sequential(*bg_encoder)

        bg_decoder = [nn.Conv2d(in_channels=ngf*2, out_channels=self.ngf, kernel_size=1, padding=0, stride=1, bias=True)]
        bg_decoder += [nn.ReflectionPad2d(3), nn.Conv2d(self.ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
        self.bg_decoder = nn.Sequential(*bg_decoder)
        
    def forward(self, input, type="label_encoder"):
        if type=="label_encoder":
            return self.downsample_model(input)
        elif type=="image_G":
            return self.model(input)
        elif type=="bg_encoder":
            return self.bg_encoder(input)
        elif type=="bg_decoder":
            # notice before bg_decoder, we should concate the feature map form G and bg_encoder
            return self.bg_decoder(input)
        else:
            print("wrong type in generator network - forward ")

class NLayerDiscriminator(nn.Module):
    def __init__(self, input_nc, ndf, dis_n_layers, norm_layer, use_sigoid, getIntermFeat):
        super(NLayerDiscriminator, self).__init__()
        self.input_nc = input_nc
        self.ndf = ndf
        self.dis_n_layers = dis_n_layers
        self.norm_layer = norm_layer
        self.use_sigoid = use_sigoid
        self.getIntermFeat = getIntermFeat
        
        kw = 4
        padw = int(np.ceil((kw-1.0)/2))
        sequence = [[nn.Conv2d(self.input_nc, self.ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]

        nf = self.ndf
        for n in range(1, self.dis_n_layers):
            nf_prev = nf
            nf = min(nf * 2, 512)
            sequence += [[
                nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
                norm_layer(nf), nn.LeakyReLU(0.2, True)
            ]]

        nf_prev = nf
        nf = min(nf * 2, 512)
        sequence += [[
            nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
            norm_layer(nf),
            nn.LeakyReLU(0.2, True)
        ]]

        sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]

        if self.use_sigmoid:
            sequence += [[nn.Sigmoid()]]

        if self.getIntermFeat:
            for n in range(len(sequence)):
                setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
        else:
            sequence_stream = []
            for n in range(len(sequence)):
                sequence_stream += sequence[n]
            self.model = nn.Sequential(*sequence_stream)

    def forward(self, input):
        if self.getIntermFeat:
            # res = [input]
            # for n in range(self.n_layers+2):
            #     model = getattr(self, 'model'+str(n))
            #     res.append(model(res[-1]))
            # return res[1:]
            res = [input]
            for n in range(self.dis_n_layers+2):
                model = getattr(self, 'model'+str(n))
                res.append(model(res[-1]))
            print("debug in networks line 721 ----")
            print(len(res[-2:]))
            return res[-2:]
        else:
            return self.model(input)
        
        
class DiscriminatorNetwork(nn.Module):
    def __init__(self, input_nc, dis_n_layers, numD, use_sigmoid):
        super(DiscriminatorNetwork, self).__init__()
        self.input_nc = input_nc
        self.dis_n_layers = dis_n_layers
        self.num_D = num_D
        self.use_sigmoid = use_sigmoid
        self.ndf = 64
        self.norm_type = 'batch'
        self.getIntermFeat = True
        
        norm_layer = get_norm_layer(norm_type=self.norm_type)
 
        for i in range(self.num_D):
            netD = NLayerDiscriminator(self.input_nc, self.ndf, self.dis_n_layers, norm_layer, self.use_sigmoid, self.getIntermFeat)
            if self.getIntermFeat:                                
                for j in range(self.dis_n_layers+2):
                    setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j)))                                   
            else:
                setattr(self, 'layer'+str(i), netD.model)

        self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
    
    def singleD_forward(self, model, input):
        if self.getIntermFeat:
            result = [input]
            for i in range(len(model)):
                result.append(model[i](result[-1]))
            return result[-2:]
        else:
            return [model(input)]
        
    def forward(self, input):        
        num_D = self.num_D
        result = []
        input_downsampled = input
        for i in range(num_D):
            if self.getIntermFeat:
                model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.dis_n_layers+2)]
            else:
                model = getattr(self, 'layer'+str(num_D-1-i))
            # print("i is ")
            # print(i)
            # print("input_downsampled size is ")
            # print(input_downsampled.size())
            
            result.append(self.singleD_forward(model, input_downsampled))
            if i != (num_D-1):
                input_downsampled = self.downsample(input_downsampled)
        return result

class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc, input_nc=None,
                 submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d
        if input_nc is None:
            input_nc = outer_nc
        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
                             stride=2, padding=1, bias=use_bias)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.ReLU(True)
        upnorm = norm_layer(outer_nc)
        
        global printlayer_index
        
        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1,output_padding=1)
            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
            # printlayer = [PrintLayer(name = str(printlayer_index))]
            # printlayer_index += 1
            # model = printlayer + down + [submodule] + up
            model = down + [submodule] + up
        elif innermost:
            upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1, bias=use_bias,output_padding=1)
            # printlayer = [PrintLayer(str(printlayer_index))]
            # printlayer_index += 1
            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]
            # model = printlayer + down + up
            model = down + up
        else:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1, bias=use_bias,output_padding=1)
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]
            # printlayer = [PrintLayer(str(printlayer_index))]
            # printlayer_index += 1
            if use_dropout:
                model = down + [submodule] + up + [nn.Dropout(0.5)]
                # model = printlayer + down + [submodule] + printlayer + up + [nn.Dropout(0.5)]
            else:
                model = down + [submodule]  + up
                # model = printlayer + down + [submodule] + printlayer + up

        self.model = nn.Sequential(*model)

    def forward(self, x):
        model_output = self.model(x)
        wb,hb = model_output.size()[3],model_output.size()[2]
        wa,ha = x.size()[3],x.size()[2]
        l = int((wb-wa)/2)
        t = int((hb-ha)/2)
        model_output = model_output[:,:,t:t+ha,l:l+wa]
        if self.outermost:
            return model_output
        else:
            return torch.cat([x, model_output], 1)           #if not the outermost block, we concate x and self.model(x) during forward to implement unet
    
class UnetGenerator(nn.Module):
    def __init__(self, segment_classes, input_nc, num_downs, ngf=64,
                 norm_layer=nn.BatchNorm2d, use_dropout=False):
        super(UnetGenerator, self).__init__()
        output_nc = segment_classes
        # construct unet structure
        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
        for i in range(num_downs - 5):
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
        #maybe do some check here with softmax
        self.model = unet_block

    def forward(self, input):
        softmax = torch.nn.Softmax(dim = 1)
        return softmax(self.model(input))    

class PNetwork(nn.Module):
    def __init__(self, label_nc, output_nc):
        self.label_nc = label_nc
        self.output_nc = output_nc
        self.ngf = 64
        self.norm_type = 'batch'
        self.use_dropout = True
        norm_layer = get_norm_layer(norm_type=norm)
        
        netP = UnetGenerator(self.label_nc, self.input_nc, 6, self.ngf, norm_layer=norm_layer, use_dropout=self.use_dropout)

        
class EncoderBlock(nn.Module):
    def __init__(self, channel_in, channel_out, kernel_size=7, padding=3, stride=4):
        super(EncoderBlock, self).__init__()
        # convolution to halve the dimensions
        self.conv = nn.Conv2d(in_channels=channel_in, out_channels=channel_out, kernel_size=kernel_size, padding=padding, stride=stride)
        self.bn = nn.BatchNorm2d(num_features=channel_out, momentum=0.9)
        self.relu = nn.ReLU(True)

    def forward(self, ten, out=False,t = False):
        # here we want to be able to take an intermediate output for reconstruction error
        if out:
            ten = self.conv(ten)
            ten_out = ten
            ten = self.bn(ten)
            ten = self.relu(ten)
            return ten, ten_out
        else:
            ten = self.conv(ten)
            ten = self.bn(ten)
            ten = self.relu(ten)
            return ten
        
class  EncoderGenerator_mask_skin(nn.Module):
    """docstring for  EncoderGenerator"""
    def __init__(self, norm_layer):
        super( EncoderGenerator_mask_skin, self).__init__()
        layers_list = []
        
        # 3*256*256
        layers_list.append(EncoderBlock(channel_in=3, channel_out=64, kernel_size=4, padding=1, stride=2))  # 64*128*128
        layers_list.append(EncoderBlock(channel_in=64, channel_out=128, kernel_size=4, padding=1, stride=2))  # 128*64*64
        layers_list.append(EncoderBlock(channel_in=128, channel_out=256, kernel_size=4, padding=1, stride=2))  # 128*32*32
        layers_list.append(EncoderBlock(channel_in=256, channel_out=512, kernel_size=4, padding=1, stride=2))  # 128*16*16
        layers_list.append(EncoderBlock(channel_in=512, channel_out=512, kernel_size=4, padding=1, stride=2))  # 128*8*8
        layers_list.append(EncoderBlock(channel_in=512, channel_out=512, kernel_size=4, padding=1, stride=2))  # 512*4*4
        layers_list.append(EncoderBlock(channel_in=512, channel_out=512, kernel_size=4, padding=1, stride=2))  # 512*2*2
        # final shape Bx128*4*4
        self.conv = nn.Sequential(*layers_list)

        # self.c_mu = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=1, padding=0, stride=1)
        # self.c_var = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=1, padding=0, stride=1)
        self.fc_mu = nn.Sequential(nn.Linear(in_features=512*2*2, out_features=1024),
                                # nn.BatchNorm1d(num_features=1024,momentum=0.9),
                                nn.ReLU(True),
                                nn.Linear(in_features=1024, out_features=512))
        self.fc_var = nn.Sequential(nn.Linear(in_features=512*2*2, out_features=1024),
                                # nn.BatchNorm1d(num_features=1024,momentum=0.9),
                                nn.ReLU(True),
                                nn.Linear(in_features=1024, out_features=512))

    def forward(self, ten):
        ten = self.conv(ten)
        ten = ten.view(ten.size()[0],-1)
        mu = self.fc_mu(ten)
        logvar = self.fc_var(ten)
        return mu,logvar

    def __call__(self, *args, **kwargs):
        return super(EncoderGenerator_mask_skin, self).__call__(*args, **kwargs)
    
class  EncoderGenerator_mask_mouth(nn.Module):
    """docstring for  EncoderGenerator"""
    def __init__(self, norm_layer):
        super( EncoderGenerator_mask_mouth, self).__init__()
        layers_list = []
        
        # 3*80*144
        layers_list.append(EncoderBlock(channel_in=3, channel_out=64, kernel_size=4, padding=1, stride=2))  # 40*72
        layers_list.append(EncoderBlock(channel_in=64, channel_out=128, kernel_size=4, padding=1, stride=2))  # 20*36
        layers_list.append(EncoderBlock(channel_in=128, channel_out=256, kernel_size=4, padding=1, stride=2))  # 10*18
        layers_list.append(EncoderBlock(channel_in=256, channel_out=512, kernel_size=4, padding=1, stride=2))  # 5*9
        # layers_list.append(EncoderBlock(channel_in=512, channel_out=512, kernel_size=4, padding=1, stride=2))  # 3*5
        
        # final shape Bx256*7*6
        self.conv = nn.Sequential(*layers_list)
        self.fc_mu = nn.Sequential(nn.Linear(in_features=512*5*9, out_features=1024),
                                # nn.BatchNorm1d(num_features=1024,momentum=0.9),
                                nn.ReLU(True),
                                nn.Linear(in_features=1024, out_features=512))
        self.fc_var = nn.Sequential(nn.Linear(in_features=512*5*9, out_features=1024),
                                # nn.BatchNorm1d(num_features=1024,momentum=0.9),
                                nn.ReLU(True),
                                nn.Linear(in_features=1024, out_features=512))
        # self.c_mu = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1, padding=0, stride=1)
        # self.c_var = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1, padding=0, stride=1)


    def forward(self, ten):
        ten = self.conv(ten)
        ten = ten.view(ten.size()[0],-1)
        mu = self.fc_mu(ten)
        logvar = self.fc_var(ten)
        return mu,logvar

    def __call__(self, *args, **kwargs):
        return super(EncoderGenerator_mask_mouth, self).__call__(*args, **kwargs)

class  EncoderGenerator_mask_eye(nn.Module):
    """docstring for  EncoderGenerator"""
    def __init__(self, norm_layer):
        super( EncoderGenerator_mask_eye, self).__init__()
        layers_list = []
        
        # 3*32*48
        layers_list.append(EncoderBlock(channel_in=3, channel_out=64, kernel_size=4, padding=1, stride=2))  # 16*24
        layers_list.append(EncoderBlock(channel_in=64, channel_out=128, kernel_size=4, padding=1, stride=2))  # 
        layers_list.append(EncoderBlock(channel_in=128, channel_out=256, kernel_size=4, padding=1, stride=2))  # 4*6
        layers_list.append(EncoderBlock(channel_in=256, channel_out=512, kernel_size=4, padding=1, stride=2))  # 512*2*3
        
        # final shape Bx256*7*6
        self.conv = nn.Sequential(*layers_list)
        self.fc_mu = nn.Sequential(nn.Linear(in_features=512*2*3, out_features=1024),
                                # nn.BatchNorm1d(num_features=1024,momentum=0.9),
                                nn.ReLU(True),
                                nn.Linear(in_features=1024, out_features=512))
        self.fc_var = nn.Sequential(nn.Linear(in_features=512*2*3, out_features=1024),
                                # nn.BatchNorm1d(num_features=1024,momentum=0.9),
                                nn.ReLU(True),
                                nn.Linear(in_features=1024, out_features=512))
        # self.c_mu = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1, padding=0, stride=1)
        # self.c_var = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1, padding=0, stride=1)


    def forward(self, ten):
        ten = self.conv(ten)
        ten = ten.view(ten.size()[0],-1)
        mu = self.fc_mu(ten)
        logvar = self.fc_var(ten)
        return mu,logvar

    def __call__(self, *args, **kwargs):
        return super(EncoderGenerator_mask_eye, self).__call__(*args, **kwargs)
    
        
class DecoderBlock(nn.Module):
    def __init__(self, channel_in, channel_out, kernel_size=4, padding=1, stride=2, output_padding=0, norelu=False):
        super(DecoderBlock, self).__init__()
        # transpose convolution to double the dimensions
        # self.conv = nn.ConvTranspose2d(channel_in, channel_out, kernel_size=kernel_size, padding=padding, stride=stride, output_padding=output_padding)
        # self.bn = nn.BatchNorm2d(channel_out, momentum=0.9)
        # self.bn = nn.InstanceNorm2d(channel_out, momentum=0.9,track_running_stats=True)
        layers_list = []
        layers_list.append(nn.ConvTranspose2d(channel_in, channel_out, kernel_size=kernel_size, padding=padding, stride=stride, output_padding=output_padding))
        layers_list.append(nn.BatchNorm2d(channel_out, momentum=0.9))
        if norelu == False:
            layers_list.append(nn.ReLU(True))
        self.conv = nn.Sequential(*layers_list)

    def forward(self, ten):
        ten = self.conv(ten)
        return ten

class DecoderGenerator_mask_skin_image(nn.Module):
    def __init__(self, norm_layer):  
        super(DecoderGenerator_mask_skin_image, self).__init__()
        self.fc = nn.Sequential(nn.Linear(in_features=512, out_features=512*2*2))
        # input is 512*2*2
        layers_list = []
        layers_list.append(DecoderBlock(channel_in=512, channel_out=512, kernel_size=4, padding=1, stride=2, output_padding=0))  #128*4
        layers_list.append(DecoderBlock(channel_in=512, channel_out=512, kernel_size=4, padding=1, stride=2, output_padding=0))  #128*8*8
        layers_list.append(DecoderBlock(channel_in=512, channel_out=512, kernel_size=4, padding=1, stride=2, output_padding=0))  #128*16*16
        layers_list.append(DecoderBlock(channel_in=512, channel_out=512, kernel_size=4, padding=1, stride=2, output_padding=0))  #128*32*32
        layers_list.append(DecoderBlock(channel_in=512, channel_out=256, kernel_size=4, padding=1, stride=2, output_padding=0))  #128*64*64
        layers_list.append(DecoderBlock(channel_in=256, channel_out=128, kernel_size=4, padding=1, stride=2, output_padding=0))  #64*128*128
        layers_list.append(DecoderBlock(channel_in=128, channel_out=64, kernel_size=4, padding=1, stride=2, output_padding=0))  #64*256*256
        layers_list.append(nn.ReflectionPad2d(2))
        layers_list.append(nn.Conv2d(64,3,kernel_size=5,padding=0))
        layers_list.append(nn.Tanh())
        
        self.conv = nn.Sequential(*layers_list)

    def forward(self, ten):
        # print("in DecoderGenerator_mask_skin, print some shape ")
        ten = self.fc(ten)
        ten = ten.view(ten.size()[0],512, 2, 2)
        ten = self.conv(ten)
        assert ten.size()[1] == 3
        assert ten.size()[2] == 256
        assert ten.size()[3] == 256
        return ten

    def __call__(self, *args, **kwargs):
        return super(DecoderGenerator_mask_skin_image, self).__call__(*args, **kwargs)
    
class DecoderGenerator_mask_mouth_image(nn.Module):
    def __init__(self, norm_layer):  
        super(DecoderGenerator_mask_mouth_image, self).__init__()
        # start from B*1024
        # self.fc = nn.Sequential(nn.Linear(in_features=1024, out_features=512*4*4),
        #                         nn.BatchNorm1d(num_features=512*4*4, momentum=0.9),
        #                         nn.ReLU(True))
        self.fc = nn.Sequential(nn.Linear(in_features=512, out_features=512*5*9))
        layers_list = []
        # layers_list.append(nn.BatchNorm2d(256, momentum=0.9))
        # layers_list.append(nn.ReLU(True))

        layers_list.append(DecoderBlock(channel_in=512, channel_out=256, kernel_size=4, padding=1, stride=2, output_padding=0)) #10*18
        layers_list.append(DecoderBlock(channel_in=256, channel_out=128, kernel_size=4, padding=1, stride=2, output_padding=0)) #20*36
        layers_list.append(DecoderBlock(channel_in=128, channel_out=64, kernel_size=4, padding=1, stride=2, output_padding=0)) #40*72
        layers_list.append(DecoderBlock(channel_in=64, channel_out=64, kernel_size=4, padding=1, stride=2, output_padding=0)) #80*144
        # layers_list.append(DecoderBlock(channel_in=64, channel_out=64, kernel_size=4, padding=1, stride=2, output_padding=0)) #96*160
        layers_list.append(nn.ReflectionPad2d(2))
        layers_list.append(nn.Conv2d(64,3,kernel_size=5,padding=0))
        layers_list.append(nn.Tanh())

        # layers_list.append(DecoderBlock(channel_in=256, channel_out=256, kernel_size=3, padding=1, stride=1, output_padding=0)) #256*12*14

        self.conv = nn.Sequential(*layers_list)

    def forward(self, ten):
        # print("in DecoderGenerator, print some shape ")
        ten = self.fc(ten)
        ten = ten.view(ten.size()[0],512, 5, 9)
        ten = self.conv(ten)
        assert ten.size()[1] == 3
        assert ten.size()[2] == 80
        assert ten.size()[3] == 144
        return ten

    def __call__(self, *args, **kwargs):
        return super(DecoderGenerator_mask_mouth_image, self).__call__(*args, **kwargs)


class DecoderGenerator_mask_eye_image(nn.Module):
    def __init__(self, norm_layer):  
        super(DecoderGenerator_mask_eye_image, self).__init__()
        # start from B*1024
        # self.fc = nn.Sequential(nn.Linear(in_features=1024, out_features=512*4*4),
        #                         nn.BatchNorm1d(num_features=512*4*4, momentum=0.9),
        #                         nn.ReLU(True))
        self.fc = nn.Sequential(nn.Linear(in_features=512, out_features=512*2*3, bias=False))
        layers_list = []
        # layers_list.append(nn.BatchNorm2d(256, momentum=0.9))
        # layers_list.append(nn.ReLU(True))

        layers_list.append(DecoderBlock(channel_in=512, channel_out=256, kernel_size=4, padding=1, stride=2, output_padding=0)) #256*4
        layers_list.append(DecoderBlock(channel_in=256, channel_out=128, kernel_size=4, padding=1, stride=2, output_padding=0)) #128*8
        layers_list.append(DecoderBlock(channel_in=128, channel_out=64, kernel_size=4, padding=1, stride=2, output_padding=0)) #64*16
        layers_list.append(DecoderBlock(channel_in=64, channel_out=64, kernel_size=4, padding=1, stride=2, output_padding=0)) #64*32
        # layers_list.append(DecoderBlock(channel_in=64, channel_out=64, kernel_size=4, padding=1, stride=2, output_padding=0)) #64*64
        layers_list.append(nn.ReflectionPad2d(2))
        layers_list.append(nn.Conv2d(64,3,kernel_size=5,padding=0))
        layers_list.append(nn.Tanh())

        # layers_list.append(DecoderBlock(channel_in=256, channel_out=256, kernel_size=3, padding=1, stride=1, output_padding=0)) #256*12*14

        self.conv = nn.Sequential(*layers_list)

    def forward(self, ten):
        # print("in DecoderGenerator, print some shape ")
        ten = self.fc(ten)
        ten = ten.view(ten.size()[0],512, 2, 3)
        ten = self.conv(ten)
        assert ten.size()[1] == 3
        assert ten.size()[2] == 32
        assert ten.size()[3] == 48
        return ten

    def __call__(self, *args, **kwargs):
        return super(DecoderGenerator_mask_eye_image, self).__call__(*args, **kwargs)


class DecoderGenerator_mask_mouth(nn.Module):
    def __init__(self, norm_layer):  
        super(DecoderGenerator_mask_mouth, self).__init__()
        

        self.fc = nn.Sequential(nn.Linear(in_features=512, out_features=512*5*9))
        layers_list = []

        layers_list.append(DecoderBlock(channel_in=512, channel_out=256, kernel_size=4, padding=1, stride=2)) #10*18
        layers_list.append(DecoderBlock(channel_in=256, channel_out=256, kernel_size=4, padding=1, stride=2)) #20*36
        # layers_list.append(DecoderBlock(channel_in=256, channel_out=256, kernel_size=4, padding=1, stride=2)) #40*72

        self.conv = nn.Sequential(*layers_list)

    def forward(self, ten):
        # print("in DecoderGenerator, print some shape ")
        # ten = self.fc(ten)
        # ten = ten.view(ten.size()[0],512, 4, 4)
        ten = self.fc(ten)
        ten = ten.view(ten.size()[0],512, 5, 9)
        ten = self.conv(ten)
        assert ten.size()[1] == 256
        assert ten.size()[2] == 20
        assert ten.size()[3] == 36
        return ten

    def __call__(self, *args, **kwargs):
        return super(DecoderGenerator_mask_mouth, self).__call__(*args, **kwargs)


class DecoderGenerator_mask_eye(nn.Module):
    def __init__(self, norm_layer):  
        super(DecoderGenerator_mask_eye, self).__init__()
        # start from B*1024
        # self.fc = nn.Sequential(nn.Linear(in_features=1024, out_features=512*4*4),
        #                         nn.BatchNorm1d(num_features=512*4*4, momentum=0.9),
        #                         nn.ReLU(True))
        # self.fc = nn.Sequential(nn.Linear(in_features=1024, out_features=256*6*7, bias=False))
        self.fc = nn.Sequential(nn.Linear(in_features=512, out_features=512*2*3, bias=False))
        layers_list = []
        # layers_list.append(nn.BatchNorm2d(256, momentum=0.9))
        # layers_list.append(nn.ReLU(True))
        layers_list.append(DecoderBlock(channel_in=512, channel_out=256, kernel_size=4, padding=1, stride=2)) #256*4
        layers_list.append(DecoderBlock(channel_in=256, channel_out=256, kernel_size=4, padding=1, stride=2)) #256*8
        # layers_list.append(DecoderBlock(channel_in=256, channel_out=256, kernel_size=4, padding=1, stride=2)) #256*16
        # layers_list.append(DecoderBlock(channel_in=256, channel_out=256, kernel_size=3, padding=1, stride=1)) #256*16
        # # layers_list.append(DecoderBlock(channel_in=256, channel_out=256, kernel_size=3, padding=1, stride=1, output_padding=0)) #256*12*14

        self.conv = nn.Sequential(*layers_list)

    def forward(self, ten):
        # print("in DecoderGenerator, print some shape ")
        # ten = self.fc(ten)
        # ten = ten.view(ten.size()[0],512, 4, 4)
        ten = self.fc(ten)
        ten = ten.view(ten.size()[0],512, 2, 3)
        ten = self.conv(ten)
        assert ten.size()[1] == 256
        assert ten.size()[2] == 8
        assert ten.size()[3] == 12
        return ten

    def __call__(self, *args, **kwargs):
        return super(DecoderGenerator_mask_eye, self).__call__(*args, **kwargs)


class DecoderGenerator_mask_skin(nn.Module):
    def __init__(self, norm_layer):  
        super(DecoderGenerator_mask_skin, self).__init__()
        # input is 128*4*4
        self.fc = nn.Sequential(nn.Linear(in_features=512, out_features=512*2*2))
        layers_list = []
        layers_list.append(DecoderBlock(channel_in=512, channel_out=256, kernel_size=4, padding=1, stride=2))  #256*4
        layers_list.append(DecoderBlock(channel_in=256, channel_out=256, kernel_size=4, padding=1, stride=2))  #256*8
        layers_list.append(DecoderBlock(channel_in=256, channel_out=256, kernel_size=4, padding=1, stride=2))  #256*16
        layers_list.append(DecoderBlock(channel_in=256, channel_out=256, kernel_size=4, padding=1, stride=2))  #256*32
        layers_list.append(DecoderBlock(channel_in=256, channel_out=256, kernel_size=4, padding=1, stride=2))  #256*64
        self.conv = nn.Sequential(*layers_list)

    def forward(self, ten):
        # print("in DecoderGenerator_mask_skin, print some shape ")
        ten = self.fc(ten)
        ten = ten.view(ten.size()[0],512, 2, 2)
        ten = self.conv(ten)
        assert ten.size()[1] == 256
        assert ten.size()[2] == 64
        return ten

    def __call__(self, *args, **kwargs):
        return super(DecoderGenerator_mask_skin, self).__call__(*args, **kwargs)
    
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

def get_norm_layer(norm_type='instance'):
    if norm_type == 'batch':
        norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
    elif norm_type == 'instance':
        norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
    else:
        raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
    return norm_layer

In [11]:
class FurryGan(nn.Module):
    def __init__(self, isTrain=True):
        
        #Hyperparams
        self.lr = 0.0002
        self.beta1 = 0.5
        
        self.input_nc = 11             #number of input channels
        self.output_nc = 3             #number of output channels
        self.label_nc = 11             #number of mask channels
        self.isTrain = isTrain         #Whether to train
        self.dis_net_input_nc = self.input_nc + self.output_nc
        self.dis_n_layers = 3
        self.num_D = 2
        self.lambda_feat= 10.0
        
        #Loss Function parameters - used in init_loss_funtion
        self.use_gan_feat_loss = True
        self.no_vgg_loss = True
        self.no_l2_loss = True
        
        #Optimization Parameters
        self.use_lsgan = False
        
        self.no_ganFeat_loss= True
        
        self.gen_net = GeneratorNetwork(self.input_nc, self.output_nc)
        self.gen_net.apply(weights_init)
        
        if self.isTrain:
            use_sigmoid = True
            
        self.dis_net = DiscriminatorNetwork(self.dis_net_input_nc, self.dis_n_layers, self.num_D, use_sigmoid)
        self.dis_net.apply(weights_init)
        
        #Dont know why we need this???
        self.dis_net2 = DiscriminatorNetwork(self.dis_net_input_nc, self.dis_n_layers, self.num_D, use_sigmoid)
        self.dis_net2.apply(weights_init)
        
        self.p_net = PNetwork(self.label_nc, self.output_nc)
        self.p_net.apply(weights_init)
        #TODO
        longSize = 256
        n_downsample_global = 2
        embed_feature_size = longSize//2**n_downsample_global 

        self.encoder_skin_net = EncoderGenerator_mask_skin(functools.partial(nn.BatchNorm2d, affine=True))
        self.encoder_hair_net = EncoderGenerator_mask_skin(functools.partial(nn.BatchNorm2d, affine=True))
        self.encoder_left_eye_net = EncoderGenerator_mask_eye(functools.partial(nn.BatchNorm2d, affine=True))
        self.encoder_right_eye_net = EncoderGenerator_mask_eye(functools.partial(nn.BatchNorm2d, affine=True))
        self.encoder_mouth_net = EncoderGenerator_mask_mouth(functools.partial(nn.BatchNorm2d, affine=True))       
        
        self.decoder_skin_net = DecoderGenerator_mask_skin(functools.partial(nn.BatchNorm2d, affine=True))
        self.decoder_hair_net = DecoderGenerator_mask_skin(functools.partial(nn.BatchNorm2d, affine=True))
        self.decoder_left_eye_net =  DecoderGenerator_mask_eye(functools.partial(nn.BatchNorm2d, affine=True))
        self.decoder_right_eye_net = DecoderGenerator_mask_eye(functools.partial(nn.BatchNorm2d, affine=True))
        self.decoder_mouth_net =  DecoderGenerator_mask_mouth(functools.partial(nn.BatchNorm2d, affine=True)) 
        
        
        self.decoder_skin_image_net = DecoderGenerator_mask_skin_image(functools.partial(nn.BatchNorm2d, affine=True))
        self.decoder_hair_image_net = DecoderGenerator_mask_skin_image(functools.partial(nn.BatchNorm2d, affine=True))
        self.decoder_left_eye_image_net = DecoderGenerator_mask_eye_image(norm_layer)
        self.decoder_right_eye_image_net = DecoderGenerator_mask_eye_image(norm_layer)
        self.decoder_mouth_image_net = DecoderGenerator_mask_mouth_image(norm_layer)
        
        if self.isTrain:
            
            self.old_lr = self.lr
            
            self.loss_filter = self.init_loss_filter(self.no_ganFeat_loss, self.no_vgg_loss, self.no_l2_loss)
            
            
            self.criterionGAN = GANLoss(use_lsgan=self.use_lsgan, tensor=self.Tensor)   
            self.criterionFeat = torch.nn.L1Loss()
            self.criterionL2 = torch.nn.MSELoss()
            self.criterionL1 = torch.nn.L1Loss()
            self.criterionMFM = MFMLoss()
            
            weight_list = [0.2,1,5,5,5,5,3,8,8,8,1]
            self.criterionCrossEntropy = torch.nn.CrossEntropyLoss(weight = torch.FloatTensor(weight_list))
            
            if self.no_vgg_loss:             
                self.criterionVGG = VGGLoss(weights=None)
                
            self.criterionGM = GramMatrixLoss()
            self.loss_names = self.loss_filter('KL_embed','L2_mask_image','G_GAN','G_GAN_Feat','G_VGG','D_real','D_fake','L2_image','ParsingLoss','G2_GAN','D2_real','D2_fake')
            
            
            params_decoder = list(self.decoder_skin_net.parameters()) + list(self.decoder_hair_net.parameters()) + list(self.decoder_left_eye_net.parameters()) + list(self.decoder_right_eye_net.parameters()) + list(self.decoder_mouth_net.parameters())
            params_image_decoder_params = list(self.decoder_skin_image_net.parameters()) + list(self.decoder_hair_image_net.parameters()) + list(self.decoder_left_eye_image_net.parameters()) + list(self.decoder_right_eye_image_net.parameters()) + list(self.decoder_mouth_image_net.parameters())
            params_encoder = list(self.encoder_skin_net.parameters()) + list(self.encoder_hair_net.parameters()) + list(self.encoder_left_eye_net.parameters()) + list(self.encoder_right_eye_net.parameters()) + list(self.encoder_mouth_net.parameters())
            
            params_together = list(self.gen_net.parameters()) + params_decoder + params_encoder + params_image_decoder
            self.optimizer_G_together = torch.optim.Adam(params_together, lr=self.lr, betas=(self.beta1, 0.999))
            
            params = list(self.dis_net.parameters())    
            self.optimizer_D = torch.optim.Adam(params, lr=self.lr, betas=(self.beta1, 0.999))

            # optimizer D2
            params = list(self.dis_net2.parameters())    
            self.optimizer_D2 = torch.optim.Adam(params, lr=self.lr, betas=(self.beta1, 0.999))
                
        
    def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss, use_l2_loss):
        flags = (True,True,True, use_gan_feat_loss, use_vgg_loss, True, True, use_l2_loss,True,True,True,True)
        def loss_filter(kl_loss,l2_mask_image,g_gan, g_gan_feat, g_vgg, d_real, d_fake, l2_image, loss_parsing,g2_gan,d2_real,d2_fake):
            return [l for (l,f) in zip((kl_loss,l2_mask_image,g_gan,g_gan_feat,g_vgg,d_real,d_fake,l2_image,loss_parsing,g2_gan,d2_real,d2_fake),flags) if f]
        
        return loss_filter
    
    def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, image_affine=None, infer=False):             
        size = label_map.size()
        oneHot_size = (size[0], self.label_nc, size[2], size[3])
        input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
        input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
#         if self.opt.data_type == 16:
#             input_label = input_label.half()

        # get edges from instance map
#         if not self.opt.no_instance:
#             inst_map = inst_map.data.cuda()
#             edge_map = self.get_edges(inst_map)
#             input_label = torch.cat((input_label, edge_map), dim=1) 
        input_label = Variable(input_label, volatile=infer)

        # real images for training
        if real_image is not None:
            real_image = Variable(real_image.data.cuda())

        # affine real images for training
        if image_affine is not None:
            image_affine = Variable(image_affine.data.cuda())

        return input_label, inst_map, real_image, feat_map, image_affine
    
    def forward(self, bg_image, label, inst, image, feat, image_affine, mask_list, ori_label, infer=False):
        input_label, inst_map, real_image, feat_map, real_bg_image = self.encode_input(label, inst, bg_image, feat, bg_image)
        mask4_image = torch.zeros(label.size()[0],3,32,48).cuda()
        mask5_image = torch.zeros(label.size()[0],3,32,48).cuda()
        mask_mouth_image = torch.zeros(label.size()[0],3,80,144).cuda()
        mask_mouth = torch.zeros(label.size()[0],3,80,144).cuda()


        mask_skin = ((label==1)+(label==2)+(label==3)+(label==6)).type(torch.cuda.FloatTensor)
        mask_skin_image = mask_skin * real_image

        mask_hair = (label==10).type(torch.cuda.FloatTensor)
        mask_hair_image = mask_hair * real_image

        mask_mouth_whole = ((label==7)+(label==8)+(label==9)).type(torch.cuda.FloatTensor)

        for batch_index in range(0,label.size()[0]):
            mask4_image[batch_index] = real_image[batch_index,:,int(mask_list[batch_index][0])-16:int(mask_list[batch_index][0])+16,int(mask_list[batch_index][1])-24:int(mask_list[batch_index][1])+24]
            mask5_image[batch_index] = real_image[batch_index,:,int(mask_list[batch_index][2])-16:int(mask_list[batch_index][2])+16,int(mask_list[batch_index][3])-24:int(mask_list[batch_index][3])+24]
            mask_mouth_image[batch_index] = real_image[batch_index,:,int(mask_list[batch_index][4])-40:int(mask_list[batch_index][4])+40,int(mask_list[batch_index][5])-72:int(mask_list[batch_index][5])+72]
            
            mask_mouth[batch_index] = mask_mouth_whole[batch_index,:,int(mask_list[batch_index][4])-40:int(mask_list[batch_index][4])+40,int(mask_list[batch_index][5])-72:int(mask_list[batch_index][5])+72]

        mask_mouth_image = mask_mouth * mask_mouth_image
        
        encode_label_feature = self.gen_net.forward(input_label,type="label_encoder")
        bg_feature = self.gen_net.forward(real_bg_image,type="bg_encoder")
        mask_bg = (label==0).type(torch.cuda.FloatTensor)
        mask_bg_feature = mask_bg * bg_feature
        
        loss_mask_image = 0
        loss_KL = 0
        
        mus4, log_variances4 = self.encoder_left_eye_net(mask4_image)
        variances4 = torch.exp(log_variances4 * 0.5)
        random_sample4 = Variable(torch.randn(mus4.size()).cuda(), requires_grad=True)
        correct_sample4 = random_sample4 * variances4 + mus4
        loss_KL4 = -0.5*torch.sum(-log_variances4.exp() - torch.pow(mus4,2) + log_variances4 + 1)
        reconstruce_mask4_image = self.decoder_left_eye_image_net(correct_sample4)
        loss_mask_image += self.criterionL2(reconstruce_mask4_image, mask4_image.detach()) * 10 
        loss_KL += loss_KL4
        decode_embed_feature4 = self.decoder_left_eye_net(correct_sample4)
        
        mus5, log_variances5 = self.encoder_right_eye_net(mask5_image)
        variances5 = torch.exp(log_variances5 * 0.5)
        random_sample5 = Variable(torch.randn(mus5.size()).cuda(), requires_grad=True)
        correct_sample5 = random_sample5 * variances5 + mus5
        loss_KL5 = -0.5*torch.sum(-log_variances5.exp() - torch.pow(mus5,2) + log_variances5 + 1)
        reconstruce_mask5_image = self.decoder_right_eye_image_net(correct_sample5)
        loss_mask_image += self.criterionL2(reconstruce_mask5_image, mask5_image.detach()) * 10 
        loss_KL += loss_KL5
        decode_embed_feature5 = self.decoder_right_eye_net(correct_sample5)
        
        mus_skin, log_variances_skin = self.encoder_skin_net(mask_skin_image)
        variances_skin = torch.exp(log_variances_skin * 0.5)
        random_sample_skin = Variable(torch.randn(mus_skin.size()).cuda(), requires_grad=True)
        correct_sample_skin = random_sample_skin * variances_skin + mus_skin
        loss_KL_skin = -0.5*torch.sum(-log_variances_skin.exp() - torch.pow(mus_skin,2) + log_variances_skin + 1)
        reconstruce_mask_skin_image = self.decoder_skin_image_net(correct_sample_skin)
        reconstruce_mask_skin_image = mask_skin * reconstruce_mask_skin_image
        loss_mask_image += self.criterionL2(reconstruce_mask_skin_image, mask_skin_image.detach()) * 10 
        loss_KL += loss_KL_skin
        decode_embed_feature_skin = self.decoder_skin_net(correct_sample_skin)
        
        mus_hair, log_variances_hair = self.encoder_hair_net(mask_hair_image)
        variances_hair = torch.exp(log_variances_hair * 0.5)
        random_sample_hair = Variable(torch.randn(mus_hair.size()).cuda(), requires_grad=True)
        correct_sample_hair = random_sample_hair * variances_hair + mus_hair
        loss_KL_hair = -0.5*torch.sum(-log_variances_hair.exp() - torch.pow(mus_hair,2) + log_variances_hair + 1)
        reconstruce_mask_hair_image = self.decoder_hair_image_net(correct_sample_hair)
        reconstruce_mask_hair_image = mask_hair * reconstruce_mask_hair_image
        loss_mask_image += self.criterionL2(reconstruce_mask_hair_image, mask_hair_image.detach()) * 10 
        loss_KL += loss_KL_hair
        decode_embed_feature_hair = self.decoder_hair_net(correct_sample_hair)
        
        mus_mouth, log_variances_mouth = self.encoder_mouth_net(mask_mouth_image)
        variances_mouth = torch.exp(log_variances_mouth * 0.5)
        random_sample_mouth = Variable(torch.randn(mus_mouth.size()).cuda(), requires_grad=True)
        correct_sample_mouth = random_sample_mouth * variances_mouth + mus_mouth
        loss_KL_mouth = -0.5*torch.sum(-log_variances_mouth.exp() - torch.pow(mus_mouth,2) + log_variances_mouth + 1)
        reconstruce_mask_mouth_image = self.decoder_mouth_image_net(correct_sample_mouth)
        reconstruce_mask_mouth_image = mask_mouth * reconstruce_mask_mouth_image 
        loss_mask_image += self.criterionL2(reconstruce_mask_mouth_image, mask_mouth_image.detach()) * 10 
        loss_KL += loss_KL_mouth
        decode_embed_feature_mouth = self.decoder_mouth_net(correct_sample_mouth)
        
        
        left_eye_tensor = torch.zeros(encode_label_feature.size()).cuda()
        right_eye_tensor = torch.zeros(encode_label_feature.size()).cuda()
        mouth_tensor = torch.zeros(encode_label_feature.size()).cuda()

        reorder_left_eye_tensor = torch.zeros(encode_label_feature.size()).cuda()
        reorder_right_eye_tensor = torch.zeros(encode_label_feature.size()).cuda()
        reorder_mouth_tensor = torch.zeros(encode_label_feature.size()).cuda()

        new_order = torch.randperm(label.size()[0])
        
        reorder_decode_embed_feature4 = decode_embed_feature4[new_order]
        reorder_decode_embed_feature5 = decode_embed_feature5[new_order]
        reorder_decode_embed_feature_mouth = decode_embed_feature_mouth[new_order]
        reorder_decode_embed_feature_skin = decode_embed_feature_skin[new_order]
        reorder_decode_embed_feature_hair = decode_embed_feature_hair[new_order]
        
        for batch_index in range(0,label.size()[0]):
            try:
                reorder_left_eye_tensor[batch_index,:,int(mask_list[batch_index][0]/4+0.5)-4:int(mask_list[batch_index][0]/4+0.5)+4,int(mask_list[batch_index][1]/4+0.5)-6:int(mask_list[batch_index][1]/4+0.5)+6] += reorder_decode_embed_feature4[batch_index]
                reorder_right_eye_tensor[batch_index,:,int(mask_list[batch_index][2]/4+0.5)-4:int(mask_list[batch_index][2]/4+0.5)+4,int(mask_list[batch_index][3]/4+0.5)-6:int(mask_list[batch_index][3]/4+0.5)+6] += reorder_decode_embed_feature5[batch_index]
                reorder_mouth_tensor[batch_index,:,int(mask_list[batch_index][4]/4+0.5)-10:int(mask_list[batch_index][4]/4+0.5)+10,int(mask_list[batch_index][5]/4+0.5)-18:int(mask_list[batch_index][5]/4+0.5)+18] += reorder_decode_embed_feature_mouth[batch_index]
            except:
                print("wrong0 ! ")
                
                
        reconstruct_transfer_face = self.gen_net.forward(torch.cat((encode_label_feature,reorder_left_eye_tensor,reorder_right_eye_tensor,reorder_decode_embed_feature_skin,reorder_decode_embed_feature_hair,reorder_mouth_tensor),1),type="image_G")
        reconstruct_transfer_image = self.gen_net.forward(torch.cat((reconstruct_transfer_face,mask_bg_feature),1),type="bg_decoder")
        
        parsing_label_feature = self.p_net(reconstruct_transfer_image)
        parsing_label = softmax2label(parsing_label_feature)
        gt_label = torch.squeeze(ori_label.type(torch.cuda.LongTensor),1)
        loss_parsing = self.criterionCrossEntropy(parsing_label_feature,gt_label)*self.opt.lambda_feat
        
        pred_fake2_pool = self.dis_net2.forward(torch.cat((input_label, reconstruct_transfer_image.detach()), dim=1))
        loss_D2_fake = self.criterionGAN(pred_fake2_pool, False)
        # Real Detection and Loss
        # pred_real = self.discriminate(input_label, real_image)
        pred_real2 = self.dis_net2.forward(torch.cat((input_label, real_image.detach()), dim=1))
        loss_D2_real = self.criterionGAN(pred_real2, True)
        # GAN loss (Fake Passability Loss)        
        pred_fake2 = self.dis_net2.forward(torch.cat((input_label, reconstruct_transfer_image), dim=1))        
        loss_G2_GAN = self.criterionGAN(pred_fake2, True)
        
        
        for batch_index in range(0,label.size()[0]):
            try:
                left_eye_tensor[batch_index,:,int(mask_list[batch_index][0]/4+0.5)-4:int(mask_list[batch_index][0]/4+0.5)+4,int(mask_list[batch_index][1]/4+0.5)-6:int(mask_list[batch_index][1]/4+0.5)+6] += decode_embed_feature4[batch_index]
                right_eye_tensor[batch_index,:,int(mask_list[batch_index][2]/4+0.5)-4:int(mask_list[batch_index][2]/4+0.5)+4,int(mask_list[batch_index][3]/4+0.5)-6:int(mask_list[batch_index][3]/4+0.5)+6] += decode_embed_feature5[batch_index]
                mouth_tensor[batch_index,:,int(mask_list[batch_index][4]/4+0.5)-10:int(mask_list[batch_index][4]/4+0.5)+10,int(mask_list[batch_index][5]/4+0.5)-18:int(mask_list[batch_index][5]/4+0.5)+18] += decode_embed_feature_mouth[batch_index]
            except:
                print("wrong ! ")

        reconstruct_face = self.gen_net.forward(torch.cat((encode_label_feature,left_eye_tensor,right_eye_tensor,decode_embed_feature_skin,decode_embed_feature_hair,mouth_tensor),1),type="image_G")

        reconstruct_image = self.gen_net.forward(torch.cat((reconstruct_face,mask_bg_feature),1),type="bg_decoder")        


        # reconstruce_part image

        mask_left_eye = (label==4).type(torch.cuda.FloatTensor)
        mask_right_eye = (label==5).type(torch.cuda.FloatTensor)
        mask_mouth = ((label==7)+(label==8)+(label==9)).type(torch.cuda.FloatTensor)

        loss_L2_image = 0
        for batch_index in range(0,label.size()[0]):
            loss_L2_image += self.criterionL2( mask_left_eye*reconstruct_image, mask_left_eye*real_image) * 10 
            loss_L2_image += self.criterionL2( mask_right_eye*reconstruct_image, mask_right_eye*real_image) * 10 
            loss_L2_image += self.criterionL2( mask_skin*reconstruct_image, mask_skin*real_image) * 5 
            loss_L2_image += self.criterionL2( mask_hair*reconstruct_image, mask_hair*real_image) * 5
            loss_L2_image += self.criterionL2( mask_mouth*reconstruct_image, mask_mouth*real_image) * 10 
            loss_L2_image += self.criterionL2( reconstruct_image, real_bg_image ) * 10

        # Fake Detection and Loss
        # pred_fake_pool = self.discriminate(input_label, reconstruct_image, use_pool=True)
        pred_fake_pool = self.dis_net.forward(torch.cat((input_label, reconstruct_image.detach()), dim=1))
        loss_D_fake = self.criterionGAN(pred_fake_pool, False)
        # Real Detection and Loss
        # pred_real = self.discriminate(input_label, real_image)
        pred_real = self.dis_net.forward(torch.cat((input_label, real_image.detach()), dim=1))
        loss_D_real = self.criterionGAN(pred_real, True)
        # GAN loss (Fake Passability Loss)        
        pred_fake = self.dis_net.forward(torch.cat((input_label, reconstruct_image), dim=1))        
        loss_G_GAN = self.criterionGAN(pred_fake, True)
        
        
        loss_G_GAN_Feat = 0
        if self.no_ganFeat_loss:
            feat_weights = 4.0 / (self.dis_n_layers + 1)
            D_weights = 1.0 / self.num_D
            for i in range(self.num_D):
                for j in range(len(pred_fake[i])-1):
                    loss_G_GAN_Feat += D_weights * feat_weights * \
                        self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.lambda_feat
        
        all_mask_tensor = torch.cat((mask_left_eye,mask_right_eye,mask_skin,mask_hair,mask_mouth),1)
        
        mask_weight_list = [10,10,5,5,10]
        # VGG feature matching loss
        loss_G_VGG = 0
        if self.no_vgg_loss:
            loss_G_VGG += self.criterionVGG(reconstruct_image, real_image, all_mask_tensor, mask_weights = mask_weight_list) * self.opt.lambda_feat * 3
            # loss_G_VGG += self.criterionVGG(reconstruct_image, real_image, mask4, weights = [1.0/4,1.0/4,1.0/4,1.0/8,1.0/8]) * self.opt.lambda_feat * 10
            
        return self.loss_filter( loss_KL,loss_mask_image,loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake, loss_L2_image, loss_parsing, loss_G2_GAN, loss_D2_real, loss_D2_fake), None if not infer else reconstruct_image, None if not infer else reconstruce_mask4_image, None if not infer else reconstruce_mask5_image, None if not infer else reconstruce_mask_skin_image, None if not infer else reconstruce_mask_hair_image, None if not infer else reconstruce_mask_mouth_image, None if not infer else reconstruct_transfer_image, None if not infer else parsing_label