In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os
import sys
module_path = os.path.abspath(os.path.join(".."))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
import json
from pathlib import Path
from typing import List

import cv2
import dataclass_array as dca
import matplotlib.pyplot as plt
import mitsuba as mi
import numpy as np
from PIL import Image
import pycolmap
import pyrender
import transforms3d as t3d
import trimesh
import visu3d as v3d

import burybarrel.colmap_util as cutil
from burybarrel.image import render_v3d, imgs_from_dir

In [None]:
def temporal_window(imgs, idx, window_size):
    if window_size % 2 == 0:
        window_size += 1
    half = window_size // 2
    if idx >= half and idx < len(imgs) - half:
        return imgs[idx - half:idx + half + 1]
    elif idx >= len(imgs) - half:
        return imgs[-window_size:]
    else:
        return imgs[:window_size]

In [114]:
def find_img_quantiles(in_img, in_mask, quantiles):
    alpha = np.sqrt(1e6 / (in_img.shape[1] * in_img.shape[0]))

    if alpha < 1.0:
        reduced_img = cv2.resize(in_img, (0, 0), fx=alpha, fy=alpha)
        if in_mask is not None:
            reduced_mask = cv2.resize(in_mask, (0, 0), fx=alpha, fy=alpha, interpolation=cv2.INTER_NEAREST)
        else:
            reduced_mask = None
    else:
        reduced_img = in_img
        reduced_mask = in_mask

    ch_values = []

    if reduced_mask is None:
        ch_values = reduced_img.flatten()
    else:
        ch_values = reduced_img[reduced_mask == 0].flatten()

    ch_lim = np.quantile(ch_values, quantiles).astype(int)
    return ch_lim

def stretch_color_img(in_img, ch1_lim, ch2_lim, ch3_lim, gamma_undo):
    # Split img channels
    temp_rgb = cv2.split(in_img)
    temp_rgb_out = [None, None, None]

    # Stretches all channels
    ch1_low_high_in = (ch1_lim[0], ch1_lim[1])
    ch2_low_high_in = (ch2_lim[0], ch2_lim[1])
    ch3_low_high_in = (ch3_lim[0], ch3_lim[1])
    low_high_out = (0, 255)

    temp_rgb_out[0] = histogram_stretch(temp_rgb[0], ch1_low_high_in, low_high_out, gamma_undo)
    temp_rgb_out[1] = histogram_stretch(temp_rgb[1], ch2_low_high_in, low_high_out, gamma_undo)
    temp_rgb_out[2] = histogram_stretch(temp_rgb[2], ch3_low_high_in, low_high_out, gamma_undo)

    # Merge channels
    stretched_img = cv2.merge(temp_rgb_out)
    return stretched_img

def histogram_stretch(in_img, low_high_in, low_high_out, gamma_undo):
    # Init
    stretched_img = np.zeros(in_img.shape, dtype=in_img.dtype)

    if abs(low_high_in[1] - low_high_in[0]) < 1:
        return in_img.copy()

    # Pre-compute low/high limits between 0.0 and 1.0
    inv_max_val = 1.0 / 255.0
    low_high_in = (low_high_in[0] * inv_max_val, low_high_in[1] * inv_max_val)
    low_high_out = (low_high_out[0] * inv_max_val, low_high_out[1] * inv_max_val)
    low_high_coef = (low_high_out[1] - low_high_out[0]) / (low_high_in[1] - low_high_in[0])

    # undo gamma if needed
    if gamma_undo:
        low_high_in = (rgb2linf(low_high_in[0]), rgb2linf(low_high_in[1]))
        low_high_out = (rgb2linf(low_high_out[0]), rgb2linf(low_high_out[1]))

    # Set-up LUT to store the mapping to apply for each intensity value
    look_up_table = np.zeros((256,), dtype=np.uint8)

    # Compute the new intensity value to apply for each initial intensity
    for i in range(256):
        current_intensity = i * inv_max_val

        if gamma_undo:
            current_intensity = rgb2linf(current_intensity)

        out_intensity = low_high_coef * (current_intensity - low_high_in[0]) + low_high_out[0]

        if gamma_undo:
            look_up_table[i] = np.clip(255.0 * lin2rgbf(out_intensity), 0, 255).astype(np.uint8)
        else:
            look_up_table[i] = np.clip(255.0 * out_intensity, 0, 255).astype(np.uint8)

    # Apply the intensity mapping
    stretched_img = cv2.LUT(in_img, look_up_table)
    return stretched_img

def rgb2linf(value):
    # gamma = 2.19921875
    return value ** 2.19921875

def lin2rgbf(value):
    # 1.0/2.19921875 = 0.45470692
    return value ** 0.45470692

In [None]:
temporal_window([0,1,2,3,4,5,6,7,8,9], 9, 5)

In [None]:
outdir = Path("../results/dive3-depthcharge-03-04-trimmed-reconstr/corrected")
imgdir = Path("../data/dive-data/Dive3/clips/dive3-depthcharge-03-04-trimmed")
imgpaths, imgs = imgs_from_dir(imgdir, asarray=True)

$$R(x,y)=\overline{\bar{f}(x,y)}\cdot\min(\frac{1}{\bar{f}(x,y)},\frac{\text{Maxscale}}{\bar{f}(x,y)+4\sigma(x,y)})$$

In [125]:
for i, img in enumerate(imgs):
    windowimg = temporal_window(imgs, i, 13)
    avgpx = np.median(windowimg, axis=0)
    stdpx = np.std(windowimg, axis=0)
    avgall = np.median(windowimg)
    maxscale = 255
    correction = avgall * np.min([1 / avgpx, maxscale / (avgpx + 4 * stdpx)], axis=0)
    lightcorrected = (imgs[i] * correction).astype(np.uint8)
    m_sat_thres = 0.001
    quantiles = [m_sat_thres, 1.0 - m_sat_thres]
    img = imgs[0]
    ch1_lim = find_img_quantiles(img[:, :, 0], None, quantiles)
    ch2_lim = find_img_quantiles(img[:, :, 1], None, quantiles)
    ch3_lim = find_img_quantiles(img[:, :, 2], None, quantiles)
    colcorrected = stretch_color_img(lightcorrected, ch1_lim, ch2_lim, ch3_lim, False)
    Image.fromarray(colcorrected).save(outdir / imgpaths[i].name)