In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import datetime

import json
import os
from pathlib import Path
from PIL import Image
import gc
import time
import struct

from typing import List, NamedTuple, Optional, Tuple

import cv2
import imageio

import numpy as np

import torch

import sys
sys.path.append(os.path.abspath(os.path.join("external", "bop_toolkit")))
from bop_toolkit_lib import inout
sys.path.append(os.path.abspath(os.path.join("external", "dinov2")))

from utils.misc import array_to_tensor, tensor_to_array, tensors_to_arrays


from utils import (
    corresp_util,
    config_util,
    eval_errors,
    eval_util,
    feature_util,
    infer_pose_util,
    knn_util,
    misc as misc_util,
    pnp_util,
    projector_util,
    repre_util,
    vis_util,
    data_util,
    renderer_builder,
    json_util, 
    logging,
    misc,
    structs,
)

from utils.structs import AlignedBox2f, PinholePlaneCameraModel, CameraModel
from utils.misc import warp_depth_image, warp_image

In [None]:
extractor_name = "dinov2_version=vitl14-reg_stride=14_facet=token_layer=18_logbin=0_norm=1"
output_path = Path("/scratch/jeyan/foundpose/output_barrelddt1_raw_vitl_layer18")
dataset_path = Path("/scratch/jeyan/barreldata/divedata/dive8/barrelddt1/rgb")
mask_path = Path("/scratch/jeyan/barreldata/results/barrelddt1/masks")
cam_json_path = Path("/scratch/jeyan/barreldata/divedata/dive8/barrelddt1/camera.json")
crop_size = (420, 420)
grid_cell_size = 14.0
crop_rel_pad = 0.2

imgpaths = sorted(list(dataset_path.glob("*.jpg")) + list(dataset_path.glob("*.png")))
maskpaths = sorted(list(mask_path.glob("*.png")))

# Prepare feature extractor.
extractor = feature_util.make_feature_extractor(extractor_name)
# Prepare a device.
device = "cuda" if torch.cuda.is_available() else "cpu"
extractor.to(device)

# Create a pose evaluator.
pose_evaluator = eval_util.EvaluatorPose([0])

# Load the object representation.
# logger.info(
#     f"Loading representation for object {0} from dataset {opts.object_dataset}..."
# )
base_repre_dir = Path(output_path, "object_repre")
repre_dir = base_repre_dir
repre: repre_util.FeatureBasedObjectRepre = repre_util.load_object_repre(
    repre_dir=repre_dir,
    tensor_device=device,
)

repre_np = repre_util.convert_object_repre_to_numpy(repre)

In [None]:
repre_np.feat_vectors.shape

In [None]:
feattempmask = repre_np.feat_to_template_ids == 0
# all p_i values
repre_np.feat_vectors[feattempmask].shape
# all x_i values
repre_np.vertices[repre_np.feat_to_vertex_ids[feattempmask]].shape

In [None]:
repre_np.vertices[repre_np.feat_to_vertex_ids[feattempmask]].shape

In [None]:
i = 0
imgpath = imgpaths[i]

In [None]:
# Camera parameters.
# transform is from GT, can we just leave as identity?
with open(Path(cam_json_path), "r") as f:
    camjson = json.load(f)
orig_camera_c2w = PinholePlaneCameraModel(
    camjson["width"], camjson["height"],
    (camjson["fx"], camjson["fy"]), (camjson["cx"], camjson["cy"])
)
orig_image_size = (
    orig_camera_c2w.width,
    orig_camera_c2w.height,
)

# Generate grid points at which to sample the feature vectors.
grid_size = crop_size
grid_points = feature_util.generate_grid_points(
    grid_size=grid_size,
    cell_size=grid_cell_size,
)
grid_points = grid_points.to(device)


# Estimate pose for each object instance.
times = {}

# Get the input image.
orig_image_np_hwc = np.array(Image.open(imgpath)) / 255.0

# Get the modal mask and amodal bounding box of the instance.
# binary mask
orig_mask_modal = np.array(Image.open(maskpaths[i]).convert("L")) / 255.0
sumvert = np.sum(orig_mask_modal, axis=0)
left = np.where(sumvert > 0)[0][0]
right = np.where(sumvert > 0)[0][-1]
sumhor = np.sum(orig_mask_modal, axis=1)
bottom = np.where(sumhor > 0)[0][0]
top = np.where(sumhor > 0)[0][-1]
# bounding box of mask
orig_box_amodal = AlignedBox2f(
    left=left,
    top=top,
    right=right,
    bottom=bottom,
)

# Get box for cropping.
crop_box = misc_util.calc_crop_box(
    box=orig_box_amodal,
    make_square=True,
)

# Construct a virtual camera focused on the crop.
crop_camera_model_c2w = misc_util.construct_crop_camera(
    box=crop_box,
    camera_model_c2w=orig_camera_c2w,
    viewport_size=crop_size,
    viewport_rel_pad=crop_rel_pad,
)

