# find best pose from foundpose results

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os
import sys
module_path = os.path.abspath(os.path.join(".."))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
import json
from pathlib import Path
from typing import List

import dataclass_array as dca
import jax.numpy as jnp
from jax import grad
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import pycolmap
import quaternion
import transforms3d as t3d
import trimesh
import visu3d as v3d

sys.path.append(os.path.abspath(os.path.join("..", "bop_toolkit")))
from bop_toolkit.bop_toolkit_lib.misc import get_symmetry_transformations

import burybarrel.colmap_util as cutil
from burybarrel.image import render_v3d
from burybarrel.plotting import get_axes_traces

In [None]:
reconstr_path = Path("/scratch/jeyan/barreldata/results/barrelddt1/colmap-out/0")
foundpose_res_path = Path("/scratch/jeyan/foundpose/output_barrelddt1_raw_vitl_layer18/inference/estimated-poses.json")
obj_path = Path("/scratch/jeyan/barreldata/models3d/barrelsingle-scaled.ply")

reconstruction = pycolmap.Reconstruction(reconstr_path)
print(reconstruction.summary())

In [None]:
scenevtxs, scenecols = cutil.get_pc(reconstruction)
scenepts = v3d.Point3d(p=scenevtxs, rgb=scenecols)
cams = cutil.get_cams_v3d(reconstruction)
# v3d.make_fig([cams, pts3d])

In [None]:
with open(foundpose_res_path, "rt") as f:
    foundpose_res = json.load(f)
# for now, we just use the "best" hypothesis
# pretty often though, this hypothesis sucks
foundpose_res = list(filter(lambda x: x["hypothesis_id"] == "0", foundpose_res))

In [None]:
obj2cams = []
for res in foundpose_res:
    R = res["R"]
    t = res["t"]
    T = np.eye(4)
    T[:3, :3] = R
    T[:3, 3] = np.reshape(t, -1)
    obj2cams.append(T)
obj2cams = np.array(obj2cams)

In [None]:
mesh = trimesh.load(obj_path)
meshvtxs = np.array(mesh.vertices)
meshpts = v3d.Point3d(p=meshvtxs, rgb=[255, 0, 0])

In [None]:
def scale_reconstr(scale: float, cam2worlds, obj2cams, scenevtxs):
    """
    Scale a 3D reconstructed scene, its camera positions, and object positions
    relative to the camera.

    Args:
        scale (float)
        cam2worlds (nx4x4)
        obj2cams (nx4x4)
        scenevtxs (nx3)
    """
    # scaled 3d
    cam2worldscaled = np.copy(cam2worlds)
    cam2worldscaled[:, :3, 3] *= scale
    scenevtxsscaled = scenevtxs * scale
    obj2worlds = cam2worldscaled @ obj2cams
    return obj2worlds, cam2worldscaled, scenevtxsscaled

In [None]:
# scaled 3d
scalefactor = 0.2
obj2worlds, cam2worldscaled, scenevtxsscaled = scale_reconstr(scalefactor, cams.world_from_cam.matrix4x4, obj2cams, scenevtxs)
tofig = []
camscaled = cams.replace(world_from_cam=v3d.Transform.from_matrix(cam2worldscaled))
barrelpts_trf = []
sceneptsscaled = scenepts.replace(p=scenevtxsscaled)
tofig.extend([sceneptsscaled, camscaled])
for i, obj2world in enumerate(obj2worlds):
    barrelpts_trf.append(v3d.Transform.from_matrix(obj2world) @ meshpts)
    tofig.append(barrelpts_trf[-1])
# v3d.make_fig(*tofig)

In [None]:
i = 9
# Image.fromarray(render_v3d(camscaled[i], dca.concat([sceneptsscaled, barrelpts_trf[i]]), radius=2))

