In [None]:
import sfm
import tqdm
import h5py
import torch
import loader
import matplotlib
import normalization
import numpy as np
import open3d as o3d
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont
from torch import Tensor
from pathlib import Path


class VignettingData:
    def __init__(self):
        self.data: list[dict[str, Tensor]] = []

    def append(self, u: Tensor, v: Tensor, cP: Tensor, z: Tensor, I: Tensor):
        self.data.append({'u': u, 'v': v, 'z': z, 'I': I, 'cP':cP})

    def iterbatch(self, batch_size: int) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
        for i in range(0, len(self.data), batch_size):
            yield (
                torch.hstack([sample['u'] for sample in self.data[i:i + batch_size]]).long(),
                torch.hstack([sample['v'] for sample in self.data[i:i + batch_size]]).long(),
                torch.hstack([sample['z'] for sample in self.data[i:i + batch_size]]),
                torch.hstack([sample['I'] for sample in self.data[i:i + batch_size]]),
                torch.hstack([sample['cP'] for sample in self.data[i:i + batch_size]])
            )

    def __len__(self):
        return sum([sample['I'].shape[0] for sample in self.data])


def se3_exp(pose: Tensor) -> tuple[Tensor, Tensor]:
    w1, w2, w3, p1, p2, p3 = pose
    zero = torch.zeros_like(w1)
    pose = torch.stack([zero, -w3, w2, p1, w3, zero, -w1, p2, -w2, w1, zero, p3, zero, zero, zero, zero])
    pose = torch.matrix_exp(pose.view(4, 4))
    return pose[:3, :3], pose[:3, 3:4]


plt.rcParams["font.family"] = "Times"
matplotlib.rc('text', usetex=True)
matplotlib.rcParams['mathtext.fontset'] = 'custom'
matplotlib.rcParams['mathtext.rm'] = 'Times'
matplotlib.rcParams['mathtext.it'] = 'Times:italic'
matplotlib.rcParams['mathtext.bf'] = 'Times:bold'

font = ImageFont.truetype('times', 64)

In [None]:
"""colmap = sfm.COLMAPModel(
    image_dir=Path('/workspace/Varos/images/'),
    depth_dir=Path('/workspace/Varos/depth_maps/'),
    model_dir=Path('/workspace/Varos/sparse/')
)

image = colmap['seq01_veh0_camM0_A-00000318.png']

matches_path = Path('test/seq01_veh0_camM0_A-00000318.h5')"""

"""colmap = sfm.COLMAPModel(
    image_dir=Path('/workspace/AQUALOC/raw_data/colmap_sequence_10/undistorted/images/'),
    depth_dir=Path('/workspace/AQUALOC/raw_data/colmap_sequence_10/undistorted/depth_maps/'),
    model_dir=Path('/workspace/AQUALOC/raw_data/colmap_sequence_10/undistorted/sparse/')
)"""

# image = colmap['frame012760.png']
# matches_path = Path('test/frame012760.h5')

# image = colmap['frame005980.png']
# matches_path = Path('test/frame005980.h5')

# image = colmap['frame001580.png']
# matches_path = Path('test/frame001580.h5')

colmap = sfm.COLMAPModel(
    image_dir=Path('/workspace/TourEiffelClean/2015/dehazing/undistort/images/'),
    depth_dir=Path('/workspace/TourEiffelClean/2015/dehazing/undistort/depth_maps/'),
    model_dir=Path('/workspace/TourEiffelClean/2015/dehazing/undistort/sparse/')
)

# image = colmap['20150418T014923.000Z.png']
# matches_path = Path('test/20150418T014923.000Z.h5')

image = colmap['20150418T025347.000Z.png']
matches_path = Path('test/20150418T025347.000Z.h5')

# image = colmap['20150419T033711.000Z.png']
# matches_path = Path('test/20150419T033711.000Z.h5')

# image = colmap['20150419T035320.000Z.png']
# matches_path = Path('test/20150419T035320.000Z.h5')

# image = colmap['20150418T034032.000Z.png']
# matches_path = Path('test/20150418T034032.000Z.h5')

# image = colmap['20150419T033302.000Z.png']
# matches_path = Path('test/20150419T033302.000Z.h5')

