In [1]:
import torch
import warnings
from typing import Callable, List, Optional
from torch.library import Library
from tqdm import tqdm
from torchvision.transforms import Compose

# diag is already mature 
from pyro_slam.sparse.py_ops import *
from pyro_slam.sparse.bsr_cuda import *
from pyro_slam.sparse.solve import *


  Overriding a previously registered kernel for the same operator and the same dispatch key
  operator: aten::mm(Tensor self, Tensor mat2) -> Tensor
    registered at /home/zitongzhan/base/pytorch/build/aten/src/ATen/RegisterSchema.cpp:6
  dispatch key: SparseCsrCPU
  previous kernel: registered at /home/zitongzhan/base/pytorch/aten/src/ATen/LegacyBatchingRegistrations.cpp:1079
       new kernel: registered at /media/zitongzhan/New_Volume/pyro_slam/pyro_slam/sparse/sparse_op_cpp.cpp:206 (function operator())


# Bundle Adjustment Example using SparsePyBA and the BAL dataset

```
The dataset is from the following paper:  
Sameer Agarwal, Noah Snavely, Steven M. Seitz, and Richard Szeliski.  
Bundle adjustment in the large.  
In European Conference on Computer Vision (ECCV), 2010.  
```

Link to the dataset: https://grail.cs.washington.edu/projects/bal/

# Fetch data

In [2]:
from datapipes.bal_loader import get_problem, read_bal_data

TARGET_DATASET = "ladybug"
TARGET_PROBLEM = "problem-49-7776-pre"
# TARGET_PROBLEM = "problem-1723-156502-pre"
# TARGET_PROBLEM = "problem-1695-155710-pre"  
# TARGET_PROBLEM = "problem-969-105826-pre"


# TARGET_DATASET = "trafalgar"
# TARGET_PROBLEM = "problem-21-11315-pre"

DEVICE = 'cuda' # change device to CPU if needed
DTYPE = torch.float64
USE_QUATERNIONS = True
OPTIMIZE_INTRINSICS = True
# dataset = read_bal_data(file_name='Data/dubrovnik-3-7-pre.txt', use_quat=USE_QUATERNIONS)
dataset = get_problem(TARGET_PROBLEM, TARGET_DATASET, use_quat=USE_QUATERNIONS)

if OPTIMIZE_INTRINSICS:
    NUM_CAMERA_PARAMS = 10 if USE_QUATERNIONS else 9
else:
    NUM_CAMERA_PARAMS = 7 if USE_QUATERNIONS else 6

print(f'Fetched {TARGET_PROBLEM} from {TARGET_DATASET}')

import torch
import torch.nn as nn
import pypose as pp

trimmed_dataset = dataset
trimmed_dataset = {k: v.to(DEVICE) for k, v in trimmed_dataset.items() if type(v) == torch.Tensor}

def convert_to(type):
    for k, v in trimmed_dataset.items():
        if 'index' not in k:
            trimmed_dataset[k] = v.to(type)


convert_to(DTYPE)
torch.set_default_dtype(DTYPE)

Streaming data for ladybug...


  camera_params = torch.tensor(camera_params).to(torch.float32)


Fetched problem-49-7776-pre from ladybug


# Declare helper functions

In [3]:
import torch

def rotate_euler(points, rot_vecs):
    """Rotate points by given rotation vectors.

    Rodrigues' rotation formula is used.
    """
    theta = torch.norm(rot_vecs, dim=-1, keepdim=True)
    v = torch.nan_to_num(rot_vecs / theta)
    dot = torch.sum(points * v, dim=-1, keepdim=True)
    cos_theta = torch.cos(theta)
    sin_theta = torch.sin(theta)
    return cos_theta * points + sin_theta * torch.cross(v, points, dim=-1) + dot * (1 - cos_theta) * v

def rotate_quat(points, rot_vecs):
    rot_vecs = pp.SE3(rot_vecs)
    return rot_vecs.Act(points)

