# SSE pose fitting using Pytorch3D differential rendering (Render-and-Compare)

---
## Overview

This program calculates the pose of the object with respect to the camera using PyTorch3D differentiable rendering. The method uses the difference between silhouettes as a cost function. It is a simple version of a render-and-compare approach.

We assume that the SSE mesh file in `.ply` (or `.obj`) format are present in the directory `assets/`. This directory is created by the code and its files are downloaded from a Dropbox shared link.


The output of the demo is a `.gif` showing the (poses) iterations of the optimization.

This program was adapted from the example provided here: https://github.com/facebookresearch/pytorch3d/blob/main/docs/tutorials/camera_position_optimization_with_differentiable_rendering.ipynb

In this program, we will learn the $[x, y, z]$ position of a camera given a reference image using differentiable rendering.

We will first initialize a renderer with a starting position for the camera. We will then use this to generate an image, compute a loss with the reference image, and finally backpropagate through the entire pipeline to update the position of the camera.

## PyTorch3D

Modules `torch` and `torchvision` are required. If `pytorch3d` is not installed, install it using the following cell. Here, I modified to install PyTorch3D from my own pre-built wheel. Using my own pytorch3d wheel allows for faster installation. Installing from source takes a few minutes to complete.

The pre-built PyTorch3D wheel is downloaded from my Dropbox (shared link). Another copy of the wheel is also stored in my Google Drive, and is located at: `/content/drive/MyDrive/research/projects/slosh_project/slosh_project_team_files/Colab_wheels/pytorch3d-0.7.8-cp311-cp311-linux_x86_64.whl`



---
## Main steps


### Load the CAD model file

We will load a CAD model (i.e., `ply` format) file and create a **Meshes** object. **Meshes** is a unique datastructure provided in PyTorch3D for working with **batches of meshes of different sizes**. It has several useful class methods which are used in the rendering pipeline.

### Create a renderer

A **renderer** in PyTorch3D is composed of a **rasterizer** and a **shader** which each have a number of subcomponents such as a **camera** (orthographic/perspective). Here, we initialize some of these components and use default values for the rest.

For optimizing the camera position we will use a renderer which produces a **silhouette** of the object only and does not apply any **lighting** or **shading**. We will also initialize another renderer which applies full **Phong shading** and use this for visualizing the outputs.

### Create a reference image

We will first position the teapot and generate an image. We use helper functions to rotate the teapot to a desired viewpoint. Then we can use the renderers to produce an image. Here we will use both renderers and visualize the silhouette and full shaded image.

The world coordinate system is defined as +Y up, +X left and +Z in.

We defined a camera which is positioned on the positive z axis hence sees the spout to the right.

### Initialize the model and optimizer

Create an instance of the **model** above and set up an **optimizer** for the camera position parameter.

### Run the optimization

We run several iterations of the forward and backward pass and save outputs every 10 iterations. When this has finished take a look at `./optimization_sequence.gif` for a cool gif of the optimization process!

---


# Installation and environment setup


## Functions to install libraries and dependencies

In [None]:
# -------------------- Platform Handling --------------------
class PlatformManager:
    def __init__(self):
        self.platform, self.local_path = self.detect_platform()

    @staticmethod
    def mount_gdrive():
        from google.colab import drive
        drive.mount('/content/drive')

    @staticmethod
    def detect_platform() -> tuple[str, str]:
        """Detect platform and return its name and the local path"""
        import os

        computing_platform = 'LocalPC'

        if os.getenv('RUNPOD_POD_ID'):
            computing_platform = "RunPod"
            print("Running on RunPod.")
            local_path = "/workspace/"
        elif 'content' in str(os.getcwd()):
            computing_platform = "Colab"
            print("Running on Colab.")
            local_path = "/content/"
        elif os.getenv("LIGHTNING_ARTIFACTS_DIR"):
            computing_platform = "LightningAI"
            print("Running on Lightning AI Studio")
            local_path = os.getenv("LIGHTNING_ARTIFACTS_DIR") + '/'
        else:
            local_path = os.getcwd() + '/'

        return computing_platform, local_path


# -------------------- Installation Helpers --------------------
class DependencyInstaller:
    @staticmethod
    def install_glut(computing_platform):
        !pip install --upgrade pip
        if computing_platform == "LightningAI":
            !sudo apt -qq update
            !sudo apt install -y freeglut3-dev libglew-dev libsdl2-dev
        elif computing_platform == "RunPod":
            !apt -qq update
            !apt install -y freeglut3-dev libglew-dev libsdl2-dev

    @staticmethod
    def install_opengl():
        import subprocess
        import sys
        try:
            subprocess.check_call([sys.executable, "-m", "pip", "install", "PyOpenGL"])
            subprocess.check_call([sys.executable, "-m", "pip", "install", "PyOpenGL_accelerate"])
        except subprocess.CalledProcessError as e:
            print(f"Failed to install OpenGL: {e}")

    @staticmethod
    def get_pytorch3d_version_string():
        import torch, sys
        pyt_version_str = torch.__version__.split("+")[0].replace(".", "")
        version_str = "".join([
            f"py3{sys.version_info.minor}_cu",
            torch.version.cuda.replace(".", ""),
            f"_pyt{pyt_version_str}"
        ])
        return version_str


# -------------------- PyTorch3D Installer --------------------
class PyTorch3DInstaller:
    def __init__(self, computing_platform, local_path):
        self.platform = computing_platform
        self.local_path = local_path

    def install(self):
        import os, sys, torch
        need_pytorch3d = False

        !pip install --upgrade pip

        DependencyInstaller.install_glut(self.platform)
        DependencyInstaller.install_opengl()

        version_str = DependencyInstaller.get_pytorch3d_version_string()
        print(f"\nPyTorch3D to be installed: {version_str}\n")

        %pip install iopath

        if sys.platform.startswith("linux"):
            print(f"Trying to install wheel for PyTorch3D. Running on a {self.platform} instance.")

            if self.platform in {"RunPod", "LightningAI"}:
                %pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html

            elif self.platform == "Colab":
                os.chdir(self.local_path)
                print("Downloading my pre-built PyTorch3D Wheel from Dropbox")
                !wget -O pytorch3d-0.7.8-cp311-cp311-linux_x86_64.whl https://www.dropbox.com/scl/fi/qfv89iszrtwrbkkkvfzzp/pytorch3d-0.7.8-cp311-cp311-linux_x86_64.whl?rlkey=hr1tqcsczgblln2zvtscq4wy0&dl=0
                my_pytorch3D_path = self.local_path + "pytorch3d-0.7.8-cp311-cp311-linux_x86_64.whl"
                %pip install pytorch3d -f $my_pytorch3D_path
                print("Deleting the wheel file to save space.")
                !rm $my_pytorch3D_path

        # Check installation
        try:
            import pytorch3d
        except ImportError:
            need_pytorch3d = True

        if need_pytorch3d:
            print(f"failed to find/install wheel for {version_str}")
            print("Installing PyTorch3D from source")
            %pip install ninja --root-user-action ignore
            %pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'
        else:
            try:
                import pytorch3d
            except:
                print(f"❌ PyTorch3D failed to install.")
                print("🤷 I don't know what happened.")
            else:
                print(f"✅ PyTorch3D successfully installed!")


## Install required libraries and setup environment

In [None]:
# Detect platform
platform_mgr = PlatformManager()
platform = platform_mgr.platform
local_path = platform_mgr.local_path

# # Optional: Mount GDrive if on Colab
# if platform == "Colab":
#     platform_mgr.mount_gdrive()

# Install PyTorch3D
installer = PyTorch3DInstaller(platform, local_path)
installer.install()

# Image grid vizualization (Meta)
import os
filename = "plot_image_grid.py"
url = "https://raw.githubusercontent.com/facebookresearch/pytorch3d/main/docs/tutorials/utils/plot_image_grid.py"
if not os.path.exists(filename):
    !wget {url}


!pip -q install trimesh pyrender opencv-python matplotlib pytorch-lightning #==1.8.1


## Main imports

In [None]:
###-------------------------------------------------------------------###
#                                Imports
###-------------------------------------------------------------------###
import os
import torch
import numpy as np
from tqdm.notebook import tqdm
import imageio
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from skimage import img_as_ubyte
import requests
import shutil
from pathlib import Path

import cv2

# io utils
from pytorch3d.io import load_obj, load_ply

# Util function for loading meshes
from pytorch3d.io import load_objs_as_meshes


# datastructures
from pytorch3d.structures import Meshes

# 3D transformations functions
from pytorch3d.transforms import Rotate, Translate

from pytorch3d.renderer.cameras import look_at_rotation
from pytorch3d.transforms import RotateAxisAngle

# rendering components
from pytorch3d.renderer import (
    FoVPerspectiveCameras, look_at_view_transform, look_at_rotation,
    RasterizationSettings, MeshRenderer, MeshRasterizer, BlendParams,
    SoftSilhouetteShader, HardPhongShader, PointLights, TexturesVertex,
    PerspectiveCameras, camera_position_from_spherical_angles, SoftPhongShader
)


from plot_image_grid import image_grid


from typing import Optional, Literal, Tuple, Dict, Any

from PIL import Image, ImageOps

# Utility class

## Old utilities

In [None]:
# #---------------------------- IMPORTS -----------------------------------------
# import os
# import torch
# import numpy as np
# from tqdm.notebook import tqdm
# import imageio
# import torch.nn as nn
# import torch.nn.functional as F
# import matplotlib.pyplot as plt
# from skimage import img_as_ubyte
# import requests
# import shutil
# from pathlib import Path

# import cv2

# # io utils
# from pytorch3d.io import load_obj, load_ply

# # Util function for loading meshes
# from pytorch3d.io import load_objs_as_meshes


# # datastructures
# from pytorch3d.structures import Meshes

# # 3D transformations functions
# from pytorch3d.transforms import Rotate, Translate

# from pytorch3d.renderer.cameras import look_at_rotation
# from pytorch3d.transforms import RotateAxisAngle

# # rendering components
# from pytorch3d.renderer import (
#     FoVPerspectiveCameras, look_at_view_transform, look_at_rotation,
#     RasterizationSettings, MeshRenderer, MeshRasterizer, BlendParams,
#     SoftSilhouetteShader, HardPhongShader, PointLights, TexturesVertex,
#     PerspectiveCameras, camera_position_from_spherical_angles, SoftPhongShader
# )


# from plot_image_grid import image_grid


# from typing import Optional, Literal, Tuple, Dict, Any

# from PIL import Image, ImageOps

# # #------------------------------------------------------------------------------


# # class SSE_Util:
# #     def __init__(self, local_path: str):
# #         self.local_path = local_path
# #         self.assets_dir = os.path.join(self.local_path, "assets/")


# #     def create_gif_writer(self, filepath, duration=0.5):
# #         return imageio.get_writer(filepath, mode='I', duration=duration)


# #     def get_camera_position(self, distance, elevation, azimuth, degrees=True, device="cpu"):
# #         camera_pos = camera_position_from_spherical_angles(
# #             distance=distance,
# #             elevation=elevation,
# #             azimuth=azimuth,
# #             degrees=degrees
# #         )
# #         return camera_pos.to(device)



# #     def camera_center_to_dist_elev_azim(self, C: torch.Tensor):
# #         """
# #         Convert camera center C (…,3) to (dist, elev_deg, azim_deg) as PyTorch3D expects.
# #         Convention: Y is up; azim=0 points along +Z; positive azim rotates toward +X.
# #         Returns tensors shaped like C[...,0].
# #         """
# #         if C.ndim == 1:
# #             C = C[None, :]  # (1,3)
# #             squeeze = True
# #         else:
# #             squeeze = False

# #         x, y, z = C[..., 0], C[..., 1], C[..., 2]
# #         dist = torch.linalg.norm(C, dim=-1)

# #         # avoid divide-by-zero at the poles/origin
# #         rho = torch.sqrt(torch.clamp(x*x + z*z, min=1e-12))

# #         elev = torch.rad2deg(torch.atan2(y, rho))     # [-90, 90]
# #         azim = torch.rad2deg(torch.atan2(x, z))       # (-180, 180]

# #         if squeeze:
# #             dist, elev, azim = dist[0], elev[0], azim[0]
# #         return dist, elev, azim

# #     # Possible fix
# #     # def add_camera_roll_to_RT(self, R, T, roll_deg, device=None, mode="world"):
# #     #     """
# #     #     Compose a Z-axis roll into (R, T) and keep the same camera center C.
# #     #     Grad-safe: preserves autograd paths from both R/T and roll_deg.
# #     #     """
# #     #     import torch

# #     #     # tensors / shapes (preserve graph)
# #     #     if not torch.is_tensor(R): R = torch.as_tensor(R)
# #     #     if not torch.is_tensor(T): T = torch.as_tensor(T)
# #     #     dev   = device or R.device
# #     #     dtype = torch.float32

# #     #     R = R.to(dev, dtype)
# #     #     T = T.to(dev, dtype)

# #     #     unbatched = (R.ndim == 2)
# #     #     if unbatched:
# #     #         R = R[None, ...]   # (1,3,3)
# #     #         T = T[None, ...]   # (1,3)

# #     #     # camera center C from (R,T):  T = -R^T C  =>  C = -R @ T
# #     #     C = -torch.matmul(R, T[..., None]).squeeze(-1)  # (B,3)

# #     #     # roll matrix about camera Z  (NO float(...), NO torch.tensor(...))
# #     #     theta = torch.as_tensor(roll_deg, dtype=dtype, device=dev).reshape(1)  # keeps requires_grad if Parameter
# #     #     c = torch.cos(torch.deg2rad(theta))
# #     #     s = torch.sin(torch.deg2rad(theta))

# #     #     # constants tied to same dtype/device
# #     #     z = torch.zeros_like(c)
# #     #     o = torch.ones_like(c)

# #     #     Rz = torch.stack([
# #     #         torch.stack([ c, -s, z], dim=-1),
# #     #         torch.stack([ s,  c, z], dim=-1),
# #     #         torch.stack([ z,  z,  o], dim=-1),
# #     #     ], dim=1)  # (1,3,3)

# #     #     # compose
# #     #     if mode == "camera":
# #     #         R_new = torch.matmul(Rz, R)      # roll in camera coords
# #     #     elif mode == "world":
# #     #         R_new = torch.matmul(R, Rz)      # roll in world coords
# #     #     else:
# #     #         raise ValueError("mode must be 'camera' or 'world'")

# #     #     # keep the same camera center
# #     #     T_new = -torch.matmul(R_new.transpose(1, 2), C[..., None]).squeeze(-1)  # (B,3)

# #     #     if unbatched:
# #     #         R_new, T_new = R_new[0], T_new[0]
# #     #     return R_new, T_new



# #     def add_camera_roll_to_RT(self, R, T, roll_deg, device=None, mode="camera"):
# #         """
# #         Compose a Z-axis roll into (R, T) and keep the same camera center C.

# #         Args
# #           R: (3,3) or (1,3,3) world->cam rotation
# #           T: (3,)  or (1,3)   world->cam translation (PyTorch3D: X_cam = R @ X_world + T)
# #           roll_deg: float degrees
# #           mode: "camera" -> R' = Rz @ R   (roll in camera frame)
# #                 "world"  -> R' = R  @ Rz  (roll in world frame)

# #         Returns
# #           R_new, T_new with the same rank as inputs
# #         """
# #         # tensors / shapes
# #         if not torch.is_tensor(R): R = torch.as_tensor(R)
# #         if not torch.is_tensor(T): T = torch.as_tensor(T)
# #         dev   = device or R.device
# #         dtype = torch.float32
# #         R = R.to(dev, dtype)
# #         T = T.to(dev, dtype)

# #         unbatched = (R.ndim == 2)
# #         if unbatched:
# #             R = R[None, ...]   # (1,3,3)
# #             T = T[None, ...]   # (1,3)

# #         # camera center C from (R,T):  T = -R^T C  =>  C = -R @ T
# #         C = -torch.matmul(R, T[..., None]).squeeze(-1)  # (B,3)

# #         # roll matrix about camera Z
# #         theta = torch.tensor(float(roll_deg), dtype=dtype, device=dev)
# #         c, s = torch.cos(torch.deg2rad(theta)), torch.sin(torch.deg2rad(theta))
# #         z = torch.tensor(0.0, dtype=dtype, device=dev)
# #         o = torch.tensor(1.0, dtype=dtype, device=dev)
# #         Rz = torch.stack([
# #             torch.stack([ c, -s, z]),
# #             torch.stack([ s,  c, z]),
# #             torch.stack([ z,  z,  o]),
# #         ], dim=0).expand(R.shape[0], -1, -1)  # (B,3,3)

# #         # compose
# #         if mode == "camera":
# #             R_new = torch.matmul(Rz, R)      # roll in camera coords
# #         elif mode == "world":
# #             R_new = torch.matmul(R, Rz)      # roll in world coords
# #         else:
# #             raise ValueError("mode must be 'camera' or 'world'")

# #         # keep the same camera center
# #         T_new = -torch.matmul(R_new.transpose(1, 2), C[..., None]).squeeze(-1)  # (B,3)

# #         if unbatched:
# #             R_new, T_new = R_new[0], T_new[0]
# #         return R_new, T_new


# #     def read_rgb_cutout_black_bg(self, path: str) -> np.ndarray:
# #         """
# #         Read an image file as an RGB cutout composited on a BLACK background.

# #         Args:
# #             path: Path to the image (PNG/JPEG/etc.). Alpha is respected if present.

# #         Returns:
# #             np.ndarray of shape (H, W, 3).
# #         """
# #         # Load and fix EXIF orientation, ensure RGBA so we have an alpha channel to composite
# #         img = Image.open(path)
# #         img = ImageOps.exif_transpose(img).convert("RGBA")

# #         rgba = np.asarray(img).astype(np.float32)  # (H,W,4), 0..255
# #         rgb  = rgba[..., :3]
# #         a    = rgba[..., 3:4] / 255.0             # (H,W,1) in [0,1]

# #         # Composite on BLACK: out = rgb * alpha + black * (1 - alpha) == rgb * alpha
# #         rgb_black = rgb * a                        # still 0..255 range (float)

# #         return np.clip(rgb_black, 0, 255).astype(np.uint8)  # (H,W,3), uint8


# #     def center_cutout_rgb_uint8(self, img: np.ndarray, black_thresh: int = 5) -> np.ndarray:
# #         """
# #         Center an RGB cutout (black background) on a same-size black canvas.
# #         Input / Output: (H, W, 3) uint8 [0..255]. No scaling.

# #         black_thresh: pixels with any channel > black_thresh are treated as foreground.
# #         """
# #         if img.ndim != 3 or img.shape[2] != 3 or img.dtype != np.uint8:
# #             raise ValueError("Expected (H, W, 3) uint8 input.")

# #         H, W, _ = img.shape
# #         # Foreground mask: "non-black"
# #         mask = (img.max(axis=2) > black_thresh)
# #         if not mask.any():
# #             return img.copy()  # nothing to center

# #         ys, xs = np.where(mask)
# #         y0, y1 = int(ys.min()), int(ys.max()) + 1
# #         x0, x1 = int(xs.min()), int(xs.max()) + 1

# #         roi = img[y0:y1, x0:x1, :]        # (h, w, 3)
# #         h, w = roi.shape[:2]

# #         # Top-left to center
# #         sy = (H - h) // 2
# #         sx = (W - w) // 2

# #         out = np.zeros_like(img)          # black canvas
# #         out[sy:sy+h, sx:sx+w, :] = roi
# #         return out

# #     import numpy as np

# #     def crop_center_to_size_uint8(self, img: np.ndarray, out_size: tuple[int, int], fill_color=(0,0,0)) -> np.ndarray:
# #         """
# #         Crop a centered window of size (W_out, H_out) from an RGB cutout (black background),
# #         preserving scale & centering. No resizing. Pads with black if the crop exceeds bounds.

# #         Args:
# #             img: (H, W, 3) uint8.
# #             out_size: (W_out, H_out) desired output size.
# #             fill_color: RGB tuple for padding if needed (default black).

# #         Returns:
# #             (H_out, W_out, 3) uint8.
# #         """
# #         if img.ndim != 3 or img.shape[2] != 3 or img.dtype != np.uint8:
# #             raise ValueError("Expected (H, W, 3) uint8 input.")

# #         W_out, H_out = map(int, out_size)
# #         H, W = img.shape[:2]

# #         # Centered crop box in source image
# #         left   = (W - W_out) // 2
# #         top    = (H - H_out) // 2
# #         right  = left + W_out
# #         bottom = top + H_out

# #         # Compute overlap with source (clip to bounds)
# #         src_x0 = max(0, left)
# #         src_y0 = max(0, top)
# #         src_x1 = min(W, right)
# #         src_y1 = min(H, bottom)

# #         # Destination positions (where to paste inside output canvas)
# #         dst_x0 = max(0, -left)
# #         dst_y0 = max(0, -top)
# #         dst_x1 = dst_x0 + (src_x1 - src_x0)
# #         dst_y1 = dst_y0 + (src_y1 - src_y0)

# #         # Prepare output canvas
# #         out = np.empty((H_out, W_out, 3), dtype=np.uint8)
# #         out[...] = np.asarray(fill_color, dtype=np.uint8)

# #         # Paste the overlapping region
# #         if src_x1 > src_x0 and src_y1 > src_y0:
# #             out[dst_y0:dst_y1, dst_x0:dst_x1] = img[src_y0:src_y1, src_x0:src_x1]

# #         return out


# #     import numpy as np
# #     import cv2

# #     def add_alpha_from_black_bg_uint8(self,
# #         img_rgb: np.ndarray,
# #         *,
# #         black_thresh: int = 5,     # pixel is FG if any channel > black_thresh
# #         mode: str = "binary",      # "binary" or "soft"
# #         close_kernel: int = 3      # 0/1 to disable; else use 3,5,7...
# #     ) -> np.ndarray:
# #         """
# #         Add alpha to an RGB cutout (black background), closing holes in the mask.

