# find best pose from foundpose results

fit from multiview data

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os
import sys
module_path = os.path.abspath(os.path.join(".."))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
import json
from pathlib import Path
from typing import List

import dataclass_array as dca
import jax.numpy as jnp
from jax import grad
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import pycolmap
import quaternion
import transforms3d as t3d
import trimesh
import visu3d as v3d
import yaml

sys.path.append(os.path.abspath(os.path.join("..", "bop_toolkit")))
from bop_toolkit.bop_toolkit_lib.misc import get_symmetry_transformations

import burybarrel.colmap_util as cutil
from burybarrel.image import render_v3d, render_model, to_contour
from burybarrel.plotting import get_axes_traces
from burybarrel.camera import load_v3dcams

In [None]:
use_coarse = False
camposes_path = Path("/scratch/jeyan/barreldata/results/dive3-depthcharge-03-04/cam_poses.json")
scene_path = Path("/scratch/jeyan/barreldata/results/dive3-depthcharge-03-04/openmvs-out/scene_dense_fixed.ply")
foundpose_res_path = Path("/scratch/jeyan/foundpose/output_dive3-depthcharge-03-04/inference/estimated-poses.json")
imgdir = Path("/scratch/jeyan/barreldata/divedata/dive3/dive3-depthcharge-03-04/rgb")
resdir = Path("/scratch/jeyan/barreldata/results/dive3-depthcharge-03-04")

obj_path = Path("/scratch/jeyan/barreldata/models3d/depth_charge_mark_9_mod_1-scaled.ply")

In [None]:
trimeshpc = trimesh.load(scene_path)
scenevtxs, scenecols = trimeshpc.vertices, trimeshpc.visual.vertex_colors[:, :3]
scenepts = v3d.Point3d(p=scenevtxs, rgb=scenecols)
cams, campaths = load_v3dcams(camposes_path, img_parent=imgdir)
imgpaths = [imgdir / campath.name for campath in campaths]
camnames = [campath.name for campath in campaths]
name2cam = {camname: cam for camname, cam in zip(camnames, cams)}
# v3d.make_fig([cams, scenepts])

In [None]:
with open(foundpose_res_path, "rt") as f:
    foundpose_res = json.load(f)
# foundpose results are used as reference for image ids
name2imgid = {}
obj2cams: v3d.Transform = []
# camera for each hypothesis. If multiple hypotheses for each image, there will be
# duplicate cameras in here
camhyps: v3d.Camera = []
# colmap usually can't reconstruct every camera pose, so we can only fit the
# foundpose results with a camera pose
for fres in foundpose_res:
    imgpath = Path(fres["img_path"])
    name = imgpath.name
    if name not in name2imgid.keys():
        name2imgid[name] = fres["img_id"]
    # could just use the "best" hypothesis. pretty often though, this hypothesis sucks.
    # valid_hyp = name in name2cam.keys() and fres["hypothesis_id"] == "0"
    valid_hyp = name in name2cam.keys()
    if valid_hyp:
        camhyps.append(name2cam[name])
        if use_coarse:
            R = fres["R_coarse"]
            t = fres["t_coarse"]
        else:
            R = fres["R"]
            t = fres["t"]
        T = np.eye(4)
        T[:3, :3] = R
        T[:3, 3] = np.reshape(t, -1)
        obj2cams.append(v3d.Transform.from_matrix(T))
obj2cams = dca.stack(obj2cams)
camhyps = dca.stack(camhyps)

In [None]:
mesh = trimesh.load(obj_path)
meshvtxs = np.array(mesh.vertices)
meshpts = v3d.Point3d(p=meshvtxs, rgb=[255, 0, 0])

In [None]:
def scale_cams(scale: float, cams: v3d.Camera):
    T = cams.world_from_cam.matrix4x4
    T[..., :3, 3] *= scale
    return cams.replace(world_from_cam=v3d.Transform.from_matrix(T))

def scale_reconstr(scale: float, cam2worlds, obj2cams, scenevtxs):
    """
    Scale a 3D reconstructed scene, its camera positions, and object positions
    relative to the camera.

    Args:
        scale (float)
        cam2worlds (nx4x4)
        obj2cams (nx4x4)
        scenevtxs (nx3)

    Returns:
        obj2worlds (nx4x4), cam2worldscaled (nx4x4), scenevtxsscaled (nx3)
    """
    # scaled 3d
    cam2worldscaled = np.copy(cam2worlds)
    cam2worldscaled[:, :3, 3] *= scale
    scenevtxsscaled = scenevtxs * scale
    obj2worlds = cam2worldscaled @ obj2cams
    return obj2worlds, cam2worldscaled, scenevtxsscaled

