In [None]:
import cv2
import se3
import tqdm
import h5py
import torch
import pycolmap
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
from torch.func import jacfwd
from torch.autograd.functional import jacobian
from IPython.display import display
from torch import Tensor

import sfm
import loader

class SUCRe(torch.nn.Module):
    def __init__(self, image: sfm.Image, light_model=False):
        super().__init__()
        self.image = image
        self.light_model = light_model
        self.B = torch.nn.Parameter(torch.tensor([[0.1], [0.1], [0.1]]))
        self.beta = torch.nn.Parameter(torch.tensor([[0.1], [0.1], [0.1]]))
        self.gamma = torch.nn.Parameter(torch.tensor([[0.1], [0.1], [0.1]]))
        if self.light_model:
            self.cam2light = torch.nn.Parameter(torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0], device=device))
            self.sigma = torch.nn.Parameter(torch.eye(2, device=device))

    def compute_l_z(self, cP: Tensor) -> tuple[float | Tensor, Tensor]:
        z = cP.norm(dim=0)
        if self.light_model:
            R, t = se3.exp(self.cam2light)
            Sigma = self.sigma.T @ self.sigma
            lP = R @ cP + t
            lp = lP[:2] / lP[2]
            lp = lp.T.unsqueeze(dim=2)
            l = torch.exp(-torch.flatten(lp.transpose(1, 2) @ Sigma.inverse() @ lp) / 2)
            z += lP.norm(dim=0)
        else:
            l = 1.0
        return l, z

    def compute_J(self, matches_data: loader.MatchesData) -> Tensor:
        J_numerator = torch.zeros((self.image.camera.height, self.image.camera.width, 3), device=self.B.device)
        J_denominator = torch.zeros((self.image.camera.height, self.image.camera.width, 3), device=self.B.device)
        
        for u, v, cP, I in matches_data.iter(device=self.B.device):
            l, z = self.compute_l_z(cP)
            absorption = l * torch.exp(-self.beta * z)
            backscatter = l * self.B * (1 - torch.exp(-self.gamma * z))
            J_numerator[v, u] += ((I - backscatter) * absorption).T
            J_denominator[v, u] += absorption.square().T
        
        J = J_numerator / J_denominator
        return J

    def forward(self, J: Tensor, u: Tensor, v: Tensor, cP: Tensor) -> Tensor:
        l, z = self.compute_l_z(cP)
        I_hat = l * (J[v, u].T * torch.exp(-self.beta * z) + self.B * (1 - torch.exp(-self.gamma * z)))
        return I_hat

    def plot_l(self):
        with torch.no_grad():
            u, v, cP = self.image.unproject_depth_map(self.image.get_depth_map().to(self.cam2light.device), to_world=False)
            l, _ = self.compute_l_z(cP)
            l_map = torch.zeros((self.image.camera.height, self.image.camera.width), device=l.device)
            l_map[v, u] = l
        return Image.fromarray(np.uint8(plt.colormaps['jet'](l_map.cpu().numpy())[:, :, :3] * 255))


def normalize(unattenuated):
    restored = unattenuated.detach().cpu().numpy().copy()
    valid = np.all(~np.isnan(restored), axis=2)
    restored_valid = restored[valid]
    restored_valid = np.clip(restored_valid, np.percentile(restored_valid, 1, axis=0), np.percentile(restored_valid, 99, axis=0))
    restored_valid = restored_valid - np.min(restored_valid, axis=0)
    restored_valid = restored_valid / np.max(restored_valid, axis=0)
    restored[~valid] = 0
    restored[valid] = restored_valid
    return Image.fromarray(np.uint8(restored * 255))


device='cuda'
num_workers = 6

In [None]:
colmap_model = sfm.COLMAPModel(
    model_dir=Path('/media/clementin/data/Dehazing/2015/sparse/'),
    image_dir=Path('/media/clementin/data/Dehazing/2015/images/'),
    depth_dir=Path('/media/clementin/data/Dehazing/2015/depth_maps/'),
    image_scale=0.25
)

image = colmap_model['20150418T030314.000Z.png']
matches_file = loader.MatchesFile(path=Path('test/20150418T030314.000Z.h5'), colmap_model=colmap_model, overwrite=False)

In [None]:
image.match_images(
    image_list=list(colmap_model.images.values()),
    matches_file=matches_file,
    min_cover=0.01,
    num_workers=num_workers,
    device=device
)
matches_file.prepare_matches(num_workers=num_workers)

In [None]:
matches_file.check_integrity()

matches_data = matches_file.load_matches(pin_memory=True)

In [None]:
sucre = SUCRe(image, light_model=True)
sucre = sucre.to(device)

n_obs = len(matches_data)
optimizer = torch.optim.Adam(sucre.parameters(), lr=0.05)

costs = []

