In [1]:
import cv2
import numpy as np
from matplotlib import pyplot as plt
%matplotlib inline 

In [2]:
import os

charts_path = "Data/chart/seen_continuous/"
output_path = "Output/seen/"

list_chart = [i for i in os.listdir(charts_path) if i not in ['no_legend', '.DS_Store', 'useless']]

In [3]:
def find_most_frequent_color(img):
    data = np.reshape(img, (-1,3))
    dict_count = {}

    for i in range(data.shape[0]):
        if str(data[i,:].tolist()) not in dict_count.keys():
            dict_count[str(data[i,:].tolist())] = 1
        else:
            dict_count[str(data[i,:].tolist())] += 1
            
    color_str = sorted(dict_count.items(), key=lambda x:x[1], reverse=1)[0][0]
    return np.fromstring(color_str[1:-1], dtype=int, sep=',')



In [4]:
def mask_img(img, mfc, T_bg=5):
    
    mask = np.ones((img.shape[0], img.shape[1]), dtype='uint8')

    for i in range(img.shape[0]):
        for j in range(img.shape[1]):
        # mask for background
            if np.max(abs(img[i,j]-mfc)) <= T_bg:
                img[i,j] = [255,128,128]
                mask[i,j] = 0    
    
    return img, mask

In [5]:
def binarize(img):
    img_gray = cv2.cvtColor(cv2.cvtColor(img, cv2.COLOR_LAB2BGR), cv2.COLOR_BGR2GRAY).reshape(-1,)
    mask_above_127 = (img_gray > 127)
    mask_below_127 = (img_gray <= 127)
    img_gray[mask_above_127] = 0
    img_gray[mask_below_127] = 1
    
    return img_gray.reshape((img.shape[0], img.shape[1]))

In [6]:
# we may not
def flood_fill(img, seed_point=(20, 100)):
    return cv2.floodFill(img, None, seed_point, 1)[1]

In [7]:
def erode(img, kernel_size=5, iteration=1):
    kernel = np.ones((kernel_size, kernel_size), np.uint8)
    img_eroded = cv2.erode(img, kernel, iterations = iteration)
    return img_eroded

In [8]:
def imshow_components(labels):
    # Map component labels to hue val
    label_hue = np.uint8(179*labels/np.max(labels))
    blank_ch = 255*np.ones_like(label_hue)
    labeled_img = cv2.merge([label_hue, blank_ch, blank_ch])

    # cvt to BGR for display
    labeled_img = cv2.cvtColor(labeled_img, cv2.COLOR_LAB2BGR)

    # set bg label to black
    labeled_img[label_hue==0] = 0

    plt.imshow(labeled_img, cmap='gray', vmin=0, vmax=1)
    plt.axis('off')
    plt.show()
#     cv2.waitKey()

# imshow_components(labels_im)

In [9]:
def find_largest_connected_component(img):
    num_labels, labels = cv2.connectedComponents(img)
    dict_area = {}
    for i in range(num_labels):
        dict_area[i] = area_i = np.sum(labels == i)
                  
    return sorted(dict_area.items(), key=lambda x:x[1], reverse=1)[1][0], labels

In [10]:
def get_colors(img, idx, labels):
    positions = np.argwhere(labels==idx)
    position_lefttop = positions[0,:]
    position_rightbottom = positions[-1,:]
    
    # width <= height
    if (position_lefttop[0]-position_rightbottom[0]) <= (position_lefttop[1]-position_rightbottom[1]):
        # most_top_color, most_bottom_color
        return [img[position_lefttop[0], int((position_lefttop[1]+position_rightbottom[1])/2)], \
                img[position_rightbottom[0], int((position_lefttop[1]+position_rightbottom[1])/2)]]
    else:
        return [img[int((position_lefttop[0]+position_rightbottom[0])/2), position_lefttop[1]], \
                img[int((position_lefttop[1]+position_rightbottom[1])/2), position_lefttop[0]]]

In [11]:
def plot_color(colors):
    color_matrices = []
    for i in colors:
        color_matrices.append(np.full((20, 20, 3), i.tolist(), dtype='uint8'))
        
    img = np.concatenate(color_matrices, axis=1)
    plt.imshow(cv2.cvtColor(img, cv2.COLOR_LAB2RGB))
    plt.axis('off')
    plt.show()   
    
#     return img

In [12]:
def get_area_plot_and_save(img, idx, labels, output_path, chart_name):
    positions = np.argwhere(labels==idx)
    position_lefttop = positions[0,:]
    position_rightbottom = positions[-1,:]
    
    area = img[position_lefttop[0]:position_rightbottom[0], position_lefttop[1]:position_rightbottom[1], :]
    
    try:
        cv2.imwrite(output_path+chart_name,cv2.cvtColor(area, cv2.COLOR_LAB2BGR))
    except:
        return
    
#     plt.imshow(cv2.cvtColor(area, cv2.COLOR_LAB2RGB))
#     plt.axis('off')
#     plt.show() 
    
#     return cv2.imwrite(output_path+chart_name,cv2.cvtColor(area, cv2.COLOR_LAB2BGR))

In [13]:
def get_result(charts_path, output_path, chart_name):
    
    
#     best_seed_point = (0, 0)
#     largest_area = 0
    
