Skip to content

Commit

Permalink
Refactor: FrameDataBuilder is more extensible.
Browse files Browse the repository at this point in the history
Summary:
This is mostly a refactoring diff to reduce friction in extending the frame data.

Slight functional changes: dataset getitem now accepts (seq_name, frame_number_as_singleton_tensor) as a non-advertised feature. Otherwise this code crashes:
```
item = dataset[0]
dataset[item.sequence_name, item.frame_number]
```

Reviewed By: bottler

Differential Revision: D45780175

fbshipit-source-id: 75b8e8d3dabed954a804310abdbd8ab44a8dea29
  • Loading branch information
shapovalov authored and facebook-github-bot committed May 17, 2023
1 parent d08fe6d commit b046259
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 41 deletions.
5 changes: 5 additions & 0 deletions projects/implicitron_trainer/tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ def test_yaml_contents(self):
# Check that the default config values, defined by Experiment and its
# members, is what we expect it to be.
cfg = OmegaConf.structured(experiment.Experiment)
# the following removes the possible effect of env variables
ds_arg = cfg.data_source_ImplicitronDataSource_args
ds_arg.dataset_map_provider_JsonIndexDatasetMapProvider_args.dataset_root = ""
ds_arg.dataset_map_provider_JsonIndexDatasetMapProviderV2_args.dataset_root = ""
cfg.training_loop_ImplicitronTrainingLoop_args.visdom_port = 8097
yaml = OmegaConf.to_yaml(cfg, sort_keys=False)
if DEBUG:
(DATA_DIR / "experiment.yaml").write_text(yaml)
Expand Down
77 changes: 48 additions & 29 deletions pytorch3d/implicitron/dataset/frame_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,10 @@ def crop_by_metadata_bbox_(
when no image has been loaded)
"""
if self.bbox_xywh is None:
raise ValueError("Attempted cropping by metadata with empty bounding box")
raise ValueError(
"Attempted cropping by metadata with empty bounding box. Consider either"
" to remove_empty_masks or turn off box_crop in the dataset config."
)

if not self._uncropped:
raise ValueError(
Expand Down Expand Up @@ -528,12 +531,7 @@ def __post_init__(self) -> None:
"Make sure it is set in either FrameDataBuilder or Dataset params."
)

if self.path_manager is None:
dataset_root_exists = os.path.isdir(self.dataset_root) # pyre-ignore
else:
dataset_root_exists = self.path_manager.isdir(self.dataset_root)

if load_any_blob and not dataset_root_exists:
if load_any_blob and not self._exists_in_dataset_root(""):
raise ValueError(
f"dataset_root is passed but {self.dataset_root} does not exist."
)
Expand Down Expand Up @@ -604,14 +602,27 @@ def build(
frame_data.image_size_hw = image_size_hw # original image size
# image size after crop/resize
frame_data.effective_image_size_hw = image_size_hw
image_path = None
dataset_root = self.dataset_root
if frame_annotation.image.path is not None and dataset_root is not None:
image_path = os.path.join(dataset_root, frame_annotation.image.path)
frame_data.image_path = image_path

if load_blobs and self.load_images:
(
frame_data.image_rgb,
frame_data.image_path,
) = self._load_images(frame_annotation, frame_data.fg_probability)
if image_path is None:
raise ValueError("Image path is required to load images.")

image_np = load_image(self._local_path(image_path))
frame_data.image_rgb = self._postprocess_image(
image_np, frame_annotation.image.size, frame_data.fg_probability
)

if load_blobs and self.load_depths and frame_annotation.depth is not None:
if (
load_blobs
and self.load_depths
and frame_annotation.depth is not None
and frame_annotation.depth.path is not None
):
(
frame_data.depth_map,
frame_data.depth_path,
Expand Down Expand Up @@ -652,44 +663,42 @@ def _load_fg_probability(

return fg_probability, full_path

def _load_images(
def _postprocess_image(
self,
entry: types.FrameAnnotation,
image_np: np.ndarray,
image_size: Tuple[int, int],
fg_probability: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, str]:
assert self.dataset_root is not None and entry.image is not None
path = os.path.join(self.dataset_root, entry.image.path)
image_rgb = load_image(self._local_path(path))
) -> torch.Tensor:
image_rgb = safe_as_tensor(image_np, torch.float)

if image_rgb.shape[-2:] != entry.image.size:
raise ValueError(
f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!"
)
if image_rgb.shape[-2:] != image_size:
raise ValueError(f"bad image size: {image_rgb.shape[-2:]} vs {image_size}!")

if self.mask_images:
assert fg_probability is not None
image_rgb *= fg_probability

return image_rgb, path
return image_rgb

def _load_mask_depth(
self,
entry: types.FrameAnnotation,
fg_probability: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, str, torch.Tensor]:
entry_depth = entry.depth
assert self.dataset_root is not None and entry_depth is not None
path = os.path.join(self.dataset_root, entry_depth.path)
dataset_root = self.dataset_root
assert dataset_root is not None
assert entry_depth is not None and entry_depth.path is not None
path = os.path.join(dataset_root, entry_depth.path)
depth_map = load_depth(self._local_path(path), entry_depth.scale_adjustment)

if self.mask_depths:
assert fg_probability is not None
depth_map *= fg_probability

if self.load_depth_masks:
assert entry_depth.mask_path is not None
# pyre-ignore
mask_path = os.path.join(self.dataset_root, entry_depth.mask_path)
mask_path = entry_depth.mask_path
if self.load_depth_masks and mask_path is not None:
mask_path = os.path.join(dataset_root, mask_path)
depth_mask = load_depth_mask(self._local_path(mask_path))
else:
depth_mask = torch.ones_like(depth_map)
Expand Down Expand Up @@ -745,6 +754,16 @@ def _local_path(self, path: str) -> str:
return path
return self.path_manager.get_local_path(path)

def _exists_in_dataset_root(self, relpath) -> bool:
if not self.dataset_root:
return False

full_path = os.path.join(self.dataset_root, relpath)
if self.path_manager is None:
return os.path.exists(full_path)
else:
return self.path_manager.exists(full_path)


@registry.register
class FrameDataBuilder(GenericWorkaround, GenericFrameDataBuilder[FrameData]):
Expand Down
3 changes: 3 additions & 0 deletions pytorch3d/implicitron/dataset/sql_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@ def _get_item(
seq, frame = self._index.index[frame_idx]
else:
seq, frame, *rest = frame_idx
if isinstance(frame, torch.LongTensor):
frame = frame.item()

if (seq, frame) not in self._index.index:
raise IndexError(
f"Sequence-frame index {frame_idx} not found; was it filtered out?"
Expand Down
14 changes: 9 additions & 5 deletions pytorch3d/implicitron/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,19 +225,23 @@ def resize_image(
return imre_, minscale, mask


def transpose_normalize_image(image: np.ndarray) -> np.ndarray:
im = np.atleast_3d(image).transpose((2, 0, 1))
return im.astype(np.float32) / 255.0


def load_image(path: str) -> np.ndarray:
with Image.open(path) as pil_im:
im = np.array(pil_im.convert("RGB"))
im = im.transpose((2, 0, 1))
im = im.astype(np.float32) / 255.0
return im

return transpose_normalize_image(im)


def load_mask(path: str) -> np.ndarray:
with Image.open(path) as pil_im:
mask = np.array(pil_im)
mask = mask.astype(np.float32) / 255.0
return mask[None] # fake feature channel

return transpose_normalize_image(mask)


def load_depth(path: str, scale_adjustment: float) -> np.ndarray:
Expand Down
44 changes: 37 additions & 7 deletions tests/implicitron/test_frame_data_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
load_image,
load_mask,
safe_as_tensor,
transpose_normalize_image,
)
from pytorch3d.implicitron.tools.config import get_default_args
from pytorch3d.renderer.cameras import PerspectiveCameras
Expand Down Expand Up @@ -123,14 +124,15 @@ def test_load_and_adjust_frame_data(self):
# assert bboxes shape
self.assertEqual(self.frame_data.bbox_xywh.shape, torch.Size([4]))

(
self.frame_data.image_rgb,
self.frame_data.image_path,
) = self.frame_data_builder._load_images(
self.frame_annotation, self.frame_data.fg_probability
image_path = os.path.join(
self.frame_data_builder.dataset_root, self.frame_annotation.image.path
)
image_np = load_image(self.frame_data_builder._local_path(image_path))
self.assertIsInstance(image_np, np.ndarray)
self.frame_data.image_rgb = self.frame_data_builder._postprocess_image(
image_np, self.frame_annotation.image.size, self.frame_data.fg_probability
)
self.assertEqual(type(self.frame_data.image_rgb), np.ndarray)
self.assertIsNotNone(self.frame_data.image_path)
self.assertIsInstance(self.frame_data.image_rgb, torch.Tensor)

(
self.frame_data.depth_map,
Expand Down Expand Up @@ -184,6 +186,34 @@ def test_load_and_adjust_frame_data(self):
)
self.assertEqual(type(self.frame_data.camera), PerspectiveCameras)

def test_transpose_normalize_image(self):
def inverse_transpose_normalize_image(image: np.ndarray) -> np.ndarray:
im = image * 255.0
return im.transpose((1, 2, 0)).astype(np.uint8)

# Test 2D input
input_image = np.array(
[[10, 20, 30], [40, 50, 60], [70, 80, 90]], dtype=np.uint8
)
expected_input = inverse_transpose_normalize_image(
transpose_normalize_image(input_image)
)
self.assertClose(input_image[..., None], expected_input)

# Test 3D input
input_image = np.array(
[
[[10, 20, 30], [40, 50, 60], [70, 80, 90]],
[[100, 110, 120], [130, 140, 150], [160, 170, 180]],
[[190, 200, 210], [220, 230, 240], [250, 255, 255]],
],
dtype=np.uint8,
)
expected_input = inverse_transpose_normalize_image(
transpose_normalize_image(input_image)
)
self.assertClose(input_image, expected_input)

def test_load_image(self):
path = os.path.join(self.dataset_root, self.frame_annotation.image.path)
local_path = self.path_manager.get_local_path(path)
Expand Down

0 comments on commit b046259

Please sign in to comment.