In [1]:
from __future__ import print_function
import argparse
import numpy as np
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
import torch.backends.cudnn as cudnn
from torchvision.utils import save_image
from net import MaskGenerator, ResiduePredictor
from mydataset import MyDataset
import cv2
import os
import time

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
#device = torch.device("cpu")
device

device(type='cuda', index=0)

In [2]:
run_name = 'sample'
num_primary_color = 7 
csv_path = 'sample.csv' 
resize_scale_factor = 1 

#img_name = 'image_test_walid_2.jpeg'; manual_color_0 = [98, 12, 15]; manual_color_1 = [138, 206, 225]; manual_color_2 = [226, 179, 159]; manual_color_3 = [69, 173, 198]; manual_color_4 = [213, 215, 221]; manual_color_5 = [85,26,20]; manual_color_6 = [160,217,214]; 

#img_name = 'palette1.jpeg'; manual_color_0 = [98, 12, 15]; manual_color_1 = [138, 206, 225]; manual_color_2 = [226, 179, 159]; manual_color_3 = [69, 173, 198]; manual_color_4 = [213, 215, 221]; manual_color_5 = [85,26,20]; manual_color_6 = [160,217,214]; 

#img_name = 'palette2.jpeg'; manual_color_0 = [98, 12, 15]; manual_color_1 = [138, 206, 225]; manual_color_2 = [226, 179, 159]; manual_color_3 = [69, 173, 198]; manual_color_4 = [213, 215, 221]; manual_color_5 = [85,26,20]; manual_color_6 = [160,217,214]; 

img_name = 'palette3.jpeg'; manual_color_0 = [98, 12, 15]; manual_color_1 = [138, 206, 225]; manual_color_2 = [226, 179, 159]; manual_color_3 = [69, 173, 198]; manual_color_4 = [213, 215, 221]; manual_color_5 = [85,26,20]; manual_color_6 = [160,217,214]; 


img_path = '../dataset/test/' + img_name

path_mask_generator = 'results/' + run_name + '/mask_generator.pth'
path_residue_predictor = 'results/' + run_name + '/residue_predictor.pth'

if num_primary_color == 7:
    manual_colors = np.array([manual_color_0, manual_color_1, manual_color_2, manual_color_3,\
                                               manual_color_4, manual_color_5, manual_color_6]) /255
elif num_primary_color == 6:
    manual_colors = np.array([manual_color_0, manual_color_1, manual_color_2, manual_color_3,\
                                               manual_color_4, manual_color_5]) /255
elif num_primary_color == 4:
    manual_colors = np.array([manual_color_0, manual_color_1, manual_color_2, manual_color_3])/255

In [3]:
try:
    os.makedirs('results/%s/%s' % (run_name, img_name))
except OSError:
    pass

In [4]:
test_dataset = MyDataset(csv_path, num_primary_color, mode='test')
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=0,
    )

if torch.cuda.is_available():
    map_location=lambda storage, loc: storage.cuda()
else:
    map_location='cpu'
    
mask_generator = MaskGenerator(num_primary_color).to(device)
residue_predictor = ResiduePredictor(num_primary_color).to(device)

mask_generator.load_state_dict(torch.load(path_mask_generator, map_location=map_location))
residue_predictor.load_state_dict(torch.load(path_residue_predictor, map_location=map_location))

mask_generator.eval()

