yeah

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import datetime

import json
import math
import os
from pathlib import Path
from PIL import Image
import gc
import time
import struct

from typing import List, NamedTuple, Optional, Tuple, Any

import cv2
import imageio
from tqdm import tqdm

import numpy as np

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.utils._pytree import tree_map
from torchmetrics import Metric
import roma
import torch_levenberg_marquardt as tlm
import matplotlib.pyplot as plt
import visu3d as v3d

import sys
sys.path.append(os.path.abspath(os.path.join("external", "bop_toolkit")))
from bop_toolkit_lib import inout
sys.path.append(os.path.abspath(os.path.join("external", "dinov2")))

from utils.misc import array_to_tensor, tensor_to_array, tensors_to_arrays


from utils import (
    corresp_util,
    config_util,
    eval_errors,
    eval_util,
    feature_util,
    infer_pose_util,
    knn_util,
    misc as misc_util,
    pnp_util,
    projector_util,
    repre_util,
    vis_util,
    data_util,
    renderer_builder,
    json_util, 
    logging,
    misc,
    structs,
    geometry
)

from utils.structs import AlignedBox2f, PinholePlaneCameraModel, CameraModel
from utils.misc import warp_depth_image, warp_image

In [None]:
extractor_name = "dinov2_version=vitl14-reg_stride=14_facet=token_layer=18_logbin=0_norm=1"
output_path = Path("/scratch/jeyan/foundpose/output_barrelddt1_raw_vitl_layer18")
dataset_path = Path("/scratch/jeyan/barreldata/divedata/dive8/barrelddt1/rgb")
mask_path = Path("/scratch/jeyan/barreldata/results/barrelddt1/masks")
cam_json_path = Path("/scratch/jeyan/barreldata/divedata/dive8/barrelddt1/camera.json")
crop_size = (420, 420)
grid_cell_size = 14.0
crop_rel_pad = 0.2

imgpaths = sorted(list(dataset_path.glob("*.jpg")) + list(dataset_path.glob("*.png")))
maskpaths = sorted(list(mask_path.glob("*.png")))

# Prepare feature extractor.
extractor = feature_util.make_feature_extractor(extractor_name)
# Prepare a device.
device = "cuda" if torch.cuda.is_available() else "cpu"
extractor.to(device)

# Create a pose evaluator.
pose_evaluator = eval_util.EvaluatorPose([0])

# Load the object representation.
# logger.info(
#     f"Loading representation for object {0} from dataset {opts.object_dataset}..."
# )
base_repre_dir = Path(output_path, "object_repre")
repre_dir = base_repre_dir
repre: repre_util.FeatureBasedObjectRepre = repre_util.load_object_repre(
    repre_dir=repre_dir,
    tensor_device=device,
)

repre_np = repre_util.convert_object_repre_to_numpy(repre)

In [None]:
repre_np.feat_vectors.shape

In [None]:
feattempmask = repre_np.feat_to_template_ids == 0
# all p_i values
repre_np.feat_vectors[feattempmask].shape
# all x_i values
repre_np.vertices[repre_np.feat_to_vertex_ids[feattempmask]].shape

In [None]:
repre_np.vertices[repre_np.feat_to_vertex_ids[feattempmask]].shape

In [None]:
i = 0
imgpath = imgpaths[i]

In [None]:
# Camera parameters.
# transform is from GT, can we just leave as identity?
with open(Path(cam_json_path), "r") as f:
    camjson = json.load(f)
orig_camera_c2w = PinholePlaneCameraModel(
    camjson["width"], camjson["height"],
    (camjson["fx"], camjson["fy"]), (camjson["cx"], camjson["cy"])
)
orig_image_size = (
    orig_camera_c2w.width,
    orig_camera_c2w.height,
)

# Generate grid points at which to sample the feature vectors.
grid_size = crop_size
grid_points = feature_util.generate_grid_points(
    grid_size=grid_size,
    cell_size=grid_cell_size,
)
grid_points = grid_points.to(device)


# Estimate pose for each object instance.
times = {}

# Get the input image.
orig_image_np_hwc = np.array(Image.open(imgpath)) / 255.0