def project(points, camera_params):
    """Convert 3-D points to 2-D by projecting onto images."""
    if USE_QUATERNIONS:
        points_proj = rotate_quat(points, camera_params[..., :7])
    else:
        points_proj = rotate_euler(points, camera_params[..., 3:6])
        points_proj = points_proj + camera_params[..., :3]
    points_proj = -points_proj[..., :2] / points_proj[..., 2].unsqueeze(-1)  # add dimension for broadcasting
    f = camera_params[..., -3].unsqueeze(-1)
    k1 = camera_params[..., -2].unsqueeze(-1)
    k2 = camera_params[..., -1].unsqueeze(-1)
    
    n = torch.sum(points_proj**2, axis=-1, keepdim=True)
    r = 1 + k1 * n + k2 * n**2
    points_proj = points_proj * r * f  # broadcasting will take care of the shape

    return points_proj


# sparse version
class ReprojNonBatched(nn.Module):
    def __init__(self, camera_params, points_3d):
        super().__init__()
        self.pose = nn.Parameter(camera_params)
        self.points_3d = nn.Parameter(points_3d)

    def forward(self, points_2d, intr, camera_indices, point_indices):
        camera_params = self.pose
        points_3d = self.points_3d
        if intr is not None:
            camera_params = torch.cat([camera_params, intr], dim=-1)
        points_proj = project(points_3d[point_indices], camera_params[camera_indices])
        loss = points_proj - points_2d
        return loss

from functools import partial

def least_square_error(camera_params, points_3d, camera_indices, point_indices, points_2d, intr=None):
    model = ReprojNonBatched(camera_params, points_3d)
    loss = model(points_2d, intr, camera_indices, point_indices)
    return torch.sum(loss**2) / 2

In [4]:
from torch.func import jacrev, jacfwd


def construct_sbt(jac_from_vmap, num, index):
    n = index.shape[0] # num 2D points
    i = torch.stack([torch.arange(n).to(index.device), index])
    block_shape = jac_from_vmap.shape[1:]
    v = jac_from_vmap # adjust dimension to accomodate for sbt constructor
    dummy_val = torch.arange(n, device=index.device, dtype=torch.int64)
    dummy_coo = torch.sparse_coo_tensor(i, dummy_val, size=(n, num), device=index.device, dtype=torch.int64)
    dummy_csc = dummy_coo.coalesce().to_sparse_csc()
    return torch.sparse_bsc_tensor(ccol_indices = dummy_csc.ccol_indices(), 
                                   row_indices=dummy_csc.row_indices(),
                                   values = v[dummy_csc.values()],
                                   size = (n * block_shape[0], num * block_shape[1]),
                                   device=index.device, dtype=DTYPE)

def modjacrev_vmap(model, input, argnums=0, *, has_aux=False):
    params = dict(model.named_parameters())

    cameras_num = params['model.pose'].shape[0]
    points_3d_num = params['model.points_3d'].shape[0]
    # need to align the indices with the parameters
    camera_indices = input['camera_indices']
    point_indices = input['point_indices']
    params['model.pose'] = params['model.pose'][camera_indices] # index using camera indices
    params['model.points_3d'] = params['model.points_3d'][point_indices] # index using point indices
    jac_points_3d, jac_pose = torch.vmap(jacrev(project, argnums=(0, 1), has_aux=has_aux))(params['model.points_3d'], params['model.pose'])
    if USE_QUATERNIONS: 
        useful_idx = [0,1,2,3,4,5,7,8,9] if OPTIMIZE_INTRINSICS else [0,1,2,3,4,5,]
        jac_pose = jac_pose[..., useful_idx] # remove the 4th element of the quaternion
                                                    # because original is [qx, qy, qz, qw, tx, ty, tz], but always dqw = 0
    return [construct_sbt(jac_pose, cameras_num, camera_indices), construct_sbt(jac_points_3d, points_3d_num, point_indices)]

# Run optimization

In [5]:
import time
from typing import Optional

from torch import Tensor
    
