<a href="https://colab.research.google.com/github/detektor777/colab_list_image/blob/main/restore_white_balance.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title ##**Upload** { display-mode: "form" }
from google.colab import files
uploaded = files.upload()
file_name = list(uploaded.keys())[0]

In [None]:
#@title ##**Run** { display-mode: "form" }
from google.colab import files
import cv2
import numpy as np
import PIL.Image
from IPython.display import display, clear_output
import ipywidgets as widgets
from ipywidgets import interactive

max_light_deviation = 40
tolerance = 10
light_penalty_weight = 0.5
step = 0.05
target_weight = 0.9
white_threshold = 200
light_lower = 150
light_upper = 200
max_iterations = 200
weight_deviation_scale = 100
light_dev_penalty_scale = 10
tolerance_multiplier = 2
light_deviation_tolerance = 20
target_weight_tolerance = 0.05

if 'file_name' not in globals():
    uploaded = files.upload()
    file_name = list(uploaded.keys())[0]
img = cv2.imread(file_name)
if img is None:
    raise ValueError(f"Failed to load image from {file_name}. Check the path.")

img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

def gray_world_balance(img):
    b, g, r = cv2.split(img)
    b_mean = np.mean(b)
    g_mean = np.mean(g)
    r_mean = np.mean(r)
    mean = (b_mean + g_mean + r_mean) / 3
    b_scale = mean / b_mean if b_mean != 0 else 1
    g_scale = mean / g_mean if g_mean != 0 else 1
    r_scale = mean / r_mean if r_mean != 0 else 1
    b_corrected = (b * b_scale).clip(0, 255).astype(np.uint8)
    g_corrected = (g * g_scale).clip(0, 255).astype(np.uint8)
    r_corrected = (r * r_scale).clip(0, 255).astype(np.uint8)
    return cv2.merge([b_corrected, g_corrected, r_corrected])

def histogram_balance(img):
    b, g, r = cv2.split(img)
    b_eq = cv2.equalizeHist(b)
    g_eq = cv2.equalizeHist(g)
    r_eq = cv2.equalizeHist(r)
    return cv2.merge([b_eq, g_eq, r_eq])

def apply_algorithms(img, weight_histogram, weight_color_temp=0.0, weight_gray_world=0.0):
    corrected = histogram_balance(img)
    orig_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)
    hist_rgb = cv2.cvtColor(corrected, cv2.COLOR_BGR2RGB).astype(np.float32)
    hist_result = (1 - weight_histogram) * orig_rgb + weight_histogram * hist_rgb

    result_rgb = hist_result.copy()
    if weight_color_temp > 0:
        gray = cv2.cvtColor(result_rgb.astype(np.uint8), cv2.COLOR_RGB2GRAY)
        light_mask = (gray > light_lower) & (gray <= light_upper)
        if np.any(light_mask):
            light_pixels = result_rgb[light_mask]
            r_mean, g_mean, b_mean = np.mean(light_pixels, axis=0)
            r_gain = g_mean / r_mean if r_mean > 0 else 1.0
            b_gain = g_mean / b_mean if b_mean > 0 else 1.0
            temp_adj = result_rgb.copy()
            temp_adj[:, :, 0] = np.clip(result_rgb[:, :, 0] * r_gain, 0, 255)
            temp_adj[:, :, 2] = np.clip(result_rgb[:, :, 2] * b_gain, 0, 255)
            result_rgb = (1 - weight_color_temp) * result_rgb + weight_color_temp * temp_adj

    gray_corrected = gray_world_balance(img)
    gray_rgb = cv2.cvtColor(gray_corrected, cv2.COLOR_BGR2RGB).astype(np.float32)
    final_rgb = (1 - weight_gray_world) * result_rgb + weight_gray_world * gray_rgb

    return cv2.cvtColor(final_rgb.clip(0, 255).astype(np.uint8), cv2.COLOR_RGB2BGR)

