In [None]:
import math
import numpy as np
from torch import cos, sin
import scipy.optimize as opt
import torch
import torch.nn as nn
%matplotlib ipympl
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings('ignore')

import sys
sys.path.append("../../../ddn/")
from ddn.pytorch.node import *

from pytorch3d.loss import chamfer_distance
from pytorch3d.ops import sample_farthest_points
from descartes import PolygonPatch
from pytorch3d.io import IO, load_obj, save_obj,load_objs_as_meshes
from pytorch3d.structures import join_meshes_as_scene, Meshes, Pointclouds

from pytorch3d.loss import (
    chamfer_distance, 
    mesh_edge_loss, 
    mesh_laplacian_smoothing, 
    mesh_normal_consistency,
)

from alpha_shapes import Alpha_Shaper, plot_alpha_shape
from torch import Tensor, tensor

import os

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")

In [None]:
torch.autograd.set_detect_anomaly(True)

def least_squares(u0, tgt_vtxs):
    """
    u0 are vertices
    """
    if not torch.is_tensor(u0):
        u0 = torch.tensor(u0)
    if not torch.is_tensor(tgt_vtxs):
        tgt_vtxs = torch.tensor(tgt_vtxs)
    res = torch.square(u0 - tgt_vtxs.flatten()).sum()
    return res.double()

def least_squares_grad(u0, tgt_vtxs):
    if torch.is_tensor(u0):
        u0 = u0.detach().clone()
    else:
        u0 = torch.tensor(u0)
    if torch.is_tensor(tgt_vtxs):
        tgt_vtxs = tgt_vtxs.detach().clone()
    else:
        tgt_vtxs = torch.tensor(tgt_vtxs)
        
    # Ensure that u0 requires gradients
    gradient = 2 * (u0 - tgt_vtxs.flatten())
    return gradient.double()


def calculate_volume(vertices, faces):
    face_vertices = vertices[faces]  # (F, 3, 3)
    v0, v1, v2 = face_vertices[:, 0, :], face_vertices[:, 1, :], face_vertices[:, 2, :]
    
    # Compute determinant of the 3x3 matrix [v0, v1, v2]
    face_volumes = torch.det(torch.stack([v0, v1, v2], dim=-1)) / 6.0  # Shape: (F,)
    volume = face_volumes.sum()
    return volume.abs()


def volume_constraint(x, faces, tgt_vol):
    """
    Calculate the volume of a mesh using PyTorch tensors.
    Args:
        vertices_torch: Nx3 tensor of vertex coordinates
        faces: Mx3 array of face indices
    Returns:
        volume: Total volume of the mesh as a PyTorch scalar
    """
    if not torch.is_tensor(x):
        x = torch.tensor(x)
    if not torch.is_tensor(faces):
        faces = torch.tensor(faces)
    if not torch.is_tensor(tgt_vol):
        tgt_vol = torch.tensor(tgt_vol)

    vertices = x.view(-1,3)
    faces = faces.view(-1,3).int()    
    volume = calculate_volume(vertices, faces)
    res = volume.abs() - tgt_vol
    return res.double()

def volume_constraint_grad(x, faces):
    if torch.is_tensor(x):
        x = x.detach().clone()
    else:
        x = torch.tensor(x)
    if torch.is_tensor(faces):
        faces = faces.detach().clone()
    else:
        faces = torch.tensor(faces)
    faces = faces.to(dtype=torch.int64)

    vertices_torch = x.view(-1, 3)
    p0 = vertices_torch[faces[:, 0]]  # (F, 3)
    p1 = vertices_torch[faces[:, 1]]  # (F, 3)
    p2 = vertices_torch[faces[:, 2]]  # (F, 3)

    grad_p0 = torch.cross(p1, p2, dim=1) / 6.0
    grad_p1 = torch.cross(p2, p0, dim=1) / 6.0
    grad_p2 = torch.cross(p0, p1, dim=1) / 6.0

    grad_verts = torch.zeros_like(vertices_torch)
    grad_verts.scatter_add_(0, faces[:, 0].unsqueeze(1).expand(-1, 3), grad_p0)
    grad_verts.scatter_add_(0, faces[:, 1].unsqueeze(1).expand(-1, 3), grad_p1)
    grad_verts.scatter_add_(0, faces[:, 2].unsqueeze(1).expand(-1, 3), grad_p2)

    analytical_grad = grad_verts.flatten()
    return analytical_grad 


