<a href="https://colab.research.google.com/github/detektor777/colab_list_video/blob/main/restore_white_balance.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

https://drive.google.com/drive

In [None]:
#@title ##**Select Video File** { display-mode: "form" }
import os
import ipywidgets as widgets
from IPython.display import display, clear_output
from google.colab import files
from google.colab import drive

upload_option = "Load from Google Drive Root"  #@param ["Upload from PC", "Load from Google Drive Root", "Load from Google Drive"]

file_name = None
last_selected_button = None

def reset_button_colors(buttons):
    for btn in buttons:
        btn.style.button_color = None

if upload_option == "Upload from PC":
    print("Please upload a video file.")
    root_dir = '/content/'
    uploaded = files.upload()
    if uploaded:
        file_name = list(uploaded.keys())[0]
    else:
        print("No file uploaded.")
        file_name = None

elif upload_option == "Load from Google Drive Root":
    drive.mount('/content/drive')
    root_dir = '/content/drive/MyDrive/'

    video_extensions = ['.mp4', '.mkv', '.avi', '.mov']
    files_list = []

    for f in os.listdir(root_dir):
        if os.path.isfile(os.path.join(root_dir, f)) and os.path.splitext(f)[1].lower() in video_extensions:
            files_list.append(f)

    if not files_list:
        print("No video files found in Google Drive root.")
        file_name = None
    else:
        print("Select a video file from Google Drive root:")

        output = widgets.Output()
        buttons = []

        def on_button_clicked(b):
            global file_name, last_selected_button
            with output:
                clear_output()
                reset_button_colors(buttons)
                selected_file = b.description
                file_name = os.path.join(root_dir, selected_file)

                if file_name and os.path.exists(file_name):
                    b.style.button_color = 'green'
                else:
                    b.style.button_color = 'red'

                last_selected_button = b
                print(f"Selected file: {file_name if file_name else 'None'}")

        for file in files_list:
            button = widgets.Button(description=file, layout=widgets.Layout(width='500px', overflow='hidden', text_overflow='ellipsis'))
            button.on_click(on_button_clicked)
            buttons.append(button)

        display(widgets.VBox(buttons), output)

elif upload_option == "Load from Google Drive":
    drive.mount('/content/drive')
    root_dir = '/content/drive/MyDrive/'

    video_extensions = ['.mp4', '.mkv', '.avi', '.mov']
    files_list = []

    for dirpath, _, filenames in os.walk(root_dir):
        for f in filenames:
            if os.path.splitext(f)[1].lower() in video_extensions:
                relative_path = os.path.relpath(os.path.join(dirpath, f), root_dir)
                files_list.append(relative_path)

    if not files_list:
        print("No video files found in Google Drive or its subfolders.")
        file_name = None
    else:
        print("Select a video file from Google Drive (including subfolders):")

        output = widgets.Output()
        buttons = []

        def on_button_clicked(b):
            global file_name, last_selected_button
            with output:
                clear_output()
                reset_button_colors(buttons)
                selected_file = b.description
                file_name = os.path.join(root_dir, selected_file)

                if file_name and os.path.exists(file_name):
                    b.style.button_color = 'green'
                else:
                    b.style.button_color = 'red'

                last_selected_button = b
                print(f"Selected file: {file_name if file_name else 'None'}")

        for file in files_list:
            button = widgets.Button(description=file, layout=widgets.Layout(width='500px', overflow='hidden', text_overflow='ellipsis'))
            button.on_click(on_button_clicked)
            buttons.append(button)

        display(widgets.VBox(buttons), output)

if file_name:
    print(f"Video file path set to: {file_name}")
else:
    print("Video file path not set. Please select a file.")

In [None]:
#@title ##**Files Config** { display-mode: "form" }
import os
from google.colab import files
import shutil
from google.colab import drive
output_folder = "google_drive" #@param ["google_drive","root"]

upload_folder = 'upload'
result_folder = 'results'

if output_folder == "google_drive":
    if not os.path.exists('/content/drive'):
        print("Google Drive не подключён. Подключаем...")
        drive.mount('/content/drive')
    root_folder = '/content/drive/MyDrive/';
    real_output_folder = '/content/drive/MyDrive/real_output'
    real_input_folder = "/content/drive/MyDrive/real_input"
elif output_folder == "root":
    root_folder = '/content/';
    real_output_folder = '/content/real_output'
    real_input_folder = "/content/real_input"

if not os.path.exists(real_output_folder):
    os.makedirs(real_output_folder)

if not os.path.exists(real_input_folder):
    os.makedirs(real_input_folder)

#clear folders
clear_input_folder = False #@param {type:"boolean"}
up_to_frame = "" #@param {type:"string"}
from_frame = "" #@param {type:"string"}

