Skip to content

Commit

Permalink
Filtering outlier input cameras in trajectory estimation
Browse files Browse the repository at this point in the history
Summary: Useful for visualising colmap output where some frames are not correctly registered.

Reviewed By: bottler

Differential Revision: D38743191

fbshipit-source-id: e823df2997870dc41d76784e112d4349f904d311
  • Loading branch information
shapovalov authored and facebook-github-bot committed Aug 17, 2022
1 parent b7c826b commit d281f8e
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 1 deletion.
27 changes: 27 additions & 0 deletions pytorch3d/implicitron/tools/eval_video_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,21 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
import math
from typing import Optional, Tuple

import torch
from pytorch3d.common.compat import eigh
from pytorch3d.implicitron.tools import utils
from pytorch3d.implicitron.tools.circle_fitting import fit_circle_in_3d
from pytorch3d.renderer import look_at_view_transform, PerspectiveCameras
from pytorch3d.transforms import Scale


logger = logging.getLogger(__name__)


def generate_eval_video_cameras(
train_cameras,
n_eval_cams: int = 100,
Expand All @@ -27,6 +32,7 @@ def generate_eval_video_cameras(
infer_up_as_plane_normal: bool = True,
traj_offset: Optional[Tuple[float, float, float]] = None,
traj_offset_canonical: Optional[Tuple[float, float, float]] = None,
remove_outliers_rate: float = 0.0,
) -> PerspectiveCameras:
"""
Generate a camera trajectory rendering a scene from multiple viewpoints.
Expand All @@ -50,9 +56,16 @@ def generate_eval_video_cameras(
Active for the `trajectory_type="circular"`.
scene_center: The center of the scene in world coordinates which all
the cameras from the generated trajectory look at.
remove_outliers_rate: the number between 0 and 1; if > 0,
some outlier train_cameras will be removed from trajectory estimation;
the filtering is based on camera center coordinates; top and
bottom `remove_outliers_rate` cameras on each dimension are removed.
Returns:
Dictionary of camera instances which can be used as the test dataset
"""
if remove_outliers_rate > 0.0:
train_cameras = _remove_outlier_cameras(train_cameras, remove_outliers_rate)

if trajectory_type in ("figure_eight", "trefoil_knot", "figure_eight_knot"):
cam_centers = train_cameras.get_camera_center()
# get the nearest camera center to the mean of centers
Expand Down Expand Up @@ -167,6 +180,20 @@ def generate_eval_video_cameras(
return test_cameras


def _remove_outlier_cameras(
cameras: PerspectiveCameras, outlier_rate: float
) -> PerspectiveCameras:
keep_indices = utils.get_inlier_indicators(
cameras.get_camera_center(), dim=0, outlier_rate=outlier_rate
)
clean_cameras = cameras[keep_indices]
logger.info(
"Filtered outlier cameras when estimating the trajectory: "
f"{len(cameras)}{len(clean_cameras)}"
)
return clean_cameras


def _disambiguate_normal(normal, up):
up_t = torch.tensor(up).to(normal)
flip = (up_t * normal).sum().sign()
Expand Down
22 changes: 21 additions & 1 deletion pytorch3d/implicitron/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import dataclasses
import time
from contextlib import contextmanager
from typing import Any, Callable, Dict
from typing import Any, Callable, Dict, Iterable, Iterator

import torch

Expand Down Expand Up @@ -157,6 +157,26 @@ def cat_dataclass(batch, tensor_collator: Callable):
return type(elem)(**collated)


def recursive_visitor(it: Iterable[Any]) -> Iterator[Any]:
for x in it:
if isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
yield from recursive_visitor(x)
else:
yield x


def get_inlier_indicators(
tensor: torch.Tensor, dim: int, outlier_rate: float
) -> torch.Tensor:
remove_elements = int(min(outlier_rate, 1.0) * tensor.shape[dim] / 2)
hi = torch.topk(tensor, remove_elements, dim=dim).indices.tolist()
lo = torch.topk(-tensor, remove_elements, dim=dim).indices.tolist()
remove_indices = set(recursive_visitor([hi, lo]))
keep_indices = tensor.new_ones(tensor.shape[dim : dim + 1], dtype=torch.bool)
keep_indices[list(remove_indices)] = False
return keep_indices


class Timer:
"""
A simple class for timing execution.
Expand Down

0 comments on commit d281f8e

Please sign in to comment.