In [12]:
import torch
import warnings
from typing import Callable, List, Optional
from torch.library import Library
from tqdm import tqdm
from torchvision.transforms import Compose
from datapipes.bal_loader import get_problem, read_bal_data

TARGET_DATASET = "ladybug"
TARGET_PROBLEM = "problem-1723-156502-pre"

DEVICE = 'cuda' # change device to CPU if needed
DTYPE = torch.float64
USE_QUATERNIONS = True
OPTIMIZE_INTRINSICS = False
#dataset = read_bal_data(file_name='Data/dubrovnik-3-7-pre.txt', use_quat=USE_QUATERNIONS)
dataset = get_problem(TARGET_PROBLEM, TARGET_DATASET, use_quat=USE_QUATERNIONS)

if OPTIMIZE_INTRINSICS:
    NUM_CAMERA_PARAMS = 10 if USE_QUATERNIONS else 9
else:
    NUM_CAMERA_PARAMS = 7 if USE_QUATERNIONS else 6

print(f'Fetched {TARGET_PROBLEM} from {TARGET_DATASET}')

import torch
import torch.nn as nn
import pypose as pp

trimmed_dataset = dataset
trimmed_dataset = {k: v.to(DEVICE) for k, v in trimmed_dataset.items() if type(v) == torch.Tensor}

def convert_to(type):
    for k, v in trimmed_dataset.items():
        if 'index' not in k:
            trimmed_dataset[k] = v.to(type)


convert_to(DTYPE)
torch.set_default_dtype(DTYPE)

Streaming data for ladybug...
Fetched problem-1723-156502-pre from ladybug


In [None]:
import os
os.environ['LD_LIBRARY_PATH'] = '/global/common/software/nersc9/pytorch/2.3.1/lib/python3.11/site-packages/torch/lib:' + os.environ.get('LD_LIBRARY_PATH', '')
os.environ['LD_LIBRARY_PATH'] = '/global/homes/h/hxu398/bal_exp/pyro_slam/cudss:' + os.environ.get('LD_LIBRARY_PATH', '')

In [1]:
from pyro_slam.sparse.py_ops import *
from pyro_slam.sparse.bsr_cuda import *
from pyro_slam.sparse.solve import *

  Overriding a previously registered kernel for the same operator and the same dispatch key
  operator: aten::mm(Tensor self, Tensor mat2) -> Tensor
    registered at /pscratch/sd/s/swowner/pytorch-build/pytorch/2.3.1/pytorch/build/aten/src/ATen/RegisterSchema.cpp:6
  dispatch key: SparseCsrCPU
  previous kernel: registered at /pscratch/sd/s/swowner/pytorch-build/pytorch/2.3.1/pytorch/aten/src/ATen/LegacyBatchingRegistrations.cpp:1079
       new kernel: registered at /global/u1/h/hxu398/bal_exp/pyro_slam/pyro_slam/sparse/sparse_op_cpp.cpp:206 (function operator())


In [2]:
import torch

def rotate_euler(points, rot_vecs):
    """Rotate points by given rotation vectors.

    Rodrigues' rotation formula is used.
    """
    theta = torch.norm(rot_vecs, dim=1, keepdim=True) # 3 element-wise ops, square, sum, divide by norm
    v = torch.nan_to_num(rot_vecs / theta) # 1 element-wise op, division
    dot = torch.sum(points * v, dim=1, keepdim=True) # 2 element-wise ops
    cos_theta = torch.cos(theta) # 1 element-wise op
    sin_theta = torch.sin(theta) # 1 element-wise op
    return cos_theta * points + sin_theta * torch.cross(v, points, dim=1) + dot * (1 - cos_theta) * v # 6 element-wise ops

def rotate_quat(points, rot_vecs):
    rot_vecs = pp.SE3(rot_vecs)
    return rot_vecs.Act(points)

