In [None]:
import os
import json
import torch
import numpy as np
import pandas as pd
import open3d as o3d
import matplotlib.pyplot as plt
from utils import project
from pathlib import Path
from kornia.geometry.conversions import quaternion_to_rotation_matrix, QuaternionCoeffOrder, normalize_quaternion

def compute_s(chat_q_w, c_R_w, w_t_chat, w_t_c, c_P, xmin, xmax):
    c_n = torch.tensor([0, 0, -1], dtype=torch.float64).view(3, 1)
    chat_R_w = quaternion_to_rotation_matrix(chat_q_w, order=QuaternionCoeffOrder.WXYZ)
    chat_t_c = chat_R_w @ (w_t_c - w_t_chat)
    chat_R_c = chat_R_w @ c_R_w.T
    s = []
    for P in c_P.T:
        if P[2] >= xmin and P[2] <= xmax:
            chat_H_c = chat_R_c - (chat_t_c @ c_n.T) / P[2]
            chat_H_c = chat_H_c / chat_H_c[2, 2]
            p = P / P[2]
            s.append((chat_H_c[2] @ p).item())
    return np.array(s)

def save_view_point(geos, filename):
    vis = o3d.visualization.Visualizer()
    vis.create_window()
    for geo in geos:
        vis.add_geometry(geo)
    vis.run()  # user changes the view and press "q" to terminate
    param = vis.get_view_control().convert_to_pinhole_camera_parameters()
    o3d.io.write_pinhole_camera_parameters(filename, param)
    vis.destroy_window()

def compute_ABC(w_t_c, c_R_w, w_t_chat, chat_R_w):
    c_n = torch.tensor([0, 0, -1], dtype=torch.float64).view(3, 1)
    eye = torch.eye(3, dtype=torch.float64)
    chat_t_c = chat_R_w @ (w_t_c - w_t_chat)
    chat_R_c = chat_R_w @ c_R_w.T

    A = eye - chat_R_c
    C = c_n @ chat_t_c.T
    B = C @ A
    A = A @ A.T
    B = B + B.T
    C = C @ C.T

    return A, B, C

def quaternion_to_R(q):
    qw, qx, qy, qz = q
    return np.array([
        [1 - 2 * (np.square(qy) + np.square(qz)), 2 * (qx * qy - qw * qz), 2 * (qw * qy + qx * qz)],
        [2 * (qx * qy + qw * qz), 1 - 2 * (np.square(qx) + np.square(qz)), 2 * (qy * qz - qw * qx)],
        [2 * (qx * qz - qw * qy), 2 * (qw * qx + qy * qz), 1 - 2 * (np.square(qx) + np.square(qy))]
    ])

def posenet(w_t_chat, w_t_c, chat_q_w, c_q_w):
    return torch.square(w_t_chat - w_t_c).sum().sqrt() + 0.2 * torch.square(chat_q_w - c_q_w).sum().sqrt()

def geometric(w_t_chat, w_t_c, chat_q_w, c_R_w, w_P):
    c_p = project(w_t_c, c_R_w, w_P)
    chat_p = project(w_t_chat, quaternion_to_rotation_matrix(chat_q_w, order=QuaternionCoeffOrder.WXYZ), w_P)
    return torch.nn.functional.mse_loss(chat_p, c_p)
    #return torch.abs(chat_p - c_p).sum(dim=0).mean()

def homographic(w_t_c, c_R_w, w_t_chat, chat_q_w, xmin, xmax):
    A, B, C = compute_ABC(w_t_c, c_R_w, w_t_chat, quaternion_to_rotation_matrix(chat_q_w, order=QuaternionCoeffOrder.WXYZ))
    B_weight = torch.log(xmax / xmin) / (xmax - xmin)
    C_weight = xmin * xmax
    error = A + B * B_weight + C / C_weight
    return error.trace()