class TrustRegion(pp.optim.strategy.TrustRegion):
    def update(self, pg, last, loss, J, D, R, *args, **kwargs):
        J = [i.to_sparse_coo() for i in J]
        JD = None
        for i in range(len(D)):
            if JD is None:
                JD = J[i] @ D[i]
            else:
                JD += J[i] @ D[i]
        JD = JD[..., None]
        quality = (last - loss) / -((JD).mT @ (2 * R.view_as(JD) + JD)).squeeze()
        pg['radius'] = 1. / pg['damping']
        if quality > pg['high']:
            pg['radius'] = pg['up'] * pg['radius']
            pg['down'] = self.down
        elif quality > pg['low']:
            pg['radius'] = pg['radius']
            pg['down'] = self.down
        else:
            pg['radius'] = pg['radius'] * pg['down']
            pg['down'] = pg['down'] * pg['factor']
        pg['down'] = max(self.min, min(pg['down'], self.max))
        pg['radius'] = max(self.min, min(pg['radius'], self.max))
        pg['damping'] = 1. / pg['radius']

class Adaptive(pp.optim.strategy.Adaptive):
    def update(self, pg, last, loss, J, D, R, *args, **kwargs):
        J = [i.to_sparse_coo() for i in J]
        JD = None
        for i in range(len(D)):
            if JD is None:
                JD = J[i] @ D[i]
            else:
                JD += J[i] @ D[i]
        JD = JD[..., None]
        quality = (last - loss) / -((JD).mT @ (2 * R.view_as(JD) + JD)).squeeze()
        if quality > pg['high']:
            pg['damping'] = pg['damping'] * pg['down']
        elif quality > pg['low']:
            pg['damping'] = pg['damping']
        else:
            pg['damping'] = pg['damping'] * pg['up']
        pg['damping'] = max(self.min, min(pg['damping'], self.max))


from pypose.optim.solver import CG
class PCG(CG):
    def __init__(self, maxiter=None, tol=0.00001):
        super().__init__(maxiter, tol)
    def forward(self, A: Tensor, b: Tensor, x: Tensor | None = None, M: Tensor | None = None) -> Tensor:
        lhs = A
        rhs = b
        if b.dim() == 1:
            b = b[..., None]
        l_diag = lhs.diagonal()
        l_diag[l_diag.abs() < 1e-6] = 1e-6
        M = torch.sparse.spdiags(1 / l_diag[None].cpu(), offsets=torch.zeros(1, dtype=int), shape=lhs.shape)
        M = M.to_sparse_bsr(blocksize=A.values().shape[-2:]).to(DEVICE)
        rhs = M @ rhs
        lhs = M @ lhs.to_sparse_bsc(blocksize=lhs.values().shape[-2:])

        return super().forward(lhs, rhs, x)

class SciPySpSolver(nn.Module):
    def __init__(self, ):
        super().__init__()
    def forward(self, A, b):
        import scipy.sparse.linalg as spla
        import scipy.sparse as sp
        import numpy as np
        if A.layout != torch.sparse_csr:
            A = A.to_sparse_coo().to_sparse_csr()
        A_csr = sp.csr_matrix((A.values().cpu().numpy(), 
                                   A.col_indices().cpu().numpy(),
                                   A.crow_indices().cpu().numpy()),
                                  shape=A.shape)
        b = b.cpu().numpy()
        x = spla.spsolve(A_csr, b, use_umfpack=False)
        assert not np.isnan(x).any()
        # a_err = np.linalg.norm(A_csr @ x - b)
        # r_err = a_err / np.linalg.norm(b)
        # print(f"Linear Solver Error: {a_err}, relative error: {r_err}")
        return torch.from_numpy(x).to(A.device)

from pyro_slam.sparse.solve import cusolvesp, cudss

class cuSolverSP(nn.Module):
    def __init__(self, ):
        super().__init__()
    def forward(self, A, b):
        x = cudss(A, b.flatten())
        return x

