In [None]:
import sfm
import tqdm
import h5py
import torch
import loader
import matplotlib
import normalization
import numpy as np
import open3d as o3d
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont
from torch import Tensor
from pathlib import Path


class VignettingData:
    def __init__(self):
        self.data: list[dict[str, Tensor]] = []

    def append(self, u: Tensor, v: Tensor, cP: Tensor, z: Tensor, I: Tensor):
        self.data.append({'u': u, 'v': v, 'z': z, 'I': I, 'cP':cP})

    def iterbatch(self, batch_size: int) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
        for i in range(0, len(self.data), batch_size):
            yield (
                torch.hstack([sample['u'] for sample in self.data[i:i + batch_size]]).long(),
                torch.hstack([sample['v'] for sample in self.data[i:i + batch_size]]).long(),
                torch.hstack([sample['z'] for sample in self.data[i:i + batch_size]]),
                torch.hstack([sample['I'] for sample in self.data[i:i + batch_size]]),
                torch.hstack([sample['cP'] for sample in self.data[i:i + batch_size]])
            )

    def __len__(self):
        return sum([sample['I'].shape[0] for sample in self.data])


def se3_exp(pose: Tensor) -> tuple[Tensor, Tensor]:
    w1, w2, w3, p1, p2, p3 = pose
    zero = torch.zeros_like(w1)
    pose = torch.stack([zero, -w3, w2, p1, w3, zero, -w1, p2, -w2, w1, zero, p3, zero, zero, zero, zero])
    pose = torch.matrix_exp(pose.view(4, 4))
    return pose[:3, :3], pose[:3, 3:4]

In [None]:
output_dir = Path('viewsynthesis/')
num_workers = 32
device='cuda:1'

filter_image_names = Path('/workspace/TourEiffelClean/2015/dehazing/viewsynthesis/imagelist.txt').read_text().splitlines()

colmap_model = sfm.COLMAPModel(
    image_dir=Path('/workspace/TourEiffelClean/2015/dehazing/undistort/images/'),
    depth_dir=Path('/workspace/TourEiffelClean/2015/dehazing/viewsynthesis/depth_maps/'),
    model_dir=Path('/workspace/TourEiffelClean/2015/dehazing/viewsynthesis/colmap/')
)

image_list = list(colmap_model.images.values())
image_list = [im for im in image_list if im.name not in filter_image_names]