"""colmap = sfm.COLMAPModel(
    image_dir=Path('/workspace/Eurydice/sfm/pixsfm/undistort/images/'),
    depth_dir=Path('/workspace/Eurydice/sfm/pixsfm/undistort/depth_maps/'),
    model_dir=Path('/workspace/Eurydice/sfm/pixsfm/undistort/sparse/')
)

image = colmap['20220420T124924.000Z.png']

matches_path = Path('test/20220420T124924.000Z.h5')"""

"""colmap = sfm.COLMAPModel(
    image_dir=Path('/workspace/Jaureguiberry/undistort/images/'),
    depth_dir=Path('/workspace/Jaureguiberry/undistort/depth_maps/'),
    model_dir=Path('/workspace/Jaureguiberry/undistort/sparse/')
)

image = colmap['frame00016124.png']
matches_path = Path('test/frame00016124.h5')"""

device='cuda:1'

In [None]:
data = VignettingData()

with h5py.File(matches_path, 'r', libver='latest') as f:
    for group in f.values():

        z = torch.tensor(group['z'][()], device=device)
        u2 = torch.tensor(group['u2'][()], device=device) + 0.5
        v2 = torch.tensor(group['v2'][()], device=device) + 0.5
        cP = image.camera.K_inv.to(device) @ torch.vstack([u2, v2, torch.ones_like(u2)])
        cP = cP / cP.norm(dim=0) * z

        data.append(
            u=torch.tensor(group['u1'][()], device=device),
            v=torch.tensor(group['v1'][()], device=device),
            z=z,
            I=torch.tensor(group['I'][()], device=device),
            cP=cP
        )

In [None]:
J = torch.nn.Parameter(loader.load_image(image.image_path).to(device))
B = torch.nn.Parameter(torch.tensor([[0.25], [0.25], [0.25]], dtype=torch.float32, device=device))
beta = torch.nn.Parameter(torch.tensor([[0.1], [0.1], [0.1]], dtype=torch.float32, device=device))
gamma = torch.nn.Parameter(torch.tensor([[0.1], [0.1], [0.1]], dtype=torch.float32, device=device))

s_T_c = torch.nn.Parameter(torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=torch.float32, device=device))
halostdx = torch.nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device))
halostdy = torch.nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device))
halocovxy = torch.nn.Parameter(torch.tensor(0.0, dtype=torch.float32, device=device))

optimizer = torch.optim.Adam([
    {'params': [J], 'lr': 0.05},
    {'params': [B, beta, gamma, s_T_c, halostdx, halostdy, halocovxy], 'lr': 0.05}
])

In [None]:
size = len(data)
costs = []

Rs, ts = [], []