# Map images to the virtual camera.
interpolation = (
    cv2.INTER_AREA
    if crop_box.width >= crop_camera_model_c2w.width
    else cv2.INTER_LINEAR
)
image_np_hwc = warp_image(
    src_camera=orig_camera_c2w,
    dst_camera=crop_camera_model_c2w,
    src_image=orig_image_np_hwc,
    interpolation=interpolation,
)
mask_modal = warp_image(
    src_camera=orig_camera_c2w,
    dst_camera=crop_camera_model_c2w,
    src_image=orig_mask_modal,
    interpolation=cv2.INTER_NEAREST,
)

# Recalculate the object bounding box (it changed if we constructed the virtual camera).
ys, xs = mask_modal.nonzero()
box = np.array(misc_util.calc_2d_box(xs, ys))
box_amodal = AlignedBox2f(
    left=box[0],
    top=box[1],
    right=box[2],
    bottom=box[3],
)

# The virtual camera is becoming the main camera.
camera_c2w = crop_camera_model_c2w

# Extract feature map from the crop.
image_tensor_chw = array_to_tensor(image_np_hwc).to(torch.float32).permute(2, 0, 1).to(device)
image_tensor_bchw = image_tensor_chw.unsqueeze(0)
# BxDxHxW
extractor_output = extractor(image_tensor_bchw)
feature_map_chw = extractor_output["feature_maps"][0]

# Keep only points inside the object mask.
mask_modal_tensor = array_to_tensor(mask_modal).to(device)
query_points = feature_util.filter_points_by_mask(
    grid_points, mask_modal_tensor
)

# Extract features at the selected points, of shape (num_points, feat_dims).
query_features = feature_util.sample_feature_map_at_points(
    feature_map_chw=feature_map_chw,
    points=query_points,
    image_size=(image_np_hwc.shape[1], image_np_hwc.shape[0]),
).contiguous()

# Potentially project features to a PCA space.
if (
    query_features.shape[1] != repre.feat_vectors.shape[1]
    and len(repre.feat_raw_projectors) != 0
):
    query_features_proj = projector_util.project_features(
        feat_vectors=query_features,
        projectors=repre.feat_raw_projectors,
    ).contiguous()

    _c, _h, _w = feature_map_chw.shape
    feature_map_chw_proj = (
        projector_util.project_features(
            feat_vectors=feature_map_chw.permute(1, 2, 0).view(-1, _c),
            projectors=repre.feat_raw_projectors,
        )
        .view(_h, _w, -1)
        .permute(2, 0, 1)
    )
else:
    query_features_proj = query_features
    feature_map_chw_proj = feature_map_chw

In [None]:
feature_map_chw.shape

In [None]:
query_features_proj.shape, feature_map_chw_proj.shape

In [None]:
Fq = feature_map_chw_proj.permute(1, 2, 0).cpu().numpy()

In [None]:
def bilinear_interpolate(im, x, y):
    """Copilot take the wheel."""
    x = np.array(x)
    y = np.array(y)
    x0 = np.floor(x).astype(np.int64)
    x1 = x0 + 1
    y0 = np.floor(y).astype(np.int64)
    y1 = y0 + 1

    x0 = np.clip(x0, 0, im.shape[1] - 1)
    x1 = np.clip(x1, 0, im.shape[1] - 1)
    y0 = np.clip(y0, 0, im.shape[0] - 1)
    y1 = np.clip(y1, 0, im.shape[0] - 1)

    Ia = im[y0, x0]
    Ib = im[y1, x0]
    Ic = im[y0, x1]
    Id = im[y1, x1]

    wa = (x1 - x) * (y1 - y)
    wb = (x1 - x) * (y - y0)
    wc = (x - x0) * (y1 - y)
    wd = (x - x0) * (y - y0)

    return (wa * Ia + wb * Ib + wc * Ic + wd * Id)

In [None]:
R = np.array([
    [0.9824963517866108, -0.18626298081482784, -0.0026496611057800646],
    [0.07184429208626533, 0.36576270685794776, 0.9279310534552504],
    [-0.1718700567889119, -0.9118792377557032, 0.37274245710604814]
])
t = np.array([0.027086201383687152, 0.21964677626568319, 5.018856479067855])

In [None]:
template_id = 245
feattempmask = repre_np.feat_to_template_ids == template_id
# all p_i values
allpi = repre_np.feat_vectors[feattempmask]
# all x_i values
allxi = repre_np.vertices[repre_np.feat_to_vertex_ids[feattempmask]]
Fq = feature_map_chw_proj.permute(1, 2, 0).cpu().numpy()

In [None]:
def robustcost(x, a=-5, c=0.5):
    return sum((abs(a - 2) / a) * (((x / c) ** 2 / abs(a - 2) + 1) ** (a / 2) - 1))

In [None]:
pi = allpi[0]
xi = allxi[0]
x, y = camera_c2w.eye_to_window(R @ xi + t) / 14
robustcost(pi - bilinear_interpolate(Fq, x, y))

