Skip to content

Commit

Permalink
Control MST init with flag plus fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ayushbaid committed May 14, 2024
1 parent 429177a commit beef33b
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 137 deletions.
24 changes: 14 additions & 10 deletions gtsfm/averaging/rotation/shonan.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import gtsfm.utils.rotation as rotation_util
from gtsfm.averaging.rotation.rotation_averaging_base import RotationAveragingBase
from gtsfm.common.pose_prior import PosePrior
from gtsfm.common.two_view_estimation_report import TwoViewEstimationReport

POSE3_DOF = 6

Expand All @@ -42,7 +41,9 @@
class ShonanRotationAveraging(RotationAveragingBase):
"""Performs Shonan rotation averaging."""

def __init__(self, two_view_rotation_sigma: float = _DEFAULT_TWO_VIEW_ROTATION_SIGMA) -> None:
def __init__(
self, two_view_rotation_sigma: float = _DEFAULT_TWO_VIEW_ROTATION_SIGMA, use_mst_init: bool = False
) -> None:
"""Initializes module.
Note: `p_min` and `p_max` describe the minimum and maximum relaxation rank.
Expand All @@ -52,6 +53,7 @@ def __init__(self, two_view_rotation_sigma: float = _DEFAULT_TWO_VIEW_ROTATION_S
"""
super().__init__()
self._two_view_rotation_sigma = two_view_rotation_sigma
self._use_mst_init = use_mst_init
self._p_min = 5
self._p_max = 64

Expand Down Expand Up @@ -198,14 +200,16 @@ def run_rotation_averaging(
if (i1, i2) in i2Ri1_dict
}
# Use negative of the number of correspondences as the edge weight.
wRi_initial_ = rotation_util.initialize_global_rotations_using_mst(
len(nodes_with_edges),
i2Ri1_dict_,
edge_weights={(i1, i2): -num_correspondences_dict.get((i1, i2), 0) for i1, i2 in i2Ri1_dict_.keys()},
)
initial_values = Values()
for i, wRi_initial_ in enumerate(wRi_initial_):
initial_values.insert(i, wRi_initial_)
initial_values: Optional[Values] = None
if self._use_mst_init:
wRi_initial_ = rotation_util.initialize_global_rotations_using_mst(
len(nodes_with_edges),
i2Ri1_dict_,
edge_weights={(i1, i2): -num_correspondences_dict.get((i1, i2), 0) for i1, i2 in i2Ri1_dict_.keys()},
)
initial_values = Values()
for i, wRi in enumerate(wRi_initial_):
initial_values.insert(i, wRi)