def get_data(path):
    views = []
    scene_coordinates = []
    rgb = []

    with open(path / 'reconstruction.nvm', 'r') as file:

        # Skip first two lines
        for _ in range(2):
            file.readline()

        # `n_views` is the number of images
        n_views = int(file.readline())

        # For each image, NVM format is:
        # <File name> <focal length> <quaternion WXYZ> <camera center> <radial distortion> 0
        for _ in range(n_views):
            line = file.readline().split()

            f = float(line[1])
            K = np.array([
                [f, 0, 1920 / 2],
                [0, f, 1080 / 2],
                [0, 0, 1]
            ])
            views.append({
                'image_file': line[0],
                'K': K,
                'observations_ids': []
            })

        # Skip one line
        file.readline()

        # `n_points` is the number of scene coordinates
        n_points = int(file.readline())

        # For each scene coordinate, SVM format is:
        # <XYZ> <RGB> <number of measurements> <List of Measurements>
        for i in range(n_points):

            line = file.readline().split()

            scene_coordinates.append(np.array(list(map(float, line[:3]))))
            rgb.append(np.array(list(map(int, line[3:6]))))

            # `n_obs` is the number of images where the scene coordinate is observed
            n_obs = int(line[6])

            # Each measurement is
            # <Image index> <Feature Index> <xy>
            for n in range(n_obs):
                views[int(line[7 + n * 4])]['observations_ids'].append(i)

    views = {view.pop('image_file'): view for view in views}
    scene_coordinates = np.stack(scene_coordinates)
    rgb = np.stack(rgb)

    df = pd.read_csv(path / 'dataset_train.txt', sep=' ', skiprows=1)

    data = []

    for image_file, w_t_c, c_q_w in zip(df.iloc[:, 0].values, df.iloc[:, 1:4].values, df.iloc[:, 4:8].values):
        w_t_c = w_t_c.reshape(3, 1)
        c_q_w = c_q_w / np.linalg.norm(c_q_w)
        c_R_w = quaternion_to_R(c_q_w)
        view = views[os.path.splitext(image_file)[0] + '.jpg']
        w_P = scene_coordinates[view['observations_ids']]
        colors = rgb[view['observations_ids']]
        c_P = c_R_w @ (w_P.T - w_t_c)
        c_p = view['K'] @ c_P
        c_p = c_p[:2] / c_p[2]
        args_inliers = np.argwhere(
            (c_P[2] > 0.2) & (c_P[2] < 1000) & (np.abs(c_P[0]) < 1000) & (np.abs(c_P[1]) < 1000) & \
            (c_p[0] > 0) & (c_p[0] < 1920) & (c_p[1] > 0) & (c_p[1] < 1080)
        ).flatten()
        if args_inliers.shape[0] < 10:
            print(f'Not using image {image_file}: [{args_inliers.shape[0]}/{w_P.shape[0]}] scene coordinates inliers')
        elif np.abs(w_t_c).max() > 1000:
            print(f'Not using image {image_file}: t is {w_t_c}')
        else:
            if args_inliers.shape[0] != w_P.shape[0]:
                print(f'Eliminating outliers in image {image_file}: [{args_inliers.shape[0]}/{w_P.shape[0]}] scene coordinates inliers')
            depths = np.sort(c_P.T[args_inliers][:, 2])
            data.append({
                'image_file': image_file,
                'w_t_c': w_t_c,
                'c_q_w': c_q_w,
                'c_R_w': c_R_w,
                'w_P': w_P[args_inliers],
                'c_p': c_p.T[args_inliers],
                'K': view['K'],
                'rgb': colors[args_inliers],
                'xmin': depths[int(0.025 * (depths.shape[0] - 1))],
                'xmax': depths[int(0.975 * (depths.shape[0] - 1))]
            })
    return data

path = Path('/media/clementin/DATA/Cambridge/KingsCollege')

camera_mesh = o3d.io.read_triangle_mesh('/home/clementin/Dev/underwater_reloc_benchmark/Hierarchical-Localization/outputs/TourEiffel/camera_insideout.obj')

In [None]:
data = get_data(path)

In [None]:
image = data[420]