def iterative_white_balance(img, max_iterations=max_iterations, tolerance=tolerance, light_penalty_weight=light_penalty_weight, max_light_deviation=max_light_deviation, step=step, target_weight=target_weight, white_threshold=white_threshold, light_lower=light_lower, light_upper=light_upper, weight_deviation_scale=weight_deviation_scale, light_dev_penalty_scale=light_dev_penalty_scale, tolerance_multiplier=tolerance_multiplier, light_deviation_tolerance=light_deviation_tolerance, target_weight_tolerance=target_weight_tolerance):
    scale_factor = 0.5
    img_small = cv2.resize(img, (0, 0), fx=scale_factor, fy=scale_factor)

    weights = {
        "Histogram": 0.0,
        "Color Temperature": 0.0,
        "Gray World": 0.0
    }
    weight_histogram = 0.0
    weight_color_temp = 0.0
    weight_gray_world = 0.0
    histogram_done = False
    color_temp_done = False
    gray_world_done = False
    max_histogram_weight = 0.4

    gray = cv2.cvtColor(img_small, cv2.COLOR_BGR2GRAY)
    _, light_mask_lower = cv2.threshold(gray, light_lower, 255, cv2.THRESH_BINARY)
    _, light_mask_upper = cv2.threshold(gray, light_upper, 255, cv2.THRESH_BINARY)
    light_mask = light_mask_lower & ~light_mask_upper

    light_ratio = np.sum(light_mask > 0) / (img_small.shape[0] * img_small.shape[1])
    if light_ratio > 0.001:
        light_pixels = img_small[light_mask > 0].reshape(-1, img_small.shape[2])
        r_init_light = np.mean(light_pixels[:, 2])
        g_init_light = np.mean(light_pixels[:, 1])
        b_init_light = np.mean(light_pixels[:, 0])
    else:
        r_init_light = g_init_light = b_init_light = 0

    best_weight_histogram = 0.0
    best_weight_color_temp = 0.0
    best_weight_gray_world = 0.0
    best_score = float('inf')

    cache = {}

    iteration = 0
    while iteration < max_iterations and not gray_world_done:
        iteration += 1
        weights_key = (weight_histogram, weight_color_temp, weight_gray_world)
        if weights_key in cache:
            result = cache[weights_key]
        else:
            result = apply_algorithms(img_small, weight_histogram, weight_color_temp, weight_gray_world)
            cache[weights_key] = result

        light_pixels_result = result[light_mask > 0].reshape(-1, img_small.shape[2])
        b_mean_light = np.mean(light_pixels_result[:, 0])
        g_mean_light = np.mean(light_pixels_result[:, 1])
        r_mean_light = np.mean(light_pixels_result[:, 2])
        white_deviation = np.sqrt((r_mean_light - 255)**2 + (g_mean_light - 255)**2 + (b_mean_light - 255)**2)
        light_deviation = np.sqrt((r_mean_light - r_init_light)**2 + (g_mean_light - g_init_light)**2 + (b_mean_light - b_init_light)**2)

        weight_deviation = abs(weight_histogram + weight_color_temp + weight_gray_world - target_weight) * weight_deviation_scale
        light_dev_penalty = max(0, light_deviation - max_light_deviation) * light_dev_penalty_scale
        score = white_deviation + light_penalty_weight * light_deviation + weight_deviation + light_dev_penalty

        if score < best_score:
            best_score = score
            best_weight_histogram = weight_histogram
            best_weight_color_temp = weight_color_temp
            best_weight_gray_world = weight_gray_world

        if not histogram_done and weight_histogram < max_histogram_weight:
            new_histogram_weight = weight_histogram + step
            if new_histogram_weight <= max_histogram_weight:
                new_weights_key = (new_histogram_weight, weight_color_temp, weight_gray_world)
                if new_weights_key in cache:
                    new_result = cache[new_weights_key]
                else:
                    new_result = apply_algorithms(img_small, new_histogram_weight, weight_color_temp, weight_gray_world)
                    cache[new_weights_key] = new_result

                new_light_pixels = new_result[light_mask > 0].reshape(-1, img_small.shape[2])
                new_b_light = np.mean(new_light_pixels[:, 0])
                new_g_light = np.mean(new_light_pixels[:, 1])
                new_r_light = np.mean(new_light_pixels[:, 2])
                new_white_dev = np.sqrt((new_r_light - 255)**2 + (new_g_light - 255)**2 + (new_b_light - 255)**2)
                new_light_dev = np.sqrt((new_r_light - r_init_light)**2 + (new_g_light - g_init_light)**2 + (new_b_light - b_init_light)**2)
                new_weight_dev = abs(new_histogram_weight + weight_color_temp + weight_gray_world - target_weight) * weight_deviation_scale
                new_light_dev_penalty = max(0, new_light_dev - max_light_deviation) * light_dev_penalty_scale
                new_score = new_white_dev + light_penalty_weight * new_light_dev + new_weight_dev + new_light_dev_penalty

                if new_score < best_score:
                    weight_histogram = new_histogram_weight
                    best_score = new_score
                    best_weight_histogram = weight_histogram
                else:
                    histogram_done = True
            else:
                histogram_done = True

        elif not color_temp_done:
            new_weight_color_temp = min(1.0 - weight_histogram, weight_color_temp + step)
            new_weights_key = (weight_histogram, new_weight_color_temp, weight_gray_world)
            if new_weights_key in cache:
                new_result = cache[new_weights_key]
            else:
                new_result = apply_algorithms(img_small, weight_histogram, new_weight_color_temp, weight_gray_world)
                cache[new_weights_key] = new_result

            new_light_pixels = new_result[light_mask > 0].reshape(-1, img_small.shape[2])
            new_b_light = np.mean(new_light_pixels[:, 0])
            new_g_light = np.mean(new_light_pixels[:, 1])
            new_r_light = np.mean(new_light_pixels[:, 2])
            new_white_dev = np.sqrt((new_r_light - 255)**2 + (new_g_light - 255)**2 + (new_b_light - 255)**2)
            new_light_dev = np.sqrt((new_r_light - r_init_light)**2 + (new_g_light - g_init_light)**2 + (new_b_light - b_init_light)**2)
            new_weight_dev = abs(weight_histogram + new_weight_color_temp + weight_gray_world - target_weight) * weight_deviation_scale
            new_light_dev_penalty = max(0, new_light_dev - max_light_deviation) * light_dev_penalty_scale
            new_score = new_white_dev + light_penalty_weight * new_light_dev + new_weight_dev + new_light_dev_penalty

            if new_weight_color_temp <= 0.40:
                weight_color_temp = new_weight_color_temp
                best_score = new_score
                best_weight_color_temp = weight_color_temp
            else:
                color_temp_done = True

        else:
            new_weight_gray_world = min(1.0 - weight_histogram - weight_color_temp, weight_gray_world + step)
            new_weights_key = (weight_histogram, weight_color_temp, new_weight_gray_world)
            if new_weights_key in cache:
                new_result = cache[new_weights_key]
            else:
                new_result = apply_algorithms(img_small, weight_histogram, weight_color_temp, new_weight_gray_world)
                cache[new_weights_key] = new_result

            new_light_pixels = new_result[light_mask > 0].reshape(-1, img_small.shape[2])
            new_b_light = np.mean(new_light_pixels[:, 0])
            new_g_light = np.mean(new_light_pixels[:, 1])
            new_r_light = np.mean(new_light_pixels[:, 2])
            new_white_dev = np.sqrt((new_r_light - 255)**2 + (new_g_light - 255)**2 + (new_b_light - 255)**2)
            new_light_dev = np.sqrt((new_r_light - r_init_light)**2 + (new_g_light - g_init_light)**2 + (new_b_light - b_init_light)**2)
            new_weight_dev = abs(weight_histogram + weight_color_temp + new_weight_gray_world - target_weight) * weight_deviation_scale
            new_light_dev_penalty = max(0, new_light_dev - max_light_deviation) * light_dev_penalty_scale
            new_score = new_white_dev + light_penalty_weight * new_light_dev + new_weight_dev + new_light_dev_penalty

            if new_score < best_score:
                weight_gray_world = new_weight_gray_world
                best_score = new_score
                best_weight_gray_world = weight_gray_world
            else:
                gray_world_done = True

    weights["Histogram"] = best_weight_histogram
    weights["Color Temperature"] = best_weight_color_temp
    weights["Gray World"] = best_weight_gray_world
    return weights