between_factors: BetweenFactorPose3s = self.__between_factors_from_2view_relative_rotations(
i2Ri1_dict, old_to_new_idxes
Expand Down
66 changes: 1 addition & 65 deletions gtsfm/utils/rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def initialize_global_rotations_using_mst(
G = nx.Graph()
G.add_edges_from(mst.edges)

wRi_list = [None] * num_images
wRi_list: List[Rot3] = [Rot3()] * num_images
# Choose origin node.
origin_node = list(G.nodes)[0]
wRi_list[origin_node] = Rot3()
Expand All @@ -61,67 +61,3 @@ def initialize_global_rotations_using_mst(
wRi_list[dst_node] = wRi1

return wRi_list


# def initialize_mst(
# num_images: int,
# i2Ri1_dict: Dict[Tuple[int, int], Optional[Rot3]],
# corr_idxs: Dict[Tuple[int, int], np.ndarray],
# old_to_new_idxs: Dict[int, int],
# ) -> gtsam.Values:
# """Initialize global rotations using the minimum spanning tree (MST).

# Args:
# num_images: Number of images in the scene.
# i2Ri1_dict: Dictionary of relative rotations (i1, i2): i2Ri1.
# corr_idxs:
# old_to_new_idxs:

# Returns:
# Initialization of global rotations for Values.
# """
# # Compute MST.
# row, col, data = [], [], []
# for (i1, i2), i2Ri1 in i2Ri1_dict.items():
# if i2Ri1 is None:
# continue
# row.append(i1)
# col.append(i2)
# data.append(-corr_idxs[(i1, i2)].shape[0])
# logger.info(corr_idxs[(i1, i2)])
# corr_adjacency = scipy.sparse.coo_array((data, (row, col)), shape=(num_images, num_images))
# Tcsr = scipy.sparse.csgraph.minimum_spanning_tree(corr_adjacency)
# logger.info(Tcsr.toarray().astype(int))

# # Build global rotations from MST.
# # TODO (travisdriver): This is simple but very inefficient. Use something else.
# i_mst, j_mst = Tcsr.nonzero()
# logger.info(i_mst)
# logger.info(j_mst)
# edges_mst = [(i, j) for (i, j) in zip(i_mst, j_mst)]
# iR0_dict = {i_mst[0]: np.eye(3)} # pick the left index of the first edge as the seed
# # max_iters = num_images * 10
# iter = 0
# while len(edges_mst) > 0:
# i, j = edges_mst.pop(0)
# if i in iR0_dict:
# jRi = i2Ri1_dict[(i, j)].matrix()
# iR0 = iR0_dict[i]
# iR0_dict[j] = jRi @ iR0
# elif j in iR0_dict:
# iRj = i2Ri1_dict[(i, j)].matrix().T
# jR0 = iR0_dict[j]
# iR0_dict[i] = iRj @ jR0
# else:
# edges_mst.append((i, j))
# iter += 1
# # if iter >= max_iters:
# # logger.info("Reached max MST iters.")
# # assert False

# # Add to Values object.
# initial = gtsam.Values()
# for i, iR0 in iR0_dict.items():
# initial.insert(old_to_new_idxs[i], Rot3(iR0))

# return initial
103 changes: 44 additions & 59 deletions tests/averaging/rotation/test_shonan.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from gtsfm.averaging.rotation.rotation_averaging_base import RotationAveragingBase
from gtsfm.averaging.rotation.shonan import ShonanRotationAveraging
from gtsfm.common.pose_prior import PosePrior, PosePriorType
from gtsfm.common.two_view_estimation_report import TwoViewEstimationReport

ROTATION_ANGLE_ERROR_THRESHOLD_DEG = 2

Expand All @@ -41,13 +40,8 @@ def __execute_test(self, i2Ri1_input: Dict[Tuple[int, int], Rot3], wRi_expected:
wRi_expected: Expected global rotations.
"""
i1Ti2_priors: Dict[Tuple[int, int], PosePrior] = {}
two_view_estimation_reports = {
(i1, i2): TwoViewEstimationReport(v_corr_idxs=np.array([]), num_inliers_est_model=random.randint(0, 100))
for i1, i2 in i2Ri1_input.keys()
}
wRi_computed = self.obj.run_rotation_averaging(
len(wRi_expected), i2Ri1_input, i1Ti2_priors, two_view_estimation_reports
)
v_corr_idxs = {(i1, i2): _generate_corr_idxs(random.randint(0, 100)) for i1, i2 in i2Ri1_input.keys()}
wRi_computed = self.obj.run_rotation_averaging(len(wRi_expected), i2Ri1_input, i1Ti2_priors, v_corr_idxs)
self.assertTrue(
geometry_comparisons.compare_rotations(wRi_computed, wRi_expected, ROTATION_ANGLE_ERROR_THRESHOLD_DEG)
)
Expand Down Expand Up @@ -87,47 +81,36 @@ def test_simple_three_nodes_two_measurements(self):
i1Ri2 = Rot3.RzRyRx(0, 0, np.deg2rad(20))
i0Ri2 = i0Ri1.compose(i1Ri2)

i2Ri1_dict = {
(0, 1): i0Ri1.inverse(),
(1, 2): i1Ri2.inverse()
}
i2Ri1_dict = {(0, 1): i0Ri1.inverse(), (1, 2): i1Ri2.inverse()}

expected_wRi_list = [
Rot3(),
i0Ri1,
i0Ri2
]
expected_wRi_list = [Rot3(), i0Ri1, i0Ri2]

self.__execute_test(i2Ri1_dict, expected_wRi_list)

# def test_simple_with_prior(self):
# """Test a simple case with 1 measurement and a single pose prior."""
# expected_wRi_list = [Rot3.RzRyRx(0, 0, 0), Rot3.RzRyRx(0, np.deg2rad(30), 0), Rot3.RzRyRx(np.deg2rad(30), 0, 0)]

# i2Ri1_dict = {
# (0, 1): expected_wRi_list[1].between(expected_wRi_list[0])
# }

# expected_0R2 = expected_wRi_list[0].between(expected_wRi_list[2])
# i1Ti2_priors = {
# (0, 2): PosePrior(
# value=Pose3(expected_0R2, np.zeros((3,))),
# covariance=np.eye(6) * 1e-5,
# type=PosePriorType.SOFT_CONSTRAINT,
# )
# }

# two_view_estimation_reports = {
# (0, 1): TwoViewEstimationReport(v_corr_idxs=np.array([]), num_inliers_est_model=1),
# (0, 2): TwoViewEstimationReport(v_corr_idxs=np.array([]), num_inliers_est_model=1),
# }

# wRi_computed = self.obj.run_rotation_averaging(
# len(expected_wRi_list), i2Ri1_dict, i1Ti2_priors, two_view_estimation_reports
# )
# self.assertTrue(
# geometry_comparisons.compare_rotations(wRi_computed, expected_wRi_list, ROTATION_ANGLE_ERROR_THRESHOLD_DEG)
# )
def test_simple_with_prior(self):
"""Test a simple case with 1 measurement and a single pose prior."""
expected_wRi_list = [Rot3.RzRyRx(0, 0, 0), Rot3.RzRyRx(0, np.deg2rad(30), 0), Rot3.RzRyRx(np.deg2rad(30), 0, 0)]

i2Ri1_dict = {(0, 1): expected_wRi_list[1].between(expected_wRi_list[0])}

expected_0R2 = expected_wRi_list[0].between(expected_wRi_list[2])
i1Ti2_priors = {
(0, 2): PosePrior(
value=Pose3(expected_0R2, np.zeros((3,))),
covariance=np.eye(6) * 1e-5,
type=PosePriorType.SOFT_CONSTRAINT,
)
}

v_corr_idxs = {
(0, 1): _generate_corr_idxs(1),
(0, 2): _generate_corr_idxs(1),
}

wRi_computed = self.obj.run_rotation_averaging(len(expected_wRi_list), i2Ri1_dict, i1Ti2_priors, v_corr_idxs)
self.assertTrue(
geometry_comparisons.compare_rotations(wRi_computed, expected_wRi_list, ROTATION_ANGLE_ERROR_THRESHOLD_DEG)
)

def test_computation_graph(self):
"""Test the dask computation graph execution using a valid collection of relative poses."""
Expand All @@ -138,23 +121,21 @@ def test_computation_graph(self):
(0, 1): Rot3.RzRyRx(0, np.deg2rad(30), 0),
(1, 2): Rot3.RzRyRx(0, 0, np.deg2rad(20)),
}
two_view_estimation_reports = {
(0, 1): TwoViewEstimationReport(v_corr_idxs=np.array([]), num_inliers_est_model=200),
(1, 2): TwoViewEstimationReport(v_corr_idxs=np.array([]), num_inliers_est_model=500),
v_corr_idxs = {
(0, 1): _generate_corr_idxs(200),
(1, 2): _generate_corr_idxs(500),
}

i2Ri1_graph = dask.delayed(i2Ri1_dict)

# use the GTSAM API directly (without dask) for rotation averaging
i1Ti2_priors: Dict[Tuple[int, int], PosePrior] = {}
expected_wRi_list = self.obj.run_rotation_averaging(
num_poses, i2Ri1_dict, i1Ti2_priors, two_view_estimation_reports
)
expected_wRi_list = self.obj.run_rotation_averaging(num_poses, i2Ri1_dict, i1Ti2_priors, v_corr_idxs)

# use dask's computation graph
gt_wTi_list = [None] * len(expected_wRi_list)
rotations_graph, _ = self.obj.create_computation_graph(
num_poses, i2Ri1_graph, i1Ti2_priors, two_view_estimation_reports, gt_wTi_list
num_poses, i2Ri1_graph, i1Ti2_priors, gt_wTi_list, v_corr_idxs
)

with dask.config.set(scheduler="single-threaded"):
Expand Down Expand Up @@ -196,21 +177,25 @@ def test_nonconsecutive_indices(self):
}

# Keys do not overlap with i2Ri1_dict.
two_view_estimation_reports = {
(1, 2): TwoViewEstimationReport(v_corr_idxs=np.array([]), num_inliers_est_model=200),
(1, 3): TwoViewEstimationReport(v_corr_idxs=np.array([]), num_inliers_est_model=500),
(0, 2): TwoViewEstimationReport(v_corr_idxs=np.array([]), num_inliers_est_model=0),
v_corr_idxs = {
(1, 2): _generate_corr_idxs(200),
(1, 3): _generate_corr_idxs(500),
(0, 2): _generate_corr_idxs(0),
}

relative_pose_priors: Dict[Tuple[int, int], PosePrior] = {}
wRi_computed = self.obj.run_rotation_averaging(
num_images, i2Ri1_input, relative_pose_priors, two_view_estimation_reports
)
wRi_computed = self.obj.run_rotation_averaging(num_images, i2Ri1_input, relative_pose_priors, v_corr_idxs)
wRi_expected = [None, wTi1.rotation(), wTi2.rotation(), wTi3.rotation()]
self.assertTrue(
geometry_comparisons.compare_rotations(wRi_computed, wRi_expected, angular_error_threshold_degrees=0.1)
)

def _test_initialization(self, )


def _generate_corr_idxs(num_corrs: int) -> np.ndarray:
return np.random.randint(low=0, high=10000, size=(num_corrs, 2))


if __name__ == "__main__":
unittest.main()
7 changes: 4 additions & 3 deletions tests/utils/test_rotation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Authors: Ayush Baid
"""

import unittest
from typing import Dict, List, Tuple

Expand All @@ -18,7 +19,7 @@
RELATIVE_ROTATION_DICT = Dict[Tuple[int, int], Rot3]


def _get_ordered_chain_pose_data() -> Tuple[RELATIVE_ROTATION_DICT, List[float]]:
def _get_ordered_chain_pose_data() -> Tuple[RELATIVE_ROTATION_DICT, np.ndarray]:
"""Return data for a scenario with 5 camera poses, with ordering that follows their connectivity.
Accordingly, we specify i1 < i2 for all edges (i1,i2).
Expand Down Expand Up @@ -48,7 +49,7 @@ def _get_ordered_chain_pose_data() -> Tuple[RELATIVE_ROTATION_DICT, List[float]]
return i2Ri1_dict, wRi_list_euler_deg_expected


def _get_mixed_order_chain_pose_data() -> Tuple[RELATIVE_ROTATION_DICT, List[float]]:
def _get_mixed_order_chain_pose_data() -> Tuple[RELATIVE_ROTATION_DICT, np.ndarray]:
"""Return data for a scenario with 5 camera poses, with ordering that does NOT follow their connectivity.
Below, we do NOT specify i1 < i2 for all edges (i1,i2).
Expand Down Expand Up @@ -116,7 +117,7 @@ def _wrap_angles(angles: np.ndarray) -> np.ndarray:

class TestRotationUtil(unittest.TestCase):
def test_mst_initialization(self):
"""Test for 4 poses in a circle, with a pose connected all others."""
"""Test for 4 poses in a circle, with a pose connected to all others."""
i2Ri1_dict, wRi_expected = sample_poses.convert_data_for_rotation_averaging(
sample_poses.CIRCLE_ALL_EDGES_GLOBAL_POSES, sample_poses.CIRCLE_ALL_EDGES_RELATIVE_POSES
)
Expand Down

0 comments on commit beef33b

Please sign in to comment.