w_P = torch.tensor(image['w_P'])
w_t_c = torch.tensor(image['w_t_c'])
c_q_w = torch.tensor(image['c_q_w'])
xmin = torch.tensor(image['xmin'])
xmax = torch.tensor(image['xmax'])
K = torch.tensor(image['K'])
c_R_w = quaternion_to_rotation_matrix(c_q_w, order=QuaternionCoeffOrder.WXYZ)
w_R_c = c_R_w.T
c_P = c_R_w @ (w_P.T - w_t_c)

gt_cam = o3d.geometry.TriangleMesh(camera_mesh)
gt_cam.paint_uniform_color([0, 0.709, 0])
gt_cam.transform(np.vstack([
    np.hstack([w_R_c.numpy(), w_t_c.numpy()]),
    [0, 0, 0, 1]
]))

es_cam = o3d.geometry.TriangleMesh(camera_mesh)
es_cam.paint_uniform_color([0.709, 0, 0])

point_cloud = o3d.geometry.PointCloud()
point_cloud.points = o3d.utility.Vector3dVector(image['w_P'])
point_cloud.colors = o3d.utility.Vector3dVector(image['rgb'] / 255)

In [None]:
chat_q_w = c_q_w.clone()
w_t_chat = w_t_c.clone()
chat_q_w[1:] *= -1
#chat_q_w += torch.tensor([0.2, -0.08, 0.1, -0.14])
w_t_chat += torch.tensor([-2, -2, 2]).view(3, 1)
chat_q_w = torch.nn.Parameter(chat_q_w)
w_t_chat = torch.nn.Parameter(w_t_chat)

optimizer = torch.optim.Adam([chat_q_w, w_t_chat])

vis = o3d.visualization.Visualizer()
vis.create_window()
vis.add_geometry(gt_cam)
vis.add_geometry(es_cam)
vis.add_geometry(point_cloud)

for epoch in range(500):
    
    for _ in range(2):
        optimizer.zero_grad()

        #loss = geometric(w_t_chat, w_t_c, chat_q_w, c_R_w, w_P)
        loss = homographic(w_t_c, c_R_w, w_t_chat, chat_q_w, xmin, xmax)
        #loss = posenet(w_t_chat, w_t_c, chat_q_w, c_q_w)

        loss.backward()
        optimizer.step()
    
    with torch.no_grad():
        s = compute_s(chat_q_w, c_R_w, w_t_chat, w_t_c, c_P, xmin, xmax)
        s_max = max(0.01, np.abs(s - 1).max())
        bins = np.linspace(-s_max + 1, s_max + 1, 10)
        plt.figure(figsize=(8, 4))
        plt.hist(s, bins=bins, color=[0.709, 0.0, 0.0])
        plt.gca().spines['right'].set_visible(False)
        plt.gca().spines['top'].set_visible(False)
        plt.title('Repartition of $s$', y=-0.18)
        plt.savefig(f's/{epoch:05d}.png', facecolor='white')
        plt.close()
        
        disp_chat_R_w = quaternion_to_rotation_matrix(chat_q_w, order=QuaternionCoeffOrder.WXYZ).detach().numpy()
        disp_w_t_chat = w_t_chat.detach().numpy()
        es_cam.transform(np.vstack([
            np.hstack([disp_chat_R_w.T, disp_w_t_chat]),
            [0, 0, 0, 1]
        ]))
        vis.update_geometry(es_cam)
        vis.get_view_control().convert_from_pinhole_camera_parameters(o3d.io.read_pinhole_camera_parameters('render_homography.json'))
        vis.poll_events()
        vis.update_renderer()
        vis.capture_screen_image(f'im/{epoch:05d}.png')
        if epoch % 100 == 0:
            gs_cam = o3d.geometry.TriangleMesh(es_cam)
            gs_cam.paint_uniform_color([0.709, 0.5, 0.5])
            vis.add_geometry(gs_cam)
        es_cam.transform(np.vstack([
            np.hstack([disp_chat_R_w, -disp_chat_R_w @ disp_w_t_chat]),
            [0, 0, 0, 1]
        ]))

vis.destroy_window()