def project(points, camera_params):
    """Convert 3-D points to 2-D by projecting onto images."""
    if USE_QUATERNIONS:
        points_proj = rotate_quat(points, camera_params[..., :7])
    else:
        points_proj = rotate_euler(points, camera_params[..., 3:6])
        points_proj = points_proj + camera_params[..., 3:6]
    points_proj = -points_proj[..., :2] / points_proj[..., 2].unsqueeze(-1)  # add dimension for broadcasting
    f = camera_params[..., -3].unsqueeze(-1)
    k1 = camera_params[..., -2].unsqueeze(-1)
    k2 = camera_params[..., -1].unsqueeze(-1)
    
    n = torch.sum(points_proj**2, axis=-1, keepdim=True)
    r = 1 + k1 * n + k2 * n**2
    points_proj = points_proj * r * f  # broadcasting will take care of the shape

    return points_proj

# dense version
class ReprojNonBatched(nn.Module):
    def __init__(self, camera_params, points_3d):
        super().__init__()
        self.pose = nn.Parameter(camera_params)
        self.points_3d = nn.Parameter(points_3d)

    def forward(self, points_2d, intr, camera_indices, point_indices):
        camera_params = self.pose
        points_3d = self.points_3d
        if intr is not None:
            camera_params = torch.cat([camera_params, intr], dim=-1)
        points_proj = project(points_3d[point_indices], camera_params[camera_indices])
        loss = points_proj - points_2d
        return loss
    
    def get_num_params(self):
        return self.pose.numel() + self.points_3d.numel()

from functools import partial

def least_square_error(camera_params, points_3d, camera_indices, point_indices, points_2d, intr=None):
    model = ReprojNonBatched(camera_params, points_3d)
    loss = model(points_2d, intr, camera_indices, point_indices)
    return torch.sum(loss**2) / 2

In [3]:
import torch

def estimate_forward_flops_vmap():
    """
    Estimate the number of FLOPs required for the forward pass of the model using vmap.
    Because vmap is used, this is just the cost for a single point.

    Returns:
        total_flops: Estimated total number of FLOPs for the forward pass.
    """
    flops_rotate_euler = 42 # 3 * 14; 14 ops per dimension
    flops_translation = 3 # 1 op per dimension
    flops_perspective = 4 # nagation and division for x and y (ditching z)
    flops_n = 3 # n = torch.sum(points_proj**2, axis=-1, keepdim=True), each point has 2 elements now
    flops_distortion = 4 # points_proj * r * f
    flops_loss = 2 # 2 ops for loss computation

    total_flops = (flops_rotate_euler + flops_translation + flops_perspective +
                     flops_n + flops_distortion + flops_loss)
    return total_flops

def estimate_forward_flops(input_args):
    """
    Estimate the number of FLOPs required for the forward pass of the model.

    Args:
        model: The model instance (e.g., ReprojNonBatched).
        input_args: A tuple containing (points_2d, intr, camera_indices, point_indices).

    Returns:
        total_flops: Estimated total number of FLOPs for the forward pass.
    """
    point_indices = input_args["point_indices"]
    N = point_indices.shape[0]

    # there are N 3d points in each full forward pass
    total_flops = N * estimate_forward_flops_vmap()

    return total_flops


def estimate_jacobian_flops(model, input_args, type=None):
    """
    Estimate the number of FLOPs required for computing the Jacobian of the model.

    Args:
        model: The model instance (e.g., ReprojNonBatched).
        input_args: A tuple containing (points_2d, intr, camera_indices, point_indices).

    Returns:
        total_flops: Estimated total number of FLOPs for Jacobian computation.
    """
    point_indices = input_args["point_indices"]

    # this is the number of row in the Jacobian
    N = point_indices.shape[0]

    if type == 'pose':
        M = model.model.pose.shape[0]
    elif type == 'points':
        M = model.model.points_3d.shape[0]

    flops_backward = estimate_forward_flops_vmap() * 2 # assume backward pass is twice as expensive as forward pass

    # In the dense mode, each jacobian row need to backward for *all* parameters
    total_flops = N * M * flops_backward

    return total_flops

