In [None]:
import os
import cv2
import tqdm
import torch
import rawpy
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path
from scipy.optimize import least_squares


def load_image(data_path, image_name):
    image_file = data_path / 'Raw' / image_name
    depth_file = data_path / 'depthMaps' / f'depth{os.path.splitext(image_name)[0]}.tif'
    depth = np.array(Image.open(depth_file), dtype=np.float64)
    raw = rawpy.imread(str(image_file))
    rgb = raw.postprocess()
    rgb = cv2.resize(rgb / 255, depth.shape[::-1])
    return rgb, depth


def plot_image(image, **kwargs):
    plt.figure(figsize=(20, 20))
    plt.imshow(image, **kwargs)


def save_image(image, image_name):
    im = image - image.min()
    im = im / im.max()
    im = im * 255
    Image.fromarray(im.astype(np.uint8)).save(image_name)


def stretch_histogram(image, clip_percentile=0.0):
    im = image.copy()
    for channel in range(3):
        im[:, :, channel] = im[:, :, channel].clip(
            np.percentile(im[:, :, channel], clip_percentile),
            np.percentile(im[:, :, channel], 100 - clip_percentile)
        )
        im[:, :, channel] = (im[:, :, channel] - im[:, :, channel].mean()) / im[:, :, channel].std()
        im[:, :, channel] -= im[:, :, channel].min()
        im[:, :, channel] /= im[:, :, channel].max()
    return im


def compute_omega(image, depth, min_depth_percentile=1, max_depth_percentile=99):
    
    depth_ranges = np.linspace(
        np.percentile(depth, min_depth_percentile),
        np.percentile(depth, max_depth_percentile),
        11
    )
    
    omega = []
    for min_depth, max_depth in zip(depth_ranges[:-1], depth_ranges[1:]):
        args_depth_in_range = np.argwhere((depth >= min_depth) & (depth < max_depth))
        im_in_range = image[args_depth_in_range[:, 0], args_depth_in_range[:, 1]]
        mean_im_in_range = im_in_range.mean(axis=1)
        omega_range_mask = mean_im_in_range < np.percentile(mean_im_in_range, 1)
        omega.append(args_depth_in_range[omega_range_mask])
    return np.vstack(omega)


def compute_backscatter(image, depth):
    
    omega = compute_omega(image, depth)
    
    B_hat = image[omega[:, 0], omega[:, 1]]
    z = depth[omega[:, 0], omega[:, 1], np.newaxis].repeat(3, 1)
    
    B_inf_init = B_hat.mean(axis=0).tolist()
    beta_B_init = [2.5, 2.5, 2.5]
    J_prime_init = [0.0, 0.0, 0.0]
    beta_D_prime_init = [0.0, 0.0, 0.0]
    
    def residuals(x):
        B_inf, beta_B, J_prime, beta_D_prime = np.array_split(x, 4)
        return (B_hat - B_inf * (1 - np.exp(-beta_B * z)) - J_prime * np.exp(-beta_D_prime * z)).flatten()
    
    return least_squares(
        residuals,
        B_inf_init + beta_B_init + J_prime_init + beta_D_prime_init,
        bounds=([0] * 12, [1, 1, 1, 5, 5, 5] * 2),
        jac='3-point',
        verbose=2
    ).x

In [None]:
I, z = load_image(Path('dataset/D1P1/'), 'T_S03136.ARW')

B_inf, beta_B, J_prime, beta_D_prime = np.array_split(compute_backscatter(I, z), 4)

B = B_inf * (1 - np.exp(-beta_B * z[..., np.newaxis].repeat(3, 2)))
D = I - B

save_image(stretch_histogram(D, 0), 'test/D.png')
save_image(stretch_histogram(I, 0), 'test/I.png')
save_image(I, 'test/IRaw.png')

In [None]:
B_inf

In [None]:
beta_B

In [None]:
J_prime

In [None]:
beta_D_prime