From ebdbfde0cee9d6adca2c0508f1a664c13d3cd65a Mon Sep 17 00:00:00 2001 From: Ildar Salakhiev Date: Tue, 4 Apr 2023 07:17:43 -0700 Subject: [PATCH] Extract BlobLoader class from JsonIndexDataset and moving crop_by_bbox to FrameData Summary: extracted blob loader added documentation for blob_loader did some refactoring on fields for detailed steps and discussions see: https://github.com/facebookresearch/pytorch3d/pull/1463 https://github.com/fairinternal/pixar_replay/pull/160 Reviewed By: bottler Differential Revision: D44061728 fbshipit-source-id: eefb21e9679003045d73729f96e6a93a1d4d2d51 --- .../dataset/data_loader_map_provider.py | 3 +- pytorch3d/implicitron/dataset/dataset_base.py | 198 +---- pytorch3d/implicitron/dataset/frame_data.py | 728 ++++++++++++++++++ .../implicitron/dataset/json_index_dataset.py | 538 ++----------- .../dataset/single_sequence_dataset.py | 6 +- pytorch3d/implicitron/dataset/types.py | 2 + pytorch3d/implicitron/dataset/utils.py | 313 +++++++- pytorch3d/implicitron/dataset/visualize.py | 2 +- .../evaluation/evaluate_new_view_synthesis.py | 2 +- tests/implicitron/test_batch_sampler.py | 3 +- tests/implicitron/test_bbox.py | 86 ++- tests/implicitron/test_data_cow.py | 2 +- tests/implicitron/test_evaluation.py | 6 +- tests/implicitron/test_frame_data_builder.py | 224 ++++++ .../test_json_index_dataset_provider_v2.py | 2 +- 15 files changed, 1421 insertions(+), 694 deletions(-) create mode 100644 pytorch3d/implicitron/dataset/frame_data.py create mode 100644 tests/implicitron/test_frame_data_builder.py diff --git a/pytorch3d/implicitron/dataset/data_loader_map_provider.py b/pytorch3d/implicitron/dataset/data_loader_map_provider.py index 50a792183..6c0436adf 100644 --- a/pytorch3d/implicitron/dataset/data_loader_map_provider.py +++ b/pytorch3d/implicitron/dataset/data_loader_map_provider.py @@ -18,8 +18,9 @@ Sampler, ) -from .dataset_base import DatasetBase, FrameData +from .dataset_base import DatasetBase from .dataset_map_provider import DatasetMap +from .frame_data import FrameData from .scene_batch_sampler import SceneBatchSampler from .utils import is_known_frame_scalar diff --git a/pytorch3d/implicitron/dataset/dataset_base.py b/pytorch3d/implicitron/dataset/dataset_base.py index 283ef3dcd..033b170c0 100644 --- a/pytorch3d/implicitron/dataset/dataset_base.py +++ b/pytorch3d/implicitron/dataset/dataset_base.py @@ -5,217 +5,27 @@ # LICENSE file in the root directory of this source tree. from collections import defaultdict -from dataclasses import dataclass, field, fields +from dataclasses import dataclass from typing import ( - Any, ClassVar, Dict, Iterable, Iterator, List, - Mapping, Optional, Sequence, Tuple, Type, - Union, ) -import numpy as np import torch -from pytorch3d.renderer.camera_utils import join_cameras_as_batch -from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras -from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds - -@dataclass -class FrameData(Mapping[str, Any]): - """ - A type of the elements returned by indexing the dataset object. - It can represent both individual frames and batches of thereof; - in this documentation, the sizes of tensors refer to single frames; - add the first batch dimension for the collation result. - - Args: - frame_number: The number of the frame within its sequence. - 0-based continuous integers. - sequence_name: The unique name of the frame's sequence. - sequence_category: The object category of the sequence. - frame_timestamp: The time elapsed since the start of a sequence in sec. - image_size_hw: The size of the image in pixels; (height, width) tensor - of shape (2,). - image_path: The qualified path to the loaded image (with dataset_root). - image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image - of the frame; elements are floats in [0, 1]. - mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image - regions. Regions can be invalid (mask_crop[i,j]=0) in case they - are a result of zero-padding of the image after cropping around - the object bounding box; elements are floats in {0.0, 1.0}. - depth_path: The qualified path to the frame's depth map. - depth_map: A float Tensor of shape `(1, H, W)` holding the depth map - of the frame; values correspond to distances from the camera; - use `depth_mask` and `mask_crop` to filter for valid pixels. - depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the - depth map that are valid for evaluation, they have been checked for - consistency across views; elements are floats in {0.0, 1.0}. - mask_path: A qualified path to the foreground probability mask. - fg_probability: A Tensor of `(1, H, W)` denoting the probability of the - pixels belonging to the captured object; elements are floats - in [0, 1]. - bbox_xywh: The bounding box tightly enclosing the foreground object in the - format (x0, y0, width, height). The convention assumes that - `x0+width` and `y0+height` includes the boundary of the box. - I.e., to slice out the corresponding crop from an image tensor `I` - we execute `crop = I[..., y0:y0+height, x0:x0+width]` - crop_bbox_xywh: The bounding box denoting the boundaries of `image_rgb` - in the original image coordinates in the format (x0, y0, width, height). - The convention is the same as for `bbox_xywh`. `crop_bbox_xywh` differs - from `bbox_xywh` due to padding (which can happen e.g. due to - setting `JsonIndexDataset.box_crop_context > 0`) - camera: A PyTorch3D camera object corresponding the frame's viewpoint, - corrected for cropping if it happened. - camera_quality_score: The score proportional to the confidence of the - frame's camera estimation (the higher the more accurate). - point_cloud_quality_score: The score proportional to the accuracy of the - frame's sequence point cloud (the higher the more accurate). - sequence_point_cloud_path: The path to the sequence's point cloud. - sequence_point_cloud: A PyTorch3D Pointclouds object holding the - point cloud corresponding to the frame's sequence. When the object - represents a batch of frames, point clouds may be deduplicated; - see `sequence_point_cloud_idx`. - sequence_point_cloud_idx: Integer indices mapping frame indices to the - corresponding point clouds in `sequence_point_cloud`; to get the - corresponding point cloud to `image_rgb[i]`, use - `sequence_point_cloud[sequence_point_cloud_idx[i]]`. - frame_type: The type of the loaded frame specified in - `subset_lists_file`, if provided. - meta: A dict for storing additional frame information. - """ - - frame_number: Optional[torch.LongTensor] - sequence_name: Union[str, List[str]] - sequence_category: Union[str, List[str]] - frame_timestamp: Optional[torch.Tensor] = None - image_size_hw: Optional[torch.Tensor] = None - image_path: Union[str, List[str], None] = None - image_rgb: Optional[torch.Tensor] = None - # masks out padding added due to cropping the square bit - mask_crop: Optional[torch.Tensor] = None - depth_path: Union[str, List[str], None] = None - depth_map: Optional[torch.Tensor] = None - depth_mask: Optional[torch.Tensor] = None - mask_path: Union[str, List[str], None] = None - fg_probability: Optional[torch.Tensor] = None - bbox_xywh: Optional[torch.Tensor] = None - crop_bbox_xywh: Optional[torch.Tensor] = None - camera: Optional[PerspectiveCameras] = None - camera_quality_score: Optional[torch.Tensor] = None - point_cloud_quality_score: Optional[torch.Tensor] = None - sequence_point_cloud_path: Union[str, List[str], None] = None - sequence_point_cloud: Optional[Pointclouds] = None - sequence_point_cloud_idx: Optional[torch.Tensor] = None - frame_type: Union[str, List[str], None] = None # known | unseen - meta: dict = field(default_factory=lambda: {}) - - def to(self, *args, **kwargs): - new_params = {} - for f in fields(self): - value = getattr(self, f.name) - if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)): - new_params[f.name] = value.to(*args, **kwargs) - else: - new_params[f.name] = value - return type(self)(**new_params) - - def cpu(self): - return self.to(device=torch.device("cpu")) - - def cuda(self): - return self.to(device=torch.device("cuda")) - - # the following functions make sure **frame_data can be passed to functions - def __iter__(self): - for f in fields(self): - yield f.name - - def __getitem__(self, key): - return getattr(self, key) - - def __len__(self): - return len(fields(self)) - - @classmethod - def collate(cls, batch): - """ - Given a list objects `batch` of class `cls`, collates them into a batched - representation suitable for processing with deep networks. - """ - - elem = batch[0] - - if isinstance(elem, cls): - pointcloud_ids = [id(el.sequence_point_cloud) for el in batch] - id_to_idx = defaultdict(list) - for i, pc_id in enumerate(pointcloud_ids): - id_to_idx[pc_id].append(i) - - sequence_point_cloud = [] - sequence_point_cloud_idx = -np.ones((len(batch),)) - for i, ind in enumerate(id_to_idx.values()): - sequence_point_cloud_idx[ind] = i - sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud) - assert (sequence_point_cloud_idx >= 0).all() - - override_fields = { - "sequence_point_cloud": sequence_point_cloud, - "sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(), - } - # note that the pre-collate value of sequence_point_cloud_idx is unused - - collated = {} - for f in fields(elem): - list_values = override_fields.get( - f.name, [getattr(d, f.name) for d in batch] - ) - collated[f.name] = ( - cls.collate(list_values) - if all(list_value is not None for list_value in list_values) - else None - ) - return cls(**collated) - - elif isinstance(elem, Pointclouds): - return join_pointclouds_as_batch(batch) - - elif isinstance(elem, CamerasBase): - # TODO: don't store K; enforce working in NDC space - return join_cameras_as_batch(batch) - else: - return torch.utils.data._utils.collate.default_collate(batch) - - -class _GenericWorkaround: - """ - OmegaConf.structured has a weirdness when you try to apply - it to a dataclass whose first base class is a Generic which is not - Dict. The issue is with a function called get_dict_key_value_types - in omegaconf/_utils.py. - For example this fails: - - @dataclass(eq=False) - class D(torch.utils.data.Dataset[int]): - a: int = 3 - - OmegaConf.structured(D) - - We avoid the problem by adding this class as an extra base class. - """ - - pass +from pytorch3d.implicitron.dataset.frame_data import FrameData +from pytorch3d.implicitron.dataset.utils import GenericWorkaround @dataclass(eq=False) -class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]): +class DatasetBase(GenericWorkaround, torch.utils.data.Dataset[FrameData]): """ Base class to describe a dataset to be used with Implicitron. diff --git a/pytorch3d/implicitron/dataset/frame_data.py b/pytorch3d/implicitron/dataset/frame_data.py new file mode 100644 index 000000000..1a4e1b5c6 --- /dev/null +++ b/pytorch3d/implicitron/dataset/frame_data.py @@ -0,0 +1,728 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +from abc import ABC, abstractmethod +from collections import defaultdict +from dataclasses import dataclass, field, fields +from typing import ( + Any, + ClassVar, + Generic, + List, + Mapping, + Optional, + Tuple, + Type, + TypeVar, + Union, +) + +import numpy as np +import torch + +from pytorch3d.implicitron.dataset import types +from pytorch3d.implicitron.dataset.utils import ( + adjust_camera_to_bbox_crop_, + adjust_camera_to_image_scale_, + bbox_xyxy_to_xywh, + clamp_box_to_image_bounds_and_round, + crop_around_box, + GenericWorkaround, + get_bbox_from_mask, + get_clamp_bbox, + load_depth, + load_depth_mask, + load_image, + load_mask, + load_pointcloud, + rescale_bbox, + resize_image, + safe_as_tensor, +) +from pytorch3d.implicitron.tools.config import registry, ReplaceableBase +from pytorch3d.renderer.camera_utils import join_cameras_as_batch +from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras +from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds + + +@dataclass +class FrameData(Mapping[str, Any]): + """ + A type of the elements returned by indexing the dataset object. + It can represent both individual frames and batches of thereof; + in this documentation, the sizes of tensors refer to single frames; + add the first batch dimension for the collation result. + + Args: + frame_number: The number of the frame within its sequence. + 0-based continuous integers. + sequence_name: The unique name of the frame's sequence. + sequence_category: The object category of the sequence. + frame_timestamp: The time elapsed since the start of a sequence in sec. + image_size_hw: The size of the original image in pixels; (height, width) + tensor of shape (2,). Note that it is optional, e.g. it can be `None` + if the frame annotation has no size ans image_rgb has not [yet] been + loaded. Image-less FrameData is valid but mutators like crop/resize + may fail if the original image size cannot be deduced. + effective_image_size_hw: The size of the image after mutations such as + crop/resize in pixels; (height, width). if the image has not been mutated, + it is equal to `image_size_hw`. Note that it is also optional, for the + same reason as `image_size_hw`. + image_path: The qualified path to the loaded image (with dataset_root). + image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image + of the frame; elements are floats in [0, 1]. + mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image + regions. Regions can be invalid (mask_crop[i,j]=0) in case they + are a result of zero-padding of the image after cropping around + the object bounding box; elements are floats in {0.0, 1.0}. + depth_path: The qualified path to the frame's depth map. + depth_map: A float Tensor of shape `(1, H, W)` holding the depth map + of the frame; values correspond to distances from the camera; + use `depth_mask` and `mask_crop` to filter for valid pixels. + depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the + depth map that are valid for evaluation, they have been checked for + consistency across views; elements are floats in {0.0, 1.0}. + mask_path: A qualified path to the foreground probability mask. + fg_probability: A Tensor of `(1, H, W)` denoting the probability of the + pixels belonging to the captured object; elements are floats + in [0, 1]. + bbox_xywh: The bounding box tightly enclosing the foreground object in the + format (x0, y0, width, height). The convention assumes that + `x0+width` and `y0+height` includes the boundary of the box. + I.e., to slice out the corresponding crop from an image tensor `I` + we execute `crop = I[..., y0:y0+height, x0:x0+width]` + crop_bbox_xywh: The bounding box denoting the boundaries of `image_rgb` + in the original image coordinates in the format (x0, y0, width, height). + The convention is the same as for `bbox_xywh`. `crop_bbox_xywh` differs + from `bbox_xywh` due to padding (which can happen e.g. due to + setting `JsonIndexDataset.box_crop_context > 0`) + camera: A PyTorch3D camera object corresponding the frame's viewpoint, + corrected for cropping if it happened. + camera_quality_score: The score proportional to the confidence of the + frame's camera estimation (the higher the more accurate). + point_cloud_quality_score: The score proportional to the accuracy of the + frame's sequence point cloud (the higher the more accurate). + sequence_point_cloud_path: The path to the sequence's point cloud. + sequence_point_cloud: A PyTorch3D Pointclouds object holding the + point cloud corresponding to the frame's sequence. When the object + represents a batch of frames, point clouds may be deduplicated; + see `sequence_point_cloud_idx`. + sequence_point_cloud_idx: Integer indices mapping frame indices to the + corresponding point clouds in `sequence_point_cloud`; to get the + corresponding point cloud to `image_rgb[i]`, use + `sequence_point_cloud[sequence_point_cloud_idx[i]]`. + frame_type: The type of the loaded frame specified in + `subset_lists_file`, if provided. + meta: A dict for storing additional frame information. + """ + + frame_number: Optional[torch.LongTensor] + sequence_name: Union[str, List[str]] + sequence_category: Union[str, List[str]] + frame_timestamp: Optional[torch.Tensor] = None + image_size_hw: Optional[torch.LongTensor] = None + effective_image_size_hw: Optional[torch.LongTensor] = None + image_path: Union[str, List[str], None] = None + image_rgb: Optional[torch.Tensor] = None + # masks out padding added due to cropping the square bit + mask_crop: Optional[torch.Tensor] = None + depth_path: Union[str, List[str], None] = None + depth_map: Optional[torch.Tensor] = None + depth_mask: Optional[torch.Tensor] = None + mask_path: Union[str, List[str], None] = None + fg_probability: Optional[torch.Tensor] = None + bbox_xywh: Optional[torch.Tensor] = None + crop_bbox_xywh: Optional[torch.Tensor] = None + camera: Optional[PerspectiveCameras] = None + camera_quality_score: Optional[torch.Tensor] = None + point_cloud_quality_score: Optional[torch.Tensor] = None + sequence_point_cloud_path: Union[str, List[str], None] = None + sequence_point_cloud: Optional[Pointclouds] = None + sequence_point_cloud_idx: Optional[torch.Tensor] = None + frame_type: Union[str, List[str], None] = None # known | unseen + meta: dict = field(default_factory=lambda: {}) + + # NOTE that batching resets this attribute + _uncropped: bool = field(init=False, default=True) + + def to(self, *args, **kwargs): + new_params = {} + for field_name in iter(self): + value = getattr(self, field_name) + if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)): + new_params[field_name] = value.to(*args, **kwargs) + else: + new_params[field_name] = value + frame_data = type(self)(**new_params) + frame_data._uncropped = self._uncropped + return frame_data + + def cpu(self): + return self.to(device=torch.device("cpu")) + + def cuda(self): + return self.to(device=torch.device("cuda")) + + # the following functions make sure **frame_data can be passed to functions + def __iter__(self): + for f in fields(self): + if f.name.startswith("_"): + continue + + yield f.name + + def __getitem__(self, key): + return getattr(self, key) + + def __len__(self): + return sum(1 for f in iter(self)) + + def crop_by_metadata_bbox_( + self, + box_crop_context: float, + ) -> None: + """Crops the frame data in-place by (possibly expanded) bounding box. + The bounding box is taken from the object state (usually taken from + the frame annotation or estimated from the foregroubnd mask). + If the expanded bounding box does not fit the image, it is clamped, + i.e. the image is *not* padded. + + Args: + box_crop_context: rate of expansion for bbox; 0 means no expansion, + + Raises: + ValueError: If the object does not contain a bounding box (usually when no + mask annotation is provided) + ValueError: If the frame data have been cropped or resized, thus the intrinsic + bounding box is not valid for the current image size. + ValueError: If the frame does not have an image size (usually a corner case + when no image has been loaded) + """ + if self.bbox_xywh is None: + raise ValueError("Attempted cropping by metadata with empty bounding box") + + if not self._uncropped: + raise ValueError( + "Trying to apply the metadata bounding box to already cropped " + "or resized image; coordinates have changed." + ) + + self._crop_by_bbox_( + box_crop_context, + self.bbox_xywh, + ) + + def crop_by_given_bbox_( + self, + box_crop_context: float, + bbox_xywh: torch.Tensor, + ) -> None: + """Crops the frame data in-place by (possibly expanded) bounding box. + If the expanded bounding box does not fit the image, it is clamped, + i.e. the image is *not* padded. + + Args: + box_crop_context: rate of expansion for bbox; 0 means no expansion, + bbox_xywh: bounding box in [x0, y0, width, height] format. If float + tensor, values are floored (after converting to [x0, y0, x1, y1]). + + Raises: + ValueError: If the frame does not have an image size (usually a corner case + when no image has been loaded) + """ + self._crop_by_bbox_( + box_crop_context, + bbox_xywh, + ) + + def _crop_by_bbox_( + self, + box_crop_context: float, + bbox_xywh: torch.Tensor, + ) -> None: + """Crops the frame data in-place by (possibly expanded) bounding box. + If the expanded bounding box does not fit the image, it is clamped, + i.e. the image is *not* padded. + + Args: + box_crop_context: rate of expansion for bbox; 0 means no expansion, + bbox_xywh: bounding box in [x0, y0, width, height] format. If float + tensor, values are floored (after converting to [x0, y0, x1, y1]). + + Raises: + ValueError: If the frame does not have an image size (usually a corner case + when no image has been loaded) + """ + effective_image_size_hw = self.effective_image_size_hw + if effective_image_size_hw is None: + raise ValueError("Calling crop on image-less FrameData") + + bbox_xyxy = get_clamp_bbox( + bbox_xywh, + image_path=self.image_path, # pyre-ignore + box_crop_context=box_crop_context, + ) + clamp_bbox_xyxy = clamp_box_to_image_bounds_and_round( + bbox_xyxy, + image_size_hw=tuple(self.effective_image_size_hw), # pyre-ignore + ) + crop_bbox_xywh = bbox_xyxy_to_xywh(clamp_bbox_xyxy) + + if self.fg_probability is not None: + self.fg_probability = crop_around_box( + self.fg_probability, + clamp_bbox_xyxy, + self.mask_path, # pyre-ignore + ) + if self.image_rgb is not None: + self.image_rgb = crop_around_box( + self.image_rgb, + clamp_bbox_xyxy, + self.image_path, # pyre-ignore + ) + + depth_map = self.depth_map + if depth_map is not None: + clamp_bbox_xyxy_depth = rescale_bbox( + clamp_bbox_xyxy, tuple(depth_map.shape[-2:]), effective_image_size_hw + ).long() + self.depth_map = crop_around_box( + depth_map, + clamp_bbox_xyxy_depth, + self.depth_path, # pyre-ignore + ) + + depth_mask = self.depth_mask + if depth_mask is not None: + clamp_bbox_xyxy_depth = rescale_bbox( + clamp_bbox_xyxy, tuple(depth_mask.shape[-2:]), effective_image_size_hw + ).long() + self.depth_mask = crop_around_box( + depth_mask, + clamp_bbox_xyxy_depth, + self.mask_path, # pyre-ignore + ) + + # changing principal_point according to bbox_crop + if self.camera is not None: + adjust_camera_to_bbox_crop_( + camera=self.camera, + image_size_wh=effective_image_size_hw.flip(dims=[-1]), + clamp_bbox_xywh=crop_bbox_xywh, + ) + + # pyre-ignore + self.effective_image_size_hw = crop_bbox_xywh[..., 2:].flip(dims=[-1]) + self._uncropped = False + + def resize_frame_(self, new_size_hw: torch.LongTensor) -> None: + """Resizes frame data in-place according to given dimensions. + + Args: + new_size_hw: target image size [height, width], a LongTensor of shape (2,) + + Raises: + ValueError: If the frame does not have an image size (usually a corner case + when no image has been loaded) + """ + + effective_image_size_hw = self.effective_image_size_hw + if effective_image_size_hw is None: + raise ValueError("Calling resize on image-less FrameData") + + image_height, image_width = new_size_hw.tolist() + + if self.fg_probability is not None: + self.fg_probability, _, _ = resize_image( + self.fg_probability, + image_height=image_height, + image_width=image_width, + mode="nearest", + ) + + if self.image_rgb is not None: + self.image_rgb, _, self.mask_crop = resize_image( + self.image_rgb, image_height=image_height, image_width=image_width + ) + + if self.depth_map is not None: + self.depth_map, _, _ = resize_image( + self.depth_map, + image_height=image_height, + image_width=image_width, + mode="nearest", + ) + + if self.depth_mask is not None: + self.depth_mask, _, _ = resize_image( + self.depth_mask, + image_height=image_height, + image_width=image_width, + mode="nearest", + ) + + if self.camera is not None: + if self.image_size_hw is None: + raise ValueError( + "image_size_hw has to be defined for resizing FrameData with cameras." + ) + adjust_camera_to_image_scale_( + camera=self.camera, + original_size_wh=effective_image_size_hw.flip(dims=[-1]), + new_size_wh=new_size_hw.flip(dims=[-1]), # pyre-ignore + ) + + self.effective_image_size_hw = new_size_hw + self._uncropped = False + + @classmethod + def collate(cls, batch): + """ + Given a list objects `batch` of class `cls`, collates them into a batched + representation suitable for processing with deep networks. + """ + + elem = batch[0] + + if isinstance(elem, cls): + pointcloud_ids = [id(el.sequence_point_cloud) for el in batch] + id_to_idx = defaultdict(list) + for i, pc_id in enumerate(pointcloud_ids): + id_to_idx[pc_id].append(i) + + sequence_point_cloud = [] + sequence_point_cloud_idx = -np.ones((len(batch),)) + for i, ind in enumerate(id_to_idx.values()): + sequence_point_cloud_idx[ind] = i + sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud) + assert (sequence_point_cloud_idx >= 0).all() + + override_fields = { + "sequence_point_cloud": sequence_point_cloud, + "sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(), + } + # note that the pre-collate value of sequence_point_cloud_idx is unused + + collated = {} + for f in fields(elem): + if not f.init: + continue + + list_values = override_fields.get( + f.name, [getattr(d, f.name) for d in batch] + ) + collated[f.name] = ( + cls.collate(list_values) + if all(list_value is not None for list_value in list_values) + else None + ) + return cls(**collated) + + elif isinstance(elem, Pointclouds): + return join_pointclouds_as_batch(batch) + + elif isinstance(elem, CamerasBase): + # TODO: don't store K; enforce working in NDC space + return join_cameras_as_batch(batch) + else: + return torch.utils.data._utils.collate.default_collate(batch) + + +FrameDataSubtype = TypeVar("FrameDataSubtype", bound=FrameData) + + +class FrameDataBuilderBase(ReplaceableBase, Generic[FrameDataSubtype], ABC): + """A base class for FrameDataBuilders that build a FrameData object, load and + process the binary data (crop and resize). Implementations should parametrize + the class with a subtype of FrameData and set frame_data_type class variable to + that type. They have to also implement `build` method. + """ + + # To be initialised to FrameDataSubtype + frame_data_type: ClassVar[Type[FrameDataSubtype]] + + @abstractmethod + def build( + self, + frame_annotation: types.FrameAnnotation, + sequence_annotation: types.SequenceAnnotation, + ) -> FrameDataSubtype: + """An abstract method to build the frame data based on raw frame/sequence + annotations, load the binary data and adjust them according to the metadata. + """ + raise NotImplementedError() + + +class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC): + """ + A class to build a FrameData object, load and process the binary data (crop and + resize). This is an abstract class for extending to build FrameData subtypes. Most + users need to use concrete `FrameDataBuilder` class instead. + Beware that modifications of frame data are done in-place. + + Args: + dataset_root: The root folder of the dataset; all the paths in jsons are + specified relative to this root (but not json paths themselves). + load_images: Enable loading the frame RGB data. + load_depths: Enable loading the frame depth maps. + load_depth_masks: Enable loading the frame depth map masks denoting the + depth values used for evaluation (the points consistent across views). + load_masks: Enable loading frame foreground masks. + load_point_clouds: Enable loading sequence-level point clouds. + max_points: Cap on the number of loaded points in the point cloud; + if reached, they are randomly sampled without replacement. + mask_images: Whether to mask the images with the loaded foreground masks; + 0 value is used for background. + mask_depths: Whether to mask the depth maps with the loaded foreground + masks; 0 value is used for background. + image_height: The height of the returned images, masks, and depth maps; + aspect ratio is preserved during cropping/resizing. + image_width: The width of the returned images, masks, and depth maps; + aspect ratio is preserved during cropping/resizing. + box_crop: Enable cropping of the image around the bounding box inferred + from the foreground region of the loaded segmentation mask; masks + and depth maps are cropped accordingly; cameras are corrected. + box_crop_mask_thr: The threshold used to separate pixels into foreground + and background based on the foreground_probability mask; if no value + is greater than this threshold, the loader lowers it and repeats. + box_crop_context: The amount of additional padding added to each + dimension of the cropping bounding box, relative to box size. + path_manager: Optionally a PathManager for interpreting paths in a special way. + """ + + dataset_root: str = "" + load_images: bool = True + load_depths: bool = True + load_depth_masks: bool = True + load_masks: bool = True + load_point_clouds: bool = False + max_points: int = 0 + mask_images: bool = False + mask_depths: bool = False + image_height: Optional[int] = 800 + image_width: Optional[int] = 800 + box_crop: bool = True + box_crop_mask_thr: float = 0.4 + box_crop_context: float = 0.3 + path_manager: Any = None + + def build( + self, + frame_annotation: types.FrameAnnotation, + sequence_annotation: types.SequenceAnnotation, + load_blobs: bool = True, + ) -> FrameDataSubtype: + """Builds the frame data based on raw frame/sequence annotations, loads the + binary data and adjust them according to the metadata. The processing includes: + * if box_crop is set, the image/mask/depth are cropped with the bounding + box provided or estimated from MaskAnnotation, + * if image_height/image_width are set, the image/mask/depth are resized to + fit that resolution. Note that the aspect ratio is preserved, and the + (possibly cropped) image is pasted into the top-left corner. In the + resulting frame_data, mask_crop field corresponds to the mask of the + pasted image. + + Args: + frame_annotation: frame annotation + sequence_annotation: sequence annotation + load_blobs: if the function should attempt loading the image, depth map + and mask, and foreground mask + + Returns: + The constructed FrameData object. + """ + + point_cloud = sequence_annotation.point_cloud + + frame_data = self.frame_data_type( + frame_number=safe_as_tensor(frame_annotation.frame_number, torch.long), + frame_timestamp=safe_as_tensor( + frame_annotation.frame_timestamp, torch.float + ), + sequence_name=frame_annotation.sequence_name, + sequence_category=sequence_annotation.category, + camera_quality_score=safe_as_tensor( + sequence_annotation.viewpoint_quality_score, torch.float + ), + point_cloud_quality_score=safe_as_tensor( + point_cloud.quality_score, torch.float + ) + if point_cloud is not None + else None, + ) + + if load_blobs and self.load_masks and frame_annotation.mask is not None: + ( + frame_data.fg_probability, + frame_data.mask_path, + frame_data.bbox_xywh, + ) = self._load_fg_probability(frame_annotation) + + if frame_annotation.image is not None: + image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long) + frame_data.image_size_hw = image_size_hw # original image size + # image size after crop/resize + frame_data.effective_image_size_hw = image_size_hw + + if load_blobs and self.load_images: + ( + frame_data.image_rgb, + frame_data.image_path, + ) = self._load_images(frame_annotation, frame_data.fg_probability) + + if load_blobs and self.load_depths and frame_annotation.depth is not None: + ( + frame_data.depth_map, + frame_data.depth_path, + frame_data.depth_mask, + ) = self._load_mask_depth(frame_annotation, frame_data.fg_probability) + + if load_blobs and self.load_point_clouds and point_cloud is not None: + pcl_path = self._fix_point_cloud_path(point_cloud.path) + frame_data.sequence_point_cloud = load_pointcloud( + self._local_path(pcl_path), max_points=self.max_points + ) + frame_data.sequence_point_cloud_path = pcl_path + + if frame_annotation.viewpoint is not None: + frame_data.camera = self._get_pytorch3d_camera(frame_annotation) + + if self.box_crop: + frame_data.crop_by_metadata_bbox_(self.box_crop_context) + + if self.image_height is not None and self.image_width is not None: + new_size = (self.image_height, self.image_width) + frame_data.resize_frame_( + new_size_hw=torch.tensor(new_size, dtype=torch.long), # pyre-ignore + ) + + return frame_data + + def _load_fg_probability( + self, entry: types.FrameAnnotation + ) -> Tuple[Optional[torch.Tensor], Optional[str], Optional[torch.Tensor]]: + + full_path = os.path.join(self.dataset_root, entry.mask.path) # pyre-ignore + fg_probability = load_mask(self._local_path(full_path)) + # we can use provided bbox_xywh or calculate it based on mask + # saves time to skip bbox calculation + # pyre-ignore + bbox_xywh = entry.mask.bounding_box_xywh or get_bbox_from_mask( + fg_probability, self.box_crop_mask_thr + ) + if fg_probability.shape[-2:] != entry.image.size: + raise ValueError( + f"bad mask size: {fg_probability.shape[-2:]} vs {entry.image.size}!" + ) + return ( + safe_as_tensor(fg_probability, torch.float), + full_path, + safe_as_tensor(bbox_xywh, torch.long), + ) + + def _load_images( + self, + entry: types.FrameAnnotation, + fg_probability: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, str]: + assert self.dataset_root is not None and entry.image is not None + path = os.path.join(self.dataset_root, entry.image.path) + image_rgb = load_image(self._local_path(path)) + + if image_rgb.shape[-2:] != entry.image.size: + raise ValueError( + f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!" + ) + + if self.mask_images: + assert fg_probability is not None + image_rgb *= fg_probability + + return image_rgb, path + + def _load_mask_depth( + self, + entry: types.FrameAnnotation, + fg_probability: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, str, torch.Tensor]: + entry_depth = entry.depth + assert entry_depth is not None + path = os.path.join(self.dataset_root, entry_depth.path) + depth_map = load_depth(self._local_path(path), entry_depth.scale_adjustment) + + if self.mask_depths: + assert fg_probability is not None + depth_map *= fg_probability + + if self.load_depth_masks: + assert entry_depth.mask_path is not None + mask_path = os.path.join(self.dataset_root, entry_depth.mask_path) + depth_mask = load_depth_mask(self._local_path(mask_path)) + else: + depth_mask = torch.ones_like(depth_map) + + return torch.tensor(depth_map), path, torch.tensor(depth_mask) + + def _get_pytorch3d_camera( + self, + entry: types.FrameAnnotation, + ) -> PerspectiveCameras: + entry_viewpoint = entry.viewpoint + assert entry_viewpoint is not None + # principal point and focal length + principal_point = torch.tensor( + entry_viewpoint.principal_point, dtype=torch.float + ) + focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float) + + format = entry_viewpoint.intrinsics_format + if entry_viewpoint.intrinsics_format == "ndc_norm_image_bounds": + # legacy PyTorch3D NDC format + # convert to pixels unequally and convert to ndc equally + image_size_as_list = list(reversed(entry.image.size)) + image_size_wh = torch.tensor(image_size_as_list, dtype=torch.float) + per_axis_scale = image_size_wh / image_size_wh.min() + focal_length = focal_length * per_axis_scale + principal_point = principal_point * per_axis_scale + elif entry_viewpoint.intrinsics_format != "ndc_isotropic": + raise ValueError(f"Unknown intrinsics format: {format}") + + return PerspectiveCameras( + focal_length=focal_length[None], + principal_point=principal_point[None], + R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None], + T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None], + ) + + def _fix_point_cloud_path(self, path: str) -> str: + """ + Fix up a point cloud path from the dataset. + Some files in Co3Dv2 have an accidental absolute path stored. + """ + unwanted_prefix = ( + "/large_experiments/p3/replay/datasets/co3d/co3d45k_220512/export_v23/" + ) + if path.startswith(unwanted_prefix): + path = path[len(unwanted_prefix) :] + return os.path.join(self.dataset_root, path) + + def _local_path(self, path: str) -> str: + if self.path_manager is None: + return path + return self.path_manager.get_local_path(path) + + +@registry.register +class FrameDataBuilder(GenericWorkaround, GenericFrameDataBuilder[FrameData]): + """ + A concrete class to build a FrameData object, load and process the binary data (crop + and resize). Beware that modifications of frame data are done in-place. Please see + the documentation for `GenericFrameDataBuilder` for the description of parameters + and methods. + """ + + frame_data_type: ClassVar[Type[FrameData]] = FrameData diff --git a/pytorch3d/implicitron/dataset/json_index_dataset.py b/pytorch3d/implicitron/dataset/json_index_dataset.py index 669f4e9b6..caa016d2d 100644 --- a/pytorch3d/implicitron/dataset/json_index_dataset.py +++ b/pytorch3d/implicitron/dataset/json_index_dataset.py @@ -15,7 +15,6 @@ import warnings from collections import defaultdict from itertools import islice -from pathlib import Path from typing import ( Any, ClassVar, @@ -30,19 +29,15 @@ Union, ) -import numpy as np -import torch -from PIL import Image +from pytorch3d.implicitron.dataset import types +from pytorch3d.implicitron.dataset.dataset_base import DatasetBase +from pytorch3d.implicitron.dataset.frame_data import FrameData, FrameDataBuilder +from pytorch3d.implicitron.dataset.utils import is_known_frame_scalar from pytorch3d.implicitron.tools.config import registry, ReplaceableBase -from pytorch3d.io import IO from pytorch3d.renderer.camera_utils import join_cameras_as_batch -from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras -from pytorch3d.structures.pointclouds import Pointclouds -from tqdm import tqdm +from pytorch3d.renderer.cameras import CamerasBase -from . import types -from .dataset_base import DatasetBase, FrameData -from .utils import is_known_frame_scalar +from tqdm import tqdm logger = logging.getLogger(__name__) @@ -65,7 +60,7 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase): A dataset with annotations in json files like the Common Objects in 3D (CO3D) dataset. - Args: + Metadata-related args:: frame_annotations_file: A zipped json file containing metadata of the frames in the dataset, serialized List[types.FrameAnnotation]. sequence_annotations_file: A zipped json file containing metadata of the @@ -83,6 +78,24 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase): pick_sequence: A list of sequence names to restrict the dataset to. exclude_sequence: A list of the names of the sequences to exclude. limit_category_to: Restrict the dataset to the given list of categories. + remove_empty_masks: Removes the frames with no active foreground pixels + in the segmentation mask after thresholding (see box_crop_mask_thr). + n_frames_per_sequence: If > 0, randomly samples #n_frames_per_sequence + frames in each sequences uniformly without replacement if it has + more frames than that; applied before other frame-level filters. + seed: The seed of the random generator sampling #n_frames_per_sequence + random frames per sequence. + sort_frames: Enable frame annotations sorting to group frames from the + same sequences together and order them by timestamps + eval_batches: A list of batches that form the evaluation set; + list of batch-sized lists of indices corresponding to __getitem__ + of this class, thus it can be used directly as a batch sampler. + eval_batch_index: + ( Optional[List[List[Union[Tuple[str, int, str], Tuple[str, int]]]] ) + A list of batches of frames described as (sequence_name, frame_idx) + that can form the evaluation set, `eval_batches` will be set from this. + + Blob-loading parameters: dataset_root: The root folder of the dataset; all the paths in jsons are specified relative to this root (but not json paths themselves). load_images: Enable loading the frame RGB data. @@ -109,23 +122,6 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase): is greater than this threshold, the loader lowers it and repeats. box_crop_context: The amount of additional padding added to each dimension of the cropping bounding box, relative to box size. - remove_empty_masks: Removes the frames with no active foreground pixels - in the segmentation mask after thresholding (see box_crop_mask_thr). - n_frames_per_sequence: If > 0, randomly samples #n_frames_per_sequence - frames in each sequences uniformly without replacement if it has - more frames than that; applied before other frame-level filters. - seed: The seed of the random generator sampling #n_frames_per_sequence - random frames per sequence. - sort_frames: Enable frame annotations sorting to group frames from the - same sequences together and order them by timestamps - eval_batches: A list of batches that form the evaluation set; - list of batch-sized lists of indices corresponding to __getitem__ - of this class, thus it can be used directly as a batch sampler. - eval_batch_index: - ( Optional[List[List[Union[Tuple[str, int, str], Tuple[str, int]]]] ) - A list of batches of frames described as (sequence_name, frame_idx) - that can form the evaluation set, `eval_batches` will be set from this. - """ frame_annotations_type: ClassVar[ @@ -162,12 +158,14 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase): sort_frames: bool = False eval_batches: Any = None eval_batch_index: Any = None + # initialised in __post_init__ + # commented because of OmegaConf (for tests to pass) + # _frame_data_builder: FrameDataBuilder = field(init=False) # frame_annots: List[FrameAnnotsEntry] = field(init=False) # seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False) + # _seq_to_idx: Dict[str, List[int]] = field(init=False) def __post_init__(self) -> None: - # pyre-fixme[16]: `JsonIndexDataset` has no attribute `subset_to_image_path`. - self.subset_to_image_path = None self._load_frames() self._load_sequences() if self.sort_frames: @@ -175,9 +173,27 @@ def __post_init__(self) -> None: self._load_subset_lists() self._filter_db() # also computes sequence indices self._extract_and_set_eval_batches() + + # pyre-ignore + self._frame_data_builder = FrameDataBuilder( + dataset_root=self.dataset_root, + load_images=self.load_images, + load_depths=self.load_depths, + load_depth_masks=self.load_depth_masks, + load_masks=self.load_masks, + load_point_clouds=self.load_point_clouds, + max_points=self.max_points, + mask_images=self.mask_images, + mask_depths=self.mask_depths, + image_height=self.image_height, + image_width=self.image_width, + box_crop=self.box_crop, + box_crop_mask_thr=self.box_crop_mask_thr, + box_crop_context=self.box_crop_context, + ) logger.info(str(self)) - def _extract_and_set_eval_batches(self): + def _extract_and_set_eval_batches(self) -> None: """ Sets eval_batches based on input eval_batch_index. """ @@ -207,13 +223,13 @@ def join(self, other_datasets: Iterable[DatasetBase]) -> None: # https://gist.github.com/treyhunner/f35292e676efa0be1728 functools.reduce( lambda a, b: {**a, **b}, - [d.seq_annots for d in other_datasets], # pyre-ignore[16] + # pyre-ignore[16] + [d.seq_annots for d in other_datasets], ) ) all_eval_batches = [ self.eval_batches, - # pyre-ignore - *[d.eval_batches for d in other_datasets], + *[d.eval_batches for d in other_datasets], # pyre-ignore[16] ] if not ( all(ba is None for ba in all_eval_batches) @@ -251,7 +267,7 @@ def seq_frame_index_to_dataset_index( allow_missing_indices: bool = False, remove_missing_indices: bool = False, suppress_missing_index_warning: bool = True, - ) -> List[List[Union[Optional[int], int]]]: + ) -> Union[List[List[Optional[int]]], List[List[int]]]: """ Obtain indices into the dataset object given a list of frame ids. @@ -323,9 +339,7 @@ def _get_dataset_idx( valid_dataset_idx = [ [b for b in batch if b is not None] for batch in dataset_idx ] - return [ # pyre-ignore[7] - batch for batch in valid_dataset_idx if len(batch) > 0 - ] + return [batch for batch in valid_dataset_idx if len(batch) > 0] return dataset_idx @@ -417,255 +431,18 @@ def __getitem__(self, index) -> FrameData: raise IndexError(f"index {index} out of range {len(self.frame_annots)}") entry = self.frame_annots[index]["frame_annotation"] - # pyre-ignore[16] - point_cloud = self.seq_annots[entry.sequence_name].point_cloud - frame_data = FrameData( - frame_number=_safe_as_tensor(entry.frame_number, torch.long), - frame_timestamp=_safe_as_tensor(entry.frame_timestamp, torch.float), - sequence_name=entry.sequence_name, - sequence_category=self.seq_annots[entry.sequence_name].category, - camera_quality_score=_safe_as_tensor( - self.seq_annots[entry.sequence_name].viewpoint_quality_score, - torch.float, - ), - point_cloud_quality_score=_safe_as_tensor( - point_cloud.quality_score, torch.float - ) - if point_cloud is not None - else None, - ) - # The rest of the fields are optional + # pyre-ignore + frame_data = self._frame_data_builder.build( + entry, + # pyre-ignore + self.seq_annots[entry.sequence_name], + ) + # Optional field frame_data.frame_type = self._get_frame_type(self.frame_annots[index]) - ( - frame_data.fg_probability, - frame_data.mask_path, - frame_data.bbox_xywh, - clamp_bbox_xyxy, - frame_data.crop_bbox_xywh, - ) = self._load_crop_fg_probability(entry) - - scale = 1.0 - if self.load_images and entry.image is not None: - # original image size - frame_data.image_size_hw = _safe_as_tensor(entry.image.size, torch.long) - - ( - frame_data.image_rgb, - frame_data.image_path, - frame_data.mask_crop, - scale, - ) = self._load_crop_images( - entry, frame_data.fg_probability, clamp_bbox_xyxy - ) - - if self.load_depths and entry.depth is not None: - ( - frame_data.depth_map, - frame_data.depth_path, - frame_data.depth_mask, - ) = self._load_mask_depth(entry, clamp_bbox_xyxy, frame_data.fg_probability) - - if entry.viewpoint is not None: - frame_data.camera = self._get_pytorch3d_camera( - entry, - scale, - clamp_bbox_xyxy, - ) - - if self.load_point_clouds and point_cloud is not None: - pcl_path = self._fix_point_cloud_path(point_cloud.path) - frame_data.sequence_point_cloud = _load_pointcloud( - self._local_path(pcl_path), max_points=self.max_points - ) - frame_data.sequence_point_cloud_path = pcl_path - return frame_data - def _fix_point_cloud_path(self, path: str) -> str: - """ - Fix up a point cloud path from the dataset. - Some files in Co3Dv2 have an accidental absolute path stored. - """ - unwanted_prefix = ( - "/large_experiments/p3/replay/datasets/co3d/co3d45k_220512/export_v23/" - ) - if path.startswith(unwanted_prefix): - path = path[len(unwanted_prefix) :] - return os.path.join(self.dataset_root, path) - - def _load_crop_fg_probability( - self, entry: types.FrameAnnotation - ) -> Tuple[ - Optional[torch.Tensor], - Optional[str], - Optional[torch.Tensor], - Optional[torch.Tensor], - Optional[torch.Tensor], - ]: - fg_probability = None - full_path = None - bbox_xywh = None - clamp_bbox_xyxy = None - crop_box_xywh = None - - if (self.load_masks or self.box_crop) and entry.mask is not None: - full_path = os.path.join(self.dataset_root, entry.mask.path) - mask = _load_mask(self._local_path(full_path)) - - if mask.shape[-2:] != entry.image.size: - raise ValueError( - f"bad mask size: {mask.shape[-2:]} vs {entry.image.size}!" - ) - - bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr)) - - if self.box_crop: - clamp_bbox_xyxy = _clamp_box_to_image_bounds_and_round( - _get_clamp_bbox( - bbox_xywh, - image_path=entry.image.path, - box_crop_context=self.box_crop_context, - ), - image_size_hw=tuple(mask.shape[-2:]), - ) - crop_box_xywh = _bbox_xyxy_to_xywh(clamp_bbox_xyxy) - - mask = _crop_around_box(mask, clamp_bbox_xyxy, full_path) - - fg_probability, _, _ = self._resize_image(mask, mode="nearest") - - return fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy, crop_box_xywh - - def _load_crop_images( - self, - entry: types.FrameAnnotation, - fg_probability: Optional[torch.Tensor], - clamp_bbox_xyxy: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, str, torch.Tensor, float]: - assert self.dataset_root is not None and entry.image is not None - path = os.path.join(self.dataset_root, entry.image.path) - image_rgb = _load_image(self._local_path(path)) - - if image_rgb.shape[-2:] != entry.image.size: - raise ValueError( - f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!" - ) - - if self.box_crop: - assert clamp_bbox_xyxy is not None - image_rgb = _crop_around_box(image_rgb, clamp_bbox_xyxy, path) - - image_rgb, scale, mask_crop = self._resize_image(image_rgb) - - if self.mask_images: - assert fg_probability is not None - image_rgb *= fg_probability - - return image_rgb, path, mask_crop, scale - - def _load_mask_depth( - self, - entry: types.FrameAnnotation, - clamp_bbox_xyxy: Optional[torch.Tensor], - fg_probability: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, str, torch.Tensor]: - entry_depth = entry.depth - assert entry_depth is not None - path = os.path.join(self.dataset_root, entry_depth.path) - depth_map = _load_depth(self._local_path(path), entry_depth.scale_adjustment) - - if self.box_crop: - assert clamp_bbox_xyxy is not None - depth_bbox_xyxy = _rescale_bbox( - clamp_bbox_xyxy, entry.image.size, depth_map.shape[-2:] - ) - depth_map = _crop_around_box(depth_map, depth_bbox_xyxy, path) - - depth_map, _, _ = self._resize_image(depth_map, mode="nearest") - - if self.mask_depths: - assert fg_probability is not None - depth_map *= fg_probability - - if self.load_depth_masks: - assert entry_depth.mask_path is not None - mask_path = os.path.join(self.dataset_root, entry_depth.mask_path) - depth_mask = _load_depth_mask(self._local_path(mask_path)) - - if self.box_crop: - assert clamp_bbox_xyxy is not None - depth_mask_bbox_xyxy = _rescale_bbox( - clamp_bbox_xyxy, entry.image.size, depth_mask.shape[-2:] - ) - depth_mask = _crop_around_box( - depth_mask, depth_mask_bbox_xyxy, mask_path - ) - - depth_mask, _, _ = self._resize_image(depth_mask, mode="nearest") - else: - depth_mask = torch.ones_like(depth_map) - - return depth_map, path, depth_mask - - def _get_pytorch3d_camera( - self, - entry: types.FrameAnnotation, - scale: float, - clamp_bbox_xyxy: Optional[torch.Tensor], - ) -> PerspectiveCameras: - entry_viewpoint = entry.viewpoint - assert entry_viewpoint is not None - # principal point and focal length - principal_point = torch.tensor( - entry_viewpoint.principal_point, dtype=torch.float - ) - focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float) - - half_image_size_wh_orig = ( - torch.tensor(list(reversed(entry.image.size)), dtype=torch.float) / 2.0 - ) - - # first, we convert from the dataset's NDC convention to pixels - format = entry_viewpoint.intrinsics_format - if format.lower() == "ndc_norm_image_bounds": - # this is e.g. currently used in CO3D for storing intrinsics - rescale = half_image_size_wh_orig - elif format.lower() == "ndc_isotropic": - rescale = half_image_size_wh_orig.min() - else: - raise ValueError(f"Unknown intrinsics format: {format}") - - # principal point and focal length in pixels - principal_point_px = half_image_size_wh_orig - principal_point * rescale - focal_length_px = focal_length * rescale - if self.box_crop: - assert clamp_bbox_xyxy is not None - principal_point_px -= clamp_bbox_xyxy[:2] - - # now, convert from pixels to PyTorch3D v0.5+ NDC convention - if self.image_height is None or self.image_width is None: - out_size = list(reversed(entry.image.size)) - else: - out_size = [self.image_width, self.image_height] - - half_image_size_output = torch.tensor(out_size, dtype=torch.float) / 2.0 - half_min_image_size_output = half_image_size_output.min() - - # rescaled principal point and focal length in ndc - principal_point = ( - half_image_size_output - principal_point_px * scale - ) / half_min_image_size_output - focal_length = focal_length_px * scale / half_min_image_size_output - - return PerspectiveCameras( - focal_length=focal_length[None], - principal_point=principal_point[None], - R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None], - T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None], - ) - def _load_frames(self) -> None: logger.info(f"Loading Co3D frames from {self.frame_annotations_file}.") local_file = self._local_path(self.frame_annotations_file) @@ -853,35 +630,6 @@ def _invalidate_seq_to_idx(self) -> None: # pyre-ignore[16] self._seq_to_idx = seq_to_idx - def _resize_image( - self, image, mode="bilinear" - ) -> Tuple[torch.Tensor, float, torch.Tensor]: - image_height, image_width = self.image_height, self.image_width - if image_height is None or image_width is None: - # skip the resizing - imre_ = torch.from_numpy(image) - return imre_, 1.0, torch.ones_like(imre_[:1]) - # takes numpy array, returns pytorch tensor - minscale = min( - image_height / image.shape[-2], - image_width / image.shape[-1], - ) - imre = torch.nn.functional.interpolate( - torch.from_numpy(image)[None], - scale_factor=minscale, - mode=mode, - align_corners=False if mode == "bilinear" else None, - recompute_scale_factor=True, - )[0] - # pyre-fixme[19]: Expected 1 positional argument. - imre_ = torch.zeros(image.shape[0], self.image_height, self.image_width) - imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre - # pyre-fixme[6]: For 2nd param expected `int` but got `Optional[int]`. - # pyre-fixme[6]: For 3rd param expected `int` but got `Optional[int]`. - mask = torch.zeros(1, self.image_height, self.image_width) - mask[:, 0 : imre.shape[1], 0 : imre.shape[2]] = 1.0 - return imre_, minscale, mask - def _local_path(self, path: str) -> str: if self.path_manager is None: return path @@ -918,169 +666,3 @@ def get_eval_batches(self) -> Optional[List[List[int]]]: def _seq_name_to_seed(seq_name) -> int: return int(hashlib.sha1(seq_name.encode("utf-8")).hexdigest(), 16) - - -def _load_image(path) -> np.ndarray: - with Image.open(path) as pil_im: - im = np.array(pil_im.convert("RGB")) - im = im.transpose((2, 0, 1)) - im = im.astype(np.float32) / 255.0 - return im - - -def _load_16big_png_depth(depth_png) -> np.ndarray: - with Image.open(depth_png) as depth_pil: - # the image is stored with 16-bit depth but PIL reads it as I (32 bit). - # we cast it to uint16, then reinterpret as float16, then cast to float32 - depth = ( - np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16) - .astype(np.float32) - .reshape((depth_pil.size[1], depth_pil.size[0])) - ) - return depth - - -def _load_1bit_png_mask(file: str) -> np.ndarray: - with Image.open(file) as pil_im: - mask = (np.array(pil_im.convert("L")) > 0.0).astype(np.float32) - return mask - - -def _load_depth_mask(path: str) -> np.ndarray: - if not path.lower().endswith(".png"): - raise ValueError('unsupported depth mask file name "%s"' % path) - m = _load_1bit_png_mask(path) - return m[None] # fake feature channel - - -def _load_depth(path, scale_adjustment) -> np.ndarray: - if not path.lower().endswith(".png"): - raise ValueError('unsupported depth file name "%s"' % path) - - d = _load_16big_png_depth(path) * scale_adjustment - d[~np.isfinite(d)] = 0.0 - return d[None] # fake feature channel - - -def _load_mask(path) -> np.ndarray: - with Image.open(path) as pil_im: - mask = np.array(pil_im) - mask = mask.astype(np.float32) / 255.0 - return mask[None] # fake feature channel - - -def _get_1d_bounds(arr) -> Tuple[int, int]: - nz = np.flatnonzero(arr) - return nz[0], nz[-1] + 1 - - -def _get_bbox_from_mask( - mask, thr, decrease_quant: float = 0.05 -) -> Tuple[int, int, int, int]: - # bbox in xywh - masks_for_box = np.zeros_like(mask) - while masks_for_box.sum() <= 1.0: - masks_for_box = (mask > thr).astype(np.float32) - thr -= decrease_quant - if thr <= 0.0: - warnings.warn(f"Empty masks_for_bbox (thr={thr}) => using full image.") - - x0, x1 = _get_1d_bounds(masks_for_box.sum(axis=-2)) - y0, y1 = _get_1d_bounds(masks_for_box.sum(axis=-1)) - - return x0, y0, x1 - x0, y1 - y0 - - -def _get_clamp_bbox( - bbox: torch.Tensor, - box_crop_context: float = 0.0, - image_path: str = "", -) -> torch.Tensor: - # box_crop_context: rate of expansion for bbox - # returns possibly expanded bbox xyxy as float - - bbox = bbox.clone() # do not edit bbox in place - - # increase box size - if box_crop_context > 0.0: - c = box_crop_context - bbox = bbox.float() - bbox[0] -= bbox[2] * c / 2 - bbox[1] -= bbox[3] * c / 2 - bbox[2] += bbox[2] * c - bbox[3] += bbox[3] * c - - if (bbox[2:] <= 1.0).any(): - raise ValueError( - f"squashed image {image_path}!! The bounding box contains no pixels." - ) - - bbox[2:] = torch.clamp(bbox[2:], 2) # set min height, width to 2 along both axes - bbox_xyxy = _bbox_xywh_to_xyxy(bbox, clamp_size=2) - - return bbox_xyxy - - -def _crop_around_box(tensor, bbox, impath: str = ""): - # bbox is xyxy, where the upper bound is corrected with +1 - bbox = _clamp_box_to_image_bounds_and_round( - bbox, - image_size_hw=tensor.shape[-2:], - ) - tensor = tensor[..., bbox[1] : bbox[3], bbox[0] : bbox[2]] - assert all(c > 0 for c in tensor.shape), f"squashed image {impath}" - return tensor - - -def _clamp_box_to_image_bounds_and_round( - bbox_xyxy: torch.Tensor, - image_size_hw: Tuple[int, int], -) -> torch.LongTensor: - bbox_xyxy = bbox_xyxy.clone() - bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0, image_size_hw[-1]) - bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0, image_size_hw[-2]) - if not isinstance(bbox_xyxy, torch.LongTensor): - bbox_xyxy = bbox_xyxy.round().long() - return bbox_xyxy # pyre-ignore [7] - - -def _rescale_bbox(bbox: torch.Tensor, orig_res, new_res) -> torch.Tensor: - assert bbox is not None - assert np.prod(orig_res) > 1e-8 - # average ratio of dimensions - rel_size = (new_res[0] / orig_res[0] + new_res[1] / orig_res[1]) / 2.0 - return bbox * rel_size - - -def _bbox_xyxy_to_xywh(xyxy: torch.Tensor) -> torch.Tensor: - wh = xyxy[2:] - xyxy[:2] - xywh = torch.cat([xyxy[:2], wh]) - return xywh - - -def _bbox_xywh_to_xyxy( - xywh: torch.Tensor, clamp_size: Optional[int] = None -) -> torch.Tensor: - xyxy = xywh.clone() - if clamp_size is not None: - xyxy[2:] = torch.clamp(xyxy[2:], clamp_size) - xyxy[2:] += xyxy[:2] - return xyxy - - -def _safe_as_tensor(data, dtype): - if data is None: - return None - return torch.tensor(data, dtype=dtype) - - -# NOTE this cache is per-worker; they are implemented as processes. -# each batch is loaded and collated by a single worker; -# since sequences tend to co-occur within batches, this is useful. -@functools.lru_cache(maxsize=256) -def _load_pointcloud(pcl_path: Union[str, Path], max_points: int = 0) -> Pointclouds: - pcl = IO().load_pointcloud(pcl_path) - if max_points > 0: - pcl = pcl.subsample(max_points) - - return pcl diff --git a/pytorch3d/implicitron/dataset/single_sequence_dataset.py b/pytorch3d/implicitron/dataset/single_sequence_dataset.py index 6a0a028b5..16972c6cc 100644 --- a/pytorch3d/implicitron/dataset/single_sequence_dataset.py +++ b/pytorch3d/implicitron/dataset/single_sequence_dataset.py @@ -20,8 +20,9 @@ ) from pytorch3d.renderer import CamerasBase, join_cameras_as_batch, PerspectiveCameras -from .dataset_base import DatasetBase, FrameData +from .dataset_base import DatasetBase from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, PathManagerFactory +from .frame_data import FrameData from .utils import DATASET_TYPE_KNOWN, DATASET_TYPE_UNKNOWN _SINGLE_SEQUENCE_NAME: str = "one_sequence" @@ -69,7 +70,8 @@ def __getitem__(self, index) -> FrameData: sequence_name=_SINGLE_SEQUENCE_NAME, sequence_category=self.object_name, camera=pose, - image_size_hw=torch.tensor(image.shape[1:]), + # pyre-ignore + image_size_hw=torch.tensor(image.shape[1:], dtype=torch.long), image_rgb=image, fg_probability=fg_probability, frame_type=frame_type, diff --git a/pytorch3d/implicitron/dataset/types.py b/pytorch3d/implicitron/dataset/types.py index 174f9f5d3..421cbc345 100644 --- a/pytorch3d/implicitron/dataset/types.py +++ b/pytorch3d/implicitron/dataset/types.py @@ -55,6 +55,8 @@ class MaskAnnotation: path: str # (soft) number of pixels in the mask; sum(Prob(fg | pixel)) mass: Optional[float] = None + # tight bounding box around the foreground mask + bounding_box_xywh: Optional[Tuple[float, float, float, float]] = None @dataclass diff --git a/pytorch3d/implicitron/dataset/utils.py b/pytorch3d/implicitron/dataset/utils.py index 05252aff1..0982fbc05 100644 --- a/pytorch3d/implicitron/dataset/utils.py +++ b/pytorch3d/implicitron/dataset/utils.py @@ -5,10 +5,18 @@ # LICENSE file in the root directory of this source tree. -from typing import List, Optional +import functools +import warnings +from pathlib import Path +from typing import List, Optional, Tuple, TypeVar, Union +import numpy as np import torch +from PIL import Image +from pytorch3d.io import IO +from pytorch3d.renderer.cameras import PerspectiveCameras +from pytorch3d.structures.pointclouds import Pointclouds DATASET_TYPE_TRAIN = "train" DATASET_TYPE_TEST = "test" @@ -16,6 +24,26 @@ DATASET_TYPE_UNKNOWN = "unseen" +class GenericWorkaround: + """ + OmegaConf.structured has a weirdness when you try to apply + it to a dataclass whose first base class is a Generic which is not + Dict. The issue is with a function called get_dict_key_value_types + in omegaconf/_utils.py. + For example this fails: + + @dataclass(eq=False) + class D(torch.utils.data.Dataset[int]): + a: int = 3 + + OmegaConf.structured(D) + + We avoid the problem by adding this class as an extra base class. + """ + + pass + + def is_known_frame_scalar(frame_type: str) -> bool: """ Given a single frame type corresponding to a single frame, return whether @@ -52,3 +80,286 @@ def is_train_frame( dtype=torch.bool, device=device, ) + + +def get_bbox_from_mask( + mask: np.ndarray, thr: float, decrease_quant: float = 0.05 +) -> Tuple[int, int, int, int]: + # bbox in xywh + masks_for_box = np.zeros_like(mask) + while masks_for_box.sum() <= 1.0: + masks_for_box = (mask > thr).astype(np.float32) + thr -= decrease_quant + if thr <= 0.0: + warnings.warn( + f"Empty masks_for_bbox (thr={thr}) => using full image.", stacklevel=1 + ) + + x0, x1 = get_1d_bounds(masks_for_box.sum(axis=-2)) + y0, y1 = get_1d_bounds(masks_for_box.sum(axis=-1)) + + return x0, y0, x1 - x0, y1 - y0 + + +def crop_around_box( + tensor: torch.Tensor, bbox: torch.Tensor, impath: str = "" +) -> torch.Tensor: + # bbox is xyxy, where the upper bound is corrected with +1 + bbox = clamp_box_to_image_bounds_and_round( + bbox, + image_size_hw=tuple(tensor.shape[-2:]), + ) + tensor = tensor[..., bbox[1] : bbox[3], bbox[0] : bbox[2]] + assert all(c > 0 for c in tensor.shape), f"squashed image {impath}" + return tensor + + +def clamp_box_to_image_bounds_and_round( + bbox_xyxy: torch.Tensor, + image_size_hw: Tuple[int, int], +) -> torch.LongTensor: + bbox_xyxy = bbox_xyxy.clone() + bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0, image_size_hw[-1]) + bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0, image_size_hw[-2]) + if not isinstance(bbox_xyxy, torch.LongTensor): + bbox_xyxy = bbox_xyxy.round().long() + return bbox_xyxy # pyre-ignore [7] + + +T = TypeVar("T", bound=torch.Tensor) + + +def bbox_xyxy_to_xywh(xyxy: T) -> T: + wh = xyxy[2:] - xyxy[:2] + xywh = torch.cat([xyxy[:2], wh]) + return xywh # pyre-ignore + + +def get_clamp_bbox( + bbox: torch.Tensor, + box_crop_context: float = 0.0, + image_path: str = "", +) -> torch.Tensor: + # box_crop_context: rate of expansion for bbox + # returns possibly expanded bbox xyxy as float + + bbox = bbox.clone() # do not edit bbox in place + + # increase box size + if box_crop_context > 0.0: + c = box_crop_context + bbox = bbox.float() + bbox[0] -= bbox[2] * c / 2 + bbox[1] -= bbox[3] * c / 2 + bbox[2] += bbox[2] * c + bbox[3] += bbox[3] * c + + if (bbox[2:] <= 1.0).any(): + raise ValueError( + f"squashed image {image_path}!! The bounding box contains no pixels." + ) + + bbox[2:] = torch.clamp(bbox[2:], 2) # set min height, width to 2 along both axes + bbox_xyxy = bbox_xywh_to_xyxy(bbox, clamp_size=2) + + return bbox_xyxy + + +def rescale_bbox( + bbox: torch.Tensor, + orig_res: Union[Tuple[int, int], torch.LongTensor], + new_res: Union[Tuple[int, int], torch.LongTensor], +) -> torch.Tensor: + assert bbox is not None + assert np.prod(orig_res) > 1e-8 + # average ratio of dimensions + # pyre-ignore + rel_size = (new_res[0] / orig_res[0] + new_res[1] / orig_res[1]) / 2.0 + return bbox * rel_size + + +def bbox_xywh_to_xyxy( + xywh: torch.Tensor, clamp_size: Optional[int] = None +) -> torch.Tensor: + xyxy = xywh.clone() + if clamp_size is not None: + xyxy[2:] = torch.clamp(xyxy[2:], clamp_size) + xyxy[2:] += xyxy[:2] + return xyxy + + +def get_1d_bounds(arr: np.ndarray) -> Tuple[int, int]: + nz = np.flatnonzero(arr) + return nz[0], nz[-1] + 1 + + +def resize_image( + image: Union[np.ndarray, torch.Tensor], + image_height: Optional[int], + image_width: Optional[int], + mode: str = "bilinear", +) -> Tuple[torch.Tensor, float, torch.Tensor]: + + if type(image) == np.ndarray: + image = torch.from_numpy(image) + + if image_height is None or image_width is None: + # skip the resizing + return image, 1.0, torch.ones_like(image[:1]) + # takes numpy array or tensor, returns pytorch tensor + minscale = min( + image_height / image.shape[-2], + image_width / image.shape[-1], + ) + imre = torch.nn.functional.interpolate( + image[None], + scale_factor=minscale, + mode=mode, + align_corners=False if mode == "bilinear" else None, + recompute_scale_factor=True, + )[0] + imre_ = torch.zeros(image.shape[0], image_height, image_width) + imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre + mask = torch.zeros(1, image_height, image_width) + mask[:, 0 : imre.shape[1], 0 : imre.shape[2]] = 1.0 + return imre_, minscale, mask + + +def load_image(path: str) -> np.ndarray: + with Image.open(path) as pil_im: + im = np.array(pil_im.convert("RGB")) + im = im.transpose((2, 0, 1)) + im = im.astype(np.float32) / 255.0 + return im + + +def load_mask(path: str) -> np.ndarray: + with Image.open(path) as pil_im: + mask = np.array(pil_im) + mask = mask.astype(np.float32) / 255.0 + return mask[None] # fake feature channel + + +def load_depth(path: str, scale_adjustment: float) -> np.ndarray: + if not path.lower().endswith(".png"): + raise ValueError('unsupported depth file name "%s"' % path) + + d = load_16big_png_depth(path) * scale_adjustment + d[~np.isfinite(d)] = 0.0 + return d[None] # fake feature channel + + +def load_16big_png_depth(depth_png: str) -> np.ndarray: + with Image.open(depth_png) as depth_pil: + # the image is stored with 16-bit depth but PIL reads it as I (32 bit). + # we cast it to uint16, then reinterpret as float16, then cast to float32 + depth = ( + np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16) + .astype(np.float32) + .reshape((depth_pil.size[1], depth_pil.size[0])) + ) + return depth + + +def load_1bit_png_mask(file: str) -> np.ndarray: + with Image.open(file) as pil_im: + mask = (np.array(pil_im.convert("L")) > 0.0).astype(np.float32) + return mask + + +def load_depth_mask(path: str) -> np.ndarray: + if not path.lower().endswith(".png"): + raise ValueError('unsupported depth mask file name "%s"' % path) + m = load_1bit_png_mask(path) + return m[None] # fake feature channel + + +def safe_as_tensor(data, dtype): + return torch.tensor(data, dtype=dtype) if data is not None else None + + +def _convert_ndc_to_pixels( + focal_length: torch.Tensor, + principal_point: torch.Tensor, + image_size_wh: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + half_image_size = image_size_wh / 2 + rescale = half_image_size.min() + principal_point_px = half_image_size - principal_point * rescale + focal_length_px = focal_length * rescale + return focal_length_px, principal_point_px + + +def _convert_pixels_to_ndc( + focal_length_px: torch.Tensor, + principal_point_px: torch.Tensor, + image_size_wh: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + half_image_size = image_size_wh / 2 + rescale = half_image_size.min() + principal_point = (half_image_size - principal_point_px) / rescale + focal_length = focal_length_px / rescale + return focal_length, principal_point + + +def adjust_camera_to_bbox_crop_( + camera: PerspectiveCameras, + image_size_wh: torch.Tensor, + clamp_bbox_xywh: torch.Tensor, +) -> None: + if len(camera) != 1: + raise ValueError("Adjusting currently works with singleton cameras camera only") + + focal_length_px, principal_point_px = _convert_ndc_to_pixels( + camera.focal_length[0], + camera.principal_point[0], # pyre-ignore + image_size_wh, + ) + principal_point_px_cropped = principal_point_px - clamp_bbox_xywh[:2] + + focal_length, principal_point_cropped = _convert_pixels_to_ndc( + focal_length_px, + principal_point_px_cropped, + clamp_bbox_xywh[2:], + ) + + camera.focal_length = focal_length[None] + camera.principal_point = principal_point_cropped[None] # pyre-ignore + + +def adjust_camera_to_image_scale_( + camera: PerspectiveCameras, + original_size_wh: torch.Tensor, + new_size_wh: torch.LongTensor, +) -> PerspectiveCameras: + focal_length_px, principal_point_px = _convert_ndc_to_pixels( + camera.focal_length[0], + camera.principal_point[0], # pyre-ignore + original_size_wh, + ) + + # now scale and convert from pixels to NDC + image_size_wh_output = new_size_wh.float() + scale = (image_size_wh_output / original_size_wh).min(dim=-1, keepdim=True).values + focal_length_px_scaled = focal_length_px * scale + principal_point_px_scaled = principal_point_px * scale + + focal_length_scaled, principal_point_scaled = _convert_pixels_to_ndc( + focal_length_px_scaled, + principal_point_px_scaled, + image_size_wh_output, + ) + camera.focal_length = focal_length_scaled[None] + camera.principal_point = principal_point_scaled[None] # pyre-ignore + + +# NOTE this cache is per-worker; they are implemented as processes. +# each batch is loaded and collated by a single worker; +# since sequences tend to co-occur within batches, this is useful. +@functools.lru_cache(maxsize=256) +def load_pointcloud(pcl_path: Union[str, Path], max_points: int = 0) -> Pointclouds: + pcl = IO().load_pointcloud(pcl_path) + if max_points > 0: + pcl = pcl.subsample(max_points) + + return pcl diff --git a/pytorch3d/implicitron/dataset/visualize.py b/pytorch3d/implicitron/dataset/visualize.py index 6d0be0362..8c2775322 100644 --- a/pytorch3d/implicitron/dataset/visualize.py +++ b/pytorch3d/implicitron/dataset/visualize.py @@ -10,7 +10,7 @@ from pytorch3d.implicitron.tools.point_cloud_utils import get_rgbd_point_cloud from pytorch3d.structures import Pointclouds -from .dataset_base import FrameData +from .frame_data import FrameData from .json_index_dataset import JsonIndexDataset diff --git a/pytorch3d/implicitron/evaluation/evaluate_new_view_synthesis.py b/pytorch3d/implicitron/evaluation/evaluate_new_view_synthesis.py index 2f739852d..e380208b8 100644 --- a/pytorch3d/implicitron/evaluation/evaluate_new_view_synthesis.py +++ b/pytorch3d/implicitron/evaluation/evaluate_new_view_synthesis.py @@ -14,7 +14,7 @@ import numpy as np import torch import torch.nn.functional as F -from pytorch3d.implicitron.dataset.dataset_base import FrameData +from pytorch3d.implicitron.dataset.frame_data import FrameData from pytorch3d.implicitron.dataset.utils import is_train_frame from pytorch3d.implicitron.models.base_model import ImplicitronRender from pytorch3d.implicitron.tools import vis_utils diff --git a/tests/implicitron/test_batch_sampler.py b/tests/implicitron/test_batch_sampler.py index 9f3732a7b..f2ac4a965 100644 --- a/tests/implicitron/test_batch_sampler.py +++ b/tests/implicitron/test_batch_sampler.py @@ -17,7 +17,8 @@ DoublePoolBatchSampler, ) -from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData +from pytorch3d.implicitron.dataset.dataset_base import DatasetBase +from pytorch3d.implicitron.dataset.frame_data import FrameData from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler diff --git a/tests/implicitron/test_bbox.py b/tests/implicitron/test_bbox.py index 999dfc924..08dc119fe 100644 --- a/tests/implicitron/test_bbox.py +++ b/tests/implicitron/test_bbox.py @@ -9,11 +9,19 @@ import numpy as np import torch -from pytorch3d.implicitron.dataset.json_index_dataset import ( - _bbox_xywh_to_xyxy, - _bbox_xyxy_to_xywh, - _get_bbox_from_mask, + +from pytorch3d.implicitron.dataset.utils import ( + bbox_xywh_to_xyxy, + bbox_xyxy_to_xywh, + clamp_box_to_image_bounds_and_round, + crop_around_box, + get_1d_bounds, + get_bbox_from_mask, + get_clamp_bbox, + rescale_bbox, + resize_image, ) + from tests.common_testing import TestCaseMixin @@ -31,9 +39,9 @@ def test_bbox_conversion(self): ] ) for bbox_xywh in bbox_xywh_list: - bbox_xyxy = _bbox_xywh_to_xyxy(bbox_xywh) - bbox_xywh_ = _bbox_xyxy_to_xywh(bbox_xyxy) - bbox_xyxy_ = _bbox_xywh_to_xyxy(bbox_xywh_) + bbox_xyxy = bbox_xywh_to_xyxy(bbox_xywh) + bbox_xywh_ = bbox_xyxy_to_xywh(bbox_xyxy) + bbox_xyxy_ = bbox_xywh_to_xyxy(bbox_xywh_) self.assertClose(bbox_xywh_, bbox_xywh) self.assertClose(bbox_xyxy, bbox_xyxy_) @@ -47,8 +55,8 @@ def test_compare_to_expected(self): ] ) for bbox_xywh, bbox_xyxy_expected in bbox_xywh_to_xyxy_expected: - self.assertClose(_bbox_xywh_to_xyxy(bbox_xywh), bbox_xyxy_expected) - self.assertClose(_bbox_xyxy_to_xywh(bbox_xyxy_expected), bbox_xywh) + self.assertClose(bbox_xywh_to_xyxy(bbox_xywh), bbox_xyxy_expected) + self.assertClose(bbox_xyxy_to_xywh(bbox_xyxy_expected), bbox_xywh) clamp_amnt = 3 bbox_xywh_to_xyxy_clamped_expected = torch.LongTensor( @@ -61,7 +69,7 @@ def test_compare_to_expected(self): ) for bbox_xywh, bbox_xyxy_expected in bbox_xywh_to_xyxy_clamped_expected: self.assertClose( - _bbox_xywh_to_xyxy(bbox_xywh, clamp_size=clamp_amnt), + bbox_xywh_to_xyxy(bbox_xywh, clamp_size=clamp_amnt), bbox_xyxy_expected, ) @@ -74,5 +82,61 @@ def test_mask_to_bbox(self): ] ).astype(np.float32) expected_bbox_xywh = [2, 1, 2, 1] - bbox_xywh = _get_bbox_from_mask(mask, 0.5) + bbox_xywh = get_bbox_from_mask(mask, 0.5) self.assertClose(bbox_xywh, expected_bbox_xywh) + + def test_crop_around_box(self): + bbox = torch.LongTensor([0, 1, 2, 3]) # (x_min, y_min, x_max, y_max) + image = torch.LongTensor( + [ + [0, 0, 10, 20], + [10, 20, 5, 1], + [10, 20, 1, 1], + [5, 4, 0, 1], + ] + ) + cropped = crop_around_box(image, bbox) + self.assertClose(cropped, image[1:3, 0:2]) + + def test_clamp_box_to_image_bounds_and_round(self): + bbox = torch.LongTensor([0, 1, 10, 12]) + image_size = (5, 6) + expected_clamped_bbox = torch.LongTensor([0, 1, image_size[1], image_size[0]]) + clamped_bbox = clamp_box_to_image_bounds_and_round(bbox, image_size) + self.assertClose(clamped_bbox, expected_clamped_bbox) + + def test_get_clamp_bbox(self): + bbox_xywh = torch.LongTensor([1, 1, 4, 5]) + clamped_bbox_xyxy = get_clamp_bbox(bbox_xywh, box_crop_context=2) + # size multiplied by 2 and added coordinates + self.assertClose(clamped_bbox_xyxy, torch.Tensor([-3, -4, 9, 11])) + + def test_rescale_bbox(self): + bbox = torch.Tensor([0.0, 1.0, 3.0, 4.0]) + original_resolution = (4, 4) + new_resolution = (8, 8) # twice bigger + rescaled_bbox = rescale_bbox(bbox, original_resolution, new_resolution) + self.assertClose(bbox * 2, rescaled_bbox) + + def test_get_1d_bounds(self): + array = [0, 1, 2] + bounds = get_1d_bounds(array) + # make nonzero 1d bounds of image + self.assertClose(bounds, [1, 3]) + + def test_resize_image(self): + image = np.random.rand(3, 300, 500) # rgb image 300x500 + expected_shape = (150, 250) + + resized_image, scale, mask_crop = resize_image( + image, image_height=expected_shape[0], image_width=expected_shape[1] + ) + + original_shape = image.shape[-2:] + expected_scale = min( + expected_shape[0] / original_shape[0], expected_shape[1] / original_shape[1] + ) + + self.assertEqual(scale, expected_scale) + self.assertEqual(resized_image.shape[-2:], expected_shape) + self.assertEqual(mask_crop.shape[-2:], expected_shape) diff --git a/tests/implicitron/test_data_cow.py b/tests/implicitron/test_data_cow.py index 07b0b339a..801863e9b 100644 --- a/tests/implicitron/test_data_cow.py +++ b/tests/implicitron/test_data_cow.py @@ -8,7 +8,7 @@ import unittest import torch -from pytorch3d.implicitron.dataset.dataset_base import FrameData +from pytorch3d.implicitron.dataset.frame_data import FrameData from pytorch3d.implicitron.dataset.rendered_mesh_dataset_map_provider import ( RenderedMeshDatasetMapProvider, ) diff --git a/tests/implicitron/test_evaluation.py b/tests/implicitron/test_evaluation.py index 1ac9db3b6..400d7835f 100644 --- a/tests/implicitron/test_evaluation.py +++ b/tests/implicitron/test_evaluation.py @@ -13,8 +13,10 @@ import unittest import lpips +import numpy as np import torch -from pytorch3d.implicitron.dataset.dataset_base import FrameData + +from pytorch3d.implicitron.dataset.frame_data import FrameData from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import eval_batch from pytorch3d.implicitron.models.base_model import ImplicitronModelBase @@ -268,7 +270,7 @@ def _check_metrics(self, frame_data, implicitron_render, eval_result): for metric in lower_better: m_better = eval_result[metric] m_worse = eval_result_bad[metric] - if m_better != m_better or m_worse != m_worse: + if np.isnan(m_better) or np.isnan(m_worse): continue # metric is missing, i.e. NaN _assert = ( self.assertLessEqual diff --git a/tests/implicitron/test_frame_data_builder.py b/tests/implicitron/test_frame_data_builder.py new file mode 100644 index 000000000..f150081b2 --- /dev/null +++ b/tests/implicitron/test_frame_data_builder.py @@ -0,0 +1,224 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import gzip +import os +import unittest +from typing import List + +import numpy as np +import torch + +from pytorch3d.implicitron.dataset import types +from pytorch3d.implicitron.dataset.dataset_base import FrameData +from pytorch3d.implicitron.dataset.frame_data import FrameDataBuilder +from pytorch3d.implicitron.dataset.utils import ( + load_16big_png_depth, + load_1bit_png_mask, + load_depth, + load_depth_mask, + load_image, + load_mask, + safe_as_tensor, +) +from pytorch3d.implicitron.tools.config import get_default_args +from pytorch3d.renderer.cameras import PerspectiveCameras + +from tests.common_testing import TestCaseMixin +from tests.implicitron.common_resources import get_skateboard_data + + +class TestFrameDataBuilder(TestCaseMixin, unittest.TestCase): + def setUp(self): + torch.manual_seed(42) + + category = "skateboard" + stack = contextlib.ExitStack() + self.dataset_root, self.path_manager = stack.enter_context( + get_skateboard_data() + ) + self.addCleanup(stack.close) + self.image_height = 768 + self.image_width = 512 + + self.frame_data_builder = FrameDataBuilder( + image_height=self.image_height, + image_width=self.image_width, + dataset_root=self.dataset_root, + path_manager=self.path_manager, + ) + + # loading single frame annotation of dataset (see JsonIndexDataset._load_frames()) + frame_file = os.path.join(self.dataset_root, category, "frame_annotations.jgz") + local_file = self.path_manager.get_local_path(frame_file) + with gzip.open(local_file, "rt", encoding="utf8") as zipfile: + frame_annots_list = types.load_dataclass( + zipfile, List[types.FrameAnnotation] + ) + self.frame_annotation = frame_annots_list[0] + + sequence_annotations_file = os.path.join( + self.dataset_root, category, "sequence_annotations.jgz" + ) + local_file = self.path_manager.get_local_path(sequence_annotations_file) + with gzip.open(local_file, "rt", encoding="utf8") as zipfile: + seq_annots_list = types.load_dataclass( + zipfile, List[types.SequenceAnnotation] + ) + seq_annots = {entry.sequence_name: entry for entry in seq_annots_list} + self.seq_annotation = seq_annots[self.frame_annotation.sequence_name] + + point_cloud = self.seq_annotation.point_cloud + self.frame_data = FrameData( + frame_number=safe_as_tensor(self.frame_annotation.frame_number, torch.long), + frame_timestamp=safe_as_tensor( + self.frame_annotation.frame_timestamp, torch.float + ), + sequence_name=self.frame_annotation.sequence_name, + sequence_category=self.seq_annotation.category, + camera_quality_score=safe_as_tensor( + self.seq_annotation.viewpoint_quality_score, torch.float + ), + point_cloud_quality_score=safe_as_tensor( + point_cloud.quality_score, torch.float + ) + if point_cloud is not None + else None, + ) + + def test_frame_data_builder_args(self): + # test that FrameDataBuilder works with get_default_args + get_default_args(FrameDataBuilder) + + def test_fix_point_cloud_path(self): + """Some files in Co3Dv2 have an accidental absolute path stored.""" + original_path = "some_file_path" + modified_path = self.frame_data_builder._fix_point_cloud_path(original_path) + self.assertIn(original_path, modified_path) + self.assertIn(self.frame_data_builder.dataset_root, modified_path) + + def test_load_and_adjust_frame_data(self): + self.frame_data.image_size_hw = safe_as_tensor( + self.frame_annotation.image.size, torch.long + ) + self.frame_data.effective_image_size_hw = self.frame_data.image_size_hw + + ( + self.frame_data.fg_probability, + self.frame_data.mask_path, + self.frame_data.bbox_xywh, + ) = self.frame_data_builder._load_fg_probability(self.frame_annotation) + + self.assertIsNotNone(self.frame_data.mask_path) + self.assertTrue(torch.is_tensor(self.frame_data.fg_probability)) + self.assertTrue(torch.is_tensor(self.frame_data.bbox_xywh)) + # assert bboxes shape + self.assertEqual(self.frame_data.bbox_xywh.shape, torch.Size([4])) + + ( + self.frame_data.image_rgb, + self.frame_data.image_path, + ) = self.frame_data_builder._load_images( + self.frame_annotation, self.frame_data.fg_probability + ) + self.assertEqual(type(self.frame_data.image_rgb), np.ndarray) + self.assertIsNotNone(self.frame_data.image_path) + + ( + self.frame_data.depth_map, + depth_path, + self.frame_data.depth_mask, + ) = self.frame_data_builder._load_mask_depth( + self.frame_annotation, + self.frame_data.fg_probability, + ) + self.assertTrue(torch.is_tensor(self.frame_data.depth_map)) + self.assertIsNotNone(depth_path) + self.assertTrue(torch.is_tensor(self.frame_data.depth_mask)) + + new_size = (self.image_height, self.image_width) + + if self.frame_data_builder.box_crop: + self.frame_data.crop_by_metadata_bbox_( + self.frame_data_builder.box_crop_context, + ) + + # assert image and mask shapes after resize + self.frame_data.resize_frame_( + new_size_hw=torch.tensor(new_size, dtype=torch.long), + ) + self.assertEqual( + self.frame_data.mask_crop.shape, + torch.Size([1, self.image_height, self.image_width]), + ) + self.assertEqual( + self.frame_data.image_rgb.shape, + torch.Size([3, self.image_height, self.image_width]), + ) + self.assertEqual( + self.frame_data.mask_crop.shape, + torch.Size([1, self.image_height, self.image_width]), + ) + self.assertEqual( + self.frame_data.fg_probability.shape, + torch.Size([1, self.image_height, self.image_width]), + ) + self.assertEqual( + self.frame_data.depth_map.shape, + torch.Size([1, self.image_height, self.image_width]), + ) + self.assertEqual( + self.frame_data.depth_mask.shape, + torch.Size([1, self.image_height, self.image_width]), + ) + self.frame_data.camera = self.frame_data_builder._get_pytorch3d_camera( + self.frame_annotation, + ) + self.assertEqual(type(self.frame_data.camera), PerspectiveCameras) + + def test_load_image(self): + path = os.path.join(self.dataset_root, self.frame_annotation.image.path) + local_path = self.path_manager.get_local_path(path) + image = load_image(local_path) + self.assertEqual(image.dtype, np.float32) + self.assertLessEqual(np.max(image), 1.0) + self.assertGreaterEqual(np.min(image), 0.0) + + def test_load_mask(self): + path = os.path.join(self.dataset_root, self.frame_annotation.mask.path) + mask = load_mask(path) + self.assertEqual(mask.dtype, np.float32) + self.assertLessEqual(np.max(mask), 1.0) + self.assertGreaterEqual(np.min(mask), 0.0) + + def test_load_depth(self): + path = os.path.join(self.dataset_root, self.frame_annotation.depth.path) + depth_map = load_depth(path, self.frame_annotation.depth.scale_adjustment) + self.assertEqual(depth_map.dtype, np.float32) + self.assertEqual(len(depth_map.shape), 3) + + def test_load_16big_png_depth(self): + path = os.path.join(self.dataset_root, self.frame_annotation.depth.path) + depth_map = load_16big_png_depth(path) + self.assertEqual(depth_map.dtype, np.float32) + self.assertEqual(len(depth_map.shape), 2) + + def test_load_1bit_png_mask(self): + mask_path = os.path.join( + self.dataset_root, self.frame_annotation.depth.mask_path + ) + mask = load_1bit_png_mask(mask_path) + self.assertEqual(mask.dtype, np.float32) + self.assertEqual(len(mask.shape), 2) + + def test_load_depth_mask(self): + mask_path = os.path.join( + self.dataset_root, self.frame_annotation.depth.mask_path + ) + mask = load_depth_mask(mask_path) + self.assertEqual(mask.dtype, np.float32) + self.assertEqual(len(mask.shape), 3) diff --git a/tests/implicitron/test_json_index_dataset_provider_v2.py b/tests/implicitron/test_json_index_dataset_provider_v2.py index 3191c0ee6..c99481a48 100644 --- a/tests/implicitron/test_json_index_dataset_provider_v2.py +++ b/tests/implicitron/test_json_index_dataset_provider_v2.py @@ -17,7 +17,7 @@ import torch import torchvision from PIL import Image -from pytorch3d.implicitron.dataset.dataset_base import FrameData +from pytorch3d.implicitron.dataset.frame_data import FrameData from pytorch3d.implicitron.dataset.json_index_dataset_map_provider_v2 import ( JsonIndexDatasetMapProviderV2, )