# Get the modal mask and amodal bounding box of the instance.
# binary mask
orig_mask_modal = np.array(Image.open(maskpaths[i]).convert("L")) / 255.0
sumvert = np.sum(orig_mask_modal, axis=0)
left = np.where(sumvert > 0)[0][0]
right = np.where(sumvert > 0)[0][-1]
sumhor = np.sum(orig_mask_modal, axis=1)
bottom = np.where(sumhor > 0)[0][0]
top = np.where(sumhor > 0)[0][-1]
# bounding box of mask
orig_box_amodal = AlignedBox2f(
    left=left,
    top=top,
    right=right,
    bottom=bottom,
)

# Get box for cropping.
crop_box = misc_util.calc_crop_box(
    box=orig_box_amodal,
    make_square=True,
)

# Construct a virtual camera focused on the crop.
crop_camera_model_c2w = misc_util.construct_crop_camera(
    box=crop_box,
    camera_model_c2w=orig_camera_c2w,
    viewport_size=crop_size,
    viewport_rel_pad=crop_rel_pad,
)

# Map images to the virtual camera.
interpolation = (
    cv2.INTER_AREA
    if crop_box.width >= crop_camera_model_c2w.width
    else cv2.INTER_LINEAR
)
image_np_hwc = warp_image(
    src_camera=orig_camera_c2w,
    dst_camera=crop_camera_model_c2w,
    src_image=orig_image_np_hwc,
    interpolation=interpolation,
)
mask_modal = warp_image(
    src_camera=orig_camera_c2w,
    dst_camera=crop_camera_model_c2w,
    src_image=orig_mask_modal,
    interpolation=cv2.INTER_NEAREST,
)

# Recalculate the object bounding box (it changed if we constructed the virtual camera).
ys, xs = mask_modal.nonzero()
box = np.array(misc_util.calc_2d_box(xs, ys))
box_amodal = AlignedBox2f(
    left=box[0],
    top=box[1],
    right=box[2],
    bottom=box[3],
)

# The virtual camera is becoming the main camera.
camera_c2w = crop_camera_model_c2w

# Extract feature map from the crop.
image_tensor_chw = array_to_tensor(image_np_hwc).to(torch.float32).permute(2, 0, 1).to(device)
image_tensor_bchw = image_tensor_chw.unsqueeze(0)
# BxDxHxW
extractor_output = extractor(image_tensor_bchw)
feature_map_chw = extractor_output["feature_maps"][0]

# Keep only points inside the object mask.
mask_modal_tensor = array_to_tensor(mask_modal).to(device)
query_points = feature_util.filter_points_by_mask(
    grid_points, mask_modal_tensor
)

# Extract features at the selected points, of shape (num_points, feat_dims).
query_features = feature_util.sample_feature_map_at_points(
    feature_map_chw=feature_map_chw,
    points=query_points,
    image_size=(image_np_hwc.shape[1], image_np_hwc.shape[0]),
).contiguous()

# Potentially project features to a PCA space.
if (
    query_features.shape[1] != repre.feat_vectors.shape[1]
    and len(repre.feat_raw_projectors) != 0
):
    query_features_proj = projector_util.project_features(
        feat_vectors=query_features,
        projectors=repre.feat_raw_projectors,
    ).contiguous()

    _c, _h, _w = feature_map_chw.shape
    feature_map_chw_proj = (
        projector_util.project_features(
            feat_vectors=feature_map_chw.permute(1, 2, 0).view(-1, _c),
            projectors=repre.feat_raw_projectors,
        )
        .view(_h, _w, -1)
        .permute(2, 0, 1)
    )
else:
    query_features_proj = query_features
    feature_map_chw_proj = feature_map_chw

In [None]:
feature_map_chw.shape, feature_map_chw_proj.shape

In [None]:
query_features_proj.shape, feature_map_chw_proj.shape

In [None]:
R = np.array([
    [0.9824963517866108, -0.18626298081482784, -0.0026496611057800646],
    [0.07184429208626533, 0.36576270685794776, 0.9279310534552504],
    [-0.1718700567889119, -0.9118792377557032, 0.37274245710604814]
])
t = np.array([0.027086201383687152, 0.21964677626568319, 5.018856479067855])

In [None]:
template_id = 245
feattempmask = repre.feat_to_template_ids == template_id
# all p_i values
allpi = repre.feat_vectors[feattempmask]
# all x_i values
allxi = repre.vertices[feattempmask]
Fq = feature_map_chw_proj
templatecam = repre.template_cameras_cam_from_model[template_id]