def clean_folder(folder_path, up_to=None, from_frame=None):
    print(f"\nCurrent parameters:")
    print(f"Delete frames up to: {up_to if up_to else 'not specified'}")
    print(f"Delete frames after: {from_frame if from_frame else 'not specified'}")

    if not os.path.isdir(folder_path):
        print(f"\nFolder {folder_path} does not exist!")
        print("Creating a new folder...")
        os.makedirs(folder_path)
        return

    if not up_to and not from_frame:
        print("\nNo parameters specified - deleting all folder content...")
        shutil.rmtree(folder_path)
        os.makedirs(folder_path)
        print(f"Folder {folder_path} cleared and recreated")
        return

    print("\nStarting file processing...")
    files = os.listdir(folder_path)
    jpg_files = [f for f in files if f.endswith('.jpg')]

    if not jpg_files:
        print("No JPG files to process in the folder")
        return

    deleted_count = 0
    processed_count = 0

    for filename in jpg_files:
        try:
            frame_number = int(filename.split('.')[0])
            should_delete = False

            if up_to and from_frame:
                if frame_number < int(up_to) or frame_number > int(from_frame):
                    should_delete = True
            elif up_to:
                if frame_number < int(up_to):
                    should_delete = True
            elif from_frame:
                if frame_number > int(from_frame):
                    should_delete = True

            if should_delete:
                file_path = os.path.join(folder_path, filename)
                os.remove(file_path)
                deleted_count += 1
                if deleted_count <= 5:
                    print(f'File deleted: {filename}')
                elif deleted_count == 6:
                    print('...')
            else:
                processed_count += 1

        except ValueError:
            print(f'Skipped file with invalid name: {filename}')

    print(f'\nProcessing complete:')
    print(f'Total files: {len(jpg_files)}')
    print(f'Files deleted: {deleted_count}')
    print(f'Files retained: {processed_count}')

if clear_input_folder:
    up_to_frame = up_to_frame if up_to_frame != "0" else None
    from_frame = from_frame if from_frame != "0" else None
    clean_folder(real_input_folder, up_to_frame, from_frame)

clear_output_folder = False #@param {type:"boolean"}

if clear_output_folder:
    if os.path.isdir(real_output_folder):
        shutil.rmtree(real_output_folder)
    os.makedirs(real_output_folder)

In [None]:
#@title ##**Video Config** { display-mode: "form" }
from google.colab import files
import cv2
import numpy as np
import PIL.Image
from IPython.display import display, clear_output, HTML
import ipywidgets as widgets
from ipywidgets import interactive
import base64
import io

max_light_deviation = 40
tolerance = 10
light_penalty_weight = 0.5
step = 0.05
target_weight = 0.9
white_threshold = 200
light_lower = 150
light_upper = 200
max_iterations = 200
weight_deviation_scale = 100
light_dev_penalty_scale = 10
tolerance_multiplier = 2
light_deviation_tolerance = 20
target_weight_tolerance = 0.05

if 'file_name' not in globals():
    print("No video file selected. Please run the video selection cell first.")
    uploaded = files.upload()
    file_name = list(uploaded.keys())[0]

cap = cv2.VideoCapture(file_name)
if not cap.isOpened():
    raise ValueError(f"Failed to open video from {file_name}. Check the path.")

fps = cap.get(cv2.CAP_PROP_FPS)
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

display_width = 500
display_height = int(height * (display_width / width))

current_frame = None
current_frame_num = 0

def gray_world_balance(img):
    b, g, r = cv2.split(img)
    b_mean = np.mean(b)
    g_mean = np.mean(g)
    r_mean = np.mean(r)
    mean = (b_mean + g_mean + r_mean) / 3
    b_scale = mean / b_mean if b_mean != 0 else 1
    g_scale = mean / g_mean if g_mean != 0 else 1
    r_scale = mean / r_mean if r_mean != 0 else 1
    b_corrected = (b * b_scale).clip(0, 255).astype(np.uint8)
    g_corrected = (g * g_scale).clip(0, 255).astype(np.uint8)
    r_corrected = (r * r_scale).clip(0, 255).astype(np.uint8)
    return cv2.merge([b_corrected, g_corrected, r_corrected])

def histogram_balance(img):
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    b, g, r = cv2.split(img)
    b_eq = clahe.apply(b)
    g_eq = clahe.apply(g)
    r_eq = clahe.apply(r)
    return cv2.merge([b_eq, g_eq, r_eq])

def apply_algorithms(img, weight_histogram, weight_color_temp=0.0, weight_gray_world=0.0):
    corrected = histogram_balance(img)
    orig_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)
    hist_rgb = cv2.cvtColor(corrected, cv2.COLOR_BGR2RGB).astype(np.float32)
    hist_result = (1 - weight_histogram) * orig_rgb + weight_histogram * hist_rgb

    result_rgb = hist_result.copy()
    if weight_color_temp > 0:
        gray = cv2.cvtColor(result_rgb.astype(np.uint8), cv2.COLOR_RGB2GRAY)
        light_mask = (gray > light_lower) & (gray <= light_upper)
        if np.any(light_mask):
            light_pixels = result_rgb[light_mask]
            r_mean, g_mean, b_mean = np.mean(light_pixels, axis=0)
            r_gain = g_mean / r_mean if r_mean > 0 else 1.0
            b_gain = g_mean / b_mean if b_mean > 0 else 1.0
            temp_adj = result_rgb.copy()
            temp_adj[:, :, 0] = np.clip(result_rgb[:, :, 0] * r_gain, 0, 255)
            temp_adj[:, :, 2] = np.clip(result_rgb[:, :, 2] * b_gain, 0, 255)
            result_rgb = (1 - weight_color_temp) * result_rgb + weight_color_temp * temp_adj

    gray_corrected = gray_world_balance(img)
    gray_rgb = cv2.cvtColor(gray_corrected, cv2.COLOR_BGR2RGB).astype(np.float32)
    final_rgb = (1 - weight_gray_world) * result_rgb + weight_gray_world * gray_rgb

    return cv2.cvtColor(final_rgb.clip(0, 255).astype(np.uint8), cv2.COLOR_RGB2BGR)