class DenseLMFlopsEstimator(pp.optim.LevenbergMarquardt):
    @torch.no_grad()
    def step(self, input_args, target=None, weight=None):
        """
        Perform a Levenberg-Marquardt optimization step, estimating the FLOPs required.

        Args:
            input_args: A tuple containing (points_2d, intr, camera_indices, point_indices).
            target: Not used in this context (optional).
            weight: Weight matrix (optional).

        Returns:
            flops_dict (dict): A dictionary containing the estimated FLOPs for each step.
        """
        point_indices = input_args["point_indices"]

        N = point_indices.shape[0] # Number of points
        n = self.model.model.get_num_params() # Number of parameters

        # model forward pass
        flops_forward = estimate_forward_flops(input_args)

        # Jacobian computation
        flops_jacobian = estimate_jacobian_flops(self.model, input_args, type="pose")
        flops_jacobian += estimate_jacobian_flops(self.model, input_args, type="points")

        # negligible FLOPs for flattening Jacobian rows, correction applied to R and J, normalize R, W, and J, and transpose of Jacobian

        # FLOPs for J_T @ J
        # Jacobian dimension after flattening is J_T (n x N) and J (N x n)
        # FLOPs for J_T @ J: n * n * (2N−1)
        flops_JT_J = n * n * (2 * N - 1)

        # FLOPs for adding damping to diagonal of A
        # A is n x n
        flops_damping = n  # one multiplication and addition per diagonal element

        # FLOPs for computing b = -J_T @ R
        # R has size N x 1, J_T is n x N
        # FLOPs for J_T @ R: n * (2N−1)
        flops_b = n * (2 * N - 1)

        # Estimate FLOPs for solving linear system A delta = b
        # A is n x n, b is n x 1
        # choleskey decomposition takes 1/3 * n^3 flops
        # back-substitution takes n^2 flops
        flops_solver = (1/3) * n**3 + n**2

        # Estimate FLOPs for updating parameters theta += delta
        flops_param_update = n  # One addition per parameter

        flops_dict = {
            "jac": flops_jacobian,
            "solver": flops_solver,
            "JT_J": flops_JT_J,
            "others": flops_forward + flops_damping + flops_b + flops_param_update,
            "total": flops_jacobian + flops_solver + flops_forward + flops_JT_J + flops_damping + flops_b + flops_param_update
        }
    

        # Return the total estimated FLOPs
        return flops_dict

def get_human_readable_flops(flops):
    """
    Convert FLOPs to a human-readable format.

    Args:
        flops: Estimated FLOPs.

    Returns:
        human_readable_flops: FLOPs in a human-readable format.
    """

    if flops < 1e3:
        human_readable_flops = f"{flops:.2f} FLOPs"
    elif flops < 1e6:
        human_readable_flops = f"{flops / 1e3:.2f} KFLOPs"
    elif flops < 1e9:
        human_readable_flops = f"{flops / 1e6:.2f} MFLOPs"
    elif flops < 1e12:
        human_readable_flops = f"{flops / 1e9:.2f} GFLOPs"
    elif flops < 1e15:
        human_readable_flops = f"{flops / 1e12:.2f} TFLOPs"
    else:
        human_readable_flops = f"{flops / 1e15:.2f} PFLOPs"

    

    return human_readable_flops

def get_human_readable_flops_dict(flops_dict):
    """
    Convert FLOPs dictionary to a human-readable format.

    Args:
        flops_dict: Dictionary containing estimated FLOPs.

    Returns:
        human_readable_flops_dict: Dictionary containing estimated FLOPs in a human-readable format.
    """

    human_readable_flops_dict = {
        key: get_human_readable_flops(flops) for key, flops in flops_dict.items()
    }

    return human_readable_flops_dict

In [8]:
from pyro_slam.sparse.py_ops import *
from pyro_slam.sparse.bsr_cuda import *
from pyro_slam.sparse.solve import *


from torch.func import jacrev, jacfwd


def construct_sbt(jac_from_vmap, num, index):
    n = index.shape[0] # num 2D points
    i = torch.stack([torch.arange(n).to(index.device), index])
    block_shape = jac_from_vmap.shape[1:]
    v = jac_from_vmap # adjust dimension to accomodate for sbt constructor
    dummy_val = torch.arange(n, device=index.device, dtype=torch.int64)
    dummy_coo = torch.sparse_coo_tensor(i, dummy_val, size=(n, num), device=index.device, dtype=torch.int64)
    dummy_csc = dummy_coo.coalesce().to_sparse_csc()
    return torch.sparse_bsc_tensor(ccol_indices = dummy_csc.ccol_indices(), 
                                   row_indices=dummy_csc.row_indices(),
                                   values = v[dummy_csc.values()],
                                   size = (n * block_shape[0], num * block_shape[1]),
                                   device=index.device, dtype=DTYPE)