In [None]:
def ransac(*data, fit_func=None, loss_func=None, cost_func=None, samp_min=10, inlier_min=10, inlier_thres=0.1, max_iter=1000, seed=None):
    """
    Args:
        fit_func (data -> model)
        loss_func ((model, data) -> array): vectorized loss for individual data points
        cost_func ((model, data) -> scalar): total cost to try to minimize
    """
    rng = np.random.default_rng(seed)
    best_model = None
    best_inlier_idxs = []
    best_inliers = []
    best_error = float("inf")

    for _ in range(max_iter):
        sample_indices = rng.choice(len(data[0]), samp_min, replace=False)
        sample = [singledata[sample_indices] for singledata in data]

        model = fit_func(sample)

        errors = loss_func(model, data)

        inlier_idxs = np.where(errors < inlier_thres)[0]
        n_inliers = len(inlier_idxs)
        inliers = [singledata[inlier_idxs] for singledata in data]

        total_error = cost_func(model, inliers)

        if n_inliers >= inlier_min:
            if n_inliers > len(best_inlier_idxs) or (n_inliers == len(best_inlier_idxs) and total_error < best_error):
                best_model = model
                best_inliers = inliers
                best_inlier_idxs = inlier_idxs
                best_error = total_error
    # TODO add condition to relax constraints if this is reached
    if best_model is None:
        raise ValueError("No valid model found after RANSAC")
    return best_model, best_inlier_idxs

In [None]:
# this derivative 100% has a closed form but i'm too lazy to solve for it
# so screw it, just do gradient descent.
def variance_from_scale(scale, data):
    camTs = jnp.array(data[0])
    objTs = jnp.array(data[1])
    scaledcamTs = camTs.at[:, 0:3, 3].multiply(scale)
    centershom = scaledcamTs @ objTs @ jnp.array([0, 0, 0, 1.0])
    centers = centershom[:, :3]
    # trace of cov matrix for now, i guess
    return jnp.sum(jnp.var(centers, axis=0))

class ScaleCentroidModel():
    def __init__(self):
        self.scale = None
        self.mean = None

    def __call__(self, data):
        return self.predict(data)
    
    def fit(self, data):
        varfunc_data = lambda x: variance_from_scale(x, data)
        grad_cost = grad(varfunc_data)
        scaleinit = 1.0
        currscale = scaleinit
        currgrad = grad_cost(scaleinit)
        rate = 0.01
        eps = 1e-3
        while jnp.abs(currgrad) > eps:
            currgrad = grad_cost(currscale)
            currscale -= rate * currgrad
        self.scale = float(currscale)
        centroids = self.predict(data)
        self.mean = np.mean(centroids, axis=0)
        return self

    def predict(self, data):
        cam2worlds = data[0]
        obj2cams = data[1]
        scaledcamTs = np.copy(cam2worlds)
        scaledcamTs[:, 0:3, 3] *= self.scale
        centershom = scaledcamTs @ obj2cams @ jnp.array([0, 0, 0, 1.0])
        centers = centershom[:, :3]
        return centers

# data = (cam2world nx4x4, obj2cam nx4x4)
def fitcams(data):
    model = ScaleCentroidModel()
    model.fit(data)
    return model

def camloss(model, data):
    cents = model(data)
    return np.linalg.norm(cents - model.mean, axis=1)

def camcost(model, data):
    cents = model(data)
    return jnp.sum(jnp.var(cents, axis=0))

model, inlieridxs = ransac(camhyps.world_from_cam.matrix4x4, obj2cams.matrix4x4, fit_func=fitcams, loss_func=camloss, cost_func=camcost, samp_min=5, inlier_min=5, inlier_thres=0.15, max_iter=50)
model.scale, inlieridxs

In [None]:
scalefactor = model.scale
camscaled = scale_cams(scalefactor, cams)
obj2worlds, cam2worldscaled, scenevtxsscaled = scale_reconstr(scalefactor, camhyps.world_from_cam.matrix4x4, obj2cams.matrix4x4, scenevtxs)
tofig = []
camsinlier = camhyps.replace(world_from_cam=v3d.Transform.from_matrix(cam2worldscaled))[inlieridxs]
obj2worldsinlier = obj2worlds[inlieridxs]
barrelpts_trf = []
sceneptsscaled = scenepts.replace(p=scenevtxsscaled)
tofig.extend([sceneptsscaled, camsinlier])
for i, obj2world in enumerate(obj2worldsinlier):
    barrelpts_trf.append(v3d.Transform.from_matrix(obj2world) @ meshpts)
    tofig.append(barrelpts_trf[-1])
v3d.make_fig(*tofig)

In [None]:
v3d.make_fig(*get_axes_traces(obj2worldsinlier))