# #         Input:  img_rgb (H, W, 3) uint8
# #         Output: rgba     (H, W, 4) uint8
# #         """
# #         if img_rgb.ndim != 3 or img_rgb.shape[2] != 3 or img_rgb.dtype != np.uint8:
# #             raise ValueError("Expected (H, W, 3) uint8.")

# #         # 1) Initial mask (0/255)
# #         mask = (img_rgb.max(axis=2) > black_thresh).astype(np.uint8) * 255  # (H,W)

# #         # 2) Fill holes (classic OpenCV recipe)
# #         im_flood = mask.copy()
# #         h, w = mask.shape
# #         ff_mask = np.zeros((h + 2, w + 2), np.uint8)
# #         cv2.floodFill(im_flood, ff_mask, (0, 0), 255)       # fill background from (0,0)
# #         holes = cv2.bitwise_not(im_flood)                   # holes are 255
# #         mask_filled = cv2.bitwise_or(mask, holes)           # FG + holes

# #         # 3) Optional morphological closing to seal tiny gaps
# #         if close_kernel and close_kernel > 1:
# #             k = np.ones((close_kernel, close_kernel), np.uint8)
# #             mask_filled = cv2.morphologyEx(mask_filled, cv2.MORPH_CLOSE, k)

# #         # 4) Alpha channel
# #         if mode == "binary":
# #             alpha = mask_filled
# #         elif mode == "soft":
# #             # Keep soft edges from brightness, but force filled mask as minimum
# #             alpha_soft = img_rgb.max(axis=2)                # 0..255
# #             alpha = np.maximum(alpha_soft, mask_filled).astype(np.uint8)
# #         else:
# #             raise ValueError("mode must be 'binary' or 'soft'")

# #         # 5) Assemble RGBA
# #         rgba = np.concatenate([img_rgb, alpha[..., None]], axis=2)
# #         return rgba

# #     import numpy as np

# #     def to_float_batched_rgba_white_bg_preserve_alpha(self,
# #         img: np.ndarray,
# #         *,
# #         derive_alpha: str = "soft",  # "soft" (max channel) or "binary"
# #         black_thresh: int = 5        # for "binary" when deriving alpha from RGB uint8
# #     ) -> np.ndarray:
# #         """
# #         Convert (H,W,3/4) or (1,H,W,3/4) to (1,H,W,4) float32 in [0,1].
# #         - RGB is composited over WHITE: rgb' = rgb * a + (1 - a) * 1
# #         - Alpha channel 'a' is PRESERVED (or derived if missing).

# #         If input has no alpha:
# #           - soft: a = max(rgb) (normalized to [0,1])
# #           - binary: a = 1 where any channel > black_thresh (uint8) else 0
# #         """
# #         arr = np.asarray(img)

# #         # Accept optional batch dim of 1
# #         if arr.ndim == 4 and arr.shape[0] == 1:
# #             arr = arr[0]
# #         if arr.ndim != 3 or arr.shape[-1] not in (3, 4):
# #             raise ValueError("Expected (H,W,3) or (H,W,4) (optionally with leading batch of 1).")

# #         # Normalize to float32 in [0,1]
# #         arr = arr.astype(np.float32, copy=False)
# #         if np.nanmax(arr) > 1.0 + 1e-6:
# #             arr = np.clip(arr, 0, 255) / 255.0

# #         H, W, C = arr.shape
# #         if C == 4:
# #             rgb = arr[..., :3]
# #             a   = np.clip(arr[..., 3], 0.0, 1.0)
# #         else:
# #             rgb = arr
# #             if derive_alpha == "soft":
# #                 a = np.clip(rgb.max(axis=-1), 0.0, 1.0)            # soft mask from brightness
# #             elif derive_alpha == "binary":
# #                 # original likely uint8->float; emulate threshold in 0..1
# #                 thr = black_thresh / 255.0
# #                 a = (rgb.max(axis=-1) > thr).astype(np.float32)
# #             else:
# #                 raise ValueError("derive_alpha must be 'soft' or 'binary'")

# #         # Composite RGB over WHITE but keep alpha as the mask
# #         a3 = a[..., None]                          # (H,W,1)
# #         rgb_on_white = rgb * a3 + (1.0 - a3) * 1.0 # white background appearance

# #         rgba_out = np.concatenate([rgb_on_white, a3], axis=-1).astype(np.float32)  # (H,W,4)
# #         return rgba_out[None, ...]  # (1,H,W,4)


# #     def center_mask_keep_scale(self,
# #         mask_image: np.ndarray,
# #         output_size=None,          # (W, H); if None, keep original canvas size
# #         pad: int = 0,              # padding inside the canvas (center within the inner area)
# #         threshold: int = 127,      # binarization threshold (0..255)
# #         overflow: str = "crop",    # 'crop' | 'error' | 'shrink'
# #         return_transform: bool = False
# #     ):
# #         """
# #         Center the largest region of a binary mask on a canvas, preserving original scale.

# #         Args:
# #             mask_image: 2D numpy array (grayscale/binary).
# #             output_size: (W, H) of output canvas; if None, use input image size.
# #             pad: pixels of padding from each border; the mask is centered in the inner area.
# #             threshold: threshold to binarize if needed.
# #             overflow:
# #                 - 'crop'  : keep scale and crop parts that would fall outside canvas
# #                 - 'error' : raise if the mask doesn't fit
# #                 - 'shrink': ONLY if needed, uniformly shrink to fit (keeps aspect ratio)
# #             return_transform: if True, returns a dict with offset/scale/bbox info.

# #         Returns:
# #             binary_mask: (H,W) uint8 {0,255}
# #             centered   : (H_out,W_out) uint8 {0,255}
# #             (optional) info: {'scale', 'bbox', 'offset', 'pasted_box', 'source_box'}
# #         """
# #         if mask_image is None or mask_image.ndim != 2:
# #             raise ValueError("Input must be a 2D mask image.")

# #         # Normalize to uint8
# #         img = mask_image
# #         if img.dtype != np.uint8:
# #             if np.nanmax(img) <= 1.0:
# #                 img = (np.clip(img, 0.0, 1.0) * 255.0).astype(np.uint8)
# #             else:
# #                 img = np.clip(img, 0, 255).astype(np.uint8)

# #         # Binarize
# #         _, binary_mask = cv2.threshold(img, threshold, 255, cv2.THRESH_BINARY)

# #         # Find largest contour
# #         contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# #         if not contours:
# #             # No foreground—return empty canvas
# #             if output_size is None:
# #                 out_w, out_h = binary_mask.shape[1], binary_mask.shape[0]
# #             else:
# #                 out_w, out_h = int(output_size[0]), int(output_size[1])
# #             centered = np.zeros((out_h, out_w), dtype=np.uint8)
# #             info = {'scale': 1.0, 'bbox': (0,0,0,0), 'offset': (0,0),
# #                     'pasted_box': (0,0,0,0), 'source_box': (0,0,0,0)}
# #             return (binary_mask, centered, info) if return_transform else (binary_mask, centered)

# #         # Bounding box of largest region (original scale)
# #         largest = max(contours, key=cv2.contourArea)
# #         x, y, w, h = cv2.boundingRect(largest)
# #         roi = binary_mask[y:y+h, x:x+w]  # 0/255

# #         # Output canvas
# #         if output_size is None:
# #             out_w, out_h = binary_mask.shape[1], binary_mask.shape[0]
# #         else:
# #             out_w, out_h = int(output_size[0]), int(output_size[1])

# #         pad = max(0, int(pad))
# #         inner_w = max(1, out_w - 2*pad)
# #         inner_h = max(1, out_h - 2*pad)

# #         # Decide if it fits at original scale
# #         fits = (w <= inner_w) and (h <= inner_h)

# #         scale = 1.0
# #         roi_to_paste = roi

# #         if not fits:
# #             if overflow == "error":
# #                 raise ValueError(
# #                     f"Mask bbox ({w}x{h}) does not fit into inner area ({inner_w}x{inner_h})."
# #                 )
# #             elif overflow == "shrink":
# #                 # Shrink uniformly just enough to fit (still "preserve" aspect ratio, but not the exact size)
# #                 scale = min(inner_w / w, inner_h / h)
# #                 new_w = max(1, int(round(w * scale)))
# #                 new_h = max(1, int(round(h * scale)))
# #                 roi_to_paste = cv2.resize(roi, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
# #                 w, h = new_w, new_h
# #             elif overflow == "crop":
# #                 # Keep original scale; we'll crop at paste time
# #                 pass
# #             else:
# #                 raise ValueError("overflow must be one of: 'crop', 'error', 'shrink'")

# #         # Center within inner area
# #         start_x = pad + (inner_w - w) // 2
# #         start_y = pad + (inner_h - h) // 2

# #         centered = np.zeros((out_h, out_w), dtype=np.uint8)

# #         # Compute destination box and crop if needed (for 'crop' overflow)
# #         dst_x0 = max(0, start_x)
# #         dst_y0 = max(0, start_y)
# #         dst_x1 = min(out_w, start_x + w)
# #         dst_y1 = min(out_h, start_y + h)

# #         # Corresponding source crop
# #         src_x0 = max(0, -start_x)
# #         src_y0 = max(0, -start_y)
# #         src_x1 = src_x0 + max(0, dst_x1 - dst_x0)
# #         src_y1 = src_y0 + max(0, dst_y1 - dst_y0)

# #         # Paste if there is an overlap
# #         if dst_x1 > dst_x0 and dst_y1 > dst_y0:
# #             centered[dst_y0:dst_y1, dst_x0:dst_x1] = roi_to_paste[src_y0:src_y1, src_x0:src_x1]

# #         if return_transform:
# #             info = {
# #                 'scale': scale,           # 1.0 if not shrunk
# #                 'bbox': (x, y, roi.shape[1], roi.shape[0]),  # original bbox size (w,h) before any shrink
# #                 'offset': (start_x, start_y),                # where top-left would be without cropping
# #                 'pasted_box': (dst_x0, dst_y0, dst_x1 - dst_x0, dst_y1 - dst_y0),
# #                 'source_box': (src_x0, src_y0, src_x1 - src_x0, src_y1 - src_y0),
# #             }
# #             return binary_mask, centered, info
# #         else:
# #             return binary_mask, centered




# #     def center_mask_preserve_ratio(self,
# #         mask_image: np.ndarray,
# #         output_size=(256, 256),      # (width, height); if None, use original size
# #         pad: int = 0,                # optional padding (pixels) around the mask on the canvas
# #         allow_upscale: bool = True,  # if False, only shrink; don't enlarge small masks
# #         threshold: int = 127,        # binarization threshold (0..255)
# #         return_transform: bool = False
# #     ):
# #         """
# #         Center the main (largest) region of a binary mask onto a new canvas,
# #         scaling uniformly to fit while preserving aspect ratio.

# #         Args:
# #             mask_image: 2D numpy array (grayscale or binary).
# #             output_size: (W, H). If None, uses original image size.
# #             pad: padding (pixels) to keep around the scaled mask (applied on all sides).
# #             allow_upscale: if False, the mask won't be enlarged beyond its original size.
# #             threshold: threshold for binarization if input isn't already {0,255}.
# #             return_transform: if True, also returns a dict with scale/offset/bbox.

# #         Returns:
# #             binary_mask: (H,W) uint8 in {0,255} — thresholded version of the input.
# #             centered_image: (H_out,W_out) uint8 in {0,255} — centered on the new canvas.
# #             (optional) info: dict with 'scale', 'bbox'=(x,y,w,h), 'offset'=(start_x,start_y)
# #         """
# #         if mask_image is None or mask_image.ndim != 2:
# #             raise ValueError("Input must be a 2D mask image.")

# #         # Normalize dtype/range to uint8
# #         img = mask_image
# #         if img.dtype != np.uint8:
# #             if np.nanmax(img) <= 1.0:
# #                 img = (np.clip(img, 0.0, 1.0) * 255.0).astype(np.uint8)
# #             else:
# #                 img = np.clip(img, 0, 255).astype(np.uint8)

# #         # Binarize (ensure {0,255})
# #         _, binary_mask = cv2.threshold(img, threshold, 255, cv2.THRESH_BINARY)

# #         # Find largest contour
# #         contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# #         if not contours:
# #             # No foreground; return empty canvas
# #             if output_size is None:
# #                 output_size = (binary_mask.shape[1], binary_mask.shape[0])
# #             centered = np.zeros((output_size[1], output_size[0]), dtype=np.uint8)
# #             return (binary_mask, centered, {'scale': 0.0, 'bbox': (0,0,0,0), 'offset': (0,0)}) if return_transform else (binary_mask, centered)

# #         largest = max(contours, key=cv2.contourArea)
# #         x, y, w, h = cv2.boundingRect(largest)
# #         roi = binary_mask[y:y+h, x:x+w]  # values are 0 or 255

# #         # Output canvas size
# #         if output_size is None:
# #             out_w, out_h = binary_mask.shape[1], binary_mask.shape[0]
# #         else:
# #             out_w, out_h = int(output_size[0]), int(output_size[1])

# #         # Effective area after padding
# #         pad = max(0, int(pad))
# #         eff_w = max(1, out_w - 2*pad)
# #         eff_h = max(1, out_h - 2*pad)

# #         # Uniform scale to fit while preserving aspect ratio
# #         scale = min(eff_w / max(1, w), eff_h / max(1, h))
# #         if not allow_upscale:
# #             scale = min(1.0, scale)

# #         new_w = max(1, int(round(w * scale)))
# #         new_h = max(1, int(round(h * scale)))

# #         # Resize with nearest to keep binary values intact
# #         roi_resized = cv2.resize(roi, (new_w, new_h), interpolation=cv2.INTER_NEAREST)

# #         # Create centered canvas
# #         centered = np.zeros((out_h, out_w), dtype=np.uint8)
# #         start_x = (out_w - new_w) // 2
# #         start_y = (out_h - new_h) // 2

# #         # Paste
# #         centered[start_y:start_y+new_h, start_x:start_x+new_w] = roi_resized

# #         if return_transform:
# #             info = {'scale': scale, 'bbox': (x, y, w, h), 'offset': (start_x, start_y)}
# #             return 255 - binary_mask, 255 - centered, info
# #         else:
# #             return 255 - binary_mask, 255 - centered




# #     def center_mask_in_image_from_array(self, mask_image: np.ndarray, output_size=None):
# #         """
# #         Centers the masked region of a binary mask image.

# #         Args:
# #             mask_image (np.ndarray): Input binary mask as a NumPy array (single-channel).
# #             output_size (tuple or None): Desired output size as (width, height).
# #                                         If None, uses original image size.

# #         Returns:
# #             original_mask (np.ndarray): Thresholded binary mask.
# #             centered_image (np.ndarray): New image with mask centered.
# #         """
# #         if mask_image is None or len(mask_image.shape) != 2:
# #             raise ValueError("Input must be a 2D binary mask image (grayscale).")

# #         if mask_image.dtype != np.uint8:
# #             mask_image = (mask_image * 255 if mask_image.max() <= 1.0 else mask_image).astype(np.uint8)

# #         # Ensure it's binary
# #         _, binary_mask = cv2.threshold(mask_image, 127, 255, cv2.THRESH_BINARY)

# #         # Find contours and bounding box
# #         contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# #         if not contours:
# #             raise ValueError("No masked region found in the image.")

# #         # Bounding box of the largest contour
# #         largest_contour = max(contours, key=cv2.contourArea)
# #         x, y, w, h = cv2.boundingRect(largest_contour)

# #         # Crop the region of interest
# #         cropped = binary_mask[y:y+h, x:x+w]

# #         # Define output canvas
# #         if output_size is None:
# #             output_size = binary_mask.shape[::-1]  # (width, height)

# #         centered_image = np.zeros((output_size[1], output_size[0]), dtype=np.uint8)

# #         # Center coordinates
# #         center_x, center_y = output_size[0] // 2, output_size[1] // 2
# #         start_x, start_y = center_x - w // 2, center_y - h // 2

# #         # Place cropped mask into center of canvas
# #         centered_image[start_y:start_y+h, start_x:start_x+w] = cropped

# #         return binary_mask, centered_image


# #     # import numpy as np
# #     # import cv2
# #     # from typing import Optional, Tuple, Literal, Dict, Any

# #     def center_cutout_keep_scale_2(
# #         self,
# #         cutout_rgba: np.ndarray,
# #         output_size: Optional[Tuple[int, int]] = None,  # (W_out, H_out). If None, keep original size
# #         pad: int = 0,
# #         overflow: Literal["crop", "error", "shrink"] = "crop",
# #         alpha_threshold: int = 10,
# #         white_thresh: int = 245,
# #         return_transform: bool = False
# #     ) -> Tuple[np.ndarray, np.ndarray, Dict[str, Any]]:
# #         """
# #         Re-center an RGBA cutout on a WHITE canvas, preserving scale.
# #         Also creates a centered RGBA silhouette (white FG, alpha=mask, transparent BG).
# #         Returns (centered_rgba, centered_silhouette_rgba, [info]) with shapes (1,H,W,4).
# #         """
# #         if cutout_rgba is None:
# #             raise ValueError("cutout_rgba is None")

# #         # Accept (H,W,4) or (1,H,W,4); ensure uint8 0..255
# #         img = cutout_rgba[0] if (cutout_rgba.ndim == 4 and cutout_rgba.shape[0] == 1) else cutout_rgba
# #         if img.ndim != 3 or img.shape[2] != 4:
# #             raise ValueError("cutout_rgba must have shape (H,W,4) or (1,H,W,4).")

# #         if img.dtype != np.uint8:
# #             maxv = float(np.nanmax(img))
# #             img = ((np.clip(img, 0.0, 1.0) * 255.0) if maxv <= 1.0 else np.clip(img, 0, 255)).astype(np.uint8)

# #         H, W, _ = img.shape
# #         if output_size is None:
# #             out_w, out_h = W, H
# #         else:
# #             out_w, out_h = int(output_size[0]), int(output_size[1])

# #         # ---- Foreground mask (use alpha if informative; fallback to non-white RGB)
# #         alpha = img[..., 3]
# #         use_alpha = not (np.all(alpha <= alpha_threshold) or np.all(alpha >= 255))
# #         if use_alpha:
# #             hard_mask = (alpha > alpha_threshold).astype(np.uint8) * 1       # for bbox finding
# #         else:
# #             rgb = img[..., :3]
# #             non_white = (rgb < white_thresh).any(axis=2)
# #             hard_mask = (non_white.astype(np.uint8)) * 1

# #         cnts, _ = cv2.findContours(hard_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# #         if not cnts:
# #             centered = np.full((out_h, out_w, 4), 255, dtype=np.uint8)         # white, opaque
# #             sil_centered = np.zeros((out_h, out_w, 4), dtype=np.uint8)         # transparent
# #             centered, sil_centered = centered[None, ...], sil_centered[None, ...]
# #             info = {'scale': 1.0, 'bbox': (0,0,0,0), 'offset': (0,0),
# #                     'pasted_box': (0,0,0,0), 'source_box': (0,0,0,0)}
# #             return (centered, sil_centered, info) if return_transform else (centered, sil_centered)

# #         # ---- Crop ROI for image and for mask/alpha
# #         x, y, w, h = cv2.boundingRect(max(cnts, key=cv2.contourArea))
# #         roi_img  = img[y:y+h, x:x+w, :]                         # (h,w,4)
# #         roi_alpha_soft = roi_img[..., 3].astype(np.uint8) if use_alpha else hard_mask[y:y+h, x:x+w]  # (h,w)

# #         # ---- Prepare canvases
# #         centered       = np.full((out_h, out_w, 4), 255, dtype=np.uint8)       # white RGB, alpha=255
# #         sil_centered   = np.zeros((out_h, out_w, 4), dtype=np.uint8)           # transparent

# #         # ---- Fit into inner area with optional shrink
# #         pad = max(0, int(pad))
# #         inner_w, inner_h = max(1, out_w - 2*pad), max(1, out_h - 2*pad)
# #         fits = (w <= inner_w) and (h <= inner_h)
# #         scale = 1.0
# #         roi_img_paste   = roi_img
# #         roi_alpha_paste = roi_alpha_soft

# #         if not fits:
# #             if overflow == "error":
# #                 raise ValueError("Cutout too large to fit canvas.")
# #             elif overflow == "shrink":
# #                 scale = min(inner_w / max(1, w), inner_h / max(1, h))
# #                 new_w = max(1, int(round(w * scale)))
# #                 new_h = max(1, int(round(h * scale)))
# #                 interp_img  = cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LINEAR
# #                 roi_img_paste   = cv2.resize(roi_img, (new_w, new_h), interpolation=interp_img)
# #                 # For alpha/mask -> keep edges crisp
# #                 roi_alpha_paste = cv2.resize(roi_alpha_soft, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
# #                 w, h = new_w, new_h
# #             # overflow == "crop": keep size; paste region will be clipped

# #         # ---- Center positions
# #         start_x = pad + (inner_w - w) // 2
# #         start_y = pad + (inner_h - h) // 2

# #         # Destination & source windows (clip to canvas)
# #         dx0, dy0 = max(0, start_x), max(0, start_y)
# #         dx1, dy1 = min(out_w, start_x + w), min(out_h, start_y + h)
# #         sx0, sy0 = max(0, -start_x), max(0, -start_y)
# #         sx1, sy1 = sx0 + max(0, dx1 - dx0), sy0 + max(0, dy1 - dy0)