def modjacrev_vmap(model, input, argnums=0, *, has_aux=False):
    params = dict(model.named_parameters())

    cameras_num = params['model.pose'].shape[0]
    points_3d_num = params['model.points_3d'].shape[0]
    # need to align the indices with the parameters
    camera_indices = input['camera_indices']
    point_indices = input['point_indices']
    params['model.pose'] = params['model.pose'][camera_indices] # index using camera indices
    params['model.points_3d'] = params['model.points_3d'][point_indices] # index using point indices
    jac_points_3d, jac_pose = torch.vmap(jacrev(project, argnums=(0, 1), has_aux=has_aux))(params['model.points_3d'], params['model.pose'])
    if USE_QUATERNIONS: 
        useful_idx = [0,1,2,3,4,5,7,8,9] if OPTIMIZE_INTRINSICS else [0,1,2,3,4,5,]
        jac_pose = jac_pose[..., useful_idx] # remove the 4th element of the quaternion
                                                    # because original is [qx, qy, qz, qw, tx, ty, tz], but always dqw = 0
    return [construct_sbt(jac_pose, cameras_num, camera_indices), construct_sbt(jac_points_3d, points_3d_num, point_indices)]


import time
from typing import Optional

from torch import Tensor
    
class TrustRegion(pp.optim.strategy.TrustRegion):
    def update(self, pg, last, loss, J, D, R, *args, **kwargs):
        J = [i.to_sparse_coo() for i in J]
        JD = None
        for i in range(len(D)):
            if JD is None:
                JD = J[i] @ D[i]
            else:
                JD += J[i] @ D[i]
        JD = JD[..., None]
        quality = (last - loss) / -((JD).mT @ (2 * R.view_as(JD) + JD)).squeeze()
        pg['radius'] = 1. / pg['damping']
        if quality > pg['high']:
            pg['radius'] = pg['up'] * pg['radius']
            pg['down'] = self.down
        elif quality > pg['low']:
            pg['radius'] = pg['radius']
            pg['down'] = self.down
        else:
            pg['radius'] = pg['radius'] * pg['down']
            pg['down'] = pg['down'] * pg['factor']
        pg['down'] = max(self.min, min(pg['down'], self.max))
        pg['radius'] = max(self.min, min(pg['radius'], self.max))
        pg['damping'] = 1. / pg['radius']

class Adaptive(pp.optim.strategy.Adaptive):
    def update(self, pg, last, loss, J, D, R, *args, **kwargs):
        J = [i.to_sparse_coo() for i in J]
        JD = None
        for i in range(len(D)):
            if JD is None:
                JD = J[i] @ D[i]
            else:
                JD += J[i] @ D[i]
        JD = JD[..., None]
        quality = (last - loss) / -((JD).mT @ (2 * R.view_as(JD) + JD)).squeeze()
        if quality > pg['high']:
            pg['damping'] = pg['damping'] * pg['down']
        elif quality > pg['low']:
            pg['damping'] = pg['damping']
        else:
            pg['damping'] = pg['damping'] * pg['up']
        pg['damping'] = max(self.min, min(pg['damping'], self.max))


from pypose.optim.solver import CG
class PCG(CG):
    def __init__(self, maxiter=None, tol=0.00001):
        super().__init__(maxiter, tol)
    def forward(self, A: Tensor, b: Tensor, x: Tensor | None = None, M: Tensor | None = None) -> Tensor:
        lhs = A
        rhs = b
        if b.dim() == 1:
            b = b[..., None]
        l_diag = lhs.diagonal()
        l_diag[l_diag.abs() < 1e-6] = 1e-6
        M = torch.sparse.spdiags(1 / l_diag[None].cpu(), offsets=torch.zeros(1, dtype=int), shape=lhs.shape)
        M = M.to_sparse_bsr(blocksize=A.values().shape[-2:]).to(DEVICE)
        rhs = M @ rhs
        lhs = M @ lhs.to_sparse_bsc(blocksize=lhs.values().shape[-2:])

        return super().forward(lhs, rhs, x)