def create_combined_image(weight_histogram, weight_color_temp, weight_gray_world, blend_factor):
    result_rgb = apply_algorithms(img, weight_histogram, weight_color_temp, weight_gray_world)
    result_rgb = cv2.cvtColor(result_rgb, cv2.COLOR_BGR2RGB).astype(np.float32)

    final_rgb = (1 - blend_factor) * img_rgb.astype(np.float32) + blend_factor * result_rgb

    image_original = PIL.Image.fromarray(img_rgb)
    image_enhanced = PIL.Image.fromarray(final_rgb.clip(0, 255).astype(np.uint8))

    max_width = 500
    width_original, height_original = image_original.size
    width_enhanced, height_enhanced = image_enhanced.size

    if width_original > max_width:
        new_height = int(height_original * max_width / width_original)
        image_original = image_original.resize((max_width, new_height))
    if width_enhanced > max_width:
        new_height = int(height_enhanced * max_width / width_enhanced)
        image_enhanced = image_enhanced.resize((max_width, new_height))

    combined_image = PIL.Image.fromarray(np.hstack((np.array(image_original), np.array(image_enhanced))))
    return combined_image

def update_image(auto, weight_histogram, weight_color_temp, weight_gray_world, blend_factor):
    if auto:
        weights = iterative_white_balance(img)
        weight_histogram = weights["Histogram"]
        weight_color_temp = weights["Color Temperature"]
        weight_gray_world = weights["Gray World"]
        slider_histogram.value = weight_histogram
        slider_color_temp.value = weight_color_temp
        slider_gray_world.value = weight_gray_world

    combined_image = create_combined_image(weight_histogram, weight_color_temp, weight_gray_world, blend_factor)
    with output:
        clear_output(wait=True)
        display(combined_image)