# #         if dx1 > dx0 and dy1 > dy0:
# #             src_rgba  = roi_img_paste[sy0:sy1, sx0:sx1, :]           # (ph,pw,4)
# #             src_alpha = roi_alpha_paste[sy0:sy1, sx0:sx1]            # (ph,pw) 0..255

# #             # ---- Paste onto centered (white) with alpha compositing for RGB; alpha stays 255
# #             a = (src_alpha.astype(np.float32) / 255.0)[..., None]    # (ph,pw,1)
# #             dst = centered[dy0:dy1, dx0:dx1, :].astype(np.float32)
# #             dst[..., :3] = a * src_rgba[..., :3].astype(np.float32) + (1.0 - a) * dst[..., :3]
# #             centered[dy0:dy1, dx0:dx1, :3] = dst[..., :3].astype(np.uint8)
# #             centered[dy0:dy1, dx0:dx1, 3]  = 255

# #             # ---- Build silhouette RGBA: white FG, transparent BG; alpha = src_alpha
# #             sil_patch = sil_centered[dy0:dy1, dx0:dx1, :]
# #             sil_patch[..., 0:3] = (src_alpha > 0).astype(np.uint8)[..., None] * 255
# #             sil_patch[..., 3]   = src_alpha
# #             sil_centered[dy0:dy1, dx0:dx1, :] = sil_patch



# #         # Add batch dimension
# #         centered     = centered[None, ...]       # (1,H,W,4)
# #         sil_centered = sil_centered[None, ...]   # (1,H,W,4)

# #         return centered, sil_centered




# #     def center_cutout_keep_scale(self,
# #         cutout_rgba: np.ndarray,
# #         output_size: Optional[Tuple[int, int]] = None,  # (W_out, H_out). If None, keep original size
# #         pad: int = 0,
# #         overflow: Literal["crop", "error", "shrink"] = "crop",
# #         alpha_threshold: int = 10,
# #         white_thresh: int = 245,
# #         return_transform: bool = False
# #     ) -> Tuple[np.ndarray, Dict[str, Any]]:
# #         """
# #         Re-center an RGBA cutout on a white canvas, preserving scale.
# #         Returns (1,H,W,4) batch format.
# #         """
# #         # --- [same preprocessing as before] ---
# #         if cutout_rgba is None:
# #             raise ValueError("cutout_rgba is None")
# #         img = cutout_rgba
# #         if img.ndim == 4 and img.shape[0] == 1:
# #             img = img[0]
# #         if img.ndim != 3 or img.shape[2] != 4:
# #             raise ValueError("cutout_rgba must have shape (H,W,4) or (1,H,W,4).")

# #         if img.dtype != np.uint8:
# #             maxv = np.nanmax(img)
# #             if maxv <= 1.0:
# #                 img = (np.clip(img, 0.0, 1.0) * 255.0).astype(np.uint8)
# #             else:
# #                 img = np.clip(img, 0, 255).astype(np.uint8)

# #         H, W, _ = img.shape
# #         if output_size is None:
# #             out_w, out_h = W, H
# #         else:
# #             out_w, out_h = int(output_size[0]), int(output_size[1])

# #         # Foreground mask detection (alpha or RGB fallback)
# #         alpha = img[..., 3]
# #         use_alpha = not (np.all(alpha <= alpha_threshold) or np.all(alpha >= 255))
# #         if use_alpha:
# #             mask = (alpha > alpha_threshold).astype(np.uint8) * 255
# #         else:
# #             rgb = img[..., :3]
# #             non_white = (rgb < white_thresh).any(axis=2)
# #             mask = (non_white.astype(np.uint8)) * 255

# #         contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# #         if not contours:
# #             centered = np.full((out_h, out_w, 4), 255, dtype=np.uint8)
# #             centered = centered[None, ...]   # add batch dim
# #             info = {'scale': 1.0, 'bbox': (0,0,0,0), 'offset': (0,0),
# #                     'pasted_box': (0,0,0,0), 'source_box': (0,0,0,0)}
# #             return (centered, info) if return_transform else (centered,)

# #         # Bounding box of largest contour
# #         largest = max(contours, key=cv2.contourArea)
# #         x, y, w, h = cv2.boundingRect(largest)
# #         roi = img[y:y+h, x:x+w, :]

# #         centered = np.full((out_h, out_w, 4), 255, dtype=np.uint8)

# #         pad = max(0, int(pad))
# #         inner_w, inner_h = max(1, out_w - 2*pad), max(1, out_h - 2*pad)
# #         fits = (w <= inner_w) and (h <= inner_h)
# #         scale, roi_to_paste = 1.0, roi

# #         if not fits:
# #             if overflow == "error":
# #                 raise ValueError("Cutout too large to fit canvas.")
# #             elif overflow == "shrink":
# #                 scale = min(inner_w / max(1, w), inner_h / max(1, h))
# #                 new_w = max(1, int(round(w * scale)))
# #                 new_h = max(1, int(round(h * scale)))
# #                 interp = cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LINEAR
# #                 roi_to_paste = cv2.resize(roi, (new_w, new_h), interpolation=interp)
# #                 w, h = new_w, new_h

# #         start_x = pad + (inner_w - w) // 2
# #         start_y = pad + (inner_h - h) // 2

# #         dst_x0, dst_y0 = max(0, start_x), max(0, start_y)
# #         dst_x1, dst_y1 = min(out_w, start_x + w), min(out_h, start_y + h)
# #         src_x0, src_y0 = max(0, -start_x), max(0, -start_y)
# #         src_x1, src_y1 = src_x0 + max(0, dst_x1 - dst_x0), src_y0 + max(0, dst_y1 - dst_y0)

# #         if dst_x1 > dst_x0 and dst_y1 > dst_y0:
# #             src = roi_to_paste[src_y0:src_y1, src_x0:src_x1, :]
# #             dst = centered[dst_y0:dst_y1, dst_x0:dst_x1, :]
# #             a = (src[..., 3:4].astype(np.float32)) / 255.0
# #             dst[..., :3] = (a * src[..., :3].astype(np.float32) +
# #                             (1.0 - a) * dst[..., :3].astype(np.float32)).astype(np.uint8)
# #             centered[dst_y0:dst_y1, dst_x0:dst_x1, :3] = dst[..., :3]

# #         info = {
# #             'scale': scale,
# #             'bbox': (x, y, roi.shape[1], roi.shape[0]),
# #             'offset': (start_x, start_y),
# #             'pasted_box': (dst_x0, dst_y0, dst_x1 - dst_x0, dst_y1 - dst_y0),
# #             'source_box': (src_x0, src_y0, src_x1 - src_x0, src_y1 - src_y0),
# #         }

# #         # ✅ Add batch dimension at the end
# #         centered = centered[None, ...]   # shape (1,H,W,4)
# #         return (centered, info) if return_transform else (centered,)





# #     # def get_camera_pose(self, distance=3, elevation=40, azimuth=40,
# #     #                     device='cpu', eye_override=None):
# #     #     """
# #     #     Computes the camera rotation and translation matrices using
# #     #     spherical coordinates or a fixed eye location.

# #     #     Args:
# #     #         distance (float): Distance from camera to the origin (default: 3)
# #     #         elevation (float): Elevation angle in degrees (default: 40)
# #     #         azimuth (float): Azimuth angle in degrees (default: 40)
# #     #         device (str): PyTorch device ('cpu' or 'cuda') for tensor operations
# #     #         eye_override (list/tuple): Optional fixed camera position [x, y, z]
# #     #                                 If provided, overrides spherical coordinates

# #     #     Returns:
# #     #         tuple: (R, T) where R is rotation matrix and T is translation vector
# #     #     """

# #     #     # Check if a specific camera position is provided
# #     #     if eye_override is not None:
# #     #         # Convert eye position to PyTorch tensor and add batch dimension
# #     #         eye = torch.tensor(eye_override, device=device).unsqueeze(0)  # Shape: (1, 3)

# #     #         # Compute view transform using fixed eye position
# #     #         # Camera looks at origin (0,0,0) by default
# #     #         R, T = look_at_view_transform(eye=eye, device=device)
# #     #     else:
# #     #         # Use spherical coordinates to position camera
# #     #         # dist: radial distance from origin
# #     #         # elev: elevation angle (up/down rotation)
# #     #         # azim: azimuth angle (left/right rotation)
# #     #         R, T = look_at_view_transform(dist=distance, elev=elevation,
# #     #                                       azim=azimuth, degrees=True,
# #     #                                       device=device)

# #     #     # # Debug output - print the computed matrices
# #     #     # print(f"Rotation R:\n{R}\n")      # 3x3 rotation matrix
# #     #     # print(f"Translation T:\n{T}\n")   # 3D translation vector

# #     #     # Return the camera extrinsic parameters
# #     #     return R, T

# #     def get_camera_pose(self, distance=3, elevation=40, azimuth=40,
# #                         device="cpu", eye_override=None):
# #         """
# #         Returns (R, T) with shapes (1,3,3) and (1,3).
# #         - If eye_override is provided (3,) or (1,3), gradients flow through it.
# #         - Otherwise uses spherical (dist, elev, azim).
# #         """
# #         dev = device  # str or torch.device both work with .to(...)

# #         if eye_override is not None:
# #             # ✅ Preserve autograd: do NOT use torch.tensor(...) or float(...)
# #             if isinstance(eye_override, torch.Tensor):
# #                 eye = eye_override.to(device=dev, dtype=torch.float32).reshape(1, 3)
# #             else:
# #                 eye = torch.as_tensor(eye_override, dtype=torch.float32, device=dev).reshape(1, 3)

# #             # Differentiable pose from camera center:
# #             # X_cam = R @ X_world + T, with T = -R^T @ eye
# #             R = look_at_rotation(eye, device=dev)                               # (1,3,3)
# #             T = -torch.bmm(R.transpose(1, 2), eye[..., None])[:, :, 0]          # (1,3)
# #             return R, T

# #         # Spherical path (angles → pose). 'degrees=True' only matters in this branch.
# #         R, T = look_at_view_transform(dist=distance, elev=elevation,
# #                                       azim=azimuth, degrees=True, device=dev)
# #         return R, T


# #     def render_mesh(self, mesh, silhouette_renderer, phong_renderer, R, T):
# #         """
# #         Render a mesh using both silhouette and Phong renderers.
# #         Returns both render outputs.
# #         """
# #         silhouette = silhouette_renderer(meshes_world=mesh, R=R, T=T)
# #         image_ref = phong_renderer(meshes_world=mesh, R=R, T=T)
# #         return silhouette, image_ref



# #     def get_camera_intrinsics(self, scale_factor = 1):
# #         """
# #         Returns the camera intrinsics matrix K.
# #         """
# #         # if cuda_available:
# #         # K = np.array([
# #         #     [672.*scale_factor,                  0.,   358.*scale_factor],
# #         #     [               0.,   770.*scale_factor,   232.*scale_factor],
# #         #     [               0.,                  0.,                  1.]
# #         # ], dtype="double")

# #         # Principle point adjusted proportionally
# #         K = np.array([
# #             [1400.*scale_factor,                  0.,   128],
# #             [               0.,   1800.*scale_factor,   128],
# #             [               0.,                  0.,                  1.]
# #         ], dtype="double")

# #         return K


# #     def create_perspective_camera(self, K, image_size, device):
# #         """
# #         Create a PerspectiveCameras object from intrinsics.
# #         """
# #         fcl_screen = ((K[0, 0], K[1, 1]),)  # fx, fy
# #         prp_screen = ((K[0, 2], K[1, 2]),)  # cx, cy
# #         return PerspectiveCameras(
# #             focal_length=fcl_screen,
# #             principal_point=prp_screen,
# #             in_ndc=False,
# #             image_size=image_size,
# #             device=device
# #         )


# #     def create_silhouette_renderer(self, cameras, image_size, device):
# #         """
# #         Create a mesh renderer that outputs silhouettes.
# #         """
# #         blend_params = BlendParams(sigma=1e-5, gamma=1e-4)

# #         raster_settings = RasterizationSettings(
# #             image_size=image_size,
# #             blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma,
# #             faces_per_pixel=100,
# #             max_faces_per_bin=200000
# #         )

# #         renderer = MeshRenderer(
# #             rasterizer=MeshRasterizer(
# #                 cameras=cameras,
# #                 raster_settings=raster_settings
# #             ),
# #             shader=SoftSilhouetteShader(blend_params=blend_params)
# #         )
# #         return renderer


# #     def create_phong_renderer(self, cameras, image_size, device):
# #         """
# #         Create a mesh renderer that outputs shaded RGB images using Phong lighting.
# #         """
# #         raster_settings = RasterizationSettings(
# #             image_size=image_size,
# #             blur_radius=0.0,
# #             faces_per_pixel=10,
# #             max_faces_per_bin=200000
# #         )

# #         lights = PointLights(device=device, location=((2.0, 2.0, -2.0),))


# #         # Tune sigma/gamma for edge softness; background_color only affects RGB where alpha==0
# #         blend_params = BlendParams(
# #             sigma=1e-4,
# #             gamma=1e-4,
# #             background_color=(1.0, 1.0, 1.0)  # choose any; alpha controls transparency
# #         )


# #         renderer = MeshRenderer(
# #             rasterizer=MeshRasterizer(
# #                 cameras=cameras,
# #                 raster_settings=raster_settings
# #             ),
# #             shader=SoftPhongShader(device=device, cameras=cameras, lights=lights, blend_params=blend_params)
# #             # shader=HardPhongShader(device=device, cameras=cameras, lights=lights)
# #         )
# #         return renderer

# #     def get_cad_model_files(self):

# #         # Create the assets directory if it doesn't already exist.
# #         os.makedirs(self.assets_dir, exist_ok=True)

# #         # These are the location of the CAD model and Texture file for the object (e.g., dropbox shared link or similar)
# #         cad_url = "https://www.dropbox.com/scl/fi/y55irvvd58yw4ucu3igup/sse_only_jun9_scaled100.ply?rlkey=6io3u20qpupumdfp18i1ssmr9&dl=1"
# #         tex_url = "https://www.dropbox.com/scl/fi/rj1thjtu16o2344vh5mjn/sse_texture.png?rlkey=503sm6dwzv1f8gyy1jgg6tnxl&dl=1"

# #         # Simplified CAD model (no texture, yet)
# #         cad_small_url = "https://www.dropbox.com/scl/fi/kg5zutt2deamc4ksi17f7/final_model.obj?rlkey=3uenujpffgqk1t0edx5plj5oe&dl=1"
# #         cad_small_path = os.path.join(self.assets_dir, "final_model.obj")
# #         print("📥 Downloading simplified CAD model...")
# #         self._download_file(cad_small_url, cad_small_path)

# #         # Detailed model
# #         cad_detailed_url = "https://www.dropbox.com/scl/fi/ubit9vw23fphcwpu0w88v/obj_000001.ply?rlkey=hx9f12a9ka8160mhxq8vnz7yk&dl=1"
# #         cad_detailed_path = os.path.join(self.assets_dir, "obj_000001.ply")
# #         print("📥 Downloading detailed CAD model...")
# #         self._download_file(cad_detailed_url, cad_detailed_path)

# #         # Test image
# #         print("📥 Downloading a test image...")
# #         image_detailed_path = os.path.join(self.assets_dir, "00002.png")
# #         self._download_file("https://www.dropbox.com/scl/fi/b4bjgg4w6n6c6ufx1krhg/00002.png?rlkey=2h6yikdk1liz4b0ojpw4dcz50&dl=1", image_detailed_path)

# #         # Test image
# #         print("📥 Downloading a test image...")
# #         image_detailed_path = os.path.join(self.assets_dir, "00016.png")
# #         self._download_file("https://www.dropbox.com/scl/fi/0nf86esq8723iirfz35uj/00016.png?rlkey=9yac94rwxj7uppz4oorgntpqz&dl=1", image_detailed_path)


# #         # Download the files into the assets directory
# #         cad_path = os.path.join(self.assets_dir, "sse_only_jun9_scaled100.ply")
# #         tex_path = os.path.join(self.assets_dir, "sse_texture.png")
# #         print("📥 Downloading CAD model and texture...")
# #         self._download_file(cad_url, cad_path)
# #         self._download_file(tex_url, tex_path)


# #         # SSE model in .obj format
# #         print("📥 Downloading SSE model in .obj format...")
# #         image_detailed_path = os.path.join(self.assets_dir, "sse_only.obj")
# #         self._download_file("https://www.dropbox.com/scl/fi/nwfswb9yqaox5guzxedyp/sse_only.obj?rlkey=ub48myfosw690btzg7gfoxert&dl=1", image_detailed_path)

# #         # SSE model .mtl texture file
# #         print("📥 Downloading SSE model .mtl texture file...")
# #         image_detailed_path = os.path.join(self.assets_dir, "sse_only.mtl")
# #         self._download_file("https://www.dropbox.com/scl/fi/3mz2zgppszxj6cvv2mgqm/sse_only.mtl?rlkey=ckrw1ko30dk9d8unvsl659vmr&dl=1", image_detailed_path)



# #         # SSE model in .obj format
# #         print("📥 Downloading new SSE model with real texture in .obj format...")
# #         image_detailed_path = os.path.join(self.assets_dir, "sse_only_real_texture.obj")
# #         self._download_file("https://www.dropbox.com/scl/fi/8oo2ofzz6gtxn3ve2ue6r/sse_only_real_texture.obj?rlkey=a13utsbxbzib2c61j5gcasblt&dl=1", image_detailed_path)

# #         # SSE model .mtl texture file
# #         print("📥 Downloading new SSE model with real texture .mtl texture file...")
# #         image_detailed_path = os.path.join(self.assets_dir, "sse_only_real_texture.mtl")
# #         self._download_file("https://www.dropbox.com/scl/fi/e9ld55c4cyq4e6bguwti6/sse_only_real_texture.mtl?rlkey=0lvgesvk9yoi245gndcw0arqm&dl=1", image_detailed_path)



# #         print(f"✅ CAD and texture files placed in directory: {self.assets_dir}")


# #     def _download_file(self, url, dest_path):
# #         with requests.get(url, stream=True) as r:
# #             with open(dest_path, "wb") as f:
# #                 shutil.copyfileobj(r.raw, f)

# #     def get_device(self):
# #         """Returns the appropriate torch.device and sets CUDA device if available."""
# #         if torch.cuda.is_available():
# #             torch.cuda.set_device(0)
# #             device = torch.device("cuda:0")
# #             print(f"[INFO]: Using CUDA device: {device}.")
# #         else:
# #             device = torch.device("cpu")
# #             print(f"[INFO]: Using CPU device.")

# #         return device


# #     def load_cad_mesh(self, cad_path: str, device: torch.device = torch.device("cuda:0")) -> Meshes:
# #         """
# #         Load a CAD mesh (.obj or .ply) and return it as a PyTorch3D Meshes object with vertex colors.

# #         Args:
# #             cad_path (str): Full path to the CAD mesh file.
# #             device (torch.device): The device to load the mesh onto.

# #         Returns:
# #             Meshes: A PyTorch3D Meshes object with yellow vertex color.
# #         """
# #         cad_path = Path(cad_path)
# #         ext = cad_path.suffix.lower()

# #         if ext == ".obj":
# #             verts, faces, _ = load_obj(str(cad_path), load_textures=False)
# #             verts = verts.to(device)
# #             faces = faces.verts_idx.to(device)
# #         elif ext == ".ply":
# #             verts, faces = load_ply(str(cad_path))
# #             verts = (verts / 1000.0).to(device)  # Convert mm to meters
# #             faces = faces.to(device)
# #         else:
# #             raise ValueError(f"Unsupported CAD file format: {ext}")

# #         # Set vertex color to yellow
# #         verts_rgb = torch.tensor([1.0, 1.0, 0.0], device=device).repeat(verts.shape[0], 1)  # (V, 3)
# #         textures = TexturesVertex(verts_features=verts_rgb[None])  # (1, V, 3)

# #         # Create and return the mesh
# #         mesh = Meshes(
# #             verts=[verts],
# #             faces=[faces],
# #             textures=textures
# #         )
# #         return mesh



# #     def load_cad_mesh(self, cad_path: str, device: torch.device = torch.device("cuda:0")) -> Meshes:
# #         """
# #         Load a CAD mesh (.obj or .ply) and return it as a PyTorch3D Meshes object with vertex colors.

# #         Args:
# #             cad_path (str): Full path to the CAD mesh file.
# #             device (torch.device): The device to load the mesh onto.

# #         Returns:
# #             Meshes: A PyTorch3D Meshes object with yellow vertex color.
# #         """
# #         cad_path = Path(cad_path)
# #         ext = cad_path.suffix.lower()

# #         if ext == ".obj":
# #             verts, faces, _ = load_obj(str(cad_path), load_textures=False)
# #             verts = verts.to(device)
# #             faces = faces.verts_idx.to(device)
# #         elif ext == ".ply":
# #             verts, faces = load_ply(str(cad_path))
# #             verts = (verts / 1000.0).to(device)  # Convert mm to meters
# #             faces = faces.to(device)
# #         else:
# #             raise ValueError(f"Unsupported CAD file format: {ext}")

# #         # Set vertex color to yellow
# #         verts_rgb = torch.tensor([1.0, 1.0, 0.0], device=device).repeat(verts.shape[0], 1)  # (V, 3)
# #         textures = TexturesVertex(verts_features=verts_rgb[None])  # (1, V, 3)

# #         # Create and return the mesh
# #         mesh = Meshes(
# #             verts=[verts],
# #             faces=[faces],
# #             textures=textures
# #         )
# #         return mesh


# #     def load_cad_mesh_with_texture(self, cad_path: str, device: torch.device = torch.device("cuda:0")) -> Meshes:
# #         """
# #         Load a CAD mesh (.obj or .ply) and return it as a PyTorch3D Meshes object with vertex colors.