In [None]:
def bilinear_interpolate_torch(im, x, y):
    """ripped from https://gist.github.com/peteflorence/a1da2c759ca1ac2b74af9a83f69ce20e"""
    x0 = torch.floor(x).long()
    x1 = x0 + 1

    y0 = torch.floor(y).long()
    y1 = y0 + 1

    # doesn't deal with edge and out of bounds, i don't care though
    x0 = torch.clamp(x0, 0, im.shape[1] - 1)
    x1 = torch.clamp(x1, 0, im.shape[1] - 1)
    y0 = torch.clamp(y0, 0, im.shape[0] - 1)
    y1 = torch.clamp(y1, 0, im.shape[0] - 1)

    Ia = im[y0, x0]
    Ib = im[y1, x0]
    Ic = im[y0, x1]
    Id = im[y1, x1]

    wa = (x1.float() - x) * (y1.float() - y)
    wb = (x1.float() - x) * (y - y0.float())
    wc = (x - x0.float()) * (y1.float() - y)
    wd = (x - x0.float()) * (y - y0.float())

    return torch.t((torch.t(Ia) * wa)) + torch.t(torch.t(Ib) * wb) + torch.t(torch.t(Ic) * wc) + torch.t(torch.t(Id) * wd)

def robustlosstorch(x: torch.Tensor, a=-5, c=0.5):
    """
    Args:
        x (torch.Tensor): Nxd tensor
    
    Returns:
        loss (torch.Tensor): N tensor
    """
    return torch.sum((abs(a - 2) / a) * (((x / c) ** 2 / abs(a - 2) + 1) ** (a / 2) - 1), dim=-1)


def descriptor_from_pose(
    q: torch.Tensor, t: torch.Tensor, camera: CameraModel, patchvtxs: torch.Tensor,
    featmap: torch.Tensor, patchsize: int, device=None
):
    """
    $$F_q(\pi(Rx_i+t)/s)$$
    """
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    T = roma.RigidUnitQuat(q, t)
    # cropped camera has a unique transform, close to identity
    world2eye = v3d.Transform.from_matrix(camera.T_world_from_eye).inv
    qcam = roma.rotmat_to_unitquat(torch.tensor(world2eye.R)).float().to(device)
    camT = roma.RigidUnitQuat(qcam, torch.tensor(world2eye.t).float().to(device))
    # need to convert (fx,fy), (cx,cy) to tensors
    camf = torch.tensor(camera.f).float().to(device)
    camc = torch.tensor(camera.c).float().to(device)
    # consider first transform to still be world space (identity camera pose)
    # then transform to cropped camera space
    camvtxs = camT[None].apply(T[None].apply(patchvtxs))
    projected = camera.project(camvtxs) * camf + camc
    patchproj = projected / patchsize
    # torch grid_sample doesn't play nicely with jacrev for some reason
    # projfeatures = feature_util.sample_feature_map_at_points(featmap, patchproj, featmap.shape[1:])
    projfeatures = bilinear_interpolate_torch(featmap.permute(1, 2, 0), patchproj[:, 0], patchproj[:, 1])
    return projfeatures


def featuremetric_loss(
    q: torch.Tensor, t: torch.Tensor, camera: CameraModel, patchdescs: torch.Tensor,
    patchvtxs: torch.Tensor, featmap: torch.Tensor, patchsize: int, a: float=-5, c: float=0.5,
    device=None
):
    """
    Args:
        q: quaternion input
        t: translation input
        patchdescs (nxd)
        patchvtxs (nx3)
        featmap (dxaxa)

    Returns:
        loss (n tensor)
    """
    # need: projected features from every patch in query rgb image
    # template descriptor + 3d point from patches inside mask
    # coarse R and t, need to parametrize R as quaternion or axangle
    projfeatures = descriptor_from_pose(q, t, camera, patchvtxs, featmap, patchsize, device=device)
    featurediff = patchdescs - projfeatures
    return robustlosstorch(featurediff, a=a, c=c)


def featuremetric_cost(
    q: torch.Tensor, t: torch.Tensor, camera: CameraModel, patchdescs: torch.Tensor,
    patchvtxs: torch.Tensor, featmap: torch.Tensor, patchsize: int, a: float=-5, c: float=0.5,
    device=None
):
    """
    Args:
        q: quaternion input
        t: translation input
    """
    losses = featuremetric_loss(q, t, camera, patchdescs, patchvtxs, featmap, patchsize, a=a, c=c, device=device)
    return torch.sum(losses)