In [6]:
class LM(pp.optim.LevenbergMarquardt):
    @torch.no_grad()
    def step(self, input, target=None, weight=None):
        for pg in self.param_groups:
            weight = self.weight if weight is None else weight
            R = list(self.model(input))
            J = modjacrev_vmap(self.model, input)

            # params = dict(self.model.named_parameters())
            # params_values = tuple(params.values())
            # J = [self.model.flatten_row_jacobian(Jr, params_values) for Jr in J]
            # for i in range(len(R)):
            #     R[i], J[i] = self.corrector[0](R = R[i], J = J[i]) if len(self.corrector) ==1 \
            #         else self.corrector[i](R = R[i], J = J[i])
            R = R[0]
            J = torch.cat([j.to_sparse_coo() for j in J], dim=-1)

            self.last = self.loss = self.loss if hasattr(self, 'loss') \
                                    else self.model.loss(input, target)
            J_T = J.T @ weight if weight is not None else J.T
            A, self.reject_count = J_T @ J, 0
            # A = A.to_sparse_bsr(blocksize=(1,1))
            A = A.to_sparse_csr()
            diagonal_op_(A, op=partial(torch.clamp_, min=pg['min'], max=pg['max']))

            while self.last <= self.loss:
                diagonal_op_(A, op=partial(torch.mul, other=1+pg['damping']))
                try:
                    D = self.solver(A = A, b = -J_T @ R.view(-1, 1))
                    D = D[:, None]
                except Exception as e:
                    print(e, "\nLinear solver failed. Breaking optimization step...")
                    break
                self.update_parameter(pg['params'], D)
                self.loss = self.model.loss(input, target)
                print("Loss:", self.loss, "Last Loss:", self.last, "Reject Count:", self.reject_count, "Damping:", pg['damping'])
                self.strategy.update(pg, last=self.last, loss=self.loss, J=J, D=D, R=R.view(-1, 1))
                if self.last < self.loss and self.reject_count < self.reject: # reject step
                    self.update_parameter(params = pg['params'], step = -D)
                    self.loss, self.reject_count = self.last, self.reject_count + 1
                else:
                    break
        return self.loss
    def update_parameter(self, params, step):
        numels = []
        for i, p in enumerate(params):
            if p.requires_grad:
                if i == 0:
                    numels.append(p.shape[0] * (9 if OPTIMIZE_INTRINSICS else 6))
                else:
                    numels.append(p.numel())
        steps = step.split(numels)
        for i, (p, d) in enumerate(zip(params, steps)):
            if p.requires_grad:
                if i == 0:
                    # continue
                    if USE_QUATERNIONS:
                        p[..., :7] = pp.SE3(p[..., :7]).add_(pp.se3(d.view(p.shape[0], -1)[..., :6]))
                        if OPTIMIZE_INTRINSICS: p[:, 7:] += d.view(p.shape[0], -1)[:, 6:]
                        continue
                p.add_(d.view(p.shape))