# #         Args:
# #             cad_path (str): Full path to the CAD mesh file.
# #             device (torch.device): The device to load the mesh onto.

# #         Returns:
# #             Meshes: A PyTorch3D Meshes object with yellow vertex color.
# #         """
# #         cad_path = Path(cad_path)

# #         # Load obj file
# #         mesh = load_objs_as_meshes([str(cad_path)], device=device)


# #         return mesh



# #     def compute_camera_pose(self, camera_position, device):
# #         """
# #         Compute camera rotation R and translation T from a given camera position.
# #         """
# #         R = look_at_rotation(camera_position[None, :], device=device)
# #         T = -torch.bmm(R.transpose(1, 2), camera_position[None, :, None])[:, :, 0]
# #         return R, T


# #     def create_background_from_mask(self, mask):
# #         """
# #         Convert a single-channel silhouette mask into an RGB background.
# #         """
# #         return np.stack([mask] * 3, axis=-1)  # (H, W, 3)


# #     def alpha_blend(self, rgb_rendered, alpha_rendered, background_img):
# #         """
# #         Alpha-blend the rendered RGB image over the background.
# #         """
# #         return rgb_rendered * alpha_rendered + background_img * (1.0 - alpha_rendered)



# #     def render_overlay(self, model, phong_renderer, mask_background):
# #         """
# #         Renders an overlay image of the current model pose and background.

# #         Args:
# #             model: The pose model with .camera_position, .device, and .meshes.
# #             phong_renderer: A renderer used for generating the RGB + alpha image.
# #             mask_background: (H, W, 3) background image to blend with.

# #         Returns:
# #             image_uint8: Blended uint8 image with the rendered object overlaid on the background.
# #         """

# #         R, T = self.compute_camera_pose(model.camera_position, model.device)


# #         # print("R, T before roll")
# #         # print(R)
# #         # print(T)

# #         # print("self.roll_deg = ", model.roll_deg)

# #         # Add roll (try mode="camera" first; if no visible spin, try mode="world")
# #         R, T = self.add_camera_roll_to_RT(R, T, roll_deg=model.roll_deg, device=model.device, mode="world")

# #         # print("R, T after roll")
# #         # print(R)
# #         # print(T)



# #         rendered = phong_renderer(
# #             meshes_world=model.meshes.clone(), R=R, T=T
# #         )[0]  # (H, W, 4)

# #         rgb_rendered = rendered[..., :3].detach().cpu().numpy()
# #         alpha_rendered = rendered[..., 3].detach().cpu().numpy()[..., None]

# #         overlay = self.alpha_blend(rgb_rendered, alpha_rendered, mask_background)
# #         image_uint8 = img_as_ubyte(overlay)

# #         return image_uint8


# #     def display_cad_trimesh(self, mesh_path):
# #         # Load a mesh from file (format is inferred from extension)
# #         mesh = trimesh.load(mesh_file_path)

# #         # Print basic info
# #         print("Number of vertices:", len(mesh.vertices))
# #         print("Number of faces:", len(mesh.faces))


# #         # Visualize the mesh
# #         mesh.show()




# #     def optimize_camera_pose_best_loss(self,
# #         model,
# #         optimizer,
# #         mask_background,
# #         phong_renderer,
# #         writer,
# #         num_iter=100,
# #         log_every=10,
# #         patience=10,
# #         tolerance=10.0
# #     ):
# #         """
# #         Runs pose optimization and saves rendered frames to a GIF.

# #         Args:
# #             model: A torch.nn.Module representing the pose model.
# #             optimizer: Optimizer for model parameters.
# #             mask_background: (H, W, 3) RGB numpy array of the background.
# #             phong_renderer: Renderer used for rendering the mesh.
# #             writer: An imageio writer to save rendered frames.
# #             num_iter: Maximum number of iterations.
# #             log_every: Interval to render and save frames.
# #             patience: Early stopping patience.
# #             tolerance: Minimum loss change to reset patience.
# #         Returns:
# #             Tuple (R, T) of the final camera rotation and translation.
# #         """
# #         best_loss = float("inf")
# #         wait = 0
# #         final_R, final_T = None, None

# #         loop = tqdm(range(num_iter))
# #         for i in loop:
# #             torch.cuda.empty_cache()
# #             optimizer.zero_grad()
# #             loss, im, _, _ = model()
# #             loss.backward()
# #             optimizer.step()

# #             current_loss = loss.item()
# #             loop.set_description(f"Optimizing (loss {current_loss:.4f})")

# #             # Early stopping with best pose tracking
# #             if current_loss < best_loss:
# #                 best_loss = current_loss
# #                 wait = 0
# #                 final_R, final_T = self.compute_camera_pose(model.camera_position, model.device)
# #             else:
# #                 wait += 1
# #                 if wait >= patience:
# #                     print(f"[INFO] Early stopping at step {i}: loss plateaued.")
# #                     break


# #             # if abs(best_loss - current_loss) < tolerance:
# #             #     wait += 1
# #             #     if wait >= patience:
# #             #         print(f"[INFO] Early stopping at step {i}: loss plateaued.")
# #             #         break
# #             # else:
# #             #     best_loss = current_loss
# #             #     wait = 0

# #             #     # Save current best pose when new best loss is found
# #             #     final_R, final_T = self.compute_camera_pose(model.camera_position, model.device)

# #             # Periodically render and save for visualization
# #             if i % log_every == 0:
# #                 image_uint8 = self.render_overlay(model, phong_renderer, mask_background)
# #                 writer.append_data(image_uint8)

# #                 print(f"Image from model: {im.shape}")


# #         writer.close()

# #         print(f"[RESULT] Best loss: {best_loss:.6f}")
# #         return final_R, final_T





# #     def optimize_camera_pose(self,
# #         model,
# #         optimizer,
# #         mask_background,
# #         phong_renderer,
# #         writer,
# #         num_iter=100,
# #         log_every=10,
# #         patience=10,
# #         tolerance=10.0
# #     ):
# #         """
# #         Runs pose optimization and saves rendered frames to a GIF.

# #         Args:
# #             model: A torch.nn.Module representing the pose model.
# #             optimizer: Optimizer for model parameters.
# #             mask_background: (H, W, 3) RGB numpy array of the background.
# #             phong_renderer: Renderer used for rendering the mesh.
# #             writer: An imageio writer to save rendered frames.
# #             num_iter: Maximum number of iterations.
# #             log_every: Interval to render and save frames.
# #             patience: Early stopping patience.
# #             tolerance: Minimum loss change to reset patience.
# #         Returns:
# #             Tuple (R, T) of the final camera rotation and translation.
# #         """
# #         best_loss = float("inf")
# #         wait = 0
# #         final_R, final_T = None, None

# #         loop = tqdm(range(num_iter))
# #         for i in loop:
# #             torch.cuda.empty_cache()
# #             optimizer.zero_grad()
# #             loss, im, _, _= model()
# #             loss.backward()
# #             optimizer.step()

# #             current_loss = loss.item()
# #             loop.set_description(f"Optimizing (loss {current_loss:.4f})")
# #             # print(f"Step {i}, Loss = {current_loss}")

# #             # Early stopping
# #             if abs(best_loss - current_loss) < tolerance:
# #                 wait += 1
# #                 if wait >= patience:
# #                     print(f"[INFO] Early stopping at step {i}: loss plateaued.")
# #                     break
# #             else:
# #                 best_loss = current_loss
# #                 wait = 0

# #             # Periodically render and save
# #             if i % log_every == 0:


# #                 R, T = self.compute_camera_pose(model.camera_position, model.device)

# #                 # Create an overlay image of the rendered model in its current
# #                 # pose and the reference silhouette.
# #                 image_uint8 = self.render_overlay(model, phong_renderer, mask_background)

# #                 print(f"Image from model: {im.shape}")
# #                 writer.append_data(image_uint8)

# #                 final_R, final_T = R, T

# #                 # Consider returning the best loss as the final one (not the one after plato is reached)

# #         writer.close()

# #         return final_R, final_T


## New utilities using namespaces and stateless classes

In [None]:
# ========= Namespaces (pure helpers live here) =================================

class Cam:
    """Camera math (stateless)."""
    @staticmethod
    def get_camera_position(distance, elevation, azimuth, *, degrees=True, device="cpu"):
        from pytorch3d.renderer import camera_position_from_spherical_angles
        return camera_position_from_spherical_angles(distance=distance,
                                                     elevation=elevation,
                                                     azimuth=azimuth,
                                                     degrees=degrees).to(device)

    @staticmethod
    def camera_center_to_dist_elev_azim(C: torch.Tensor):
        """C (...,3) -> (dist,elev,azim) (PyTorch3D conv: Y-up; azim=0 -> +Z; +azim toward +X)."""
        squeeze = False
        if C.ndim == 1:
            C = C[None, :]
            squeeze = True
        x, y, z = C[..., 0], C[..., 1], C[..., 2]
        dist = torch.linalg.norm(C, dim=-1)
        rho  = torch.sqrt(torch.clamp(x*x + z*z, min=1e-12))
        elev = torch.rad2deg(torch.atan2(y, rho))
        azim = torch.rad2deg(torch.atan2(x, z))
        if squeeze:
            dist, elev, azim = dist[0], elev[0], azim[0]
        return dist, elev, azim

    @staticmethod
    def get_camera_pose(distance=3, elevation=40, azimuth=40, *, device="cpu", eye_override=None):
        """
        Returns (R,T) with shapes (1,3,3),(1,3). If eye_override is provided (3,) or (1,3),
        autograd paths are preserved.
        """
        from pytorch3d.renderer import look_at_view_transform
        from pytorch3d.renderer.cameras import look_at_rotation

        dev = device
        if eye_override is not None:
            if isinstance(eye_override, torch.Tensor):
                eye = eye_override.to(device=dev, dtype=torch.float32).reshape(1, 3)
            else:
                eye = torch.as_tensor(eye_override, dtype=torch.float32, device=dev).reshape(1, 3)
            R = look_at_rotation(eye, device=dev)
            T = -torch.bmm(R.transpose(1, 2), eye[..., None])[:, :, 0]
            return R, T

        R, T = look_at_view_transform(dist=distance, elev=elevation, azim=azimuth,
                                      degrees=True, device=dev)
        return R, T

    @staticmethod
    def add_camera_roll_to_RT(R, T, roll_deg, *, device=None, mode="camera"):
        """
        Compose a Z-axis roll into (R,T), keeping the same camera center C.
        Grad-safe: roll_deg can be a Tensor/Parameter.
        mode: "camera" -> R' = Rz @ R ; "world" -> R' = R @ Rz
        """
        if not torch.is_tensor(R): R = torch.as_tensor(R)
        if not torch.is_tensor(T): T = torch.as_tensor(T)
        dev   = device or R.device
        dtype = torch.float32
        R = R.to(dev, dtype)
        T = T.to(dev, dtype)

        unbatched = (R.ndim == 2)
        if unbatched:
            R = R[None, ...]
            T = T[None, ...]

        # C from T = -R^T C  =>  C = -R T
        C = -torch.matmul(R, T[..., None]).squeeze(-1)

        theta = torch.as_tensor(roll_deg, dtype=dtype, device=dev).reshape(1)  # keep grad
        c, s = torch.cos(torch.deg2rad(theta)), torch.sin(torch.deg2rad(theta))
        z = torch.zeros_like(c); o = torch.ones_like(c)
        Rz = torch.stack([
            torch.stack([ c, -s, z], dim=-1),
            torch.stack([ s,  c, z], dim=-1),
            torch.stack([ z,  z,  o], dim=-1),
        ], dim=1)  # (1,3,3)

        R_new = torch.matmul(Rz, R) if mode == "camera" else torch.matmul(R, Rz)
        T_new = -torch.matmul(R_new.transpose(1, 2), C[..., None]).squeeze(-1)

        if unbatched:
            R_new, T_new = R_new[0], T_new[0]
        return R_new, T_new

    @staticmethod
    def compute_camera_pose_from_center(C: torch.Tensor, *, device):
        """C (3,) -> (R,T)."""
        from pytorch3d.renderer.cameras import look_at_rotation
        R = look_at_rotation(C[None, :], device=device)
        T = -torch.bmm(R.transpose(1, 2), C[None, :, None])[:, :, 0]
        return R, T