def iterative_white_balance(img, max_iterations=max_iterations, tolerance=tolerance, light_penalty_weight=light_penalty_weight, max_light_deviation=max_light_deviation, step=step, target_weight=target_weight, white_threshold=white_threshold, light_lower=light_lower, light_upper=light_upper, weight_deviation_scale=weight_deviation_scale, light_dev_penalty_scale=light_dev_penalty_scale, tolerance_multiplier=tolerance_multiplier, light_deviation_tolerance=light_deviation_tolerance, target_weight_tolerance=target_weight_tolerance):
    scale_factor = 0.5
    img_small = cv2.resize(img, (0, 0), fx=scale_factor, fy=scale_factor)

    weights = {
        "Histogram": 0.0,
        "Color Temperature": 0.0,
        "Gray World": 0.0
    }
    weight_histogram = 0.0
    weight_color_temp = 0.0
    weight_gray_world = 0.0
    histogram_done = False
    color_temp_done = False
    gray_world_done = False
    max_histogram_weight = 0.4

    gray = cv2.cvtColor(img_small, cv2.COLOR_BGR2GRAY)
    _, light_mask_lower = cv2.threshold(gray, light_lower, 255, cv2.THRESH_BINARY)
    _, light_mask_upper = cv2.threshold(gray, light_upper, 255, cv2.THRESH_BINARY)
    light_mask = light_mask_lower & ~light_mask_upper

    light_ratio = np.sum(light_mask > 0) / (img_small.shape[0] * img_small.shape[1])
    if light_ratio > 0.001:
        light_pixels = img_small[light_mask > 0].reshape(-1, img_small.shape[2])
        r_init_light = np.mean(light_pixels[:, 2])
        g_init_light = np.mean(light_pixels[:, 1])
        b_init_light = np.mean(light_pixels[:, 0])
    else:
        r_init_light = g_init_light = b_init_light = 0

    best_weight_histogram = 0.0
    best_weight_color_temp = 0.0
    best_weight_gray_world = 0.0
    best_score = float('inf')

    cache = {}

    iteration = 0
    while iteration < max_iterations and not gray_world_done:
        iteration += 1
        weights_key = (weight_histogram, weight_color_temp, weight_gray_world)
        if weights_key in cache:
            result = cache[weights_key]
        else:
            result = apply_algorithms(img_small, weight_histogram, weight_color_temp, weight_gray_world)
            cache[weights_key] = result

        light_pixels_result = result[light_mask > 0].reshape(-1, img_small.shape[2])
        b_mean_light = np.mean(light_pixels_result[:, 0])
        g_mean_light = np.mean(light_pixels_result[:, 1])
        r_mean_light = np.mean(light_pixels_result[:, 2])
        white_deviation = np.sqrt((r_mean_light - 255)**2 + (g_mean_light - 255)**2 + (b_mean_light - 255)**2)
        light_deviation = np.sqrt((r_mean_light - r_init_light)**2 + (g_mean_light - g_init_light)**2 + (b_mean_light - b_init_light)**2)

        weight_deviation = abs(weight_histogram + weight_color_temp + weight_gray_world - target_weight) * weight_deviation_scale
        light_dev_penalty = max(0, light_deviation - max_light_deviation) * light_dev_penalty_scale
        score = white_deviation + light_penalty_weight * light_deviation + weight_deviation + light_dev_penalty

        if score < best_score:
            best_score = score
            best_weight_histogram = weight_histogram
            best_weight_color_temp = weight_color_temp
            best_weight_gray_world = weight_gray_world

        if not histogram_done and weight_histogram < max_histogram_weight:
            new_histogram_weight = min(1.0, weight_histogram + step)
            if new_histogram_weight <= max_histogram_weight:
                new_weights_key = (new_histogram_weight, weight_color_temp, weight_gray_world)
                if new_weights_key in cache:
                    new_result = cache[new_weights_key]
                else:
                    new_result = apply_algorithms(img_small, new_histogram_weight, weight_color_temp, weight_gray_world)
                    cache[new_weights_key] = new_result

                new_light_pixels = new_result[light_mask > 0].reshape(-1, img_small.shape[2])
                new_b_light = np.mean(new_light_pixels[:, 0])
                new_g_light = np.mean(new_light_pixels[:, 1])
                new_r_light = np.mean(new_light_pixels[:, 2])
                new_white_dev = np.sqrt((new_r_light - 255)**2 + (new_g_light - 255)**2 + (new_b_light - 255)**2)
                new_light_dev = np.sqrt((new_r_light - r_init_light)**2 + (new_g_light - g_init_light)**2 + (new_b_light - b_init_light)**2)
                new_weight_dev = abs(new_histogram_weight + weight_color_temp + weight_gray_world - target_weight) * weight_deviation_scale
                new_light_dev_penalty = max(0, new_light_dev - max_light_deviation) * light_dev_penalty_scale
                new_score = new_white_dev + light_penalty_weight * new_light_dev + new_weight_dev + new_light_dev_penalty

                if new_score < best_score:
                    weight_histogram = new_histogram_weight
                    best_score = new_score
                    best_weight_histogram = weight_histogram
                else:
                    histogram_done = True
            else:
                histogram_done = True

        elif not color_temp_done:
            new_weight_color_temp = min(1.0 - weight_histogram, weight_color_temp + step)
            new_weights_key = (weight_histogram, new_weight_color_temp, weight_gray_world)
            if new_weights_key in cache:
                new_result = cache[new_weights_key]
            else:
                new_result = apply_algorithms(img_small, weight_histogram, new_weight_color_temp, weight_gray_world)
                cache[new_weights_key] = new_result

            new_light_pixels = new_result[light_mask > 0].reshape(-1, img_small.shape[2])
            new_b_light = np.mean(new_light_pixels[:, 0])
            new_g_light = np.mean(new_light_pixels[:, 1])
            new_r_light = np.mean(new_light_pixels[:, 2])
            new_white_dev = np.sqrt((new_r_light - 255)**2 + (new_g_light - 255)**2 + (new_b_light - 255)**2)
            new_light_dev = np.sqrt((new_r_light - r_init_light)**2 + (new_g_light - g_init_light)**2 + (new_b_light - b_init_light)**2)
            new_weight_dev = abs(weight_histogram + new_weight_color_temp + weight_gray_world - target_weight) * weight_deviation_scale
            new_light_dev_penalty = max(0, new_light_dev - max_light_deviation) * light_dev_penalty_scale
            new_score = new_white_dev + light_penalty_weight * new_light_dev + new_weight_dev + new_light_dev_penalty

            if new_weight_color_temp <= 0.40:
                weight_color_temp = new_weight_color_temp
                best_score = new_score
                best_weight_color_temp = weight_color_temp
            else:
                color_temp_done = True

        else:
            new_weight_gray_world = min(1.0 - weight_histogram - weight_color_temp, weight_gray_world + step)
            new_weights_key = (weight_histogram, weight_color_temp, new_weight_gray_world)
            if new_weights_key in cache:
                new_result = cache[new_weights_key]
            else:
                new_result = apply_algorithms(img_small, weight_histogram, weight_color_temp, new_weight_gray_world)
                cache[new_weights_key] = new_result

            new_light_pixels = new_result[light_mask > 0].reshape(-1, img_small.shape[2])
            new_b_light = np.mean(new_light_pixels[:, 0])
            new_g_light = np.mean(new_light_pixels[:, 1])
            new_r_light = np.mean(new_light_pixels[:, 2])
            new_white_dev = np.sqrt((new_r_light - 255)**2 + (new_g_light - 255)**2 + (new_b_light - 255)**2)
            new_light_dev = np.sqrt((new_r_light - r_init_light)**2 + (new_g_light - g_init_light)**2 + (new_b_light - b_init_light)**2)
            new_weight_dev = abs(weight_histogram + weight_color_temp + new_weight_gray_world - target_weight) * weight_deviation_scale
            new_light_dev_penalty = max(0, new_light_dev - max_light_deviation) * light_dev_penalty_scale
            new_score = new_white_dev + light_penalty_weight * new_light_dev + new_weight_dev + new_light_dev_penalty

            if new_score < best_score:
                weight_gray_world = new_weight_gray_world
                best_score = new_score
                best_weight_gray_world = weight_gray_world
            else:
                gray_world_done = True

    weights["Histogram"] = best_weight_histogram
    weights["Color Temperature"] = best_weight_color_temp
    weights["Gray World"] = best_weight_gray_world
    return weights

