# LOAD DATA

In [1]:
import os
import numpy as np
from skimage.color import rgb2lab
from skimage import io

import torch
import torch.utils.data as data
import torchvision.datasets as dsets
import torchvision.transforms as transforms

import re
import pickle

def atoi(text):
    return int(text) if text.isdigit() else text

def natural_keys(text):
    return [ atoi(c) for c in re.split('(\d+)', text) ]

class Dataset(data.Dataset):

    def __init__(self, root_dir, pal_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.img_list = os.listdir(root_dir)
        self.img_list.sort(key=natural_keys)
        self.palette_dir = pal_dir
        self.data = rgb2lab(np.load(self.palette_dir)
                            .reshape(-1,5,3)/255, 
                            illuminant='D50')
    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.img_list[idx])
        image = io.imread(img_name)
        if self.transform:
            image = self.transform(image)

        return (image, self.data[idx])


class LoadImagenet(data.Dataset):

    def __init__(self, image_dir, pal_dir):

        with open(image_dir,'rb') as f:
            self.image_data = np.asarray(pickle.load(f)[:1000]) / 255

        with open(pal_dir,'rb') as f:
            self.pal_data = rgb2lab(np.asarray(pickle.load(f)[:1000])
                                    .reshape(-1,5,3) / 256
                                    ,illuminant='D50')
                                      
        self.data_size = self.image_data.shape[0]

    def __len__(self):
        return self.data_size

    def __getitem__(self, idx):
        return self.image_data[idx], self.pal_data[idx]


def Color_Dataloader(dataset, batch_size, idx=0):

    if dataset == 'bird256':

        traindir = '/kaggle/input/data-text2color/data/bird256/test_palette/test_images_origin.txt'
        pal_traindir = '/kaggle/input/data-text2color/data/bird256/test_palette/test_palette_origin.txt'
        
        train_dataset = LoadImagenet(traindir, pal_traindir)
        train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   num_workers=2)

        imsize = 256


    return (train_dataset, train_loader, imsize)


def process_palette_ab(pal_data, batch_size):

    img_a_scale = (pal_data[:, :, 1:2] + 88) / 185
    img_b_scale = (pal_data[:, :, 2:3] + 127) / 212
    img_ab_scale = np.concatenate((img_a_scale,img_b_scale),axis=2)
    ab_for_global = torch.from_numpy(img_ab_scale).float()
    ab_for_global = ab_for_global.view(batch_size, 10).unsqueeze(2).unsqueeze(2)

    return ab_for_global

def process_palette_lab(pal_data, batch_size):

    img_l = pal_data[:, :, 0:1] / 100
    img_a_scale = (pal_data[:, :, 1:2] + 88) / 185
    img_b_scale = (pal_data[:, :, 2:3] + 127) / 212
    img_lab_scale = np.concatenate((img_l, img_a_scale, img_b_scale),axis=2)
    lab_for_global = torch.from_numpy(img_lab_scale).float()
    lab_for_global = lab_for_global.view(batch_size, 15).unsqueeze(2).unsqueeze(2)

    return lab_for_global

def process_data(image_data, batch_size, imsize):
    input = torch.zeros(batch_size, 1, imsize, imsize)
    labels = torch.zeros(batch_size, 2, imsize, imsize)
    images_np = image_data.numpy().transpose((0, 2, 3, 1))

    for k in range(batch_size):

        img_lab = rgb2lab(images_np[k], illuminant='D50')
        img_l = img_lab[:, :,0] / 100
        input[k] = torch.from_numpy(np.expand_dims(img_l, 0))

        img_a_scale = (img_lab[:, :, 1:2] + 88) / 185
        img_b_scale = (img_lab[:, :, 2:3] + 127) / 212

        img_ab_scale = np.concatenate((img_a_scale,img_b_scale),axis=2)
        labels[k] = torch.from_numpy(img_ab_scale.transpose((2, 0, 1)))

    return input, labels

# Global Hint

In [2]:
import torch
import numpy as np
from torch.autograd import Variable

def process_global_ab(input_ab, batch_size, always_give_global_hint):
    X_hist = input_ab

    if always_give_global_hint:
        B_hist = torch.ones(batch_size, 1, 1 ,1)
    else:
        B_hist = torch.round(torch.rand(batch_size, 1, 1 ,1))
        for l in range(batch_size):
            if B_hist[l].numpy() == 0:
                X_hist[l] = torch.rand(10)

    
    global_input = torch.cat([X_hist, B_hist], 1)

    return global_input