class SciPySpSolver(nn.Module):
    def __init__(self, ):
        super().__init__()
    def forward(self, A, b):
        import scipy.sparse.linalg as spla
        import scipy.sparse as sp
        import numpy as np
        if A.layout != torch.sparse_csr:
            A = A.to_sparse_coo().to_sparse_csr()
        A_csr = sp.csr_matrix((A.values().cpu().numpy(), 
                                   A.col_indices().cpu().numpy(),
                                   A.crow_indices().cpu().numpy()),
                                  shape=A.shape)
        b = b.cpu().numpy()
        x = spla.spsolve(A_csr, b, use_umfpack=False)
        assert not np.isnan(x).any()
        # a_err = np.linalg.norm(A_csr @ x - b)
        # r_err = a_err / np.linalg.norm(b)
        # print(f"Linear Solver Error: {a_err}, relative error: {r_err}")
        return torch.from_numpy(x).to(A.device)

from pyro_slam.sparse.solve import cusolvesp, cudss

class cuSolverSP(nn.Module):
    def __init__(self, ):
        super().__init__()
    def forward(self, A, b):
        x = cudss(A, b.flatten())
        return x


def estimate_sparse_jacobian_flops(input_args):
    point_indices = input_args["point_indices"]

    # this is the number of row in the Jacobian
    N = point_indices.shape[0]

    # Estimate FLOPs for forward pass
    flops_backward = estimate_forward_flops_vmap() * 2 # assume backward pass is twice as expensive as forward pass

    # In the sparse mode, each jacobian row only need to backward for the parameters that affect the output
    total_flops = N * flops_backward

    return total_flops
    
def estimate_JT_J_flops(J: torch.Tensor) -> int:
    # Get the row and column indices of the non-zeros
    rows, cols = J._indices()

    # Count how many non-zeros per column (this is efficient and avoids iteration)
    col_nnz_counts = torch.bincount(cols)

    # Compute FLOPs using vectorized operations
    flops = torch.sum(col_nnz_counts ** 2).item()

    return int(2 * flops)  # Multiply by 2 to account for multiplication and addition

import scipy.sparse as sp
import scipy.sparse.linalg as splinalg

class SparseLMFlopsEstimator(pp.optim.LevenbergMarquardt):
    @torch.no_grad()
    def step(self, input, target=None, weight=None):
        
        flops_forward = 2 * estimate_forward_flops(input)
        flops_jacobian = 2 * estimate_sparse_jacobian_flops(input) # two jacobians
        flops_param_update = self.model.model.get_num_params() # number of parameters
        
        # TBD by actual data
        flops_JT_J = 0
        flops_damping = 0
        flops_b = 0
        flops_solver = 0
        
        for pg in self.param_groups:
            weight = self.weight if weight is None else weight
            R = list(self.model(input))
            J = modjacrev_vmap(self.model, input)

            R = R[0]
            J = torch.cat([j.to_sparse_coo() for j in J], dim=-1)

            self.last = self.loss = self.loss if hasattr(self, 'loss') \
                                    else self.model.loss(input, target)
            
            J_T = J.T @ weight if weight is not None else J.T
            
            A, self.reject_count = J_T @ J, 0
            flops_JT_J += estimate_JT_J_flops(J)
            
            A = A.to_sparse_csr()
            diagonal_op_(A, op=partial(torch.clamp_, min=pg['min'], max=pg['max']))
            flops_damping += A.size(0) # diagonal

            while self.last <= self.loss:
                diagonal_op_(A, op=partial(torch.mul, other=1+pg['damping']))
                try:
                    D = self.solver(A = A, b = -J_T @ R.view(-1, 1))
                    flops_b += J_T._nnz() * 2
                    num_nonzeros_L = A._nnz() * 2  # The non-zeros in the factor L
                    flops_factorization = 6 * num_nonzeros_L
                    flops_substitution = 2 * num_nonzeros_L
                    flops_solver += flops_factorization + flops_substitution
                    D = D[:, None]
                except Exception as e:
                    print(e, "\nLinear solver failed. Breaking optimization step...")
                    break
                self.update_parameter(pg['params'], D)
                self.loss = self.model.loss(input, target)
                print("Loss:", self.loss, "Last Loss:", self.last, "Reject Count:", self.reject_count, "Damping:", pg['damping'])
                self.strategy.update(pg, last=self.last, loss=self.loss, J=J, D=D, R=R.view(-1, 1))
                if self.last < self.loss and self.reject_count < self.reject: # reject step
                    self.update_parameter(params = pg['params'], step = -D)
                    self.loss, self.reject_count = self.last, self.reject_count + 1
                else:
                    break

        flops_dict = {
            "jac": flops_jacobian,
            "solver": flops_solver,
            "JT_J": flops_JT_J,
            "others": flops_forward + flops_damping + flops_b + flops_param_update,
            "total": flops_jacobian + flops_solver + flops_forward + flops_JT_J + flops_damping + flops_b + flops_param_update
        }
        return flops_dict
    def update_parameter(self, params, step):
        numels = []
        for i, p in enumerate(params):
            if p.requires_grad:
                if i == 0:
                    numels.append(p.shape[0] * (9 if OPTIMIZE_INTRINSICS else 6))
                else:
                    numels.append(p.numel())
        steps = step.split(numels)
        for i, (p, d) in enumerate(zip(params, steps)):
            if p.requires_grad:
                if i == 0:
                    # continue
                    if USE_QUATERNIONS:
                        p[..., :7] = pp.SE3(p[..., :7]).add_(pp.se3(d.view(p.shape[0], -1)[..., :6]))
                        if OPTIMIZE_INTRINSICS: p[:, 7:] += d.view(p.shape[0], -1)[:, 6:]
                        continue
                p.add_(d.view(p.shape))