class Img:
    """Cutout/mask utilities (stateless)."""
    @staticmethod
    def read_rgb_cutout_black_bg(path: str) -> np.ndarray:
        img = Image.open(path)
        img = ImageOps.exif_transpose(img).convert("RGBA")
        rgba = np.asarray(img).astype(np.float32)
        rgb  = rgba[..., :3]
        a    = rgba[..., 3:4] / 255.0
        rgb_black = rgb * a
        return np.clip(rgb_black, 0, 255).astype(np.uint8)

    @staticmethod
    def center_cutout_rgb_uint8(img: np.ndarray, black_thresh: int = 5) -> np.ndarray:
        if img.ndim != 3 or img.shape[2] != 3 or img.dtype != np.uint8:
            raise ValueError("Expected (H,W,3) uint8.")
        H, W, _ = img.shape
        mask = (img.max(axis=2) > black_thresh)
        if not mask.any():
            return img.copy()
        ys, xs = np.where(mask)
        y0, y1 = int(ys.min()), int(ys.max()) + 1
        x0, x1 = int(xs.min()), int(xs.max()) + 1
        roi = img[y0:y1, x0:x1, :]
        h, w = roi.shape[:2]
        sy, sx = (H - h) // 2, (W - w) // 2
        out = np.zeros_like(img)
        out[sy:sy+h, sx:sx+w] = roi
        return out

    @staticmethod
    def crop_center_to_size_uint8(img: np.ndarray, out_size: tuple[int, int], fill_color=(0,0,0)) -> np.ndarray:
        if img.ndim != 3 or img.shape[2] != 3 or img.dtype != np.uint8:
            raise ValueError("Expected (H,W,3) uint8.")
        W_out, H_out = map(int, out_size)
        H, W = img.shape[:2]
        left = (W - W_out) // 2; top = (H - H_out) // 2
        right, bottom = left + W_out, top + H_out
        src_x0, src_y0 = max(0, left), max(0, top)
        src_x1, src_y1 = min(W, right), min(H, bottom)
        dst_x0, dst_y0 = max(0, -left), max(0, -top)
        dst_x1, dst_y1 = dst_x0 + (src_x1 - src_x0), dst_y0 + (src_y1 - src_y0)
        out = np.empty((H_out, W_out, 3), dtype=np.uint8); out[...] = np.asarray(fill_color, np.uint8)
        if src_x1 > src_x0 and src_y1 > src_y0:
            out[dst_y0:dst_y1, dst_x0:dst_x1] = img[src_y0:src_y1, src_x0:src_x1]
        return out

    @staticmethod
    def add_alpha_from_black_bg_uint8(img_rgb: np.ndarray, *, black_thresh: int = 5,
                                      mode: str = "binary", close_kernel: int = 3) -> np.ndarray:
        if img_rgb.ndim != 3 or img_rgb.shape[2] != 3 or img_rgb.dtype != np.uint8:
            raise ValueError("Expected (H,W,3) uint8.")
        mask = (img_rgb.max(axis=2) > black_thresh).astype(np.uint8) * 255
        im_flood = mask.copy()
        h, w = mask.shape
        ff_mask = np.zeros((h+2, w+2), np.uint8)
        cv2.floodFill(im_flood, ff_mask, (0,0), 255)
        holes = cv2.bitwise_not(im_flood)
        mask_filled = cv2.bitwise_or(mask, holes)
        if close_kernel and close_kernel > 1:
            k = np.ones((close_kernel, close_kernel), np.uint8)
            mask_filled = cv2.morphologyEx(mask_filled, cv2.MORPH_CLOSE, k)
        if mode == "binary":
            alpha = mask_filled
        elif mode == "soft":
            alpha = np.maximum(img_rgb.max(axis=2), mask_filled).astype(np.uint8)
        else:
            raise ValueError("mode must be 'binary' or 'soft'")
        return np.concatenate([img_rgb, alpha[..., None]], axis=2)

    @staticmethod
    def to_float_batched_rgba_white_bg_preserve_alpha(img: np.ndarray, *,
                                                      derive_alpha: str = "soft",
                                                      black_thresh: int = 5) -> np.ndarray:
        arr = np.asarray(img)
        if arr.ndim == 4 and arr.shape[0] == 1:
            arr = arr[0]
        if arr.ndim != 3 or arr.shape[-1] not in (3,4):
            raise ValueError("Expected (H,W,3/4) or (1,H,W,3/4).")
        arr = arr.astype(np.float32, copy=False)
        if np.nanmax(arr) > 1.0 + 1e-6:
            arr = np.clip(arr, 0, 255) / 255.0
        if arr.shape[-1] == 4:
            rgb, a = arr[..., :3], np.clip(arr[..., 3], 0, 1)
        else:
            rgb = arr
            if derive_alpha == "soft":
                a = np.clip(rgb.max(axis=-1), 0, 1)
            elif derive_alpha == "binary":
                a = (rgb.max(axis=-1) > (black_thresh/255.0)).astype(np.float32)
            else:
                raise ValueError("derive_alpha must be 'soft' or 'binary'")
        a3 = a[..., None]
        rgb_on_white = rgb * a3 + (1.0 - a3) * 1.0
        return np.concatenate([rgb_on_white, a3], axis=-1)[None, ...].astype(np.float32)

    @staticmethod
    def center_cutout_keep_scale_2(cutout_rgba: np.ndarray,
                                   output_size: Optional[Tuple[int,int]] = None,
                                   pad: int = 0,
                                   overflow: Literal["crop","error","shrink"] = "crop",
                                   alpha_threshold: int = 10,
                                   white_thresh: int = 245,
                                   return_transform: bool = False):
        """Re-center RGBA on white; ALSO produce centered silhouette RGBA. Returns (1,H,W,4) each."""
        if cutout_rgba.ndim == 4 and cutout_rgba.shape[0] == 1:
            img = cutout_rgba[0]
        else:
            img = cutout_rgba
        if img.ndim != 3 or img.shape[2] != 4:
            raise ValueError("Expected (H,W,4) or (1,H,W,4).")
        if img.dtype != np.uint8:
            img = (np.clip(img, 0, 1)*255.0 if np.nanmax(img) <= 1.0 else np.clip(img,0,255)).astype(np.uint8)
        H, W, _ = img.shape
        out_w, out_h = (W, H) if output_size is None else (int(output_size[0]), int(output_size[1]))
        alpha = img[..., 3]
        use_alpha = not (np.all(alpha <= alpha_threshold) or np.all(alpha >= 255))
        if use_alpha:
            hard_mask = (alpha > alpha_threshold).astype(np.uint8)
        else:
            hard_mask = (img[..., :3] < white_thresh).any(axis=2).astype(np.uint8)
        cnts, _ = cv2.findContours(hard_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        if not cnts:
            centered = np.full((out_h, out_w, 4), 255, np.uint8)[None, ...]
            sil      = np.zeros((out_h, out_w, 4), np.uint8)[None, ...]
            info = {'scale':1.0,'bbox':(0,0,0,0),'offset':(0,0),
                    'pasted_box':(0,0,0,0),'source_box':(0,0,0,0)}
            return (centered, sil, info) if return_transform else (centered, sil)

        x, y, w, h = cv2.boundingRect(max(cnts, key=cv2.contourArea))
        roi = img[y:y+h, x:x+w, :]
        roi_alpha = roi[..., 3] if use_alpha else (hard_mask[y:y+h, x:x+w]*255).astype(np.uint8)

        centered = np.full((out_h, out_w, 4), 255, np.uint8)
        sil      = np.zeros((out_h, out_w, 4), np.uint8)

        pad = max(0, int(pad))
        inner_w, inner_h = max(1, out_w-2*pad), max(1, out_h-2*pad)
        scale = 1.0
        if not (w <= inner_w and h <= inner_h):
            if overflow == "error":
                raise ValueError("Cutout too large.")
            if overflow == "shrink":
                scale = min(inner_w/max(1,w), inner_h/max(1,h))
                new_w, new_h = max(1,int(round(w*scale))), max(1,int(round(h*scale)))
                roi = cv2.resize(roi, (new_w,new_h), interpolation=cv2.INTER_AREA if scale<1 else cv2.INTER_LINEAR)
                roi_alpha = cv2.resize(roi_alpha, (new_w,new_h), interpolation=cv2.INTER_NEAREST)
                w, h = new_w, new_h
        start_x = pad + (inner_w - w)//2
        start_y = pad + (inner_h - h)//2

        dx0, dy0 = max(0,start_x), max(0,start_y)
        dx1, dy1 = min(out_w, start_x+w), min(out_h, start_y+h)
        sx0, sy0 = max(0,-start_x), max(0,-start_y)
        sx1, sy1 = sx0 + max(0,dx1-dx0), sy0 + max(0,dy1-dy0)

        if dx1>dx0 and dy1>dy0:
            src_rgba  = roi[sy0:sy1, sx0:sx1]
            src_alpha = roi_alpha[sy0:sy1, sx0:sx1]
            a = (src_alpha.astype(np.float32)/255.0)[..., None]
            dst = centered[dy0:dy1, dx0:dx1].astype(np.float32)
            dst[..., :3] = a*src_rgba[..., :3].astype(np.float32) + (1.0-a)*dst[..., :3]
            centered[dy0:dy1, dx0:dx1, :3] = dst[..., :3].astype(np.uint8)
            centered[dy0:dy1, dx0:dx1, 3]  = 255
            patch = sil[dy0:dy1, dx0:dx1]
            patch[..., 0:3] = (src_alpha>0).astype(np.uint8)[..., None]*255
            patch[..., 3]   = src_alpha
            sil[dy0:dy1, dx0:dx1] = patch

        centered = centered[None, ...]
        sil      = sil[None, ...]
        info = {'scale': scale, 'bbox': (x,y,w,h), 'offset': (start_x,start_y),
                'pasted_box': (dx0,dy0,dx1-dx0,dy1-dy0), 'source_box': (sx0,sy0,sx1-sx0,sy1-sy0)}
        return (centered, sil, info) if return_transform else (centered, sil)


class Render:
    """Render setup & overlay (stateless)."""
    @staticmethod
    def get_camera_intrinsics(scale_factor=1):
        K = np.array([
            [1400.*scale_factor, 0., 128],
            [0., 1800.*scale_factor, 128],
            [0., 0., 1.]
        ], dtype="double")
        return K

    @staticmethod
    def create_perspective_camera(K, image_size, device):
        from pytorch3d.renderer import PerspectiveCameras
        fcl = ((K[0,0], K[1,1]),)
        prp = ((K[0,2], K[1,2]),)
        return PerspectiveCameras(focal_length=fcl, principal_point=prp,
                                  in_ndc=False, image_size=image_size, device=device)

    @staticmethod
    def create_silhouette_renderer(cameras, image_size, device):
        from pytorch3d.renderer import (RasterizationSettings, MeshRenderer, MeshRasterizer,
                                        BlendParams, SoftSilhouetteShader)
        blend = BlendParams(sigma=1e-5, gamma=1e-4)
        rast  = RasterizationSettings(image_size=image_size,
                                      blur_radius=np.log(1./1e-4 - 1.)*blend.sigma,
                                      faces_per_pixel=100,
                                      max_faces_per_bin=200000)
        return MeshRenderer(
            rasterizer=MeshRasterizer(cameras=cameras, raster_settings=rast),
            shader=SoftSilhouetteShader(blend_params=blend)
        )

    @staticmethod
    def create_phong_renderer(cameras, image_size, device):
        from pytorch3d.renderer import (RasterizationSettings, MeshRenderer, MeshRasterizer,
                                        BlendParams, SoftPhongShader, PointLights)
        rast = RasterizationSettings(image_size=image_size, blur_radius=0.0,
                                     faces_per_pixel=10, max_faces_per_bin=200000)
        lights = PointLights(device=device, location=((2.0,2.0,-2.0),))
        blend  = BlendParams(sigma=1e-4, gamma=1e-4, background_color=(1.0,1.0,1.0))
        return MeshRenderer(
            rasterizer=MeshRasterizer(cameras=cameras, raster_settings=rast),
            shader=SoftPhongShader(device=device, cameras=cameras, lights=lights, blend_params=blend)
        )

    @staticmethod
    def render_mesh(mesh, silhouette_renderer, phong_renderer, R, T):
        sil = silhouette_renderer(meshes_world=mesh, R=R, T=T)
        img = phong_renderer(meshes_world=mesh, R=R, T=T)
        return sil, img

    @staticmethod
    def alpha_blend(rgb_rendered, alpha_rendered, background_img):
        return rgb_rendered * alpha_rendered + background_img * (1.0 - alpha_rendered)

    @staticmethod
    def render_overlay(model, phong_renderer, mask_background):
        R, T = Cam.compute_camera_pose_from_center(model.camera_position, device=model.device)
        R, T = Cam.add_camera_roll_to_RT(R, T, roll_deg=model.roll_deg, device=model.device, mode="world")
        rgba = phong_renderer(meshes_world=model.meshes.clone(), R=R, T=T)[0]
        rgb  = rgba[..., :3].detach().cpu().numpy()
        a    = rgba[..., 3].detach().cpu().numpy()[..., None]
        overlay = Render.alpha_blend(rgb, a, mask_background)
        return img_as_ubyte(overlay)


class IO:
    """I/O & assets (stateful paths are passed in)."""
    @staticmethod
    def get_device():
        if torch.cuda.is_available():
            torch.cuda.set_device(0)
            dev = torch.device("cuda:0")
            print(f"[INFO] Using CUDA: {dev}.")
        else:
            dev = torch.device("cpu")
            print("[INFO] Using CPU.")
        return dev

    @staticmethod
    def _download_file(url, dest_path):
        with requests.get(url, stream=True) as r:
            with open(dest_path, "wb") as f:
                shutil.copyfileobj(r.raw, f)

    @staticmethod
    def get_cad_model_files(assets_dir: str):
        os.makedirs(assets_dir, exist_ok=True)
        # (unchanged URLs)
        files = {
            "final_obj": ("https://www.dropbox.com/scl/fi/kg5zutt2deamc4ksi17f7/final_model.obj?rlkey=3uenujpffgqk1t0edx5plj5oe&dl=1",
                          os.path.join(assets_dir, "final_model.obj")),
            "detailed_ply": ("https://www.dropbox.com/scl/fi/ubit9vw23fphcwpu0w88v/obj_000001.ply?rlkey=hx9f12a9ka8160mhxq8vnz7yk&dl=1",
                             os.path.join(assets_dir, "obj_000001.ply")),
            "cad_ply": ("https://www.dropbox.com/scl/fi/y55irvvd58yw4ucu3igup/sse_only_jun9_scaled100.ply?rlkey=6io3u20qpupumdfp18i1ssmr9&dl=1",
                        os.path.join(assets_dir, "sse_only_jun9_scaled100.ply")),
            "cad_tex": ("https://www.dropbox.com/scl/fi/rj1thjtu16o2344vh5mjn/sse_texture.png?rlkey=503sm6dwzv1f8gyy1jgg6tnxl&dl=1",
                        os.path.join(assets_dir, "sse_texture.png")),
            "sse_obj": ("https://www.dropbox.com/scl/fi/nwfswb9yqaox5guzxedyp/sse_only.obj?rlkey=ub48myfosw690btzg7gfoxert&dl=1",
                        os.path.join(assets_dir, "sse_only.obj")),
            "sse_mtl": ("https://www.dropbox.com/scl/fi/3mz2zgppszxj6cvv2mgqm/sse_only.mtl?rlkey=ckrw1ko30dk9d8unvsl659vmr&dl=1",
                        os.path.join(assets_dir, "sse_only.mtl")),
            "sse_rt_obj": ("https://www.dropbox.com/scl/fi/8oo2ofzz6gtxn3ve2ue6r/sse_only_real_texture.obj?rlkey=a13utsbxbzib2c61j5gcasblt&dl=1",
                           os.path.join(assets_dir, "sse_only_real_texture.obj")),
            "sse_rt_mtl": ("https://www.dropbox.com/scl/fi/e9ld55c4cyq4e6bguwti6/sse_only_real_texture.mtl?rlkey=0lvgesvk9yoi245gndcw0arqm&dl=1",
                           os.path.join(assets_dir, "sse_only_real_texture.mtl")),
            "img_00002": ("https://www.dropbox.com/scl/fi/b4bjgg4w6n6c6ufx1krhg/00002.png?rlkey=2h6yikdk1liz4b0ojpw4dcz50&dl=1",
                          os.path.join(assets_dir, "00002.png")),
            "img_00016": ("https://www.dropbox.com/scl/fi/0nf86esq8723iirfz35uj/00016.png?rlkey=9yac94rwxj7uppz4oorgntpqz&dl=1",
                          os.path.join(assets_dir, "00016.png")),
        }
        for url, path in files.values():
            print(f"📥 Downloading -> {path}")
            IO._download_file(url, path)
        print(f"✅ Files in: {assets_dir}")
        return {k: p for k, (_, p) in files.items()}

    @staticmethod
    def load_cad_mesh(cad_path: str, device: torch.device) -> Meshes:
        cad_path = Path(cad_path); ext = cad_path.suffix.lower()
        if ext == ".obj":
            verts, faces, _ = load_obj(str(cad_path), load_textures=False)
            verts = verts.to(device); faces = faces.verts_idx.to(device)
        elif ext == ".ply":
            verts, faces = load_ply(str(cad_path))
            verts = (verts / 1000.0).to(device); faces = faces.to(device)
        else:
            raise ValueError(f"Unsupported CAD file: {ext}")
        verts_rgb = torch.tensor([1.0, 1.0, 0.0], device=device).repeat(verts.shape[0], 1)
        textures = TexturesVertex(verts_features=verts_rgb[None])
        return Meshes(verts=[verts], faces=[faces], textures=textures)

    @staticmethod
    def load_cad_mesh_with_texture(cad_path: str, device: torch.device) -> Meshes:
        return load_objs_as_meshes([str(cad_path)], device=device)


class Optim:
    """Optimization loops (stateless)."""
    @staticmethod
    def optimize_camera_pose_best_loss(model, optimizer, mask_background, phong_renderer, writer,
                                       num_iter=100, log_every=10, patience=10):
        best_loss, wait = float("inf"), 0
        final_R, final_T = None, None
        loop = tqdm(range(num_iter))
        for i in loop:
            torch.cuda.empty_cache()
            optimizer.zero_grad()
            loss, _, _, _ = model()
            loss.backward()
            optimizer.step()
            cur = float(loss.item())
            loop.set_description(f"Optimizing (loss {cur:.4f})")
            if cur < best_loss:
                best_loss, wait = cur, 0
                final_R, final_T = Cam.compute_camera_pose_from_center(model.camera_position, device=model.device)
            else:
                wait += 1
                if wait >= patience:
                    print(f"[INFO] Early stopping at step {i}.")
                    break
            if i % log_every == 0:
                frame = Render.render_overlay(model, phong_renderer, mask_background)
                writer.append_data(frame)
        writer.close()
        print(f"[RESULT] Best loss: {best_loss:.6f}")
        return final_R, final_T

    @staticmethod
    def optimize_camera_pose(model, optimizer, mask_background, phong_renderer, writer,
                             num_iter=100, log_every=10, patience=10, tolerance=10.0):
        best_loss, wait = float("inf"), 0
        final_R, final_T = None, None
        loop = tqdm(range(num_iter))
        for i in loop:
            torch.cuda.empty_cache()
            optimizer.zero_grad()
            loss, im, _, _ = model()
            loss.backward()
            optimizer.step()
            cur = float(loss.item())
            loop.set_description(f"Optimizing (loss {cur:.4f})")
            if abs(best_loss - cur) < tolerance:
                wait += 1
                if wait >= patience:
                    print(f"[INFO] Early stopping at step {i}.")
                    break
            else:
                best_loss, wait = cur, 0
            if i % log_every == 0:
                R, T = Cam.compute_camera_pose_from_center(model.camera_position, device=model.device)
                frame = Render.render_overlay(model, phong_renderer, mask_background)
                writer.append_data(frame)
                final_R, final_T = R, T
        writer.close()
        return final_R, final_T


# ========= Facade (keeps your current calls working) ==========================

class SSE_Util:
    """
    Thin facade that keeps your existing calls intact.
    Only path state lives here; everything else delegates to namespaces above.
    """
    def __init__(self, local_path: str):
        self.local_path = local_path
        self.assets_dir = os.path.join(self.local_path, "assets/")

    # --- convenience / backwards-compat wrappers ---
    create_gif_writer = staticmethod(lambda filepath, duration=0.5: imageio.get_writer(filepath, mode="I", duration=duration))

    # Camera
    get_camera_position = staticmethod(Cam.get_camera_position)
    camera_center_to_dist_elev_azim = staticmethod(Cam.camera_center_to_dist_elev_azim)
    get_camera_pose = staticmethod(Cam.get_camera_pose)
    add_camera_roll_to_RT = staticmethod(Cam.add_camera_roll_to_RT)
    compute_camera_pose = staticmethod(Cam.compute_camera_pose_from_center)

    # Image
    read_rgb_cutout_black_bg = staticmethod(Img.read_rgb_cutout_black_bg)
    center_cutout_rgb_uint8 = staticmethod(Img.center_cutout_rgb_uint8)
    crop_center_to_size_uint8 = staticmethod(Img.crop_center_to_size_uint8)
    add_alpha_from_black_bg_uint8 = staticmethod(Img.add_alpha_from_black_bg_uint8)
    to_float_batched_rgba_white_bg_preserve_alpha = staticmethod(Img.to_float_batched_rgba_white_bg_preserve_alpha)
    center_cutout_keep_scale_2 = staticmethod(Img.center_cutout_keep_scale_2)

    # Render
    get_camera_intrinsics = staticmethod(Render.get_camera_intrinsics)
    create_perspective_camera = staticmethod(Render.create_perspective_camera)
    create_silhouette_renderer = staticmethod(Render.create_silhouette_renderer)
    create_phong_renderer = staticmethod(Render.create_phong_renderer)
    render_mesh = staticmethod(Render.render_mesh)
    alpha_blend = staticmethod(Render.alpha_blend)
    render_overlay = staticmethod(Render.render_overlay)

    # IO / assets
    get_device = staticmethod(IO.get_device)
    _download_file = staticmethod(IO._download_file)
    get_cad_model_files = lambda self: IO.get_cad_model_files(self.assets_dir)
    load_cad_mesh = staticmethod(IO.load_cad_mesh)
    load_cad_mesh_with_texture = staticmethod(IO.load_cad_mesh_with_texture)

    # Optimization
    optimize_camera_pose_best_loss = staticmethod(Optim.optimize_camera_pose_best_loss)
    optimize_camera_pose = staticmethod(Optim.optimize_camera_pose)


In [None]:
# Helper function to prepare video data for processing
import os, shutil, zipfile

def prepare_video_data(local_path: str,
                       video_name: str,
                       *,
                       delete_zips: bool = True,
                       overwrite: bool = True,
                       verbose: bool = True):
    """
    Prepare video data under <local_path>/video_data.

    Looks for these ZIPs in local_path:
      - {video_name}.zip           -> extracted to 'images'
      - masks_{video_name}.zip     -> extracted to 'masks'
      - cutouts_{video_name}.zip   -> extracted to 'cutouts'

    Moves found ZIPs to <local_path>/video_data, extracts them there,
    renames extracted content to the target dir names, cleans __MACOSX,
    and optionally deletes the ZIPs.

    Returns:
        dict with keys: created (mapping), missing (list),
        zips_moved (list), zips_used (list)
    """
    base = os.path.abspath(local_path)
    video_data = os.path.join(base, "video_data")
    os.makedirs(video_data, exist_ok=True)

    def vprint(*a, **k):
        if verbose:
            print(*a, **k)

    zip_specs = [
        (f"{video_name}.zip",          "images"),
        (f"masks_{video_name}.zip",    "masks"),
        (f"cutouts_{video_name}.zip",  "cutouts"),
    ]

    results = {"created": {}, "missing": [], "zips_moved": [], "zips_used": []}

    for zip_name, target_dirname in zip_specs:
        src_zip = os.path.join(base, zip_name)
        vd_zip  = os.path.join(video_data, zip_name)

        # Pick which ZIP path to use, moving if necessary
        if os.path.exists(src_zip):
            if os.path.exists(vd_zip) and overwrite:
                os.remove(vd_zip)
            shutil.move(src_zip, video_data)  # keeps filename
            zip_path = vd_zip
            results["zips_moved"].append(zip_name)
            vprint(f"Moved: {zip_name} -> video_data/")
        elif os.path.exists(vd_zip):
            zip_path = vd_zip
            results["zips_used"].append(zip_name)
            vprint(f"Using existing in video_data/: {zip_name}")
        else:
            results["missing"].append(zip_name)
            vprint(f"Missing: {zip_name} (skipping)")
            continue

        # Track directory contents before extraction
        before = set(os.listdir(video_data))

        # Extract
        with zipfile.ZipFile(zip_path, "r") as zf:
            zf.extractall(video_data)

        # Identify what was created
        after = set(os.listdir(video_data))
        new_entries = sorted(list(after - before))

        # Clean macOS metadata
        mac_dir = os.path.join(video_data, "__MACOSX")
        if os.path.isdir(mac_dir):
            shutil.rmtree(mac_dir, ignore_errors=True)
            if "__MACOSX" in new_entries:
                new_entries.remove("__MACOSX")

        target_dir = os.path.join(video_data, target_dirname)

        # Handle existing target
        if os.path.exists(target_dir):
            if overwrite:
                shutil.rmtree(target_dir)
            else:
                raise FileExistsError(
                    f"Target directory already exists: {target_dir}. "
                    "Set overwrite=True to replace."
                )

        # Decide how to move the extracted content into target_dir
        new_dirs  = [e for e in new_entries if os.path.isdir(os.path.join(video_data, e))]
        new_files = [e for e in new_entries if os.path.isfile(os.path.join(video_data, e))]

        if len(new_dirs) == 1 and not new_files:
            src_dir = os.path.join(video_data, new_dirs[0])
            # If the extracted dir is already named like the target, nothing to move
            if os.path.abspath(src_dir) != os.path.abspath(target_dir):
                shutil.move(src_dir, target_dir)
            else:
                vprint(f"Extracted folder already named '{target_dirname}'")
        else:
            # Mixed content or flat files: create target and move all new entries into it
            os.makedirs(target_dir, exist_ok=True)
            for name in new_entries:
                p = os.path.join(video_data, name)
                # Skip if it was already moved (defensive)
                if not os.path.exists(p):
                    continue
                shutil.move(p, os.path.join(target_dir, os.path.basename(p)))

        # Optional: delete the ZIP after successful extraction
        if delete_zips:
            try:
                os.remove(zip_path)
            except FileNotFoundError:
                pass

        results["created"][target_dirname] = target_dir
        vprint(f"Prepared: {target_dirname} -> {target_dir}")

    return results


In [None]:
import os
import re
import cv2
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

def show_mask_grid(local_path: str,
                   pattern: str = r'^\d{5}\.png$',
                   max_images: int = 20,
                   cols: int = 5,
                   figsize_cell: float = 2.0,
                   show: bool = True):
    """
    Read masks from <local_path>/video_data/masks/ and display as a grid.

    Args:
        local_path: Base directory (e.g., "/content/").
        pattern: Regex for mask filenames. Default '00000.png' style.
        max_images: Max number of masks to display.
        cols: Number of columns in the grid.
        figsize_cell: Figure size per cell (inches).
        show: If True, calls plt.show().

    Returns:
        (fig, grid, files_used, masks_np)
        - fig: Matplotlib Figure (or None if nothing to show)
        - grid: ImageGrid instance (or None)
        - files_used: list of filenames shown
        - masks_np: list of numpy arrays (grayscale)
    """
    mask_dir = os.path.join(local_path, "video_data", "masks")
    if not os.path.isdir(mask_dir):
        print(f"[show_mask_grid] Directory not found: {mask_dir}")
        return None, None, [], []

    rx = re.compile(pattern)
    files = [f for f in os.listdir(mask_dir) if rx.match(f)]

    # Sort numerically if filenames are like '00012.png'
    def _num_key(name):
        stem = os.path.splitext(name)[0]
        try:
            return int(stem)
        except ValueError:
            return stem  # fallback to string
    files = sorted(files, key=_num_key)

    if not files:
        print(f"[show_mask_grid] No files matching pattern {pattern} in {mask_dir}")
        return None, None, [], []

    # Load grayscale masks
    files = files[:max_images]
    masks = []
    used = []
    for f in files:
        img = cv2.imread(os.path.join(mask_dir, f), cv2.IMREAD_GRAYSCALE)
        if img is None:
            continue
        masks.append(img)
        used.append(f)

    if not masks:
        print("[show_mask_grid] No masks could be read.")
        return None, None, [], []

    num_images = len(masks)
    cols = max(1, cols)
    rows = (num_images + cols - 1) // cols

    fig = plt.figure(figsize=(cols * figsize_cell, rows * figsize_cell))
    grid = ImageGrid(fig, 111,
                     nrows_ncols=(rows, cols),
                     axes_pad=0.0,
                     share_all=True)

    for ax, img in zip(grid, masks):
        ax.imshow(img, cmap='gray')
        ax.axis('off')

    # Turn off any extra axes if grid > num_images
    for ax in grid[len(masks):]:
        ax.axis('off')

    if show:
        plt.show()

    return fig, grid, used, masks

In [None]:
import os
import re
import cv2
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

def show_cutout_grid(local_path: str,
                     pattern: str = r'^\d{5}\.png$',
                     max_images: int = 20,
                     cols: int = 5,
                     figsize_cell: float = 2.0,
                     show: bool = True):
    """
    Read cutout images from <local_path>/video_data/cutouts/ and display as a grid.

    Args:
        local_path: Base directory (e.g., "/content/").
        pattern: Regex for filenames (default: '00000.png' style).
        max_images: Maximum number of images to display.
        cols: Number of columns in the grid.
        figsize_cell: Figure size per cell (inches).
        show: If True, calls plt.show().

    Returns:
        (fig, grid, files_used, images_rgb)
          - fig: Matplotlib Figure (or None if nothing to show)
          - grid: ImageGrid instance (or None)
          - files_used: list of filenames shown
          - images_rgb: list of numpy arrays in RGB (H, W, 3, uint8)
    """
    cutout_dir = os.path.join(local_path, "video_data", "cutouts")
    if not os.path.isdir(cutout_dir):
        print(f"[show_cutout_grid] Directory not found: {cutout_dir}")
        return None, None, [], []

    rx = re.compile(pattern)
    files = [f for f in os.listdir(cutout_dir) if rx.match(f)]

    # Sort numerically if filenames look like '00012.png'
    def _num_key(name):
        stem = os.path.splitext(name)[0]
        try:
            return int(stem)
        except ValueError:
            return stem
    files = sorted(files, key=_num_key)

    if not files:
        print(f"[show_cutout_grid] No files matching pattern {pattern} in {cutout_dir}")
        return None, None, [], []

    # Load color images (BGR) and convert to RGB
    files = files[:max_images]
    images_rgb, used = [], []
    for f in files:
        bgr = cv2.imread(os.path.join(cutout_dir, f), cv2.IMREAD_COLOR)
        if bgr is None:
            continue
        rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
        images_rgb.append(rgb)
        used.append(f)

    if not images_rgb:
        print("[show_cutout_grid] No images could be read.")
        return None, None, [], []

    num_images = len(images_rgb)
    cols = max(1, cols)
    rows = (num_images + cols - 1) // cols

    fig = plt.figure(figsize=(cols * figsize_cell, rows * figsize_cell))
    grid = ImageGrid(fig, 111, nrows_ncols=(rows, cols), axes_pad=0.0, share_all=True)

    for ax, img in zip(grid, images_rgb):
        ax.imshow(img)
        ax.axis('off')

    # Hide any extra axes
    for ax in grid[len(images_rgb):]:
        ax.axis('off')

    if show:
        plt.show()

    return fig, grid, used, images_rgb


In [None]:
import os
from typing import List, Optional, Dict, Any

def estimate_poses_for_masks(
    pose_est,
    mask_files: List[str],
    local_path: str,
    *,
    N: Optional[int] = None,
    iterations: int = 200,
    learning_rate: float = 0.05,
    tol: float = 1.0,
    show_overlay: bool = True,
    print_results: bool = True,
    stop_on_error: bool = False,
    output_size: tuple,
) -> List[Dict[str, Any]]:
    """
    Run pose estimation over a list of mask files.

    Args:
        pose_est: Your pose estimator object (e.g., pose_est01) exposing:
                  set_number_of_iterations, set_learning_rate, set_reference_mask,
                  init_model, run_optimization, display_current_and_reference_as_overlay,
                  printout_results, and attribute .model.camera_position (Tensor).
        mask_files: List of mask file paths (relative to local_path or absolute).
        local_path: Base directory containing the mask files (e.g., "/content/").
        N: If provided, process only the first N files; otherwise process all.
        iterations: Number of optimization iterations to set on the estimator.
        learning_rate: Learning rate to set and to pass to init_model.
        tol: Tolerance value passed to run_optimization(tol=...).
        show_overlay: If True, calls display_current_and_reference_as_overlay() per file.
        print_results: If True, calls printout_results() per file.
        stop_on_error: If True, raises on first file error; otherwise continues.

    Returns:
        A list of dicts, one per processed file:
        [
          {
            "file": "<filename>",
            "path": "<full_path>",
            "R": Tensor (detached clone),
            "T": Tensor (detached clone),
            "camera_position": Tensor (detached clone)
          },
          ...
        ]
    """

    selected = mask_files[:N] if (N is not None) else mask_files
    results = []

    for mask_file in selected:
        try:
            print(f"\nProcessing: {os.path.basename(mask_file)}...")
            image_path = os.path.join(local_path, mask_file)

            # Set the current reference mask
            # pose_est.set_reference_mask(image_path)
            _ = pose_est.read_rgb_cutout_black_bg_from_file(image_path, output_size=output_size)


            # Use current camera position as eye_override (batchify)
            camera_pos = pose_est.model.camera_position.detach().unsqueeze(0)

            # Initialize model for this image
            pose_est.init_model(learning_rate=learning_rate, eye_override=camera_pos)

            # Optimize
            R, T = pose_est.run_optimization(tol=tol)

            # Show overlay & print per-image results
            if show_overlay:
                pose_est.display_current_and_reference_as_overlay()
            if print_results:
                pose_est.printout_results()

            # Snapshot final camera position
            camera_position = pose_est.model.camera_position.detach().clone()

            results.append({
                "file": os.path.basename(mask_file),
                "path": image_path,
                "R": R.detach().clone(),
                "T": T.detach().clone(),
                "camera_position": camera_position
            })

        except Exception as e:
            msg = f"[estimate_poses_for_masks] Error processing {mask_file}: {e}"
            if stop_on_error:
                raise RuntimeError(msg) from e
            else:
                print(msg)
                continue

    print("\nDone. Collected results for", len(results), "file(s).")
    return results


In [None]:
import os
import numpy as np
import imageio
import matplotlib.pyplot as plt

def make_pose_gif(
    pose_est,
    results,
    out_path: str = "./result.gif",
    duration: float = 0.5,          # seconds per frame
    apply_alpha: bool = True,
    preview: bool = False,
    verbose: bool = True,
):
    """
    Create an animated GIF by re-rendering frames at the camera positions in `results`.

    Args:
        pose_est: your estimator (e.g., pose_est01) exposing render_sse_image(eye_override=pos)
        results: list of tuples (R, T, pos) OR dicts with key "camera_position" (or "pos")
        out_path: where to save the GIF
        duration: per-frame duration in seconds
        apply_alpha: if True and frames are RGBA, multiply RGB by alpha
        preview: if True, show each frame with matplotlib (slow)
        verbose: log progress

    Returns:
        (out_path, num_frames)
    """

    def _get_pos(entry):
        # tuple (R, T, pos)
        if isinstance(entry, (tuple, list)) and len(entry) >= 3:
            return entry[2]
        # dict with camera position
        if isinstance(entry, dict):
            if "camera_position" in entry:
                return entry["camera_position"]
            if "pos" in entry:
                return entry["pos"]
        raise ValueError("Unrecognized results entry format. Expected (R,T,pos) or dict with 'camera_position'/'pos'.")

    def _to_numpy(arr_or_tensor):
        # torch tensor -> numpy
        try:
            import torch
            if isinstance(arr_or_tensor, torch.Tensor):
                return arr_or_tensor.detach().cpu().numpy()
        except Exception:
            pass
        return np.asarray(arr_or_tensor)

    def _ensure_ch_last(img):
        # Move (C,H,W) -> (H,W,C)
        if img.ndim == 3 and img.shape[0] in (1,3,4) and img.shape[-1] not in (1,3,4):
            img = np.transpose(img, (1,2,0))
        return img

    def _rgba_to_rgb_uint8(img_rgba):
        rgb = img_rgba[..., :3]
        if not apply_alpha or img_rgba.shape[-1] < 4:
            return _to_uint8(rgb)
        alpha = img_rgba[..., 3]
        # normalize alpha if needed
        if alpha.max() > 1.0 + 1e-6:
            alpha = alpha / 255.0
        rgb = rgb * alpha[..., None]
        return _to_uint8(rgb)

    def _to_uint8(img):
        img = np.asarray(img)
        if np.issubdtype(img.dtype, np.floating):
            if img.max() <= 1.0 + 1e-6 and img.min() >= -1e-6:
                img = (np.clip(img, 0.0, 1.0) * 255.0).round()
            else:
                img = np.clip(img, 0.0, 255.0).round()
            img = img.astype(np.uint8)
        elif img.dtype != np.uint8:
            img = np.clip(img, 0, 255).astype(np.uint8)
        return img

    os.makedirs(os.path.dirname(os.path.abspath(out_path)) or ".", exist_ok=True)
    frame_count = 0

    with imageio.get_writer(out_path, mode="I", duration=duration, loop=0) as writer:
        for i, entry in enumerate(results):
            pos = _get_pos(entry)

            # Many setups store pos as torch tensor; allow shape (3,) or (1,3)
            try:
                import torch
                if isinstance(pos, torch.Tensor):
                    pos_np = pos.detach().cpu()
                    # Leave shape as provided; your renderer already accepts (1,3) or (3,)
                    eye_override = pos_np
                else:
                    eye_override = pos
            except Exception:
                eye_override = pos

            if verbose:
                print(f"[INFO] Rendering frame {i}...")

            # Render with the given camera position
            frame = pose_est.render_sse_image(eye_override=eye_override)

            # Convert to numpy, channel-last
            frame = _ensure_ch_last(_to_numpy(frame))

            # If RGBA, apply alpha; otherwise just convert
            if frame.ndim == 3 and frame.shape[-1] == 4:
                rgb_uint8 = _rgba_to_rgb_uint8(frame)
            else:
                rgb_uint8 = _to_uint8(frame)

            writer.append_data(rgb_uint8)
            frame_count += 1

            if preview:
                plt.imshow(rgb_uint8)
                plt.axis('off')
                plt.title(f"Frame {i}")
                plt.show()

    if verbose:
        print(f"[INFO] GIF saved to {out_path} with {frame_count} frame(s).")
    return out_path, frame_count


### Set up a basic model

The model class and initialize a parameter for the camera position.

In [None]:
# class Model(nn.Module):
#     def __init__(self, meshes, renderer, image_ref, initial_position):
#         super().__init__()
#         self.meshes = meshes
#         self.device = meshes.device
#         self.renderer = renderer
#         self.initial_position = initial_position

#         # Get the silhouette of the reference RGB image by finding all non-white pixel values.
#         image_ref = torch.from_numpy((image_ref[..., :3].max(-1) != 1).astype(np.float32))
#         # image_ref = (image_ref[..., :3].amax(dim=-1) != 1.0).float()

#         self.register_buffer('image_ref', image_ref)

#         # Create an optimizable parameter for the x, y, z position of the camera.
#         # This is the (x,y,z)_world camera translation vector
#         self.camera_position = nn.Parameter(
#             torch.from_numpy(initial_position).to(meshes.device)
#             )

#     def forward(self):

#         # Render the image using the updated camera position. Based on the new position of the
#         # camera we calculate the rotation and translation matrices.
#         R = look_at_rotation(self.camera_position[None, :], device=self.device)  # (1, 3, 3)
#         T = -torch.bmm(R.transpose(1, 2), self.camera_position[None, :, None])[:, :, 0]   # (1, 3)

#         image = self.renderer(meshes_world=self.meshes.clone(), R=R, T=T)

#         # Convert image to silhouette (just a mask here)
#         image = image[..., 3]

#         # Loss function
#         # loss = torch.sum((image - self.image_ref) ** 2)

#         loss1 = torch.sum((image - self.image_ref) ** 2)

#         loss2 = self.dice_loss(image, self.image_ref)

#         loss = 0.5 *loss1 + 0.5 * loss2

#         return loss, image, R, T


#     def dice_loss(self, pred, target, eps=1e-6):
#         pred_bin = (pred > 0.5).float()
#         target_bin = (target > 0.5).float()

#         intersection = (pred_bin * target_bin).sum()
#         union = pred_bin.sum() + target_bin.sum()

#         dice = (2. * intersection + eps) / (union + eps)
#         return 1. - dice


# New model with roll angle

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from pytorch3d.renderer.cameras import look_at_rotation


class Model(nn.Module):
    def __init__(
        self,
        meshes,
        renderer,
        image_ref,              # numpy or torch; HxW (mask) or HxWx3/4 (RGB/A)
        initial_position,       # (3,)
        optimize_roll: bool = True,
        roll_init_deg: float = 0.0,
        white_thresh: float = 0.99,   # for RGB refs: non-white => foreground
        auto_resize_ref: bool = True, # resize ref mask to render size
        local_path: str = "/content/",
    ):
        super().__init__()
        self.meshes = meshes
        self.device = meshes.device
        self.renderer = renderer
        self.white_thresh = white_thresh
        self.auto_resize_ref = auto_resize_ref

        self.util = SSE_Util(local_path)

        # ---- Build a 2D silhouette reference (H, W) in [0,1]
        ref = torch.as_tensor(image_ref).to(torch.float32)
        if ref.max() > 1.0 + 1e-6:
            ref = ref / 255.0
        # If a batch dim sneaked in, squeeze it (e.g., (1,H,W,4))
        if ref.ndim == 4 and ref.shape[0] == 1:
            ref = ref[0]
        # If CHW, permute to HWC
        if ref.ndim == 3 and ref.shape[0] in (1, 3, 4) and ref.shape[-1] not in (1, 3, 4):
            ref = ref.permute(1, 2, 0)

        if ref.ndim == 2:
            sil_ref = ref.clamp(0, 1)
        elif ref.ndim == 3 and ref.shape[-1] == 4:
            sil_ref = ref[..., 3].clamp(0, 1)                             # use alpha
        elif ref.ndim == 3 and ref.shape[-1] == 3:
            sil_ref = (ref.max(dim=-1).values < self.white_thresh).float() # non-white
        else:
            raise ValueError(f"Unsupported image_ref shape {tuple(ref.shape)}")



        self.register_buffer("image_ref", sil_ref.to(self.device))   # (H,W) float


        # ---- Learnable camera translation C = (x,y,z)
        C0 = torch.as_tensor(initial_position, dtype=torch.float32, device=self.device)
        if C0.ndim == 2 and C0.shape[0] == 1:
            C0 = C0[0]
        self.camera_position = nn.Parameter(C0)

        # ---- Learnable camera roll (deg) around camera Z
        roll0 = torch.tensor(roll_init_deg, dtype=torch.float32, device=self.device)
        self.roll_deg = nn.Parameter(roll0) if optimize_roll else None



    def _roll_Rz_cam(self, deg: torch.Tensor) -> torch.Tensor:
        """(1,3,3) rotation about camera Z by 'deg' degrees (pre-multiplies world->cam)."""
        theta = torch.deg2rad(deg.view(1))
        c, s = torch.cos(theta), torch.sin(theta)
        z = torch.zeros_like(c); o = torch.ones_like(c)
        Rz = torch.stack([
            torch.stack([ c, -s, z], dim=-1),
            torch.stack([ s,  c, z], dim=-1),
            torch.stack([ z,  z,  o], dim=-1),
        ], dim=1)  # (1,3,3)
        return Rz



    def forward(self):
        C = self.camera_position

        # Get Rotation and Translation from camera position
        R, T = self.util.get_camera_pose(eye_override=C, device=self.device)


        # Add roll (try mode="camera" first; if no visible spin, try mode="world")
        R, T = self.util.add_camera_roll_to_RT(R, T, roll_deg=self.roll_deg, device=self.device, mode="world")



        # R0 = look_at_rotation(C[None, :], device=self.device)  # (1,3,3)

        # # Camera roll in camera coords: R = R_roll @ R0
        # if self.roll_deg is not None:
        #     Rroll = self._roll_Rz_cam(self.roll_deg)
        #     R = torch.bmm(Rroll, R0)
        # else:
        #     R = R0

        # # T = -R^T * C
        # T = -torch.bmm(R.transpose(1, 2), C[None, :, None])[:, :, 0]  # (1,3)




        # Render -> RGBA, take alpha as silhouette
        rgba = self.renderer(meshes_world=self.meshes.clone(), R=R, T=T)
        if rgba.ndim == 4:
            rgba = rgba[0]                 # (H,W,4)
        pred = rgba[..., 3].clamp(0, 1)    # (H,W)

        # Match ref size if needed
        ref = self.image_ref
        if self.auto_resize_ref and ref.shape != pred.shape:
            ref = F.interpolate(ref[None, None, ...], size=pred.shape, mode="nearest")[0, 0]

        # Loss: MSE + (thresholded) Dice, same as your original intent
        loss1 = F.mse_loss(pred, ref)
        loss2 = self.dice_loss(pred, ref)
        loss = 0.5 * loss1 + 0.5 * loss2

        return loss, pred, R, T

    ## @staticmethod
    # def dice_loss(pred, target, eps=1e-6):
    #     pred_bin = (pred > 0.5).float()
    #     target_bin = (target > 0.5).float()
    #     inter = (pred_bin * target_bin).sum()
    #     denom = pred_bin.sum() + target_bin.sum()
    #     return 1.0 - (2.0 * inter + eps) / (denom + eps)

    def dice_loss(self, pred, target, eps=1e-6):
        pred_bin = (pred > 0.5).float()
        target_bin = (target > 0.5).float()

        intersection = (pred_bin * target_bin).sum()
        union = pred_bin.sum() + target_bin.sum()

        dice = (2. * intersection + eps) / (union + eps)
        return 1. - dice



### Pose-Estimator Class

In [None]:

import torch
from pytorch3d.structures import Meshes

class PoseEstimator:
    def __init__(self, local_path, use_light_model=True, image_size=(480, 720), scale_factor=1, num_iter=50):

        # Local path of server or local computer (e.g., /content/ for Colab)
        self.local_path = local_path

        # True if using the light-weight model
        self.use_light_model = use_light_model

        # Scale factor to resize image and intrinsics calibration matrix
        self.scale_factor = scale_factor

        # (Re-scaled) image size (probably it will eventually become a square image, e.g., 256x256)
        self.im_size = (int(image_size[0] * scale_factor), int(image_size[1] * scale_factor))

        # Maximum number of iterations to perform
        self.num_iter = num_iter

        # PyTorch3D mesh from CAD model. Set later by _load_cad_model()
        self.mesh = None

        # Model (Will be set later)
        self.model = None

        # Optimizer (will be set later)
        self.optimizer = None

        # These are the PyTorch3D renderers used to create images of the CAD model given a camera.
        # Set later by _setup_camera_and_renderer()
        self.silhouette_renderer = None
        self.phong_renderer = None

        # Reference silhouette.
        self.silhouette_ref = None

        # Reference image
        self.image_ref = None

        # CAD file path
        self.cad_file = None

        # Camera intrinsics
        self.K = None

        # Camera position
        self.camera_pos = None

        # Reference camera pose
        self.R_ref = None
        self.T_ref = None

        # Roll angle
        self.roll_deg = None


        #-------------------------------------------------

        # Declare the utilities object
        self.util = SSE_Util(local_path)

        # Set the device (i.e., cuda, cpu, and eventually mps)
        self.device = self.util.get_device()

        # Private method: Load cad models into memory
        self._load_cad_model()

        # Change center of cad model to its centroid (volume centroid)
        self._center_model_to_centroid()

        # Private method: Set up camera and some renderers (i.e., Phong, Silhouette)
        self._setup_camera_and_renderer()


    @torch.no_grad()
    def recenter_meshes_to_com(self, meshes: Meshes, method: str = "surface"):
        """
        Translate each mesh in a batched PyTorch3D Meshes so its center-of-mass is at the origin.

        Args:
            meshes: PyTorch3D Meshes (can be a batch).
            method: 'vertex' | 'surface' | 'volume'

        Returns:
            centered_meshes: new Meshes with verts translated so COM == (0,0,0)
            translations: (N, 3) tensor with the per-mesh translation that was SUBTRACTED
                          (i.e., original_verts - translations -> centered_verts)
        """
        if method not in {"vertex", "surface", "volume"}:
            raise ValueError("method must be 'vertex', 'surface', or 'volume'")

        verts_list = meshes.verts_list()
        faces_list = meshes.faces_list()
        device = verts_list[0].device

        translations = []
        new_verts_list = []

        for V, F in zip(verts_list, faces_list):
            # V: (Vn, 3) float; F: (Fn, 3) long
            if method == "vertex":
                com = V.mean(0)

            else:
                tri = V[F]                    # (F, 3, 3): triangle vertices
                a, b, c = tri[:, 0, :], tri[:, 1, :], tri[:, 2, :]

                if method == "surface":
                    # Triangle areas and centroids
                    ab = b - a
                    ac = c - a
                    areas = 0.5 * torch.linalg.norm(torch.cross(ab, ac, dim=1), dim=1)  # (F,)
                    centroids = (a + b + c) / 3.0                                       # (F, 3)
                    A = areas.sum()
                    if A.abs() < 1e-12:
                        # Degenerate surface -> fall back
                        com = V.mean(0)
                    else:
                        com = (areas[:, None] * centroids).sum(0) / A

                elif method == "volume":
                    # Volume of tetrahedra (0, a, b, c): V = dot(a, cross(b, c)) / 6
                    vols = torch.einsum('ij,ij->i', a, torch.cross(b, c, dim=1)) / 6.0   # signed (F,)
                    # Tetra centroid = (0 + a + b + c)/4 = (a+b+c)/4
                    tet_centroids = (a + b + c) / 4.0                                    # (F, 3)
                    Vtot = vols.sum()
                    if Vtot.abs() < 1e-12:
                        # Likely not closed or inconsistent winding; fall back
                        com = V.mean(0)
                    else:
                        com = (vols[:, None] * tet_centroids).sum(0) / Vtot

            translations.append(com)
            new_verts_list.append(V - com)   # shift so COM -> origin

        translations = torch.stack(translations, dim=0).to(device)

        # Rebuild a Meshes; keep textures/materials if present
        centered = Meshes(verts=new_verts_list, faces=faces_list, textures=meshes.textures)
        return centered, translations


    def _center_model_to_centroid(self):
        # meshes: a PyTorch3D Meshes (batched or single)
        centered_meshes, com = self.recenter_meshes_to_com(self.mesh, method="vertex")
        print("Translation applied (per mesh):", com)  # original COM positions

        # If you ever want to move it back later:
        # original_verts = centered_meshes.verts_list()[i] + com[i]

        self.mesh = centered_meshes


    def _load_cad_model(self):
        """
        Loads the cad model into memory.
        Here, we can select the light model or the detailed model.
        """
        if self.use_light_model:
            self.cad_file = self.local_path + "assets/final_model.obj"
        else:
            # cad_file = self.local_path + "assets/obj_000001.ply"
            self.cad_file = self.local_path + "assets/sse_only.obj"
            # self.cad_file = self.local_path + "assets/sse_only_real_texture.obj"

        # self.mesh = self.util.load_cad_mesh(self.cad_file, device=self.device)
        self.mesh = self.util.load_cad_mesh_with_texture(self.cad_file, device=self.device)


    def set_number_of_iterations(self, num_iter):
        """
        Sets the number of iterations for the optimization.
        """
        self.num_iter = num_iter
        print(f"[INFO]: Number of iterations set to {self.num_iter}")

    def _setup_camera_and_renderer(self):
        """
        Sets up the camera and some renderers (i.e., Phong, Silhouette).
        """

        # Get the intrinsic matrix to be used for rendering.
        self.K = self.util.get_camera_intrinsics(scale_factor=self.scale_factor)

        # Create the PyTorch3D perspective camera based on the K matrix.
        self.cameras = self.util.create_perspective_camera(self.K, image_size=(self.im_size,), device=self.device)

        # These are the PyTorch3D renderers used to create images of the CAD model given a camera.
        self.silhouette_renderer = self.util.create_silhouette_renderer(self.cameras, self.im_size, self.device)
        self.phong_renderer = self.util.create_phong_renderer(self.cameras, self.im_size, self.device)




    def add_camera_roll(self, R, T, roll_deg, device=None):
        """
        Add a Z-axis camera roll (degrees) to an existing (R,T) in PyTorch3D convention.
        Returns R_new, T_new with the same rank as inputs.
        """
        # to tensors
        R = torch.as_tensor(R)
        T = torch.as_tensor(T)
        dev   = device if device is not None else R.device
        dtype = R.dtype if R.is_floating_point() else torch.float32
        R = R.to(dev, dtype)
        T = T.to(dev, dtype)

        # ensure batch
        unbatched = (R.ndim == 2)  # (3,3) vs (1,3,3)
        if unbatched:
            R = R[None, ...]   # (1,3,3)
            T = T[None, ...]   # (1,3)

        # recover camera center C from (R, T):  T = -R^T C  =>  C = -R T
        C = -torch.matmul(R, T[..., None]).squeeze(-1)   # (B,3)

        # build Rz(roll) about camera Z
        theta = torch.tensor(float(roll_deg), dtype=dtype, device=dev)
        c, s = torch.cos(torch.deg2rad(theta)), torch.sin(torch.deg2rad(theta))
        z = torch.tensor(0.0, dtype=dtype, device=dev)
        o = torch.tensor(1.0, dtype=dtype, device=dev)
        Rz = torch.stack([
            torch.stack([ c, -s, z]),   # [ [c,-s,0],
            torch.stack([ s,  c, z]),   #   [s, c,0],
            torch.stack([ z,  z,  o])   #   [0, 0,1] ]
        ], dim=0)                        # (3,3)
        Rz = Rz.expand(R.shape[0], -1, -1)  # (B,3,3)

        # compose and recompute T to keep C fixed
        R_new = torch.matmul(Rz, R)                                  # (B,3,3)
        T_new = -torch.matmul(R_new.transpose(1, 2), C[..., None]).squeeze(-1)  # (B,3)

        if unbatched:
            R_new = R_new[0]
            T_new = T_new[0]
        return R_new, T_new





    def create_test_reference_image(self, distance=3, elevation=50, azimuth=-40, roll=0):
        """
        Renders a reference image for testing
        """

        # Get Rotation and Translation from camera position
        R, T = self.util.get_camera_pose(distance, elevation, azimuth, device=self.device)


        # Add roll (try mode="camera" first; if no visible spin, try mode="world")
        R, T = self.util.add_camera_roll_to_RT(R, T, roll_deg=roll, device=self.device, mode="world")


        self.R_ref, self.T_ref = R, T

        # Render a silhouette and an image of the CAD model to use as reference for tests.
        self.silhouette_ref, image_ref = self.util.render_mesh(
            self.mesh,
            self.silhouette_renderer,
            self.phong_renderer,
            R, T
        )
        self.image_ref = image_ref.cpu().numpy()

        # Check dimensions
        print(self.image_ref.shape)
        print(self.silhouette_ref.shape)


        return self.image_ref


    def create_reference_image_from_cutout(
        self,
        cutout_path: str,
        *,
        use_alpha_if_present: bool = True,
        white_thresh: float = 0.9,   # used only if we need to infer a mask from RGB
    ):
        """
        Load a PNG cutout and prepare:
          - self.image_ref: (1, H, W, 4) float32 NumPy in [0,1], composited on WHITE (alpha=1)
          - self.silhouette_ref: (1, H, W, 4) float32 Torch on self.device, white FG, alpha=mask

        No resizing is performed here.
        """
        import numpy as np
        import torch
        from PIL import Image

        # Load as RGBA
        img = Image.open(cutout_path).convert("RGBA")
        rgba = np.asarray(img).astype(np.float32) / 255.0          # (H, W, 4)
        rgb   = rgba[..., :3]
        alpha = rgba[..., 3]

        # Build mask (sil) in [0,1]
        if use_alpha_if_present and (alpha.max() - alpha.min()) > 1e-6:
            sil = alpha.astype(np.float32)
        else:
            # Infer mask from RGB: treat near-white as foreground
            # (If your object isn’t white, swap this to a "non-black" test or provide a mask.)
            sil = (rgb.max(axis=-1) > white_thresh).astype(np.float32)

        # -------- image_ref: composite over WHITE, alpha = 1
        rgb_on_white = rgb * sil[..., None] + (1.0 - sil[..., None]) * 1.0
        rgba_white = np.concatenate(
            [rgb_on_white, np.ones_like(sil, dtype=np.float32)[..., None]], axis=-1
        )  # (H, W, 4), alpha=1

        self.image_ref = rgba_white[None, ...].astype(np.float32)  # (1, H, W, 4) NumPy

        # -------- silhouette_ref: white FG (1,1,1), transparent BG (alpha = mask)
        sil_rgb = np.repeat(sil[..., None], 3, axis=-1)            # (H, W, 3)
        sil_rgba = np.concatenate([sil_rgb, sil[..., None]], axis=-1).astype(np.float32)  # (H, W, 4)


        self.silhouette_ref = torch.from_numpy(sil_rgba[None, ...]).to(self.device)      # (1, H, W, 4)

        # self.silhouette_ref = sil_rgba

        return self.image_ref


    def create_reference_image_from_cutout_old(
        self,
        cutout_path: str,
        *,
        use_alpha_if_present: bool = True,
        white_thresh: float = 0.9,   # used only if we need to infer a mask from RGB
    ):
        """
        Load a PNG cutout and prepare:
          - self.image_ref: (1, H, W, 4) float32 in [0,1], with a WHITE background
          - self.silhouette_ref: (H, W) float32 torch tensor in [0,1]

        No resizing is performed here.
        """
        import numpy as np
        import torch
        from PIL import Image

        # Load RGBA (preserve alpha if present)
        img = Image.open(cutout_path).convert("RGBA")
        rgba = np.asarray(img).astype(np.float32) / 255.0      # (H, W, 4)
        rgb   = rgba[..., :3]
        alpha = rgba[..., 3]

        # Build silhouette (H, W) in [0,1]
        # Prefer alpha if it's informative; otherwise infer: white (foreground) vs black (background)
        if use_alpha_if_present and (alpha.max() - alpha.min()) > 1e-6:
            sil = alpha.astype(np.float32)
        else:
            # If there's no useful alpha, assume 'white foreground on black background'
            # Foreground where any channel is near white
            sil = (rgb.max(axis=-1) > white_thresh).astype(np.float32)

        # Compose onto a WHITE background (so background is visibly white even if input was black)
        # rgb_out = sil * rgb + (1 - sil) * white
        rgb_out = rgb * sil[..., None] + (1.0 - sil[..., None]) * 1.0  # white = 1.0

        # Keep an alpha channel; using the silhouette as alpha is often convenient downstream
        alpha_out = sil

        rgba_out = np.concatenate([rgb_out, alpha_out[..., None]], axis=-1).astype(np.float32)  # (H, W, 4)

        # Store with a batch dimension (1, H, W, 4) as requested
        self.image_ref = rgba_out[None, ...]
        self.silhouette_ref = torch.from_numpy(sil).to(self.device)

        return self.image_ref





    def display_test_reference_image(self):
        """
        Displays the reference image for testing
        """

        plt.figure(figsize=(5, 5))


        # Reference image is numpy (1, 480, 720, 4). Look into keeping Ref image as a torch tensor
        # until we need to plot it.
        image_ref_np = self.image_ref.squeeze()
        # Plot reference image
        plt.subplot(1, 2, 1)
        plt.imshow(image_ref_np)
        plt.title("Phong Rendering")
        plt.axis('off')
        plt.tight_layout()

        # Silhouette is torch tensor torch.Size([1, 480, 720, 4])
        silhouette_np = self.silhouette_ref.cpu().numpy().squeeze()[..., 3]
        # Plot Silhouette
        plt.subplot(1, 2, 2)
        plt.imshow(silhouette_np, cmap="gray")  # Alpha channel
        plt.title("Silhouette (alpha)")
        plt.axis('off')

        plt.show()

    def display_initial_and_reference(self):

        plt.figure()

        # At the moment, the image returned by model() is just a mask, not the
        # rendering. To get the rendering, we need to call the Phong renderer
        # explicitly. Maybe, we should not return any image or mask and simply
        # recalculate them from R, t, Camera whenever needed.
        _, image_init, R, T = self.model()

        # Plot initial image
        plt.subplot(1, 2, 1)
        initial_pose_img = image_init.detach().squeeze().cpu().numpy()
        plt.imshow(initial_pose_img, cmap="gray")
        plt.grid(False)
        plt.title("Starting position")
        plt.axis('off')

        # Silhouette is torch tensor torch.Size([1, 480, 720, 4])
        silhouette_np = self.silhouette_ref.cpu().numpy().squeeze()[..., 3]
        # Plot Silhouette
        plt.subplot(1, 2, 2)
        plt.imshow(silhouette_np, cmap="gray")  # Alpha channel
        plt.title("Reference mask")
        plt.axis('off')

        plt.show()


    def display_current_image(self):

        plt.figure(figsize=(5, 5))

        # At the moment, the image returned by model() is just a mask, not the
        # rendering. To get the rendering, we need to call the Phong renderer
        # explicitly. Maybe, we should not return any image or mask and simply
        # recalculate them from R, t, Camera whenever needed.
        _, image_init, R, T = self.model()


        # R, T = self.compute_camera_pose(model.camera_position, model.device)

        rendered = self.phong_renderer(
            meshes_world=self.model.meshes.clone(),
            R=R,
            T=T
        )[0]  # (H, W, 4)


        rgb_rendered = rendered[..., :3].detach().cpu().numpy()


        # Plot reference image
        # plt.subplot(1, 2, 1)
        plt.imshow(rgb_rendered)
        plt.title("Phong Rendering")
        plt.axis('off')
        # plt.tight_layout()


    def display_image(self, distance=2, elevation=50, azimuth=45, eye_override=None):

        plt.figure(figsize=(5, 5))

        # At the moment, the image returned by model() is just a mask, not the
        # rendering. To get the rendering, we need to call the Phong renderer
        # explicitly. Maybe, we should not return any image or mask and simply
        # recalculate them from R, t, Camera whenever needed.
        # _, image_init, R, T = self.model()


        if eye_override is not None:
            # Get Rotation and Translation from camera position
            R, T = self.util.get_camera_pose(eye_override=eye_override, device=self.device)
        else:
            # Get Rotation and Translation from camera position
            R, T = self.util.get_camera_pose(distance, elevation, azimuth, device=self.device)


        rendered = self.phong_renderer(
            meshes_world=self.mesh,
            R=R,
            T=T
        )[0]  # (H, W, 4)


        rgb_rendered = rendered[..., :3].detach().cpu().numpy()


        # Plot reference image
        # plt.subplot(1, 2, 1)
        plt.imshow(rgb_rendered)
        plt.title("Phong Rendering")
        plt.axis('off')
        # plt.tight_layout()


    def render_sse_image(self, distance=2, elevation=0, azimuth=0, eye_override=None):

        if eye_override is not None:
            # Get Rotation and Translation from camera position
            R, T = self.util.get_camera_pose(eye_override=eye_override, device=self.device)
        else:
            # Get Rotation and Translation from camera position
            R, T = self.util.get_camera_pose(distance, elevation, azimuth, device=self.device)


        rendered = self.phong_renderer(
            meshes_world=self.mesh,
            R=R,
            T=T
        )[0]  # (H, W, 4)


        rgb_rendered = rendered[..., :3].detach().cpu().numpy()


        return rgb_rendered



    def display_grid_of_images(self):

        # Set batch size - this is the number of different viewpoints from which we want to render the mesh.
        batch_size = 20

        # Create a batch of meshes by repeating the cow mesh and associated textures.
        # Meshes has a useful `extend` method which allows us do this very easily.
        # This also extends the textures.
        meshes = self.mesh.extend(batch_size)

        # Get a batch of viewing angles.
        elev = torch.linspace(0, 180, batch_size)
        azim = torch.linspace(-180, 180, batch_size)

        # All the cameras helper methods support mixed type inputs and broadcasting. So we can
        # view the camera from the same distance and specify dist=2.7 as a float,
        # and then specify elevation and azimuth angles for each viewpoint as tensors.
        R, T = look_at_view_transform(dist=1.0, elev=elev, azim=azim)



        # cameras = FoVPerspectiveCameras(device=device, R=R, T=T)

        # Place a point light in front of the object
        # lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]])

        # Move the light back in front of the cow which is facing the -z direction.
        # lights.location = torch.tensor([[0.0, 0.0, -3.0]], device=device)

        # We can pass arbitrary keyword arguments to the rasterizer/shader via the renderer
        # so the renderer does not need to be reinitialized if any of the settings change.
        # images = renderer(meshes, cameras=cameras, lights=lights)
        images = self.phong_renderer(
            meshes_world=meshes,
            R=R,
            T=T
        )  # (H, W, 4)

        print(images.shape)

        # Display grid of images
        image_grid(images.cpu().numpy(), rows=4, cols=5, rgb=True)




    def display_current_and_reference_as_overlay(self, show_mask = False):

        if show_mask:
            sil2 = (self.image_ref[..., :3].max(-1) != 1).astype(np.float32).squeeze()
            bg_img = self.util.create_background_from_mask(sil2)
            image_uint8 = self.util.render_overlay(self.model, self.phong_renderer, bg_img)
        else:
            image_uint8 = self.render_current_and_reference_as_overlay()


        plt.figure(figsize=(5, 5))
        plt.imshow(image_uint8)
        plt.show()

    def get_reference_image(self):
        return self.image_ref


    def get_reference_silhouette(self):
        return self.silhouette_ref


    def render_current_and_reference_as_overlay(self):

        # sil2 = (self.image_ref[..., :3].max(-1) != 1).astype(np.float32).squeeze()
        # bg_img = self.util.create_background_from_mask(sil2)

        bg_img = self.image_ref[..., 0:3].squeeze()

        image_uint8 = self.util.render_overlay(self.model, self.phong_renderer, bg_img)

        return image_uint8



    def set_reference_image_from_file(self,
                                      image_path,
                                      output_size: Optional[Tuple[int, int]] = None,  # (W_out, H_out). If None, keep original canvas size.
    ):
        """
        Sets the reference image for the pose alignment (reads the image from a file).
        """

        # No resize or centering
        _ = self.create_reference_image_from_cutout(
          image_path,
          use_alpha_if_present=True,
          white_thresh = 0.0)

        # This output is in the range [0,1] RGB but no alpha
        cutout = self.image_ref

        _, H, W, _ = cutout.shape
        if output_size is None:
            out_w, out_h = W, H
        else:
            out_w, out_h = int(output_size[0]), int(output_size[1])

        # cutout: (1,480,720,4) or (480,720,4) uint8 / float

        print("ref image from inside read cutout function:", cutout.shape)

        # # centering and resize
        # centered_rgba, info = self.util.center_cutout_keep_scale(
        #     cutout_rgba=cutout,
        #     output_size=(out_w, out_h),   # (width,height) keep same size; or change to e.g. (800, 600)
        #     pad=20,
        #     overflow="shrink",         # or "crop"/"error"
        #     return_transform=True
        # )


        # centered_rgba, centered_mask = self.util.center_cutout_keep_scale_2(
        #     cutout_rgba=cutout,
        #     output_size=(out_w, out_h),
        #     pad=20,
        #     overflow="shrink",
        #     return_transform=True)



        self.image_ref = cutout
        self.silhouette_ref = cutout



    def read_rgb_cutout_black_bg_from_file(self, image_path: str,
                                          output_size: tuple[int, int]) -> np.ndarray:

        # Read cutout
        cutout = self.util.read_rgb_cutout_black_bg(image_path)  # (H,W,3) float32 in [0,1]

        # Centers the cutout
        centered_cutout = self.util.center_cutout_rgb_uint8(cutout)  # -> (H, W, 3) uint8

        # Crop the centered cutout preserving the scale
        crop_cutout = self.util.crop_center_to_size_uint8(centered_cutout, out_size=output_size)

        # cropped is (H, W, 3) uint8
        rgba = self.util.add_alpha_from_black_bg_uint8(crop_cutout, black_thresh=0, mode="binary")  # -> (H, W, 4) uint8

        # img: (H,W,3) or (H,W,4), any numeric dtype
        batched_rgba = self.util.to_float_batched_rgba_white_bg_preserve_alpha(rgba)  # -> (1,H,W,4) float in [0,1]

        # Set ref_image
        self.image_ref = batched_rgba


        return batched_rgba



    def set_reference_mask(self, image_path):
        """
        Sets the reference mask for the pose alignment (reads the mask from a file).
        """

        # Load mask image
        img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
        if img is None:
            raise IOError(f"Failed to load image: {image_path}")

        # Threshold mask to make into (0,255) range
        _, binary_mask = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY)

        self.prepare_mask_from_numpy(binary_mask)


    def prepare_mask_from_numpy(self, mask_np):

        from PIL import Image
        import numpy as np


        # mask_centered = self.util.center_mask_in_image_from_array(mask_np)[1]

        mask_centered = self.util.center_mask_preserve_ratio(mask_np, output_size=(256,256))[1]

        # _, mask_centered = self.util.center_mask_keep_scale(mask_np, output_size=(256, 256), overflow="crop", pad=8)


        mask_resized = cv2.resize(mask_centered, self.im_size[::-1], interpolation=cv2.INTER_LANCZOS4)
        mask_rgb = np.stack([mask_resized]*4, axis=-1)[None, ...]  # (1, H, W, 4)
        mask_rgb = 255 - mask_rgb
        self.image_ref = mask_rgb.astype(np.float32) / 255.0

        self.silhouette_ref = mask_rgb.astype(np.float32) / 255.0


    def init_model(self, distance=2, elevation=110, azimuth=50, roll_deg=0.0, learning_rate=0.05, eye_override=None):
        """
        Initializes the model with a given camera position.
        """

        if eye_override is not None:
            # Set the camera's 3-D location from the input location (3-D position)
            self.camera_pos = eye_override
            print(f"  Camera position (x,y,z): = {self.camera_pos.cpu().numpy()[0]}")
        else:
            # Set the camera's 3-D location from the input spherical angles
            self.camera_pos = self.util.get_camera_position(distance, elevation, azimuth, device=self.device)
            print(f"  distance = {distance}, elevation = {elevation}, azimuth = {azimuth}")


        # Initialize the model
        self.model = Model(
            meshes=self.mesh,
            renderer=self.silhouette_renderer,
            image_ref=self.image_ref,
            initial_position=self.camera_pos.cpu().numpy()[0],
            roll_init_deg = roll_deg
            # camera_pos=self.camera_pos.detach().cpu().numpy()[0],
            # device=self.device
        ).to(self.device)




        # Set the optimizer
        self.optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=learning_rate
            )


    def set_learning_rate(self, new_lr):
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = new_lr
        print(f"[INFO]: Learning rate set to {new_lr}")


    def run_optimization(self, tol = 0.1, show_mask=False):


        # print("Shape image_ref before conversion to silhouette:", self.image_ref.shape)

        if show_mask:
            # Make a mask out of image_ref
            sil2 = (self.image_ref[..., :3].max(-1) != 1).astype(np.float32).squeeze()
            bg_img = self.util.create_background_from_mask(sil2)
        else:
            bg_img = self.image_ref[..., 0:3].squeeze()

        # print("Shape image_ref after conversion to silhouette:", bg_img.shape)
        # print("This is the bg_mask that is passed to util.optimize_camera_pose_best_loss()",)
        # plt.figure()
        # plt.imshow(bg_img)
        # plt.show()



        self.writer = imageio.get_writer("./optimization_sequence.gif", mode='I', duration=0.5)

        # self.R, self.T = self.util.optimize_camera_pose(
        self.R, self.T = self.util.optimize_camera_pose_best_loss(
            model=self.model,
            optimizer=self.optimizer,
            mask_background=bg_img,
            phong_renderer=self.phong_renderer,
            writer=self.writer,
            num_iter=self.num_iter,
            log_every=10,
            patience=10
            )
        return self.R, self.T

    def get_results(self):
        return self.R, self.T, self.model.camera_position


    def printout_results(self):
        print("\nCam position from model =\n:", self.model.camera_position.detach().cpu().numpy())

        # Convert to NumPy
        R = self.R.detach().cpu().numpy()
        T = self.T.detach().cpu().numpy()

        # Print results
        np.set_printoptions(precision=4, suppress=True)
        print("\nEstimated Rotation Matrix (R):")
        print(R)

        print("\nEstimated Translation Vector (T):")
        print(T.reshape(-1))