def create_combined_image(frame, weight_histogram, weight_color_temp, weight_gray_world, blend_factor=1.0):
    result_rgb = apply_algorithms(frame, weight_histogram, weight_color_temp, weight_gray_world)
    result_rgb = cv2.cvtColor(result_rgb, cv2.COLOR_BGR2RGB).astype(np.float32)

    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB).astype(np.float32)
    final_rgb = (1 - blend_factor) * frame_rgb + blend_factor * result_rgb

    image_original = PIL.Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
    image_enhanced = PIL.Image.fromarray(final_rgb.clip(0, 255).astype(np.uint8))

    max_width = 500
    width_original, height_original = image_original.size
    width_enhanced, height_enhanced = image_enhanced.size

    if width_original > max_width:
        new_height = int(height_original * max_width / width_original)
        image_original = image_original.resize((max_width, new_height))
    if width_enhanced > max_width:
        new_height = int(height_enhanced * max_width / width_enhanced)
        image_enhanced = image_enhanced.resize((max_width, new_height))

    combined_image = PIL.Image.fromarray(np.hstack((np.array(image_original), np.array(image_enhanced))))
    return combined_image

# Function to create a base64 encoded video frame for display
def get_frame_data_url(frame):
    _, buffer = cv2.imencode('.jpg', cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
    encoded_frame = base64.b64encode(buffer).decode('ascii')
    return f"data:image/jpeg;base64,{encoded_frame}"

# Function to update the displayed frame
def update_frame(frame_num):
    global current_frame, current_frame_num
    cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
    ret, frame = cap.read()
    if ret:
        current_frame = frame
        current_frame_num = frame_num
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame_resized = cv2.resize(frame_rgb, (display_width, display_height))
        frame_pil = PIL.Image.fromarray(frame_resized)

        # Convert PIL image to base64 for display
        buffered = io.BytesIO()
        frame_pil.save(buffered, format="JPEG")
        img_str = base64.b64encode(buffered.getvalue()).decode()

        with video_output:
            clear_output(wait=True)
            display(HTML(f'<img src="data:image/jpeg;base64,{img_str}" />'))
            display(HTML(f'<p>Frame: {frame_num}/{frame_count-1}</p>'))

# Function to handle the "Test" button click - process the current frame
def on_test_button_clicked(b):
    global current_frame
    if current_frame is None:
        with result_output:
            clear_output(wait=True)
            print("No frame loaded. Please select a frame first.")
        return

    process_frame(current_frame)

# Function to process the selected frame
def process_frame(frame):
    if checkbox_auto.value:
        # Run the automatic white balance algorithm
        with result_output:
            clear_output(wait=True)
            display(HTML("<p>Processing auto white balance... Please wait.</p>"))

        weights = iterative_white_balance(frame)
        weight_histogram = weights["Histogram"]
        weight_color_temp = weights["Color Temperature"]
        weight_gray_world = weights["Gray World"]

        # Update sliders without triggering their observe callbacks
        with slider_hist_output:
            slider_histogram.value = weight_histogram
        with slider_color_output:
            slider_color_temp.value = weight_color_temp
        with slider_gray_output:
            slider_gray_world.value = weight_gray_world
    else:
        # Use manual settings from sliders
        weight_histogram = slider_histogram.value
        weight_color_temp = slider_color_temp.value
        weight_gray_world = slider_gray_world.value

    # Create and display the processed image
    combined_image = create_combined_image(frame, weight_histogram, weight_color_temp, weight_gray_world, slider_blend.value)

    with result_output:
        clear_output(wait=True)
        display(combined_image)
        display(HTML(f"<p>Applied settings - Histogram: {weight_histogram:.2f}, Color Temp: {weight_color_temp:.2f}, Gray World: {weight_gray_world:.2f}, Blend: {slider_blend.value:.2f}</p>"))

# Handler for Auto checkbox changes
def on_auto_checkbox_change(change):
    # Enable/disable sliders based on Auto setting
    slider_histogram.disabled = change['new']
    slider_color_temp.disabled = change['new']
    slider_gray_world.disabled = change['new']

    # If we have a current frame and Auto is checked, process it
    if current_frame is not None and change['new'] == True:
        process_frame(current_frame)

def on_slider_change(change):
    # Only process the frame if we're in manual mode and have a frame
    if not checkbox_auto.value and current_frame is not None:
        process_frame(current_frame)

frame_slider = widgets.IntSlider(
    value=0,
    min=0,
    max=frame_count-1,
    step=1,
    description='Frame:',
    continuous_update=False,
    orientation='horizontal',
    layout=widgets.Layout(width='500px')
)

# Create video navigation buttons
prev_button = widgets.Button(description='Previous', icon='backward')
next_button = widgets.Button(description='Next', icon='forward')
test_button = widgets.Button(description='Test Frame', button_style='success')

# Create white balance correction controls
checkbox_auto = widgets.Checkbox(value=True, description='Auto')
slider_histogram = widgets.FloatSlider(value=0.0, min=0.0, max=1.0, step=0.01, description='Histogram', disabled=True)
slider_color_temp = widgets.FloatSlider(value=0.0, min=0.0, max=1.0, step=0.01, description='Color Temp', disabled=True)
slider_gray_world = widgets.FloatSlider(value=0.0, min=0.0, max=1.0, step=0.01, description='Gray World', disabled=True)
slider_blend = widgets.FloatSlider(value=1.0, min=0.0, max=1.0, step=0.01, description='Blend Factor')

# Output widgets to prevent recursive observe calls
slider_hist_output = widgets.Output()
slider_color_output = widgets.Output()
slider_gray_output = widgets.Output()

# Define button click handlers
def on_prev_button_clicked(b):
    new_frame = max(0, frame_slider.value - 1)
    frame_slider.value = new_frame

def on_next_button_clicked(b):
    new_frame = min(frame_count-1, frame_slider.value + 1)
    frame_slider.value = new_frame

# Connect event handlers
frame_slider.observe(lambda change: update_frame(change['new']), names='value')
prev_button.on_click(on_prev_button_clicked)
next_button.on_click(on_next_button_clicked)
test_button.on_click(on_test_button_clicked)

# Connect Auto checkbox and slider handlers
checkbox_auto.observe(on_auto_checkbox_change, names='value')

with slider_hist_output:
    slider_histogram.observe(on_slider_change, names='value')
with slider_color_output:
    slider_color_temp.observe(on_slider_change, names='value')
with slider_gray_output:
    slider_gray_world.observe(on_slider_change, names='value')
slider_blend.observe(on_slider_change, names='value')

# Create output areas
video_controls = widgets.HBox([prev_button, frame_slider, next_button])
video_output = widgets.Output()
result_output = widgets.Output()

# Display all widgets
display(widgets.HTML("<h3>Video Player</h3>"))
display(video_controls)
display(video_output)
display(widgets.HTML("<h3>White Balance Correction</h3>"))
display(widgets.VBox([checkbox_auto, slider_histogram, slider_color_temp, slider_gray_world, slider_blend, test_button]))
display(result_output)

# Initialize with the first frame
update_frame(0)

In [None]:
#@title ##**Run sequence** { display-mode: "form" }
import cv2
import imageio
import os
import tqdm
import subprocess
import numpy as np
import time


library = "ffmpeg" #@param ["cv2","imageio","ffmpeg","skvideo","scipy","moviepy"]
delay = "0.1" #@param [0, 0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5]


if (library == "ffmpeg"):
    !pip install ffmpeg-python
    import ffmpeg
    path = root_dir
    full_path = os.path.join(path, file_name)

    probe = ffmpeg.probe(full_path)
    video_info = next(stream for stream in probe['streams'] if stream['codec_type'] == 'video')
    fps = video_info['r_frame_rate']
    duration = float(video_info['duration'])
    frame_count = int(video_info['nb_frames'])

    print("FPS: ", fps)
    print("Duration: ", duration)
    print("Frames: ", frame_count)

    pbar_ffmpeg = tqdm.tqdm(total=frame_count, ncols=100, position=0, leave=True)
    process = (
        ffmpeg
        .input(full_path)
        .output('pipe:', format='rawvideo', pix_fmt='rgb24', qscale=0)
        .run_async(pipe_stdout=True)
    )

    for i in range(frame_count):
        try:
            raw_video = process.stdout.read(video_info['width'] * video_info['height'] * 3)
            frame = np.frombuffer(raw_video, dtype='uint8').reshape((video_info['height'], video_info['width'], 3))
            frame_path = f"{real_input_folder}/{i:09d}.jpg"
            if os.path.isfile(frame_path):
              pbar_ffmpeg.update(1)
              continue
            imageio.imwrite(frame_path, frame)
        except Exception as e:
            print(f"Error writing to disk: {str(e)}. Retrying...")
            continue
        pbar_ffmpeg.update(1)
        time.sleep(float(delay))

    pbar_ffmpeg.close()
    process.wait()

#check
import os

def check_frames():
    frame_dir = real_input_folder
    frames = [int(f.split('.')[0].replace('frame', '')) for f in os.listdir(frame_dir) if f.endswith('.jpg')]
    min_frame = min(frames)
    max_frame = max(frames)
    print(min_frame)
    print(max_frame)

    missing_frames = []
    for i in range(min_frame, max_frame+1):
        if i not in frames:
            missing_frames.append(i)

    if len(missing_frames) > 0:
        print(f"Missing frames: {missing_frames}")
    else:
        print("All frames present")

attempts = 0
max_attempts = 10

while attempts < max_attempts:
    try:
        check_frames()
        break
    except Exception as e:
        attempts += 1
        print(f"Attempt {attempts} failed with error: {str(e)}")
        if attempts == max_attempts:
            print("Maximum attempts reached. Execution failed.")
        else:
            print("Retrying...")


In [None]:
#@title ##**Run restore white balance** { display-mode: "form" }
import shutil
from tqdm import tqdm
import os
import re
import cv2
import numpy as np
import glob

def process_frame(img):
    if checkbox_auto.value:
        weights = iterative_white_balance(img)
        weight_histogram = weights["Histogram"]
        weight_color_temp = weights["Color Temperature"]
        weight_gray_world = weights["Gray World"]
    else:
        weight_histogram = slider_histogram.value
        weight_color_temp = slider_color_temp.value
        weight_gray_world = slider_gray_world.value

    result_bgr = apply_algorithms(img, weight_histogram, weight_color_temp, weight_gray_world)
    orig_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)
    result_rgb = cv2.cvtColor(result_bgr, cv2.COLOR_BGR2RGB).astype(np.float32)

    blend_factor = slider_blend.value
    final_rgb = (1 - blend_factor) * orig_rgb + blend_factor * result_rgb

    return cv2.cvtColor(final_rgb.clip(0, 255).astype(np.uint8), cv2.COLOR_RGB2BGR)