In [None]:
for image_name in filter_image_names[51:52]:
    image = colmap_model[image_name]
    
    matches_path = (output_dir / image_name).with_suffix('.h5')
    matches_file = loader.MatchesFile(matches_path, overwrite=True)
    
    image.match_images(
        image_list=image_list,
        matches_file=matches_file,
        min_cover=0.00001,
        num_workers=num_workers,
        device=device
    )
    
    matches_file.prepare_matches(colmap_model=colmap_model, num_workers=num_workers, device=device)
    
    data = VignettingData()

    with h5py.File(matches_path, 'r', libver='latest') as f:
        for group in f.values():

            z = torch.tensor(group['z'][()], device=device)
            u2 = torch.tensor(group['u2'][()], device=device) + 0.5
            v2 = torch.tensor(group['v2'][()], device=device) + 0.5
            cP = image.camera.K_inv.to(device) @ torch.vstack([u2, v2, torch.ones_like(u2)])
            cP = cP / cP.norm(dim=0) * z

            data.append(
                u=torch.tensor(group['u1'][()], device=device),
                v=torch.tensor(group['v1'][()], device=device),
                z=z,
                I=torch.tensor(group['I'][()], device=device),
                cP=cP
            )
    
    J = torch.nn.Parameter(torch.ones((image.camera.height, image.camera.width, 3), dtype=torch.float32, device=device))
    B = torch.nn.Parameter(torch.tensor([[0.25], [0.25], [0.25]], dtype=torch.float32, device=device))
    beta = torch.nn.Parameter(torch.tensor([[0.1], [0.1], [0.1]], dtype=torch.float32, device=device))
    gamma = torch.nn.Parameter(torch.tensor([[0.1], [0.1], [0.1]], dtype=torch.float32, device=device))

    s_T_c = torch.nn.Parameter(torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=torch.float32, device=device))
    halostdx = torch.nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device))
    halostdy = torch.nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device))
    halocovxy = torch.nn.Parameter(torch.tensor(0.0, dtype=torch.float32, device=device))

    optimizer = torch.optim.Adam([
        {'params': [J], 'lr': 0.05},
        {'params': [B, beta, gamma, s_T_c, halostdx, halostdy, halocovxy], 'lr': 0.05}
    ])
    
    size = len(data)

    for iteration in tqdm.tqdm(range(1000)):

        optimizer.zero_grad()

        s_R_c, s_t_c = se3_exp(s_T_c)
        halocov = torch.stack([halostdx.square(), halocovxy, halocovxy, halostdy.square()]).view(2, 2)

        for ui, vi, zi, Ii, ciP in data.iterbatch(batch_size=5):

            siP = s_R_c @ ciP + s_t_c
            sip = siP[:2] / siP[2]

            sip = sip.T.unsqueeze(dim=2)
            halo = torch.exp(-(sip.transpose(1, 2) @ halocov.inverse() @ sip).flatten() / 2)

            zi = zi + siP.norm(dim=0)

            loss = torch.square(
                Ii - halo * (J[vi, ui].T * torch.exp(-beta * zi) + B * (1 - torch.exp(-gamma * zi)))
            ).sum() / size / 3
            loss.backward(retain_graph=True)

        optimizer.step()
    
    with torch.no_grad():
        s_R_c, s_t_c = se3_exp(s_T_c)
        halocov = torch.stack([halostdx.square(), halocovxy, halocovxy, halostdy.square()]).view(2, 2)

        im = J.cpu().numpy().copy()
        valid = loader.load_depth(image.depth_path).numpy() > 0
        image_valid = im[valid]
        image_valid = np.clip(image_valid, np.percentile(image_valid, 1, axis=0), np.percentile(image_valid, 99, axis=0))
        image_valid = image_valid - np.min(image_valid, axis=0)
        image_valid = image_valid / np.max(image_valid, axis=0)
        im[~valid] = 0
        im[valid] = image_valid
        im = Image.fromarray(np.uint8(im * 255))
        im.save(output_dir / f'sucre_{image_name}')

        u, v, cP = image.unproject_depth_map(loader.load_depth(image.depth_path), transform=False)
        sP = s_R_c.cpu() @ cP + s_t_c.cpu()
        sp = sP[:2] / sP[2]
        sp = sp.T.unsqueeze(dim=2)
        hl = torch.exp(-(sp.transpose(1, 2) @ halocov.cpu().inverse() @ sp).flatten() / 2)
        hl_image = torch.zeros((image.camera.height, image.camera.width))
        hl_image[v, u] = hl
        Image.fromarray(np.uint8(plt.colormaps['jet'](hl_image)[:, :, :3] * 255)).save(output_dir / f'halo_{image_name}')
        
        depth_map = loader.load_depth(image.depth_path)
        depth_map[~valid] = 100000
        u, v, cP = image.unproject_depth_map(depth_map, transform=False)
        sP = s_R_c.cpu() @ cP + s_t_c.cpu()
        sp = sP[:2] / sP[2]
        sp = sp.T.unsqueeze(dim=2)
        hl = torch.exp(-(sp.transpose(1, 2) @ halocov.cpu().inverse() @ sp).flatten() / 2)
        zih = cP.norm(dim=0) + sP.norm(dim=0)
        rec = (hl * (J[v, u].T.cpu() * torch.exp(-beta.cpu() * zih) + B.cpu() * (1 - torch.exp(-gamma.cpu() * zih)))).T
        rec_image = torch.zeros((image.camera.height, image.camera.width, 3))
        rec_image[v, u] = rec.clip(0, 1)
        Image.fromarray(np.uint8(rec_image * 255)).save(output_dir / f'reconstructed_{image_name}')
    
    matches_path.unlink()

In [None]:
torch.save({
    'B': B.detach().cpu(),
    'beta': beta.detach().cpu(),
    'gamma': gamma.detach().cpu(),
    's_T_c': s_T_c.detach().cpu(),
    'halostdx': halostdx.detach().cpu(),
    'halostdy': halostdy.detach().cpu(),
    'halocovxy': halocovxy.detach().cpu()
}, 'frame000052.pt')