### Gradio UI function for pose selection
Call this function whenever manual pose selection is needed.

In [None]:
# def gradio_UI_select_pose(pose_estimator):

#     import gradio as gr
#     from PIL import Image
#     import numpy as np
#     import threading


#     def render_pose_image(distance, azimuth, elevation, roll):
#         try:
#             # Update the model with new pose
#             pose_estimator.init_model(distance=distance, elevation=elevation, azimuth=azimuth, roll_deg=float(roll), learning_rate=0.0)

#             # Display the new overlay image
#             overlay_img = pose_estimator.render_current_and_reference_as_overlay()

#             # Convert to PIL if needed
#             if isinstance(overlay_img, np.ndarray):
#                 if overlay_img.dtype in [np.float32, np.float64]:
#                     overlay_img = (overlay_img * 255).clip(0, 255).astype(np.uint8)
#                 image_pil = Image.fromarray(overlay_img)
#                 return image_pil
#             else:
#                 return overlay_img  # If already PIL Image

#         except Exception as e:
#             print(f"Error rendering overlay: {e}")
#             return None


#     app = None  # launch handle is stored here

#     def on_done():
#         def _close():
#             try:
#                 app.close()
#             except Exception as e:
#                 print("Close error:", e)
#         threading.Thread(target=_close, daemon=True).start()
#         return "✅ Closing the interface…"