for iteration in tqdm.tqdm(range(1001)):

    cost = 0
    optimizer.zero_grad()
    
    s_R_c, s_t_c = se3_exp(s_T_c)
    Rs.append(s_R_c.detach().cpu().numpy())
    ts.append(s_t_c.detach().cpu().numpy())
    halocov = torch.stack([halostdx.square(), halocovxy, halocovxy, halostdy.square()]).view(2, 2)
    
    if iteration % 1 == 0:
        im = J.detach().cpu().numpy().copy()
        valid = loader.load_depth(image.depth_path).numpy() > 0
        image_valid = im[valid]
        image_valid = np.clip(image_valid, np.percentile(image_valid, 1, axis=0), np.percentile(image_valid, 99, axis=0))
        image_valid = image_valid - np.min(image_valid, axis=0)
        image_valid = image_valid / np.max(image_valid, axis=0)
        im[~valid] = 0
        im[valid] = image_valid
        im = Image.fromarray(np.uint8(im * 255))
        draw = ImageDraw.Draw(im)
        draw.text((15, 0), f'Iteration {iteration:02d}', (255, 255, 255), font=font, anchor='la')
        im.save(f'20150419T033711/frame{iteration+1:05d}.png')

        u, v, cP = image.unproject_depth_map(loader.load_depth(image.depth_path), transform=False)
        sP = s_R_c.detach().cpu() @ cP + s_t_c.detach().cpu()
        sp = sP[:2] / sP[2]
        sp = sp.T.unsqueeze(dim=2)
        hl = torch.exp(-(sp.transpose(1, 2) @ halocov.detach().cpu().inverse() @ sp).flatten() / 2)
        # hl = 1.0
        hl_image = torch.zeros((image.camera.height, image.camera.width))
        hl_image[v, u] = hl
        Image.fromarray(np.uint8(plt.colormaps['jet'](hl_image)[:, :, :3] * 255)).save(f'20150419T033711/halo{iteration+1:05d}.png')
        
        zih = cP.norm(dim=0) + sP.norm(dim=0)
        att_image = torch.zeros((image.camera.height, image.camera.width, 3))
        att_image[v, u] = (hl * (torch.exp(-beta.detach().cpu() * zih) + B.detach().cpu() * (1 - torch.exp(-gamma.detach().cpu() * zih)))).T.clip(0, 1)
        Image.fromarray(np.uint8(att_image * 255)).save(f'20150419T033711/att{iteration+1:05d}.png')
        
        rec = (hl * (J[v, u].T.detach().cpu() * torch.exp(-beta.detach().cpu() * zih) + B.detach().cpu() * (1 - torch.exp(-gamma.detach().cpu() * zih)))).T
        rec_image = torch.zeros((image.camera.height, image.camera.width, 3))
        rec_image[v, u] = rec.clip(0, 1)
        Image.fromarray(np.uint8(rec_image * 255)).save(f'20150419T033711/rec{iteration+1:05d}.png')
        
        res_image = torch.zeros((image.camera.height, image.camera.width, 3))
        res_image[v, u] = torch.abs(loader.load_image(image.image_path)[v, u] - rec)
        Image.fromarray(np.uint8(plt.colormaps['jet'](res_image.mean(dim=2) * 5)[:, :, :3] * 255)).save(f'20150419T033711/res{iteration+1:05d}.png')
        

    for ui, vi, zi, Ii, ciP in data.iterbatch(batch_size=5):
        
        rand_args = torch.randperm(len(zi), device=device)[:int(len(zi) * 0.1)]
        ui = ui[rand_args]
        vi = vi[rand_args]
        zi = zi[rand_args]
        Ii = Ii[:, rand_args]
        ciP = ciP[:, rand_args]
        
        siP = s_R_c @ ciP + s_t_c
        sip = siP[:2] / siP[2]
        
        sip = sip.T.unsqueeze(dim=2)
        halo = torch.exp(-(sip.transpose(1, 2) @ halocov.inverse() @ sip).flatten() / 2)
        
        zi = zi + siP.norm(dim=0)
        
        loss = torch.square(
            Ii - halo * (J[vi, ui].T * torch.exp(-beta * zi) + B * (1 - torch.exp(-gamma * zi)))
        ).sum()
        """loss = 2 * torch.nn.functional.huber_loss(
            input=halo * (J[vi, ui].T * torch.exp(-beta * zi) + B * (1 - torch.exp(-gamma * zi))),
            target=Ii,
            reduction='sum',
            delta=0.005
        )"""
        cost = cost + loss / size / 3
    
    cost.backward()
    optimizer.step()
    
    costs.append(cost.item())
    with np.printoptions(precision=2):
        """print(f'cost: {cost.item():.3e}, B: {B.detach().flatten().cpu().numpy()}, beta: {beta.detach().flatten().cpu().numpy()}, '
              f'gamma: {gamma.detach().flatten().cpu().numpy()}, t: {s_t_c.detach().flatten().cpu().numpy()}, '
              f'w: {s_T_c[:3].detach().flatten().cpu().numpy()}, halocov: {halocov.detach().flatten().cpu().numpy()}')"""
    
    """if (iteration + 1) % 1000 == 0:
        im = J.detach().cpu().numpy().copy()
        valid = loader.load_depth(image.depth_path).numpy() > 0
        image_valid = im[valid]
        image_valid = np.clip(image_valid, np.percentile(image_valid, 1, axis=0), np.percentile(image_valid, 99, axis=0))
        image_valid = image_valid - np.min(image_valid, axis=0)
        image_valid = image_valid / np.max(image_valid, axis=0)
        im[~valid] = 0
        im[valid] = image_valid
        plt.imshow(im)
        plt.savefig(f'halo2/grame{iteration:05d}.png')
        plt.close()
        
        u, v, cP = image.unproject_depth_map(loader.load_depth(image.depth_path), transform=False)
        sP = s_R_c.detach().cpu() @ cP + s_t_c.detach().cpu()
        sp = sP[:2] / sP[2]
        sp = sp.T.unsqueeze(dim=2)
        hl = torch.exp(-(sp.transpose(1, 2) @ halocov.detach().cpu().inverse() @ sp).flatten() / 2)
        #hl = hl - hl.min()
        #hl = hl / hl.max()
        hl_image = torch.zeros((image.camera.height, image.camera.width))
        hl_image[v, u] = hl
        plt.imshow(hl_image, cmap='jet')
        plt.savefig(f'halo2/frame{iteration:05d}.png')
        plt.close()"""

