In [46]:
import sys
!{sys.executable} -m pip install opencv-python



In [47]:
import cv2
import numpy as np
import math 

# Read the image
image = cv2.imread("./data/sky.jpg")

cv2.imshow("Original",image)
cv2.waitKey(0)
cv2.destroyAllWindows()

In [48]:
def get_swatch(img):
    v_list = []
    copy= img.copy()
    def select(e, x, y, f, param):
        if e == cv2.EVENT_LBUTTONDOWN:
            v_list.append((x,y))
            if len(v_list)%2 == 0:
                cv2.rectangle(copy, v_list[-2], v_list[-1], (0,255,0), 1)

            else:
                cv2.circle(copy, v_list[-1], 1, (0,0,255), 1)

    cv2.namedWindow("Swatch")
    cv2.setMouseCallback("Swatch", select)
    while(True):
        cv2.imshow("Swatch", copy)
        key = cv2.waitKey(1)
        if key == 122 and len(v_list)!=0:
            v_list.pop()
            copy = img.copy()
            for x,y in zip(v_list[::2],v_list[1::2]):
                cv2.rectangle(copy, x, y, (0,255,0), 1)
            if(len(v_list)%2==1):
                cv2.circle(copy, v_list[-1], 1, (0,0,255), 1)
        if key!=-1:
            break
    cv2.destroyWindow("Swatch")
    l = [(x[0],x[1],y[0]-x[0],y[1]-x[1]) for x,y in zip(v_list[::2],v_list[1::2])]
    if len(l)==0:
        return None
    return l


def get_swatch_img(img):
    s = get_swatch(img)
    if s is None: return None
    a, b= s[0][0], s[0][1]
    da, db = s[0][2]+a, s[0][3]+b

    if a> da:
        a,da = da,a
    if b> db:
        b,db = db,b
    return img[b:db,a:da,:]

In [49]:
!pip install colorthief
from pickletools import uint8
import time
from colorthief import ColorThief

def dominant_color(img):
    if img is None: return None
    cv2.imwrite("./data/swatch.jpg",img)
    time.sleep(5)
    color_thief = ColorThief("./data/swatch.jpg")
    color = color_thief.get_color(quality=1)
    return [color[2],color[1],color[0]]

def get_binnary_mask(channel,x,y, c_min, c_max):
    channel = np.resize(channel,(x,y))
    mask = (channel>=c_min) & (channel<=c_max)
    return mask.astype(int)

def sky_mask(blue, blur):
    if blue is None or blur is None: return None
    b,g,r = blue[0],blue[1],blue[2]
    dev = 30
    b_min, b_max = b-dev, b+dev
    g_min, g_max = g-dev, g+dev
    r_min, r_max = r-dev, r+dev
    x = blur.shape[0]
    y = blur.shape[1]
    b_mask = get_binnary_mask(blur[:,:,0],x,y,b_min,b_max)
    g_mask = get_binnary_mask(blur[:,:,1],x,y,g_min,g_max)
    r_mask = get_binnary_mask(blur[:,:,2],x,y,r_min,r_max)
    sum_mask = b_mask+g_mask+r_mask
    t = (sum_mask == 3)
    return t.astype(int)

def refined_mask(mask):
    laplace = cv2.Laplacian(mask.astype(float), cv2.CV_64F)
    laplace /= np.max(laplace)
    return np.multiply(np.exp(-laplace),(mask.astype(float) ==1).astype(int))
    




In [50]:
def sky_enhancement(img):
    blur = cv2.bilateralFilter(img, 9, 75, 75)
    
    blue = get_swatch_img(img)
    if blue is None: return []
    ideal_blue = dominant_color(blue)
    sky_m = sky_mask(ideal_blue, blur)
    
    cloud = get_swatch_img(img)
    ideal_cloud = dominant_color(cloud)
    cloud_m = sky_mask(ideal_cloud, blur)
    
    binary_mask = sky_m + cloud_m if cloud_m is not None else sky_m

    final_mask = refined_mask(binary_mask)

    final = np.empty([blur.shape[0],blur.shape[1],blur.shape[2]])
    for i in range(3):
        final[:,:,i] = np.multiply(final_mask,blur[:,:,i])

    return final.astype(np.uint8)


In [51]:
colored_img = sky_enhancement(image)
cv2.imshow("Original",image)
cv2.imshow("Final",colored_img)
cv2.waitKey(0)
cv2.destroyAllWindows()