class Schur(LM):
    @torch.no_grad()
    def step(self, input, target=None, weight=None):
        for pg in self.param_groups:
            weight = self.weight if weight is None else weight
            R = self.model(input, target)
            J = modjacrev_vmap(self.model, input)

            R = R[0]
            J[0] = J[0]
            J[1] = J[1]

            self.last = self.loss = self.loss if hasattr(self, 'loss') \
                                    else self.model.loss(input, target)
            torch.cuda.nvtx.range_push("JTJc")
            U = J[0].mT @ J[0]
            torch.cuda.nvtx.range_pop()
            # J0D = J[0].to_dense()
            # UD = U.to_dense()
            # torch.testing.assert_close(UD, J0D.mT @ J0D)
            # del J0D
            # del UD
            torch.cuda.nvtx.range_push("JTJp")
            V = J[1].mT @ J[1]
            torch.cuda.nvtx.range_pop()
            # J1D = J[1].to_dense()
            # VD = V.to_dense()
            # torch.testing.assert_close(VD, J1D.mT @ J1D)
            # del J1D
            # del VD
            
            torch.cuda.nvtx.range_push("Clamp")
            diagonal_op_(U, op=partial(torch.clamp_, min=pg['min'], max=pg['max']))
            diagonal_op_(V, op=partial(torch.clamp_, min=pg['min'], max=pg['max']))
            torch.cuda.nvtx.range_pop()

            while self.last <= self.loss:
                damping = pg['damping']
                R = R.reshape(-1)
                
                torch.cuda.nvtx.range_push("Damp")
                diagonal_op_(U, op=partial(torch.add, other=(torch.diagonal(U).pow(2)) * damping))
                diagonal_op_(V, op=partial(torch.add, other=(torch.diagonal(V).pow(2)) * damping))
                torch.cuda.nvtx.range_pop()

                torch.cuda.nvtx.range_push("W")
                W = J[0].mT @ J[1]
                torch.cuda.nvtx.range_pop()
                torch.cuda.nvtx.range_push("Ic")
                Ic = -J[0].mT.to_sparse_coo().to_sparse_csr() @ R
                Ip = -J[1].mT.to_sparse_coo().to_sparse_csr() @ R
                torch.cuda.nvtx.range_pop()
                torch.cuda.nvtx.range_push("Inv")
                V_i = inv_op(V)
                torch.cuda.nvtx.range_pop()
                torch.cuda.nvtx.range_push("WVi")
                WV_i = W @ V_i
                torch.cuda.nvtx.range_pop()
                torch.cuda.nvtx.range_push("rhs1")
                rhs = Ic - WV_i.to_sparse_coo().to_sparse_csr() @ Ip  
                torch.cuda.nvtx.range_pop()
                torch.cuda.nvtx.range_push("lhs1")
                lhs = add_op(U, (-WV_i @ W.mT))  # this matrix is NOT symetric
                torch.cuda.nvtx.range_pop()

                torch.cuda.nvtx.range_push("Solve C")
                try:
                    D_c = self.solver(A = lhs.to_sparse_coo().to_sparse_csr(), b = rhs)
                except Exception as e:
                    print(e, "\nLinear solver failed. Breaking optimization step...")
                    break
                torch.cuda.nvtx.range_pop()
                
                torch.cuda.nvtx.range_push("rhs2")
                rhs = Ip - W.mT.to_sparse_coo() @ D_c
                torch.cuda.nvtx.range_pop()
                torch.cuda.nvtx.range_push("solve2")
                lhs = V
                D_p = self.solver(A = lhs.to_sparse_coo().to_sparse_csr(), b = rhs)
                torch.cuda.nvtx.range_pop()
                torch.cuda.nvtx.range_push("Update")
                D = torch.cat([D_c, D_p])
                self.update_parameter(pg['params'], D)
                torch.cuda.nvtx.range_pop()
                self.loss = self.model.loss(input, target)
                print("Loss:", self.loss, "Last Loss:", self.last, "Reject Count:", self.reject_count, "Damping:", pg['damping'])
                torch.cuda.nvtx.range_push("Strategy")
                self.strategy.update(pg, last=self.last, loss=self.loss, J=J, D=[D_c, D_p], R=R.view(-1, 1))
                torch.cuda.nvtx.range_pop()
                if self.last < self.loss and self.reject_count < self.reject: # reject step
                    self.update_parameter(params = pg['params'], step = -D)
                    self.loss, self.reject_count = self.last, self.reject_count + 1
                else:
                    break
        return self.loss
    
    # def _update_parameter(self, params, step):
        
    #     V, Ip, W = self.model.cur['V'], self.model.cur['Ip'], self.model.cur['W']
    #     rhs = Ip - W.mT.to_sparse_coo() @ step
    #     lhs = V
    #     D_p = self.solver(A = lhs, b = rhs)
    #     params[1] += D_p.view_as(params[1])
    #     return step, D_p

In [7]:

camera_params_other = None if NUM_CAMERA_PARAMS == trimmed_dataset['camera_params'].shape[1] else trimmed_dataset['camera_params'][:, NUM_CAMERA_PARAMS:]
input = {"points_2d": trimmed_dataset['points_2d'],
         "intr": camera_params_other,
         "camera_indices": trimmed_dataset['camera_index_of_observations'],
         "point_indices": trimmed_dataset['point_index_of_observations']}

