In [2]:

import numpy as np
import os
import sys, time
sys.path.append('../')

import matplotlib.pyplot as plt
%matplotlib inline

from model.loss import myloss
from model.metric import all_accuracy
from keras.applications.xception import preprocess_input
from keras.callbacks import Callback, EarlyStopping
from model.DepthwiseConv2D import DepthwiseConv2D
from model.switchnorm import SwitchNormalization
from keras.utils import multi_gpu_model
from keras.optimizers import *
from keras.models import load_model
from IPython.display import SVG
from keras.utils import plot_model
from keras.utils.vis_utils import model_to_dot

from dataload.data_generator import *
from model.core_model import model_v1 as matrix_model
from model.core_model import model_v6 as point_model

gpus = 2
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1"
config = tf.ConfigProto() 
config.gpu_options.per_process_gpu_memory_fraction = 0.95
session = tf.Session(config=config)


In [3]:
base_project = os.path.abspath('../')
data_dir    = os.path.join(base_project, 'dataspace')
model_name = 'model_v1'
image_size = (512, 512)
mask_size  = (31, 31)

In [4]:
threshold = 0.0 
heatmap_height = mask_size[0]
heatmap_width  = mask_size[1]

train_batch_size = 32
valid_batch_size = 32

train_gen = MyImageDataGenerator(
    preprocessing_function=preprocess_input,
    width_shift_range=0.05,
    height_shift_range=0.05,
    rotation_range=6,
    channel_shift_range=15,
    zoom_range=(.9, 1.1),
    horizontal_flip=True
)

gen = MyImageDataGenerator(preprocessing_function=preprocess_input)

train_generator = train_gen.myflow_from_directory(os.path.join(data_dir, 'train'),
                                                  target_size       = image_size,
                                                  x_threshold       = threshold,
                                                  y_threshold       = threshold,
                                                  dataset_mode      = 'valid',
                                                  return_index_array= True,
                                                  heatmap_height    = heatmap_height,
                                                  heatmap_width     = heatmap_width,
                                                  batch_size        = train_batch_size)

valid_generator = gen.myflow_from_directory(os.path.join(data_dir, 'valid'), 
                                           target_size       = image_size,
                                           x_threshold       = threshold,
                                           y_threshold       = threshold,
                                           dataset_mode      = 'valid',
                                           return_index_array= True,
                                           heatmap_height   = heatmap_height,
                                           heatmap_width   = heatmap_width,
                                           batch_size        = valid_batch_size)

test_generator = gen.myflow_from_directory(os.path.join(data_dir, 'test'), 
                                           target_size       = image_size,
                                           x_threshold       = threshold,
                                           y_threshold       = threshold,
                                           dataset_mode      = 'valid',
                                           return_index_array= True,
                                           heatmap_height   = heatmap_height,
                                           heatmap_width   = heatmap_width,
                                           batch_size        = valid_batch_size)

train_samples_epoch = train_generator.data_num
print("samples_train_epoch = {}".format(train_samples_epoch))

steps_train = len(train_generator)
print("steps_train = {}".format(steps_train))

valid_samples_epoch = valid_generator.data_num
print("samples_valid_epoch = {}".format(valid_samples_epoch))

steps_valid = len(valid_generator)
print("steps_valid = {}".format(steps_valid))

test_samples_epoch = test_generator.data_num
print("samples_test_epoch = {}".format(test_samples_epoch))

steps_test = len(test_generator)
print("steps_test = {}".format(steps_test))


Found 257107 image pairs.
Found 88145 image pairs.
Found 15907 image pairs.
samples_train_epoch = 257107
steps_train = 8035
samples_valid_epoch = 88145
steps_valid = 2755
samples_test_epoch = 15907
steps_test = 498


In [5]:
model_path = os.path.join(base_project,  'checkpoints/' + model_name)
model      = load_model(os.path.join(model_path, 'matrix_model_v1_201901240447_508.2626_0.5644_model.h5'))



In [19]:
import cv2
    
def NptoImg(x):
    return np.uint8(np.asarray((np.squeeze(x) + 1) * 127.5))

def ImgtoInput(x):
    return np.float64(x) / 127.5 - 1

def MptoImg(x):
    mp = np.squeeze(x)
    mp = mp - np.min(mp)
    mp = mp / (np.max(mp) + 1e-18)
    mp = np.uint8(mp * 255)
    return mp