In [13]:
from pypose.optim.solver import CG
from pypose.optim import LevenbergMarquardt

camera_params_other = None if NUM_CAMERA_PARAMS == trimmed_dataset['camera_params'].shape[1] else trimmed_dataset['camera_params'][:, NUM_CAMERA_PARAMS:]
input = {"points_2d": trimmed_dataset['points_2d'],
         "intr": camera_params_other,
         "camera_indices": trimmed_dataset['camera_index_of_observations'],
         "point_indices": trimmed_dataset['point_index_of_observations']}

model_non_batched = ReprojNonBatched(trimmed_dataset['camera_params'][:, :NUM_CAMERA_PARAMS].clone(),
                                     trimmed_dataset['points_3d'].clone())
model_non_batched = model_non_batched.to(DEVICE)

strategy = pp.optim.strategy.Adaptive(damping=0.0001, min=1.5e-6)

#optimizer_flops_estimator = DenseLMFlopsEstimator(model_non_batched, strategy=strategy, reject=30)
sparse_solver = cuSolverSP()
optimizer_flops_estimator = SparseLMFlopsEstimator(model_non_batched, strategy=strategy, solver=sparse_solver, reject=0)

print('Starting loss:', least_square_error(model_non_batched.pose, model_non_batched.points_3d, trimmed_dataset['camera_index_of_observations'], trimmed_dataset['point_index_of_observations'], trimmed_dataset['points_2d'], intr=camera_params_other).item())
for idx in range(1):
    flops_dict = optimizer_flops_estimator.step(input)
    print(f"FLOPs for step {idx + 1}: {get_human_readable_flops_dict(flops_dict)}")
    #print('Loss:', least_square_error(model_non_batched.pose, model_non_batched.points_3d, trimmed_dataset['camera_index_of_observations'], trimmed_dataset['point_index_of_observations'], trimmed_dataset['points_2d'], intr=camera_params_other).item())
print('Ending loss:', least_square_error(model_non_batched.pose, model_non_batched.points_3d, trimmed_dataset['camera_index_of_observations'], trimmed_dataset['point_index_of_observations'], trimmed_dataset['points_2d'], intr=camera_params_other).item())

Starting loss: 85279364.20754541
called_count: 4
Fill-in factor: 0.19957
Loss: tensor(4.8047e+106, device='cuda:0') Last Loss: tensor(1.7056e+08, device='cuda:0') Reject Count: 0 Damping: 0.0001
FLOPs for step 1: {'jac': '157.46 MFLOPs', 'solver': '414.47 MFLOPs', 'JT_J': '16.17 GFLOPs', 'others': '104.13 MFLOPs', 'total': '16.85 GFLOPs'}
Ending loss: 2.402367371033183e+106
