In [None]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2
import os
import tqdm
from skimage.metrics import structural_similarity as ssim
# from sewar.full_ref import msssim
import time
import copy
import imutils

def list_contours(contours):
    list = []
    max_len = 0
    for cnt in contours:
        for p in cnt:
            list.append(p)
    return list


def filter_much_less_cnts(contours, min_area=20):
    """
    remove contours that have the area is smaller than min_area
    :param contours: list of contours (outputs of cv2.findContours)
    :param min_area: min contour area
    :return: list of contours is more than min_area
    """
    cnt_list = []  # to avoid error
    for cnt in contours:
        if cv2.contourArea(cnt) > min_area:
            # contours.remove(cnt)
            cnt_list.append(cnt)
    return list_contours(cnt_list)

def get_cnt_area(bin_image):
    contours, hierarchy = cv2.findContours(bin_image, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
    area = 0
    for cnt in contours:
        area += cv2.contourArea(cnt) 
    return area


def rotate_img(image, angle):
    """
    Rotate image angle (counter clockwise if angle > 0) at the center of image
    :param image: cv2 image
    :param angle: angle rotation
    :return: rotated image
    """
    (h, w) = image.shape[:2]
    (cX, cY) = (w // 2, h // 2)
    # rotate our image by x degrees around the center of the image
    M = cv2.getRotationMatrix2D((cX, cY), angle, 1.0)
    rotated = cv2.warpAffine(image, M, (w, h))
    return rotated


def image_resize(image, width=None, height=None, inter=cv2.INTER_AREA):
    """
    Resize image but hold the ratio between 2 edges, besides, resize width/height to target width/height
    :param image: cv2 image
    :param width: None or target width
    :param height: None or height width
    :param inter: resize methods
    :return: resized image
    """
    # initialize the dimensions of the image to be resized and
    # grab the image size
    dim = None
    (h, w) = image.shape[:2]
    # if both the width and height are None, then return the
    # original image
    if width is None and height is None:
        return image

    # check to see if the width is None
    if width is None:
        # calculate the ratio of the height and construct the
        # dimensions
        r = height / float(h)
        dim = (int(w * r), height)

    # otherwise, the height is None
    else:
        # calculate the ratio of the width and construct the
        # dimensions
        r = width / float(w)
        dim = (width, int(h * r))

    # resize the image
    resized = cv2.resize(image, dim, interpolation=inter)

    # return the resized image
    return resized


def resize_image_to_square(image, dst=256, color=[0, 0, 0]):
    """
    Resize with padding, first resize image such that max edge get dst, then pad 2 sides of the other edge to dest
    :param image: cv2 image
    :param dst: target size
    :param color: constant pad value
    :return: square image
    """
    desired_size = dst

    im = image
    old_size = im.shape[:2]  # old_size is in (height, width) format

    ratio = float(desired_size) / max(old_size)
    new_size = tuple([int(x * ratio) for x in old_size])

    # new_size should be in (width, height) format

    im = cv2.resize(im, (new_size[1], new_size[0]))

    delta_w = desired_size - new_size[1]
    delta_h = desired_size - new_size[0]
    top, bottom = delta_h // 2, delta_h - (delta_h // 2)
    left, right = delta_w // 2, delta_w - (delta_w // 2)

    color = color
    new_im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT,
                                value=color)

    return new_im


def find_center(list2, mask):
    # scipy.ndimage.measurements.center_of_mass¶
    """
    Find center of the character(mask)
    :param list2: list of contours
    :param mask: mask indicates the position of character
    :return: center of character
    """
    cnts = cv2.drawContours(mask, list2, -1, (0, 255, 0), 1)

    kpCnt = len(list2)

    x = 0
    y = 0

    for kp in list2:
        x = x + kp[0][0]
        y = y + kp[0][1]

    # cv2.circle(mask, (np.uint8(np.ceil(x/kpCnt)), np.uint8(np.ceil(y/kpCnt))), 1, (255, 255, 255), 1)
    return x / kpCnt, y / kpCnt  # x_center, y_center

# def shift_image(image, x, y):
#   M = np.float32([[1, 0, x], [0, 1, y]])
#   return cv2.warpAffine(image, M, (image.shape[1], image.shape[0]))

def shift_image(img, x, y):
    """
    Pad into image to center of character is center of image
    :param img: cv2 binary image
    :param x: x_center_of_image - x_center_of_character
    :param y: y_center_of_image - y_center_of_character
    :return: Padded image that center of character lies on center of image
    """
    x_abs = int(round(abs(x), 0))
    y_abs = int(round(abs(y), 0))
    if x < 0 and y < 0:
        pad_img = cv2.copyMakeBorder(img, top=y_abs, bottom=0, left=x_abs, right=0, borderType=cv2.BORDER_CONSTANT,
                                     value=0)
    elif x < 0 and y > 0:
        pad_img = cv2.copyMakeBorder(img, top=y_abs, bottom=0, left=0, right=x_abs, borderType=cv2.BORDER_CONSTANT,
                                     value=0)
    elif x > 0 and y < 0:
        pad_img = cv2.copyMakeBorder(img, top=0, bottom=y_abs, left=x_abs, right=0, borderType=cv2.BORDER_CONSTANT,
                                     value=0)
    else:
        pad_img = cv2.copyMakeBorder(img, top=0, bottom=y_abs, left=0, right=x_abs, borderType=cv2.BORDER_CONSTANT,
                                     value=0)
    return pad_img


def crop_char(thresh, contours, depth_image=None):
  """
  Crop image based on contours, find the most top, bot, left, right points
  :param thresh: binary image
  :param contours: list of contours
  :return: crop of original image that contains character
  """
  list1 = filter_much_less_cnts(contours)
  xmax = max(i[0][0] for i in list1)
  xmin = min([i[0][0] for i in list1])
  ymax = max(i[0][1] for i in list1)
  ymin = min([i[0][1] for i in list1])
  new_depth_image = None
  if (xmax - xmin) > (ymax -ymin):
    image = image_resize(thresh[ymin:ymax, xmin:xmax], width = 256)
    if depth_image is not None:
      new_depth_image = image_resize(depth_image[ymin:ymax, xmin:xmax], width = 256)
  else:
    image = image_resize(thresh[ymin:ymax, xmin:xmax], height = 256)
    if depth_image is not None:
      new_depth_image = image_resize(depth_image[ymin:ymax, xmin:xmax], height = 256)
  #   if (xmax-112) > (112-xmin):
  #     image = image_resize(thresh[ymin:ymax, 224-xmax:xmax], width = 224) #xmax, ymax, xmin, ymin
  #   else:
  #     image = image_resize(thresh[ymin:ymax, xmin:224-xmin], width = 224)
  # else:
  #   if (ymax-112) > (112-ymin):
  #     image = image_resize(thresh[224-ymax:ymax, xmin:xmax], height = 224)
  #   else:
  #     image = image_resize(thresh[ymin:224-ymin, xmin:xmax], height = 224)
  return image, new_depth_image

def normalize_mask(thresh, raw_image=None):
    """
    Get character image and shift center of character to center of image(a part of image can be disappeared)
    :param thresh: binary image
    :return: image that contain image, and center of image is center of character
    """
    new_image = None
    contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
    if raw_image is not None:
      thresh, new_depth_image = crop_char(thresh, contours, raw_image)
    else:
      thresh = crop_char(thresh, contours)[0]
    contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
    x_center, y_center = find_center(list_contours(contours), thresh)
    shifted_img = shift_image(thresh, thresh.shape[1] / 2 - x_center, thresh.shape[0] / 2 - y_center)
    # plt.imshow(cv2.circle(cv2.cvtColor(thresh, cv2.COLOR_GRAY2RGB), (int(x_center), int(y_center)), 10, (255, 0, 0), 1), cmap='gray')
    return resize_image_to_square(shifted_img, max(shifted_img.shape[:2])), new_depth_image

def remove_noise(bin_depth_image):
    def fill_noise(thresh1):
        h, w = thresh1.shape
        cnts = cv2.findContours(thresh1, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
        cnts = imutils.grab_contours(cnts)
        mask = np.ones(thresh1.shape[:2], dtype="uint8") * 255
        # loop over the contours
        for c in cnts:
            # if the contour is bad, draw it on the mask
            if cv2.contourArea(c)<(h*w/1200):
                cv2.drawContours(mask, [c], -1, 0, -1)
        # remove the contours from the image and show the resulting images
        thresh1 = cv2.bitwise_and(thresh1, thresh1, mask=mask)
        return thresh1
    return 255-fill_noise(255-fill_noise(bin_depth_image))

def normalize_print_image(print_image, image_size=256):
    """
    Get character area that center of character is same as center of image
    :param print_image: the cv2 color square image that was padded
    :param image_size: tgt image size
    :return: character area
    """
    print_image = resize_image_to_square(print_image, image_size, [255, 255, 255])
    gray = cv2.cvtColor(print_image, cv2.COLOR_BGR2GRAY)
    # ret, bin_print_image = cv2.threshold(gray, 120, 255, cv2.THRESH_BINARY_INV)
    ret2,bin_print_image = cv2.threshold(gray,0,255,cv2.THRESH_BINARY_INV+cv2.THRESH_OTSU)
    character_image, character_raw_image = normalize_mask(bin_print_image, print_image)
    
    # kernel = np.ones((3, 3), np.uint8)
    # character_image = cv2.erode(character_image, kernel, iterations=1)
    return character_image, character_raw_image

def normalize_depth_image(depth_image, thresh=255):
    """
    Get the mask of depth image
    :param depth_image: cv2 color image
    :return: binarized image
    """
    ### Normalize depth_image mask
    depth_image = cv2.resize(depth_image, (256, 256))
    hsvImg = cv2.cvtColor(depth_image, cv2.COLOR_BGR2HSV)
    # increase contrast
    value = 90

    vValue = hsvImg[..., 2]
    hsvImg[..., 2] = np.where((255 - vValue) < value, 255, vValue + value)
    gray = cv2.cvtColor(hsvImg, cv2.COLOR_HSV2BGR)
    gray = cv2.cvtColor(hsvImg, cv2.COLOR_BGR2GRAY)
    thresh = 255-cv2.adaptiveThreshold(gray,thresh,cv2.ADAPTIVE_THRESH_MEAN_C,\
                cv2.THRESH_BINARY,11,2)
    # thresh = 255 - cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, \
    #                                      cv2.THRESH_BINARY, 11, 1)
    thresh1 = normalize_mask(remove_noise(thresh))[0]
    # kernel = np.ones((3, 3), np.uint8)
    # thresh1 = cv2.dilate(thresh1, kernel, iterations=1)

    return thresh1


def estimate_dense_ratio(bin_image):
    """
    This function is used to find the best threshold to binarize the depth image
    The dense ratio is num pixels in character / num pixels of image
    :param bin_image: binary image
    :return: ratio between num pixels in character and num pixels of image
    """
    thresh = bin_image
    
    # contours, hierarchy = cv2.findContours(bin_image, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
    # thresh = crop_char(bin_image, contours)
    # thresh = resize_image_to_square(thresh)
    return np.count_nonzero(thresh) / (thresh.shape[0] * thresh.shape[1])

def estimate_dense_character_ratio(bin_image):
    """
    This function is used to find the best threshold to binarize the depth image
    The dense ratio is num pixels in character / num pixels of image
    :param bin_image: binary image
    :return: ratio between num pixels in character and num pixels of image
    """
    thresh = bin_image
    
    contours, hierarchy = cv2.findContours(bin_image, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
    thresh = crop_char(bin_image, contours)
    thresh = resize_image_to_square(thresh)
    return np.count_nonzero(thresh) / (thresh.shape[0] * thresh.shape[1])

def estimate_dense_ratio_wrap(bin_image):
    """
    This function is used to find the best threshold to binarize the depth image
    The dense ratio is num pixels in character / num pixels of image
    :param bin_image: binary image
    :return: ratio between num pixels in character and num pixels of image
    """
    def crop_rect(img, rect):
      # get the parameter of the small rectangle
      center = rect[0]
      size = rect[1]
      angle = rect[2]
      center, size = tuple(map(int, center)), tuple(map(int, size))

      # get row and col num in img
      height, width = img.shape[0], img.shape[1]
      # print("width: {}, height: {}".format(width, height))

      M = cv2.getRotationMatrix2D(center, angle, 1)
      img_rot = cv2.warpAffine(img, M, (width, height))

      img_crop = cv2.getRectSubPix(img_rot, size, center)

      return img_crop
    img = bin_image
    contours, hierarchy = cv2.findContours(img, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
    cnt = np.array(filter_much_less_cnts(contours))

    rect = cv2.minAreaRect(cnt)

    box = cv2.boxPoints(rect)
    box = np.int0(box)

    img_crop = crop_rect(img, rect)

    return np.count_nonzero(img_crop) / (img_crop.shape[0] * img_crop.shape[1])

def normalize_depth_image_v2(depth_image, print_image):
    """
    Find the best depth image correspond to print image
    The idea that best depth image will have least diff of dense ratio. So, this function will find threshold parameter
    and num iteration of erode/dilate to get least diff of dense ratio
    :param depth_image: cv2 color depth image
    :param print_image: cv2 color image
    :return: best depth image and best parameter
    """
    print_image = resize_image_to_square(print_image, dst=256, color=[255, 255, 255])
    
    print_image = cv2.cvtColor(print_image, cv2.COLOR_BGR2GRAY)
    ret2,bin_print_image = cv2.threshold(print_image,0,255,cv2.THRESH_BINARY_INV+cv2.THRESH_OTSU)
    # bin_print_image = remove_noise(bin_print_image)
    print_dense = estimate_dense_ratio_wrap(bin_print_image)
    depth_image = cv2.resize(depth_image, (256, 256))
    # bin_depth_image = normalize_depth_image(depth_image)
    # hsvImg = cv2.cvtColor(depth_image, cv2.COLOR_BGR2HSV)
    # # increase contrast
    # value = 90

    # vValue = hsvImg[..., 2]
    # hsvImg[..., 2] = np.where((255 - vValue) < value, 255, vValue + value)
    # gray = cv2.cvtColor(hsvImg, cv2.COLOR_HSV2BGR)
    depth_image = cv2.cvtColor(depth_image, cv2.COLOR_BGR2GRAY)
    #im_show([depth_image]) 
    best_diff = 1
    best_score = -1
    # for th in range(60, 200):
    #     try:
    #       ret, bin_depth_image = cv2.threshold(copy.deepcopy(depth_image), th, 255, cv2.THRESH_BINARY_INV)
    #       # bin_depth_image = remove_noise(bin_depth_image)
    #       score = sim_score(bin_print_image, bin_depth_image)
    #       if score > best_score:
    #           best_score = score
    #           best_depth_image = bin_depth_image
    #           best_thresh = th
    #     except Exception:
    #       continue

    for th in range(50, 140):
      try:
        ret, bin_depth_image = cv2.threshold(copy.deepcopy(depth_image), th, 255, cv2.THRESH_BINARY_INV)
        depth_dense = estimate_dense_ratio_wrap(bin_depth_image)
        # print(th, abs(depth_dense - print_dense))
        dense = abs(depth_dense - print_dense)
        if dense < best_diff:
            need_erode = True if depth_dense > print_dense else False
            best_diff = abs(depth_dense - print_dense)
            best_depth_image = bin_depth_image
            best_thresh = th #+ 10
            best_dense = dense
      except Exception:
        continue

    ret, best_depth_image = cv2.threshold(depth_image, best_thresh, 255, cv2.THRESH_BINARY_INV)
   
    num_iter = 0
    # for i in range(1, 2):
    #     try:
    #         kernel = np.ones((3, 3), np.uint8)
    #         if need_erode:
    #             mor_depth_img = cv2.erode(copy.copy(best_depth_image), kernel, iterations = i)
    #         else:
    #             mor_depth_img = cv2.dilate(copy.copy(best_depth_image), kernel, iterations = i)
    #         depth_dense, depth_character = estimate_dense_ratio(mor_depth_img)
    #         if abs(depth_dense - print_dense) < best_diff:
    #             best_diff = abs(depth_dense - print_dense)
    #             best_depth_image = mor_depth_img
    #             num_iter = i
    #     except Exception:
    #         continue
    return best_depth_image, bin_print_image, best_diff, best_thresh

def normalize_depth_image_v3(depth_image, print_image):
    """
    Find the best depth image correspond to print image
    The idea that best depth image will have least diff of dense ratio. So, this function will find threshold parameter
    and num iteration of erode/dilate to get least diff of dense ratio
    :param depth_image: cv2 color depth image
    :param print_image: cv2 color image
    :return: best depth image and best parameter
    """
    gray1 = cv2.cvtColor(cv2.resize(print_image, (256,256)), cv2.COLOR_BGR2GRAY)
    gray1 = resize_image_to_square(gray1, 256, [255,255,255])
    ret, bin_print_image = cv2.threshold(gray1, 127, 255, cv2.THRESH_BINARY_INV)
    gray2 = cv2.cvtColor(cv2.resize(depth_image, (256,256)), cv2.COLOR_BGR2GRAY)

    

    best_score = -1
    best_thresh = 50
    for th in range(50, 200):
      try:
        ret, bin_depth_image = cv2.threshold(gray2, th, 255, cv2.THRESH_BINARY_INV)
        score, square2, dense_diff = sim_score(bin_print_image, bin_depth_image)
        if score > best_score:
          best_score = score
          best_square = square2
          best_thresh = th 
          best_bin_depth_image = bin_depth_image
          # best_dense_diff = dense_diff
      except Exception:
        continue

    print(best_thresh, best_score)
    ret, best_depth_image = cv2.threshold(gray2, best_thresh, 255, cv2.THRESH_BINARY_INV)
 
    return remove_noise(best_depth_image), best_thresh, best_score

def normalize_depth_image_v4(depth_image, print_image):
    """
    Find the best depth image correspond to print image
    The idea that best depth image will have least diff of dense ratio. So, this function will find threshold parameter
    and num iteration of erode/dilate to get least diff of dense ratio
    :param depth_image: cv2 color depth image
    :param print_image: cv2 color image
    :return: best depth image and best parameter
    """

    gray1 = cv2.cvtColor(cv2.resize(print_image, (256,256)), cv2.COLOR_BGR2GRAY)
    gray1 = resize_image_to_square(gray1, 256, [255,255,255])
    ret, bin_print_image = cv2.threshold(gray1, 127, 255, cv2.THRESH_BINARY_INV)
    gray2 = cv2.cvtColor(cv2.resize(depth_image, (256,256)), cv2.COLOR_BGR2GRAY)

    list_dense_diff = []
    list_ssim = []
    current_thresh = 90
    for th in range(90, 150):
      try:
        ret, bin_depth_image = cv2.threshold(gray2, th, 255, cv2.THRESH_BINARY_INV)
        score, square2, dense_diff = sim_score(bin_print_image, bin_depth_image)
        list_ssim.append(score)
        # dense_diff = abs(estimate_dense_ratio(center_crop(bin_depth_image, (max_size,max_size)))-estimate_dense_character_ratio(bin_print_image))
        list_dense_diff.append(dense_diff)
      except Exception:
        current_thresh += 1
        continue
    # list_dense_diff = [1-i for i in NormalizeData(list_dense_diff)]
    # list_dense_diff = NormalizeData(list_dense_diff)
    # print(list_ssim.index(max(list_ssim)), list_dense_diff.index(min(list_dense_diff)))
    # max_sim_id = list_ssim.index(max(list_ssim))
    # min_dense_diff_id = list_dense_diff.index(min(list_dense_diff))
    # if max_sim_id != min_dense_diff_id:
    #   best_thresh = current_thresh + (abs(max_sim_id - min_dense_diff_id)%2)
    # else:
    #   best_thresh = current_thresh + min_dense_diff_id
    # print(abs(max_sim_id - min_dense_diff_id)//2)
    # list_ssim_nomarlized = list_ssim
    # entropy = []
    # for i in range(len(list_ssim)):
    #     entropy.append(list_dense_diff[i]+list_ssim_nomarlized[i])
    # max_entropy_id = entropy.index(max(entropy))
    # best_thresh = entropy.index(max(entropy)) + current_thresh
    
    #max_accepted_diff_dense = max(heapq.nlargest(len(list_dense_diff)//4, list_dense_diff))

    # max_accepted_diff_dense = min(list_dense_diff) + 0.05
    
    # list_dense_diff = NormalizeData(list_dense_diff)
    # print(list_dense_diff)

    current_ssim = 0
    for i in range(len(list_ssim)):
      if list_ssim[i] > current_ssim:# and list_dense_diff[i] < max_accepted_diff_dense:
        current_ssim = list_ssim[i]
    best_thresh = current_thresh + list_ssim.index(current_ssim)

    ret, best_bin_depth_image = cv2.threshold(gray2, best_thresh, 255, cv2.THRESH_BINARY_INV)
    print(best_thresh)
    
    return remove_noise(best_bin_depth_image), best_thresh, best_thresh

def rotate_bound(image, angle, x_center, y_center):
    """
    Rotate image at (x_center, y_center) and do not lose any part of image
    :param image: cv2 image
    :param angle: counter-clockwise angle
    :param x_center: x-coordinate of center
    :param y_center: y-coordinate of center
    :return: rotated image
    """
    # grab the dimensions of the image and then determine the
    # center
    (h, w) = image.shape[:2]
    (cX, cY) = (x_center, y_center)
    # grab the rotation matrix, then grab the sine and cosine
    # (i.e., the rotation components of the matrix)
    M = cv2.getRotationMatrix2D((cX, cY), angle, 1.0)
    cos = np.abs(M[0, 0])
    sin = np.abs(M[0, 1])
    # compute the new bounding dimensions of the image
    nW = int((h * sin) + (w * cos))
    nH = int((h * cos) + (w * sin))
    # adjust the rotation matrix to take into account translation
    M[0, 2] += (nW / 2) - cX
    M[1, 2] += (nH / 2) - cY
    # perform the actual rotation and return the image
    return cv2.warpAffine(image, M, (nW, nH))

def pad_img(img, pad_value=20):
  # padding 20px to print_image mask
  h ,w = img.shape
  img = cv2.resize(img, (h-pad_value,w-pad_value)) ######################
  ht, wd = img.shape
  # result = np.full((hh,ww), color, dtype=np.uint8)
  thresh2 = np.zeros((h, w))
  # compute center offset
  xx = (h - wd) // 2
  yy = (w - ht) // 2
  # copy img image into center of result image
  thresh2[yy:yy+ht, xx:xx+wd] = img

  return thresh2


def make_border(valueX, valueY, name):
    top = bottom = left = right = 0
    if valueX < 0:
        left = valueX
        right = 0
    else:
        left = 0
        right = valueX
    if valueY < 0:
        top = valueY
        bottom = 0
    else:
        top = 0
        bottom = valueY
    if name == 'depth':
        return abs(top), abs(bottom), abs(left), abs(right)
    else:
        return abs(bottom), abs(top), abs(right), abs(left)



def shift_character_in_canvas(bin_print_image, bin_depth_image):
        # find maximum character specifications
        h_print, w_print = bin_print_image.shape
        h_rotated, w_rotated = bin_depth_image.shape
        max_size = max(h_print, w_print,h_rotated, w_rotated) 
        canvas_size = (max_size, max_size)

        #create a canvas for depth, print image with max_size
        canvas_depth_image = np.zeros(canvas_size, np.uint8)
        canvas_print_image = np.zeros(canvas_size, np.uint8)

        #move character to to center of canvas
        x_center = max_size //2
        y_center = max_size //2

        ##print image
        from_x_print = x_center - w_print //2
        to_x_print   = x_center + w_print - (w_print //2)

        from_y_print = y_center - (h_print //2)
        to_y_print = y_center + h_print - (h_print //2)
        canvas_print_image[from_x_print : to_x_print,from_y_print:to_y_print] = bin_print_image #draw image on a canvas
        
        ##depth image
        from_x_depth = x_center - w_rotated //2
        to_x_depth   = x_center + w_rotated - (w_rotated //2)

        from_y_depth = y_center - (h_rotated //2)
        to_y_depth = y_center + h_rotated - (h_rotated //2)
        canvas_depth_image[from_x_depth : to_x_depth,from_y_depth:to_y_depth] = bin_depth_image

        #add border to image  
        #plt.imshow(canvas_print_image)
        #plt.show()
        #find the best fit for for two two canvas
        prevIOU = 0
        step = 1
        movements = {
                "left": (-step,0),
                "right":(step,0),
                "top": (0, -step),
                "bottom":(0, step)
                        }
        changeX = 0 # recording changing in the x direction
        changeY = 0 # recording changing in the y direction
        stillChanging = True
        best_fit_canvas_depth = None
        best_fit_canvas_print = None
        while stillChanging:
                stillChanging = False
                for key,mov in movements.items():
                        deltaX, deltaY = movements[key]
                        while True:
                                #padding image for canvas
                                top, bottom, left, right = make_border(changeX + deltaX , changeY +  deltaY , 'depth')
                                #top_delta, bottom_delta, left_delta, right_delta = make_border(delta_width, delta_height , 'depth')
                                
                                new_canvas_depth_image = cv2.copyMakeBorder(canvas_depth_image, top , bottom , left , right , cv2.BORDER_CONSTANT,value = 0)
                                
                                
                                top, bottom, left, right = make_border(changeX + deltaX , changeY +  deltaY  , 'print')
                                #top_delta, bottom_delta, left_delta, right_delta = make_border(delta_width, delta_height , 'print')
                                new_canvas_print_image = cv2.copyMakeBorder(canvas_print_image, top , bottom , left , right , cv2.BORDER_CONSTANT,value = 0)
                                
                               
                                #print(new_canvas_depth_image.shape, new_canvas_print_image.shape)
                                #compare current IOU
                                
                                currentIOU = get_iou_metric(new_canvas_depth_image, new_canvas_print_image)
                                #print(currentIOU - prev)
                                if currentIOU - prevIOU <= 0:
                                        break
                                else:
                                        #update 
                                        prevIOU = currentIOU
                                        changeX += deltaX
                                        changeY += deltaY
                                        best_fit_canvas_depth = new_canvas_depth_image
                                        best_fit_canvas_print = new_canvas_print_image
                                        #plt.imshow(best_fit_canvas_depth+best_fit_canvas_print)
                                        #plt.show()
                                        stillChanging = True
        # Again iou not ssim
        sim = get_iou_metric(best_fit_canvas_depth, best_fit_canvas_print)
        return  best_fit_canvas_depth,best_fit_canvas_print, changeX, changeY, sim  

def get_IOU(canvas1,canvas2):
    intersection_image = cv2.bitwise_and(canvas1,canvas2)
    union_image = canvas1+canvas2
    return  np.count_nonzero(intersection_image)/np.count_nonzero(union_image)

def find_best_angle(bin_depth_img, min_angle, max_angle, num_step, bin_print_img, depth_image):
    """
    This function will be find best angle in range np.linspace(min_angle, max_angle, num_step)
    This function supports match_rotation function
    :param bin_depth_img: cv2 binary image
    :param min_angle:
    :param max_angle:
    :param num_step:
    :param bin_print_img: cv2 binary image
    :return: rotated angle, similarity score, best rotated image, best print image
    """
    depth_image = cv2.resize(depth_image, (256,256))

    maxNonZero = 0
    maxiou = 0
    best_fit_changeX = 0
    best_fit_changeY = 0
    canvas_depth_best_fit_image = 0
    canvas_print_best_fit_image = 0
    resized_print_img = copy.copy(bin_print_img)
    if min_angle == max_angle:
        angle_range = [min_angle]
    else:
        angle_range = np.linspace(min_angle, max_angle, num=num_step)

    for angle in angle_range:
        rotated_img = rotate_bound(copy.deepcopy(bin_depth_img), angle, bin_depth_img.shape[1] / 2,
                                   bin_depth_img.shape[0] / 2)
        rotated_depth_image = rotate_bound(copy.deepcopy(depth_image), angle, bin_depth_img.shape[1] / 2,
                                   bin_depth_img.shape[0] / 2)

        
        rotated_img, new_depth_image = normalize_mask(rotated_img, rotated_depth_image)
        
        if rotated_img.shape[0] < resized_print_img.shape[0]:
            resized_print_img = cv2.resize(resized_print_img, rotated_img.shape[:2])
        else:
            rotated_img = cv2.resize(rotated_img, resized_print_img.shape[:2])
        # countNonZero = c(rotated_img, resized_print_img)

        # if countNonZero > maxNonZero:
        #     maxNonZero = countNonZero
        #     rotate_angle = angle
        #     best_depth_rotated_img = rotated_img
        #     best_print_img = resized_print_img
        canvas_depth_new_image,canvas_print_image,changeX, changeY, iou = shift_character_in_canvas(resized_print_img, rotated_img)
        if iou > maxiou:
            rotate_angle = angle
            maxiou = iou
            best_fit_changeX = changeX
            best_fit_changeY = changeY
            canvas_depth_best_fit_image  = canvas_depth_new_image
            canvas_print_best_fit_image = canvas_print_image
            best_depth_image = new_depth_image


    return rotate_angle, maxiou, canvas_depth_best_fit_image, canvas_print_best_fit_image, best_fit_changeX, best_fit_changeY, best_depth_image


def convert_color_img(img, color):
    """
    Convert color of character of binary image [0, 255]
    :param img: cv2 binary image
    :param color: 'r'/'b'/'g': convert to red/blue/green
    :return: numpy color image
    """
    cv_rgb_img = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_GRAY2RGB)
    np_rgb_color = np.array(cv_rgb_img)
    if color == 'r':
        color_index = 0
    elif color == 'g':
        color_index = 1
    else:
        color_index = 2
    np_rgb_color[np_rgb_color[:, :, color_index] == 0, color_index] = 255
    return np_rgb_color


def match_rotation(depth_image, print_image, normalize_depth_image_type):
    """
    Find best rotated angle
    :param depth_image: cv2 color depth image
    :param print_image: cv2 color print image
    :return: best angle, similarity score
    """
    
    if normalize_depth_image_type == 'v2':
      bin_depth_img, _, _, _ = normalize_depth_image_v2(depth_image, print_image)
    elif normalize_depth_image_type == 'v3':
      bin_depth_img, _, _ = normalize_depth_image_v3(depth_image, print_image)
    else:
      bin_depth_img, _, _ = normalize_depth_image_v4(depth_image, print_image)
    bin_print_img, print_image = normalize_print_image(print_image)
    maxNonZero = 0
    best_depth_rotated_img = None
    best_print_img = None
    min_angle = 0
    max_angle = 359
    max_depth = 2
    num_step = 60

    for depth in range(max_depth):
        rotate_angle, sim_score, best_depth_rotated_img, best_print_img, changeX, changeY, best_depth_image = find_best_angle(bin_depth_img, min_angle,
                                                                                            max_angle, num_step,
                                                                                            copy.copy(bin_print_img), depth_image)
        min_angle = rotate_angle - 3 * ((max_angle - min_angle) / num_step)
        max_angle = rotate_angle + 3 * ((max_angle - min_angle) / num_step)
    g_best_depth_rotated_img = 255 - convert_color_img(255 - best_depth_rotated_img, 'g')
    r_best_print_img = 255 - convert_color_img(best_print_img, 'r')
    stacked_img = g_best_depth_rotated_img + r_best_print_img
    
    return dict(rotate_angle=rotate_angle, sim_score=sim_score, best_depth_rotated_img=best_depth_rotated_img, best_print_img=best_print_img, stacked_img=stacked_img, bin_depth_img=bin_depth_img, changeX=changeX, changeY=changeY, best_depth_image=best_depth_image,print_image=print_image)


def concatenate_image(depth_image, print_image, rotate_angle, dst=256):
    """
    Rotate depth image rotate_angle and merge rotated depth image and print image, and resize to dst resolution
    :param depth_image: cv2 depth image
    :param print_image: cv2 print image
    :param rotate_angle: angle to rotate depth image
    :param dst: target resolution
    :return: merged image
    """
    cv_resized_depth = resize_image_to_square(depth_image, 256, [255,255,255])
    cv_resized_print = resize_image_to_square(print_image, 256, [255,255,255])
    print(cv_resized_depth.shape)
    return np.concatenate((cv_resized_depth, cv_resized_print), axis=1)
def change_contrast(img, level):

    factor = (259 * (level + 255)) / (255 * (259 - level))
    def contrast(c):
        return 128 + factor * (c - 128)
    return img.point(contrast)
def register_v1(id):
    """
    Register 2 images (actually get different angle between direction of depth character and print character
    :param depth_image_path: path to depth image
    :param print_image_path: path to print image
    :return:
    """
   
    depth_image_path = '/content/drive/MyDrive/DATA_WORK/24539_register/depth_imgs/' + id + '_r.png'

    # print_image_path = train_path + id + '/' + id + '01.png'
    # print_image_path = '/content/drive/MyDrive/DATA_WORK/woodblock_labels/processed/25900_pad' '/' + id + '.png'
    print_image_path = '/content/drive/MyDrive/DATA_WORK/24539_register/24539_2D/' + id + '.png'
    depth_image = cv2.imread(depth_image_path)
    
    # depth_image = cv2.bitwise_not(depth_image)
    # get (i, j) positions of all RGB pixels that are black (i.e. [0, 0, 0])
    # black_pixels = np.where(
    #     (depth_image[:, :, 0] == 0) & 
    #     (depth_image[:, :, 1] == 0) & 
    #     (depth_image[:, :, 2] == 0)
    # )

    # # set those pixels to white
    # depth_image[black_pixels] = [255, 255, 255]
    # im_pil = change_contrast(Image.fromarray(depth_image), 60)
    # depth_image = np.asarray(im_pil)
    print_image = cv2.imread(print_image_path)
    # print_image = cv2.flip(print_image, 1)
    print_image = resize_image_to_square(print_image, dst=256, color=[255, 255, 255])
  
    
    result1 = match_rotation(depth_image, print_image, 'v2')
    result2 = match_rotation(depth_image, print_image, 'v3')
   
    print(result2['sim_score'])
    result = result1 if result1['sim_score'] > result2['sim_score'] else result2
    # result = result2
    # print(result['changeX'], result['changeY'])
    
    concated_img = concatenate_image(result['best_depth_image'], result['print_image'], result['rotate_angle'], (256, 256))

    img_name = os.path.basename(print_image_path).replace('.png', '')
    iou = get_iou_metric(result['best_depth_rotated_img'], result['best_print_img'])
    ssim_score = ssim(result['best_depth_rotated_img'], result['best_print_img'])
    dice = get_dice_metric(result['best_depth_rotated_img'], result['best_print_img'])
    save_path = '/content/drive/MyDrive/DATA_WORK/GAN_pix2pix/register_outputs/9/'
    # cv2.imwrite(save_path + img_name + '_registed.jpg', concated_img)
    # cv2.imwrite(save_path + img_name.replace('.png', '') + '_stacked.jpg', result['stacked_img'])
    cv2.imwrite('/content/drive/MyDrive/DATA_WORK/GAN_pix2pix/register_outputs/good_thresh_depth_1/'+img_name+'.png', result['bin_depth_img'])
    with open(save_path+'register_log_4k.txt', 'a') as f:
    # # with open('/content/drive/MyDrive/DATA_WORK/woodblock_labels/register_log.txt', 'a') as f:
        f.write(img_name + ' ' +str(result['rotate_angle'])+' '+str(round(ssim_score, 5))+' '+str(round(iou, 5))+' '+str(round(dice, 5))+'\n')
    print('SSIM: ', ssim_score)
    return concated_img, result['best_depth_rotated_img'], result['best_print_img'], result['stacked_img'], result['rotate_angle'], ssim_score

def register_with_angle(depth_image_path, print_image_path, rotate_angle):
    """
    Register 2 images (actually get different angle between direction of depth character and print character
    :param depth_image_path: path to depth image
    :param print_image_path: path to print image
    :return:
    """
    depth_image = cv2.imread(depth_image_path)
    print_image = cv2.imread(print_image_path)
    print_image = cv2.flip(print_image, 1)
    
    bin_depth_img, _, _, _, _ = normalize_depth_image_v2(depth_image, print_image)

    best_print_img = normalize_print_image(print_image)
    best_depth_rotated_img = rotate_img(copy.deepcopy(bin_depth_img), rotate_angle)#, bin_depth_img.shape[1] / 2, bin_depth_img.shape[0] / 2)
    
    if best_print_img.shape[0] < best_depth_rotated_img.shape[0]:
        best_depth_rotated_img = cv2.resize(best_depth_rotated_img, best_print_img.shape[:2])
            
    else:
        best_print_img = cv2.resize(best_print_img, best_depth_rotated_img.shape[:2])
    g_best_depth_rotated_img = 255 - convert_color_img(255 - best_depth_rotated_img, 'g')
    r_best_print_img = 255 - convert_color_img(best_print_img, 'r')
    stacked_img = g_best_depth_rotated_img + r_best_print_img
    concated_img = concatenate_image(depth_image, print_image, rotate_angle, (256, 512))

    img_name = os.path.basename(print_image_path).replace('.png', '')
    iou = get_iou_metric(best_depth_rotated_img, best_print_img)
    ssim_score = ssim(best_depth_rotated_img, best_print_img)
    dice = get_dice_metric(best_depth_rotated_img, best_print_img)
    print('ssim: ',ssim_score,' iou: ',iou,' dice: ', dice)
    return concated_img, best_depth_rotated_img, best_print_img, stacked_img


def register(img_name):
    """
    Register 2 images (actually get different angle between direction of depth character and print character
    :param depth_image_path: path to depth image
    :param print_image_path: path to print image
    :return:
    """
    try:
      global num_imgs
      #------------------
      print_dir = '/content/drive/MyDrive/DATA_WORK/woodblock_labels/full_2d_pad/'
      depth_dir = '/content/drive/MyDrive/DATA_WORK/woodblock_labels/21072021_mocban_sokhop/full_depthmap/'
      #print('Registering ', count,'/',num_imgs, 'image: ',image)
      print_image_path = print_dir + img_name + '.png'
      depth_image_path = depth_dir + img_name + '_r_depth.png'
      #print(print_image_path)
      depth_image = cv2.imread(depth_image_path)
      print_image = cv2.imread(print_image_path)
      print_image = cv2.flip(print_image, 1)
      
      rotate_angle, ssim_score, best_depth_rotated_img, best_print_img, stacked_img = match_rotation(depth_image, print_image)
      concated_img = concatenate_image(depth_image, print_image, rotate_angle, (256, 512))

      img_name = os.path.basename(print_image_path).replace('.png', '')
      iou = get_iou_metric(best_depth_rotated_img, best_print_img)
      ssim_score = m_ssim(best_depth_rotated_img, best_print_img)
      dice = get_dice_metric(best_depth_rotated_img, best_print_img)
      save_path = '/content/drive/MyDrive/DATA_WORK/woodblock_labels/test17_iou/'
      cv2.imwrite(save_path + img_name + '_registed.jpg', concated_img)
      cv2.imwrite(save_path + img_name.replace('.png', '') + '_stacked.jpg', stacked_img)
      files = os.listdir(save_path)
      print('Registering ', len(files) // 2,'/',num_imgs)
      with open(save_path+'register_log_lastest.txt', 'a') as f:
      # with open('/content/drive/MyDrive/DATA_WORK/woodblock_labels/register_log.txt', 'a') as f:
          f.write(img_name + ' ' +str(rotate_angle)+' '+str(round(ssim_score, 5))+' '+str(round(iou, 5))+' '+str(round(dice, 5))+'\n')
      return concated_img, best_depth_rotated_img, best_print_img, stacked_img, rotate_angle, ssim_score
    except Exception as e:
        print(f'Error at ID: {img_name}')

def get_iou_metric(bin_depth_img, bin_print_img, smooth=1e-7):
    """
    Get iou metric of 2 same size images
    iou = |bin_depth_img ^ bin_print_img| / |bin_depth_img v bin_print_img|
    :param bin_depth_img: best rotated cv2 binary depth image
    :param bin_print_img: binary print image
    :param smooth: coff in case union equals zero
    :return: iou metric
    """
    mask_depth_img = np.where(bin_depth_img == 255, 1, 0)
    mask_print_img = np.where(bin_print_img == 255, 1, 0)
    intersection = np.sum(np.logical_and(mask_depth_img, mask_print_img))
    union = np.sum(np.logical_or(mask_depth_img, mask_print_img))
    iou_score = (intersection + smooth) / (union + smooth)
    return iou_score

def get_intersection(bin_depth_img, bin_print_img):
    mask_depth_img = np.where(bin_depth_img == 255, 1, 0)
    mask_print_img = np.where(bin_print_img == 255, 1, 0)
    intersection = np.sum(np.logical_and(mask_depth_img, mask_print_img))
    return intersection


def get_dice_metric(bin_depth_img, bin_print_img, smooth=1e-7):
    """
    Get dice metric of 2 same size images
    dice = 2 * |bin_depth_img ^ bin_print_img| / |bin_depth_img| + |bin_print_img|
    :param bin_depth_img: best rotated cv2 binary depth image
    :param bin_print_img: binary print image
    :param smooth: coff in case union equals zero
    :return: dice metric
    """
    mask_depth_img = np.where(bin_depth_img == 255, 1, 0)
    mask_print_img = np.where(bin_print_img == 255, 1, 0)
    intersection = np.sum(np.logical_and(mask_depth_img, mask_print_img))
    mask_sum = np.sum(np.abs(mask_depth_img)) + np.sum(np.abs(mask_print_img))
    dice_score = 2 * (intersection + smooth)/(mask_sum + smooth)
    return dice_score
def wrap(bin_image):
  def crop_rect(img, rect):
      # get the parameter of the small rectangle
      center = rect[0]
      size = rect[1]
      angle = rect[2]
      center, size = tuple(map(int, center)), tuple(map(int, size))

      # get row and col num in img
      height, width = img.shape[0], img.shape[1]
      # print("width: {}, height: {}".format(width, height))

      M = cv2.getRotationMatrix2D(center, angle, 1)
      img_rot = cv2.warpAffine(img, M, (width, height))

      img_crop = cv2.getRectSubPix(img_rot, size, center)
      
      return img_crop

  img = bin_image
  contours, hierarchy = cv2.findContours(img, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
  cnt = np.array(filter_much_less_cnts(contours))

  rect = cv2.minAreaRect(cnt)

  # box = cv2.boxPoints(rect)
  # box = np.int0(box)

  img_crop = crop_rect(img, rect)
  return img_crop

def sim_score(bin_print_image, bin_depth_image):
  wrap1 = wrap(remove_noise(bin_print_image))
  wrap2 = wrap(remove_noise(bin_depth_image)) 
  square1 = wrap1
  square2 = wrap2
  score = -1
  for i in range(4):
    square2 = cv2.rotate(square2, cv2.cv2.ROTATE_90_CLOCKWISE)
    square2 = cv2.resize(square2, (square1.shape[1], square1.shape[0]))
    # square2 = resize_image_to_square(square2)
    
    ssim_score = ssim(square1, square2)
    best_square2 = square2
    # iou_score = get_iou_metric(square1, square2)
    if score < ssim_score and abs(estimate_dense_character_ratio(square1)-estimate_dense_character_ratio(square2)) <0.05:
      score = ssim_score
      best_square2 = square2

  return score, best_square2, score
def _ssim(img1, img2):
    C1 = (0.01 * 255)**2
    C2 = (0.03 * 255)**2

    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)
    kernel = cv2.getGaussianKernel(1, 1.5)
    window = np.outer(kernel, kernel.transpose())

    mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]  # valid
    mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
    mu1_sq = mu1**2
    mu2_sq = mu2**2
    mu1_mu2 = mu1 * mu2
    sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
    sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
    sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
                                                            (sigma1_sq + sigma2_sq + C2))
    # ssim_map = (1 * (2 * sigma12 + C2)) / (1 * (sigma1_sq + sigma2_sq + C2))
                                                            
    return ssim_map.mean()


def calculate_ssim(img1, img2):
    '''calculate SSIM
    the same outputs as MATLAB's
    img1, img2: [0, 255]
    '''
    if not img1.shape == img2.shape:
        raise ValueError('Input images must have the same dimensions.')
    if img1.ndim == 2:
        return _ssim(img1, img2)
    elif img1.ndim == 3:
        if img1.shape[2] == 3:
            ssims = []
            for i in range(3):
                ssims.append(_ssim(img1, img2))
            return np.array(ssims).mean()
        elif img1.shape[2] == 1:
            return ssim(np.squeeze(img1), np.squeeze(img2))
    else:
        raise ValueError('Wrong input image dimensions.')


In [None]:
import time
id = '245392010117'
#291411090108
#291412050102
#291411110105
#084652090115
#084652060112_stacked
# 245391030104_stacked
# 245392130114_stacked
# 291411090105_stacked dd
# 291411110109_stacked
# 291412100201_stacked
#084652060107_stacked
#245392070105
start = time.time()
concated_img, best_depth_rotated_img, best_print_img, stacked_img, rotate_angle, ssim_score = register_v1(id)
print(ssim_score, rotate_angle)
print(time.time()-start)
# plt.imshow(normalize_depth_image(cv2.imread(depth_image_path)))
# plt.imshow(normalize_print_image(cv2.flip(cv2.imread(print_image_path), 1))[1])
plt.figure(figsize=(12, 10))
plt.subplot(1, 4, 1), plt.imshow(concated_img, cmap='gray'), plt.title("result")
plt.subplot(1, 4, 2), plt.imshow(stacked_img, cmap='gray'), plt.title("stack")
plt.subplot(1, 4, 3), plt.imshow(best_depth_rotated_img, cmap='gray'), plt.title("best_depth_rotated_img")
plt.subplot(1, 4, 4), plt.imshow(best_print_img, cmap='gray'), plt.title("best_print_img")
plt.show()

In [None]:
import tqdm
import functools
import sys
import os
import concurrent.futures 
# Read .txt file by lines and put to an array
def get_list_from_txt(fileName):
    fileObj = open(fileName, "r")  # opens the file in read mode
    words = fileObj.read().splitlines()  # puts the file into an array
    fileObj.close()
    return words

dirpaths = get_list_from_txt('/content/beauty_ids.txt')
num_imgs = len(dirpaths)

list_of_valid_id = []
num_of_thread_at_once = 20
with concurrent.futures.ProcessPoolExecutor(max_workers=num_of_thread_at_once) as executor:
  for i, image in tqdm.tqdm(enumerate(dirpaths)):
  # for imagePath in dirpaths:
      img_name = str(image)
      list_of_valid_id.append(img_name)
      #print(print_image_path)
  executor.map(register_v1, list_of_valid_id) 