class ConstrainedProjectionNode(EqConstDeclarativeNode):

    def __init__(self, src: Meshes, tgt: Meshes):
        super().__init__(eps=1.0e-6) # relax tolerance on optimality test 
        self.src = src # source meshes (B,)
        self.tgt = tgt # target meshes (B,)

    def objective(self, xs: torch.Tensor, y: torch.Tensor, scatter_add=True):
        """
        Calculates sum of squared differences between source and target meshes.

        Args:
            xs (tensor): vertices of original mesh, sum(V_i) x 3
            y (tensor): vertices of projected mesh, sum(V_i) x 3
        """
        src_verts = y.view(-1,3) # (sum(V_i), 3)
        tgt_verts = self.tgt.verts_packed().detach() # (sum(V_i), 3)
        sqr_diffs = torch.square(src_verts - tgt_verts) # (sum(V_i), 3)

        n_batches = len(self.src)
        sse = torch.zeros(n_batches, dtype=sqr_diffs.dtype)
        if scatter_add:
            sse.scatter_add_(0, self.src.verts_packed_to_mesh_idx(), sqr_diffs)
        else:
            n_verts_per_mesh = self.src.num_verts_per_mesh()
            for i in range(n_batches):
                mesh_to_vert = self.src.mesh_to_verts_packed_first_idx()  # Index of first face per mesh
                start = mesh_to_vert[i]
                end = start + n_verts_per_mesh[i]
                sse[i] = sqr_diffs[start:end].sum()  # Sum over all faces
        return sse

    def equality_constraints(self, xs, y, scatter_add=True):
        """
        Enforces volume constraint
        Assumes same number of vertices in each projected mesh currently

        Args:
            xs (tensor): vertices of original mesh, sum(V_i) x 3
            y (tensor): vertices of projected mesh, sum(V_i) x 3
        """
        n_batches = len(self.src)
        verts_packed = y.view(-1,3) # (sum(V_i), 3)

        faces_packed = self.src.faces_packed()  # (sum(F_i), 3)
        face_vertices = verts_packed[faces_packed]  # (sum(F_i), 3, 3)
        
        # Calculate tetrahedron volumes for each face
        v0, v1, v2 = face_vertices[:, 0, :], face_vertices[:, 1, :], face_vertices[:, 2, :]
        cross_product = torch.cross(v0, v1, dim=-1)  # (F, 3)
        face_volumes = torch.sum(cross_product * v2, dim=-1) / 6.0  # (F,)
        volumes = torch.zeros(n_batches, device=verts_packed.device, dtype=face_volumes.
                                dtype)
        if scatter_add:
            volumes.scatter_add_(0, self.src.faces_packed_to_mesh_idx(), face_volumes)
        else:
            n_faces_per_mesh = self.src.num_faces_per_mesh()
            for i in range(n_batches):
                mesh_to_face = self.src.mesh_to_faces_packed_first_idx()  # Index of first face per mesh
                start = mesh_to_face[i]
                end = start + n_faces_per_mesh[i]
                volumes[i] = face_volumes[start:end].sum()  # Sum over all faces

        volumes = volumes.abs()
        return volumes  # Shape: (B,)    
    
    def solve(self, xs: torch.Tensor):
        """Projects the vertices onto the target mesh vertices across batches.

        Args:
            xs (torch.Tensor): a (sum Vi, 3) packed tensor of vertices in the batched meshes

        Returns:
            results (torch.Tensor): a (sum Vi, 3) packed tensor of the projected vertices
        """
        n_batches = len(self.src)
        start_vtx = self.src.mesh_to_verts_packed_first_idx()
        end_vtx = start_vtx + self.src.num_verts_per_mesh()
        
        n_vtx = len(xs)
        results = torch.zeros(n_vtx, 3, dtype=torch.double)
        for batch in range(n_batches):
            start,end = start_vtx[batch],end_vtx[batch]
            verts = xs[start:end].flatten().detach().double().cpu().numpy()
            faces = self.src[batch].faces_packed().detach().int().cpu().numpy()
            tgt_vtx = self.tgt[batch].verts_packed().detach()
            tgt_faces = self.tgt[batch].faces_packed().detach()
            with torch.no_grad():
                tgt_vol = calculate_volume(tgt_vtx, tgt_faces)

            eq_constraint = {
                'type': 'eq',
                'fun' : lambda u: volume_constraint(u, faces, tgt_vol).cpu().numpy(),
                'jac' : lambda u: volume_constraint_grad(u, faces).cpu().numpy()
            }

            res = opt.minimize(
                lambda u: least_squares(u, tgt_vtx),
                verts,
                method='SLSQP',
                jac=lambda u: least_squares_grad(u, tgt_vtx),
                constraints=[eq_constraint],
                options={'ftol': 1e-6, 'iprint': 2, 'maxiter': 100}
            )

            if not res.success:
                print("FAILED:", res.message)
            results[start:end, :] = torch.tensor(res.x, dtype=torch.double, requires_grad=True).view(-1,3)
        return results,None
    