In [None]:
size = len(data)

for iteration in tqdm.tqdm(range(1001)):
    
    if iteration % 1 == 0:
        with torch.no_grad():
            im = J.cpu().numpy().copy()
            valid = loader.load_depth(image.depth_path).numpy() > 0
            image_valid = im[valid]
            image_valid = np.clip(image_valid, np.percentile(image_valid, 1, axis=0), np.percentile(image_valid, 99, axis=0))
            image_valid = image_valid - np.min(image_valid, axis=0)
            image_valid = image_valid / np.max(image_valid, axis=0)
            im[~valid] = 0
            im[valid] = image_valid
            im = Image.fromarray(np.uint8(im * 255))
            draw = ImageDraw.Draw(im)
            draw.text((15, 0), f'Iteration {iteration:02d}', (255, 255, 255), font=font, anchor='la')
            im.save(f'eiffelanimation2/frame{iteration+1:05d}.png')

            u, v, cP = image.unproject_depth_map(loader.load_depth(image.depth_path), transform=False)
            s_R_c, s_t_c = se3_exp(s_T_c.cpu())
            halocov = torch.stack([halostdx.cpu().square(), halocovxy.cpu(), halocovxy.cpu(), halostdy.cpu().square()]).view(2, 2)
            sP = s_R_c @ cP + s_t_c
            sp = sP[:2] / sP[2]
            sp = sp.T.unsqueeze(dim=2)
            hl = torch.exp(-(sp.transpose(1, 2) @ halocov.inverse() @ sp).flatten() / 2)
            hl_image = torch.zeros((image.camera.height, image.camera.width))
            hl_image[v, u] = hl
            Image.fromarray(np.uint8(plt.colormaps['jet'](hl_image)[:, :, :3] * 255)).save(f'eiffelanimation2/halo{iteration+1:05d}.png')

    for ui, vi, zi, Ii, ciP in data.iterbatch(batch_size=50):
        
        optimizer.zero_grad()

        s_R_c, s_t_c = se3_exp(s_T_c)
        halocov = torch.stack([halostdx.square(), halocovxy, halocovxy, halostdy.square()]).view(2, 2)
        
        siP = s_R_c @ ciP + s_t_c
        sip = siP[:2] / siP[2]
        
        sip = sip.T.unsqueeze(dim=2)
        halo = torch.exp(-(sip.transpose(1, 2) @ halocov.inverse() @ sip).flatten() / 2)
        
        zi = zi + siP.norm(dim=0)
        
        loss = torch.square(
            Ii - halo * (J[vi, ui].T * torch.exp(-beta * zi) + B * (1 - torch.exp(-gamma * zi)))
        ).sum() / size / 3
        
        loss.backward()
        optimizer.step()
    
    with np.printoptions(precision=2):
        """print(f'cost: {cost.item():.3e}, B: {B.detach().flatten().cpu().numpy()}, beta: {beta.detach().flatten().cpu().numpy()}, '
              f'gamma: {gamma.detach().flatten().cpu().numpy()}, t: {s_t_c.detach().flatten().cpu().numpy()}, '
              f'w: {s_T_c[:3].detach().flatten().cpu().numpy()}, halocov: {halocov.detach().flatten().cpu().numpy()}')"""

In [None]:
fig, ax = plt.subplots(1, figsize=(4, 2.1), layout='constrained')
hl_plot = ax.imshow(hl_image, vmin=0, vmax=1, cmap='jet')
ax.axis('off')
cbar = fig.colorbar(hl_plot, shrink=0.89, pad=0.02)
#cbar.set_label('Halo intensity', rotation=270, labelpad=10)
cbar.locator = matplotlib.ticker.MaxNLocator(nbins=6)
cbar.ax.tick_params(length=2)
cbar.update_ticks()
fig.savefig(f'halo.pdf', bbox_inches='tight', pad_inches=0, dpi=900)

In [None]:
Image.fromarray(np.uint8(im * 255)).save('20220420T124924.000Z_sucre.png')

In [None]:
plt.hist(im[:, :, 1].flatten(), bins=20)

