Skip to content

Commit

Permalink
Working
Browse files Browse the repository at this point in the history
  • Loading branch information
travisdriver committed Apr 11, 2023
1 parent dda4af0 commit 4ec82dd
Showing 1 changed file with 36 additions and 16 deletions.
52 changes: 36 additions & 16 deletions gtsfm/data_association/data_assoc.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __validate_track(self, sfm_track: Optional[SfmTrack]) -> bool:
"""Validate the track by checking its length."""
return sfm_track is not None and sfm_track.numberMeasurements() >= self.min_track_len

def run(
def run_da(
self,
num_images: int,
cameras: Dict[int, gtsfm_types.CAMERA_TYPE],
Expand All @@ -80,12 +80,15 @@ def run(
relative_pose_priors: Dict[Tuple[int, int], Optional[PosePrior]],
images: Optional[List[Image]] = None,
) -> Tuple[GtsfmData, GtsfmMetricsGroup]:
"""Perform the data association.
"""Perform the data association and compute metrics.
Args:
num_images: Number of images in the scene.
cameras: dictionary, with image index -> camera mapping.
tracks_2d: list of 2D tracks.
sfm_tracks: List of triangulated tracks.
avg_track_repoj_errors: List of average reprojection errors per track.
triangulation_exit_codes: exit codes for each triangulation call.
cameras_gt: list of GT cameras, to be used for benchmarking the tracks.
images: a list of all images in scene (optional and only for track patch visualization)
Expand Down Expand Up @@ -204,21 +207,38 @@ def run_triangulation(
cameras: Dict[int, gtsfm_types.CAMERA_TYPE],
tracks_2d: List[SfmTrack2d],
) -> Tuple[List[Delayed], List[Delayed], List[Delayed]]:
"""Performs triangulation of 2D tracks in parallel.
Ref: https://docs.dask.org/en/stable/delayed-best-practices.html#compute-on-lots-of-computation-at-once
Args:
cameras: list of cameras wrapped up as Delayed.
tracks_2d: list of tracks wrapped up as Delayed.
Returns:
sfm_tracks: List of triangulated tracks.
avg_track_repoj_errors: List of average reprojection errors per track.
triangulation_exit_codes: exit codes for each triangulation call.
"""
# Initialize 3D landmark for each track
point3d_initializer = Point3dInitializer(cameras, self.triangulation_options)

# Loop through tracks and triangulate.
delayed_sfm_tracks, delayed_avg_track_reproj_errors, delayed_triangulation_exit_codes = [], [], []
# Loop through tracks and and generate delayed triangulation tasks.
triangulation_results = []
for track_2d in tracks_2d:
# triangulate and filter based on reprojection error
(delayed_sfm_track, delayed_avg_track_reproj_error, delayed_triangulation_exit_code) = dask.delayed(
point3d_initializer.triangulate, nout=3
)(track_2d)
delayed_sfm_tracks.append(delayed_sfm_track)
delayed_avg_track_reproj_errors.append(delayed_avg_track_reproj_error)
delayed_triangulation_exit_codes.append(delayed_triangulation_exit_code)
triangulation_results.append(dask.delayed(point3d_initializer.triangulate)(track_2d))

# Perform triangulation in parallel.
triangulation_results = dask.compute(*triangulation_results)

# Unpack results.
sfm_tracks, avg_track_reproj_errors, triangulation_exit_codes = [], [], []
for result in triangulation_results:
sfm_tracks.append(result[0])
avg_track_reproj_errors.append(result[1])
triangulation_exit_codes.append(result[2])

return delayed_sfm_tracks, delayed_avg_track_reproj_errors, delayed_triangulation_exit_codes
return sfm_tracks, avg_track_reproj_errors, triangulation_exit_codes

def create_computation_graph(
self,
Expand Down Expand Up @@ -246,12 +266,12 @@ def create_computation_graph(
"""

# Triangulate 2D tracks.
sfm_tracks, avg_track_reproj_errors, triangulation_exit_codes = self.run_triangulation(
cameras=cameras.compute(), tracks_2d=tracks_2d.compute()
sfm_tracks, avg_track_reproj_errors, triangulation_exit_codes = dask.delayed(self.run_triangulation, nout=3)(
cameras=cameras, tracks_2d=tracks_2d
)

# Unpack results, create BA input, and compute metrics.
ba_input_graph, data_assoc_metrics_graph = dask.delayed(self.run, nout=2)(
# Perform DA, create BA input, and compute metrics.
ba_input_graph, data_assoc_metrics_graph = dask.delayed(self.run_da, nout=2)(
num_images=num_images,
cameras=cameras,
tracks_2d=tracks_2d,
Expand Down

0 comments on commit 4ec82dd

Please sign in to comment.