def check_frames():
    frame_dir = real_input_folder
    frames = [int(f.split('.')[0].replace('frame', '')) for f in os.listdir(frame_dir) if f.endswith('.jpg')]
    min_frame = min(frames)
    max_frame = max(frames)
    print(min_frame)
    print(max_frame)

    missing_frames = []
    for i in range(min_frame, max_frame+1):
        if i not in frames:
            missing_frames.append(i)

    if len(missing_frames) > 0:
        print(f"Missing frames: {missing_frames}")
    else:
        print("All frames present")

attempts = 0
max_attempts = 10

while attempts < max_attempts:
    try:
        check_frames()
        break
    except Exception as e:
        attempts += 1
        print(f"Attempt {attempts} failed with error: {str(e)}")
        if attempts == max_attempts:
            print("Maximum attempts reached. Execution failed.")
        else:
            print("Retrying...")

upload_folder = "/content/upload"
result_folder = "/content/results"

if os.path.isdir(upload_folder):
    shutil.rmtree(upload_folder)
os.makedirs(upload_folder)

if os.path.isdir(result_folder):
    shutil.rmtree(result_folder)
os.makedirs(result_folder)

file_list = os.listdir(real_input_folder)
file_list.sort()
frames = [int(f.split('.')[0].replace('', '')) for f in file_list if f.endswith('.jpg')]
min_frame = min(frames)