def process_global_lab(input_lab, batch_size, always_give_global_hint):
    
    X_hist = input_lab
    if always_give_global_hint:
        B_hist = torch.ones(batch_size, 1, 1 ,1)
    else:
        B_hist = torch.round(torch.rand(batch_size, 1, 1 ,1))
        for l in range(batch_size):
            if B_hist[l].numpy() == 0:
                X_hist[l] = torch.rand(15)

    global_input = torch.cat([X_hist, B_hist], 1)

    return global_input

def process_global_sampling_ab(palette, batch_size, imsize, hist_mean, hist_std):

    X_hist = palette
    B_hist = torch.ones(batch_size, 1, 1, 1)

    X_hist = Variable(X_hist).cuda()
    B_hist = Variable(B_hist).cuda()
    
    global_input = torch.cat([X_hist, B_hist], 1)

    return global_input

def process_global_sampling_lab(palette, batch_size, imsize, hist_mean, hist_std):

    X_hist = palette
    B_hist = torch.ones(batch_size, 1, 1, 1)

    X_hist = Variable(X_hist).cuda()
    B_hist = Variable(B_hist).cuda()
    
    global_input = torch.cat([X_hist, B_hist], 1)

    return global_input

# Utils

In [3]:
import os
import numpy as np
import time
import datetime
import torch
import warnings
from skimage.color import rgb2lab, lab2rgb, rgb2gray

def check_value(inds, val):
    if (np.array(inds).size == 1):
        if (inds == val):
            return True
    return False

def flatten_nd_array(pts_nd, axis=1):

    NDIM = pts_nd.ndim
    SHP = np.array(pts_nd.shape)
    nax = np.setdiff1d(np.arange(0, NDIM), np.array((axis)))
    NPTS = np.prod(SHP[nax])
    axorder = np.concatenate((nax, np.array(axis).flatten()), axis=0)
    pts_flt = pts_nd.transpose((axorder))
    pts_flt = pts_flt.reshape(NPTS, SHP[axis])
    return pts_flt

def unflatten_2d_array(pts_flt, pts_nd, axis=1, squeeze=False):

    NDIM = pts_nd.ndim
    SHP = np.array(pts_nd.shape)
    nax = np.setdiff1d(np.arange(0, NDIM), np.array((axis)))
    NPTS = np.prod(SHP[nax])

    if (squeeze):
        axorder = nax
        axorder_rev = np.argsort(axorder)
        M = pts_flt.shape[1]
        NEW_SHP = SHP[nax].tolist()
        pts_out = pts_flt.reshape(NEW_SHP)
        pts_out = pts_out.transpose(axorder_rev)
    else:
        axorder = np.concatenate((nax, np.array(axis).flatten()), axis=0)
        axorder_rev = np.argsort(axorder)
        M = pts_flt.shape[1]
        NEW_SHP = SHP[nax].tolist()
        NEW_SHP.append(M)
        pts_out = pts_flt.reshape(NEW_SHP)
        pts_out = pts_out.transpose(axorder_rev)

    return pts_out

def na():
    return np.newaxis


class Timer():
    def __init__(self):
        self.cur_t = time.time()

    def tic(self):
        self.cur_t = time.time()

    def toc(self):
        return time.time() - self.cur_t

    def tocStr(self, t=-1):
        if (t == -1):
            return str(datetime.timedelta(seconds=np.round(time.time() - self.cur_t, 3)))[:-4]
        else:
            return str(datetime.timedelta(seconds=np.round(t, 3)))[:-4]

def distribution(tensor):

    tensor = torch.div(tensor, expand(tensor.sum(dim=1).unsqueeze(-1), tensor))
    if (tensor.sum(dim=1).data.cpu().numpy()==0).any():
        print ("")
        print ("")
        print ("division by zero")
        print ("")
        print ("")
    return tensor.unsqueeze(-1)


def make_folder(path, dataset):
    try:
        os.makedirs(os.path.join(path, dataset))
    except OSError:
        pass

def num_param(model):
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    return params

