In [None]:
import cv2
import tqdm
import h5py
import torch
import pycolmap
import numpy as np
from pathlib import Path
from PIL import Image
from torch.func import jacfwd
from torch.autograd.functional import jacobian

import sfm
import loader


def normalize(unattenuated):
    restored = unattenuated.cpu().numpy().copy()
    valid = np.all(~np.isnan(restored), axis=2)
    restored_valid = restored[valid]
    restored_valid = np.clip(restored_valid, np.percentile(restored_valid, 1, axis=0), np.percentile(restored_valid, 99, axis=0))
    restored_valid = restored_valid - np.min(restored_valid, axis=0)
    restored_valid = restored_valid / np.max(restored_valid, axis=0)
    restored[~valid] = 0
    restored[valid] = restored_valid
    return Image.fromarray(np.uint8(restored * 255))


def sparse_norm(P):
    return torch.sparse_coo_tensor(
        indices=P.indices(),
        values=P.values().norm(dim=1, keepdim=True),
        size=(*P.shape[:-1], 1),
        is_coalesced=True
    )


def sparse_exp(P):
    return torch.sparse_coo_tensor(
        indices=P.indices(),
        values=P.values().exp(),
        size=P.shape,
        is_coalesced=True
    )


def sparse_add(P, value):
    return torch.sparse_coo_tensor(
        indices=P.indices(),
        values=P.values() + value,
        size=P.shape,
        is_coalesced=True
    )


device='cuda'
num_workers = 6

In [None]:
colmap_model = sfm.COLMAPModel(
    model_dir=Path('/media/clementin/data/Dehazing/2015/sparse/'),
    image_dir=Path('/media/clementin/data/Dehazing/2015/images/'),
    depth_dir=Path('/media/clementin/data/Dehazing/2015/depth_maps/'),
    image_scale=0.25
)

image = colmap_model['20150418T030314.000Z.png']
matches_file = loader.MatchesFile(path=Path('test/20150418T030314.000Z.h5'), colmap_model=colmap_model, overwrite=False)

In [None]:
image.match_images(
    image_list=list(colmap_model.images.values()),
    matches_file=matches_file,
    min_cover=0.01,
    num_workers=num_workers,
    device=device
)
matches_file.prepare_matches(num_workers=num_workers)

In [None]:
matches_file.check_integrity()
cP, I = matches_file.load_matches(image, device=device)
z = sparse_norm(cP)
del cP

In [None]:
def compute_cost(B, beta, gamma):
    absorption = sparse_exp(-beta.view(1, 1, 3) * z)
    backscatter = B.view(1, 1, 3) * sparse_add(-sparse_exp(-gamma.view(1, 1, 3) * z), 1)
    D = I - backscatter
    J = torch.sum(D * absorption, dim=0).to_dense() / torch.sum(absorption.square(), dim=0).to_dense()
    cost = torch.square(I - (J * absorption + backscatter)).sum()
    return cost

B_hat = torch.full((3,), 0.01, device=device)
beta_hat = torch.full((3,), 0.01, device=device)
gamma_hat = torch.full((3,), 0.01, device=device)

for _ in range(100):
    jac = jacobian(compute_cost, (B_hat, beta_hat, gamma_hat))
    hess = torch.func.hessian(compute_cost, argnums=(0, 1, 2))(B_hat, beta_hat, gamma_hat)
    #print(compute_residuals(B_hat, beta_hat, gamma_hat))
    diff = torch.tensor(hess).inverse() @ torch.tensor(jac)
    B_hat -= diff[0]
    beta_hat -= diff[1]
    gamma_hat -= diff[2]
    print(B_hat.item(), beta_hat.item(), gamma_hat.item())

In [None]:
cost

In [None]:
normalize(J.view(image.camera.height, image.camera.width, 3))

In [None]:
normalize(J.view(image.camera.height, image.camera.width, 3))

In [None]:
def compute_residuals(B, beta, gamma):
    numerator = torch.zeros((image.camera.height, image.camera.width, 3), device=device)
    denominator = torch.zeros((image.camera.height, image.camera.width, 3), device=device)
    for ui, vi, cPi, Ii in zip(u, v, cP, I):
        zi = cPi.norm(dim=0)
        ai = torch.exp(-beta * zi)
        bi = 1 - torch.exp(-gamma * zi)
        Di = Ii - B * bi
        numerator[vi.long(), ui.long()] += (Di * bi).T
        denominator[vi.long(), ui.long()] += ai.square().T
    J = numerator / denominator
    cursor = 0
    residuals = torch.zeros((3, n_obs), device=device)
    for ui, vi, cPi, Ii in zip(u, v, cP, I):
        length = ui.shape[0]
        zi = cPi.norm(dim=0)
        residuals[:, cursor:cursor + length] = Ii - J[vi.long(), ui.long()].T * torch.exp(-beta * zi) - B * (1 - torch.exp(-gamma * zi))
        cursor += length
    return residuals.flatten()

B_hat = torch.full((3, 1), 0.01, device=device)
beta_hat = torch.full((3, 1), 0.01, device=device)
gamma_hat = torch.full((3, 1), 0.01, device=device)

jac = torch.func.jacfwd(compute_residuals)(B_hat, beta_hat, gamma_hat)

In [None]:
def compute_residuals(B, beta, gamma):
    numerator = torch.zeros((image.camera.height, image.camera.width), device=device)
    denominator = torch.zeros((image.camera.height, image.camera.width), device=device)
    for ui, vi, cPi, Ii in zip(u, v, cP, I):
        zi = cPi.norm(dim=0)
        ai = torch.exp(-beta * zi)
        bi = 1 - torch.exp(-gamma * zi)
        Di = Ii[2] - B * bi
        numerator[vi.long(), ui.long()] += Di * bi
        denominator[vi.long(), ui.long()] += ai.square()
    J = numerator / denominator
    cost = 0
    for ui, vi, cPi, Ii in zip(u, v, cP, I):
        zi = cPi.norm(dim=0)
        cost += torch.square(Ii[2] - J[vi.long(), ui.long()] * torch.exp(-beta * zi) - B * (1 - torch.exp(-gamma * zi))).sum()
    return cost


B_hat = torch.tensor(0.25, device=device)
beta_hat = torch.tensor(0.1, device=device)
gamma_hat = torch.tensor(0.1, device=device)

# jac = jacobian(compute_residuals, (B_hat, beta_hat, gamma_hat))

In [None]:
for _ in range(100):
    jac = jacobian(compute_residuals, (B_hat, beta_hat, gamma_hat))
    hess = torch.func.hessian(compute_residuals, argnums=(0, 1, 2))(B_hat, beta_hat, gamma_hat)
    print(compute_residuals(B_hat, beta_hat, gamma_hat))
    diff = torch.tensor(hess).inverse() @ torch.tensor(jac)
    B_hat -= diff[0]
    beta_hat -= diff[1]
    gamma_hat -= diff[2]
    print(B_hat.item(), beta_hat.item(), gamma_hat.item())

In [None]:
jac = torch.func.jacfwd(compute_residuals)(B_hat, beta_hat, gamma_hat)
hess = torch.func.jacfwd(torch.func.jacfwd(compute_residuals))(B_hat, beta_hat, gamma_hat)

In [None]:
torch.tensor(hess)

In [None]:
hess.max()

In [None]:
jac.re