checkbox_auto = widgets.Checkbox(value=True, description='Auto')
slider_histogram = widgets.FloatSlider(value=0.0, min=0.0, max=1.0, step=0.01, description='Histogram')
slider_color_temp = widgets.FloatSlider(value=0.0, min=0.0, max=1.0, step=0.01, description='Color Temp')
slider_gray_world = widgets.FloatSlider(value=0.0, min=0.0, max=1.0, step=0.01, description='Gray World')
slider_blend = widgets.FloatSlider(value=1.0, min=0.0, max=1.0, step=0.01, description='Blend Factor')

display(widgets.VBox([checkbox_auto, slider_histogram, slider_color_temp, slider_gray_world, slider_blend]))

output = widgets.Output()
display(output)

def on_value_change(change):
    update_image(checkbox_auto.value, slider_histogram.value, slider_color_temp.value, slider_gray_world.value, slider_blend.value)

checkbox_auto.observe(on_value_change, names='value')
slider_histogram.observe(on_value_change, names='value')
slider_color_temp.observe(on_value_change, names='value')
slider_gray_world.observe(on_value_change, names='value')
slider_blend.observe(on_value_change, names='value')

on_value_change(None)

In [None]:
#@title ##**Download** { display-mode: "form" }
from google.colab import files
import cv2
import numpy as np
from IPython.display import display

weight_histogram = slider_histogram.value
weight_color_temp = slider_color_temp.value
weight_gray_world = slider_gray_world.value
blend_factor = slider_blend.value

result_rgb = apply_algorithms(img, weight_histogram, weight_color_temp, weight_gray_world)
result_rgb = cv2.cvtColor(result_rgb, cv2.COLOR_BGR2RGB).astype(np.float32)
final_rgb = (1 - blend_factor) * img_rgb.astype(np.float32) + blend_factor * result_rgb
final_image = final_rgb.clip(0, 255).astype(np.uint8)
final_image_bgr = cv2.cvtColor(final_image, cv2.COLOR_RGB2BGR)

output_filename = 'processed_image.jpg'
cv2.imwrite(output_filename, final_image_bgr)
files.download(output_filename)

print(f"Image saved as {output_filename} and downloaded")