In [None]:
parameters = torch.load('frame000052.pt')
B = parameters['B'].to(device)
beta = parameters['beta'].to(device)
gamma = parameters['gamma'].to(device)
s_T_c = parameters['s_T_c'].to(device)
halostdx = parameters['halostdx'].to(device)
halostdy = parameters['halostdy'].to(device)
halocovxy = parameters['halocovxy'].to(device)

for image_name in filter_image_names:
    image = colmap_model[image_name]
    
    matches_path = (output_dir / image_name).with_suffix('.h5')
    matches_file = loader.MatchesFile(matches_path, overwrite=True)
    
    image.match_images(
        image_list=image_list,
        matches_file=matches_file,
        min_cover=0.00001,
        num_workers=num_workers,
        device=device
    )
    
    matches_file.prepare_matches(colmap_model=colmap_model, num_workers=num_workers, device=device)
    
    data = VignettingData()

    with h5py.File(matches_path, 'r', libver='latest') as f:
        for group in f.values():

            z = torch.tensor(group['z'][()], device=device)
            u2 = torch.tensor(group['u2'][()], device=device) + 0.5
            v2 = torch.tensor(group['v2'][()], device=device) + 0.5
            cP = image.camera.K_inv.to(device) @ torch.vstack([u2, v2, torch.ones_like(u2)])
            cP = cP / cP.norm(dim=0) * z

            data.append(
                u=torch.tensor(group['u1'][()], device=device),
                v=torch.tensor(group['v1'][()], device=device),
                z=z,
                I=torch.tensor(group['I'][()], device=device),
                cP=cP
            )

    with torch.no_grad():
        s_R_c, s_t_c = se3_exp(s_T_c)
        halocov = torch.stack([halostdx.square(), halocovxy, halocovxy, halostdy.square()]).view(2, 2)

        numerator = torch.zeros(image.camera.height, image.camera.width, 3, dtype=torch.float32, device=device)
        denominator = torch.zeros(image.camera.height, image.camera.width, 3, dtype=torch.float32, device=device)

        for ui, vi, zi, Ii, ciP in data.iterbatch(batch_size=1):

            siP = s_R_c @ ciP + s_t_c
            sip = siP[:2] / siP[2]

            sip = sip.T.unsqueeze(dim=2)
            halo = torch.exp(-(sip.transpose(1, 2) @ halocov.inverse() @ sip).flatten() / 2)

            zi = zi + siP.norm(dim=0)

            Di = Ii - halo * B * (1 - torch.exp(-gamma * zi))
            alphai = halo * torch.exp(-beta * zi)

            numerator[vi, ui] += (alphai * Di).T
            denominator[vi, ui] += torch.square(alphai).T

        J = numerator / denominator

        im = J.cpu().numpy().copy()
        valid = np.all(~np.isnan(im), axis=2)
        image_valid = im[valid]
        image_valid = np.clip(image_valid, np.percentile(image_valid, 1, axis=0), np.percentile(image_valid, 99, axis=0))
        image_valid = image_valid - np.min(image_valid, axis=0)
        image_valid = image_valid / np.max(image_valid, axis=0)
        im[~valid] = 0
        im[valid] = image_valid
        im = Image.fromarray(np.uint8(im * 255))
        im.save(output_dir / f'sucre_{image_name}')

        J[~valid] = 0

        depth_map = loader.load_depth(image.depth_path)
        valid = depth_map > 0
        depth_map[~valid] = 100000
        u, v, cP = image.unproject_depth_map(depth_map, transform=False)
        sP = s_R_c.cpu() @ cP + s_t_c.cpu()
        sp = sP[:2] / sP[2]
        sp = sp.T.unsqueeze(dim=2)
        hl = torch.exp(-(sp.transpose(1, 2) @ halocov.cpu().inverse() @ sp).flatten() / 2)
        zih = cP.norm(dim=0) + sP.norm(dim=0)
        rec = (hl * (J[v, u].T.cpu() * torch.exp(-beta.cpu() * zih) + B.cpu() * (1 - torch.exp(-gamma.cpu() * zih)))).T
        rec_image = torch.zeros((image.camera.height, image.camera.width, 3))
        rec_image[v, u] = rec.clip(0, 1)
        Image.fromarray(np.uint8(rec_image * 255)).save(output_dir / f'reconstructed_{image_name}')
        
    matches_path.unlink()