#     # Gradio UI


#     with gr.Blocks(title="CAD Pose Renderer") as demo:
#         gr.Markdown("Adjust the camera pose to render the CAD model. Click **Done** to stop the app.")
#         with gr.Row():
#             distance  = gr.Slider(minimum=0.5, maximum=20.0, step=0.1, value=6, label="Distance")
#             elevation = gr.Slider(minimum=-89.0, maximum=89.0, step=1, value=0, label="Elevation")
#             azimuth   = gr.Slider(minimum=-179.0, maximum=180.0, step=1, value=0, label="Azimuth")
#             roll      = gr.Slider(minimum=-179.0, maximum=180.0, step=1, value=0, label="Roll")
#             # done_btn = gr.Button("✅ Done", variant="primary")
#         out_img = gr.Image(type="pil", label="Rendered Image")
#         status  = gr.Markdown("")

#         # Update image whenever a slider changes (use a button instead if you prefer manual updates)
#         distance.change(render_pose_image, [distance, azimuth, elevation, roll], out_img)
#         azimuth.change(render_pose_image,  [distance, azimuth, elevation, roll], out_img)
#         elevation.change(render_pose_image,[distance, azimuth, elevation, roll], out_img)
#         roll.change(render_pose_image,     [distance, azimuth, elevation, roll], out_img)