real_files = os.listdir(real_output_folder)
if real_files:
    real_frames = [int(re.findall(r'(\d+)\.jpg', f)[0]) for f in real_files if re.match(r'\d+\.jpg', f)]
    start_frame = max(real_frames) + 1
else:
    start_frame = min_frame

max_frame = frames[-1]
print(f"max frame: {max_frame}")
files_to_copy = [f"{real_input_folder}/{frame:09d}.jpg" for frame in range(start_frame, max_frame+1) if f"{frame:09d}.jpg" in file_list]

total_files = len(files_to_copy)
batch_size = 10
num_iterations = (total_files + batch_size - 1) // batch_size

print(f"start frame: {start_frame}")
print(f"min frame: {min_frame}")
print(f"total: {total_files}")
print(f"iterations: {num_iterations}")

with tqdm(total=num_iterations) as pbar:
    for i in range(0, total_files, batch_size):
        batch_files = files_to_copy[i:i+batch_size]

        for file in batch_files:
            copied = False
            while not copied:
                try:
                    shutil.copy(file, upload_folder)
                    copied = True
                except:
                    print(f"File {file} failed to copy. Retrying...")
                    continue

        for img_path in glob.glob(os.path.join(upload_folder, "*.jpg")):
            img = cv2.imread(img_path)
            if img is None:
                print(f"Failed to load image: {img_path}")
                continue

            output = process_frame(img)

            imgname = os.path.basename(img_path)
            save_path = os.path.join(result_folder, imgname)
            cv2.imwrite(save_path, output)

        copied = False
        while not copied:
            try:
                shutil.copytree(result_folder, real_output_folder, dirs_exist_ok=True)
                copied = True
            except:
                print("Failed to copy result folder. Retrying...")
                continue

        if os.path.isdir(upload_folder):
            shutil.rmtree(upload_folder)
        os.makedirs(upload_folder)

        if os.path.isdir(result_folder):
            shutil.rmtree(result_folder)
        os.makedirs(result_folder)

        pbar.update(1)