def post_process_prob_argmax(y_pred, y_count, threshold=0.1, filter_size=3):
    h, w = y_pred.shape
    y_hat = np.zeros_like(y_pred)
    y_count = int(y_count.round())
    sz = int((filter_size - 1) / 2)
    range_size = [int(v) for v in np.linspace(-sz, sz, 2*sz+1)]
    y_pred[y_pred < threshold] = 0
    for i in range(h):
        for j in range(w):
            v_max = y_pred[i,j]
            ix_m = i
            iy_m = j
            flag = True
            for i_x in range_size:
                for i_y in range_size:
                    ix = i_x + i
                    iy = i_y + j
                    if i_x + i < 0:
                        ix = 0
                    if i_y + j < 0:
                        iy = 0
                    if i_x + i > h - 1:
                        ix = h - 1
                    if i_y + j > w - 1:
                        iy = w - 1
                    v_current = y_pred[ix, iy]
                    if v_max < v_current:
                        flag = False
            if flag:
                y_hat[ix_m, iy_m] = v_max
    
    y_hat_flatten = y_hat.flatten()
    y_sort_index = np.argsort(y_hat_flatten)         # ascending order
    
    N_largest_indice = y_sort_index[-y_count:]       # the N largest probabilities
    N_largest_prob = y_hat_flatten[N_largest_indice]
    y_hat_flatten[N_largest_indice] = 1
    y_hat_pred = y_hat_flatten.reshape((h, w))
    
    return y_hat_pred.round(), N_largest_prob


def add_weight(img, in_x, weight=0.4, if_filter=False, black=False):
    cam = cv2.resize(in_x, (img.shape[0], img.shape[1]), interpolation=cv2.INTER_NEAREST)
    if if_filter:
        heatmap = np.tile(np.uint8(255*cam)[:, :, np.newaxis], (1, 1, 3))
    else:
        heatmap = cv2.applyColorMap(np.uint8(255*cam), cv2.COLORMAP_JET)
    out = cv2.addWeighted(img, 1.0, heatmap, weight, 0)
    if black:
        out[heatmap>=250] = 0
    out = out[:, :, ::-1]
    return out

def display_masking_image(imgs, names=None, sz=20):
    N = len(imgs)
#     print(N)
    if N == 1:
        display_image_one(imgs[0])
        if names:
                plt.title(names[0])
    else:
        plt.figure(figsize=(sz, sz * N))
        for i in range(N):
            plt.subplot(1, N, i + 1)
            plt.imshow(imgs[i])
            if names:
                plt.title(names[i])

def generate_matrix(a_map, b_map):
    a_hei, a_wid = a_map.shape
    b_hei, b_wid = b_map.shape
    matrix_mask = np.zeros((b_hei, b_wid, a_hei, a_wid, 1))
    for b_h in range(b_hei):
        for b_w in range(a_wid):
            for a_h in range(a_hei):
                for a_w in range(a_wid):
                    if a_map[a_h, a_w] and b_map[b_h, b_w]:
                        matrix_mask[b_h, b_w, a_h, a_w, :] = 1
    return matrix_mask

def mash_sku_points(sku_points, heatmap_size=(31, 31)):
    map_sku = {}  # {index:sku}
    point_lists = [] # [h, w]
    im_h = heatmap_size[0]
    im_w = heatmap_size[1]
    sku_map = np.zeros((im_h, im_w))
    
    for sku, point_list in sku_points.items():
        for point in point_list:
            h_ratio = point[1]
            w_ratio = point[0]
            if 0 < h_ratio < 1 and 0 < w_ratio < 1:
                map_row, map_col = int(h_ratio * im_h), int(w_ratio * im_w)
                map_sku[map_row * im_h + map_col] = sku
                point_lists.append((map_row, map_col))
                sku_map[map_row, map_col] = 1
    return map_sku, sku_map, point_lists

def index2coord(index, heatmap_size=31):
    h, w = divmod(index, heatmap_size)
    return (h, w)

def coord2index(map_row, map_col, im_h =31):
    return map_row * im_h + map_col