In [None]:
def ransac(*data, fit_func=None, loss_func=None, cost_func=None, samp_min=10, inlier_min=10, inlier_thres=0.1, max_iter=1000, seed=None):
    """
    Args:
        fit_func (data -> model)
        loss_func ((model, data) -> array): vectorized loss for individual data points
        cost_func ((model, data) -> scalar): total cost to try to minimize
    """
    rng = np.random.default_rng(seed)
    best_model = None
    best_inlier_idxs = []
    best_inliers = []
    best_error = float("inf")

    for _ in range(max_iter):
        sample_indices = rng.choice(len(data[0]), samp_min, replace=False)
        sample = [singledata[sample_indices] for singledata in data]

        model = fit_func(sample)

        errors = loss_func(model, data)

        inlier_idxs = np.where(errors < inlier_thres)[0]
        n_inliers = len(inlier_idxs)
        inliers = [singledata[inlier_idxs] for singledata in data]

        total_error = cost_func(model, inliers)

        if n_inliers >= inlier_min:
            if n_inliers > len(best_inlier_idxs) or (n_inliers == len(best_inlier_idxs) and total_error < best_error):
                best_model = model
                best_inliers = inliers
                best_inlier_idxs = inlier_idxs
                best_error = total_error
    if best_model is None:
        raise ValueError("No valid model found after RANSAC")
    return best_model, best_inlier_idxs

In [None]:
# sanity test for the ransac function
X = np.array([-0.848,-0.800,-0.704,-0.632,-0.488,-0.472,-0.368,-0.336,-0.280,-0.200,-0.00800,-0.0840,0.0240,0.100,0.124,0.148,0.232,0.236,0.324,0.356,0.368,0.440,0.512,0.548,0.660,0.640,0.712,0.752,0.776,0.880,0.920,0.944,-0.108,-0.168,-0.720,-0.784,-0.224,-0.604,-0.740,-0.0440,0.388,-0.0200,0.752,0.416,-0.0800,-0.348,0.988,0.776,0.680,0.880,-0.816,-0.424,-0.932,0.272,-0.556,-0.568,-0.600,-0.716,-0.796,-0.880,-0.972,-0.916,0.816,0.892,0.956,0.980,0.988,0.992,0.00400]).reshape(-1,1)
y = np.array([-0.917,-0.833,-0.801,-0.665,-0.605,-0.545,-0.509,-0.433,-0.397,-0.281,-0.205,-0.169,-0.0531,-0.0651,0.0349,0.0829,0.0589,0.175,0.179,0.191,0.259,0.287,0.359,0.395,0.483,0.539,0.543,0.603,0.667,0.679,0.751,0.803,-0.265,-0.341,0.111,-0.113,0.547,0.791,0.551,0.347,0.975,0.943,-0.249,-0.769,-0.625,-0.861,-0.749,-0.945,-0.493,0.163,-0.469,0.0669,0.891,0.623,-0.609,-0.677,-0.721,-0.745,-0.885,-0.897,-0.969,-0.949,0.707,0.783,0.859,0.979,0.811,0.891,-0.137]).reshape(-1,1)

class LinearRegressor:
    def __init__(self):
        self.params = None

    def fit(self, X: np.ndarray, y: np.ndarray):
        r, _ = X.shape
        X = np.hstack([np.ones((r, 1)), X])
        self.params = np.linalg.inv(X.T @ X) @ X.T @ y
        return self

    def predict(self, X: np.ndarray):
        r, _ = X.shape
        X = np.hstack([np.ones((r, 1)), X])
        return X @ self.params

def fitfunc(data):
    X = data[0]
    y = data[1]
    reg = LinearRegressor()
    reg.fit(X, y)
    return reg

def lossfunc(model, data):
    X = data[0]
    y = data[1]
    ypred = model.predict(X)
    return np.abs(y - ypred).reshape(-1)

def costfunc(model, data):
    X = data[0]
    y = data[1]
    ypred = model.predict(X)
    return np.sum((y - ypred) ** 2) ** 0.5

model, inlieridxs = ransac(X, y, fit_func=fitfunc, loss_func=lossfunc, cost_func=costfunc, samp_min=10, inlier_min=10, inlier_thres=0.2)
# plt.scatter(X[inlieridxs], y[inlieridxs])
# line = np.linspace(np.min(X), np.max(X), num=100).reshape(-1, 1)
# plt.plot(line, model.predict(line), c="peru")