def print_log(idx, num_idx, epoch, mini_batch, num_epochs, num_batches, sL1_loss, tell_time, iter):
    if (mini_batch + 1) % 10 == 0:
        print('Epoch [%d/%d], IDX [%d/%d], Iter [%d/%d], sL1_loss: %.10f, iter_time: %2.2f, aggregate_time: %6.2f'
              % (epoch + 1, num_epochs, idx, num_idx, mini_batch + 1, num_batches, sL1_loss,
                 (tell_time.toc() - iter), tell_time.toc()))
        iter = tell_time.toc()

def resume(resume_, log_path, dataset, G, G_optimizer, D, D_optimizer):
    start_idx=1
    start_epoch=0
    if resume_:
        ckpt_path = os.path.join(log_path, dataset, 'ckpt/model_origin.ckpt')
        if os.path.isfile(ckpt_path):
            print("Loading checkpoint...")
            checkpoint = torch.load(ckpt_path)
            try:
                start_idx = checkpoint['idx']
            except:
                start_idx = 0
            start_epoch = checkpoint['epoch']
            G.load_state_dict(checkpoint['G_state_dict'])
            G_optimizer.load_state_dict(checkpoint['G_optimizer'])
            D.load_state_dict(checkpoint['D_state_dict'])
            D_optimizer.load_state_dict(checkpoint['D_optimizer'])
            print("Start training from epoch {}.".format(checkpoint['epoch']+1))
        else:
            print("Sorry, no checkpoint found.")

    return G, G_optimizer, D, D_optimizer, start_idx, start_epoch

def lab2rgb_1d(in_lab, clip=True):
    warnings.filterwarnings("ignore")
    tmp_rgb = lab2rgb(in_lab[np.newaxis, np.newaxis, :], illuminant='D50').flatten()
    if clip:
        tmp_rgb = np.clip(tmp_rgb, 0, 1)
    return tmp_rgb

# Model

In [4]:
import torch.nn as nn
from torch.autograd import Variable
from torch import optim
import torch.optim.lr_scheduler as scheduler
import torch.nn.functional as F
# import torch.nn.sigmoid as F


class UNetConvBlock1_1(nn.Module):
    def __init__(self, in_size, out_size, kernel_size=3):
        super(UNetConvBlock1_1, self).__init__()
        self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=1)

    def forward(self, x):
        out = self.conv(x)
        return out

class UNetConvBlock1_2(nn.Module):
    def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu):
        super(UNetConvBlock1_2, self).__init__()
        self.conv2 = nn.Conv2d(in_size, out_size, kernel_size, padding=1)
        self.activation = activation
        self.batchnorm = nn.BatchNorm2d(out_size)
        self.conv3 = nn.Conv2d(out_size, out_size, 1, stride=2, groups=out_size, bias=False)

    def forward(self, x):
        out = self.activation(x)
        out = self.activation(self.conv2(out))
        out = self.batchnorm(out)
        out = self.conv3(out)
        return out

class UNetConvBlock1_2_2(nn.Module):
    def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu):
        super(UNetConvBlock1_2_2, self).__init__()
        self.conv2 = nn.Conv2d(in_size, out_size, kernel_size, padding=1)
        self.activation = activation
        self.batchnorm = nn.BatchNorm2d(out_size)

    def forward(self, x):
        out = self.activation(x)
        out = self.activation(self.conv2(out))
        out = self.batchnorm(out)
        return out

class UNetConvBlock2(nn.Module):
    def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu):
        super(UNetConvBlock2, self).__init__()
        self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=1)
        self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=1)
        self.activation = activation
        self.batchnorm = nn.BatchNorm2d(out_size)
        self.conv3 = nn.Conv2d(out_size, out_size, 1, stride=2, groups=out_size, bias=False)

    def forward(self, x):
        out = self.activation(self.conv(x))
        out = self.activation(self.conv2(out))
        out = self.batchnorm(out)
        out = self.conv3(out)
        return out

class UNetConvBlock2_2(nn.Module):
    def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu):
        super(UNetConvBlock2_2, self).__init__()
        self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=1)
        self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=1)
        self.activation = activation
        self.batchnorm = nn.BatchNorm2d(out_size)

    def forward(self, x):
        out = self.activation(self.conv(x))
        out = self.activation(self.conv2(out))
        out = self.batchnorm(out)
        return out

