-
Notifications
You must be signed in to change notification settings - Fork 493
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added center cropping and resize ops for PPO agents #365
Changes from all commits
bd98e86
1689a12
c8df9d7
32f25ed
0884a8d
7518d05
1988d8d
514ee1b
27ad573
d495ee2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,15 +4,19 @@ | |
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import copy | ||
import glob | ||
import numbers | ||
import os | ||
from collections import defaultdict | ||
from typing import Dict, List, Optional | ||
from typing import Any, Dict, List, Optional | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
from gym.spaces import Box | ||
|
||
from habitat import logger | ||
from habitat.utils.visualizations.utils import images_to_video | ||
from habitat_baselines.common.tensorboard_utils import TensorboardWriter | ||
|
||
|
@@ -53,20 +57,68 @@ def forward(self, x): | |
return CustomFixedCategorical(logits=x) | ||
|
||
|
||
class ResizeCenterCropper(nn.Module): | ||
def __init__(self, size, channels_last: bool = False): | ||
r"""An nn module the resizes and center crops your input. | ||
Args: | ||
size: A sequence (w, h) or int of the size you wish to resize/center_crop. | ||
If int, assumes square crop | ||
channels_list: indicates if channels is the last dimension | ||
""" | ||
super().__init__() | ||
if isinstance(size, numbers.Number): | ||
size = (int(size), int(size)) | ||
assert len(size) == 2, "forced input size must be len of 2 (w, h)" | ||
self._size = size | ||
self.channels_last = channels_last | ||
|
||
def transform_observation_space( | ||
self, observation_space, trans_keys=["rgb", "depth", "semantic"] | ||
): | ||
size = self._size | ||
observation_space = copy.deepcopy(observation_space) | ||
if size: | ||
for key in observation_space.spaces: | ||
if ( | ||
key in trans_keys | ||
and observation_space.spaces[key].shape != size | ||
): | ||
logger.info( | ||
"Overwriting CNN input size of %s: %s" % (key, size) | ||
) | ||
observation_space.spaces[key] = overwrite_gym_box_shape( | ||
observation_space.spaces[key], size | ||
) | ||
self.observation_space = observation_space | ||
return observation_space | ||
|
||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
if self._size is None: | ||
return input | ||
|
||
return center_crop( | ||
image_resize_shortest_edge( | ||
input, max(self._size), channels_last=self.channels_last | ||
), | ||
self._size, | ||
channels_last=self.channels_last, | ||
) | ||
|
||
|
||
def linear_decay(epoch: int, total_num_updates: int) -> float: | ||
r"""Returns a multiplicative factor for linear value decay | ||
|
||
Args: | ||
epoch: current epoch number | ||
total_num_updates: total number of epochs | ||
total_num_updates: total number of | ||
|
||
Returns: | ||
multiplicative factor that decreases param value linearly | ||
""" | ||
return 1 - (epoch / float(total_num_updates)) | ||
|
||
|
||
def _to_tensor(v): | ||
def _to_tensor(v) -> torch.Tensor: | ||
if torch.is_tensor(v): | ||
return v | ||
elif isinstance(v, np.ndarray): | ||
|
@@ -174,3 +226,92 @@ def generate_video( | |
tb_writer.add_video_from_np_images( | ||
f"episode{episode_id}", checkpoint_idx, images, fps=fps | ||
) | ||
|
||
|
||
def image_resize_shortest_edge( | ||
img, size: int, channels_last: bool = False | ||
) -> torch.Tensor: | ||
"""Resizes an img so that the shortest side is length of size while | ||
preserving aspect ratio. | ||
|
||
Args: | ||
img: the array object that needs to be resized (HWC) or (NHWC) | ||
size: the size that you want the shortest edge to be resize to | ||
channels: a boolean that channel is the last dimension | ||
Returns: | ||
The resized array as a torch tensor. | ||
""" | ||
img = _to_tensor(img) | ||
no_batch_dim = len(img.shape) == 3 | ||
if len(img.shape) < 3 or len(img.shape) > 5: | ||
raise NotImplementedError() | ||
if no_batch_dim: | ||
img = img.unsqueeze(0) # Adds a batch dimension | ||
if channels_last: | ||
h, w = img.shape[-3:-1] | ||
if len(img.shape) == 4: | ||
# NHWC -> NCHW | ||
img = img.permute(0, 3, 1, 2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need to do this permutations? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because PyTorch only accepts NCHW channel order for that function. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Skylion007, for each |
||
else: | ||
# NDHWC -> NDCHW | ||
img = img.permute(0, 1, 4, 2, 3) | ||
else: | ||
# ..HW | ||
h, w = img.shape[-2:] | ||
|
||
# Percentage resize | ||
scale = size / min(h, w) | ||
h = int(h * scale) | ||
w = int(w * scale) | ||
img = torch.nn.functional.interpolate( | ||
img.float(), size=(h, w), mode="area" | ||
).to(dtype=img.dtype) | ||
if channels_last: | ||
if len(img.shape) == 4: | ||
# NCHW -> NHWC | ||
img = img.permute(0, 2, 3, 1) | ||
else: | ||
# NDCHW -> NDHWC | ||
img = img.permute(0, 1, 3, 4, 2) | ||
if no_batch_dim: | ||
img = img.squeeze(dim=0) # Removes the batch dimension | ||
return img | ||
|
||
|
||
def center_crop(img, size, channels_last: bool = False): | ||
"""Performs a center crop on an image. | ||
Skylion007 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Args: | ||
img: the array object that needs to be resized (either batched or unbatched) | ||
size: A sequence (w, h) or a python(int) that you want cropped | ||
channels_last: If the channels are the last dimension. | ||
Returns: | ||
the resized array | ||
""" | ||
if channels_last: | ||
# NHWC | ||
h, w = img.shape[-3:-1] | ||
else: | ||
# NCHW | ||
h, w = img.shape[-2:] | ||
|
||
if isinstance(size, numbers.Number): | ||
size = (int(size), int(size)) | ||
assert len(size) == 2, "size should be (h,w) you wish to resize to" | ||
cropx, cropy = size | ||
|
||
startx = w // 2 - (cropx // 2) | ||
starty = h // 2 - (cropy // 2) | ||
if channels_last: | ||
return img[..., starty : starty + cropy, startx : startx + cropx, :] | ||
else: | ||
return img[..., starty : starty + cropy, startx : startx + cropx] | ||
|
||
|
||
def overwrite_gym_box_shape(box: Box, shape) -> Box: | ||
if box.shape == shape: | ||
return box | ||
shape = list(shape) + list(box.shape[len(shape) :]) | ||
low = box.low if np.isscalar(box.low) else np.min(box.low) | ||
high = box.high if np.isscalar(box.high) else np.max(box.high) | ||
return Box(low=low, high=high, shape=shape, dtype=box.dtype) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -144,7 +144,10 @@ def get_config( | |
|
||
for config_path in config_paths: | ||
config.merge_from_file(config_path) | ||
|
||
if opts: | ||
for k, v in zip(opts[0::2], opts[1::2]): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the logic behind this change? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's impossible to overwrite the BASE_TASK_CONFIG from the command line without since the BASE_TASK_CONFIG is used before its args are overwritten by the command line. Likewise, moving the code for that to this point would make it impossible to overwrite TASK_CONFIG variables from the command line. As such, BASE_TASK_CONFIG must be extracted and overwritten and then the remaining config parameters can be overwritten. |
||
if k == "BASE_TASK_CONFIG_PATH": | ||
config.BASE_TASK_CONFIG_PATH = v | ||
config.TASK_CONFIG = get_task_config(config.BASE_TASK_CONFIG_PATH) | ||
if opts: | ||
config.CMD_TRAILING_OPTS = opts | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ | |
import pytest | ||
|
||
import habitat | ||
from habitat.config import Config as CN | ||
|
||
try: | ||
from habitat_baselines.agents import ppo_agents | ||
|
@@ -25,28 +26,39 @@ | |
not baseline_installed, reason="baseline sub-module not installed" | ||
) | ||
def test_ppo_agents(): | ||
|
||
agent_config = ppo_agents.get_default_config() | ||
agent_config.MODEL_PATH = "" | ||
agent_config.defrost() | ||
config_env = habitat.get_config(config_paths=CFG_TEST) | ||
if not os.path.exists(config_env.SIMULATOR.SCENE): | ||
pytest.skip("Please download Habitat test data to data folder.") | ||
|
||
benchmark = habitat.Benchmark(config_paths=CFG_TEST) | ||
|
||
for input_type in ["blind", "rgb", "depth", "rgbd"]: | ||
config_env.defrost() | ||
config_env.SIMULATOR.AGENT_0.SENSORS = [] | ||
if input_type in ["rgb", "rgbd"]: | ||
config_env.SIMULATOR.AGENT_0.SENSORS += ["RGB_SENSOR"] | ||
if input_type in ["depth", "rgbd"]: | ||
config_env.SIMULATOR.AGENT_0.SENSORS += ["DEPTH_SENSOR"] | ||
config_env.freeze() | ||
del benchmark._env | ||
benchmark._env = habitat.Env(config=config_env) | ||
agent_config.INPUT_TYPE = input_type | ||
|
||
agent = ppo_agents.PPOAgent(agent_config) | ||
habitat.logger.info(benchmark.evaluate(agent, num_episodes=10)) | ||
for resolution in [256, 384]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That would require reconstructing the agent and benchmark for every iteration? I think there is a reason this is already done in a loop before I added the code (it would make the test a lot longer). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense, do you want to test when |
||
config_env.defrost() | ||
config_env.SIMULATOR.AGENT_0.SENSORS = [] | ||
if input_type in ["rgb", "rgbd"]: | ||
config_env.SIMULATOR.AGENT_0.SENSORS += ["RGB_SENSOR"] | ||
agent_config.RESOLUTION = resolution | ||
config_env.SIMULATOR.RGB_SENSOR.WIDTH = resolution | ||
config_env.SIMULATOR.RGB_SENSOR.HEIGHT = resolution | ||
if input_type in ["depth", "rgbd"]: | ||
config_env.SIMULATOR.AGENT_0.SENSORS += ["DEPTH_SENSOR"] | ||
agent_config.RESOLUTION = resolution | ||
config_env.SIMULATOR.DEPTH_SENSOR.WIDTH = resolution | ||
config_env.SIMULATOR.DEPTH_SENSOR.HEIGHT = resolution | ||
|
||
config_env.freeze() | ||
|
||
del benchmark._env | ||
benchmark._env = habitat.Env(config=config_env) | ||
agent_config.INPUT_TYPE = input_type | ||
|
||
agent = ppo_agents.PPOAgent(agent_config) | ||
habitat.logger.info(benchmark.evaluate(agent, num_episodes=10)) | ||
|
||
|
||
@pytest.mark.skipif( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How to disable
image_resize_shortest_edge
functionality in current setup?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There isn't a way, that's why it's ResizeCenterCropper not just CenterCropper. Currently there is no way to disable it.