#         # Initial render on load
#         demo.load(render_pose_image, [distance, azimuth, elevation, roll], out_img)

#         # Done -> close server
#         # done_btn.click(on_done, outputs=status)


#     app = demo.launch(inline=True, prevent_thread_lock=True, debug=False)

#     return app





# Gradio interface for selecting pose

## (Old) Not capturing values

In [None]:
# def gradio_UI_select_pose(pose_estimator):
#     import gradio as gr
#     from PIL import Image
#     import numpy as np
#     import threading
#     import torch

#     # ---- helpers ------------------------------------------------------------
#     def camera_center_to_dist_elev_azim(C: torch.Tensor):
#         """PyTorch3D conv: Y-up; azim=0 -> +Z; +azim toward +X."""
#         if C.ndim == 2 and C.shape[0] == 1:
#             C = C[0]
#         x, y, z = C[0], C[1], C[2]
#         dist = torch.linalg.norm(C)
#         rho  = torch.sqrt(torch.clamp(x*x + z*z, min=1e-12))
#         elev = torch.rad2deg(torch.atan2(y, rho))   # [-90, 90]
#         azim = torch.rad2deg(torch.atan2(x, z))     # (-180, 180]
#         return float(dist.detach().cpu()), float(elev.detach().cpu()), float(azim.detach().cpu())

#     def get_initial_slider_values_from_model(pose_estimator):
#         """Return (distance, elevation, azimuth, roll) as floats for slider init."""
#         C = pose_estimator.model.camera_position.detach()
#         d, e, a = camera_center_to_dist_elev_azim(C)
#         roll_param = getattr(pose_estimator.model, "roll_deg", None)
#         r = float(roll_param.detach().cpu()) if roll_param is not None else 0.0
#         # Clamp/wrap into slider ranges
#         e = max(-89.0, min(89.0, e))
#         a = ((a + 180.0) % 360.0) - 180.0
#         r = ((r + 180.0) % 360.0) - 180.0
#         return d, e, a, r

#     # Keep the argument order consistent: (distance, elevation, azimuth, roll)
#     def render_pose_image(distance, elevation, azimuth, roll):
#         try:
#             pose_estimator.init_model(
#                 distance=distance,
#                 elevation=elevation,
#                 azimuth=azimuth,
#                 roll_deg=float(roll),
#                 learning_rate=0.0
#             )
#             overlay_img = pose_estimator.render_current_and_reference_as_overlay()
#             if isinstance(overlay_img, np.ndarray):
#                 if overlay_img.dtype in (np.float32, np.float64):
#                     overlay_img = (overlay_img * 255).clip(0, 255).astype(np.uint8)
#                 return Image.fromarray(overlay_img)
#             return overlay_img
#         except Exception as e:
#             print(f"Error rendering overlay: {e}")
#             return None

#     # Do both: set slider values and render once
#     def init_and_render():
#         d, e, a, r = get_initial_slider_values_from_model(pose_estimator)
#         img = render_pose_image(d, e, a, r)
#         return d, e, a, r, img

#     app = None
#     def on_done():
#         def _close():
#             try:
#                 app.close()
#             except Exception as e:
#                 print("Close error:", e)
#         threading.Thread(target=_close, daemon=True).start()
#         return "✅ Closing the interface…"

#     # ---- UI -----------------------------------------------------------------
#     with gr.Blocks(title="CAD Pose Renderer") as demo:
#         gr.Markdown("Adjust the camera pose to render the CAD model. Click **Done** to stop the app.")
#         with gr.Row():
#             distance  = gr.Slider(minimum=0.5,  maximum=20.0,  step=0.1, value=2.0, label="Distance")
#             elevation = gr.Slider(minimum=-89.0, maximum=89.0,  step=1,   value=0.0, label="Elevation")
#             azimuth   = gr.Slider(minimum=-180.0, maximum=180.0, step=1,   value=0.0, label="Azimuth")
#             roll      = gr.Slider(minimum=-180.0, maximum=180.0, step=1,   value=0.0, label="Roll")
#             # done_btn = gr.Button("✅ Done", variant="primary")

#         out_img = gr.Image(type="pil", label="Rendered Image")
#         status  = gr.Markdown("")

#         # Initialize sliders and render immediately on load (single callback)
#         demo.load(
#             init_and_render,
#             inputs=None,
#             outputs=[distance, elevation, azimuth, roll, out_img]
#         )

#         # Live updates when sliders move (match arg order!)
#         distance.change( render_pose_image, [distance, elevation, azimuth, roll], out_img)
#         elevation.change(render_pose_image, [distance, elevation, azimuth, roll], out_img)
#         azimuth.change(  render_pose_image, [distance, elevation, azimuth, roll], out_img)
#         roll.change(     render_pose_image, [distance, elevation, azimuth, roll], out_img)

#         # done_btn.click(on_done, outputs=status)

#     app = demo.launch(inline=True, prevent_thread_lock=True, debug=False)
#     return app


## (New) Capturing values

In [None]:
def gradio_UI_select_pose(pose_estimator):
    import gradio as gr
    from PIL import Image
    import numpy as np
    import threading
    import torch
    import json

    # ---- helpers ------------------------------------------------------------
    def camera_center_to_dist_elev_azim(C: torch.Tensor):
        if C.ndim == 2 and C.shape[0] == 1:
            C = C[0]
        x, y, z = C[0], C[1], C[2]
        dist = torch.linalg.norm(C)
        rho  = torch.sqrt(torch.clamp(x*x + z*z, min=1e-12))
        elev = torch.rad2deg(torch.atan2(y, rho))   # [-90, 90]
        azim = torch.rad2deg(torch.atan2(x, z))     # (-180, 180]
        return float(dist.detach().cpu()), float(elev.detach().cpu()), float(azim.detach().cpu())

    def get_initial_slider_values_from_model(pose_estimator):
        C = pose_estimator.model.camera_position.detach()
        d, e, a = camera_center_to_dist_elev_azim(C)
        roll_param = getattr(pose_estimator.model, "roll_deg", None)
        r = float(roll_param.detach().cpu()) if roll_param is not None else 0.0
        # Clamp/wrap into slider ranges
        e = max(-89.0, min(89.0, e))
        a = ((a + 180.0) % 360.0) - 180.0
        r = ((r + 180.0) % 360.0) - 180.0
        return d, e, a, r

    # Keep arg order: (distance, elevation, azimuth, roll)
    def render_pose_image(distance, elevation, azimuth, roll):
        try:
            pose_estimator.init_model(
                distance=distance,
                elevation=elevation,
                azimuth=azimuth,
                roll_deg=float(roll),
                learning_rate=0.0
            )
            overlay_img = pose_estimator.render_current_and_reference_as_overlay()
            if isinstance(overlay_img, np.ndarray):
                if overlay_img.dtype in (np.float32, np.float64):
                    overlay_img = (overlay_img * 255).clip(0, 255).astype(np.uint8)
                return Image.fromarray(overlay_img)
            return overlay_img
        except Exception as e:
            print(f"Error rendering overlay: {e}")
            return None

    # init on load: set sliders + render once
    def init_and_render():
        d, e, a, r = get_initial_slider_values_from_model(pose_estimator)
        img = render_pose_image(d, e, a, r)
        return d, e, a, r, img

    app = None

    # Save current values to notebook vars and close UI
    def save_pose_and_close(distance, elevation, azimuth, roll):
        vals = {
            "distance": float(distance),
            "elevation": float(elevation),
            "azimuth": float(azimuth),
            "roll_deg": float(roll),
        }
        # Store into the notebook's namespace
        try:
            ip = get_ipython()
            if ip is not None:
                ip.user_ns["selected_pose"] = vals
                ip.user_ns["selected_distance"]  = vals["distance"]
                ip.user_ns["selected_elevation"] = vals["elevation"]
                ip.user_ns["selected_azimuth"]   = vals["azimuth"]
                ip.user_ns["selected_roll"]      = vals["roll_deg"]
        except Exception as e:
            print("Failed to save pose to notebook vars:", e)

        # Close the app
        def _close():
            try:
                app.close()
            except Exception as e:
                print("Close error:", e)
        threading.Thread(target=_close, daemon=True).start()

        return f"✅ Saved to variables: selected_pose (and individual scalars). Values: {json.dumps(vals)}"

    # ---- UI -----------------------------------------------------------------
    with gr.Blocks(title="CAD Pose Renderer") as demo:
        gr.Markdown("Adjust the camera pose to render the CAD model, then click **Done** to save the values to the notebook.")
        with gr.Row():
            distance  = gr.Slider(minimum=0.5,  maximum=20.0,  step=0.1, value=2.0, label="Distance")
            elevation = gr.Slider(minimum=-89.0, maximum=89.0,  step=1,   value=0.0, label="Elevation")
            azimuth   = gr.Slider(minimum=-180.0, maximum=180.0, step=1,   value=0.0, label="Azimuth")
            roll      = gr.Slider(minimum=-180.0, maximum=180.0, step=1,   value=0.0, label="Roll")
        out_img = gr.Image(type="pil", label="Rendered Image")
        with gr.Row():
            done_btn = gr.Button("✅ Done", variant="primary")
        status  = gr.Markdown("")

        # Initialize sliders and render immediately on load
        demo.load(init_and_render, inputs=None, outputs=[distance, elevation, azimuth, roll, out_img])

        # Live updates when sliders move (match arg order!)
        distance.change( render_pose_image, [distance, elevation, azimuth, roll], out_img)
        elevation.change(render_pose_image, [distance, elevation, azimuth, roll], out_img)
        azimuth.change(  render_pose_image, [distance, elevation, azimuth, roll], out_img)
        roll.change(     render_pose_image, [distance, elevation, azimuth, roll], out_img)

        # Save + close
        done_btn.click(save_pose_and_close, inputs=[distance, elevation, azimuth, roll], outputs=status)

    app = demo.launch(inline=True, prevent_thread_lock=True, debug=False)
    return app


### Download CAD files and textures

In [None]:
# Declare the utilities object
util = SSE_Util(local_path)

# Download cad models and texture
util.get_cad_model_files()

---
# Test 1: Estimate pose using a synthetic silhouette

In this example, we create an image of the  silhouette of the object to use as a reference image for the pose estimator.

## Create a pose estimator

In [None]:
# Create a pose estimator
pose_est = PoseEstimator(
    local_path=local_path,
    use_light_model = False,
    image_size=(256, 256),
    scale_factor=1,
    num_iter=200
    )

## Use the model to create a test reference image

In [None]:
# Create a test reference image
#
# elevation = [-89, 89] degrees.
# azimuth   = [0,360] or [-180,180]
#
pose_est.create_test_reference_image(
    distance=6.6,
    elevation=0.0,
    azimuth=0.0,
    roll = 30.0
    )

# Display the reference image and its silhouette
pose_est.display_test_reference_image()






In [None]:
# Initialize pose-estimation model
pose_est.init_model(distance=6.6, elevation=0, azimuth=0, roll_deg=-30.0, learning_rate=0.15)


overlay_img = pose_est.render_current_and_reference_as_overlay()

print(overlay_img.shape)
# Plot reference image mask
plt.figure()
plt.imshow(overlay_img, cmap='gray')
plt.title("silhouette_ref (Generated from alpha) - hard silhouette")
plt.axis('off')
plt.show()


## Select an initial pose using the UI


In [None]:
app = gradio_UI_select_pose(pose_est)   # launches inline, non-blocking
# Close app programmatically
# app.close()

In [None]:
selected_pose         # dict with keys: distance, elevation, azimuth, roll
selected_distance
selected_elevation
selected_azimuth
selected_roll

print(selected_pose)
pose_est.init_model(**selected_pose, learning_rate=0.05)

## Use the selected pose as the initial pose of the optmization step

In [None]:
# Display an overlay with the current pose and reference mask
pose_est.display_current_and_reference_as_overlay()

## Run the optimization and display the resulting pose

In [None]:
# Set the number of iterations
pose_est.set_number_of_iterations(400)

# Estimate pose
R, T = pose_est.run_optimization(tol=0.001)

# Display an overlay with the current pose and reference mask
pose_est.display_current_and_reference_as_overlay()

In [None]:
ref_mask = pose_est.image_ref[..., :3]
print(pose_est.image_ref.shape)
print(pose_est.silhouette_ref.shape)

m = pose_est.image_ref[:,:,:,0:3].squeeze()
# m = 255 - m

ref_mask = ref_mask.squeeze()
# Plot reference image
plt.figure()
plt.imshow(m, cmap='gray')
plt.plot(128, 128, 'o', markersize=8, color='red')
plt.title("Reference Image")
plt.axis('off')
plt.show()


# Test 2: Estimate pose using an actual sillhouette

In [None]:
# Create a pose estimator
pose_est_mask = PoseEstimator(
    local_path=local_path,
    use_light_model = False,
    image_size=(256, 256),
    scale_factor=1,
    num_iter=200
    )

## Set the reference image
This is an actual mask from the segmented object.

In [None]:
image_path = local_path + "assets/00002.png"
image_path = local_path + "assets/00006.png"


_ = pose_est_mask.read_rgb_cutout_black_bg_from_file(image_path, output_size=(256, 256))


# Plot reference image
plt.figure()
plt.imshow(pose_est_mask.image_ref.squeeze(), cmap='gray')
plt.title("cutout")
plt.axis('off')
plt.show()

# Plot reference image
plt.figure()
plt.imshow(pose_est_mask.image_ref.squeeze()[:,:,3], cmap='gray')
plt.title("cutout")
plt.axis('off')
plt.show()




In [None]:
pose_est_mask.init_model(distance=10, elevation=110, azimuth=50, roll_deg=0.0, learning_rate=0.15)

overlay_img = pose_est_mask.render_current_and_reference_as_overlay()

print(overlay_img.shape)
# Plot reference image mask
plt.figure()
plt.imshow(overlay_img, cmap='gray')
plt.title("silhouette_ref (Generated from alpha) - hard silhouette")
plt.axis('off')
plt.show()

## Choose the initial pose
Use the UI to select the initial pose.

In [None]:
# Set this position as the initial position for the optimization step.
# Also set the learning rate.



app = gradio_UI_select_pose(pose_est_mask)   # launches inline, non-blocking
# Close app programmatically
# app.close()

In [None]:
selected_pose         # dict with keys: distance, elevation, azimuth, roll
selected_distance
selected_elevation
selected_azimuth
selected_roll

print(selected_pose)
pose_est.init_model(**selected_pose, learning_rate=0.05)

## Set the initial pose

In [None]:
# Get the camera position
camera_pos = pose_est_mask.model.camera_position.detach()
camera_pos = camera_pos.unsqueeze(0)

# Set this position as the initial position for the optimization step.
# Also set the learning rate.
pose_est_mask.init_model(learning_rate=0.05, eye_override=camera_pos)

# Display an overlay with the current pose and reference mask
pose_est_mask.display_current_and_reference_as_overlay()

## Run the optimization and display the resulting pose


In [None]:

pose_est_mask.init_model(distance=8, elevation=15, azimuth=-28, roll_deg=25.0, learning_rate=0.15)


# Set the number of iterations
pose_est_mask.set_number_of_iterations(300)

# Estimate pose
R, T = pose_est_mask.run_optimization(tol=0.01)


# Display an overlay with the current pose and reference mask
pose_est_mask.display_current_and_reference_as_overlay()

# Display the model with the current pose
pose_est_mask.display_current_image()

print("\nCam position from model =\n:", pose_est_mask.model.camera_position)

cam_pos = pose_est_mask.model.camera_position.detach()
cam_pos = cam_pos.unsqueeze(0)

print("\nCam position from model numpy =\n:", cam_pos.cpu().numpy())


---
# Test 3: Process a sequence of masks from a video

## Create directory to hold video data

## Processing video: `sse_subset`
Mostly in-plane translation across the image.

In [None]:
# video_name = "sse_subset"
# res = prepare_video_data(local_path, video_name, delete_zips=True, overwrite=True, verbose=True)
# print(res)

## Processing video: `test14_node_small`
Spinning about the vertical direction


In [None]:
video_name = "test14_node_small"
res = prepare_video_data(local_path, video_name, delete_zips=True, overwrite=True, verbose=True)
print(res)

## Processing video: `node2_turning`
Some in-plane rotation


In [None]:
# video_name = "node2_turning"
# res = prepare_video_data(local_path, video_name, delete_zips=True, overwrite=True, verbose=True)
# print(res)

### Make a list of all mask files

In [None]:
# Process all masks
import os
import glob
import re

# Directory with mask files
mask_dir = "video_data/cutouts/"

# Regex pattern for exactly 5-digit filenames ending in .png
pattern = re.compile(r"^\d{5}\.png$")

# Get all .png files
all_files = glob.glob(os.path.join(mask_dir, "*.png"))

# Filter using regex
mask_files = sorted([
    f for f in all_files if pattern.match(os.path.basename(f))
])

print(mask_files)

### Display masks

In [None]:
fig, grid, names, masks = show_mask_grid(local_path=local_path, max_images=20, cols=5)
print("Shown files:", names[:5], "...")

### Display cutouts

In [None]:
fig, grid, names, imgs = show_cutout_grid(local_path=local_path, max_images=20, cols=5)
print("Shown files:", names[:5], "...")


### Run the pose estimation on the mask sequence

#### Initialize pose estimator

In [None]:
# Create a pose estimator
pose_est01 = PoseEstimator(
    local_path=local_path,
    use_light_model = False,
    image_size=(256, 256),
    scale_factor=1,
    num_iter=200
    )

#### Set the initial pose for the first frame

In [None]:
# Mask to use to set the initial pose
image_path = local_path + "video_data/masks/00001.png"
image_path = local_path + "video_data/cutouts/00001.png"

# pose_est01.set_reference_mask(image_path)


_ = pose_est01.read_rgb_cutout_black_bg_from_file(image_path, output_size=(256, 256))

# Plot reference image
plt.figure()
plt.imshow(pose_est01.image_ref.squeeze(), cmap='gray')
plt.title("cutout")
plt.axis('off')
plt.show()

# Plot reference image
plt.figure()
plt.imshow(pose_est01.image_ref.squeeze()[:,:,3], cmap='gray')
plt.title("mask")
plt.axis('off')
plt.show()



#### GRADIO: Choose initial pose

Use the UI to select the initial pose

In [None]:
app = gradio_UI_select_pose(pose_est01)   # launches inline, non-blocking
# Close app programmatically
# app.close()

In [None]:
selected_pose         # dict with keys: distance, elevation, azimuth, roll
selected_distance
selected_elevation
selected_azimuth
selected_roll

print(selected_pose)
pose_est01.init_model(**selected_pose, learning_rate=0.05)

In [None]:
# pose_est01.init_model(distance=8, elevation=15, azimuth=-28, roll_deg=25.0, learning_rate=0.15)



# Get the camera position
camera_pos = pose_est01.model.camera_position.detach()
camera_pos = camera_pos.unsqueeze(0)

# Set this position as the initial position for the optimization step.
# Also set the learning rate.
pose_est01.init_model(learning_rate=0.05, eye_override=camera_pos)

# Display an overlay with the current pose and reference mask
pose_est01.display_current_and_reference_as_overlay()

### Fit initial pose
Estimate pose and display the overlay

In [None]:
pose_est01.set_number_of_iterations(400)

R, T = pose_est01.run_optimization(tol=1)

# Display an overlay with the current pose and reference mask
pose_est01.display_current_and_reference_as_overlay()

pose_est01.printout_results()

### Now, estimate pose for all images

In [None]:
pose_est01.set_number_of_iterations(200)
pose_est01.set_learning_rate(0.05)

results = estimate_poses_for_masks(
    pose_est=pose_est01,
    mask_files=mask_files,
    local_path=local_path,
    N=None,                 # None or an int like 3
    iterations=200,
    learning_rate=0.05,
    tol=1.0,
    show_overlay=False,
    print_results=False,
    stop_on_error=False,
    output_size = (256, 256)
)

# Access a result
if results:
    r0 = results[0]
    print("First file:", r0["file"])
    print("R shape:", tuple(r0["R"].shape), "T shape:", tuple(r0["T"].shape))


In [None]:
gif_path, n = make_pose_gif(pose_est01, results, out_path="./result.gif", duration=0.5)
print(gif_path, n)