From d281f8efd1e52172256ecdf21e82c7547f235ef2 Mon Sep 17 00:00:00 2001 From: Roman Shapovalov Date: Wed, 17 Aug 2022 03:47:31 -0700 Subject: [PATCH] Filtering outlier input cameras in trajectory estimation Summary: Useful for visualising colmap output where some frames are not correctly registered. Reviewed By: bottler Differential Revision: D38743191 fbshipit-source-id: e823df2997870dc41d76784e112d4349f904d311 --- .../tools/eval_video_trajectory.py | 27 +++++++++++++++++++ pytorch3d/implicitron/tools/utils.py | 22 ++++++++++++++- 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/pytorch3d/implicitron/tools/eval_video_trajectory.py b/pytorch3d/implicitron/tools/eval_video_trajectory.py index e3d86d920..e540a3452 100644 --- a/pytorch3d/implicitron/tools/eval_video_trajectory.py +++ b/pytorch3d/implicitron/tools/eval_video_trajectory.py @@ -4,16 +4,21 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import logging import math from typing import Optional, Tuple import torch from pytorch3d.common.compat import eigh +from pytorch3d.implicitron.tools import utils from pytorch3d.implicitron.tools.circle_fitting import fit_circle_in_3d from pytorch3d.renderer import look_at_view_transform, PerspectiveCameras from pytorch3d.transforms import Scale +logger = logging.getLogger(__name__) + + def generate_eval_video_cameras( train_cameras, n_eval_cams: int = 100, @@ -27,6 +32,7 @@ def generate_eval_video_cameras( infer_up_as_plane_normal: bool = True, traj_offset: Optional[Tuple[float, float, float]] = None, traj_offset_canonical: Optional[Tuple[float, float, float]] = None, + remove_outliers_rate: float = 0.0, ) -> PerspectiveCameras: """ Generate a camera trajectory rendering a scene from multiple viewpoints. @@ -50,9 +56,16 @@ def generate_eval_video_cameras( Active for the `trajectory_type="circular"`. scene_center: The center of the scene in world coordinates which all the cameras from the generated trajectory look at. + remove_outliers_rate: the number between 0 and 1; if > 0, + some outlier train_cameras will be removed from trajectory estimation; + the filtering is based on camera center coordinates; top and + bottom `remove_outliers_rate` cameras on each dimension are removed. Returns: Dictionary of camera instances which can be used as the test dataset """ + if remove_outliers_rate > 0.0: + train_cameras = _remove_outlier_cameras(train_cameras, remove_outliers_rate) + if trajectory_type in ("figure_eight", "trefoil_knot", "figure_eight_knot"): cam_centers = train_cameras.get_camera_center() # get the nearest camera center to the mean of centers @@ -167,6 +180,20 @@ def generate_eval_video_cameras( return test_cameras +def _remove_outlier_cameras( + cameras: PerspectiveCameras, outlier_rate: float +) -> PerspectiveCameras: + keep_indices = utils.get_inlier_indicators( + cameras.get_camera_center(), dim=0, outlier_rate=outlier_rate + ) + clean_cameras = cameras[keep_indices] + logger.info( + "Filtered outlier cameras when estimating the trajectory: " + f"{len(cameras)} → {len(clean_cameras)}" + ) + return clean_cameras + + def _disambiguate_normal(normal, up): up_t = torch.tensor(up).to(normal) flip = (up_t * normal).sum().sign() diff --git a/pytorch3d/implicitron/tools/utils.py b/pytorch3d/implicitron/tools/utils.py index 5e70c1c59..430c117c1 100644 --- a/pytorch3d/implicitron/tools/utils.py +++ b/pytorch3d/implicitron/tools/utils.py @@ -9,7 +9,7 @@ import dataclasses import time from contextlib import contextmanager -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, Iterable, Iterator import torch @@ -157,6 +157,26 @@ def cat_dataclass(batch, tensor_collator: Callable): return type(elem)(**collated) +def recursive_visitor(it: Iterable[Any]) -> Iterator[Any]: + for x in it: + if isinstance(x, Iterable) and not isinstance(x, (str, bytes)): + yield from recursive_visitor(x) + else: + yield x + + +def get_inlier_indicators( + tensor: torch.Tensor, dim: int, outlier_rate: float +) -> torch.Tensor: + remove_elements = int(min(outlier_rate, 1.0) * tensor.shape[dim] / 2) + hi = torch.topk(tensor, remove_elements, dim=dim).indices.tolist() + lo = torch.topk(-tensor, remove_elements, dim=dim).indices.tolist() + remove_indices = set(recursive_visitor([hi, lo])) + keep_indices = tensor.new_ones(tensor.shape[dim : dim + 1], dtype=torch.bool) + keep_indices[list(remove_indices)] = False + return keep_indices + + class Timer: """ A simple class for timing execution.