#     for i in range(0, img_lab.shape[1], 10):
#         for j in range(0, img_lab.shape[0], 10):
#     sp = (i, j)

    previous_pos_lt = []
    previous_pos_rb = []
    
    dict_pos_lt = {}
    dict_pos_rb = {}

    for k in range(10, 100, 10):
        for p in range(10, 100, 10):
            try:
#                 print((int(img_lab.shape[1]*(k/100.0)), int(img_lab.shape[0]*(p/100.0))))
                img_bgr = cv2.imread(charts_path+chart_name, cv2.IMREAD_COLOR) 
                img_lab = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2LAB)
                mfc = find_most_frequent_color(img_lab)
                masked_img_lab,mask = mask_img(img_lab, mfc, T_bg=0)
                img_binary = binarize(masked_img_lab)
                img_filled = flood_fill(img_binary, seed_point=(int(img_lab.shape[1]*(k/100.0)), int(img_lab.shape[0]*(p/100.0))))
                img_eroded = erode(img_filled, kernel_size=5)
#                 num_labels, labels_im = cv2.connectedComponents(img_eroded)
                idx, labels = find_largest_connected_component(img_eroded)
                n = chart_name.split(".")[0]+f"_{k}_{p}.png"
                positions = np.argwhere(labels==idx)
                position_lefttop = positions[0,:]
                position_rightbottom = positions[-1,:]
                
#                 print(position_lefttop, position_rightbottom)

                if not list(position_lefttop) in previous_pos_lt or\
                not list(position_rightbottom) in previous_pos_rb:
                    previous_pos_lt.append(list(position_lefttop))
                    previous_pos_rb.append(list(position_rightbottom))
                    get_area_plot_and_save(img_lab, idx, labels, output_path, n)
                    dict_pos_lt[f"{k}_{p}"] = list(position_lefttop)
                    dict_pos_rb[f"{k}_{p}"] = list(position_rightbottom)
            except:
                print(chart_name, k, p)
                
    return {
        "lt": dict_pos_lt,
        "rb": dict_pos_rb
    }
                

#     for k in range(2,5):
#         for p in range(2,5):
#             try:
# #                 print((int(img_lab.shape[1]/k), int(img_lab.shape[0]/p)))
#                 img_filled = flood_fill(img_binary, seed_point=(int(img_lab.shape[1]/k), int(img_lab.shape[0]/p)))
#                 img_eroded = erode(img_filled, kernel_size=5)
# #                 num_labels, labels_im = cv2.connectedComponents(img_eroded)
#                 idx, labels = find_largest_connected_component(img_eroded)
#                 n = chart_name.split(".")[0]+f"_{k}_{p}.png"
#                 positions = np.argwhere(labels==idx)
#                 position_lefttop = positions[0,:]
#                 position_rightbottom = positions[-1,:]
#                 print(position_lefttop, position_rightbottom)
    
#                 if not np.array_equal(position_lefttop, previous_pos_lt) or\
#                 not np.array_equal(position_rightbottom, previous_pos_rb):
#                     previous_pos_lt = position_lefttop
#                     previous_pos_rb = position_rightbottom
                
#                 get_area_plot_and_save(img_lab, idx, labels, output_path, n)
#             except:
#                 print(chart_name, k, p)



#     positions = np.argwhere(labels==idx)
#     position_lefttop = positions[0,:]
#     position_rightbottom = positions[-1,:]

#     area = abs((position_lefttop[0]-position_rightbottom[0])*(position_lefttop[1]-position_rightbottom[1]))

#     if area > largest_area:
#         n = chart_name.split(".")[0]+f"({i} {j}).png"
#         largest_area = area
#         best_seed_point = sp
#         print(n)
#         get_area_plot_and_save(img_lab, idx, labels, output_path, n)

                

In [14]:
from tqdm import tqdm

In [15]:
dict_point = {}

for i in tqdm(list_chart):
    dict_point[i.split(".")[0]] = get_result(charts_path, output_path, i)

 80%|████████  | 16/20 [06:30<02:34, 38.71s/it]

C2.png 10 10
C2.png 10 20
C2.png 10 30
C2.png 10 40
C2.png 20 10
C2.png 20 20
C2.png 20 30
C2.png 20 40
C2.png 30 10
C2.png 30 20
C2.png 30 30
C2.png 30 40
C2.png 40 10
C2.png 40 20
C2.png 40 30
C2.png 40 40
C2.png 50 10
C2.png 50 20
C2.png 50 30
C2.png 50 40
C2.png 60 10
C2.png 60 20
C2.png 60 30
C2.png 60 40
C2.png 70 10
C2.png 70 20
C2.png 70 30
C2.png 70 40
C2.png 70 50
C2.png 70 60
C2.png 70 70
C2.png 70 80
C2.png 70 90
C2.png 80 10
C2.png 80 20
C2.png 80 30
C2.png 80 40
C2.png 80 50
C2.png 80 60
C2.png 80 70
C2.png 80 80
C2.png 80 90
C2.png 90 10
C2.png 90 20
C2.png 90 30
C2.png 90 50
C2.png 90 70


 85%|████████▌ | 17/20 [06:43<01:32, 30.97s/it]

C2.png 90 90


100%|██████████| 20/20 [07:56<00:00, 23.84s/it]
