Skip to content

Commit

Permalink
Extract BlobLoader class from JsonIndexDataset and moving crop_by_bbo…
Browse files Browse the repository at this point in the history
…x to FrameData

Summary:
extracted blob loader
added documentation for blob_loader
did some refactoring on fields
for detailed steps and discussions see:
#1463
fairinternal/pixar_replay#160

Reviewed By: bottler

Differential Revision: D44061728

fbshipit-source-id: eefb21e9679003045d73729f96e6a93a1d4d2d51
  • Loading branch information
Ildar Salakhiev authored and facebook-github-bot committed Apr 4, 2023
1 parent c759fc5 commit ebdbfde
Show file tree
Hide file tree
Showing 15 changed files with 1,421 additions and 694 deletions.
3 changes: 2 additions & 1 deletion pytorch3d/implicitron/dataset/data_loader_map_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
Sampler,
)

from .dataset_base import DatasetBase, FrameData
from .dataset_base import DatasetBase
from .dataset_map_provider import DatasetMap
from .frame_data import FrameData
from .scene_batch_sampler import SceneBatchSampler
from .utils import is_known_frame_scalar

Expand Down
198 changes: 4 additions & 194 deletions pytorch3d/implicitron/dataset/dataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,217 +5,27 @@
# LICENSE file in the root directory of this source tree.

from collections import defaultdict
from dataclasses import dataclass, field, fields
from dataclasses import dataclass
from typing import (
Any,
ClassVar,
Dict,
Iterable,
Iterator,
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
Union,
)

import numpy as np
import torch
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds


@dataclass
class FrameData(Mapping[str, Any]):
"""
A type of the elements returned by indexing the dataset object.
It can represent both individual frames and batches of thereof;
in this documentation, the sizes of tensors refer to single frames;
add the first batch dimension for the collation result.
Args:
frame_number: The number of the frame within its sequence.
0-based continuous integers.
sequence_name: The unique name of the frame's sequence.
sequence_category: The object category of the sequence.
frame_timestamp: The time elapsed since the start of a sequence in sec.
image_size_hw: The size of the image in pixels; (height, width) tensor
of shape (2,).
image_path: The qualified path to the loaded image (with dataset_root).
image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image
of the frame; elements are floats in [0, 1].
mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image
regions. Regions can be invalid (mask_crop[i,j]=0) in case they
are a result of zero-padding of the image after cropping around
the object bounding box; elements are floats in {0.0, 1.0}.
depth_path: The qualified path to the frame's depth map.
depth_map: A float Tensor of shape `(1, H, W)` holding the depth map
of the frame; values correspond to distances from the camera;
use `depth_mask` and `mask_crop` to filter for valid pixels.
depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the
depth map that are valid for evaluation, they have been checked for
consistency across views; elements are floats in {0.0, 1.0}.
mask_path: A qualified path to the foreground probability mask.
fg_probability: A Tensor of `(1, H, W)` denoting the probability of the
pixels belonging to the captured object; elements are floats
in [0, 1].
bbox_xywh: The bounding box tightly enclosing the foreground object in the
format (x0, y0, width, height). The convention assumes that
`x0+width` and `y0+height` includes the boundary of the box.
I.e., to slice out the corresponding crop from an image tensor `I`
we execute `crop = I[..., y0:y0+height, x0:x0+width]`
crop_bbox_xywh: The bounding box denoting the boundaries of `image_rgb`
in the original image coordinates in the format (x0, y0, width, height).
The convention is the same as for `bbox_xywh`. `crop_bbox_xywh` differs
from `bbox_xywh` due to padding (which can happen e.g. due to
setting `JsonIndexDataset.box_crop_context > 0`)
camera: A PyTorch3D camera object corresponding the frame's viewpoint,
corrected for cropping if it happened.
camera_quality_score: The score proportional to the confidence of the
frame's camera estimation (the higher the more accurate).
point_cloud_quality_score: The score proportional to the accuracy of the
frame's sequence point cloud (the higher the more accurate).
sequence_point_cloud_path: The path to the sequence's point cloud.
sequence_point_cloud: A PyTorch3D Pointclouds object holding the
point cloud corresponding to the frame's sequence. When the object
represents a batch of frames, point clouds may be deduplicated;
see `sequence_point_cloud_idx`.
sequence_point_cloud_idx: Integer indices mapping frame indices to the
corresponding point clouds in `sequence_point_cloud`; to get the
corresponding point cloud to `image_rgb[i]`, use
`sequence_point_cloud[sequence_point_cloud_idx[i]]`.
frame_type: The type of the loaded frame specified in
`subset_lists_file`, if provided.
meta: A dict for storing additional frame information.
"""

