In [1]:
from skimage import color
from skimage.transform import resize
from skimage.io import imread
import numpy as np
import os
import sklearn.neighbors as nn
import warnings
import configparser

In [4]:
class NNEncode():
    ''' Encode points using NN search and Gaussian kernel '''
    def __init__(self,NN,sigma,km_filepath='',cc=-1):
        if(check_value(cc,-1)):
            self.cc = np.load(km_filepath)
        else:
            self.cc = cc
        self.K = self.cc.shape[0]
        self.NN = int(NN)
        self.sigma = sigma
        self.nbrs = nn.NearestNeighbors(n_neighbors=NN, algorithm='ball_tree').fit(self.cc)

        self.alreadyUsed = False

    def encode_points_mtx_nd(self,pts_nd,axis=1,returnSparse=False,sameBlock=True):

        pts_flt = flatten_nd_array(pts_nd,axis=axis)

        P = pts_flt.shape[0]
        if(sameBlock and self.alreadyUsed):
            self.pts_enc_flt[...] = 0 # already pre-allocated
        else:
            self.alreadyUsed = True
            self.pts_enc_flt = np.zeros((P,self.K))
            self.p_inds = np.arange(0,P,dtype='int')[:,na()]

        P = pts_flt.shape[0]

        (dists,inds) = self.nbrs.kneighbors(pts_flt)

        wts = np.exp(-dists**2/(2*self.sigma**2))
        wts = wts/np.sum(wts,axis=1)[:,na()]

        self.pts_enc_flt[self.p_inds,inds] = wts
        pts_enc_nd = unflatten_2d_array(self.pts_enc_flt,pts_nd,axis=axis)

        return pts_enc_nd

    def decode_points_mtx_nd(self,pts_enc_nd,axis=1):
        pts_enc_flt = flatten_nd_array(pts_enc_nd,axis=axis)
        pts_dec_flt = np.dot(pts_enc_flt,self.cc)
        pts_dec_nd = unflatten_2d_array(pts_dec_flt,pts_enc_nd,axis=axis)
        return pts_dec_nd

    def decode_1hot_mtx_nd(self,pts_enc_nd,axis=1,returnEncode=False):
        pts_1hot_nd = nd_argmax_1hot(pts_enc_nd,axis=axis)
        pts_dec_nd = self.decode_points_mtx_nd(pts_1hot_nd,axis=axis)
        if(returnEncode):
            return (pts_dec_nd,pts_1hot_nd)
        else:
            return pts_dec_nd

def _nnencode(data_ab_ss):
    '''Encode to 313bin
    Args:
    data_ab_ss: [N, H, W, 2]
    Returns:
    gt_ab_313 : [N, H, W, 313]
    '''
    NN = 10.0
    sigma = 5.0
    enc_dir = './resources/'
    data_ab_ss = np.transpose(data_ab_ss, (0, 3, 1, 2))
    nnenc = NNEncode(NN, sigma, km_filepath=os.path.join(enc_dir, 'pts_in_hull.npy'))
    gt_ab_313 = nnenc.encode_points_mtx_nd(data_ab_ss, axis=1)

    gt_ab_313 = np.transpose(gt_ab_313, (0, 2, 3, 1))
    return gt_ab_313

In [5]:
def preprocess(data):
    '''Preprocess
    Args: 
    data: RGB batch (N * H * W * 3)
    Return:
    data_l: L channel batch (N * H * W * 1)
    gt_ab_313: ab discrete channel batch (N * H/4 * W/4 * 313)
    prior_boost_nongray: (N * H/4 * W/4 * 1) 
    '''
    warnings.filterwarnings("ignore")
    N = data.shape[0]
    H = data.shape[1]
    W = data.shape[2]

    #rgb2lab
    img_lab = color.rgb2lab(data)

    #slice
    #l: [0, 100]
    img_l = img_lab[:, :, :, 0:1]
    #ab: [-110, 110]
    data_ab = img_lab[:, :, :, 1:]

    #scale img_l to [-50, 50]
    data_l = img_l - 50

    #subsample 1/4  (N * H/4 * W/4 * 2)
    data_ab_ss = data_ab[:, ::4, ::4, :]

    #NonGrayMask {N, 1, 1, 1}
    thresh = 5
    nongray_mask = (np.sum(np.sum(np.sum(np.abs(data_ab_ss) > thresh, axis=1), axis=1), axis=1) > 0)[:, np.newaxis, np.newaxis, np.newaxis]

    #NNEncoder
    #gt_ab_313: [N, H/4, W/4, 313]
    gt_ab_313 = _nnencode(data_ab_ss)

    #Prior_Boost 
    #prior_boost: [N, 1, H/4, W/4]
    prior_boost = _prior_boost(gt_ab_313)

    #Eltwise
    #prior_boost_nongray: [N, 1, H/4, W/4]
    prior_boost_nongray = prior_boost * nongray_mask

    return data_l, gt_ab_313, prior_boost_nongray

def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    e_x = np.exp(x - np.expand_dims(np.max(x, axis=-1), axis=-1))
    return e_x / np.expand_dims(e_x.sum(axis=-1), axis=-1) # only difference


def decode(data_l, conv8_313, rebalance=1):
    """
    Args:
    data_l   : [1, height, width, 1]
    conv8_313: [1, height/4, width/4, 313]
    Returns:
    img_rgb  : [height, width, 3]
    """
    data_l = data_l + 50
    _, height, width, _ = data_l.shape
    data_l = data_l[0, :, :, :]
    conv8_313 = conv8_313[0, :, :, :]
    enc_dir = './resources'
    conv8_313_rh = conv8_313 * rebalance
    class8_313_rh = softmax(conv8_313_rh)

    cc = np.load(os.path.join(enc_dir, 'pts_in_hull.npy'))

    data_ab = np.dot(class8_313_rh, cc)
    data_ab = resize(data_ab, (height, width))
    img_lab = np.concatenate((data_l, data_ab), axis=-1)
    img_rgb = color.lab2rgb(img_lab)

    return img_rgb

def get_data_l(image_path):
    """
    Args:
    image_path  
    Returns:
    data_l 
    """
    data = imread(image_path)
    data = data[None, :, :, :]
    img_lab = color.rgb2lab(data)
    img_l = img_lab[:, :, :, 0:1]
    data_l = img_l - 50
    data_l = data_l.astype(dtype=np.float32)
    return data, data_l