### Setup

* pip3 install opencv-python
* pip3 install git+https://github.com/whitews/cv-color-features
* pip3 install git+https://github.com/whitews/cv2-extras

In [None]:
from glob import glob
import os
import cv2
import cv_color_features.utils as color_utils
import cv2_extras as cv2x
import numpy as np
import matplotlib.pyplot as plt

In [None]:
true_blue = 120

img_dir = '/home/swhite/git/ihc-image-data/data/mm_e16.5_20x_sox9_sftpc_acta2/'

images = glob(img_dir + '*.tif')
images = sorted(images)

### First, perform lightness non-uniformity correction 

In [None]:
light_corr_imgs_hsv = []

for img in images:
    img_bgr = cv2.imread(img)
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    img_hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV)
    img_v = img_hsv[:, :, 2]

    # get the blue mask
    b_mask = color_utils.create_mask(img_hsv, colors=['blue'])

    # this part is not required to do the correction, as it's 
    # included in the correction function. But, we want to 
    # plot the non-uniformity field for QC
    img_v_b_filtered = cv2.bitwise_and(img_v, img_v, mask=b_mask)
    non_uni_field = cv2x.calculate_nonuniform_field(img_v_b_filtered)

    img_v_corr = cv2x.correct_nonuniformity(img_v, mask=b_mask)
    img_hsv_corr = img_hsv.copy()
    img_hsv_corr[:, :, 2] = img_v_corr
    
    # repair black regions
    black_mask = color_utils.create_mask(img_hsv, colors=['black'])
    img_hsv_corr[black_mask > 0, 2] = img_hsv[black_mask > 0, 2]
    
    light_corr_imgs_hsv.append(img_hsv_corr)

    fig = plt.figure(figsize=(15, 3))
    plt.subplot(151)
    plt.imshow(img_v_b_filtered, cmap='gray')
    plt.subplot(152)
    plt.imshow(non_uni_field, cmap='gray')
    plt.subplot(153)
    plt.imshow(non_uni_field, cmap='gray', vmin=0, vmax=255)
    plt.subplot(154)
    plt.imshow(img_rgb)
    plt.subplot(155)
    plt.imshow(cv2.cvtColor(img_hsv_corr, cv2.COLOR_HSV2RGB))
    plt.show()

In [None]:
for i, img in enumerate(images):
    img_bgr = cv2.imread(img)
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    
    fig = plt.figure(figsize=(16, 8))
    plt.subplot(121)
    plt.imshow(img_rgb)
    plt.subplot(122)
    plt.imshow(cv2.cvtColor(light_corr_imgs_hsv[i], cv2.COLOR_HSV2RGB))
    plt.show()

### Now, select the reference image for color correction

In [None]:
b_h_value_counts = []
b_h_means = []

for i, img_hsv in enumerate(light_corr_imgs_hsv):
    # get the blue mask
    b_mask = color_utils.create_mask(img_hsv, colors=['blue'])
    b_mask_img_h = cv2.bitwise_and(img_hsv[:, :, 0], img_hsv[:, :, 0], mask=b_mask)
    b_h_values = b_mask_img_h[b_mask_img_h > 0].flatten()
    
    b_h_value_counts.append(b_h_values.shape[0])
    b_h_means.append(np.mean(b_h_values))
    
    fig = plt.figure(figsize=(16, 1))
    plt.title(os.path.basename(images[i]))
    plt.hist(b_h_values, bins=39, align='left')

In [None]:
b_h_mean_dev = np.abs(np.array(b_h_means) - true_blue)
max_b_dev = b_h_mean_dev.max()

b_center_devs = 1 - (b_h_mean_dev / max_b_dev)

max_score = 0.0
best_idx = None

upper_count = float(max(b_h_value_counts))

for i, img_path in enumerate(images):
    val_count_comp = b_h_value_counts[i] / upper_count
    
    score = np.mean([val_count_comp, b_center_devs[i]])
        
    if score > max_score:
        max_score = score
        best_idx = i
    
print(os.path.basename(images[best_idx]), max_score)

In [None]:
ref_img_hsv = light_corr_imgs_hsv[best_idx]
ref_img_bgr = cv2.cvtColor(ref_img_hsv, cv2.COLOR_HSV2BGR)

In [None]:
final_corr_images_rgb = []

for i, img_hsv in enumerate(light_corr_imgs_hsv):
    if i == best_idx:
        final_corr_images_rgb.append(cv2.cvtColor(ref_img_bgr, cv2.COLOR_BGR2RGB))
        continue

    tar_img_bgr = cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR)
    cor_img_bgr = cv2x.color_transfer(ref_img_bgr, tar_img_bgr, clip=True, preserve_paper=True)
    cor_img_rgb = cv2.cvtColor(cor_img_bgr, cv2.COLOR_BGR2RGB)
    final_corr_images_rgb.append(cor_img_rgb)

    fig = plt.figure(figsize=(15, 5))

    plt.subplot(131)
    plt.imshow(cv2.cvtColor(ref_img_bgr, cv2.COLOR_BGR2RGB))
    plt.subplot(132)
    plt.imshow(cv2.cvtColor(img_hsv, cv2.COLOR_HSV2RGB))
    plt.subplot(133)
    plt.imshow(cor_img_rgb)
    plt.show()

### Comparison of originals vs final corrections

In [None]:
for i, corr_img_rgb in enumerate(final_corr_images_rgb):
    img_bgr = cv2.imread(images[i])
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    
    fig = plt.figure(figsize=(16, 8))

    plt.subplot(121)
    plt.imshow(img_rgb)
    plt.subplot(122)
    plt.imshow(corr_img_rgb)
    plt.show()