In [None]:
import os
import cv2
import tqdm
import torch
import rawpy
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path
from scipy.optimize import least_squares


def stretch_histogram(image, clip_percentile=0):
    im = np.array(image)
    for c in range(3):
        im[:, :, c] = im[:, :, c].clip(
            np.percentile(im[:, :, c], clip_percentile),
            np.percentile(im[:, :, c], 100 - clip_percentile)
        )
        im[:, :, c] -= im[:, :, c].min()
        im[:, :, c] /= im[:, :, c].max()
    return im


def process_depth(image, depth):
    im = np.array(image)
    for c in range(3):
        im[z == 0, c] = im[z != 0, c].min()
    return im


def plot_image(image, **kwargs):
    im = image - image.min()
    im = im / im.max()
    plt.figure(figsize=(20, 20))
    plt.imshow(im, **kwargs)


def save_image(image, image_name):
    im = image - image.min()
    im = im / im.max()
    im = im * 255
    Image.fromarray(im.astype(np.uint8)).save(image_name)


def load_image(data_path, image_name):
    image_file = data_path / 'Raw' / image_name
    depth_file = data_path / 'depthMaps' / f'depth{os.path.splitext(image_name)[0]}.tif'
    raw = rawpy.imread(str(image_file))
    bayer = raw.raw_image_visible.astype(np.float64)
    Ir = bayer[::2, ::2]
    Ig = (bayer[::2, 1::2] + bayer[1::2, ::2]) / 2
    Ib = bayer[1::2, 1::2]
    I = np.stack([Ir, Ig, Ib], axis=2)
    z = np.array(Image.open(depth_file), dtype=np.float64)
    z = cv2.resize(z, I.shape[:2][::-1])
    z[z != 0] = z[z != 0].clip(
        np.percentile(z[z != 0], 1),
        np.percentile(z[z != 0], 99),
    )
    return raw, I, z


def demosaic(data_path, image_name, I, white_balance=False):
    image_file = data_path / 'Raw' / image_name
    raw = rawpy.imread(str(image_file))
    I = I - I.min()
    I = I / I.max()
    I = I * (raw.raw_image_visible.max() - raw.raw_image_visible.min())
    I = I + raw.raw_image_visible.min()
    I = I.astype(np.uint16)
    raw.raw_image_visible[::2, ::2] = I[:, :, 0]
    raw.raw_image_visible[::2, 1::2] = I[:, :, 1]
    raw.raw_image_visible[1::2, ::2] = I[:, :, 1]
    raw.raw_image_visible[1::2, 1::2] = I[:, :, 2]
    return raw.postprocess(half_size=True, use_auto_wb=white_balance)


def compute_omega(Ic, z, percentile, min_z_percentile=1, max_z_percentile=99):
    
    z_range = np.linspace(
        z[z != 0].min(),
        z[z != 0].max(),
        11
    )
    
    omega = []
    for min_z, max_z in zip(z_range[:-1], z_range[1:]):
        args_z_in_range = np.argwhere((z >= min_z) & (z < max_z))
        Ic_in_range = Ic[args_z_in_range[:, 0], args_z_in_range[:, 1]]
        if percentile < 50:
            omega_range_mask = Ic_in_range < np.percentile(Ic_in_range, percentile)
        else:
            omega_range_mask = Ic_in_range > np.percentile(Ic_in_range, percentile)
        omega.append(args_z_in_range[omega_range_mask])
    return np.vstack(omega)


def compute_backscatter(Ic, z):
    
    print('Computing backscatter')
    
    omega = compute_omega(Ic, z, 1)
    
    Bc_hat = Ic[omega[:, 0], omega[:, 1]]
    z_low = z[omega[:, 0], omega[:, 1]]
    
    def residuals(x):
        Bc_inf, beta_Bc = x
        return (Bc_hat - Bc_inf * (1 - np.exp(-beta_Bc * z_low))).flatten()
    
    return least_squares(
        residuals,
        [Bc_hat.mean(), 2.5],
        bounds=([0, 0], [1, 5]),
        jac='3-point',
        verbose=2
    ).x