MaskGenerator(
  (conv1): Conv2d(24, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (conv2): Conv2d(48, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (conv3): Conv2d(96, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (deconv1): ConvTranspose2d(192, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1), bias=False)
  (deconv2): ConvTranspose2d(192, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1), bias=False)
  (deconv3): ConvTranspose2d(96, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1), bias=False)
  (conv4): Conv2d(51, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv5): Conv2d(24, 7, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn3): BatchN

In [5]:
def replace_color(primary_color_layers, manual_colors):
    temp_primary_color_layers = primary_color_layers.clone()
    for layer in range(len(manual_colors)):
        for color in range(3):
                temp_primary_color_layers[:,layer,color,:,:].fill_(manual_colors[layer][color])
    return temp_primary_color_layers


def cut_edge(target_img):
    target_img = F.interpolate(target_img, scale_factor=resize_scale_factor, mode='area')
    h = target_img.size(2)
    w = target_img.size(3)
    h = h - (h % 8)
    w = w - (w % 8)
    target_img = target_img[:,:,:h,:w]
    return target_img

def alpha_normalize(alpha_layers):
    return alpha_layers / (alpha_layers.sum(dim=1, keepdim=True) + 1e-8)

def read_backimage():
    img = cv2.imread('../dataset/backimage.jpg')
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = img.transpose((2,0,1))
    img = img/255
    img = torch.from_numpy(img.astype(np.float32))

    return img.view(1,3,256,256).to(device)

backimage = read_backimage()

In [6]:
from guided_filter_pytorch.guided_filter import GuidedFilter
def proc_guidedfilter(alpha_layers, guide_img):
    guide_img = (guide_img[:, 0, :, :]*0.299 + guide_img[:, 1, :, :]*0.587 + guide_img[:, 2, :, :]*0.114).unsqueeze(1)
    for i in range(alpha_layers.size(1)):
        layer = alpha_layers[:, i, :, :, :]
        processed_layer = GuidedFilter(3, 1*1e-6)(guide_img, layer)
        if i == 0: 
            processed_alpha_layers = processed_layer.unsqueeze(1)
        else:
            processed_alpha_layers = torch.cat((processed_alpha_layers, processed_layer.unsqueeze(1)), dim=1)
    
    return processed_alpha_layers

In [7]:
target_layer_number = [0, 1]
mask_path = 'path/to/mask.image'

def load_mask(mask_path):
    mask = cv2.imread(mask_path, 0) 
    mask[mask<128] = 0.
    mask[mask >= 128] = 1.
    mask = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0).float().cuda()
    return mask

def mask_operate(alpha_layers, target_layer_number, mask_path):
    layer_A = alpha_layers[:, target_layer_number[0], :, :, :]
    layer_B = alpha_layers[:, target_layer_number[1], :, :, :]
    
    layer_AB = layer_A + layer_B
    mask = load_mask(mask_path)
    mask = cut_edge(mask)
    
    layer_A = layer_AB * mask
    layer_B = layer_AB * (1. - mask)
    
    return_alpha_layers = alpha_layers.clone()
    return_alpha_layers[:, target_layer_number[0], :, :, :] = layer_A
    return_alpha_layers[:, target_layer_number[1], :, :, :] = layer_B
    
    return return_alpha_layers
    

In [8]:
test_dataset.imgs_path[0] = img_path