class PoseDescriptorModel(nn.Module):
    def __init__(self, R, t, camera: CameraModel, featmap: torch.Tensor, patchsize: int, device=None):
        super().__init__()
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            self.device = device
        self.qtorch = nn.Parameter(roma.rotmat_to_unitquat(torch.tensor(R)).float().to(device), requires_grad=True)
        self.ttorch = nn.Parameter(torch.tensor(t).float().to(device), requires_grad=True)
        self.test = nn.Linear(256, 256)
        self.camera = camera
        self.featmap = featmap.to(device)
        self.patchsize = patchsize

    def forward(self, patchvtxs: torch.Tensor):
        out = descriptor_from_pose(self.qtorch, self.ttorch, self.camera, patchvtxs, self.featmap, self.patchsize, device=self.device)
        return self.test(out)
        # return out

class RobustLoss(tlm.loss.Loss):
    def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
        return torch.sum(torch.sqrt(robustlosstorch(y_true - y_pred)))

    def residuals(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
        return torch.sqrt(robustlosstorch(y_true - y_pred))

class PatchDataset(Dataset):
    def __init__(self, patchvtxs: torch.Tensor, patchdescs: torch.Tensor, device=None):
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            self.device = device
        self.patchvtxs = patchvtxs.to(device)  # nx3
        self.patchdescs = patchdescs.to(device)  # nxd

    def __len__(self):
        return len(self.patchvtxs)

    def __getitem__(self, idx):
        return self.patchvtxs[idx], self.patchdescs[idx]

def tree_to_device(tree: Any, device: torch.device | str) -> Any:
    """Recursively move all tensor leaves in a pytree to the specified device.

    Args:
        tree: The pytree containing tensors and non-tensor leaves.
        device: The target device (e.g. CPU or GPU) to move the tensors to.

    Returns:
        A new pytree with every tensor moved to the specified device.
    """
    return tree_map(lambda x: x.to(device) if isinstance(x, torch.Tensor) else x, tree)

def fit(
    training_module,
    dataloader: DataLoader,
    epochs: int,
    metrics: dict[str, Metric] | None = None,
    overwrite_progress_bar: bool = True,
    update_every_n_steps: int = 1,
) -> None:
    """Fit function with support for TrainingModule and torchmetrics.

    Trains the model for a specified number of epochs. It supports logging metrics using
    `torchmetrics` and provides detailed progress tracking using `tqdm`.

    Args:
        training_module: A `TrainingModule` encapsulating the training logic.
        dataloader: A PyTorch DataLoader.
        epochs: The number of epochs.
        metrics: Optional dict of torchmetrics.Metric objects.
        overwrite_progress_bar: If True, mimic a single-line progress bar similar
            to PyTorch Lightning (old bars overwritten).
        update_every_n_steps: Update the progress bar and displayed logs every n steps.
    """
    assert update_every_n_steps > 0
    device = training_module.device
    steps = len(dataloader)
    stop_training = False

    if metrics:
        metrics = {name: metric.to(device) for name, metric in metrics.items()}

    losses = []
    for epoch in range(epochs):
        if stop_training:
            break

        # Create a new progress bar for this epoch
        progress_bar = tqdm(
            total=steps,
            desc=f'Epoch {epoch + 1}/{epochs}',
            leave=not overwrite_progress_bar,  # Leave bar if overwrite is False
            dynamic_ncols=True,
        )
        total_loss = 0.0
        steps_since_update = 0

        for step, (inputs, targets) in enumerate(dataloader):
            # Ensure that inputs and targets are on the same device as the model
            inputs = tree_to_device(inputs, device)
            targets = tree_to_device(targets, device)

            # Perform a training step
            outputs, loss, stop_training, logs = training_module.training_step(
                inputs, targets
            )

            total_loss += loss.item()

            # Update metrics if provided
            if metrics:
                for name, metric in metrics.items():
                    metric(outputs, targets)

            # Format logs
            formatted_logs = {'loss': f'{loss:.4e}'}
            if metrics:
                for name, metric in metrics.items():
                    formatted_logs[name] = metric.compute().item()
            for key, value in logs.items():
                if isinstance(value, torch.Tensor):
                    value = value.item()
                formatted_logs[key] = (
                    f'{value:.4e}' if isinstance(value, float) else str(value)
                )

            steps_since_update += 1
            if (
                steps_since_update == update_every_n_steps
                or step == steps - 1
                or stop_training
            ):
                # Update the progress bar and logs
                progress_bar.update(steps_since_update)
                progress_bar.set_postfix(formatted_logs)
                steps_since_update = 0

            if stop_training:
                # End early, ensure progress bar remains visible
                progress_bar.leave = True
                break

        with torch.no_grad():
            qraw = training_module.model.qtorch.data
            training_module.model.qtorch.data = qraw / torch.norm(qraw)
        losses.append(total_loss)
        # Reset metrics at the end of the epoch
        if metrics:
            for metric in metrics.values():
                metric.reset()

        # Epoch summary
        avg_loss = total_loss / steps
        if overwrite_progress_bar:
            progress_bar.set_postfix({'epoch_avg_loss': f'{avg_loss:.4e}'})
        else:
            progress_bar.write(
                f'Epoch {epoch + 1} complete. Average loss: {avg_loss:.4e}'
            )

        # Ensure the final progress bar is left visible
        if epoch == epochs - 1 or stop_training:
            progress_bar.leave = True

        progress_bar.close()

    # Final training summary
    if overwrite_progress_bar:
        print(f'Training complete. Final epoch average loss: {avg_loss:.4e}')
    return losses

In [None]:
device = "cuda:1"
patchloader = DataLoader(PatchDataset(allxi, allpi, device=device), batch_size=9999, shuffle=False)
posemodel = PoseDescriptorModel(R, t, camera_c2w, Fq, 14, device=device).to(device)
fit(
  tlm.training.LevenbergMarquardtModule(
    model=posemodel,
    loss_fn=RobustLoss(),
    learning_rate=0.1,
    attempts_per_step=10,
    solve_method="qr",
    use_vmap=False,
  ),
  patchloader,
  epochs=50,
)

In [None]:
roma.rotmat_to_unitquat(torch.tensor(R)), t

In [None]:
posemodel.qtorch.data, posemodel.ttorch.data

In [None]:
from torch.func import jacrev
x = torch.randn(5)
jacobian = jacrev(torch.sin)(x)
expected = torch.diag(torch.cos(x))
assert torch.allclose(jacobian, expected)

In [None]:
jacobian

In [None]:
list(posemodel.parameters())

In [None]:
qtorch = roma.rotmat_to_unitquat(torch.tensor(R)).float().cuda().requires_grad_()
ttorch = torch.tensor(t).float().cuda().requires_grad_()
print(qtorch, ttorch)
optimizer = torch.optim.Adam([qtorch, ttorch], lr=0.1)

losses = []
for _ in range(100):
    loss = featuremetric_cost(qtorch, ttorch, camera_c2w, allpi, allxi, Fq, 14)
    losses.append(loss.cpu().detach().numpy())
    loss.backward()
    optimizer.step()
    with torch.no_grad():
        qtorch = roma.quat_normalize(qtorch)

In [None]:
qtorch, ttorch

In [None]:
repre.vertices

In [None]:
orig_camera_c2w

In [None]:
q = roma.rotmat_to_unitquat(torch.tensor(R)).float().cuda()
t = torch.tensor(t).float().cuda()
T = roma.RigidUnitQuat(q, t)
# cam = orig_camera_c2w
cam = camera_c2w
camera_c2w.T_world_from_eye
camf = torch.tensor(cam.f).float().cuda()
camc = torch.tensor(cam.c).float().cuda()
projected = cam.project(T[None].apply(repre.vertices)) * camf + camc
projected = projected.cpu().numpy()
plt.plot(projected[:, 1], projected[:, 0])
plt.show()

In [None]:
world2eye = v3d.Transform.from_matrix(camera_c2w.T_world_from_eye).inv
qcam = roma.rotmat_to_unitquat(torch.tensor(world2eye.R)).float().cuda()
qcam
camT = roma.RigidUnitQuat(qcam, torch.tensor(world2eye.t).float().cuda())

In [None]:
camT

In [None]:
camera_c2w.T_world_from_eye

In [None]:
feature_map_chw.shape

In [None]:
image_np_hwc.shape