for iteration in tqdm.tqdm(range(1000)):
    cost = 0
    optimizer.zero_grad()

    with torch.no_grad():
        J = sucre.compute_J(matches_data)
    
    for u, v, cP, I in matches_data.iter(batch_size=5, device=device):
        loss = torch.square(I - sucre(J=J, u=u, v=v, cP=cP)).sum()
        (loss / n_obs / 3).backward()
        cost += loss.item()

    optimizer.step()
    costs.append(cost)
    if iteration % 25 == 0:
        display(sucre.plot_l())

## Jacobian

In [None]:
def compute_residuals(Br, Bg, Bb, betar, betag, betab, gammar, gammag, gammab):
    B = torch.stack([Br, Bg, Bb]).view(3, 1)
    beta = torch.stack([betar, betag, betab]).view(3, 1)
    gamma = torch.stack([gammar, gammag, gammab]).view(3, 1)
    
    numerator = torch.zeros((image.camera.height, image.camera.width, 3), device=device)
    denominator = torch.zeros((image.camera.height, image.camera.width, 3), device=device)
    for ui, vi, zi, Ii in zip(u, v, z, I):
        absorption = torch.exp(-beta * zi)
        backscatter = B * (1 - torch.exp(-gamma * zi))
        Di = Ii - backscatter
        numerator[vi.long(), ui.long()] += (Di * absorption).T
        denominator[vi.long(), ui.long()] += absorption.square().T
    J = numerator / denominator
    try:
        display(normalize(J))
    except:
        pass
    cursor = 0
    residuals = torch.zeros((3, n_obs), dtype=torch.float32, device=device)
    for ui, vi, zi, Ii in zip(u, v, z, I):
        length = ui.shape[0]
        residuals[:, cursor: cursor + length] = Ii - J[vi.long(), ui.long()].T * torch.exp(-beta * zi) - B * (1 - torch.exp(-gamma * zi))
        cursor += length
    return residuals.flatten()

Br_hat = torch.tensor(0.1, device=device)
Bg_hat = torch.tensor(0.1, device=device)
Bb_hat = torch.tensor(0.1, device=device)
betar_hat = torch.tensor(0.1, device=device)
betag_hat = torch.tensor(0.1, device=device)
betab_hat = torch.tensor(0.1, device=device)
gammar_hat = torch.tensor(0.1, device=device)
gammag_hat = torch.tensor(0.1, device=device)
gammab_hat = torch.tensor(0.1, device=device)

In [None]:
damping = 0.1
eye = torch.eye(9, dtype=torch.float32, device=device)
previous_cost = torch.inf

for _ in range(100):
    res = compute_residuals(Br_hat, Bg_hat, Bb_hat, betar_hat, betag_hat, betab_hat, gammar_hat, gammag_hat, gammab_hat)
    jac = torch.stack(
        torch.func.jacfwd(compute_residuals, argnums=(0, 1, 2, 3, 4, 5, 6, 7, 8))(Br_hat, Bg_hat, Bb_hat, betar_hat, betag_hat, betab_hat, gammar_hat, gammag_hat, gammab_hat)
    ).T
    delta = torch.inverse(jac.T @ jac + damping * eye) @ (jac.T @ res)
    Br_hat -= delta[0]
    Bg_hat -= delta[1]
    Bb_hat -= delta[2]
    betar_hat -= delta[3]
    betag_hat -= delta[4]
    betab_hat -= delta[5]
    gammar_hat -= delta[6]
    gammag_hat -= delta[7]
    gammab_hat -= delta[8]
    cost = torch.square(res).sum().item()
    if cost < previous_cost:
        damping = max(damping / 10, 1e-32)
    else:
        damping *= 10
    previous_cost = cost
    print(cost)

In [None]:
def compute_residuals(B, beta, gamma):
    numerator = torch.zeros((image.camera.height, image.camera.width), device=device)
    denominator = torch.zeros((image.camera.height, image.camera.width), device=device)
    for ui, vi, zi, Ii in zip(u, v, z, I):
        absorption = torch.exp(-beta * zi)
        backscatter = B * (1 - torch.exp(-gamma * zi))
        Di = Ii[2] - backscatter
        numerator[vi.long(), ui.long()] += Di * absorption
        denominator[vi.long(), ui.long()] += absorption.square()
    J = numerator / denominator
    cost = 0
    for ui, vi, zi, Ii in zip(u, v, z, I):
        cost += torch.square(Ii[2] - J[vi.long(), ui.long()] * torch.exp(-beta * zi) - B * (1 - torch.exp(-gamma * zi))).sum()
    return cost

B_hat = torch.tensor(0.1, device=device)
beta_hat = torch.tensor(0.1, device=device)
gamma_hat = torch.tensor(0.1, device=device)