In [None]:
plt.plot(costs)

In [None]:
v, u = torch.where(loader.load_depth(image.depth_path) > -1)
uvw = torch.stack([u + 0.5, v + 0.5, torch.ones_like(u)])
cp = uvw * 10
cP = image.camera.K_inv @ cp

s_R_c, s_t_c = se3_exp(si_T_ci.detach().cpu())
sP = s_R_c @ cP + s_t_c
sp = sP[:2] / sP[2]

hl = coeffs[0].detach().cpu() + coeffs[1].detach().cpu() * sp[0].square() + coeffs[2].detach().cpu() * sp[1].square()
plt.imshow(hl.reshape(image.camera.height, image.camera.width), cmap='gray')

In [None]:
sp = np.linspace(-0.5, 0.5, 300)
spx, spy = np.meshgrid(sp, sp)
sP = np.stack([spx.flatten(), spy.flatten(), np.ones_like(spx.flatten())])
hl = coeffs[0].detach().cpu().numpy() + coeffs[1].detach().cpu().numpy() * sP[0].square() + coeffs[2].detach().cpu().numpy() * sP[1].square()

r, t = se3_exp(si_T_ci.detach().cpu())
t = -r.T @ t
t = t.numpy()
r = r.T.numpy()

cp = image.camera.K.numpy() @ (r @ sP + t)
cp = cp[:2] / cp[2]

plt.scatter(*cp, s=0.1, c=hl, cmap='gray')

In [None]:
coeffs

In [None]:
im = Jg.detach().cpu().numpy().copy()
valid = loader.load_depth(image.depth_path).numpy() > 0
image_valid = im[valid]
image_valid = np.clip(image_valid, np.percentile(image_valid, 1, axis=0), np.percentile(image_valid, 99, axis=0))
image_valid = image_valid - np.min(image_valid, axis=0)
image_valid = image_valid / np.max(image_valid, axis=0)
im[~valid] = 0
im[valid] = image_valid

In [None]:
plt.imshow(im, cmap='gray')

## Halo analysis

In [None]:
plt.figure(figsize=(6, 6))
x, y = np.meshgrid(np.linspace(-5, 5, 100), np.linspace(-5, 5, 100))
x = x.flatten()
y = y.flatten()
sp = torch.tensor(np.stack([x, y]).T[..., None], dtype=torch.float32)
halocov = torch.tensor([
    [1.0, -0.9],
    [-0.9, 2.0]
])
hl = torch.exp(-(sp.transpose(1, 2) @ halocov.inverse() @ sp).flatten() / 2) / (2 * np.pi * halocov.det().sqrt())
plt.scatter(sp[:, 0, 0].numpy(), sp[:, 1, 0], c=hl.numpy())

In [None]:
camera_mesh = o3d.io.read_triangle_mesh('camera.obj')

camera = o3d.geometry.TriangleMesh(camera_mesh)
camera.paint_uniform_color([0, 0.709, 0])
cameras = [camera]

i, j = np.meshgrid(np.linspace(-np.pi / 4, np.pi / 4, 4), np.linspace(-np.pi / 4, np.pi / 4, 4))
i = i.flatten()
j = j.flatten()
for a, b in zip(i, j):
    R, t = se3_exp(torch.tensor([a, b, 0.0, 0.0, 0.0, 0.0]))
    camera = o3d.geometry.TriangleMesh(camera_mesh)
    camera.paint_uniform_color([0.709, 0, 0])
    camera.transform(np.vstack([
        np.hstack([R.numpy(), t.numpy()]),
        [0, 0, 0, 1]
    ]))
    cameras.append(camera)

w3 is roll

In [None]:
o3d.visualization.draw_geometries(cameras)

## Plot spot trajectory

In [None]:
camera_mesh = o3d.io.read_triangle_mesh('camera.obj')
cameras = []
for i, (R, t) in enumerate(zip(Rs, ts)):
    camera = o3d.geometry.TriangleMesh(camera_mesh)
    if i == 0:
        camera.paint_uniform_color([0, 0.709, 0])
    else:
        camera.paint_uniform_color([0.709, 0, 0])
    camera.transform(np.vstack([
        np.hstack([0.1 * R, t]),
        [0, 0, 0, 1]
    ]))
    cameras.append(camera)
o3d.visualization.draw_geometries(cameras)

In [None]:
2000000*10000000*4 / (10**12)