total_upload_files = len(os.listdir('/content/upload'))
print(f"total upload files: {total_upload_files}")

def check_frames():
    frame_dir = real_output_folder
    frames = [int(f.split('.')[0].replace('frame', '')) for f in os.listdir(frame_dir) if f.endswith('.jpg')]
    min_frame = min(frames)
    max_frame = max(frames)
    print(min_frame)
    print(max_frame)

    missing_frames = []
    for i in range(min_frame, max_frame+1):
        if i not in frames:
            missing_frames.append(i)

    if len(missing_frames) > 0:
        print(f"Missing frames: {missing_frames}")
    else:
        print("All frames present")

attempts = 0
max_attempts = 10

while attempts < max_attempts:
    try:
        check_frames()
        break
    except Exception as e:
        attempts += 1
        print(f"Attempt {attempts} failed with error: {str(e)}")
        if attempts == max_attempts:
            print("Maximum attempts reached. Execution failed.")
        else:
            print("Retrying...")

In [None]:
#@title ##**Create video** { display-mode: "form" }
import cv2
import os
import subprocess
import time
from tqdm.notebook import tqdm
import torch
import gc

gc.collect()

def log_time(start, message):
    elapsed = time.time() - start
    print(f"{message}: {elapsed:.2f} seconds")
    return time.time()

start_time = time.time()

upscaled_image = 100 #@param {type:"slider", min:0, max:100, step:1}

print(f"output_folder: {output_folder}")

if 'file_name' in locals() and os.path.exists(file_name):
    base_file_name = os.path.basename(file_name)
else:
    raise ValueError("file_name is not defined or the file does not exist")

if output_folder == "google_drive":
    save_path = '/content/drive/MyDrive/'
elif output_folder == "root":
    save_path = '/content/'
else:
    save_path = '/content/'

full_path = os.path.join(save_path, base_file_name) if not os.path.exists(file_name) else file_name
output_file_name = base_file_name.rsplit('.', 1)[0] + f'_upscale_{upscaled_image}.mp4'
output_file = os.path.join(save_path, output_file_name)
temp_video = "/content/temp_video.mp4"

start_time = log_time(start_time, "Initial setup")

cap = cv2.VideoCapture(full_path)
fps_of_video = int(cap.get(cv2.CAP_PROP_FPS))
cap.release()

upscaled_img_files = [os.path.join(real_output_folder, img) for img in os.listdir(real_output_folder) if img.lower().endswith(('.png', '.jpg', '.jpeg'))]
upscaled_img_files.sort()

if upscaled_image < 100:
    original_img_files = [os.path.join(real_input_folder, img) for img in os.listdir(real_input_folder) if img.lower().endswith(('.png', '.jpg', '.jpeg'))]
    original_img_files.sort()
    if len(upscaled_img_files) != len(original_img_files):
        raise ValueError("Number of upscaled and original frames does not match")

if upscaled_img_files:
    first_frame = cv2.imread(upscaled_img_files[0], cv2.IMREAD_COLOR)
    height, width = first_frame.shape[:2]

    needs_resize = False
    for img in upscaled_img_files[:10]:
        frame = cv2.imread(img, cv2.IMREAD_COLOR)
        if frame.shape[:2] != (height, width):
            needs_resize = True
            break
        del frame
    del first_frame
else:
    raise ValueError("No images found in the upscaled frames folder")

start_time = log_time(start_time, "Frame list preparation")

def get_video_bitrate(file_path):
    cmd = ['ffprobe', '-v', 'error', '-select_streams', 'v:0', '-show_entries', 'stream=bit_rate', '-of', 'default=noprint_wrappers=1:nokey=1', file_path]
    result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
    bitrate = result.stdout.strip()
    try:
        return int(bitrate)
    except ValueError:
        return None

bitrate = get_video_bitrate(full_path)
if bitrate:
    bitrate = int(bitrate * 1.5)
    bitrate_str = f'{bitrate // 1000}k'
else:
    bitrate_str = '7500k'

fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(temp_video, fourcc, fps_of_video, (width, height))

if upscaled_image == 100:
    for img_file in tqdm(upscaled_img_files, desc="Processing frames"):
        frame = cv2.imread(img_file, cv2.IMREAD_COLOR)
        if needs_resize and frame.shape[:2] != (height, width):
            frame = cv2.resize(frame, (width, height))
        out.write(frame)
        del frame
else:
    alpha = upscaled_image / 100.0
    beta = 1 - alpha
    for upscaled_img, original_img in tqdm(zip(upscaled_img_files, original_img_files), total=len(upscaled_img_files), desc="Processing frames"):
        upscaled_frame = cv2.imread(upscaled_img, cv2.IMREAD_COLOR)
        original_frame = cv2.imread(original_img, cv2.IMREAD_COLOR)

        if needs_resize and upscaled_frame.shape[:2] != (height, width):
            upscaled_frame = cv2.resize(upscaled_frame, (width, height))

        original_frame_resized = cv2.resize(original_frame, (width, height))

        blended_frame = cv2.addWeighted(upscaled_frame, alpha, original_frame_resized, beta, 0)

        out.write(blended_frame)
        del upscaled_frame, original_frame, original_frame_resized, blended_frame

out.release()
gc.collect()

start_time = log_time(start_time, "Frame processing and writing")


temp_converted = "/content/temp_converted.mp4"
cmd = ['ffmpeg', '-i', temp_video, '-c:v', 'libx264', '-b:v', bitrate_str, '-y', temp_converted]
result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
if result.returncode != 0:
    print(f"FFmpeg conversion failed: {result.stderr}")
    raise RuntimeError("Conversion to libx264 failed")