class UNetConvBlock3(nn.Module):
    def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu):
        super(UNetConvBlock3, self).__init__()
        self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=1)
        self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=1)
        self.conv3 = nn.Conv2d(out_size, out_size, kernel_size, padding=1)
        self.activation = activation
        self.batchnorm = nn.BatchNorm2d(out_size)
        self.conv4 = nn.Conv2d(out_size, out_size, 1, stride=2, groups=out_size, bias=False)

    def forward(self, x):
        out = self.activation(self.conv(x))
        out = self.activation(self.conv2(out))
        out = self.activation(self.conv3(out))
        out = self.batchnorm(out)
        out = self.conv4(out)
        return out

class UNetConvBlock3_2(nn.Module):
    def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu):
        super(UNetConvBlock3_2, self).__init__()
        self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=1)
        self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=1)
        self.conv3 = nn.Conv2d(out_size, out_size, kernel_size, padding=1)
        self.activation = activation
        self.batchnorm = nn.BatchNorm2d(out_size)

    def forward(self, x):
        out = self.activation(self.conv(x))
        out = self.activation(self.conv2(out))
        out = self.activation(self.conv3(out))
        out = self.batchnorm(out)
        return out

class UNetConvBlock4(nn.Module):
    def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu):
        super(UNetConvBlock4, self).__init__()
        self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=1, dilation=1)
        self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=1, dilation=1)
        self.conv3 = nn.Conv2d(out_size, out_size, kernel_size, padding=1, dilation=1)
        self.activation = activation
        self.batchnorm = nn.BatchNorm2d(out_size)

    def forward(self, x):
        out = self.activation(self.conv(x))
        out = self.activation(self.conv2(out))
        out = self.activation(self.conv3(out))
        out = self.batchnorm(out)
        return out

class UNetConvBlock5(nn.Module):
    def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu):
        super(UNetConvBlock5, self).__init__()
        self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=2, dilation=2)
        self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=2, dilation=2)
        self.conv3 = nn.Conv2d(out_size, out_size, kernel_size, padding=2, dilation=2)
        self.activation = activation
        self.batchnorm = nn.BatchNorm2d(out_size)

    def forward(self, x):
        out = self.activation(self.conv(x))
        out = self.activation(self.conv2(out))
        out = self.activation(self.conv3(out))
        out = self.batchnorm(out)
        return out

class UNetConvBlock6(nn.Module):
    def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu):
        super(UNetConvBlock6, self).__init__()
        self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=2, dilation=2)
        self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=2, dilation=2)
        self.conv3 = nn.Conv2d(out_size, out_size, kernel_size, padding=2, dilation=2)
        self.activation = activation
        self.batchnorm = nn.BatchNorm2d(out_size)

    def forward(self, x):
        out = self.activation(self.conv(x))
        out = self.activation(self.conv2(out))
        out = self.activation(self.conv3(out))
        out = self.batchnorm(out)
        return out

class UNetConvBlock7(nn.Module):
    def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu):
        super(UNetConvBlock7, self).__init__()
        self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=1, dilation=1)
        self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=1, dilation=1)
        self.conv3 = nn.Conv2d(out_size, out_size, kernel_size, padding=1, dilation=1)
        self.activation = activation
        self.batchnorm = nn.BatchNorm2d(out_size)

    def forward(self, x):
        out = self.activation(self.conv(x))
        out = self.activation(self.conv2(out))
        out = self.activation(self.conv3(out))
        out = self.batchnorm(out)
        return out

class UNetConvBlock8(nn.Module):
    def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu, space_dropout=False):
        super(UNetConvBlock8, self).__init__()
        self.up = nn.ConvTranspose2d(in_size, out_size, 4, stride=2, padding=1, dilation=1)
        self.bridge = nn.Conv2d(256, 256, kernel_size, padding=1)

        self.conv = nn.Conv2d(out_size, out_size, kernel_size, padding=1, dilation=1)
        self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=1, dilation=1)
        self.activation = activation
        self.batchnorm = nn.BatchNorm2d(out_size)

    def forward(self, x, bridge):
        up = self.up(x)
        out = self.activation(self.bridge(bridge) + up)
        out = self.activation(self.conv(out))
        out = self.activation(self.conv2(out))
        out = self.batchnorm(out)
        return out

