In [None]:
import numpy as np
import matplotlib.pyplot as plt
import cv2
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
from white_brush.io import read_image
from white_brush.colors.color_balance import balance_color
from white_brush.colors.color_extraction import extract_background_colors, hsv_distance_threshold
from white_brush.colors.utils import rgb_to_hsv, hsv_to_rgb

In [None]:
def show_img(img, cmap="Greys_r"):
    if img.ndim >= 3:
        plt.imshow(img)
    else:
        plt.imshow(img, cmap=cmap)
    plt.axis("off")
    plt.gcf().set_size_inches(20, 10)

In [None]:
orig_img = read_image("../test_images/01.jpg")

In [None]:
show_img(orig_img)

In [None]:
img = balance_color(orig_img)
show_img(img)

### HSV Thresholding

In [None]:
bg_colors = extract_background_colors(img, thresh=0.2)
hsv_thresh = hsv_distance_threshold(img, bg_colors)
show_img(hsv_thresh)

#### Erode HSV result

In [None]:
hsv_thresh_img = hsv_thresh.astype(np.uint8)*255
hsv_eroded = cv2.erode(hsv_thresh_img,np.ones((3,3),np.uint8),iterations = 1)
hsv_eroded_mask = hsv_eroded == 255
show_img(hsv_eroded_mask)

### Adaptive Thresholding

In [None]:
adaptive_thresh = cv2.adaptiveThreshold(cv2.cvtColor(img, cv2.COLOR_RGB2GRAY),255,cv2.ADAPTIVE_THRESH_MEAN_C,\
            cv2.THRESH_BINARY,9,2)
adaptive_thresh = adaptive_thresh==255
show_img(adaptive_thresh)

### Combine HSV Distance + Adaptive Thresh

In [None]:
hsv_fg = ~hsv_eroded_mask
adaptive_fg = ~adaptive_thresh
combined_thresh = ~(hsv_fg & adaptive_fg)
show_img(combined_thresh)

## Insert colors

In [None]:
foreground_mask = ~combined_thresh

In [None]:
colors = orig_img[foreground_mask]
colors &= 0b11110000

In [None]:
from mpl_toolkits.mplot3d import Axes3D

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111,projection='3d')
ax.scatter(colors[:,0], colors[:,1], colors[:,2], c=colors.astype(np.float)/255)


In [None]:
from white_brush.colors.calc_colors import choose_representative_colors

In [None]:
rep_colors, color_mapping = choose_representative_colors(colors, n=8)
rep_colors /= 255

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111,projection='3d')
ax.scatter(rep_colors[:,0], rep_colors[:,1], rep_colors[:,2], c=rep_colors)


In [None]:
out_img = np.copy(orig_img).astype(np.float) / 255
out_img[combined_thresh] = [1., 1., 1.] # set background white
out_img[foreground_mask, :] = rep_colors[color_mapping]
show_img(out_img)