frame_number: Optional[torch.LongTensor]
sequence_name: Union[str, List[str]]
sequence_category: Union[str, List[str]]
frame_timestamp: Optional[torch.Tensor] = None
image_size_hw: Optional[torch.Tensor] = None
image_path: Union[str, List[str], None] = None
image_rgb: Optional[torch.Tensor] = None
# masks out padding added due to cropping the square bit
mask_crop: Optional[torch.Tensor] = None
depth_path: Union[str, List[str], None] = None
depth_map: Optional[torch.Tensor] = None
depth_mask: Optional[torch.Tensor] = None
mask_path: Union[str, List[str], None] = None
fg_probability: Optional[torch.Tensor] = None
bbox_xywh: Optional[torch.Tensor] = None
crop_bbox_xywh: Optional[torch.Tensor] = None
camera: Optional[PerspectiveCameras] = None
camera_quality_score: Optional[torch.Tensor] = None
point_cloud_quality_score: Optional[torch.Tensor] = None
sequence_point_cloud_path: Union[str, List[str], None] = None
sequence_point_cloud: Optional[Pointclouds] = None
sequence_point_cloud_idx: Optional[torch.Tensor] = None
frame_type: Union[str, List[str], None] = None # known | unseen
meta: dict = field(default_factory=lambda: {})

def to(self, *args, **kwargs):
new_params = {}
for f in fields(self):
value = getattr(self, f.name)
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):
new_params[f.name] = value.to(*args, **kwargs)
else:
new_params[f.name] = value
return type(self)(**new_params)

def cpu(self):
return self.to(device=torch.device("cpu"))

def cuda(self):
return self.to(device=torch.device("cuda"))

# the following functions make sure **frame_data can be passed to functions
def __iter__(self):
for f in fields(self):
yield f.name

def __getitem__(self, key):
return getattr(self, key)

def __len__(self):
return len(fields(self))

@classmethod
def collate(cls, batch):
"""
Given a list objects `batch` of class `cls`, collates them into a batched
representation suitable for processing with deep networks.
"""

elem = batch[0]

if isinstance(elem, cls):
pointcloud_ids = [id(el.sequence_point_cloud) for el in batch]
id_to_idx = defaultdict(list)
for i, pc_id in enumerate(pointcloud_ids):
id_to_idx[pc_id].append(i)

sequence_point_cloud = []
sequence_point_cloud_idx = -np.ones((len(batch),))
for i, ind in enumerate(id_to_idx.values()):
sequence_point_cloud_idx[ind] = i
sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud)
assert (sequence_point_cloud_idx >= 0).all()

override_fields = {
"sequence_point_cloud": sequence_point_cloud,
"sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(),
}
# note that the pre-collate value of sequence_point_cloud_idx is unused

collated = {}
for f in fields(elem):
list_values = override_fields.get(
f.name, [getattr(d, f.name) for d in batch]
)
collated[f.name] = (
cls.collate(list_values)
if all(list_value is not None for list_value in list_values)
else None
)
return cls(**collated)

elif isinstance(elem, Pointclouds):
return join_pointclouds_as_batch(batch)

elif isinstance(elem, CamerasBase):
# TODO: don't store K; enforce working in NDC space
return join_cameras_as_batch(batch)
else:
return torch.utils.data._utils.collate.default_collate(batch)


class _GenericWorkaround:
"""
OmegaConf.structured has a weirdness when you try to apply
it to a dataclass whose first base class is a Generic which is not
Dict. The issue is with a function called get_dict_key_value_types
in omegaconf/_utils.py.
For example this fails:
@dataclass(eq=False)
class D(torch.utils.data.Dataset[int]):
a: int = 3
OmegaConf.structured(D)
We avoid the problem by adding this class as an extra base class.
"""

pass
from pytorch3d.implicitron.dataset.frame_data import FrameData
from pytorch3d.implicitron.dataset.utils import GenericWorkaround


@dataclass(eq=False)
class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]):
class DatasetBase(GenericWorkaround, torch.utils.data.Dataset[FrameData]):
"""
Base class to describe a dataset to be used with Implicitron.
Expand Down
Loading

0 comments on commit ebdbfde

Please sign in to comment.