# inverse quat
def openGL2gtsam(pose):
    R = pose.rotation()
    t = pose.translation()
    R90 = torch.eye(3, device=pose.device, dtype=pose.dtype)
    R90[0, 0] = 1
    R90[1, 1] = -1
    R90[2, 2] = -1
    wRc = R.Inv() @ pp.mat2SO3(R90)
    t = R.Inv() @ -t
    # // Our camera-to-world translation wTc = -R'*t
    return pp.SE3(torch.cat([t, wRc], dim=-1))

# gtsam coord
# trimmed_dataset['camera_params'][:, :7] = Compose([pp.SE3, openGL2gtsam])(trimmed_dataset['camera_params'][:, :7])
# trimmed_dataset['points_2d'][:, 1] = -trimmed_dataset['points_2d'][:, 1]

model_non_batched = ReprojNonBatched(trimmed_dataset['camera_params'][:, :NUM_CAMERA_PARAMS].clone(),
                                     trimmed_dataset['points_3d'].clone())

model_non_batched = model_non_batched.to(DEVICE)

# strategy_sparse = TrustRegion()
strategy_sparse = pp.optim.strategy.Adaptive(damping=0.0001, min=1.5e-6)
# sparse_solver = PCG(tol=1e-3, maxiter=10000)
# sparse_solver = SciPySpSolver()
sparse_solver = cuSolverSP()
optimizer_sparse = LM(model_non_batched, strategy=strategy_sparse, solver=sparse_solver, reject=30)

# least_square_error(camera_params, points_3d, camera_indices, point_indices, points_2d, intr=None)

print('Starting loss:', least_square_error(model_non_batched.pose, model_non_batched.points_3d, trimmed_dataset['camera_index_of_observations'], trimmed_dataset['point_index_of_observations'], trimmed_dataset['points_2d'], intr=camera_params_other).item())
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for idx in range(10):
    loss = optimizer_sparse.step(input)

torch.cuda.synchronize()
end.record()
print('Time', start.elapsed_time(end) / 1000)

print('Loss:', least_square_error(model_non_batched.pose, model_non_batched.points_3d, trimmed_dataset['camera_index_of_observations'], trimmed_dataset['point_index_of_observations'], trimmed_dataset['points_2d'], intr=camera_params_other).item())
print('Ending loss:', least_square_error(model_non_batched.pose, model_non_batched.points_3d, trimmed_dataset['camera_index_of_observations'], trimmed_dataset['point_index_of_observations'], trimmed_dataset['points_2d'], intr=camera_params_other).item())

Starting loss: 850912.7770917793


  dummy_csc = dummy_coo.coalesce().to_sparse_csc()


called_count: 0
Loss: tensor(91744.2051, device='cuda:0') Last Loss: tensor(1701825.5542, device='cuda:0') Reject Count: 0 Damping: 0.0001
Time 2.229585693359375
Loss: 45872.10257305857
called_count: 1
Loss: tensor(29565.0478, device='cuda:0') Last Loss: tensor(91744.2051, device='cuda:0') Reject Count: 0 Damping: 5e-05
Time 0.1874469757080078
Loss: 14782.523915175754
called_count: 2
Loss: tensor(26947.8921, device='cuda:0') Last Loss: tensor(29565.0478, device='cuda:0') Reject Count: 0 Damping: 2.5e-05
Time 0.18658268737792968
Loss: 13473.946054467091
called_count: 3
Loss: tensor(26843.7234, device='cuda:0') Last Loss: tensor(26947.8921, device='cuda:0') Reject Count: 0 Damping: 1.25e-05
Time 0.17814694213867188
Loss: 13421.861690113663
called_count: 4
Loss: tensor(26815.0757, device='cuda:0') Last Loss: tensor(26843.7234, device='cuda:0') Reject Count: 0 Damping: 6.25e-06
Time 0.17733056640625
Loss: 13407.537843024096
called_count: 5
Loss: tensor(26820.8826, device='cuda:0') Last Los