class UNetConvBlock9(nn.Module):
    def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu, space_dropout=False):
        super(UNetConvBlock9, self).__init__()
        self.up = nn.ConvTranspose2d(in_size, out_size, 4, stride=2, padding=1, dilation=1)
        self.bridge = nn.Conv2d(128, 128, kernel_size, padding=1)
        self.conv = nn.Conv2d(out_size, out_size, kernel_size, padding=1, dilation=1)
        self.activation = activation
        self.batchnorm = nn.BatchNorm2d(out_size)

    def forward(self, x, bridge):
        up = self.up(x)
        out = self.activation(self.bridge(bridge) + up)
        out = self.activation(self.conv(out))
        out = self.batchnorm(out)

        return out

class UNetConvBlock10(nn.Module):
    def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu, space_dropout=False):
        super(UNetConvBlock10, self).__init__()
        self.up = nn.ConvTranspose2d(in_size, out_size, 4, stride=2, padding=1, dilation=1)
        self.bridge = nn.Conv2d(64, 128, kernel_size, padding=1)
        self.conv = nn.Conv2d(out_size, out_size, kernel_size, padding=1, dilation=1)
        self.activation = activation
        self.activation2 = nn.LeakyReLU(negative_slope=0.02)

    def forward(self, x, bridge):
        up = self.up(x)
        out = self.activation(self.bridge(bridge) + up)
        out = self.activation2(self.conv(out))
        return out

class prediction(nn.Module):
    def __init__(self, in_size, out_size, kernel_size=1, activation=F.sigmoid, space_dropout=False):
        super(prediction, self).__init__()
        self.conv = nn.Conv2d(in_size, out_size, kernel_size, dilation=1)
        self.activation = activation

    def forward(self, x):
        out = self.activation(self.conv(x))
        return out

class convrelu(nn.Module):

    def __init__(self, in_size, out_size, kernel_size=1, activation=F.relu, space_dropout=False):
        super(convrelu, self).__init__()
        self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=0)
        self.activation = activation

    def forward(self, x):
        out = self.activation(self.conv(x))
        return out

class global_network(nn.Module):
    def __init__(self, image_size, add_L):
        super(global_network, self).__init__()
        if add_L:
            self.oneD = convrelu(16, 128)
        else:
            self.oneD = convrelu(11, 128)
        self.twoD = convrelu(128, 256)
        self.threeD = convrelu(256, 512)
        self.fourD = convrelu(512, 512)
        self.image_size = image_size

    def forward(self, x, dim):
        n = 2
        out = self.oneD(x)
        
        if dim >= 256:
            n = 4
            out = self.twoD(out)
        if dim == 512:
            n = 8
            out = self.threeD(out)
            out = self.fourD(out)

        out = out.repeat(1,1, int(self.image_size/n), int(self.image_size/n))
        return out


class UNet(nn.Module):
    def __init__(self, imsize, multi_injection, add_L):
        super(UNet, self).__init__()
        self.imsize = imsize
        self.multi_injection = multi_injection
        self.globalnet512 = global_network(self.imsize, add_L)

        if multi_injection:
            self.globalnet256 = global_network(self.imsize, add_L)
            self.globalnet128 = global_network(self.imsize, add_L)

        self.convlayer1_1 = UNetConvBlock1_1(1, 64)
        self.convlayer1_2 = UNetConvBlock1_2(64, 64)
        self.convlayer1_2_2 = UNetConvBlock1_2_2(64, 64)
        self.convlayer2 = UNetConvBlock2(64, 128)
        self.convlayer2_2 = UNetConvBlock2_2(64, 128)
        self.convlayer3 = UNetConvBlock3(128, 256)
        self.convlayer3_2 = UNetConvBlock3_2(128, 256)
        self.convlayer4 = UNetConvBlock4(256, 512)
        self.convlayer5 = UNetConvBlock5(512, 512)
        self.convlayer6 = UNetConvBlock6(512, 512)
        self.convlayer7 = UNetConvBlock7(512, 512)
        self.convlayer8 = UNetConvBlock8(512, 256)
        self.convlayer9 = UNetConvBlock9(256, 128)
        self.convlayer10 = UNetConvBlock10(128, 128)

        self.prediction = prediction(128, 2)

    def forward(self, x, side_input):
        layer1_1 = self.convlayer1_1(x)
        layer1_2 = self.convlayer1_2(layer1_1)
        layer1_2_2 = self.convlayer1_2_2(layer1_1)
        layer2 = self.convlayer2(layer1_2)
        layer2_2 = self.convlayer2_2(layer1_2)
        layer3 = self.convlayer3(layer2)
        layer3_2 = self.convlayer3_2(layer2)
        layer4 = self.convlayer4(layer3)

        global_net512 = self.globalnet512(side_input, 512)
        layer4 = layer4 + global_net512
        layer5 = self.convlayer5(layer4)
        layer6 = self.convlayer6(layer5)
        layer7 = self.convlayer7(layer6)

        layer8 = self.convlayer8(layer7, layer3_2)
        if self.multi_injection:
            global_net256 = self.globalnet256(side_input, 256)
            layer8 = layer8 + global_net256
        
        layer9 = self.convlayer9(layer8, layer2_2)
        if self.multi_injection:
            global_net128 = self.globalnet128(side_input, 128)
            layer9 = layer9 + global_net128
        
        layer10 = self.convlayer10(layer9, layer1_2_2)

        prediction = self.prediction(layer10)

        return prediction