class ConstrainedProjectionFunction(DeclarativeFunction):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)


Pseudo code:
- load in the meshes
- inner problem needs access to the vertices, number of meshes, faces, and indexing
- outer problem needs access to projected vertices, number of meshes, and indexing. Also needs projection matrices, and edge maps of renders, so perform edge detection of renders beforehand.

just provide both with the meshes lol

Projection:
Get the indexing correct for the vertices, take the projection of these vertices

In [None]:
# outer problem
def create_padded_tensor(vertices, vert2mesh, max_V, B):
    padded = torch.zeros((B, max_V, 3),device=vertices.device)
    for i in range(B):
        mesh_vertices = vertices[vert2mesh == i]
        num_vertices = mesh_vertices.shape[0]
        padded[i, :num_vertices, :] = mesh_vertices
    return padded

class PyTorchChamferLoss(nn.Module):
    def __init__(self, src: Meshes, tgt: Meshes, projmatrices, edgemap_info):
        super().__init__()
        self.src = src  # (B meshes)
        self.tgt = tgt  # (B meshes)
        self.projmatrices = projmatrices # (P, 3, 4)
        self.edgemaps = edgemap_info[0] # (P, max_Ni, 2)
        self.edgemaps_len = edgemap_info[1] # (P,)
    
    def project_vertices(self, vertices):
        """
        Projects a set of vertices into multiple views using different projection matrices.

        Args:
            vertices: Tensor of shape (N, 3), representing 3D vertex positions.

        Returns:
            Tensor of shape (P, N, 2), containing projected 2D points in each view.
        """
        V = vertices.shape[0]
        projection_matrices = self.projmatrices

        ones = torch.ones((V, 1), dtype=vertices.dtype, device=vertices.device)
        vertices_homogeneous = torch.cat([vertices, ones], dim=1).double()  # Shape: (V, 4)

        # Perform batched matrix multiplication (P, 3, 4) @ (V, 4, 1) -> (P, V, 3)
        projected = torch.einsum("pij,vj->pvi", projection_matrices, vertices_homogeneous)  # (P, V, 3)
        
        projected_cartesian = projected[:, :, :2] / projected[:, :, 2:3]  # (P, V, 2)

        return projected_cartesian

    def get_boundary(self, projected_pts, alpha=10.0):
        shaper = Alpha_Shaper(projected_pts.detach())
        alpha_shape = shaper.get_shape(alpha)
        boundary = torch.tensor(alpha_shape.exterior.coords.xy, dtype=torch.double)
        boundary_pts = projected_pts[
            torch.any(torch.isclose(projected_pts[:, None], boundary.T, atol=1e-6).all(dim=-1), dim=1)
        ]
        return boundary_pts

    def forward(self, y):
        # y Shape: (sum Vi, 3) -> reshape nicely into (B, maxV, 3)
        B, P, max_V = len(self.src), self.projmatrices.size(0), self.src.num_verts_per_mesh().max().item()
        vertices = create_padded_tensor(y, self.src.verts_packed_to_mesh_idx(), max_V, B) # (B, maxV, 3)

        # project vertices
        num_verts_per_mesh = self.src.num_verts_per_mesh()
        projected_vertices = [] # (B, P, V, 2)
        for b in range(B):
            end = num_verts_per_mesh[b]
            projverts = self.project_vertices(vertices[b][:end,:])  # Shape: (P, V, 2)
            projected_vertices.append(projverts)  # Store without padding

        # get boundaries
        boundaries = [] 
        boundary_lengths = torch.zeros(B, P)
        for b, batch in enumerate(projected_vertices):
            boundaries_b = []
            for p, projverts in enumerate(batch):
                boundary = self.get_boundary(projverts)
                boundaries_b.append(boundary)
                boundary_lengths[b,p] = len(boundary)
            # stacked_boundaries = torch.stack(boundaries_b)
            padded_boundaries = torch.nn.utils.rnn.pad_sequence(boundaries_b, batch_first=True, padding_value=0.0)
            boundaries.append(padded_boundaries)

        # perform chamfer
        chamfer_loss = torch.zeros(B)
        for b in range(B):
            boundaries_b = boundaries[b].float()
            edgemaps_b = self.edgemaps[b].float()
            print(boundaries_b.dtype, edgemaps_b.dtype)
            print(boundary_lengths[b], self.edgemaps_len[b])
            res, _ = chamfer_distance(  x=boundaries_b,
                                        y=edgemaps_b,
                                        x_lengths=boundary_lengths[b].long(),
                                        y_lengths=self.edgemaps_len[b].long(),
                                        batch_reduction="mean",
                                        point_reduction="mean")
            chamfer_loss[b] = res.sum()
        return chamfer_loss.double()