def compute_backscatter_with_residuals(Ic, z):
    
    print('Computing backscatter')
    
    omega = compute_omega(Ic, z, 1)
    
    Bc_hat = Ic[omega[:, 0], omega[:, 1]]
    z_low = z[omega[:, 0], omega[:, 1]]
    
    def residuals(x):
        Bc_inf, beta_Bc, Jc_prime, beta_Dc_prime = x
        return (Bc_hat - Bc_inf * (1 - np.exp(-beta_Bc * z_low)) - Jc_prime * np.exp(-beta_Dc_prime * z_low)).flatten()
    
    return least_squares(
        residuals,
        [Bc_hat.mean(), 2.5, 0, 5],
        bounds=([0, 0, 0, 0], [1, 5, 1, 5]),
        jac='3-point',
        verbose=2
    ).x[:2]


def beta_D(beta_D_a, beta_D_b, beta_D_c, beta_D_d, z):
    return beta_D_a * np.exp(beta_D_b * z) + beta_D_c * np.exp(beta_D_d * z)


def compute_beta_D(Dc, z):
    
    print('Computing beta_D')
    
    omega = compute_omega(Dc, z, 99)
    
    Dc_top = Dc[omega[:, 0], omega[:, 1]]
    z_top = z[omega[:, 0], omega[:, 1]]
    
    def residuals(x):
        beta_Dc_a, beta_Dc_b, beta_Dc_c, beta_Dc_d = x
        beta_Dc = beta_D(beta_Dc_a, beta_Dc_b, beta_Dc_c, beta_Dc_d, z_top)
        Ac = np.sum(Dc_top * np.exp(-beta_Dc * z_top)) / np.sum(np.exp(-2 * beta_Dc * z_top))
        return (Dc_top - Ac * np.exp(-beta_Dc * z_top)).flatten()
    
    return least_squares(
        residuals,
        [0.05, -0.05, 0.05, -0.05],
        bounds=([0, -np.inf, 0, -np.inf], [np.inf, 0, np.inf, 0]),
        jac='3-point',
        verbose=2,
        #max_nfev=10000
    ).x

In [None]:
#data_path = Path('dataset/D1P1/')
#image_name = 'T_S03075.ARW'
data_path = Path('dataset/D5/')
image_name = 'LFT_3400.NEF'

raw, I, z = load_image(data_path, image_name)
I = I / I.max()

B_inf = np.zeros(3)
beta_B = np.zeros(3)
for c in range(3):
    B_inf[c], beta_B[c] = compute_backscatter(I[:, :, c], z)

B = B_inf * (1 - np.exp(-beta_B * z[..., np.newaxis]))
D = I -  B

beta_D_a = np.zeros(3)
beta_D_b = np.zeros(3)
beta_D_c = np.zeros(3)
beta_D_d = np.zeros(3)
for c in range(3):
    beta_D_a[c], beta_D_b[c], beta_D_c[c], beta_D_d[c] = compute_beta_D(D[:, :, c], z)

J = D * np.exp(beta_D(beta_D_a, beta_D_b, beta_D_c, beta_D_d, z[..., np.newaxis]) * z[..., np.newaxis])

In [None]:
wb = True
Image.fromarray(raw.postprocess(half_size=True, use_auto_wb=wb)).save('test/raw.png')
Image.fromarray(demosaic(data_path, image_name, stretch_histogram(process_depth(I, z)), white_balance=wb)).save('test/I.png')
Image.fromarray(demosaic(data_path, image_name, stretch_histogram(process_depth(D, z)), white_balance=wb)).save('test/D.png')
Image.fromarray(demosaic(data_path, image_name, stretch_histogram(process_depth(J, z)), white_balance=wb)).save('test/J.png')

In [None]:
plt.hist([J_wb[:, :, 0].flatten(), J_wb[:, :, 1].flatten(), J_wb[:, :, 2].flatten()], color=['r', 'g', 'b'])
plt.yscale('log')

In [None]:
plot_image(process_depth(B, z))

In [None]:
plot_image(stretch_histogram(process_depth(I, z), 1))

In [None]:
plot_image(stretch_histogram(process_depth(J, z), 1))

In [None]:
B.max()

In [None]:
plot_image(I)

In [None]:
I.min()