In [None]:
for _ in range(100):
    res = compute_residuals(B_hat, beta_hat, gamma_hat)
    print(res.item())
    jac = torch.func.jacrev(compute_residuals, argnums=(0, 1, 2))(B_hat, beta_hat, gamma_hat)
    delta = res / torch.stack(jac)
    B_hat -= delta[0]
    beta_hat -= delta[1]
    gamma_hat -= delta[2]

In [None]:
for _ in range(100):
    print(compute_residuals(B_hat, beta_hat, gamma_hat).item())
    jac = torch.func.jacrev(compute_residuals, argnums=(0, 1, 2))(B_hat, beta_hat, gamma_hat)
    hess = torch.func.hessian(compute_residuals, argnums=(0, 1, 2))(B_hat, beta_hat, gamma_hat)
    delta = torch.tensor(hess).inverse() @ torch.tensor(jac)
    B_hat += delta[0]
    beta_hat += delta[1]
    gamma_hat += delta[2]

In [None]:
torch.tensor(hess).T

In [None]:
torch.tensor(hess).T.inverse() @ torch.tensor(jac)

In [None]:
damping = 0.1
eye = torch.eye(3, dtype=torch.float32, device=device)
previous_cost = torch.inf

for _ in range(100):
    res = compute_residuals(B_hat, beta_hat, gamma_hat)
    jac = torch.stack(
        torch.func.jacfwd(compute_residuals, argnums=(0, 1, 2))(B_hat, beta_hat, gamma_hat)
    ).T
    delta = torch.inverse(jac.T @ jac  + damping * eye) @ (jac.T @ res)
    B_hat -= delta[0]
    beta_hat -= delta[1]
    gamma_hat -= delta[2]
    cost = torch.square(res).sum().item()
    if cost < previous_cost:
        damping = max(damping / 10, 1e-32)
    else:
        damping *= 10
    previous_cost = cost
    print(cost, B_hat.item(), beta_hat.item(), gamma_hat.item())

## Adam

In [None]:
B = torch.nn.Parameter(torch.tensor([[0.1], [0.1], [0.1]], device=device))
beta = torch.nn.Parameter(torch.tensor([[0.1], [0.1], [0.1]], device=device))
gamma = torch.nn.Parameter(torch.tensor([[0.1], [0.1], [0.1]], device=device))

optimizer = torch.optim.Adam([B, beta, gamma], lr=0.05)

In [None]:
for iteration in range(1000):
    cost = 0
    optimizer.zero_grad()

    with torch.no_grad():
        numerator = torch.zeros((image.camera.height, image.camera.width, 3), device=device)
        denominator = torch.zeros((image.camera.height, image.camera.width, 3), device=device)
        for ui, vi, zi, Ii in zip(u, v, z, I):
            ui, vi, zi, Ii = ui.to(device).long(), vi.to(device).long(), zi.to(device), Ii.to(device)
            absorption = torch.exp(-beta * zi)
            backscatter = B * (1 - torch.exp(-gamma * zi))
            Di = Ii - backscatter
            numerator[vi.long(), ui.long()] += (Di * absorption).T
            denominator[vi.long(), ui.long()] += absorption.square().T
        J = numerator / denominator

    for ui, vi, zi, Ii in zip(u, v, z, I):
        ui, vi, zi, Ii = ui.to(device).long(), vi.to(device).long(), zi.to(device), Ii.to(device)
        loss = torch.square(
            Ii - J[vi.long(), ui.long()].T * torch.exp(-beta * zi) - B * (1 - torch.exp(-gamma * zi))
        ).sum()
        (loss / n_obs / 3).backward()
        cost += loss.item()

    optimizer.step()
    if iteration % 1 == 0:
        print(cost)
        display(normalize(J))

In [None]:
J = torch.nn.Parameter(image.get_rgb().to(device))
with torch.no_grad():
    J[image.get_depth_map() <= 0] = torch.nan
B = torch.nn.Parameter(torch.tensor([[0.1], [0.1], [0.1]], device=device))
beta = torch.nn.Parameter(torch.tensor([[0.1], [0.1], [0.1]], device=device))
gamma = torch.nn.Parameter(torch.tensor([[0.1], [0.1], [0.1]], device=device))

optimizer = torch.optim.Adam([J, B, beta, gamma], lr=0.05)

for iteration in range(100):
    cost = 0
    optimizer.zero_grad()

    for ui, vi, zi, Ii in zip(u, v, z, I):
        ui, vi, zi, Ii = ui.to(device).long(), vi.to(device).long(), zi.to(device), Ii.to(device)
        loss = torch.square(
            Ii - J[vi.long(), ui.long()].T * torch.exp(-beta * zi) - B * (1 - torch.exp(-gamma * zi))
        ).sum()
        (loss / n_obs / 3).backward()
        cost += loss.item()

    optimizer.step()
    if iteration % 1 == 0:
        print(cost)
        display(normalize(J))