In [None]:

paths = [os.path.join("../../../Blender/", f"{name}_2.obj") for name in ["sphere", "balloon", "parabola", "rstrawberry"]]
sphere, balloon, parabola, rstrawberry = load_objs_as_meshes(paths)

def outer_problem(src: Meshes, tgt: Meshes, projmats, edgemap_info, n_iters, lr, moment, verbose=True):
    node = ConstrainedProjectionNode(src, tgt)
    verts_init = src.verts_packed() # (sum Vi, 3)
    verts_init.requires_grad = True
    # apply solve
    projverts_init = ConstrainedProjectionFunction.apply(node, verts_init) # (sum Vi, 3)

    chamfer_loss = PyTorchChamferLoss(src, tgt, projmats, edgemap_info)
    history = [projverts_init]
    verts = verts_init.clone().detach().requires_grad_(True)
    optimiser = torch.optim.SGD([verts], lr=lr, momentum=moment)

    # verts_prev = None
    for i in range(n_iters):
        optimiser.zero_grad()
        projverts = ConstrainedProjectionFunction.apply(node, verts)
        history.append(projverts.detach().clone())
        # verts_prev = projverts

        loss = chamfer_loss(projverts)
        loss.backward()
        optimiser.step()

        if verbose:
            print(f"{i:4d} Loss: {loss.item()} Gradient: {verts.grad}")
    return verts

In [None]:
from utils import load_renders, load_camera_matrices
import cv2
from cv2.typing import MatLike

# Apply Canny edge detection
def canny_edge_map(img: MatLike):
    # convert to grayscale
    img_greyscale = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    # apply edge detection
    edge_map = cv2.Canny(img_greyscale, threshold1=50, threshold2=250)
    # return edge map
    return edge_map

def get_edgemaps(renders):
    edgemaps = {}
    edgemaps_len = {}
    for k,v in renders.items():
        views = {}
        views_len = {}
        for num, img in v.items():
            edges = canny_edge_map(img)
            edge_coords = np.argwhere(edges > 0)
            views[num] = torch.tensor(edge_coords)
            views_len[num] = len(edge_coords)
        edgemaps[k] = views
        edgemaps_len[k] = views_len
    return edgemaps, edgemaps_len

renders_path = "../../../Blender/renders/"
renders = load_renders(renders_path)
edgemaps, edgemaps_len = get_edgemaps(renders)
matrices_path = "../../../Blender/cameras"
matrices = load_camera_matrices(matrices_path)

In [None]:
projmats = torch.stack([matrices["Camera0"]["P"], 
                        matrices["Camera2"]["P"],
                        matrices["Camera3"]["P"]])

view_idx = [1,3,4]
tgt_edgemaps = torch.nn.utils.rnn.pad_sequence([edgemaps["balloon"][i] for i in view_idx], batch_first=True, padding_value=0.0)
tgt_edgemaps_len = torch.tensor([edgemaps_len["balloon"][i] for i in view_idx])
tgt_edgemap_info = [tgt_edgemaps], [tgt_edgemaps_len]

In [None]:
outer_problem(sphere, parabola, projmats, tgt_edgemap_info, n_iters=20, lr=1e-4,moment=0)