In [9]:
img_number = 0
mean_estimation_time = 0
with torch.no_grad():
    for batch_idx, (target_img, primary_color_layers) in enumerate(test_loader):
        if batch_idx != img_number:
            print('Skip ', batch_idx)
            continue
        print('img #', batch_idx)
        target_img = cut_edge(target_img)
        target_img = target_img.to(device) 
        primary_color_layers = primary_color_layers.to(device)
        primary_color_layers = replace_color(primary_color_layers, manual_colors)
        start_time = time.time()
        primary_color_pack = primary_color_layers.view(primary_color_layers.size(0), -1 , primary_color_layers.size(3), primary_color_layers.size(4))
        primary_color_pack = cut_edge(primary_color_pack)
        primary_color_layers = primary_color_pack.view(primary_color_pack.size(0),-1,3,primary_color_pack.size(2), primary_color_pack.size(3))
        pred_alpha_layers_pack = mask_generator(target_img, primary_color_pack)
        pred_alpha_layers = pred_alpha_layers_pack.view(target_img.size(0), -1, 1, target_img.size(2), target_img.size(3))
        
        processed_alpha_layers = alpha_normalize(pred_alpha_layers) 
        processed_alpha_layers = proc_guidedfilter(processed_alpha_layers, target_img) 
        processed_alpha_layers = alpha_normalize(processed_alpha_layers)  
        
        mono_color_layers = torch.cat((primary_color_layers, processed_alpha_layers), 2) #shape: bn, ln, 4, h, w
        mono_color_layers_pack = mono_color_layers.view(target_img.size(0), -1 , target_img.size(2), target_img.size(3))
        residue_pack  = residue_predictor(target_img, mono_color_layers_pack)
        residue = residue_pack.view(target_img.size(0), -1, 3, target_img.size(2), target_img.size(3))
        pred_unmixed_rgb_layers = torch.clamp((primary_color_layers + residue), min=0., max=1.0)
        reconst_img = (pred_unmixed_rgb_layers * processed_alpha_layers).sum(dim=1)
        end_time = time.time()
        estimation_time = end_time - start_time
        print(estimation_time)
        mean_estimation_time += estimation_time
        
        if True:
            save_layer_number = 0
            save_image(primary_color_layers[save_layer_number,:,:,:,:],
                   'results/%s/%s/test' % (run_name, img_name) + '_img-%02d_primary_color_layers.png' % batch_idx)
            save_image(reconst_img[save_layer_number,:,:,:].unsqueeze(0),
                   'results/%s/%s/test' % (run_name, img_name)  + '_img-%02d_reconst_img.png' % batch_idx)
            save_image(target_img[save_layer_number,:,:,:].unsqueeze(0),
                   'results/%s/%s/test' % (run_name, img_name)  + '_img-%02d_target_img.png' % batch_idx)

            RGBA_layers = torch.cat((pred_unmixed_rgb_layers, processed_alpha_layers), dim=2) 
            RGBA_layers = RGBA_layers[0]
            for i in range(len(RGBA_layers)):
                save_image(RGBA_layers[i, :, :, :], 'results/%s/%s/img-%02d_layer-%02d.png' % (run_name, img_name, batch_idx, i) )
            print('Saved to results/%s/%s/...' % (run_name, img_name))
            
        if False:
            mono_RGBA_layers = torch.cat((primary_color_layers, processed_alpha_layers), dim=2) 
            mono_RGBA_layers = mono_RGBA_layers[0] 
            for i in range(len(mono_RGBA_layers)):
                save_image(mono_RGBA_layers[i, :, :, :], 'results/%s/%s/mono_img-%02d_layer-%02d.png' % (run_name, img_name, batch_idx, i) )

            save_image((primary_color_layers * processed_alpha_layers).sum(dim=1)[save_layer_number,:,:,:].unsqueeze(0),
                   'results/%s/%s/test' % (run_name, img_name)  + '_mono_img-%02d_reconst_img.png' % batch_idx)   
        
        
        if batch_idx == 0:
            break 

img # 0
0.059426307678222656
Saved to results/sample/palette3.jpeg/...


In [10]:
import numpy as np
import cv2
from sklearn.cluster import KMeans
import pandas as pd

num_clusters = 7
img_name = 'palette3.jpeg'
img_path = '../dataset/test/' + img_name

img = cv2.imread(img_path)[:, :, [2, 1, 0]]
size = img.shape[:2]
vec_img = img.reshape(-1, 3)
model = KMeans(n_clusters=num_clusters, n_jobs=-1)
pred = model.fit_predict(vec_img)
pred_img = np.tile(pred.reshape(*size,1), (1,1,3))

center = model.cluster_centers_.reshape(-1)
print(center)

[ 37.66605876  37.08373478  38.28576344 105.66786723  74.02425252
  62.12334395 210.88639369 193.62756526 171.38110398   4.61467189
   3.83550286   5.81568787 214.91317886 163.61534939 104.29410508
 147.61077284 130.65422581 125.53918248 192.19831804  98.48629969
  61.04275229]


In [11]:
print('img_name = \'%s\';' % img_name, end=" ")
for k, i in enumerate(model.cluster_centers_):
    print('manual_color_%d = [' % k + str(i[0].astype('int')) +', '+ str(i[1].astype('int'))+  ', '+ str(i[2].astype('int')) + '];', end=" ")

img_name = 'palette3.jpeg'; manual_color_0 = [37, 37, 38]; manual_color_1 = [105, 74, 62]; manual_color_2 = [210, 193, 171]; manual_color_3 = [4, 3, 5]; manual_color_4 = [214, 163, 104]; manual_color_5 = [147, 130, 125]; manual_color_6 = [192, 98, 61]; 