In [None]:
# this derivative 100% has a closed form but i'm too lazy to solve for it
# so screw it, just do gradient descent.
def variance_from_scale(scale, data):
    camTs = jnp.array(data[0])
    objTs = jnp.array(data[1])
    scaledcamTs = camTs.at[:, 0:3, 3].multiply(scale)
    centershom = scaledcamTs @ objTs @ jnp.array([0, 0, 0, 1.0])
    centers = centershom[:, :3]
    # trace of cov matrix for now, i guess
    return jnp.sum(jnp.var(centers, axis=0))

class ScaleCentroidModel():
    def __init__(self):
        self.scale = None
        self.mean = None

    def __call__(self, data):
        return self.predict(data)
    
    def fit(self, data):
        varfunc_data = lambda x: variance_from_scale(x, data)
        grad_cost = grad(varfunc_data)
        scaleinit = 1.0
        currscale = scaleinit
        currgrad = grad_cost(scaleinit)
        rate = 0.01
        eps = 1e-3
        while jnp.abs(currgrad) > eps:
            currgrad = grad_cost(currscale)
            currscale -= rate * currgrad
        self.scale = float(currscale)
        centroids = self.predict(data)
        self.mean = np.mean(centroids, axis=0)
        return self

    def predict(self, data):
        cam2worlds = data[0]
        obj2cams = data[1]
        scaledcamTs = np.copy(cam2worlds)
        scaledcamTs[:, 0:3, 3] *= self.scale
        centershom = scaledcamTs @ obj2cams @ jnp.array([0, 0, 0, 1.0])
        centers = centershom[:, :3]
        return centers

# data = (cam2world nx4x4, obj2cam nx4x4)
def fitcams(data):
    model = ScaleCentroidModel()
    model.fit(data)
    return model

def camloss(model, data):
    cents = model(data)
    return np.linalg.norm(cents - model.mean, axis=1)

def camcost(model, data):
    cents = model(data)
    return jnp.sum(jnp.var(cents, axis=0))

model, inlieridxs = ransac(cams.world_from_cam.matrix4x4, obj2cams, fit_func=fitcams, loss_func=camloss, cost_func=camcost, samp_min=5, inlier_min=5, inlier_thres=0.15, max_iter=50)

In [None]:
model.scale, inlieridxs

In [None]:
scalefactor = model.scale
obj2worlds, cam2worldscaled, scenevtxsscaled = scale_reconstr(scalefactor, cams.world_from_cam.matrix4x4, obj2cams, scenevtxs)
tofig = []
camsinlier = cams.replace(world_from_cam=v3d.Transform.from_matrix(cam2worldscaled))[inlieridxs]
obj2worldsinlier = obj2worlds[inlieridxs]
barrelpts_trf = []
sceneptsscaled = scenepts.replace(p=scenevtxsscaled)
tofig.extend([sceneptsscaled, camscaled])
for i, obj2world in enumerate(obj2worldsinlier):
    barrelpts_trf.append(v3d.Transform.from_matrix(obj2world) @ meshpts)
    tofig.append(barrelpts_trf[-1])
# v3d.make_fig(*tofig)

In [None]:
v3d.make_fig(*get_axes_traces(obj2worldsinlier))

In [None]:
def qangle(q1, q2):
    """
    Angle in radians between 2 quaternions.
    https://math.stackexchange.com/questions/3572459/how-to-compute-the-orientation-error-between-two-3d-coordinate-frames
    """
    qerr = q1 * q2.conjugate()
    if qerr.w < 0:
        qerr *= -1
    err = np.atan2(np.sqrt(qerr.x ** 2 + qerr.y ** 2 + qerr.z ** 2), qerr.w)
    return err

def closest_quat_sym(q1, q2, syms):
    """
    Use q1 as reference, brute force rotate q2 and return that.

    Args:
        syms (dict): Set of symmetry transformations, each given by a dictionary with:
            - 'R': 3x3 ndarray with the rotation matrix.
            - 't': 3x1 ndarray with the translation vector.
    """
    errs = []
    q2_syms = []
    for sym in syms:
        q2_sym = q2 * quaternion.from_rotation_matrix(sym["R"])
        errs.append(qangle(q1, q2_sym))
        q2_syms.append(q2_sym)
    return q2_syms[np.argmin(np.abs(errs))]