In [None]:
import math
def axangle2mat(axis, angle=None):
    """
    Rotation matrix for rotation angle `angle` around `axis`.
    Ripped from transforms3d with some tweaks.

    From: http://en.wikipedia.org/wiki/Rotation_matrix#Axis_and_angle

    Args:
        axis (3 element sequence): vector specifying axis for rotation.
        angle (scalar): angle of rotation in radians. Default is the norm of `axis`.

    Returns
        mat (array shape (3,3)): rotation matrix for specified rotation
    """
    x, y, z = axis
    n = math.sqrt(x*x + y*y + z*z)
    x = x / n
    y = y / n
    z = z / n
    if angle is None:
        angle = n
    c = math.cos(angle); s = math.sin(angle); C = 1-c
    xs = x*s;   ys = y*s;   zs = z*s
    xC = x*C;   yC = y*C;   zC = z*C
    xyC = x*yC; yzC = y*zC; zxC = z*xC
    return np.array([
            [ x*xC+c,   xyC-zs,   zxC+ys ],
            [ xyC+zs,   y*yC+c,   yzC-xs ],
            [ zxC-ys,   yzC+xs,   z*zC+c ]])

def mat2axangle(mat, unit_thresh=1e-5):
    """Return axis, angle and point from (3, 3) matrix `mat`

    Parameters
    ----------
    mat : array-like shape (3, 3)
        Rotation matrix
    unit_thresh : float, optional
        Tolerable difference from 1 when testing for unit eigenvalues to
        confirm `mat` is a rotation matrix.

    Returns
    -------
    axis : array shape (3,)
       vector giving axis of rotation
    angle : scalar
       angle of rotation in radians.

    Examples
    --------
    >>> direc = np.random.random(3) - 0.5
    >>> angle = (np.random.random() - 0.5) * (2*math.pi)
    >>> R0 = axangle2mat(direc, angle)
    >>> direc, angle = mat2axangle(R0)
    >>> R1 = axangle2mat(direc, angle)
    >>> np.allclose(R0, R1)
    True

    Notes
    -----
    http://en.wikipedia.org/wiki/Rotation_matrix#Axis_of_a_rotation
    """
    M = np.asarray(mat, dtype=np.float64)
    # direction: unit eigenvector of R33 corresponding to eigenvalue of 1
    L, W = np.linalg.eig(M.T)
    i = np.where(np.abs(L - 1.0) < unit_thresh)[0]
    if not len(i):
        raise ValueError("no unit eigenvector corresponding to eigenvalue 1")
    direction = np.real(W[:, i[-1]]).squeeze()
    # rotation angle depending on direction
    cosa = (np.trace(M) - 1.0) / 2.0
    if abs(direction[2]) > 1e-8:
        sina = (M[1, 0] + (cosa-1.0)*direction[0]*direction[1]) / direction[2]
    elif abs(direction[1]) > 1e-8:
        sina = (M[0, 2] + (cosa-1.0)*direction[0]*direction[2]) / direction[1]
    else:
        sina = (M[2, 1] + (cosa-1.0)*direction[1]*direction[2]) / direction[0]
    angle = math.atan2(sina, cosa)
    return direction, angle

def featuremetric_loss(tfm, camera, patchdescs, patchvtxs, featmap, patchsize, a=-5, c=0.5):
    # need: projected features from every patch in query rgb image
    # template descriptor + 3d point from patches inside mask
    # coarse R and t, need to parametrize R as quaternion or axangle
    # tfm is 6d [*axangle, *t]
    npatches = len(patchdescs)
    losses = np.zeros(npatches)
    axangle = tfm[:3]
    R = axangle2mat(axangle)
    t = tfm[3:]
    totalerror = 0.0
    for i, pi in enumerate(patchdescs):
        xi = patchvtxs[i]
        x, y = camera.eye_to_window(R @ xi + t) / patchsize
        losses[i] = robustcost(pi - bilinear_interpolate(featmap, x, y), a=a, c=c)
    return losses

def featuremetric_error(tfm, camera, patchdescs, patchvtxs, featmap, patchsize, a=-5, c=0.5):
    # need: projected features from every patch in query rgb image
    # template descriptor + 3d point from patches inside mask
    # coarse R and t, need to parametrize R as quaternion or axangle
    # tfm is 6d [*axangle, *t]
    axangle = tfm[:3]
    R = axangle2mat(axangle)
    t = tfm[3:]
    totalerror = 0.0
    for i, pi in enumerate(patchdescs):
        xi = patchvtxs[i]
        x, y = camera.eye_to_window(R @ xi + t) / patchsize
        totalerror += robustcost(pi - bilinear_interpolate(featmap, x, y), a=a, c=c)
    return totalerror

In [None]:
axang, theta = mat2axangle(R)
axang = axang * theta
tfm = np.concatenate([axang, t])

In [None]:
from scipy.optimize import least_squares, minimize
featuremetric_error(tfm, camera_c2w, allpi, allxi, Fq, 14)

In [None]:
least_squares(featuremetric_loss, tfm, args=(camera_c2w, allpi, allxi, Fq, 14), method="lm")

In [None]:
tfm