os.remove(temp_video)
os.rename(temp_converted, temp_video)

start_time = log_time(start_time, "FFmpeg conversion to libx264")

cmd = ['ffmpeg', '-i', temp_video, '-i', full_path, '-map', '0:v', '-map', '1:a?', '-map', '1:s?', '-c', 'copy', '-y', output_file]
result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
if result.returncode != 0:
    print(f"FFmpeg audio muxing failed: {result.stderr}")
    raise RuntimeError("Audio muxing failed")

start_time = log_time(start_time, "Final audio and subtitles muxing")

if os.path.exists(output_file):
    if os.path.exists(temp_video):
        os.remove(temp_video)
    print("Video created successfully")
    print(f"Video saved at: {output_file}")
else:
    print("Failed to create video")
    print(f"Expected save path: {output_file}")
    print(f"FFmpeg error output: {result.stderr}")

start_time = log_time(start_time, "Cleanup")

In [None]:
#@title ##**Compare videos (optional)** { display-mode: "form" }
from IPython.display import display, HTML
import os
import base64

original_video_path = file_name
processed_video_path = output_file

if not os.path.exists(original_video_path):
    raise ValueError(f"Оригинальное видео не найдено по пути: {original_video_path}")
if not os.path.exists(processed_video_path):
    raise ValueError(f"Обработанное видео не найдено по пути: {processed_video_path}")

original_size = os.path.getsize(original_video_path) / (1024 * 1024)
processed_size = os.path.getsize(processed_video_path) / (1024 * 1024)
print(f"Размер оригинального видео: {original_size:.2f} МБ")
print(f"Размер обработанного видео: {processed_size:.2f} МБ")


def video_to_base64(video_path):
    with open(video_path, "rb") as video_file:
        video_data = video_file.read()
    return base64.b64encode(video_data).decode('utf-8')

original_base64 = video_to_base64(original_video_path)
processed_base64 = video_to_base64(processed_video_path)

html_code = f"""
<div style="display: flex; justify-content: center; flex-direction: column; align-items: center;">
    <div style="display: flex; justify-content: center;">
        <div style="margin-right: 10px;">
            <video id="originalVideo" width="400" controls preload="auto">
                <source src="data:video/mp4;base64,{original_base64}" type="video/mp4">
                Ваш браузер не поддерживает видео.
            </video>
            <p>Оригинальное видео</p>
        </div>
        <div>
            <video id="processedVideo" width="400" controls preload="auto">
                <source src="data:video/mp4;base64,{processed_base64}" type="video/mp4">
                Ваш браузер не поддерживает видео.
            </video>
            <p>Обработанное видео</p>
        </div>
    </div>
    <button id="playPauseBtn" style="margin-top: 10px; padding: 10px 20px; font-size: 16px;">Play</button>
</div>
<script>
(function() {{
    var originalVideo = document.getElementById("originalVideo");
    var processedVideo = document.getElementById("processedVideo");
    var playPauseBtn = document.getElementById("playPauseBtn");
    var isPlaying = false;

    playPauseBtn.disabled = false;

    function playBoth() {{
        Promise.all([
            originalVideo.play().catch(function(error) {{
                console.log("Ошибка воспроизведения оригинального видео:", error);
            }}),
            processedVideo.play().catch(function(error) {{
                console.log("Ошибка воспроизведения обработанного видео:", error);
            }})
        ]).then(function() {{
            playPauseBtn.textContent = "Pause";
            isPlaying = true;
        }}).catch(function(error) {{
            console.log("Не удалось воспроизвести видео:", error);
        }});
    }}

    function pauseBoth() {{
        originalVideo.pause();
        processedVideo.pause();
        playPauseBtn.textContent = "Play";
        isPlaying = false;
    }}

    playPauseBtn.addEventListener("click", function() {{
        if (isPlaying) {{
            pauseBoth();
        }} else {{
            playBoth();
        }}
    }});

    originalVideo.addEventListener("play", function() {{
        if (processedVideo.paused) processedVideo.play();
        playPauseBtn.textContent = "Pause";
        isPlaying = true;
    }});
    processedVideo.addEventListener("play", function() {{
        if (originalVideo.paused) originalVideo.play();
        playPauseBtn.textContent = "Pause";
        isPlaying = true;
    }});
    originalVideo.addEventListener("pause", function() {{
        if (!processedVideo.paused) processedVideo.pause();
        playPauseBtn.textContent = "Play";
        isPlaying = false;
    }});
    processedVideo.addEventListener("pause", function() {{
        if (!originalVideo.paused) originalVideo.pause();
        playPauseBtn.textContent = "Play";
        isPlaying = false;
    }});

    originalVideo.addEventListener("timeupdate", function() {{
        if (Math.abs(originalVideo.currentTime - processedVideo.currentTime) > 0.5) {{
            processedVideo.currentTime = originalVideo.currentTime;
        }}
    }});
    processedVideo.addEventListener("timeupdate", function() {{
        if (Math.abs(processedVideo.currentTime - originalVideo.currentTime) > 0.5) {{
            originalVideo.currentTime = processedVideo.currentTime;
        }}
    }});

    console.log("Скрипт синхронизации видео инициализирован");
}})();
</script>
"""

display(HTML(html_code))

In [None]:
#@title ##**Download video** { display-mode: "form" }
from google.colab import files
import os

if 'output_file' not in locals() or not os.path.exists(output_file):
    print("Video not found")
else:
    files.download(output_file)