In [None]:
with open(Path("/scratch/jeyan/barreldata/models3d/model_info.json"), "rt") as f:
    objinfo = json.load(f)
symTs = get_symmetry_transformations(objinfo["barrelsingle-scaled.ply"], 0.01)
# symTs = get_symmetry_transformations({"symmetries_continuous":[{"axis":[0,0,1],"offset":[0,0,0]}]}, 0.01)
R1 = t3d.euler.euler2mat(0, 0, 0)
R2 = t3d.euler.euler2mat(0, 0, np.pi)
best = closest_quat_sym(quaternion.from_rotation_matrix(R1), quaternion.from_rotation_matrix(R2), symTs)
best, quaternion.as_rotation_matrix(best)

In [None]:
quatsinlier = quaternion.from_rotation_matrix(obj2worldsinlier[..., :3, :3])
ref = quatsinlier[0]
quatssymd = [ref]
for otherquat in quatsinlier[1:]:
    best = closest_quat_sym(ref, otherquat, symTs)
    quatssymd.append(best)
quatssymd = np.array(quatssymd)
obj2worldsinliersym = np.copy(obj2worldsinlier)
obj2worldsinliersym[..., :3, :3] = quaternion.as_rotation_matrix(quatssymd)
v3d.make_fig(*get_axes_traces(obj2worldsinliersym))

In [None]:
# data = (n quaternions)
def qmean(qs, weights=None):
    """https://stackoverflow.com/questions/12374087/average-of-multiple-quaternions"""
    if weights is None:
        weights = np.ones(len(qs))
    qs = np.squeeze(qs)
    Q = quaternion.as_float_array(qs * weights).T
    QQ = Q @ Q.T
    vals, vecs = np.linalg.eig(QQ)
    avg = vecs[:, np.argmax(np.abs(vals))]
    avg = avg / np.linalg.norm(avg)
    return quaternion.from_float_array(avg)

def qloss(model, qs):
    qs = np.squeeze(qs)
    return np.array([qangle(model, q) for q in qs])

def qcost(model, qs):
    qs = np.squeeze(qs)
    return np.sum(qloss(model, qs))

qmeanransac, qinliers = ransac(quatssymd, fit_func=qmean, loss_func=qloss, cost_func=qcost, samp_min=5, inlier_min=5, inlier_thres=0.1, max_iter=50)
qmeanransac, qinliers

In [None]:
meanT = v3d.Transform(R=quaternion.as_rotation_matrix(qmeanransac), t=np.mean(v3d.Transform.from_matrix(obj2worldsinliersym).t, axis=0))
v3d.make_fig(*get_axes_traces(obj2worldsinliersym, scale=0.5), *get_axes_traces(meanT, linewidth=10))

In [None]:
v3d.make_fig(camscaled, meanT @ meshpts, sceneptsscaled)

In [None]:
imgdir = Path("/scratch/jeyan/barreldata/divedata/dive8/barrelddt1/rgb")
imgpaths = list(sorted(imgdir.glob("*.png")))
imgs = [np.array(Image.open(imgpath)) for imgpath in imgpaths]
resdir = Path("/scratch/jeyan/barreldata/results/barrelddt1")
overlaydir = resdir / "fit-overlays"
overlaydir.mkdir(exist_ok=True)
for i, img in enumerate(imgs):
    imgpath = imgpaths[i]
    Image.fromarray(render_v3d(camscaled[i], meanT @ meshpts, radius=4, background=img)).save(overlaydir / f"{imgpaths[i].stem}.png")

In [None]:
obj2camfit = camscaled.world_from_cam.inv @ meanT[..., None]
estposes = []
for i, obj2cam in enumerate(obj2camfit):
    posedata = {
        "img_path": str(imgpaths[i]),
        "img_id": str(i),
        "hypothesis_id": "0",
        "R": obj2cam.R.tolist(),
        "t": obj2cam.t[..., None].tolist(),
    }
    estposes.append(posedata)
with open(resdir / "estimated-poses.json", "wt") as f:
    json.dump(estposes, f)