In [None]:
def qangle(q1, q2):
    """
    Angle in radians between 2 quaternions.
    https://math.stackexchange.com/questions/3572459/how-to-compute-the-orientation-error-between-two-3d-coordinate-frames
    """
    qerr = q1 * q2.conjugate()
    if qerr.w < 0:
        qerr *= -1
    err = np.atan2(np.sqrt(qerr.x ** 2 + qerr.y ** 2 + qerr.z ** 2), qerr.w)
    return err

def closest_quat_sym(q1, q2, syms):
    """
    Use q1 as reference, brute force rotate q2 and return that.

    Args:
        syms (dict): Set of symmetry transformations, each given by a dictionary with:
            - 'R': 3x3 ndarray with the rotation matrix.
            - 't': 3x1 ndarray with the translation vector.
    """
    errs = []
    q2_syms = []
    for sym in syms:
        q2_sym = q2 * quaternion.from_rotation_matrix(sym["R"])
        errs.append(qangle(q1, q2_sym))
        q2_syms.append(q2_sym)
    return q2_syms[np.argmin(np.abs(errs))]

In [None]:
with open(Path("/scratch/jeyan/barreldata/models3d/model_info.json"), "rt") as f:
    objinfo = json.load(f)
symTs = get_symmetry_transformations(objinfo[obj_path.name], 0.01)
quatsinlier = quaternion.from_rotation_matrix(obj2worldsinlier[..., :3, :3])
ref = quatsinlier[0]
quatssymd = [ref]
for otherquat in quatsinlier[1:]:
    best = closest_quat_sym(ref, otherquat, symTs)
    quatssymd.append(best)
quatssymd = np.array(quatssymd)
obj2worldsinliersym = np.copy(obj2worldsinlier)
obj2worldsinliersym[..., :3, :3] = quaternion.as_rotation_matrix(quatssymd)
v3d.make_fig(*get_axes_traces(obj2worldsinliersym))

In [None]:
# data = (n quaternions)
def qmean(qs, weights=None):
    """https://stackoverflow.com/questions/12374087/average-of-multiple-quaternions"""
    if weights is None:
        weights = np.ones(len(qs))
    qs = np.squeeze(qs)
    Q = quaternion.as_float_array(qs * weights).T
    QQ = Q @ Q.T
    vals, vecs = np.linalg.eig(QQ)
    avg = vecs[:, np.argmax(np.abs(vals))]
    avg = avg / np.linalg.norm(avg)
    return quaternion.from_float_array(avg)

def qloss(model, qs):
    qs = np.reshape(qs, -1)
    return np.array([qangle(model, q) for q in qs])

def qcost(model, qs):
    return np.sum(qloss(model, qs))

qmeanransac, qinliers = ransac(quatssymd, fit_func=qmean, loss_func=qloss, cost_func=qcost, samp_min=5, inlier_min=5, inlier_thres=0.2, max_iter=50)
qmeanransac, qinliers

In [None]:
meanT = v3d.Transform(R=quaternion.as_rotation_matrix(qmeanransac), t=np.mean(v3d.Transform.from_matrix(obj2worldsinliersym).t, axis=0))
v3d.make_fig(*get_axes_traces(obj2worldsinliersym, scale=0.5), *get_axes_traces(meanT, linewidth=10))

In [None]:
v3d.make_fig(camscaled, meanT @ meshpts, sceneptsscaled)

In [None]:
imgs = [np.array(Image.open(imgpath).convert("RGB")) for imgpath in imgpaths]
overlaydir = resdir / "fit-overlays"
overlaydir.mkdir(exist_ok=True)
for i, img in enumerate(imgs):
    imgpath = imgpaths[i]
    rgb, _, _ = render_model(camscaled[i], mesh, meanT, light_intensity=20.0)
    overlayimg = to_contour(rgb, color=(255, 0, 0), background=img)
    Image.fromarray(overlayimg).save(overlaydir / f"{imgpaths[i].stem}.png")
    # Image.fromarray(render_v3d(camscaled[i], meanT @ meshpts, radius=4, background=img)).save(overlaydir / f"{imgpaths[i].stem}.png")

In [None]:
obj2camfit = camscaled.world_from_cam.inv @ meanT[..., None]
estposes = []
for i, obj2cam in enumerate(obj2camfit):
    posedata = {
        "img_path": str(imgpaths[i]),
        "img_id": name2imgid[camnames[i]],
        "hypothesis_id": "0",
        "R": obj2cam.R.tolist(),
        "t": obj2cam.t[..., None].tolist(),
    }
    estposes.append(posedata)
with open(resdir / "estimated-poses.json", "wt") as f:
    json.dump(estposes, f)