class Discriminator(nn.Module):

    def __init__(self, add_L, imsize, conv_dim=64, repeat_num=5):
        super(Discriminator, self).__init__()

        input_dim = 2 + 10
        if add_L:
            input_dim = 3 + 15

        layers = []
        layers.append(nn.Conv2d(input_dim, conv_dim, kernel_size=4, stride=2, padding=1))
        layers.append(nn.LeakyReLU(0.01, inplace=True))

        curr_dim = conv_dim
        for i in range(1, repeat_num):
            layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1))
            layers.append(nn.LeakyReLU(0.01, inplace=True))
            curr_dim = curr_dim * 2

        k_size = int(imsize / np.power(2, repeat_num))
        self.main = nn.Sequential(*layers)
        self.conv1 = nn.Conv2d(curr_dim, curr_dim, kernel_size=3, stride=1, padding=1, bias=False)

        self.fc = nn.Sequential(
            nn.BatchNorm1d(k_size*k_size*curr_dim),
            nn.Linear(k_size*k_size*curr_dim, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        batch_size = x.size(0)
        h = self.main(x)
        out = self.conv1(h)
        out = out.view(batch_size, -1)
        out = self.fc(out)
        return out


def init_models(batch_size, imsize, dropout_ep, learning_rate, multi_injection, 
                add_L, weight_decay=1e-7):

    G = UNet(imsize, multi_injection, add_L).cuda()
    print('# parameters of Generator : ',num_param(G))
    D = Discriminator(add_L, imsize).cuda()
    print('# parameters of Discriminator : ',num_param(D))
    G_optimizer = optim.Adam(G.parameters(), lr=learning_rate, weight_decay=weight_decay)
    D_optimizer = optim.Adam(D.parameters(), lr=learning_rate, weight_decay=weight_decay)

    G_scheduler = scheduler.ReduceLROnPlateau(G_optimizer, 'min', patience=5, factor = 0.1)
    D_scheduler = scheduler.ReduceLROnPlateau(D_optimizer, 'min', patience=5, factor = 0.1)

    return (G, D, G_optimizer, D_optimizer, G_scheduler, D_scheduler)

# GAN

In [5]:
import torch
from torch.autograd import Variable

def train(gm, images, pals, G, D, G_optimizer, D_optimizer, criterion_bce, 
        criterion_sL1, always_give_global_hint, add_L, gan_loss=0.1, isTrain=True):

    D_loss = Variable(torch.zeros(1)).data
    gm.image_process(images, pals, add_L, always_give_global_hint)

    gm.init(G, D)
    gm.g_forward()
    gm.d_forward(True)
    D_loss = gm.d_backward(D_optimizer, criterion_bce, isTrain, gan_loss)

    gm.init(G, D)
    gm.g_forward()
    gm.d_forward(False)
    sL1_loss, G_loss = gm.g_backward(G_optimizer, D_optimizer,
                    criterion_bce, criterion_sL1, isTrain, gan_loss)

    output, ground_truth = gm.getImage()

    loss = sL1_loss + G_loss + D_loss

    return output, ground_truth, loss, sL1_loss


class GanModel(nn.Module):

    def init(self, unet, discriminator):

        self.fake_image = None
        self.true = None
        self.false = None

        self.G = unet
        self.D = discriminator

    def image_process(self, images, pals, always_give_global_hint, add_L):
        
        batch = images.size(0)
        imsize = images.size(3)

        inputs, labels = process_data(images, batch, imsize)
        if add_L:
            for_global = process_palette_lab(pals, batch)
            global_hint = process_global_lab(for_global, batch, always_give_global_hint)
        else:
            for_global = process_palette_ab(pals, batch)
            global_hint = process_global_ab(for_global, batch, always_give_global_hint)

        self.L_image = Variable(inputs).cuda()
        self.real_image = Variable(labels).cuda()
        self.global_hint = Variable(global_hint).cuda()

    def g_forward(self):

        batch = self.global_hint.size(0)
        self.fake_image = self.G(self.L_image, self.global_hint)

    def d_forward(self, isD):
        true = None
        imsize = self.real_image.size(3)
        global_hint = (self.global_hint).expand(-1,-1,imsize,imsize)
        if isD:
            true = self.D(torch.cat((self.real_image, global_hint), dim=1))
        false = self.D(torch.cat((self.fake_image, global_hint), dim=1))

        self.true = true
        self.false = false

    def d_backward(self, D_optimizer, criterion_bce, isTrain, gan_loss):

        batch_size =  self.global_hint.size(0)
        y_ones, y_zeros = (Variable(torch.ones(batch_size, 1), requires_grad=False).cuda(),
                            Variable(torch.zeros(batch_size, 1), requires_grad=False).cuda())

        real_loss = criterion_bce(self.true, y_ones)
        fake_loss = criterion_bce(self.false, y_zeros)
        D_loss = real_loss + fake_loss
        loss = gan_loss * D_loss

        if isTrain:
            D_optimizer.zero_grad()
            loss.backward()
            D_optimizer.step()

#         return loss.data[0]
        return loss.item()

    def g_backward(self, G_optimizer, D_optimizer,
                    criterion_bce, criterion_sL1, isTrain, gan_loss):

        batch_size = self.L_image.size(0)
        G_loss = Variable(torch.zeros(1)).cuda()
        y_ones = Variable(torch.ones(batch_size, 1), requires_grad=False).cuda()
        G_loss = gan_loss * criterion_bce(self.false, y_ones)

        outputs = self.fake_image.view(batch_size, -1)
        labels = self.real_image.contiguous().view(batch_size, -1)

        sL1_loss = criterion_sL1(outputs, labels)
        loss = sL1_loss + G_loss

        if isTrain:
            G_optimizer.zero_grad()
            loss.backward()
            G_optimizer.step()

#         return sL1_loss.data[0], G_loss.data[0]
        return sL1_loss.item(), G_loss.item()
    def getImage(self):
        if self.real_image is not None:
            return self.fake_image.clone(), self.real_image.clone()
        return self.fake_image.clone()


# Train and Test

In [6]:
from __future__ import division
import os
import torch
import argparse
import torch.nn as nn
from torch import cuda
from torch.autograd import Variable
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import pandas as pd
import numpy as np

import sys
sys.argv=['']
del sys

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', type=str, default='bird256', choices=['imagenet','bird256'])
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--model_path', type=str, default='./pal2color/models/')
    parser.add_argument('--log_path', type=str, default='./pal2color/logs')
    parser.add_argument('--image_save', type=str, default='./pal2color/images')
    parser.add_argument('--learning_rate', type=int, default=0.0002)
    parser.add_argument('--num_epochs', type=int, default=100)
    parser.add_argument('--start_epoch', type=int, default=None)
    parser.add_argument('--batch_size', type=int, default=8)
    parser.add_argument('--dropout_p', type=int, default=0.2)
    parser.add_argument('--resume', type=bool, default=False,
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('--gan_loss', type=float, default=0.1)

    parser.add_argument('--always_give_global_hint', type=int, default=1)
    parser.add_argument('--multi_injection', type=int, default=1)
    parser.add_argument('--add_L', type=int, default=1)
    return parser.parse_args()


def main(args):
    dataset = args.data
    gpu = args.gpu
    batch_size = args.batch_size
    dropout_p = args.dropout_p
    model_path = args.model_path
    log_path = args.log_path
    num_epochs = args.num_epochs
    learning_rate = args.learning_rate
    start_epoch = args.start_epoch
    gan_loss = args.gan_loss
    always_give_global_hint = args.always_give_global_hint
    multi_injection = args.multi_injection
    add_L = args.add_L
    g_loss = []
    d_loss = []
    epochs = []
    
    print("Running on gpu : ", gpu)
    cuda.set_device(gpu)

    make_folder(model_path, dataset)
    make_folder(log_path, dataset +'/ckpt')

    (train_dataset, train_loader, imsize) = Color_Dataloader(dataset, batch_size, 0)
    (G, D, G_optimizer, D_optimizer, G_scheduler, D_scheduler) = init_models(batch_size, imsize, dropout_p, learning_rate, multi_injection, add_L)
        
    criterion_sL1 = nn.SmoothL1Loss().cuda()
    criterion_bce = nn.BCELoss().cuda()

    (G, G_optimizer, D, D_optimizer, _, start_epoch) = resume(args.resume, log_path, dataset, G, G_optimizer, D, D_optimizer)
    tell_time = Timer()
    iter = 0
    gm = GanModel()
    for epoch in range(start_epoch, num_epochs):

        G.train()
        for i, (images, pals) in enumerate(train_loader):

            (_, _, loss, sL1_loss) = train(gm, images, pals, G, D, G_optimizer, D_optimizer,
                                            criterion_bce, criterion_sL1, always_give_global_hint, 
                                            add_L, gan_loss, True)

            num_batches = (len(train_dataset) // batch_size)
            print_log(0, 0, epoch, i, num_epochs, num_batches, sL1_loss, tell_time, iter)
            g_loss.append(loss)
            d_loss.append(sL1_loss)
            epochs.append(epoch)

        checkpoint = {
            'epoch': epoch + 1,
            'args': args,
            'G_state_dict': G.state_dict(),
            'G_optimizer': G_optimizer.state_dict(),
            'D_state_dict': D.state_dict(),
            'D_optimizer': D_optimizer.state_dict()
        }

        torch.save(checkpoint, os.path.join(log_path, dataset, 'ckpt/model.ckpt'))
        msg = "epoch: %d" % (epoch)
        if (epoch + 1) % 10 == 0:
            print ('Saved model')
            torch.save(G.state_dict(), os.path.join(
                model_path, dataset, '%s_cGAN-unet_bird256.pkl' % (msg)))
            
#     fig, ax = plt.subplots()
#     plt.plot(d_loss, label='Discriminator', alpha=0.5)
#     plt.plot(g_loss, label='Generator', alpha=0.5)
#     plt.title("Training Losses")
#     plt.legend()
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=epochs, y=d_loss, mode='lines', name='Discriminator Loss', marker_color='orange'))
    fig.add_trace(go.Scatter(x=epochs, y=g_loss, mode='lines', name='Generator Loss', marker_color='seagreen'))
    fig.show()

if __name__ == '__main__':
    args = parse_args()
    main(args)

Running on gpu :  0
# parameters of Generator :  37072706
# parameters of Discriminator :  20795329




Epoch [1/100], IDX [0/0], Iter [10/125], sL1_loss: 0.0032000672, iter_time: 33.40, aggregate_time:  33.40
Epoch [1/100], IDX [0/0], Iter [20/125], sL1_loss: 0.0030785191, iter_time: 55.09, aggregate_time:  55.09
Epoch [1/100], IDX [0/0], Iter [30/125], sL1_loss: 0.0015163614, iter_time: 77.06, aggregate_time:  77.06
Epoch [1/100], IDX [0/0], Iter [40/125], sL1_loss: 0.0012527242, iter_time: 99.39, aggregate_time:  99.39
Epoch [1/100], IDX [0/0], Iter [50/125], sL1_loss: 0.0018031462, iter_time: 122.29, aggregate_time: 122.29
Epoch [1/100], IDX [0/0], Iter [60/125], sL1_loss: 0.0038612261, iter_time: 145.80, aggregate_time: 145.80
Epoch [1/100], IDX [0/0], Iter [70/125], sL1_loss: 0.0033345211, iter_time: 170.29, aggregate_time: 170.29
Epoch [1/100], IDX [0/0], Iter [80/125], sL1_loss: 0.0031713755, iter_time: 194.85, aggregate_time: 194.85
Epoch [1/100], IDX [0/0], Iter [90/125], sL1_loss: 0.0015495808, iter_time: 218.87, aggregate_time: 218.87
Epoch [1/100], IDX [0/0], Iter [100/125],