def merge_sku(sku_point_dict):
    sku_dict = {}
    for sku, point_list in sku_point_dict.items():
        for point in point_list:
            if sku in sku_dict:
                sku_dict[sku] +=1
            else:
                sku_dict[sku] = 1
    return sku_dict

def cmp_sku(a_sku_dict, b_sku_dict):
    flag = True
    if len(a_sku_dict) != len(b_sku_dict):
        flag = False
    else:
        for a_sku in a_sku_dict:
            if a_sku not in b_sku_dict:
                flag = False
            elif a_sku_dict[a_sku] != b_sku_dict[a_sku]:
                    flag = False
    return flag

def count_sku_dict(dict_list):
    sumup = 0
    for sku, number in dict_list.items():
        sumup += number
    return sumup

In [None]:
x, y, index_array, a_sku_points, b_sku_points, img_path, batch_input_a_all, batch_input_b_all = test_generator.next()

In [None]:
idx = 0

black = True

x_img_A          = x[0][idx]
x_img_B          = x[1][idx]
x_matrix_mask    = x[2][idx]
a_path           = img_path[idx][0]
b_path           = img_path[idx][1]
a_sku_point_dict = a_sku_points[idx]
b_sku_point_dict = b_sku_points[idx]

point_sku_A, point_mask_A, sku_coord_A = mash_sku_points(a_sku_point_dict) # {index, sku}
point_sku_B, point_mask_B, sku_coord_B = mash_sku_points(b_sku_point_dict) # {index, sku}

b_true_sku_number   = merge_sku(b_sku_point_dict) # {index, sku}
b_true_count        = count_sku_dict(b_true_sku_number)
x_matrix_mask       = generate_matrix(point_mask_A, point_mask_B)

x_input_a           = np.expand_dims(x_img_A, 0)
x_input_b           = np.expand_dims(x_img_B, 0)
x_input_matrix_mask = np.expand_dims(x_matrix_mask, 0)

y_pred_matrix = model.predict([x_input_a, x_input_b, x_input_matrix_mask])
y_pred_matrix = y_pred_matrix.squeeze()

img_A = cv2.imread(a_path) # BGR
img_A = cv2.resize(img_A, image_size)
img_B = cv2.imread(b_path) # BGR
img_B = cv2.resize(img_B, image_size)


black = True
# black = False
mask_hei = x_matrix_mask.shape[0]
mask_wid = x_matrix_mask.shape[1]
b_pred_sku_number = {}
showing_num = 0
for x_b in range(mask_hei):
    for y_b in range(mask_wid):
        if point_mask_B[x_b, y_b]: # B_points_mask
            A_item_point = np.zeros((31, 31))
            B_item_point = np.zeros((31, 31))
            pred_similiar_points_prob = y_pred_matrix[x_b, y_b] # A similiar matrix 
            pred_similiar_points_prob = pred_similiar_points_prob * point_mask_A
            pred_similiar_points_prob[pred_similiar_points_prob < 0.5] = 0 # skip the low confidence point
            
            max_index_A = pred_similiar_points_prob.argmax()
            if max_index_A > 0:
                
                point_max_prob = pred_similiar_points_prob.flatten()[max_index_A]
                point_pred_sku_id = point_sku_A[max_index_A]
                point_true_sku_id = point_sku_B[coord2index(x_b, y_b)]
                
                if point_pred_sku_id in b_pred_sku_number:
                        b_pred_sku_number[point_pred_sku_id] += 1
                else:
                    b_pred_sku_number[point_pred_sku_id] = 1
            if showing_num < 15: # only show 10 pimages
                max_hei, max_wid = index2coord(index = max_index_A, heatmap_size=mask_hei)
                A_item_point[max_hei, max_wid] = 1
                B_item_point[x_b, y_b] = 1
                masking_A = add_weight(img_A, A_item_point, if_filter=True, weight=1, black=black)
                masking_B = add_weight(img_B, B_item_point, if_filter=True, weight=1, black=black)
                display_masking_image([masking_B, masking_A], names=[point_true_sku_id, point_pred_sku_id])
                showing_num += 1

b_pred_count = count_sku_dict(b_pred_sku_number)
if cmp_sku(b_pred_sku_number, b_true_sku_number):
    print('Correction!')
else:
    print('Wrong!')

print(b_true_sku_number)
print(b_pred_sku_number)
print(b_true_count, b_pred_count)
