diff --git a/.circleci/config.yml b/.circleci/config.yml index a27838d52a..3fee9ffeab 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -235,6 +235,7 @@ jobs: while [ ! -f ~/miniconda/pytorch_installed ]; do sleep 2; done # wait for Pytorch pip install -e habitat-lab pip install -e habitat-baselines + pip install -e habitat-hitl - save_cache: key: conda-{{ checksum "habitat-lab/.circleci/config.yml" }}-{{ checksum "./date" }} background: true @@ -270,8 +271,18 @@ jobs: . activate habitat; cd habitat-lab export PYTHONPATH=.:$PYTHONPATH export MULTI_PROC_OFFSET=0 && export MAGNUM_LOG=quiet && export HABITAT_SIM_LOG=quiet - python -m pytest --cov-report=xml --cov-report term --cov=./ + python -m pytest test/ --cov-report=xml --cov-report term --cov=./ - codecov/upload + - run: + name: Run HITL tests + no_output_timeout: 60m + command: | + export PATH=$HOME/miniconda/bin:/usr/local/cuda/bin:$PATH + . activate habitat; cd habitat-lab + export PYTHONPATH=.:$PYTHONPATH + export MULTI_PROC_OFFSET=0 && export MAGNUM_LOG=quiet && export HABITAT_SIM_LOG=quiet + python -m habitat_sim.utils.datasets_download --uids hab3-episodes hab3_bench_assets habitat_humanoids hab_spot_arm ycb --data-path data/ --no-replace --no-prune + python -m pytest habitat-hitl/test - run: name: Run baseline training tests no_output_timeout: 30m @@ -280,9 +291,9 @@ jobs: . activate habitat; cd habitat-lab export PYTHONPATH=.:$PYTHONPATH export MULTI_PROC_OFFSET=0 && export MAGNUM_LOG=quiet && export HABITAT_SIM_LOG=quiet - # This is a flag that enables test_test_baseline_training to work + # This is a flag that enables test_baseline_training to work export TEST_BASELINE_SMALL=1 - python -m pytest test/test_baseline_training.py -s + python -m pytest test/test_baseline_training.py -s - run: name: Run Hab2.0 benchmark no_output_timeout: 30m @@ -324,6 +335,8 @@ jobs: python -c 'import habitat; print("habitat version:", habitat.__version__)' pip install habitat-baselines/ python -c 'import habitat_baselines; print("habitat_baselines version:", habitat_baselines.__version__)' + pip install habitat-hitl/ + python -c 'import habitat_hitl; print("habitat_hitl version:", habitat_hitl.__version__)' - run: &build_sdist_and_bdist name: Build sdist and bdist command: | diff --git a/docs/conf.py b/docs/conf.py index c8f876a9be..bb4c4b2728 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -103,6 +103,7 @@ "pages/habitat-lab-tdmap-viz.rst", "pages/habitat2.rst", "pages/view-transform-warp.rst", + "pages/metadata-taxonomy.rst", ] PLUGINS = [ @@ -137,6 +138,7 @@ ("Habitat Lab TopdownMap Visualization", "habitat-lab-tdmap-viz"), ("Habitat 2.0 Overview", "habitat2"), ("View, Transform and Warp", "view-transform-warp"), + ("'user_defined' Metadata Taxonomy", "metadata-taxonomy"), ], ), ("Classes", "classes", []), diff --git a/docs/pages/metadata-taxonomy.rst b/docs/pages/metadata-taxonomy.rst new file mode 100644 index 0000000000..2a80380262 --- /dev/null +++ b/docs/pages/metadata-taxonomy.rst @@ -0,0 +1,113 @@ +Taxonomy of "user_defined" configurations in habitat-lab +######################################################## + +This resource page outlines the expected taxonomy of expected metadata fields and systems in habitat-lab leveraging the non-official "user_defined" Configuration fields for objects, stages, and scenes. + +As outlined on the `Using JSON Files to configure Attributes `_ doc page, "user_defined" attributes provide a generic, reserved JSON configuration node which can be filled with user data. The intent was that no "officially supported" metadata would use this field, leaving it open for arbitrary user metadata. However, several prototype and bleeding-edge features are actively leveraging this system. The purpose of this doc page is to enumerate those known uses and their taxonomy to guide further development and avoid potential conflict with ongoing/future development. + + +`Receptacles`_ +============== + +Who: Stages, RigidObjects, and ArticulatedObjects. +Where: stage_config.json, object_config.json, ao_config.json, scene_instance.json (overrides) + +What: sub_config with key string containing "receptacle\_". "receptacle_mesh\_" defines a TriangleMeshReceptacle while "receptacle_aabb\_" defines a bounding box (AABB) Receptacle. See the `parse_receptacles_from_user_config `_ function. + +Example: + +.. code:: python + + "user_defined": { + "receptacle_mesh_table0001_receptacle_mesh": { + "name": "table0001_receptacle_mesh", + "parent_object": "0a5df6da61cd2e78e972690b501452152309e56b", #handle of the parent ManagedObject's template + "parent_link": "table0001", #if attached to an ArticulatedLink, this is the local index + "position": [0,0,0], # position of the receptacle in parent's local space + "rotation": [1,0,0,0],#orientation (quaternion) of the receptacle in parent's local space + "scale": [1,1,1], #scale of the receptacles in parent's local space + "up": [0,0,1], #up vector for the receptacle in parent's local space (for tilt culling and placement snapping) + "mesh_filepath": "table0001_receptacle_mesh.glb" #filepath for the receptacle's mesh asset (.glb with triangulated faces expected) + } + } + +`Scene Receptacle Filter Files`_ +================================ + +Who: Scene Instances +Where: scene_instance.json +What: filepath (relative to dataset root directory) to the file containing Receptacle filter strings for the scene. + +Example: + +.. code:: python + + "user_defined": { + "scene_filter_file": "scene_filter_files/102344022.rec_filter.json" + } + + +`Object States`_ +================ + +Who: RigidObjects and ArticulatedObjects +Where: object_config.json, ao_config.json, scene_instance.json (overrides) + +What: sub_config containing any fields which pertain to the ObjectStateMachine and ObjectStateSpec logic. Exact taxonomy in flux. Consider this key reserved. + +.. code:: python + + "user_defined": { + "object_states": { + + } + } + +`Marker Sets`_ +============== + +Who: RigidObjects and ArticulatedObjects +Where: object_config.json, ao_config.json, scene_instance.json (overrides) + +What: sub_config containing any 3D point sets which must be defined for various purposes. + +.. code:: python + + "user_defined": { + "marker_sets": { + + "handle_marker_sets":{ #these are handles for opening an ArticulatedObject's links. + 0: { # these marker sets are attached to link_id "0". + "handle_0": { #this is a set of 3D points. + 0: [x,y,z] #we index because JSON needs a dict and Configuration cannot digest lists + 1: [x,y,z] + 2: [x,y,z] + }, + ... + }, + ... + }, + + "faucet_marker_set":{ #these are faucet points on sinks in object local space + 0: { # these marker sets are attached to link_id "0". "-1" implies base link or rigid object. + 0: [x,y,z] #this is a faucet + ... + }, + ... + } + } + } + +`ArticulatedObject "default link"`_ +====================================== + +Who: ArticulatedObjects +Where: ao_config.json + +What: The "default" link (integer index) is the one link which should be used if only one joint can be actuated. For example, the largest or most accessible drawer or door. Cannot be base link (-1). + +.. code:: python + + "user_defined": { + "default_link": 5 #the link id which is "default" + } diff --git a/examples/hitl/basic_viewer/basic_viewer.py b/examples/hitl/basic_viewer/basic_viewer.py index 7a0f6c9075..f4dcf5906a 100644 --- a/examples/hitl/basic_viewer/basic_viewer.py +++ b/examples/hitl/basic_viewer/basic_viewer.py @@ -73,7 +73,6 @@ def _update_lookat_pos(self): self._get_camera_lookat_pos(), radius, mn.Color3(255 / 255, 0 / 255, 0 / 255), - 24, ) @property diff --git a/examples/hitl/pick_throw_vr/pick_throw_vr.py b/examples/hitl/pick_throw_vr/pick_throw_vr.py index 6184e41efa..f3e105d91f 100644 --- a/examples/hitl/pick_throw_vr/pick_throw_vr.py +++ b/examples/hitl/pick_throw_vr/pick_throw_vr.py @@ -79,7 +79,9 @@ def __init__(self, app_service: AppService): assert not self._app_service.hitl_config.camera.first_person_mode self._nav_helper = GuiNavigationHelper( - self._app_service, self.get_gui_controlled_agent_index() + self._app_service, + self.get_gui_controlled_agent_index(), + user_index=0, ) self._throw_helper = GuiThrowHelper( self._app_service, self.get_gui_controlled_agent_index() @@ -483,12 +485,10 @@ def _get_target_object_positions(self): ) def _draw_circle(self, pos, color, radius): - num_segments = 24 self._app_service.gui_drawer.draw_circle( pos, radius, color, - num_segments, ) def _add_target_object_highlight_ring( diff --git a/examples/hitl/rearrange/rearrange.py b/examples/hitl/rearrange/rearrange.py index c5aee71131..cab259d7c2 100644 --- a/examples/hitl/rearrange/rearrange.py +++ b/examples/hitl/rearrange/rearrange.py @@ -84,7 +84,9 @@ def __init__( ) self._nav_helper = GuiNavigationHelper( - self._app_service, self.get_gui_controlled_agent_index() + self._app_service, + self.get_gui_controlled_agent_index(), + user_index=0, ) self._episode_helper = self._app_service.episode_helper @@ -121,7 +123,7 @@ def _update_grasping_and_set_act_hints(self): color = mn.Color3(0, 255 / 255, 0) # green goal_position = self._goal_positions[self._held_target_obj_idx] self._app_service.gui_drawer.draw_circle( - goal_position, end_radius, color, 24 + goal_position, end_radius, color ) self._nav_helper.draw_nav_hint_from_agent( @@ -138,7 +140,6 @@ def _update_grasping_and_set_act_hints(self): can_place_position, self._can_grasp_place_threshold, mn.Color3(255 / 255, 255 / 255, 0), - 24, ) if self._app_service.gui_input.get_key_down(GuiInput.KeyNS.SPACE): @@ -272,7 +273,6 @@ def _update_task(self): can_grasp_position, self._can_grasp_place_threshold, mn.Color3(255 / 255, 255 / 255, 0), - 24, ) def get_gui_controlled_agent_index(self): diff --git a/examples/hitl/rearrange_v2/README.md b/examples/hitl/rearrange_v2/README.md index f57a409da8..a9dc38317b 100644 --- a/examples/hitl/rearrange_v2/README.md +++ b/examples/hitl/rearrange_v2/README.md @@ -12,6 +12,15 @@ git clone --branch articulated-scenes --single-branch --depth 1 https://huggingf mv fphab fpss ``` +To test the Habitat-LLM episodes in `rearrange_v2` you'll need to download and unzip the following [episode dataset](https://drive.google.com/file/d/1zFCBiWE_XFY0Ry9CZOV_NF_rfxBw1y-F/view?usp=sharing) in Habitat-Lab root directory. In addition, you'll need YCB, GSO, AI2THOR, and ABO object assets. To download these assets use the following commands: + +``` +cd data +git clone https://huggingface.co/datasets/ai-habitat/OVMM_objects objects --recursive +cd objects +git checkout 3893a735352b92d46505f35d759553f5fc82a39b +``` + ## Data directory Run `rearrange_v2` from the Habitat-lab root directory. It will expect `data/` for Habitat-lab data, and it will also look for `examples/hitl/rearrange_v2/app_data/demo.json.gz` (included alongside source files in our git repo). @@ -33,6 +42,11 @@ Headless server: python examples/hitl/rearrange_v2/rearrange_v2.py +experiment=headless_server ``` +To test Habitat-LLM episodes using a user-controlled humanoid use: +```bash +python examples/hitl/rearrange_v2/rearrange_v2.py --config-name lang_rearrange_humanoid_only +``` + ## Controls See on-screen help text. In addition, press `1` or `2` to select an episode. diff --git a/examples/hitl/rearrange_v2/config/lang_rearrange_humanoid_only.yaml b/examples/hitl/rearrange_v2/config/lang_rearrange_humanoid_only.yaml new file mode 100644 index 0000000000..ce17cef51f --- /dev/null +++ b/examples/hitl/rearrange_v2/config/lang_rearrange_humanoid_only.yaml @@ -0,0 +1,40 @@ +# @package _global_ + +defaults: + - language_rearrange + - hitl_defaults + - _self_ + +habitat: + # various config args to ensure the episode never ends + environment: + max_episode_steps: 0 + iterator_options: + # For the demo, we want to showcase the episodes in the specified order + shuffle: False + +habitat_baselines: + # todo: document these choices + eval: + should_load_ckpt: False + rl: + agent: + num_pool_agents_per_type: [1, 1] + policy: + + +habitat_hitl: + window: + title: "Rearrange" + width: 1300 + height: 1000 + gui_controlled_agents: + - agent_index: 0 + lin_speed: 10.0 + ang_speed: 300 + hide_humanoid_in_gui: True + camera: + first_person_mode: True + data_collection: + save_filepath_base: my_session + save_episode_record: True diff --git a/examples/hitl/rearrange_v2/config/language_rearrange.yaml b/examples/hitl/rearrange_v2/config/language_rearrange.yaml new file mode 100644 index 0000000000..349f016cb3 --- /dev/null +++ b/examples/hitl/rearrange_v2/config/language_rearrange.yaml @@ -0,0 +1,87 @@ +# This config is derived from habitat-lab/habitat/config/benchmark/multi_agent/hssd_spot_human.yaml +# @package _global_ + +defaults: + - /habitat: habitat_config_base + - /habitat/task: task_config_base + + - /habitat/simulator/sensor_setups@habitat.simulator.agents.main_agent: rgbd_head_agent + - /habitat/simulator/agents@habitat.simulator.agents.main_agent: human + + - /habitat/dataset/rearrangement: hssd + + - /habitat/task/actions@habitat.task.actions.base_velocity: base_velocity + - /habitat/task/actions@habitat.task.actions.rearrange_stop: rearrange_stop + + - /habitat/task/measurements: + - num_steps + - /habitat/task/lab_sensors: + - relative_resting_pos_sensor + - target_start_sensor + - goal_sensor + - joint_sensor + - is_holding_sensor + - end_effector_sensor + - target_start_gps_compass_sensor + - target_goal_gps_compass_sensor + - localization_sensor + + - _self_ + +habitat: + task: + type: RearrangeEmptyTask-v0 + reward_measure: num_steps + success_measure: num_steps + success_reward: 10.0 + min_distance_start_agents: 5.0 + slack_reward: -0.0005 + end_on_success: True + constraint_violation_ends_episode: False + constraint_violation_drops_object: True + task_spec_base_path: benchmark/multi_agent/ + task_spec: pddl/multi_agent_tidy_house + pddl_domain_def: fp + actions: + base_velocity: + lin_speed: 40.0 + ang_speed: 20.0 + + robot_at_thresh: 3.0 + gym: + obs_keys: + - head_depth + - relative_resting_position + - obj_start_sensor + - obj_goal_sensor + - obj_start_gps_compass + - obj_goal_gps_compass + - is_holding + - ee_pos + - localization_sensor + - has_finished_oracle_nav + environment: + max_episode_steps: 750 + simulator: + type: RearrangeSim-v0 + seed: 100 + additional_object_paths: + - "data/objects/ycb/configs/" + - "data/objects_ovmm/train_val/ai2thorhab/configs/objects/" + - "data/objects_ovmm/train_val/amazon_berkeley/configs/" + - "data/objects_ovmm/train_val/google_scanned/configs/" + - "data/objects_ovmm/train_val/hssd/configs/objects/" + concur_render: True + auto_sleep: True + agents_order: + - main_agent + + kinematic_mode: True + ac_freq_ratio: 1 + step_physics: False + + habitat_sim_v0: + allow_sliding: True + enable_physics: True + dataset: + data_path: data/datasets/hssd/llm_rearrange/v2/60scenes_dataset_776eps_with_eval.json.gz diff --git a/examples/hitl/rearrange_v2/config/rearrange_v2.yaml b/examples/hitl/rearrange_v2/config/rearrange_v2.yaml index 40ed567553..a5de6659a6 100644 --- a/examples/hitl/rearrange_v2/config/rearrange_v2.yaml +++ b/examples/hitl/rearrange_v2/config/rearrange_v2.yaml @@ -23,7 +23,6 @@ habitat: data_path: examples/hitl/rearrange_v2/app_data/demo.json.gz - habitat_baselines: # todo: document these choices eval: diff --git a/examples/hitl/rearrange_v2/rearrange_v2.py b/examples/hitl/rearrange_v2/rearrange_v2.py index 3b619a4977..7bb1bad8f5 100644 --- a/examples/hitl/rearrange_v2/rearrange_v2.py +++ b/examples/hitl/rearrange_v2/rearrange_v2.py @@ -9,6 +9,7 @@ import hydra import magnum as mn +import numpy as np import habitat_sim from habitat.sims.habitat_simulator import sim_utilities @@ -29,12 +30,70 @@ from habitat_hitl.environment.gui_pick_helper import GuiPickHelper from habitat_hitl.environment.gui_placement_helper import GuiPlacementHelper from habitat_hitl.environment.hablab_utils import get_agent_art_obj_transform +from habitat_sim.utils.common import quat_from_magnum, quat_to_coeffs ENABLE_ARTICULATED_OPEN_CLOSE = False # Visually snap picked objects into the humanoid's hand. May be useful in third-person mode. Beware that this conflicts with GuiPlacementHelper. DO_HUMANOID_GRASP_OBJECTS = False +class DataLogger: + def __init__(self, app_service): + self._app_service = app_service + self._sim = app_service.sim + + def get_num_agents(self): + return len(self._sim.agents_mgr._all_agent_data) + + def get_agents_state(self): + agent_states = [] + for agent_idx in range(self.get_num_agents()): + agent_root = get_agent_art_obj_transform(self._sim, agent_idx) + position = np.array(agent_root.translation).tolist() + rotation = mn.Quaternion.from_matrix(agent_root.rotation()) + rotation = quat_to_coeffs(quat_from_magnum(rotation)).tolist() + + snap_idx = self._sim.agents_mgr._all_agent_data[ + agent_idx + ].grasp_mgr.snap_idx + agent_states.append( + { + "position": position, + "rotation": rotation, + "grasp_mgr_snap_idx": snap_idx, + } + ) + return agent_states + + def get_objects_state(self): + object_states = [] + rom = self._sim.get_rigid_object_manager() + for object_handle, rel_idx in self._sim._handle_to_object_id.items(): + obj_id = self._sim._scene_obj_ids[rel_idx] + ro = rom.get_object_by_id(obj_id) + position = np.array(ro.translation).tolist() + rotation = quat_to_coeffs(quat_from_magnum(ro.rotation)).tolist() + object_states.append( + { + "position": position, + "rotation": rotation, + "object_handle": object_handle, + "object_id": obj_id, + } + ) + return object_states + + def record_state(self, task_completed: bool = False): + agent_states = self.get_agents_state() + object_states = self.get_objects_state() + + self._app_service.step_recorder.record("agent_states", agent_states) + self._app_service.step_recorder.record("object_states", object_states) + self._app_service.step_recorder.record( + "task_completed", task_completed + ) + + class AppStateRearrangeV2(AppState): """ Todo @@ -58,16 +117,17 @@ def __init__(self, app_service): self._recent_reach_pos = None self._paused = False self._hide_gui_text = False + self._can_place_object = False self._camera_helper = CameraHelper( self._app_service.hitl_config, self._app_service.gui_input, ) - self._pick_helper = GuiPickHelper( - self._app_service, + self._pick_helper = GuiPickHelper(self._app_service, user_index=0) + self._placement_helper = GuiPlacementHelper( + self._app_service, user_index=0 ) - self._placement_helper = GuiPlacementHelper(self._app_service) self._client_helper = None if self._app_service.hitl_config.networking.enable: self._client_helper = ClientHelper(self._app_service) @@ -76,6 +136,9 @@ def __init__(self, app_service): self._frame_counter = 0 self._sps_tracker = AverageRateTracker(2.0) + self._task_instruction = "" + self._data_logger = DataLogger(app_service=self._app_service) + # needed to avoid spurious mypy attr-defined errors @staticmethod def get_sim_utilities() -> Any: @@ -171,6 +234,13 @@ def on_environment_reset(self, episode_recorder_dict): self._camera_helper.update(self._get_camera_lookat_pos(), dt=0) + # Set the task instruction + current_episode = self._app_service.env.current_episode + if current_episode.info.get("extra_info") is not None: + self._task_instruction = current_episode.info["extra_info"][ + "instruction" + ] + client_message_manager = self._app_service.client_message_manager if client_message_manager: client_message_manager.signal_scene_change() @@ -203,14 +273,17 @@ def _update_grasping_and_set_act_hints(self, user_index): # todo: implement grasping properly for each user. _held_obj_id, _has_grasp_preview, etc. must be tracked per user. if self._held_obj_id is not None: - if self._get_user_key_down(user_index, GuiInput.KeyNS.SPACE): + if ( + self._get_user_key_down(user_index, GuiInput.KeyNS.SPACE) + and self._can_place_object + ): if DO_HUMANOID_GRASP_OBJECTS: # todo: better drop pos drop_pos = self._get_gui_agent_translation( user_index ) # self._gui_agent_controllers.get_base_translation() else: - # GuiPlacementHelper has already placed this object, so nothing to do here + # GuiPlacementHelper has already placed this object. pass self._held_obj_id = None else: @@ -281,6 +354,7 @@ def get_grasp_release_controls_text(): controls_str += "I, K: look up, down\n" controls_str += "A, D: turn\n" controls_str += "W/F, S/V: walk\n" + controls_str += "N: next episode\n" if ENABLE_ARTICULATED_OPEN_CLOSE: controls_str += "Z/X: open/close receptacle\n" controls_str += get_grasp_release_controls_text() @@ -292,6 +366,8 @@ def get_grasp_release_controls_text(): def _get_status_text(self): status_str = "" + if len(self._task_instruction) > 0: + status_str += "\nInstruction: " + self._task_instruction + "\n" if self._paused: status_str += "\n\npaused\n" if ( @@ -316,6 +392,8 @@ def _update_help_text(self): self._app_service.text_drawer.add_text( status_str, TextOnScreenAlignment.TOP_CENTER, + text_delta_x=-280, + text_delta_y=-50, ) def _get_camera_lookat_pos(self): @@ -331,40 +409,13 @@ def is_user_idle_this_frame(self): return not self._app_service.gui_input.get_any_key_down() def _check_change_episode(self): - if self._paused: + if self._paused or not self._app_service.gui_input.get_key_down( + GuiInput.KeyNS.N + ): return - # episode_id should be a string, e.g. "5" - episode_ids_by_dataset = { - "data/datasets/hssd/rearrange/{split}/social_rearrange.json.gz": [ - "23775", - "23776", - ] - } - fallback_episode_ids = ["0", "1"] - dataset_key = self._app_service.config.habitat.dataset.data_path - episode_ids = ( - episode_ids_by_dataset[dataset_key] - if dataset_key in episode_ids_by_dataset - else fallback_episode_ids - ) - - # use number keys to select episode - episode_index_by_key = { - GuiInput.KeyNS.ONE: 0, - GuiInput.KeyNS.TWO: 1, - } - assert len(episode_index_by_key) == len(episode_ids) - - for key in episode_index_by_key: - if self._app_service.gui_input.get_key_down(key): - episode_id = episode_ids[episode_index_by_key[key]] - # episode_id should be a string, e.g. "5" - assert isinstance(episode_id, str) - self._app_service.episode_helper.set_next_episode_by_id( - episode_id - ) - self._app_service.end_episode(do_reset=True) + if self._app_service.episode_helper.next_episode_exists(): + self._app_service.end_episode(do_reset=True) def _update_held_object_placement(self): if not self._held_obj_id: @@ -380,6 +431,9 @@ def _update_held_object_placement(self): if self._placement_helper.update(ray, self._held_obj_id): # sloppy: save another keyframe here since we just moved the held object self.get_sim().gfx_replay_manager.save_keyframe() + self._can_place_object = True + else: + self._can_place_object = False def sim_update(self, dt, post_sim_update_dict): if ( @@ -446,6 +500,12 @@ def sim_update(self, dt, post_sim_update_dict): self._update_help_text() + def record_state(self): + task_completed = self._app_service.gui_input.get_key_down( + GuiInput.KeyNS.N + ) + self._data_logger.record_state(task_completed=task_completed) + @hydra.main( version_base=None, config_path="config", config_name="rearrange_v2" diff --git a/habitat-hitl/habitat_hitl/__init__.py b/habitat-hitl/habitat_hitl/__init__.py index 0f0db8cad5..4aba23fc53 100644 --- a/habitat-hitl/habitat_hitl/__init__.py +++ b/habitat-hitl/habitat_hitl/__init__.py @@ -3,3 +3,5 @@ # Copyright (c) Meta Platforms, Inc. and its affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + +from habitat_hitl.version import VERSION as __version__ # noqa: F401 diff --git a/habitat-hitl/habitat_hitl/_internal/gui_application.py b/habitat-hitl/habitat_hitl/_internal/gui_application.py index 8b85744251..086dfeb442 100644 --- a/habitat-hitl/habitat_hitl/_internal/gui_application.py +++ b/habitat-hitl/habitat_hitl/_internal/gui_application.py @@ -7,6 +7,7 @@ import abc import math import time +from typing import List import magnum as mn from magnum.platform.glfw import Application @@ -32,14 +33,15 @@ def unproject(self, viewport_pos): class InputHandlerApplication(Application): def __init__(self, config): super().__init__(config) - self._gui_inputs = [] + self._gui_inputs: List[GuiInput] = [] + self._mouse_ray = None - def add_gui_input(self, gui_input): + def add_gui_input(self, gui_input: GuiInput) -> None: self._gui_inputs.append(gui_input) def key_press_event(self, event: Application.KeyEvent) -> None: - key = MagnumKeyConverter.convert(event.key) - if key: + key = MagnumKeyConverter.convert_key(event.key) + if key is not None: for wrapper in self._gui_inputs: # If the key is already held, this is a repeat press event and we should # ignore it. @@ -48,30 +50,30 @@ def key_press_event(self, event: Application.KeyEvent) -> None: wrapper._key_down.add(key) def key_release_event(self, event: Application.KeyEvent) -> None: - key = MagnumKeyConverter.convert(event.key) - if key: + key = MagnumKeyConverter.convert_key(event.key) + if key is not None: for wrapper in self._gui_inputs: if key in wrapper._key_held: wrapper._key_held.remove(key) wrapper._key_up.add(key) def mouse_press_event(self, event: Application.MouseEvent) -> None: - mouse_button = event.button - GuiInput.validate_mouse_button(mouse_button) - for wrapper in self._gui_inputs: - wrapper._mouse_button_held.add(mouse_button) - wrapper._mouse_button_down.add(mouse_button) + key = MagnumKeyConverter.convert_mouse_button(event.button) + if key is not None: + for wrapper in self._gui_inputs: + # If the key is already held, this is a repeat press event and we should + # ignore it. + if key not in wrapper._mouse_button_held: + wrapper._mouse_button_held.add(key) + wrapper._mouse_button_down.add(key) def mouse_release_event(self, event: Application.MouseEvent) -> None: - mouse_button = event.button - GuiInput.validate_mouse_button(mouse_button) - for wrapper in self._gui_inputs: - # In theory, mouse_button should always be present in _mouse_button_held. - # In practice, we seem to get spurious release events due to the app - # losing focus (e.g. switching to VS code debugger while mouse-clicking) - if mouse_button in wrapper._mouse_button_held: - wrapper._mouse_button_held.remove(mouse_button) - wrapper._mouse_button_up.add(mouse_button) + key = MagnumKeyConverter.convert_mouse_button(event.button) + if key is not None: + for wrapper in self._gui_inputs: + if key in wrapper._mouse_button_held: + wrapper._mouse_button_held.remove(key) + wrapper._mouse_button_up.add(key) def mouse_scroll_event(self, event: Application.MouseEvent) -> None: # shift+scroll is forced into x direction on mac, seemingly at OS level, @@ -109,10 +111,13 @@ def mouse_move_event(self, event: Application.MouseMoveEvent) -> None: wrapper._mouse_position = mouse_pos wrapper._relative_mouse_position[0] += relative_mouse_position[0] wrapper._relative_mouse_position[1] += relative_mouse_position[1] + if self._mouse_ray: + wrapper._mouse_ray = self._mouse_ray def update_mouse_ray(self, unproject_fn): - for wrapper in self._gui_inputs: - wrapper._mouse_ray = unproject_fn(wrapper._mouse_position) + if len(self._gui_inputs) > 0: + gui_input = self._gui_inputs[0] + self._mouse_ray = unproject_fn(gui_input._mouse_position) class GuiApplication(InputHandlerApplication): diff --git a/habitat-hitl/habitat_hitl/_internal/hitl_driver.py b/habitat-hitl/habitat_hitl/_internal/hitl_driver.py index 11f8591115..10e2a0e1ec 100644 --- a/habitat-hitl/habitat_hitl/_internal/hitl_driver.py +++ b/habitat-hitl/habitat_hitl/_internal/hitl_driver.py @@ -171,16 +171,17 @@ def __init__( self._episode_helper = EpisodeHelper(self.habitat_env) + # TODO: Only one user is currently supported. + users = Users(1) + self._client_message_manager = None if self.network_server_enabled: - # TODO: Only one user is currently supported. - users = Users(1) self._client_message_manager = ClientMessageManager(users) gui_drawer = GuiDrawer(debug_line_drawer, self._client_message_manager) gui_drawer.set_line_width(self._hitl_config.debug_line_width) - self._check_init_server(gui_drawer, gui_input) + self._check_init_server(gui_drawer, gui_input, users) def local_end_episode(do_reset=False): self._end_episode(do_reset) @@ -231,22 +232,22 @@ def close(self): def network_server_enabled(self) -> bool: return self._hitl_config.networking.enable - def _check_init_server(self, gui_drawer: GuiDrawer, gui_input: GuiInput): + def _check_init_server( + self, gui_drawer: GuiDrawer, server_gui_input: GuiInput, users: Users + ): self._remote_client_state = None self._interprocess_record = None if self.network_server_enabled: - # How many frames we can simulate "ahead" of what keyframes have been sent. - # A larger value increases lag on the client, while ensuring a more reliable - # simulation rate in the presence of unreliable network comms. - # See also server.py max_send_rate - max_steps_ahead = 5 self._interprocess_record = InterprocessRecord( - self._hitl_config.networking, max_steps_ahead + self._hitl_config.networking ) launch_networking_process(self._interprocess_record) self._remote_client_state = RemoteClientState( - self._interprocess_record, gui_drawer, gui_input + self._interprocess_record, gui_drawer, users ) + # Bind the server input to user 0 + if self._hitl_config.networking.client_sync.server_input: + self._remote_client_state.bind_gui_input(server_gui_input, 0) def _check_terminate_server(self): if self.network_server_enabled: diff --git a/habitat-hitl/habitat_hitl/_internal/networking/interprocess_record.py b/habitat-hitl/habitat_hitl/_internal/networking/interprocess_record.py index 878df6089f..17ecce4df7 100644 --- a/habitat-hitl/habitat_hitl/_internal/networking/interprocess_record.py +++ b/habitat-hitl/habitat_hitl/_internal/networking/interprocess_record.py @@ -5,9 +5,14 @@ # LICENSE file in the root directory of this source tree. from multiprocessing import Queue -from multiprocessing import Semaphore as create_semaphore -from multiprocessing.synchronize import Semaphore -from typing import Any, List, Optional +from typing import List, Optional + +from habitat_hitl.core.types import ( + ClientState, + ConnectionRecord, + DataDict, + Keyframe, +) class InterprocessRecord: @@ -15,51 +20,41 @@ class InterprocessRecord: Utility that stores incoming (client state) and outgoing (keyframe) data such as it can be used by concurrent threads. """ - def __init__(self, networking_config, max_steps_ahead: int) -> None: + def __init__(self, networking_config) -> None: self._networking_config = networking_config - self._keyframe_queue: Queue = Queue() - self._client_state_queue: Queue = Queue() - self._step_semaphore: Semaphore = create_semaphore(max_steps_ahead) - self._connection_record_queue: Queue = Queue() + self._keyframe_queue: Queue[Keyframe] = Queue() + self._client_state_queue: Queue[ClientState] = Queue() + self._connection_record_queue: Queue[ConnectionRecord] = Queue() - def send_keyframe_to_networking_thread(self, keyframe) -> None: + def send_keyframe_to_networking_thread(self, keyframe: Keyframe) -> None: """Send a keyframe (outgoing data) to the networking thread.""" # Acquire the semaphore to ensure the simulation doesn't advance too far ahead - self._step_semaphore.acquire() self._keyframe_queue.put(keyframe) - def send_client_state_to_main_thread(self, client_state) -> None: + def send_client_state_to_main_thread( + self, client_state: ClientState + ) -> None: """Send a client state (incoming data) to the main thread.""" self._client_state_queue.put(client_state) - def send_connection_record_to_main_thread(self, connection_record) -> None: + def send_connection_record_to_main_thread( + self, connection_record: ConnectionRecord + ) -> None: """Send a connection record to the main thread.""" assert "connectionId" in connection_record assert "isClientReady" in connection_record self._connection_record_queue.put(connection_record) - def get_queued_keyframes(self) -> List[Any]: - """Dequeue all keyframes.""" - keyframes = [] - - while not self._keyframe_queue.empty(): - keyframe = self._keyframe_queue.get(block=False) - keyframes.append(keyframe) - self._step_semaphore.release() - - return keyframes - - def get_single_queued_keyframe(self) -> Optional[Any]: + def get_single_queued_keyframe(self) -> Optional[Keyframe]: """Dequeue one keyframe.""" if self._keyframe_queue.empty(): return None keyframe = self._keyframe_queue.get(block=False) - self._step_semaphore.release() return keyframe @staticmethod - def _dequeue_all(queue: Queue) -> List[Any]: + def _dequeue_all(queue: Queue) -> List[DataDict]: """Dequeue all items from a queue.""" items = [] @@ -69,10 +64,14 @@ def _dequeue_all(queue: Queue) -> List[Any]: return items - def get_queued_client_states(self) -> List[Any]: + def get_queued_keyframes(self) -> List[Keyframe]: + """Dequeue all keyframes.""" + return self._dequeue_all(self._keyframe_queue) + + def get_queued_client_states(self) -> List[ClientState]: """Dequeue all client states.""" return self._dequeue_all(self._client_state_queue) - def get_queued_connection_records(self) -> List[Any]: + def get_queued_connection_records(self) -> List[ConnectionRecord]: """Dequeue all connection records.""" return self._dequeue_all(self._connection_record_queue) diff --git a/habitat-hitl/habitat_hitl/_internal/networking/keyframe_utils.py b/habitat-hitl/habitat_hitl/_internal/networking/keyframe_utils.py index 6971796f99..2200183be9 100644 --- a/habitat-hitl/habitat_hitl/_internal/networking/keyframe_utils.py +++ b/habitat-hitl/habitat_hitl/_internal/networking/keyframe_utils.py @@ -4,10 +4,12 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Any +from habitat_hitl.core.types import Keyframe -def update_consolidated_keyframe(consolidated_keyframe, inc_keyframe): +def update_consolidated_keyframe( + consolidated_keyframe: Keyframe, inc_keyframe: Keyframe +) -> None: """ A "consolidated" keyframe is several incremental keyframes merged together. See nearly duplicate logic in habitat-sim Recorder::addLoadsCreationsDeletions. @@ -37,26 +39,28 @@ def ensure_dict(keyframe, key): for state_update in inc_keyframe["stateUpdates"]: key = state_update["instanceKey"] state = state_update["state"] - found = False - for con_state_update in consolidated_keyframe["stateUpdates"]: - if con_state_update["instanceKey"] == key: - con_state_update["state"] = state - found = True - if not found: - consolidated_keyframe["stateUpdates"].append(state_update) + if "stateUpdates" in consolidated_keyframe: + found = False + for con_state_update in consolidated_keyframe["stateUpdates"]: + if con_state_update["instanceKey"] == key: + con_state_update["state"] = state + found = True + if not found: + consolidated_keyframe["stateUpdates"].append(state_update) # add or update rigUpdates if "rigUpdates" in inc_keyframe: for rig_update in inc_keyframe["rigUpdates"]: key = rig_update["id"] pose = rig_update["pose"] - found = False - for con_rig_update in consolidated_keyframe["rigUpdates"]: - if con_rig_update["id"] == key: - con_rig_update["pose"] = pose - found = True - if not found: - consolidated_keyframe["rigUpdates"].append(rig_update) + if "rigUpdates" in consolidated_keyframe: + found = False + for con_rig_update in consolidated_keyframe["rigUpdates"]: + if con_rig_update["id"] == key: + con_rig_update["pose"] = pose + found = True + if not found: + consolidated_keyframe["rigUpdates"].append(rig_update) # append creations if "creations" in inc_keyframe: @@ -76,17 +80,18 @@ def ensure_dict(keyframe, key): # the creation and otherwise skip this deletion. This logic ensures # consolidated keyframes don't get bloated as many items are added # and removed over time. - con_creations = consolidated_keyframe["creations"] - found = False - for entry in con_creations: - if entry["instanceKey"] == key: - con_creations.remove(entry) - found = True - break - if not found: - # if we didn't find the creation, then we should still include the deletion - ensure_list(consolidated_keyframe, "deletions") - consolidated_keyframe["deletions"].append(key) + if "creations" in consolidated_keyframe: + con_creations = consolidated_keyframe["creations"] + found = False + for entry in con_creations: + if entry["instanceKey"] == key: + con_creations.remove(entry) + found = True + break + if not found: + # if we didn't find the creation, then we should still include the deletion + ensure_list(consolidated_keyframe, "deletions") + consolidated_keyframe["deletions"].append(key) # remove stateUpdates for the deleted keys if "stateUpdates" in consolidated_keyframe: @@ -108,8 +113,8 @@ def ensure_dict(keyframe, key): # todo: lights, userTransforms -def get_empty_keyframe(): - keyframe: Any = dict() +def get_empty_keyframe() -> Keyframe: + keyframe: Keyframe = dict() keyframe["loads"] = [] keyframe["creations"] = [] keyframe["rigCreations"] = [] diff --git a/habitat-hitl/habitat_hitl/_internal/networking/networking_process.py b/habitat-hitl/habitat_hitl/_internal/networking/networking_process.py index 736bba6bf9..5e0ae9eb2f 100644 --- a/habitat-hitl/habitat_hitl/_internal/networking/networking_process.py +++ b/habitat-hitl/habitat_hitl/_internal/networking/networking_process.py @@ -9,9 +9,10 @@ import os import signal import ssl +import traceback from datetime import datetime, timedelta from multiprocessing import Process -from typing import Any, Dict, Optional +from typing import Dict, List, Optional import aiohttp.web import websockets @@ -28,6 +29,7 @@ get_empty_keyframe, update_consolidated_keyframe, ) +from habitat_hitl.core.types import ClientState, ConnectionRecord, Keyframe # Boolean variable to indicate whether to use SSL use_ssl = False @@ -107,7 +109,7 @@ def __init__(self, interprocess_record: InterprocessRecord): self._waiting_for_app_ready = False self._recent_connection_activity_timestamp: Optional[datetime] = None - def update_consolidated_keyframes(self, keyframes) -> None: + def update_consolidated_keyframes(self, keyframes: List[Keyframe]) -> None: for inc_keyframe in keyframes: update_consolidated_keyframe( self._consolidated_keyframe, inc_keyframe @@ -119,7 +121,7 @@ async def receive_client_states(self, websocket: ClientConnection) -> None: self._recent_connection_activity_timestamp = datetime.now() try: # Parse the received message as a JSON object - client_state = json.loads(message) + client_state: ClientState = json.loads(message) client_state["connectionId"] = connection_id @@ -232,8 +234,8 @@ def handle_disconnect(self) -> None: print(f"Closed connection to client {websocket.remote_address}") del self._connected_clients[websocket_id] - def parse_connection_record(self, message: str) -> Any: - connection_record = None + def parse_connection_record(self, message: str) -> ConnectionRecord: + connection_record: ConnectionRecord if message == "client ready!": # legacy message format for initial client message connection_record = {"isClientReady": True} @@ -365,7 +367,7 @@ async def networking_main_async( network_mgr = NetworkManager(interprocess_record) - # Start servers + # Start servers. websocket_server = await start_websocket_server( network_mgr, networking_config ) @@ -375,36 +377,65 @@ async def networking_main_async( else None ) - check_keyframe_queue_task = asyncio.ensure_future( - network_mgr.check_keyframe_queue() + # Define tasks (concurrent looping coroutines). + tasks: List[asyncio.Future] = [] + tasks.append(asyncio.create_task(network_mgr.check_keyframe_queue())) + tasks.append( + asyncio.create_task(network_mgr.check_close_broken_connection()) ) - check_close_broken_connection_task = asyncio.ensure_future( - network_mgr.check_close_broken_connection() - ) - - # Handle SIGTERM. We should get this signal when we do networking_process.terminate(). See terminate_networking_process. + # Handle termination signals. + # We should get SIGTERM when we do networking_process.terminate(). See terminate_networking_process. stop: asyncio.Future = asyncio.Future() loop = asyncio.get_event_loop() - loop.add_signal_handler(signal.SIGTERM, stop.set_result, None) - - # This await essentially means "wait forever" (or until we get SIGTERM). Meanwhile, the other tasks we've started above (websocket server, http server, check_keyframe_queue_task) will also run forever in the asyncio event loop. - await stop - - # Do cleanup code after we've received SIGTERM: close both servers and cancel check_keyframe_queue_task. + stop_signals = [ + signal.SIGTERM, + signal.SIGQUIT, + signal.SIGINT, + signal.SIGHUP, + ] + for stop_signal in stop_signals: + loop.add_signal_handler(stop_signal, stop.set_result, None) + # Add the stop signal as a task. + tasks.append(stop) + + # Run tasks. + abort = False + while tasks: + # Execute tasks until one is done (or fails). + done_tasks, pending = await asyncio.wait( + tasks, return_when=asyncio.FIRST_COMPLETED + ) + for task in done_tasks: + # Print exception for failed tasks. + try: + await task + except Exception as e: + print(f"Exception raised in network process. Aborting: {e}.") + traceback.print_exc() + abort = True + # Abort if exception was raised, or if a termination signal was caught. + if abort or stop.done(): + if stop.done(): + print(f"Caught termination signal: {stop.result}.") + break + # Resume pending tasks. + tasks = pending + + # Terminate network process. + print("Networking process terminating...") + + # Close servers. websocket_server.close() await websocket_server.wait_closed() if http_runner: await http_runner.cleanup() - check_keyframe_queue_task.cancel() - check_close_broken_connection_task.cancel() - def networking_main(interprocess_record: InterprocessRecord) -> None: # Set up the event loop and run the main coroutine loop = asyncio.get_event_loop() loop.run_until_complete(networking_main_async(interprocess_record)) loop.close() - print("networking_main finished") + print("Networking process terminated.") diff --git a/habitat-hitl/habitat_hitl/config/hitl_defaults.yaml b/habitat-hitl/habitat_hitl/config/hitl_defaults.yaml index 645f9e9879..01b4882619 100644 --- a/habitat-hitl/habitat_hitl/config/hitl_defaults.yaml +++ b/habitat-hitl/habitat_hitl/config/hitl_defaults.yaml @@ -40,6 +40,8 @@ habitat_hitl: client_sync: # If enabled, the server main camera transform will be sent to the client. Disable if the client should control its own camera (e.g. VR), or if clients must use different camera transforms (e.g. multiplayer). server_camera: True + # If enabled, the first client input is relayed to the server's GuiInput. Disable if clients have independent controls from the server. + server_input: True # Enable transmission of skinned mesh poses. If 'camera.first_person_mode' is enabled, you should generally disable this as well as enable `hide_humanoid_in_gui` because the humanoid will occlude the camera. skinning: True diff --git a/habitat-hitl/habitat_hitl/core/client_helper.py b/habitat-hitl/habitat_hitl/core/client_helper.py index d837178309..b74498df6a 100644 --- a/habitat-hitl/habitat_hitl/core/client_helper.py +++ b/habitat-hitl/habitat_hitl/core/client_helper.py @@ -5,6 +5,9 @@ # LICENSE file in the root directory of this source tree. +from typing import Optional + +from habitat_hitl.app_states.app_service import AppService from habitat_hitl.core.average_helper import AverageHelper @@ -13,26 +16,30 @@ class ClientHelper: Tracks connected remote clients. Displays client latency and kicks idle clients. """ - def __init__(self, app_service): + def __init__(self, app_service: AppService): self._app_service = app_service - self._show_idle_kick_warning = False - self._idle_frame_counter = None self._frame_counter = 0 self._client_frame_latency_avg_helper = AverageHelper( window_size=10, output_rate=10 ) - self._display_latency_ms = None + + self._show_idle_kick_warning: Optional[bool] = False + self._idle_frame_counter: Optional[int] = None + self._display_latency_ms: Optional[float] = None @property - def display_latency_ms(self): + def display_latency_ms(self) -> Optional[float]: + """Returns the display latency.""" return self._display_latency_ms @property - def do_show_idle_kick_warning(self): + def do_show_idle_kick_warning(self) -> Optional[bool]: + """Indicates that the user should be warned that they will be kicked imminently.""" return self._show_idle_kick_warning - def _update_idle_kick(self, is_user_idle_this_frame): + def _update_idle_kick(self, is_user_idle_this_frame: bool) -> None: + """Keeps tracks of whether the user is AFK. After some time, they will be kicked.""" hitl_config = self._app_service.hitl_config self._show_idle_kick_warning = False @@ -74,7 +81,10 @@ def _update_idle_kick(self, is_user_idle_this_frame): # reset counter whenever the client isn't idle self._idle_frame_counter = 0 - def _update_frame_counter_and_display_latency(self, server_sps): + def _update_frame_counter_and_display_latency( + self, server_sps: float + ) -> None: + """Update the frame counter.""" recent_server_keyframe_id = ( self._app_service.remote_client_state.pop_recent_server_keyframe_id() ) @@ -92,6 +102,7 @@ def _update_frame_counter_and_display_latency(self, server_sps): ) self._frame_counter += 1 - def update(self, is_user_idle_this_frame, server_sps): + def update(self, is_user_idle_this_frame: bool, server_sps: float) -> None: + """Update the client helper.""" self._update_idle_kick(is_user_idle_this_frame) self._update_frame_counter_and_display_latency(server_sps) diff --git a/habitat-hitl/habitat_hitl/core/client_message_manager.py b/habitat-hitl/habitat_hitl/core/client_message_manager.py index 880d65c8bb..98d40dd3f1 100644 --- a/habitat-hitl/habitat_hitl/core/client_message_manager.py +++ b/habitat-hitl/habitat_hitl/core/client_message_manager.py @@ -4,12 +4,14 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, Final, List, Optional, Union import magnum as mn from habitat_hitl.core.user_mask import Mask, Users +DEFAULT_NORMAL: Final[List[float]] = [0.0, 1.0, 0.0] + class ClientMessageManager: r""" @@ -41,6 +43,7 @@ def add_highlight( self, pos: List[float], radius: float, + normal: List[float] = DEFAULT_NORMAL, billboard: bool = True, color: Optional[Union[mn.Color4, mn.Color3]] = None, destination_mask: Mask = Mask.ALL, @@ -53,9 +56,13 @@ def add_highlight( for user_index in self._users.indices(destination_mask): message = self._messages[user_index] - if "highlights" not in message: - message["highlights"] = [] - highlight_dict = {"t": [pos[0], pos[1], pos[2]], "r": radius} + if "circles" not in message: + message["circles"] = [] + highlight_dict = { + "t": [pos[0], pos[1], pos[2]], + "r": radius, + "n": normal, + } if billboard: highlight_dict["b"] = 1 if color is not None: @@ -71,7 +78,48 @@ def conv(channel): conv(color.b), conv(alpha), ] - message["highlights"].append(highlight_dict) + message["circles"].append(highlight_dict) + + def add_line( + self, + a: List[float], + b: List[float], + from_color: Optional[Union[mn.Color4, mn.Color3]] = None, + to_color: Optional[Union[mn.Color4, mn.Color3]] = None, + destination_mask: Mask = Mask.ALL, + ) -> None: + r""" + Draw a line from the two specified world positions. + """ + assert len(a) == 3 + assert len(b) == 3 + + for user_index in self._users.indices(destination_mask): + message = self._messages[user_index] + + if "lines" not in message: + message["lines"] = [] + lines_dict = {"a": [a[0], a[1], a[2]], "b": [b[0], b[1], b[2]]} + + if from_color is not None: + + def conv(channel): + # sloppy: using int 0-255 to reduce serialized data size + return int(channel * 255.0) + + alpha = ( + 1.0 if isinstance(from_color, mn.Color3) else from_color.a + ) + lines_dict["c"] = [ + conv(from_color.r), + conv(from_color.g), + conv(from_color.b), + conv(alpha), + ] + + # TODO: Implement "to_color". + + message["lines"].append(lines_dict) def add_text( self, text: str, pos: list[float], destination_mask: Mask = Mask.ALL diff --git a/habitat-hitl/habitat_hitl/core/gui_drawer.py b/habitat-hitl/habitat_hitl/core/gui_drawer.py index d676eca2fa..a93ca460dc 100644 --- a/habitat-hitl/habitat_hitl/core/gui_drawer.py +++ b/habitat-hitl/habitat_hitl/core/gui_drawer.py @@ -9,6 +9,7 @@ import magnum as mn from habitat_hitl.core.client_message_manager import ClientMessageManager +from habitat_hitl.core.user_mask import Mask from habitat_sim.gfx import DebugLineRender @@ -33,6 +34,9 @@ def __init__( self._sim_debug_line_render = sim_debug_line_render self._client_message_manager = client_message_manager + # TODO: Implement per-user. + self._local_transforms: List[mn.Matrix4] = [] + def get_sim_debug_line_render(self) -> Optional[DebugLineRender]: """ Set the internal 'sim_debug_line_render' object, used for rendering lines onto the server. @@ -43,6 +47,7 @@ def get_sim_debug_line_render(self) -> Optional[DebugLineRender]: def set_line_width( self, line_width: float, + destination_mask: Mask = Mask.ALL, ) -> None: """ Set global line width for all lines rendered by GuiDrawer. @@ -59,6 +64,7 @@ def set_line_width( def push_transform( self, transform: mn.Matrix4, + destination_mask: Mask = Mask.ALL, ) -> None: """ Push (multiply) a transform onto the transform stack, affecting all line-drawing until popped. @@ -70,11 +76,12 @@ def push_transform( # If remote rendering is enabled: if self._client_message_manager: - # Networking not implemented - pass + # TODO: Implement per-user. + self._local_transforms.append(transform) def pop_transform( self, + destination_mask: Mask = Mask.ALL, ) -> None: """ See push_transform. @@ -85,26 +92,54 @@ def pop_transform( # If remote rendering is enabled: if self._client_message_manager: - # Networking not implemented - pass + # TODO: Implement per-user. + self._local_transforms.pop() def draw_box( self, min_extent: mn.Vector3, max_extent: mn.Vector3, color: mn.Color4, + destination_mask: Mask = Mask.ALL, ) -> None: """ Draw a box in world-space or local-space (see pushTransform). """ # If server rendering is enabled: if self._sim_debug_line_render: - self._sim_debug_line_render.draw_box(min, max, color) + self._sim_debug_line_render.draw_box(min_extent, max_extent, color) # If remote rendering is enabled: if self._client_message_manager: - # Networking not implemented - pass + + def vec(x, y, z) -> mn.Vector3: + return mn.Vector3(x, y, z) + + def draw_line(a: mn.Vector3, b: mn.Vector3) -> None: + self.draw_transformed_line( + a, b, from_color=color, destination_mask=destination_mask + ) + + e0 = min_extent + e1 = max_extent + + # 4 lines along x axis + draw_line(vec(e0.x, e0.y, e0.z), vec(e1.x, e0.y, e0.z)) + draw_line(vec(e0.x, e0.y, e1.z), vec(e1.x, e0.y, e1.z)) + draw_line(vec(e0.x, e1.y, e0.z), vec(e1.x, e1.y, e0.z)) + draw_line(vec(e0.x, e1.y, e1.z), vec(e1.x, e1.y, e1.z)) + + # 4 lines along y axis + draw_line(vec(e0.x, e0.y, e0.z), vec(e0.x, e1.y, e0.z)) + draw_line(vec(e1.x, e0.y, e0.z), vec(e1.x, e1.y, e0.z)) + draw_line(vec(e0.x, e0.y, e1.z), vec(e0.x, e1.y, e1.z)) + draw_line(vec(e1.x, e0.y, e1.z), vec(e1.x, e1.y, e1.z)) + + # 4 lines along z axis + draw_line(vec(e0.x, e0.y, e0.z), vec(e0.x, e0.y, e1.z)) + draw_line(vec(e1.x, e0.y, e0.z), vec(e1.x, e0.y, e1.z)) + draw_line(vec(e0.x, e1.y, e0.z), vec(e0.x, e1.y, e1.z)) + draw_line(vec(e1.x, e1.y, e0.z), vec(e1.x, e1.y, e1.z)) def draw_circle( self, @@ -114,10 +149,13 @@ def draw_circle( num_segments: int = DEFAULT_SEGMENT_COUNT, normal: mn.Vector3 = DEFAULT_NORMAL, billboard: bool = False, + destination_mask: Mask = Mask.ALL, ) -> None: """ Draw a circle in world-space or local-space (see pushTransform). The circle is an approximation; see numSegments. + + The normal is always in world-space. """ # If server rendering is enabled: if self._sim_debug_line_render: @@ -127,8 +165,16 @@ def draw_circle( # If remote rendering is enabled: if self._client_message_manager: + parent_transform = self._compute_parent_transform() + global_translation = parent_transform.transform_point(translation) + self._client_message_manager.add_highlight( - translation, radius, billboard=billboard, color=color + pos=_vec_to_list(global_translation), + radius=radius, + normal=_vec_to_list(normal), + billboard=billboard, + color=color, + destination_mask=destination_mask, ) def draw_transformed_line( @@ -137,6 +183,7 @@ def draw_transformed_line( to_pos: mn.Vector3, from_color: mn.Color4, to_color: mn.Color4 = None, + destination_mask: Mask = Mask.ALL, ) -> None: """ Draw a line segment in world-space or local-space (see pushTransform) with interpolated color. @@ -155,8 +202,17 @@ def draw_transformed_line( # If remote rendering is enabled: if self._client_message_manager: - # Networking not implemented - pass + parent_transform = self._compute_parent_transform() + global_from_pos = parent_transform.transform_point(from_pos) + global_to_pos = parent_transform.transform_point(to_pos) + + self._client_message_manager.add_line( + _vec_to_list(global_from_pos), + _vec_to_list(global_to_pos), + from_color=from_color, + to_color=to_color, + destination_mask=destination_mask, + ) def draw_path_with_endpoint_circles( self, @@ -165,6 +221,7 @@ def draw_path_with_endpoint_circles( color: mn.Color4, num_segments: int = DEFAULT_SEGMENT_COUNT, normal: mn.Vector3 = DEFAULT_NORMAL, + destination_mask: Mask = Mask.ALL, ) -> None: """ Draw a sequence of line segments with circles at the two endpoints. @@ -180,3 +237,17 @@ def draw_path_with_endpoint_circles( if self._client_message_manager: # Networking not implemented pass + + def _compute_parent_transform(self) -> mn.Matrix4: + """ + Resolve the transform resulting from the push/pop_transform calls. + To apply to a point, use {ret_val}.transform_point(from_pos). + """ + parent_transform = mn.Matrix4.identity_init() + for local_transform in self._local_transforms: + parent_transform = parent_transform @ local_transform + return parent_transform + + +def _vec_to_list(vec: mn.Vector3) -> List[float]: + return [vec.x, vec.y, vec.z] diff --git a/habitat-hitl/habitat_hitl/core/gui_input.py b/habitat-hitl/habitat_hitl/core/gui_input.py index 184525759f..caf3eb660c 100644 --- a/habitat-hitl/habitat_hitl/core/gui_input.py +++ b/habitat-hitl/habitat_hitl/core/gui_input.py @@ -4,17 +4,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from habitat_hitl.core.key_mapping import KeyCode - - -class StubNSMeta(type): - def __getattr__(cls, name): - return None - - -# Stub version of Application.MouseEvent.Button -class StubMouseNS(metaclass=StubNSMeta): - pass +from habitat_hitl.core.key_mapping import KeyCode, MouseButton class GuiInput: @@ -25,7 +15,7 @@ class GuiInput: """ KeyNS = KeyCode - MouseNS = StubMouseNS + MouseNS = MouseButton def __init__(self): self._key_held = set() @@ -37,7 +27,7 @@ def __init__(self): self._mouse_button_down = set() self._mouse_button_up = set() self._relative_mouse_position = [0, 0] - self._mouse_scroll_offset = 0 + self._mouse_scroll_offset = 0.0 self._mouse_ray = None def validate_key(key): @@ -59,9 +49,7 @@ def get_key_up(self, key): return key in self._key_up def validate_mouse_button(mouse_button): - # if not do_agnostic_gui_input: - # assert isinstance(mouse_button, Application.MouseEvent.Button) - pass + assert isinstance(mouse_button, MouseButton) def get_mouse_button(self, mouse_button): GuiInput.validate_mouse_button(mouse_button) @@ -99,4 +87,4 @@ def on_frame_end(self): self._mouse_button_down.clear() self._mouse_button_up.clear() self._relative_mouse_position = [0, 0] - self._mouse_scroll_offset = 0 + self._mouse_scroll_offset = 0.0 diff --git a/habitat-hitl/habitat_hitl/core/key_mapping.py b/habitat-hitl/habitat_hitl/core/key_mapping.py index bf2e641b60..53f023e627 100644 --- a/habitat-hitl/habitat_hitl/core/key_mapping.py +++ b/habitat-hitl/habitat_hitl/core/key_mapping.py @@ -69,6 +69,28 @@ class KeyCode(IntEnum, metaclass=KeyCodeMetaEnum): # fmt: on +class MouseButtonMetaEnum(EnumMeta): + keycode_value_cache: Set[int] = None + + # Override 'in' keyword to check whether the specified integer exists in 'MouseButton'. + def __contains__(cls, value) -> bool: + if MouseButtonMetaEnum.keycode_value_cache == None: + MouseButtonMetaEnum.keycode_value_cache = set(MouseButton) + return value in MouseButtonMetaEnum.keycode_value_cache + + +class MouseButton(IntEnum, metaclass=MouseButtonMetaEnum): + """ + Mouse buttons available to control habitat-hitl. + """ + + # fmt: off + LEFT = 0 + RIGHT = 1 + MIDDLE = 2 + # fmt: on + + # On headless systems, we may be unable to import magnum.platform.glfw.Application. try: from magnum.platform.glfw import Application @@ -124,9 +146,22 @@ class KeyCode(IntEnum, metaclass=KeyCodeMetaEnum): # fmt: on } + magnum_mouse_keymap: Dict[Application.KeyEvent.Key, MouseButton] = { + # fmt: off + Application.MouseEvent.Button.LEFT : MouseButton.LEFT , + Application.MouseEvent.Button.RIGHT : MouseButton.RIGHT , + Application.MouseEvent.Button.MIDDLE : MouseButton.MIDDLE, + # fmt: on + } + class MagnumKeyConverter: - def convert(key: Any) -> Optional[KeyCode]: + def convert_key(key: Any) -> Optional[KeyCode]: if magnum_enabled and key in magnum_keymap: return magnum_keymap[key] return None + + def convert_mouse_button(button: Any) -> Optional[MouseButton]: + if magnum_enabled and button in magnum_mouse_keymap: + return magnum_mouse_keymap[button] + return None diff --git a/habitat-hitl/habitat_hitl/core/remote_client_state.py b/habitat-hitl/habitat_hitl/core/remote_client_state.py index 690cf1ac8c..b6dac7fbe4 100644 --- a/habitat-hitl/habitat_hitl/core/remote_client_state.py +++ b/habitat-hitl/habitat_hitl/core/remote_client_state.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import math -from typing import Any, List +from typing import Any, List, Optional, Tuple import magnum as mn @@ -17,7 +17,10 @@ ) from habitat_hitl.core.gui_drawer import GuiDrawer from habitat_hitl.core.gui_input import GuiInput -from habitat_hitl.core.key_mapping import KeyCode +from habitat_hitl.core.key_mapping import KeyCode, MouseButton +from habitat_hitl.core.types import ClientState, ConnectionRecord +from habitat_hitl.core.user_mask import Mask, Users +from habitat_sim.geo import Ray class RemoteClientState: @@ -30,16 +33,21 @@ def __init__( self, interprocess_record: InterprocessRecord, gui_drawer: GuiDrawer, - gui_input: GuiInput, + users: Users, ): - self._gui_input = gui_input - self._recent_client_states: List[Any] = [] self._interprocess_record = interprocess_record self._gui_drawer = gui_drawer + self._users = users self._receive_rate_tracker = AverageRateTracker(2.0) - self._new_connection_records: List[Any] = [] + self._recent_client_states: List[ClientState] = [] + self._new_connection_records: List[ConnectionRecord] = [] + + # Create one GuiInput per user to be controlled by remote clients. + self._gui_inputs: List[GuiInput] = [] + for _ in users.indices(Mask.ALL): + self._gui_inputs.append(GuiInput()) # temp map VR button to key self._button_map = { @@ -49,19 +57,31 @@ def __init__( 3: GuiInput.KeyNS.THREE, } - def get_gui_input(self): - """Internal GuiInput class.""" - return self._gui_input + def get_gui_input(self, user_index: int = 0) -> GuiInput: + """Get the GuiInput for a specified user index.""" + return self._gui_inputs[user_index] + + def get_gui_inputs(self) -> List[GuiInput]: + """Get a list of all GuiInputs indexed by user index.""" + return self._gui_inputs + + def bind_gui_input(self, gui_input: GuiInput, user_index: int) -> None: + """ + Bind the specified GuiInput to a specified user, allowing the associated remote client to control it. + Erases the previous GuiInput. + """ + assert user_index < len(self._gui_inputs) + self._gui_inputs[user_index] = gui_input - def get_history_length(self): + def get_history_length(self) -> int: """Length of client state history preserved. Anything beyond this horizon is discarded.""" return 4 - def get_history_timestep(self): + def get_history_timestep(self) -> float: """Frequency at which client states are read.""" return 1 / 60 - def pop_recent_server_keyframe_id(self): + def pop_recent_server_keyframe_id(self) -> Optional[int]: """ Removes and returns ("pops") the recentServerKeyframeId included in the latest client state. @@ -78,14 +98,18 @@ def pop_recent_server_keyframe_id(self): del latest_client_state["recentServerKeyframeId"] return retval - def get_recent_client_state_by_history_index(self, history_index): + def get_recent_client_state_by_history_index( + self, history_index: int + ) -> Optional[ClientState]: assert history_index >= 0 if history_index >= len(self._recent_client_states): return None return self._recent_client_states[-(1 + history_index)] - def get_head_pose(self, history_index=0): + def get_head_pose( + self, history_index: int = 0 + ) -> Optional[Tuple[mn.Vector3, mn.Quaternion]]: """ Get the latest head transform. Beware that this is in agent-space. Agents are flipped 180 degrees on the y-axis such as their z-axis faces forward. @@ -114,7 +138,9 @@ def get_head_pose(self, history_index=0): ) return pos, rot_quat - def get_hand_pose(self, hand_idx, history_index=0): + def get_hand_pose( + self, hand_idx: int, history_index: int = 0 + ) -> Optional[Tuple[mn.Vector3, mn.Quaternion]]: """ Get the latest hand transforms. Beware that this is in agent-space. Agents are flipped 180 degrees on the y-axis such as their z-axis faces forward. @@ -146,30 +172,78 @@ def get_hand_pose(self, hand_idx, history_index=0): ) return pos, rot_quat - def _update_input_state(self, client_states): + def _update_input_state(self, client_states: List[ClientState]) -> None: """Update mouse/keyboard input based on new client states.""" - if not len(client_states): + if not len(client_states) or not len(self._gui_inputs): return + # TODO: Only one user supported for now. + # In multiplayer, there will be client_state per user_index. + user_index = 0 + assert user_index < len(self._gui_inputs) + gui_input = self._gui_inputs[user_index] + # Gather all recent keyDown and keyUp events for client_state in client_states: input_json = ( client_state["input"] if "input" in client_state else None ) - # TODO: Add mouse support - # mouse_json = ( - # client_state["mouse"] if "mouse" in client_state else None - # ) + mouse_json = ( + client_state["mouse"] if "mouse" in client_state else None + ) if input_json is not None: for button in input_json["buttonDown"]: if button not in KeyCode: continue - self._gui_input._key_down.add(KeyCode(button)) + gui_input._key_down.add(KeyCode(button)) for button in input_json["buttonUp"]: if button not in KeyCode: continue - self._gui_input._key_up.add(KeyCode(button)) + gui_input._key_up.add(KeyCode(button)) + + if mouse_json is not None: + mouse_buttons = mouse_json["buttons"] + for button in mouse_buttons["buttonDown"]: + if button not in MouseButton: + continue + gui_input._mouse_button_down.add(MouseButton(button)) + for button in mouse_buttons["buttonUp"]: + if button not in MouseButton: + continue + gui_input._mouse_button_up.add(MouseButton(button)) + + if "scrollDelta" in mouse_json: + delta: List[Any] = mouse_json["scrollDelta"] + if len(delta) == 2: + gui_input._mouse_scroll_offset += ( + delta[0] + if abs(delta[0]) > abs(delta[1]) + else delta[1] + ) + + if "mousePositionDelta" in mouse_json: + pos_delta: List[Any] = mouse_json["mousePositionDelta"] + if len(pos_delta) == 2: + gui_input._relative_mouse_position = [ + pos_delta[0], + pos_delta[1], + ] + + if "rayOrigin" in mouse_json: + ray_origin: List[float] = mouse_json["rayOrigin"] + ray_direction: List[float] = mouse_json["rayDirection"] + if len(ray_origin) == 3 and len(ray_direction) == 3: + ray = Ray() + ray.origin = mn.Vector3( + ray_origin[0], ray_origin[1], ray_origin[2] + ) + ray.direction = mn.Vector3( + ray_direction[0], + ray_direction[1], + ray_direction[2], + ).normalized() + gui_input._mouse_ray = ray # todo: think about ambiguous GuiInput states (key-down and key-up events in the same # frame and other ways that keyHeld, keyDown, and keyUp can be inconsistent. @@ -180,30 +254,42 @@ def _update_input_state(self, client_states): if "input" in last_client_state else None ) - # TODO: Add mouse support - # mouse_json = last_client_state["mouse"] if "mouse" in last_client_state else None + mouse_json = ( + last_client_state["mouse"] + if "mouse" in last_client_state + else None + ) - self._gui_input._key_held.clear() + gui_input._key_held.clear() + gui_input._mouse_button_held.clear() if input_json is not None: for button in input_json["buttonHeld"]: if button not in KeyCode: continue - self._gui_input._key_held.add(KeyCode(button)) + gui_input._key_held.add(KeyCode(button)) - def debug_visualize_client(self): + if mouse_json is not None: + mouse_buttons = mouse_json["buttons"] + for button in mouse_buttons["buttonHeld"]: + if button not in MouseButton: + continue + gui_input._mouse_button_held.add(MouseButton(button)) + + def debug_visualize_client(self) -> None: """Visualize the received VR inputs (head and hands).""" - # Sloppy: Use internal debug_line_render to render on server only. - line_renderer = self._gui_drawer.get_sim_debug_line_render() - if not line_renderer: + if not self._gui_drawer: return + server_only = Mask.NONE # Render on the server only. avatar_color = mn.Color3(0.3, 1, 0.3) pos, rot_quat = self.get_head_pose() if pos is not None and rot_quat is not None: trans = mn.Matrix4.from_(rot_quat.to_matrix(), pos) - line_renderer.push_transform(trans) + self._gui_drawer.push_transform( + trans, destination_mask=server_only + ) color0 = avatar_color color1 = mn.Color4( avatar_color.r, avatar_color.g, avatar_color.b, 0 @@ -211,49 +297,57 @@ def debug_visualize_client(self): size = 0.5 # Draw a frustum (forward is flipped (z+)) - line_renderer.draw_transformed_line( + self._gui_drawer.draw_transformed_line( mn.Vector3(0, 0, 0), mn.Vector3(size, size, size), color0, color1, ) - line_renderer.draw_transformed_line( + self._gui_drawer.draw_transformed_line( mn.Vector3(0, 0, 0), mn.Vector3(-size, size, size), color0, color1, + destination_mask=server_only, ) - line_renderer.draw_transformed_line( + self._gui_drawer.draw_transformed_line( mn.Vector3(0, 0, 0), mn.Vector3(size, -size, size), color0, color1, + destination_mask=server_only, ) - line_renderer.draw_transformed_line( + self._gui_drawer.draw_transformed_line( mn.Vector3(0, 0, 0), mn.Vector3(-size, -size, size), color0, color1, + destination_mask=server_only, ) - line_renderer.pop_transform() + self._gui_drawer.pop_transform(destination_mask=server_only) # Draw controller rays (forward is flipped (z+)) for hand_idx in range(2): hand_pos, hand_rot_quat = self.get_hand_pose(hand_idx) if hand_pos is not None and hand_rot_quat is not None: trans = mn.Matrix4.from_(hand_rot_quat.to_matrix(), hand_pos) - line_renderer.push_transform(trans) + self._gui_drawer.push_transform( + trans, destination_mask=server_only + ) pointer_len = 0.5 - line_renderer.draw_transformed_line( + self._gui_drawer.draw_transformed_line( mn.Vector3(0, 0, 0), mn.Vector3(0, 0, pointer_len), color0, color1, + destination_mask=server_only, ) - line_renderer.pop_transform() + self._gui_drawer.pop_transform(destination_mask=server_only) - def _clean_history_by_connection_id(self, client_states): + def _clean_history_by_connection_id( + self, client_states: List[ClientState] + ) -> None: """ Clear history by connection id. Typically done after a client disconnect. @@ -262,6 +356,8 @@ def _clean_history_by_connection_id(self, client_states): return latest_client_state = client_states[-1] + if "connectionId" not in latest_client_state: + return latest_connection_id = latest_client_state["connectionId"] # discard older states that don't match the latest connection id @@ -282,7 +378,7 @@ def _clean_history_by_connection_id(self, client_states): ): self.clear_history() - def update(self): + def update(self) -> None: """Get the latest received remote client states.""" self._new_connection_records = ( self._interprocess_record.get_queued_connection_records() @@ -307,12 +403,13 @@ def update(self): self.debug_visualize_client() - def get_new_connection_records(self): + def get_new_connection_records(self) -> List[ConnectionRecord]: return self._new_connection_records - def on_frame_end(self): - self._gui_input.on_frame_end() + def on_frame_end(self) -> None: + for user_index in self._users.indices(Mask.ALL): + self._gui_inputs[user_index].on_frame_end() self._new_connection_records = None - def clear_history(self): + def clear_history(self) -> None: self._recent_client_states.clear() diff --git a/habitat-hitl/habitat_hitl/core/types.py b/habitat-hitl/habitat_hitl/core/types.py new file mode 100644 index 0000000000..262011142a --- /dev/null +++ b/habitat-hitl/habitat_hitl/core/types.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict + +# Dictionary that is serialized to or from JSON. +DataDict = Dict[str, Any] + +# Server -> Client communication dictionary originating from habitat-sim (loads, updates, deletions, ...). +Keyframe = DataDict + +# Client -> Server communication dictionary (inputs, etc.). +ClientState = DataDict + +# Dictionary that contains data about a new connection. +ConnectionRecord = DataDict diff --git a/habitat-hitl/habitat_hitl/core/user_mask.py b/habitat-hitl/habitat_hitl/core/user_mask.py index 0983b614fe..4e28727f70 100644 --- a/habitat-hitl/habitat_hitl/core/user_mask.py +++ b/habitat-hitl/habitat_hitl/core/user_mask.py @@ -92,54 +92,3 @@ def to_index_list(self, user_mask: Mask) -> List[int]: def max_user_count(self) -> int: """Returns the size of the user set.""" return self._max_user_count - - -if __name__ == "__main__": - # TODO: Move this to a unit test when testing is added habitat-hitl. - four_users = Users(4) - assert four_users.max_user_count == 4 - assert len(four_users.to_index_list(Mask.ALL)) == 4 - assert len(four_users.to_index_list(Mask.NONE)) == 0 - user_indices = four_users.to_index_list( - Mask.from_index(1) | Mask.from_index(2) | Mask.from_index(11) - ) - assert 1 in user_indices - assert 2 in user_indices - assert 11 not in user_indices - user_indices = four_users.to_index_list(Mask.all_except_index(1)) - assert 0 in user_indices - assert 1 not in user_indices - assert 2 in user_indices - assert 3 in user_indices - assert 4 not in user_indices - - six_users = Users(6) - assert six_users.max_user_count == 6 - assert len(six_users.to_index_list(Mask.ALL)) == 6 - assert len(six_users.to_index_list(Mask.NONE)) == 0 - user_indices = six_users.to_index_list(Mask.all_except_indices([0, 2])) - assert 0 not in user_indices - assert 1 in user_indices - assert 2 not in user_indices - assert 3 in user_indices - assert 4 in user_indices - assert 5 in user_indices - assert 6 not in user_indices - - two_users = Users(2) - assert two_users.max_user_count == 2 - assert len(two_users.to_index_list(Mask.ALL)) == 2 - assert len(two_users.to_index_list(Mask.NONE)) == 0 - user_indices = two_users.to_index_list(Mask.from_indices([1, 2])) - assert 0 not in user_indices - assert 1 in user_indices - assert 2 not in user_indices - - max_users = Users(32) - assert max_users.max_user_count == 32 - assert len(max_users.to_index_list(Mask.ALL)) == 32 - assert len(max_users.to_index_list(Mask.NONE)) == 0 - assert ( - len(max_users.to_index_list(Mask.all_except_indices([17, 22]))) == 30 - ) - assert len(max_users.to_index_list(Mask.from_indices([3, 15]))) == 2 diff --git a/habitat-hitl/habitat_hitl/environment/gui_navigation_helper.py b/habitat-hitl/habitat_hitl/environment/gui_navigation_helper.py index 4c4cec0fbe..8612688bb8 100644 --- a/habitat-hitl/habitat_hitl/environment/gui_navigation_helper.py +++ b/habitat-hitl/habitat_hitl/environment/gui_navigation_helper.py @@ -12,6 +12,7 @@ from habitat.datasets.rearrange.navmesh_utils import get_largest_island_index from habitat.tasks.rearrange.rearrange_sim import RearrangeSim from habitat_hitl.app_states.app_service import AppService +from habitat_hitl.core.user_mask import Mask from habitat_hitl.environment.hablab_utils import get_agent_art_obj_transform from habitat_sim.nav import ShortestPath @@ -19,9 +20,12 @@ class GuiNavigationHelper: """Helper for controlling an agent from the GUI.""" - def __init__(self, gui_service: AppService, agent_idx: int) -> None: + def __init__( + self, gui_service: AppService, agent_idx: int, user_index: int + ) -> None: self._app_service = gui_service self._agent_idx = agent_idx + self._user_index = user_index self._largest_island_idx: Optional[int] = None def _get_sim(self) -> RearrangeSim: @@ -108,7 +112,10 @@ def _viz_humanoid_walk_path(self, path: ShortestPath) -> None: path_points.append(adjusted_point) self._app_service.gui_drawer.draw_path_with_endpoint_circles( - path_points, path_endpoint_radius, path_color + path_points, + path_endpoint_radius, + path_color, + destination_mask=Mask.from_index(self._user_index), ) def get_humanoid_walk_hints_from_remote_client_state( @@ -298,6 +305,11 @@ def _draw_nav_hint( color_with_alpha = mn.Color4(color) color_with_alpha[3] *= alpha self._app_service.gui_drawer.draw_circle( - pos, radius, color_with_alpha, num_segments, normal + pos, + radius, + color_with_alpha, + num_segments, + normal, + destination_mask=Mask.from_index(self._user_index), ) prev_pos = pos diff --git a/habitat-hitl/habitat_hitl/environment/gui_pick_helper.py b/habitat-hitl/habitat_hitl/environment/gui_pick_helper.py index 81f974ceca..afea05cb3a 100644 --- a/habitat-hitl/habitat_hitl/environment/gui_pick_helper.py +++ b/habitat-hitl/habitat_hitl/environment/gui_pick_helper.py @@ -4,11 +4,14 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Final +from typing import Final, List import magnum as mn import numpy as np +from habitat_hitl.app_states.app_service import AppService +from habitat_hitl.core.user_mask import Mask + DIST_HIGHLIGHT: Final[float] = 0.15 COLOR_GRASPABLE: Final[mn.Color3] = mn.Color3(1, 0.75, 0) COLOR_GRASP_PREVIEW: Final[mn.Color3] = mn.Color3(0.5, 1, 0) @@ -21,15 +24,15 @@ class GuiPickHelper: """Helper for picking up objects from the GUI.""" - def __init__(self, gui_service): - self._app_service = gui_service - self._rom = self._get_sim().get_rigid_object_manager() - self._obj_ids = self._get_sim()._scene_obj_ids - self._dist_to_highlight_obj = DIST_HIGHLIGHT - self._pick_candidate_indices = [] + def __init__(self, app_service: AppService, user_index: int): + self._app_service = app_service + self._user_index = user_index + self._sim = self._app_service.sim - def _get_sim(self): - return self._app_service.sim + self._rom = self._sim.get_rigid_object_manager() + self._obj_ids = self._sim._scene_obj_ids + self._dist_to_highlight_obj = DIST_HIGHLIGHT + self._pick_candidate_indices: List[int] = [] def _closest_point_and_dist_to_ray( self, ray_origin, ray_direction_vector, points @@ -46,9 +49,8 @@ def _closest_point_and_dist_to_ray( return np.argmin(distances), np.min(distances) def on_environment_reset(self): - sim = self._get_sim() - self._rom = sim.get_rigid_object_manager() - self._obj_ids = sim._scene_obj_ids + self._rom = self._sim.get_rigid_object_manager() + self._obj_ids = self._sim._scene_obj_ids self._pick_candidate_indices = [] def _closest_point_and_dist_to_query_position(self, points, query_pos): @@ -77,18 +79,23 @@ def get_pick_object_near_query_position(self, query_pos): else: return None - def _draw_circle(self, pos, color, radius, billboard): - num_segments = 24 - self._app_service.gui_drawer.draw_circle( - pos, radius, color, num_segments, billboard=billboard - ) - def _add_highlight_ring( - self, pos, color, radius, do_pulse=False, billboard=True + self, + pos: mn.Vector3, + radius: float, + color: mn.Color3, + do_pulse: bool = False, + billboard: bool = True, ): if do_pulse: radius += self._app_service.get_anim_fraction() * RING_PULSE_SIZE - self._draw_circle(pos, color, radius, billboard) + self._app_service.gui_drawer.draw_circle( + pos, + radius, + color, + billboard=billboard, + destination_mask=Mask.from_index(self._user_index), + ) def viz_objects(self): obj_positions = self._get_object_positions() @@ -101,8 +108,8 @@ def viz_objects(self): ).transformation.translation self._add_highlight_ring( pos, - COLOR_GRASP_PREVIEW, RADIUS_GRASP_PREVIEW, + COLOR_GRASP_PREVIEW, do_pulse=False, ) self._pick_candidate_indices = [] @@ -113,7 +120,7 @@ def viz_objects(self): obj_id ).transformation.translation self._add_highlight_ring( - pos, COLOR_GRASPABLE, RADIUS_GRASPABLE, do_pulse=True + pos, RADIUS_GRASPABLE, COLOR_GRASPABLE, do_pulse=True ) # Reference code diff --git a/habitat-hitl/habitat_hitl/environment/gui_placement_helper.py b/habitat-hitl/habitat_hitl/environment/gui_placement_helper.py index b02cfc1a6c..d04e7e3825 100644 --- a/habitat-hitl/habitat_hitl/environment/gui_placement_helper.py +++ b/habitat-hitl/habitat_hitl/environment/gui_placement_helper.py @@ -9,6 +9,8 @@ import magnum as mn +from habitat_hitl.app_states.app_service import AppService +from habitat_hitl.core.user_mask import Mask from habitat_sim.physics import CollisionGroups COLOR_PLACE_PREVIEW_VALID: Final[mn.Color3] = mn.Color3(1, 1, 1) @@ -23,8 +25,14 @@ class GuiPlacementHelper: """Helper for placing objects from the GUI.""" - def __init__(self, app_service, gravity_dir=DEFAULT_GRAVITY): + def __init__( + self, + app_service: AppService, + user_index: int, + gravity_dir: mn.Vector3 = DEFAULT_GRAVITY, + ): self._app_service = app_service + self._user_index = user_index self._gravity_dir = gravity_dir def _snap_or_hide_object(self, ray, query_obj) -> tuple[bool, mn.Vector3]: @@ -101,29 +109,21 @@ def update(self, ray, query_obj_id): query_obj.collidable = cached_is_collidable if success: - self._draw_circle( + self._app_service.gui_drawer.draw_circle( hint_pos, - COLOR_PLACE_PREVIEW_VALID, RADIUS_PLACE_PREVIEW_VALID, + COLOR_PLACE_PREVIEW_VALID, billboard=False, + destination_mask=Mask.from_index(self._user_index), ) else: query_obj.translation = FAR_AWAY_HIDDEN_POSITION - self._draw_circle( + self._app_service.gui_drawer.draw_circle( hint_pos, - COLOR_PLACE_PREVIEW_INVALID, RADIUS_PLACE_PREVIEW_INVALID, + COLOR_PLACE_PREVIEW_INVALID, billboard=True, + destination_mask=Mask.from_index(self._user_index), ) return hint_pos if success else None - - def _draw_circle(self, pos, color, radius, billboard): - num_segments = 24 - self._app_service.gui_drawer.draw_circle( - pos, - radius, - color, - num_segments, - billboard=billboard, - ) diff --git a/habitat-hitl/habitat_hitl/environment/gui_throw_helper.py b/habitat-hitl/habitat_hitl/environment/gui_throw_helper.py index 9d59e4fc0c..38604b0cc2 100644 --- a/habitat-hitl/habitat_hitl/environment/gui_throw_helper.py +++ b/habitat-hitl/habitat_hitl/environment/gui_throw_helper.py @@ -9,11 +9,14 @@ import magnum as mn import numpy as np +from habitat_hitl.app_states.app_service import AppService +from habitat_hitl.core.user_mask import Mask + class GuiThrowHelper: """Helper for throwing objects from the GUI.""" - def __init__(self, gui_service, agent_idx): + def __init__(self, gui_service: AppService, agent_idx: int): self._app_service = gui_service self._agent_idx = agent_idx self._largest_island_idx = None @@ -66,16 +69,20 @@ def viz_and_get_humanoid_throw(self): vel_vector, path_points = self.compute_velocity_throw( robot_root, target_on_floor ) - # Sloppy: Use internal debug_line_render to render on server only. - line_renderer = ( - self._app_service.gui_drawer.get_sim_debug_line_render() - ) - if line_renderer is not None: - line_renderer.draw_path_with_endpoint_circles( - path_points, path_endpoint_radius, path_color + gui_drawer = self._app_service.gui_drawer + server_only = Mask.NONE # Render on the server only. + if gui_drawer is not None: + gui_drawer.draw_path_with_endpoint_circles( + path_points, + path_endpoint_radius, + path_color, + destination_mask=server_only, ) - line_renderer.draw_path_with_endpoint_circles( - path_points, path_endpoint_radius, path_color + gui_drawer.draw_path_with_endpoint_circles( + path_points, + path_endpoint_radius, + path_color, + destination_mask=server_only, ) return vel_vector diff --git a/habitat-hitl/test/config/base_test_cfg.yaml b/habitat-hitl/test/config/base_test_cfg.yaml new file mode 100644 index 0000000000..4567ab1889 --- /dev/null +++ b/habitat-hitl/test/config/base_test_cfg.yaml @@ -0,0 +1,6 @@ +defaults: + # We load the `pop_play` Habitat baseline featuring a Spot robot and a humanoid in HSSD scenes. See habitat-baselines/README.md. + - social_rearrange: pop_play + # Load default parameters for the HITL framework. See habitat-hitl/habitat_hitl/config/hitl_defaults.yaml. + - hitl_defaults + - _self_ diff --git a/habitat-hitl/test/config/experiment/smoke_test.yaml b/habitat-hitl/test/config/experiment/smoke_test.yaml new file mode 100644 index 0000000000..51cbd8d1b7 --- /dev/null +++ b/habitat-hitl/test/config/experiment/smoke_test.yaml @@ -0,0 +1,16 @@ +# @package _global_ + +# HITL smoke test configuration. +# Starts a headless HITL app and runs one frame. +habitat_hitl: + window: ~ + experimental: + headless: + do_headless: True + exit_after: 1 + test: + testing: True + +habitat: + dataset: + data_path: data/hab3_bench_assets/episode_datasets/small_small.json.gz diff --git a/habitat-hitl/test/test_example_apps.py b/habitat-hitl/test/test_example_apps.py new file mode 100644 index 0000000000..53d7301bbc --- /dev/null +++ b/habitat-hitl/test/test_example_apps.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import multiprocessing +import runpy +import sys +from os import path + +import pytest + + +def run_main(*args): + sys.argv = list(args) + target = args[0] + if path.isfile(target): + sys.path.insert(0, path.dirname(target)) + runpy.run_path(target, run_name="__main__") + + +def run_main_as_subprocess(args): + context = multiprocessing.get_context("spawn") + process = context.Process(target=run_main, args=args) + process.start() + process.join() + assert process.exitcode == 0 + + +@pytest.mark.parametrize( + "args", + [ + ( + "examples/hitl/basic_viewer/basic_viewer.py", + "--config-dir", + "habitat-hitl/test/config", + "+experiment=smoke_test", + ), + ], +) +def test_hitl_example_basic_viewer(args): + run_main_as_subprocess(args) + + +@pytest.mark.parametrize( + "args", + [ + ( + "examples/hitl/minimal/minimal.py", + "--config-dir", + "habitat-hitl/test/config", + "+experiment=smoke_test", + ), + ], +) +def test_hitl_example_minimal(args): + run_main_as_subprocess(args) + + +@pytest.mark.parametrize( + "args", + [ + ( + "examples/hitl/pick_throw_vr/pick_throw_vr.py", + "--config-dir", + "habitat-hitl/test/config", + "+experiment=smoke_test", + ), + ], +) +def test_hitl_example_pick_throw_vr(args): + run_main_as_subprocess(args) + + +@pytest.mark.parametrize( + "args", + [ + ( + "examples/hitl/rearrange/rearrange.py", + "--config-dir", + "habitat-hitl/test/config", + "+experiment=smoke_test", + ), + ], +) +def test_hitl_example_rearrange(args): + run_main_as_subprocess(args) + + +@pytest.mark.parametrize( + "args", + [ + ( + "examples/hitl/rearrange_v2/rearrange_v2.py", + "--config-dir", + "habitat-hitl/test/config", + "+experiment=smoke_test", + ), + ], +) +def test_hitl_example_rearrange_v2(args): + run_main_as_subprocess(args) diff --git a/habitat-hitl/test/test_main.py b/habitat-hitl/test/test_main.py new file mode 100644 index 0000000000..b7685d93a5 --- /dev/null +++ b/habitat-hitl/test/test_main.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import magnum +from hydra import compose, initialize + +from habitat_hitl.app_states.app_state_abc import AppState +from habitat_hitl.core.hitl_main import hitl_main +from habitat_hitl.core.hydra_utils import register_hydra_plugins + + +class AppStateTest(AppState): + """ + A minimal HITL test app that loads and steps a Habitat environment, with + a fixed overhead camera. + """ + + def __init__(self, app_service): + self._app_service = app_service + + def sim_update(self, dt, post_sim_update_dict): + assert not self._app_service.env.episode_over + self._app_service.compute_action_and_step_env() + + # set the camera for the main 3D viewport + post_sim_update_dict["cam_transform"] = magnum.Matrix4.look_at( + eye=magnum.Vector3(-20, 20, -20), + target=magnum.Vector3(0, 0, 0), + up=magnum.Vector3(0, 1, 0), + ) + + +def main(config) -> None: + hitl_main(config, lambda app_service: AppStateTest(app_service)) + + +def test_hitl_main(): + register_hydra_plugins() + with initialize(version_base=None, config_path="config"): + cfg = compose( + config_name="base_test_cfg", + overrides=[ + "+experiment=smoke_test", + ], + ) + + main(cfg) diff --git a/habitat-hitl/test/test_user_mask.py b/habitat-hitl/test/test_user_mask.py new file mode 100644 index 0000000000..a3bc319d8e --- /dev/null +++ b/habitat-hitl/test/test_user_mask.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from habitat_hitl.core.user_mask import Mask, Users + + +def test_hitl_user_mask(): + four_users = Users(4) + assert four_users.max_user_count == 4 + assert len(four_users.to_index_list(Mask.ALL)) == 4 + assert len(four_users.to_index_list(Mask.NONE)) == 0 + user_indices = four_users.to_index_list( + Mask.from_index(1) | Mask.from_index(2) | Mask.from_index(11) + ) + assert 1 in user_indices + assert 2 in user_indices + assert 11 not in user_indices + user_indices = four_users.to_index_list(Mask.all_except_index(1)) + assert 0 in user_indices + assert 1 not in user_indices + assert 2 in user_indices + assert 3 in user_indices + assert 4 not in user_indices + + six_users = Users(6) + assert six_users.max_user_count == 6 + assert len(six_users.to_index_list(Mask.ALL)) == 6 + assert len(six_users.to_index_list(Mask.NONE)) == 0 + user_indices = six_users.to_index_list(Mask.all_except_indices([0, 2])) + assert 0 not in user_indices + assert 1 in user_indices + assert 2 not in user_indices + assert 3 in user_indices + assert 4 in user_indices + assert 5 in user_indices + assert 6 not in user_indices + + two_users = Users(2) + assert two_users.max_user_count == 2 + assert len(two_users.to_index_list(Mask.ALL)) == 2 + assert len(two_users.to_index_list(Mask.NONE)) == 0 + user_indices = two_users.to_index_list(Mask.from_indices([1, 2])) + assert 0 not in user_indices + assert 1 in user_indices + assert 2 not in user_indices + + max_users = Users(32) + assert max_users.max_user_count == 32 + assert len(max_users.to_index_list(Mask.ALL)) == 32 + assert len(max_users.to_index_list(Mask.NONE)) == 0 + assert ( + len(max_users.to_index_list(Mask.all_except_indices([17, 22]))) == 30 + ) + assert len(max_users.to_index_list(Mask.from_indices([3, 15]))) == 2 diff --git a/habitat-lab/habitat/config/benchmark/multi_agent/pddl/multi_agent_social_nav.yaml b/habitat-lab/habitat/config/benchmark/multi_agent/pddl/multi_agent_social_nav.yaml index ae6ce016bb..6af63864ca 100644 --- a/habitat-lab/habitat/config/benchmark/multi_agent/pddl/multi_agent_social_nav.yaml +++ b/habitat-lab/habitat/config/benchmark/multi_agent/pddl/multi_agent_social_nav.yaml @@ -16,8 +16,8 @@ objects: goal: expr_type: AND sub_exprs: - - at(any_targets|0,TARGET_any_targets|0) - - at(any_targets|1,TARGET_any_targets|1) + - object_at(any_targets|0,TARGET_any_targets|0) + - object_at(any_targets|1,TARGET_any_targets|1) - not_holding(robot_0) - not_holding(robot_1) @@ -35,7 +35,7 @@ stage_goals: stage_1_2: expr_type: AND sub_exprs: - - at(any_targets|0,TARGET_any_targets|0) + - object_at(any_targets|0,TARGET_any_targets|0) stage_2_1: expr_type: AND sub_exprs: @@ -49,7 +49,7 @@ stage_goals: stage_2_2: expr_type: AND sub_exprs: - - at(any_targets|1,TARGET_any_targets|1) + - object_at(any_targets|1,TARGET_any_targets|1) solution: - nav_to_goal(any_targets|1, robot_1) diff --git a/habitat-lab/habitat/config/benchmark/multi_agent/pddl/multi_agent_tidy_house.yaml b/habitat-lab/habitat/config/benchmark/multi_agent/pddl/multi_agent_tidy_house.yaml index 03a67cfd2d..59db184284 100644 --- a/habitat-lab/habitat/config/benchmark/multi_agent/pddl/multi_agent_tidy_house.yaml +++ b/habitat-lab/habitat/config/benchmark/multi_agent/pddl/multi_agent_tidy_house.yaml @@ -16,8 +16,8 @@ objects: goal: expr_type: AND sub_exprs: - - at(any_targets|0,TARGET_any_targets|0) - - at(any_targets|1,TARGET_any_targets|1) + - object_at(any_targets|0,TARGET_any_targets|0) + - object_at(any_targets|1,TARGET_any_targets|1) - not_holding(robot_0) - not_holding(robot_1) @@ -35,7 +35,7 @@ stage_goals: stage_1_2: expr_type: AND sub_exprs: - - at(any_targets|0,TARGET_any_targets|0) + - object_at(any_targets|0,TARGET_any_targets|0) stage_2_1: expr_type: AND sub_exprs: @@ -49,7 +49,7 @@ stage_goals: stage_2_2: expr_type: AND sub_exprs: - - at(any_targets|1,TARGET_any_targets|1) + - object_at(any_targets|1,TARGET_any_targets|1) solution: - nav_to_goal(any_targets|0, robot_0) diff --git a/habitat-lab/habitat/config/habitat/task/rearrange/pddl/prepare_groceries.yaml b/habitat-lab/habitat/config/habitat/task/rearrange/pddl/prepare_groceries.yaml index a68c350339..fd9f011cbb 100644 --- a/habitat-lab/habitat/config/habitat/task/rearrange/pddl/prepare_groceries.yaml +++ b/habitat-lab/habitat/config/habitat/task/rearrange/pddl/prepare_groceries.yaml @@ -8,9 +8,9 @@ init: goal: expr_type: AND sub_exprs: - - at(obj0_target|0,TARGET_obj0_target|0) - - at(obj1_target|1,TARGET_obj1_target|1) - - at(obj2_target|2,TARGET_obj2_target|2) + - object_at(obj0_target|0,TARGET_obj0_target|0) + - object_at(obj1_target|1,TARGET_obj1_target|1) + - object_at(obj2_target|2,TARGET_obj2_target|2) - not_holding(robot_0) objects: @@ -40,31 +40,31 @@ stage_goals: stage_1: expr_type: AND sub_exprs: - - at(obj0_target|0,TARGET_obj0_target|0) + - object_at(obj0_target|0,TARGET_obj0_target|0) - not_holding(robot_0) stage_1_5: expr_type: AND sub_exprs: - - at(obj0_target|0,TARGET_obj0_target|0) + - object_at(obj0_target|0,TARGET_obj0_target|0) - holding(obj1_target|1, robot_0) stage_2: expr_type: AND sub_exprs: - - at(obj0_target|0,TARGET_obj0_target|0) - - at(obj1_target|1,TARGET_obj1_target|1) + - object_at(obj0_target|0,TARGET_obj0_target|0) + - object_at(obj1_target|1,TARGET_obj1_target|1) - not_holding(robot_0) stage_2_5: expr_type: AND sub_exprs: - - at(obj0_target|0,TARGET_obj0_target|0) - - at(obj1_target|1,TARGET_obj1_target|1) + - object_at(obj0_target|0,TARGET_obj0_target|0) + - object_at(obj1_target|1,TARGET_obj1_target|1) - holding(obj2_target|2, robot_0) stage_3: expr_type: AND sub_exprs: - - at(obj0_target|0,TARGET_obj0_target|0) - - at(obj1_target|1,TARGET_obj1_target|1) - - at(obj2_target|2,TARGET_obj2_target|2) + - object_at(obj0_target|0,TARGET_obj0_target|0) + - object_at(obj1_target|1,TARGET_obj1_target|1) + - object_at(obj2_target|2,TARGET_obj2_target|2) solution: - nav(obj0_target|0, robot_0) diff --git a/habitat-lab/habitat/config/habitat/task/rearrange/pddl/rearrange.yaml b/habitat-lab/habitat/config/habitat/task/rearrange/pddl/rearrange.yaml index e8637ecc66..473930a03a 100644 --- a/habitat-lab/habitat/config/habitat/task/rearrange/pddl/rearrange.yaml +++ b/habitat-lab/habitat/config/habitat/task/rearrange/pddl/rearrange.yaml @@ -9,7 +9,7 @@ objects: goal: expr_type: AND sub_exprs: - - at(goal0|0,TARGET_goal0|0) + - object_at(goal0|0,TARGET_goal0|0) - not_holding(robot_0) stage_goals: @@ -20,4 +20,4 @@ stage_goals: stage_1: expr_type: AND sub_exprs: - - at(goal0|0,TARGET_goal0|0) + - object_at(goal0|0,TARGET_goal0|0) diff --git a/habitat-lab/habitat/config/habitat/task/rearrange/pddl/rearrange_easy.yaml b/habitat-lab/habitat/config/habitat/task/rearrange/pddl/rearrange_easy.yaml index d9cc34be2d..da78ce73b0 100644 --- a/habitat-lab/habitat/config/habitat/task/rearrange/pddl/rearrange_easy.yaml +++ b/habitat-lab/habitat/config/habitat/task/rearrange/pddl/rearrange_easy.yaml @@ -9,7 +9,7 @@ objects: goal: expr_type: AND sub_exprs: - - at(goal0|0,TARGET_goal0|0) + - object_at(goal0|0,TARGET_goal0|0) - not_holding(robot_0) stage_goals: @@ -20,7 +20,7 @@ stage_goals: stage_1: expr_type: AND sub_exprs: - - at(goal0|0,TARGET_goal0|0) + - object_at(goal0|0,TARGET_goal0|0) solution: - nav(goal0|0, robot_0) diff --git a/habitat-lab/habitat/config/habitat/task/rearrange/pddl/set_table.yaml b/habitat-lab/habitat/config/habitat/task/rearrange/pddl/set_table.yaml index 336354d858..651c75f672 100644 --- a/habitat-lab/habitat/config/habitat/task/rearrange/pddl/set_table.yaml +++ b/habitat-lab/habitat/config/habitat/task/rearrange/pddl/set_table.yaml @@ -20,8 +20,8 @@ init: goal: expr_type: AND sub_exprs: - - at(bowl_target|0,TARGET_bowl_target|0) - - at(fruit_target|1,TARGET_fruit_target|1) + - object_at(bowl_target|0,TARGET_bowl_target|0) + - object_at(fruit_target|1,TARGET_fruit_target|1) - not_holding(robot_0) stage_goals: stage_0_5: @@ -31,18 +31,18 @@ stage_goals: stage_1: expr_type: AND sub_exprs: - - at(bowl_target|0,TARGET_bowl_target|0) + - object_at(bowl_target|0,TARGET_bowl_target|0) - not_holding(robot_0) stage_1_5: expr_type: AND sub_exprs: - - at(bowl_target|0,TARGET_bowl_target|0) + - object_at(bowl_target|0,TARGET_bowl_target|0) - holding(fruit_target|1, robot_0) stage_2: expr_type: AND sub_exprs: - - at(bowl_target|0,TARGET_bowl_target|0) - - at(fruit_target|1,TARGET_fruit_target|1) + - object_at(bowl_target|0,TARGET_bowl_target|0) + - object_at(fruit_target|1,TARGET_fruit_target|1) solution: - nav_to_receptacle(cab_push_point_5,bowl_target|0, robot_0) diff --git a/habitat-lab/habitat/config/habitat/task/rearrange/pddl/tidy_house.yaml b/habitat-lab/habitat/config/habitat/task/rearrange/pddl/tidy_house.yaml index de9e684da1..0c8a788f2a 100644 --- a/habitat-lab/habitat/config/habitat/task/rearrange/pddl/tidy_house.yaml +++ b/habitat-lab/habitat/config/habitat/task/rearrange/pddl/tidy_house.yaml @@ -30,11 +30,11 @@ init: goal: expr_type: AND sub_exprs: - - at(any_targets|0,TARGET_any_targets|0) - - at(any_targets|1,TARGET_any_targets|1) - - at(any_targets|2,TARGET_any_targets|2) - - at(any_targets|3,TARGET_any_targets|3) - - at(any_targets|4,TARGET_any_targets|4) + - object_at(any_targets|0,TARGET_any_targets|0) + - object_at(any_targets|1,TARGET_any_targets|1) + - object_at(any_targets|2,TARGET_any_targets|2) + - object_at(any_targets|3,TARGET_any_targets|3) + - object_at(any_targets|4,TARGET_any_targets|4) - not_holding(robot_0) stage_goals: @@ -45,55 +45,55 @@ stage_goals: stage_1: expr_type: AND sub_exprs: - - at(any_targets|0,TARGET_any_targets|0) + - object_at(any_targets|0,TARGET_any_targets|0) - not_holding(robot_0) stage_1_5: expr_type: AND sub_exprs: - - at(any_targets|0,TARGET_any_targets|0) + - object_at(any_targets|0,TARGET_any_targets|0) - holding(any_targets|1, robot_0) stage_2: expr_type: AND sub_exprs: - - at(any_targets|0,TARGET_any_targets|0) - - at(any_targets|1,TARGET_any_targets|1) + - object_at(any_targets|0,TARGET_any_targets|0) + - object_at(any_targets|1,TARGET_any_targets|1) - not_holding(robot_0) stage_2_5: expr_type: AND sub_exprs: - - at(any_targets|0,TARGET_any_targets|0) - - at(any_targets|1,TARGET_any_targets|1) + - object_at(any_targets|0,TARGET_any_targets|0) + - object_at(any_targets|1,TARGET_any_targets|1) - holding(any_targets|2, robot_0) stage_3: expr_type: AND sub_exprs: - - at(any_targets|0,TARGET_any_targets|0) - - at(any_targets|1,TARGET_any_targets|1) - - at(any_targets|2,TARGET_any_targets|2) + - object_at(any_targets|0,TARGET_any_targets|0) + - object_at(any_targets|1,TARGET_any_targets|1) + - object_at(any_targets|2,TARGET_any_targets|2) - not_holding(robot_0) stage_3_5: expr_type: AND sub_exprs: - - at(any_targets|0,TARGET_any_targets|0) - - at(any_targets|1,TARGET_any_targets|1) - - at(any_targets|2,TARGET_any_targets|2) + - object_at(any_targets|0,TARGET_any_targets|0) + - object_at(any_targets|1,TARGET_any_targets|1) + - object_at(any_targets|2,TARGET_any_targets|2) - holding(any_targets|3, robot_0) stage_4: expr_type: AND sub_exprs: - - at(any_targets|0,TARGET_any_targets|0) - - at(any_targets|1,TARGET_any_targets|1) - - at(any_targets|2,TARGET_any_targets|2) - - at(any_targets|3,TARGET_any_targets|3) + - object_at(any_targets|0,TARGET_any_targets|0) + - object_at(any_targets|1,TARGET_any_targets|1) + - object_at(any_targets|2,TARGET_any_targets|2) + - object_at(any_targets|3,TARGET_any_targets|3) - not_holding(robot_0) stage_4_5: expr_type: AND sub_exprs: - - at(any_targets|0,TARGET_any_targets|0) - - at(any_targets|1,TARGET_any_targets|1) - - at(any_targets|2,TARGET_any_targets|2) - - at(any_targets|3,TARGET_any_targets|3) - - at(any_targets|4,TARGET_any_targets|4) + - object_at(any_targets|0,TARGET_any_targets|0) + - object_at(any_targets|1,TARGET_any_targets|1) + - object_at(any_targets|2,TARGET_any_targets|2) + - object_at(any_targets|3,TARGET_any_targets|3) + - object_at(any_targets|4,TARGET_any_targets|4) solution: - nav(any_targets|0, robot_0) diff --git a/habitat-lab/habitat/config/habitat/task/rearrange/pddl/tidy_house_2obj.yaml b/habitat-lab/habitat/config/habitat/task/rearrange/pddl/tidy_house_2obj.yaml index 2bb2cb196b..58d3b4128d 100644 --- a/habitat-lab/habitat/config/habitat/task/rearrange/pddl/tidy_house_2obj.yaml +++ b/habitat-lab/habitat/config/habitat/task/rearrange/pddl/tidy_house_2obj.yaml @@ -13,8 +13,8 @@ objects: goal: expr_type: AND sub_exprs: - - at(any_targets|0,TARGET_any_targets|0) - - at(any_targets|1,TARGET_any_targets|1) + - object_at(any_targets|0,TARGET_any_targets|0) + - object_at(any_targets|1,TARGET_any_targets|1) - not_holding(robot_0) stage_goals: @@ -25,7 +25,7 @@ stage_goals: stage_1_2: expr_type: AND sub_exprs: - - at(any_targets|0,TARGET_any_targets|0) + - object_at(any_targets|0,TARGET_any_targets|0) stage_2_1: expr_type: AND sub_exprs: @@ -33,7 +33,7 @@ stage_goals: stage_2_2: expr_type: AND sub_exprs: - - at(any_targets|1,TARGET_any_targets|1) + - object_at(any_targets|1,TARGET_any_targets|1) solution: - nav_to_goal(any_targets|0, robot_0) - pick(any_targets|0, robot_0) diff --git a/habitat-lab/habitat/datasets/rearrange/navmesh_utils.py b/habitat-lab/habitat/datasets/rearrange/navmesh_utils.py index a2d9f46d7d..e3fa54456e 100644 --- a/habitat-lab/habitat/datasets/rearrange/navmesh_utils.py +++ b/habitat-lab/habitat/datasets/rearrange/navmesh_utils.py @@ -72,7 +72,7 @@ def unoccluded_navmesh_snap( min_sample_dist: float = 0.5, ) -> Optional[mn.Vector3]: """ - Snap a point to the navmesh considering point visibilty via raycasting. + Snap a point to the navmesh considering point visibility via raycasting. :property pos: The 3D position to snap. :property height: The height of the agent above the navmesh. Assumes the navmesh snap point is on the ground. Should be the maximum relative distance from navmesh ground to which a visibility check should indicate non-occlusion. The first check starts from this height. (E.g. agent_eyes_y - agent_base_y) @@ -85,14 +85,15 @@ def unoccluded_navmesh_snap( :property max_samples: The maximum number of attempts to sample navmesh points for the test batch. :property min_sample_dist: The minimum allowed L2 distance between samples in the test batch. - NOTE: this function is based on smapling and does not guarantee the closest point. + NOTE: this function is based on sampling and does not guarantee the closest point. :return: An approximation of the closest unoccluded snap point to pos or None if an unoccluded point could not be found. """ # first try the closest snap point snap_point = pathfinder.snap_point(pos, island_id) - is_occluded = snap_point_is_occluded( + + is_occluded = np.isnan(snap_point[0]) or snap_point_is_occluded( target=pos, snap_point=snap_point, height=height, @@ -114,15 +115,19 @@ def unoccluded_navmesh_snap( sample = pathfinder.get_random_navigable_point_near( circle_center=pos, radius=search_radius, island_index=island_id ) - reject = False - for batch_sample in test_batch: - if np.linalg.norm(sample - batch_sample[0]) < min_sample_dist: - reject = True - break - if not reject: - test_batch.append( - (sample, float(np.linalg.norm(sample - pos))) - ) + if not np.isnan(sample[0]): + reject = False + for batch_sample in test_batch: + if ( + np.linalg.norm(sample - batch_sample[0]) + < min_sample_dist + ): + reject = True + break + if not reject: + test_batch.append( + (sample, float(np.linalg.norm(sample - pos))) + ) sample_count += 1 # sort the test batch points by distance to the target diff --git a/habitat-lab/habitat/datasets/rearrange/samplers/receptacle.py b/habitat-lab/habitat/datasets/rearrange/samplers/receptacle.py index 8b0ddb10a7..1b76a99c8a 100644 --- a/habitat-lab/habitat/datasets/rearrange/samplers/receptacle.py +++ b/habitat-lab/habitat/datasets/rearrange/samplers/receptacle.py @@ -304,6 +304,7 @@ def __init__( parent_object_handle: str = None, parent_link: Optional[int] = None, up: Optional[mn.Vector3] = None, + scale: Union[float, mn.Vector3] = None, ) -> None: """ Initialize the TriangleMeshReceptacle from mesh data and pre-compute the area weighted accumulator. @@ -313,9 +314,19 @@ def __init__( :param parent_object_handle: The rigid or articulated object instance handle for the parent object to which the Receptacle is attached. None for globally defined stage Receptacles. :param parent_link: Index of the link to which the Receptacle is attached if the parent is an ArticulatedObject. -1 denotes the base link. None for rigid objects and stage Receptacles. :param up: The "up" direction of the Receptacle in local AABB space. Used for optionally culling receptacles in un-supportive states such as inverted surfaces. + :param scale: The scaling vector (or uniform scaling float) to be applied to the mesh. """ super().__init__(name, parent_object_handle, parent_link, up) self.mesh_data = mesh_data + + # apply the scale + if scale is not None: + m_verts = self.mesh_data.mutable_attribute( + mn.trade.MeshAttribute.POSITION + ) + for vix, v in enumerate(m_verts): + m_verts[vix] = v * scale + self.area_weighted_accumulator = ( [] ) # normalized float weights for each triangle for sampling @@ -695,6 +706,7 @@ def parse_receptacles_from_user_config( up=up, parent_object_handle=parent_object_handle, parent_link=parent_link_ix, + scale=ao_uniform_scaling, ) ) else: @@ -705,13 +717,42 @@ def parse_receptacles_from_user_config( return receptacles +def cull_filtered_receptacles( + receptacles: List[Receptacle], exclude_filter_strings: List[str] +) -> List[Receptacle]: + """ + Filter a list of Receptacles to exclude any which are matched to the provided exclude_filter_strings. + Each string in filter strings is checked against each receptacle's unique_name. If the unique_name contains any filter string as a substring, that Receptacle is filtered. + + :param receptacles: The initial list of Receptacle objects. + :param exclude_filter_strings: The list of filter substrings defining receptacles which should not be active in the current scene. + + :return: The filtered list of Receptacle objects. Those which contain none of the filter substrings in their unqiue_name. + """ + + filtered_receptacles = [] + for receptacle in receptacles: + culled = False + for filter_substring in exclude_filter_strings: + if filter_substring in receptacle.unique_name: + culled = True + break + if not culled: + filtered_receptacles.append(receptacle) + return filtered_receptacles + + def find_receptacles( - sim: habitat_sim.Simulator, ignore_handles: Optional[List[str]] = None + sim: habitat_sim.Simulator, + ignore_handles: Optional[List[str]] = None, + exclude_filter_strings: Optional[List[str]] = None, ) -> List[Union[Receptacle, AABBReceptacle, TriangleMeshReceptacle]]: """ Scrape and return a list of all Receptacles defined in the metadata belonging to the scene's currently instanced objects. :param sim: Simulator must be provided. + :param ignore_handles: An optional list of handles for ManagedObjects which should be skipped. No Receptacles for matching objects will be returned. + :param exclude_filter_strings: An optional list of excluded Receptacle substrings. Any Receptacle which contains any excluded filter substring in its unique_name will not be included in the returned set. """ obj_mgr = sim.get_rigid_object_manager() @@ -773,6 +814,12 @@ def find_receptacles( ) ) + # filter out individual Receptacles with excluded substrings + if exclude_filter_strings is not None: + receptacles = cull_filtered_receptacles( + receptacles, exclude_filter_strings + ) + # check for non-unique naming mistakes in user dataset for rec_ix in range(len(receptacles)): rec1_unique_name = receptacles[rec_ix].unique_name @@ -795,6 +842,59 @@ class ReceptacleSet: comment: str = "" +def get_scene_rec_filter_filepath( + mm: habitat_sim.metadata.MetadataMediator, scene_handle: str +) -> str: + """ + Look in the user_defined metadata for a scene to find the configured filepath for the scene's Receptacle filter file. + + :return: Filter filepath or None if not found. + """ + scene_user_defined = mm.get_scene_user_defined(scene_handle) + if scene_user_defined is not None and scene_user_defined.has_value( + "scene_filter_file" + ): + scene_filter_file = scene_user_defined.get("scene_filter_file") + scene_filter_file = os.path.join( + os.path.dirname(mm.active_dataset), scene_filter_file + ) + return scene_filter_file + return None + + +def get_excluded_recs_from_filter_file( + rec_filter_filepath: str, filter_types: Optional[List[str]] = None +) -> List[str]: + """ + Load and digest a Receptacle filter file to generate a list of strings which should be excluded from the active ReceptacleSet. + + :param filter_types: Optionally specify a particular set of filter types to scrape. Default is all filters. + """ + + possible_filter_types = [ + "manually_filtered", + "access_filtered", + "stability_filtered", + "height_filtered", + ] + + if filter_types is None: + filter_types = possible_filter_types + else: + for filter_type in filter_types: + assert ( + filter_type in possible_filter_types + ), f"Specified filter type '{filter_type}' is not in supported set: {possible_filter_types}" + + filtered_unique_names = [] + with open(rec_filter_filepath, "r") as f: + filter_json = json.load(f) + for filter_type in filter_types: + for filtered_unique_name in filter_json[filter_type]: + filtered_unique_names.append(filtered_unique_name) + return filtered_unique_names + + class ReceptacleTracker: def __init__( self, @@ -825,34 +925,19 @@ def init_scene_filters( :param mm: The active MetadataMediator instance from which to load the filter data. :param scene_handle: The handle of the currently instantiated scene. """ - scene_user_defined = mm.get_scene_user_defined(scene_handle) - filtered_unique_names = [] - if scene_user_defined is not None and scene_user_defined.has_value( - "scene_filter_file" - ): - scene_filter_file = scene_user_defined.get("scene_filter_file") - # construct the dataset level path for the filter data file - scene_filter_file = os.path.join( - os.path.dirname(mm.active_dataset), scene_filter_file + scene_filter_filepath = get_scene_rec_filter_filepath(mm, scene_handle) + if scene_filter_filepath is not None: + filtered_unique_names = get_excluded_recs_from_filter_file( + scene_filter_filepath ) - with open(scene_filter_file, "r") as f: - filter_json = json.load(f) - for filter_type in [ - "manually_filtered", - "access_filtered", - "stability_filtered", - "height_filtered", - ]: - for filtered_unique_name in filter_json[filter_type]: - filtered_unique_names.append(filtered_unique_name) # add exclusion filters to all receptacles sets for r_set in self._receptacle_sets.values(): r_set.excluded_receptacle_substrings.extend( filtered_unique_names ) - logger.info( - f"Loaded receptacle filter data for scene '{scene_handle}' from configured filter file '{scene_filter_file}'." - ) + logger.info( + f"Loaded receptacle filter data for scene '{scene_handle}' from configured filter file '{scene_filter_filepath}'." + ) else: logger.info( f"Loaded receptacle filter data for scene '{scene_handle}' does not have configured filter file." diff --git a/habitat-lab/habitat/sims/habitat_simulator/debug_visualizer.py b/habitat-lab/habitat/sims/habitat_simulator/debug_visualizer.py index bb6266110d..29feeec659 100644 --- a/habitat-lab/habitat/sims/habitat_simulator/debug_visualizer.py +++ b/habitat-lab/habitat/sims/habitat_simulator/debug_visualizer.py @@ -14,6 +14,7 @@ import habitat_sim from habitat.core.logging import logger from habitat.utils.common import check_make_dir +from habitat_sim.physics import ManagedArticulatedObject, ManagedRigidObject class DebugObservation: @@ -77,6 +78,43 @@ def save(self, output_path: str, prefix: str = "") -> str: return file_path +def draw_object_highlight( + obj: Union[ManagedRigidObject, ManagedArticulatedObject], + debug_line_render: habitat_sim.gfx.DebugLineRender, + camera_transform: mn.Matrix4, + color: mn.Color4 = None, +) -> None: + """ + Draw a circle around the object to highlight it. The circle normal is oriented toward the camera_transform. + + :param obj: The ManagedObject + :param debug_line_render: The DebugLineRender instance for the Simulator. + :param camera_transform: The Matrix4 transform of the camera. Used to orient the circle normal. + :param color: The color of the circle. Default magenta. + """ + + if color is None: + color = mn.Color4.magenta() + + obj_bb = None + if isinstance(obj, ManagedArticulatedObject): + from habitat.sims.habitat_simulator.sim_utilities import get_ao_root_bb + + obj_bb = get_ao_root_bb(obj) + else: + obj_bb = obj.root_scene_node.cumulative_bb + + obj_center = obj.transformation.transform_point(obj_bb.center()) + obj_size = obj_bb.size().max() / 2 + + debug_line_render.draw_circle( + translation=obj_center, + radius=obj_size, + color=color, + normal=camera_transform.translation - obj_center, + ) + + class DebugVisualizer: """ Support class for simple visual debugging of a Simulator instance. diff --git a/habitat-lab/habitat/sims/habitat_simulator/object_state_machine.py b/habitat-lab/habitat/sims/habitat_simulator/object_state_machine.py new file mode 100644 index 0000000000..730d1dc9d3 --- /dev/null +++ b/habitat-lab/habitat/sims/habitat_simulator/object_state_machine.py @@ -0,0 +1,338 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from collections import defaultdict +from typing import Any, Dict, List, Union + +import magnum as mn + +import habitat.sims.habitat_simulator.sim_utilities as sutils +import habitat_sim +from habitat.sims.habitat_simulator.debug_visualizer import ( + draw_object_highlight, +) +from habitat_sim.physics import ManagedArticulatedObject, ManagedRigidObject + +################################################## +# Supporting utilities for getting and setting metadata values in ManagedObject "user_defined" Configurations. +################################################## + + +def get_state_of_obj( + obj: Union[ManagedArticulatedObject, ManagedRigidObject], + state_name: str, +) -> Any: + """ + Try to get the specified state from an object's "object_states" user_defined metadata. + + :param obj: The ManagedObject. + :param state_name: The name/key of the object state property to query. + + :return: The state value (variable type) or None if not found. + """ + + if "object_states" in obj.user_attributes.get_subconfig_keys(): + obj_states_config = obj.user_attributes.get_subconfig("object_states") + if obj_states_config.has_value(state_name): + return obj_states_config.get(state_name) + return None + + +def set_state_of_obj( + obj: Union[ManagedArticulatedObject, ManagedRigidObject], + state_name: str, + state_val: Any, +) -> None: + """ + Set the specified state in an object's "object_states" user_defined metadata. + + :param obj: The ManagedObject. + :param state_name: The name/key of the object state property to set. + :param state_val: The value of the object state property to set. + """ + + user_attr = obj.user_attributes + obj_state_config = user_attr.get_subconfig("object_states") + obj_state_config.set(state_name, state_val) + user_attr.save_subconfig("object_states", obj_state_config) + + +################################################## +# Object state machine implementation +################################################## + + +class ObjectStateSpec: + """ + Abstract base class for object states specifications. Defines the API for inherited and extended states. + + An ObjectStateSpec is a singleton instance defining the interface and dynamics of a particular metadata state. + + Many ManagedObject instances can share an ObjectStateSpec, but there should be only one for each active Simulator since the state may compute and pivot on global internal caches and variables. + """ + + def __init__(self): + # Each ObjectStateSpec should have a unique name string + self.name = "AbstractState" + # What type of data describes this state + self.type = None + # S list of semantic classes labels with pre-define membership in the state set. All objects in these classes are assumed to have this state, whether or not a value is defined in metadata. + self.accepted_semantic_classes = [] + + def is_affordance_of_obj( + self, obj: Union[ManagedArticulatedObject, ManagedRigidObject] + ) -> bool: + """ + Determine whether or not an object instance can have this ObjectStateSpec by checking semantic class against the configured set. + + :param obj: The ManagedObject instance. + + :return: Whether or not the object has this state affordance. + """ + + # TODO: This is a placeholder until semantic_class can be officially supported or replaced by something else + if ( + get_state_of_obj(obj, "semantic_class") + in self.accepted_semantic_classes + ): + return True + + return False + + def update_state_context(self, sim: habitat_sim.Simulator) -> None: + """ + Update internal state context independent of individual objects' states. + + :param sim: The Simulator instance. + """ + + def update_state( + self, + sim: habitat_sim.Simulator, + obj: Union[ManagedArticulatedObject, ManagedRigidObject], + dt: float, + ) -> None: + """ + Add state machine logic to modify the state of an object given access to the Simulator and timestep. + Meant to be called from within the simulation or step loop to continuously update the state. + + :param sim: The Simulator instance. + :param obj: The ManagedObject instance. + :param dt: The timestep over which to update. + """ + + def default_value(self) -> Any: + """ + If an object does not have a value for this state defined, return a default value. + """ + + def draw_context( + self, + debug_line_render: habitat_sim.gfx.DebugLineRender, + camera_transform: mn.Matrix4, + ) -> None: + """ + Draw any context cues which are independent of individual objects' state. + Meant to be called once per draw per ObjectStateSpec singleton. + + :param debug_line_render: The DebugLineRender instance for the Simulator. + :param camera_transform: The Matrix4 camera transform. + """ + + def draw_state( + self, + obj: Union[ManagedArticulatedObject, ManagedRigidObject], + debug_line_render: habitat_sim.gfx.DebugLineRender, + camera_transform: mn.Matrix4, + ) -> None: + """ + Logic to draw debug lines visualizing this state for the object. + + :param obj: The ManagedObject instance. + :param debug_line_render: The DebugLineRender instance for the Simulator. + :param camera_transform: The Matrix4 camera transform. + """ + + +class BooleanObjectState(ObjectStateSpec): + """ + Abstract ObjectStateSpec base class for boolean type states. + Defines some standard handling for boolean states. + """ + + def __init__(self): + self.name = "BooleanState" + self.type = bool + + def default_value(self) -> Any: + """ + If an object does not have a value for this state defined, return a default value. + """ + + return True + + def draw_state( + self, + obj: Union[ManagedArticulatedObject, ManagedRigidObject], + debug_line_render: habitat_sim.gfx.DebugLineRender, + camera_transform: mn.Matrix4, + ) -> None: + """ + Logic to draw debug lines visualizing this state for the object. + Draws a circle highlight around the object color by state value: green if True, red if False. + + :param obj: The ManagedObject instance. + :param debug_line_render: The DebugLineRender instance for the Simulator. + :param camera_transform: The Matrix4 camera transform. + """ + + obj_state = get_state_of_obj(obj, self.name) + obj_state = self.default_value() if (obj_state is None) else obj_state + + color = mn.Color4.red() + if obj_state: + color = mn.Color4.green() + + draw_object_highlight(obj, debug_line_render, camera_transform, color) + + def toggle( + self, obj: Union[ManagedArticulatedObject, ManagedRigidObject] + ) -> bool: + """ + Toggles a boolean state, returning the newly set value. + + :param obj: The ManagedObject instance. + + :return: The new value of the state. + """ + + cur_state = get_state_of_obj(obj, self.name) + new_state = not cur_state + set_state_of_obj(obj, self.name, new_state) + return new_state + + +class ObjectIsClean(BooleanObjectState): + """ + ObjectIsClean state specifies whether an object is clean or dirty. + """ + + def __init__(self): + super().__init__() + self.name = "is_clean" + # TODO: set the semantic class membership list + self.accepted_semantic_classes = [] + + +class ObjectIsPoweredOn(BooleanObjectState): + """ + State specifies whether an appliance object is powered on or off. + """ + + def __init__(self): + super().__init__() + self.name = "is_powered_on" + # TODO: set the semantic class membership list + self.accepted_semantic_classes = [] + + def default_value(self) -> Any: + """ + Default value for power is off. + """ + + return False + + +class ObjectStateMachine: + """ + Defines the logic for managing multiple states across all objects in the scene. + """ + + def __init__(self, active_states: List[ObjectStateSpec] = None) -> None: + # a list of ObjectStateSpec singleton instances which are active in the current scene + self.active_states = active_states if active_states is not None else [] + # map tracked objects to their set of state properies + self.objects_with_states: Dict[ + str, List[ObjectStateSpec] + ] = defaultdict(lambda: []) + + def initialize_object_state_map(self, sim: habitat_sim.Simulator) -> None: + """ + Reset the objects_with_states dict and re-initializes it by parsing all objects from the scene and checking is_affordance_of_obj for all active ObjectStateSpecs. + + :param sim: The Simulator instance. + """ + + self.objects_with_states = defaultdict(lambda: []) + all_objects = sutils.get_all_objects(sim) + for obj in all_objects: + self.register_object(obj) + + def register_object( + self, obj: Union[ManagedArticulatedObject, ManagedRigidObject] + ) -> None: + """ + Register a single object in the 'objects_with_states' dict by checking 'is_affordance_of_obj' for all active ObjectStateSpecs. + Use this when a new object is added to the scene and needs to be registered. + + :param obj: The ManagedObject instance to register. + """ + + for state in self.active_states: + if state.is_affordance_of_obj(obj): + self.objects_with_states[obj.handle].append(state) + print(f"registered state {state} for object {obj.handle}") + + def update_states(self, sim: habitat_sim.Simulator, dt: float) -> None: + """ + Update all tracked object states for a simulation step. + + :param sim: The Simulator instance. + """ + + # first update any state context + for state in self.active_states: + state.update_state_context(sim) + # then update the individual object states + for obj_handle, states in self.objects_with_states.items(): + if len(states) > 0: + obj = sutils.get_obj_from_handle(sim, obj_handle) + for state in states: + state.update_state(sim, obj, dt) + + def get_snapshot_dict( + self, sim: habitat_sim.Simulator + ) -> Dict[str, Dict[str, Any]]: + """ + Scrape all active ObjectStateSpecs to collect a snapshot of the current state of all objects. + + :return: The state snapshot as a Dict keyed by object state unique name, value is another dict mapping object instance handles to state values. + + Example: + { + "is_powered_on": { + "my_lamp.0001": True, + "my_oven": False, + ... + }, + "is_clean": { + "my_dish.0002:" False, + ... + }, + ... + } + """ + snapshot: Dict[str, Dict[str, Any]] = defaultdict(lambda: {}) + for object_handle, states in self.objects_with_states.items(): + obj = sutils.get_obj_from_handle(sim, object_handle) + for state in states: + obj_state = get_state_of_obj(obj, state.name) + snapshot[state.name][object_handle] = ( + obj_state + if obj_state is not None + else state.default_value() + ) + return dict(snapshot) diff --git a/habitat-lab/habitat/sims/habitat_simulator/sim_utilities.py b/habitat-lab/habitat/sims/habitat_simulator/sim_utilities.py index 29cc1e927f..9c281d570a 100644 --- a/habitat-lab/habitat/sims/habitat_simulator/sim_utilities.py +++ b/habitat-lab/habitat/sims/habitat_simulator/sim_utilities.py @@ -13,6 +13,18 @@ from habitat.sims.habitat_simulator.debug_visualizer import DebugVisualizer +def object_shortname_from_handle(object_handle: str) -> str: + """ + Splits any path directory and instance increment from the handle. + + :param object_handle: The raw object template or instance handle. + + :return: the shortened name string. + """ + + return object_handle.split("/")[-1].split(".")[0].split("_:")[0] + + def register_custom_wireframe_box_template( sim: habitat_sim.Simulator, size: mn.Vector3, @@ -219,7 +231,49 @@ def get_obj_size_along( return local_vec_size, center -def size_regularized_distance( +def size_regularized_bb_distance( + bb_a: mn.Range3D, + bb_b: mn.Range3D, + transform_a: mn.Matrix4 = None, + transform_b: mn.Matrix4 = None, +) -> float: + """ + Get the heuristic surface-to-surface distance between two bounding boxes (regularized by their individual heuristic sizes). + Estimate the distance from center to boundary along the line between bb centers. These sizes are then subtracted from the center-to-center distance as a heuristic for surface-to-surface distance. + + :param bb_a: local bounding box of one object + :param bb_b: local bounding box of another object + :param transform_a: local to global transform for the first object. Default is identity. + :param transform_b: local to global transform for the second object. Default is identity. + + :return: heuristic surface-to-surface distance. + """ + + if transform_a is None: + transform_a = mn.Matrix4.identity_init() + if transform_b is None: + transform_b = mn.Matrix4.identity_init() + + a_center = transform_a.transform_point(bb_a.center()) + b_center = transform_b.transform_point(bb_b.center()) + + disp = a_center - b_center + dist = disp.length() + disp_dir = disp / dist + + local_scale_a = mn.Matrix4.scaling(bb_a.size() / 2.0) + local_vec_a = transform_a.inverted().transform_vector(disp_dir) + local_vec_size_a = local_scale_a.transform_vector(local_vec_a).length() + + local_scale_b = mn.Matrix4.scaling(bb_b.size() / 2.0) + local_vec_b = transform_b.inverted().transform_vector(disp_dir) + local_vec_size_b = local_scale_b.transform_vector(local_vec_b).length() + + # if object bounding boxes are significantly overlapping then distance may be negative, clamp to 0 + return max(0, dist - local_vec_size_a - local_vec_size_b) + + +def size_regularized_object_distance( sim: habitat_sim.Simulator, object_id_a: int, object_id_b: int, @@ -255,23 +309,9 @@ def size_regularized_distance( sim, object_id_b, ao_link_map, ao_aabbs ) - a_center = transform_a.transform_point(obja_bb.center()) - b_center = transform_b.transform_point(objb_bb.center()) - - disp = a_center - b_center - dist = disp.length() - disp_dir = disp / dist - - local_scale_a = mn.Matrix4.scaling(obja_bb.size() / 2.0) - local_vec_a = transform_a.inverted().transform_vector(disp_dir) - local_vec_size_a = local_scale_a.transform_vector(local_vec_a).length() - - local_scale_b = mn.Matrix4.scaling(objb_bb.size() / 2.0) - local_vec_b = transform_b.inverted().transform_vector(disp_dir) - local_vec_size_b = local_scale_b.transform_vector(local_vec_b).length() - - # if object bounding boxes are significantly overlapping then distance may be negative, clamp to 0 - return max(0, dist - local_vec_size_a - local_vec_size_b) + return size_regularized_bb_distance( + obja_bb, objb_bb, transform_a, transform_b + ) def bb_ray_prescreen( @@ -379,6 +419,7 @@ def snap_down( obj: habitat_sim.physics.ManagedRigidObject, support_obj_ids: Optional[List[int]] = None, dbv: Optional[DebugVisualizer] = None, + max_collision_depth: float = 0.01, ) -> bool: """ Attempt to project an object in the gravity direction onto the surface below it. @@ -387,6 +428,7 @@ def snap_down( :param obj: The RigidObject instance. :param support_obj_ids: A list of object ids designated as valid support surfaces for object placement. Contact with other objects is a criteria for placement rejection. If none provided, default support surface is the stage/ground mesh (0). :param dbv: Optionally provide a DebugVisualizer (dbv) to render debug images of each object's computed snap position before collision culling. + :param max_collision_depth: The maximum contact penetration depth between the object and the support surface. Higher values are easier to sample, but result in less dynamically stabile states. :return: boolean placement success. @@ -424,7 +466,9 @@ def snap_down( cp.object_id_a == obj.object_id or cp.object_id_b == obj.object_id ) and ( - (cp.contact_distance < -0.05) + ( + cp.contact_distance < (-1 * max_collision_depth) + ) # contact depth is negative distance or not ( cp.object_id_a in support_obj_ids or cp.object_id_b in support_obj_ids @@ -577,6 +621,54 @@ def get_ao_link_id_map(sim: habitat_sim.Simulator) -> Dict[int, int]: return ao_link_map +def get_ao_default_link( + ao: habitat_sim.physics.ManagedArticulatedObject, + compute_if_not_found: bool = False, +) -> Optional[int]: + """ + Get the "default" link index for a ManagedArticulatedObject. + The "default" link is the one link which should be used if only one joint can be actuated. For example, the largest or most accessible drawer or door. + The default link is determined by: + - must be "prismatic" or "revolute" joint type + - first look in the metadata Configuration for an annotated link. + - (if compute_if_not_found) - if not annotated, it is programmatically computed from a heuristic. + + Default link heuristic: the link with the lowest Y value in the bounding box with appropriate joint type. + + :param compute_if_not_found: If true, try to compute the default link if it isn't found. + + :return: The default link index or None if not found. Cannot be base link (-1). + """ + + # first look in metadata + default_link = ao.user_attributes.get("default_link") + + if default_link is None and compute_if_not_found: + valid_joint_types = [ + habitat_sim.physics.JointType.Revolute, + habitat_sim.physics.JointType.Prismatic, + ] + lowest_link = None + lowest_y: int = None + # compute the default link + for link_id in ao.get_link_ids(): + if ao.get_link_joint_type(link_id) in valid_joint_types: + # use minimum global keypoint Y value + link_lowest_y = min( + get_articulated_link_global_keypoints(ao, link_id), + key=lambda x: x[1], + )[1] + if lowest_y is None or link_lowest_y < lowest_y: + lowest_y = link_lowest_y + lowest_link = link_id + if lowest_link is not None: + default_link = lowest_link + # if found, set in metadata for next time + ao.user_attributes.set("default_link", default_link) + + return default_link + + def get_obj_from_id( sim: habitat_sim.Simulator, obj_id: int, @@ -623,7 +715,7 @@ def get_obj_from_handle( Get a ManagedRigidObject or ManagedArticulatedObject from its instance handle. :param sim: The Simulator instance. - :param obj_handle: object istance handle for which ManagedObject is desired. + :param obj_handle: object instance handle for which ManagedObject is desired. :return: a ManagedObject or None """ @@ -870,7 +962,7 @@ def within( first_voting_keypoint = 0 if center_ensures_containment: - # initialize the list from keypoint 0 (center of bounding box) which gaurantees containment + # initialize the list from keypoint 0 (center of bounding box) which guarantees containment containment_ids = list(keypoint_intersect_set[0]) first_voting_keypoint = 1 @@ -907,7 +999,7 @@ def ontop( ) -> List[int]: """ Get a list of all object ids or objects that are "ontop" of a particular object_a. - Concretely, 'ontop' is defined as: contact points between object_a and objectB have vertical normals "upward" relative to object_a. + Concretely, 'ontop' is defined as: contact points between object_a and object_b have vertical normals "upward" relative to object_a. This function uses collision points to determine which objects are resting on or contacting the surface of object_a. :param sim: The Simulator instance. @@ -1091,25 +1183,25 @@ def get_object_regions( def get_link_normalized_joint_position( - objectA: habitat_sim.physics.ManagedArticulatedObject, link_ix: int + object_a: habitat_sim.physics.ManagedArticulatedObject, link_ix: int ) -> float: """ Normalize the joint limit range [min, max] -> [0,1] and return the current joint state in this range. - :param objectA: The parent ArticulatedObject of the link. + :param object_a: The parent ArticulatedObject of the link. :param link_ix: The index of the link within the parent object. Not the link's object_id. :return: normalized joint position [0,1] """ - assert objectA.get_link_joint_type(link_ix) in [ + assert object_a.get_link_joint_type(link_ix) in [ habitat_sim.physics.JointType.Revolute, habitat_sim.physics.JointType.Prismatic, - ], f"Invalid joint type '{objectA.get_link_joint_type(link_ix)}'. Open/closed not a valid check for multi-dimensional or fixed joints." + ], f"Invalid joint type '{object_a.get_link_joint_type(link_ix)}'. Open/closed not a valid check for multi-dimensional or fixed joints." - joint_pos_ix = objectA.get_link_joint_pos_offset(link_ix) - joint_pos = objectA.joint_positions[joint_pos_ix] - limits = objectA.joint_position_limits + joint_pos_ix = object_a.get_link_joint_pos_offset(link_ix) + joint_pos = object_a.joint_positions[joint_pos_ix] + limits = object_a.joint_position_limits # compute the normalized position [0,1] n_pos = (joint_pos - limits[0][joint_pos_ix]) / ( @@ -1119,7 +1211,7 @@ def get_link_normalized_joint_position( def set_link_normalized_joint_position( - objectA: habitat_sim.physics.ManagedArticulatedObject, + object_a: habitat_sim.physics.ManagedArticulatedObject, link_ix: int, normalized_pos: float, ) -> None: @@ -1128,31 +1220,31 @@ def set_link_normalized_joint_position( Assumes the joint has valid joint limits. - :param objectA: The parent ArticulatedObject of the link. + :param object_a: The parent ArticulatedObject of the link. :param link_ix: The index of the link within the parent object. Not the link's object_id. :param normalized_pos: The normalized position [0,1] to set. """ - assert objectA.get_link_joint_type(link_ix) in [ + assert object_a.get_link_joint_type(link_ix) in [ habitat_sim.physics.JointType.Revolute, habitat_sim.physics.JointType.Prismatic, - ], f"Invalid joint type '{objectA.get_link_joint_type(link_ix)}'. Open/closed not a valid check for multi-dimensional or fixed joints." + ], f"Invalid joint type '{object_a.get_link_joint_type(link_ix)}'. Open/closed not a valid check for multi-dimensional or fixed joints." assert ( normalized_pos <= 1.0 and normalized_pos >= 0 ), "values outside the range [0,1] are by definition beyond the joint limits." - joint_pos_ix = objectA.get_link_joint_pos_offset(link_ix) - limits = objectA.joint_position_limits - joint_positions = objectA.joint_positions + joint_pos_ix = object_a.get_link_joint_pos_offset(link_ix) + limits = object_a.joint_position_limits + joint_positions = object_a.joint_positions joint_positions[joint_pos_ix] = limits[0][joint_pos_ix] + ( normalized_pos * (limits[1][joint_pos_ix] - limits[0][joint_pos_ix]) ) - objectA.joint_positions = joint_positions + object_a.joint_positions = joint_positions def link_is_open( - objectA: habitat_sim.physics.ManagedArticulatedObject, + object_a: habitat_sim.physics.ManagedArticulatedObject, link_ix: int, threshold: float = 0.4, ) -> bool: @@ -1160,18 +1252,18 @@ def link_is_open( Check whether a particular AO link is in the "open" state. We assume that joint limits define the closed state (min) and open state (max). - :param objectA: The parent ArticulatedObject of the link to check. + :param object_a: The parent ArticulatedObject of the link to check. :param link_ix: The index of the link within the parent object. Not the link's object_id. :param threshold: The normalized threshold ratio of joint ranges which are considered "open". E.g. 0.8 = 80% :return: Whether or not the link is considered "open". """ - return get_link_normalized_joint_position(objectA, link_ix) >= threshold + return get_link_normalized_joint_position(object_a, link_ix) >= threshold def link_is_closed( - objectA: habitat_sim.physics.ManagedArticulatedObject, + object_a: habitat_sim.physics.ManagedArticulatedObject, link_ix: int, threshold: float = 0.1, ) -> bool: @@ -1179,41 +1271,138 @@ def link_is_closed( Check whether a particular AO link is in the "closed" state. We assume that joint limits define the closed state (min) and open state (max). - :param objectA: The parent ArticulatedObject of the link to check. + :param object_a: The parent ArticulatedObject of the link to check. :param link_ix: The index of the link within the parent object. Not the link's object_id. :param threshold: The normalized threshold ratio of joint ranges which are considered "closed". E.g. 0.1 = 10% :return: Whether or not the link is considered "closed". """ - return get_link_normalized_joint_position(objectA, link_ix) <= threshold + return get_link_normalized_joint_position(object_a, link_ix) <= threshold def close_link( - objectA: habitat_sim.physics.ManagedArticulatedObject, link_ix: int + object_a: habitat_sim.physics.ManagedArticulatedObject, link_ix: int ) -> None: """ Set a link to the "closed" state. Sets the joint position to the minimum joint limit. TODO: does not do any collision checking to validate the state or move any other objects which may be contained in or supported by this link. - :param objectA: The parent ArticulatedObject of the link to check. + :param object_a: The parent ArticulatedObject of the link to check. :param link_ix: The index of the link within the parent object. Not the link's object_id. """ - set_link_normalized_joint_position(objectA, link_ix, 0) + set_link_normalized_joint_position(object_a, link_ix, 0) def open_link( - objectA: habitat_sim.physics.ManagedArticulatedObject, link_ix: int + object_a: habitat_sim.physics.ManagedArticulatedObject, link_ix: int ) -> None: """ Set a link to the "open" state. Sets the joint position to the maximum joint limit. TODO: does not do any collision checking to validate the state or move any other objects which may be contained in or supported by this link. - :param objectA: The parent ArticulatedObject of the link to check. + :param object_a: The parent ArticulatedObject of the link to check. :param link_ix: The index of the link within the parent object. Not the link's object_id. """ - set_link_normalized_joint_position(objectA, link_ix, 1.0) + set_link_normalized_joint_position(object_a, link_ix, 1.0) + + +def bb_next_to( + bb_a: mn.Range3D, + bb_b: mn.Range3D, + transform_a: mn.Matrix4 = None, + transform_b: mn.Matrix4 = None, + vertical_threshold=0.1, + l2_threshold=0.3, +) -> bool: + """ + Check whether or not two bounding boxes should be considered "next to" one another. + Concretely, consists of two checks: + 1. height difference between the lowest points on the two objects to check that they are approximately resting on the same surface. + 2. regularized L2 distance between object centers. Regularized in this case means displacement vector is truncted by each object's heuristic size. + + :param bb_a: local bounding box of one object + :param bb_b: local bounding box of another object + :param transform_a: local to global transform for the first object. Default is identity. + :param transform_b: local to global transform for the second object. Default is identity. + :param vertical_threshold: vertical distance allowed between objects' lowest points. + :param l2_threshold: regularized L2 distance allow between the objects' centers. + + :return: Whether or not the objects are heuristically "next to" one another. + """ + + if transform_a is None: + transform_a = mn.Matrix4.identity_init() + if transform_b is None: + transform_b = mn.Matrix4.identity_init() + + keypoints_a = get_global_keypoints_from_bb(bb_a, transform_a) + keypoints_b = get_global_keypoints_from_bb(bb_b, transform_b) + + lowest_height_a = min([p[1] for p in keypoints_a]) + lowest_height_b = min([p[1] for p in keypoints_b]) + + if abs(lowest_height_a - lowest_height_b) > vertical_threshold: + return False + + if ( + size_regularized_bb_distance(bb_a, bb_b, transform_a, transform_b) + > l2_threshold + ): + return False + + return True + + +def obj_next_to( + sim: habitat_sim.Simulator, + object_id_a: int, + object_id_b: int, + vertical_threshold=0.1, + l2_threshold=0.5, + ao_link_map: Dict[int, int] = None, + ao_aabbs: Dict[int, mn.Range3D] = None, +) -> bool: + """ + Check whether or not two objects should be considered "next to" one another. + Concretely, consists of two checks: + 1. height difference between the lowest points on the two objects to check that they are approximately resting on the same surface. + 2. regularized L2 distance between object centers. Regularized in this case means displacement vector is truncted by each object's heuristic size. + + :param sim: The Simulator instance. + :param object_id_a: object_id of the first ManagedObject or link. + :param object_id_b: object_id of the second ManagedObject or link. + :param vertical_threshold: vertical distance allowed between objects' lowest points. + :param l2_threshold: regularized L2 distance allow between the objects' centers. This should be tailored to the scenario. + :param ao_link_map: A pre-computed map from link object ids to their parent ArticulatedObject's object id. + :param ao_aabbs: A pre-computed map from ArticulatedObject object_ids to their local bounding boxes. If not provided, recomputed as necessary. + + :return: Whether or not the objects are heuristically "next to" one another. + """ + + assert object_id_a != object_id_b, "Object cannot be 'next to' itself." + + assert ( + object_id_a != habitat_sim.stage_id + and object_id_b != habitat_sim.stage_id + ), "Cannot compute distance between the stage and its contents." + + obja_bb, transform_a = get_bb_for_object_id( + sim, object_id_a, ao_link_map, ao_aabbs + ) + objb_bb, transform_b = get_bb_for_object_id( + sim, object_id_b, ao_link_map, ao_aabbs + ) + + return bb_next_to( + obja_bb, + objb_bb, + transform_a, + transform_b, + vertical_threshold, + l2_threshold, + ) diff --git a/habitat-lab/habitat/tasks/rearrange/multi_task/domain_configs/fp.yaml b/habitat-lab/habitat/tasks/rearrange/multi_task/domain_configs/fp.yaml index babc0f3459..7ecf71a787 100644 --- a/habitat-lab/habitat/tasks/rearrange/multi_task/domain_configs/fp.yaml +++ b/habitat-lab/habitat/tasks/rearrange/multi_task/domain_configs/fp.yaml @@ -14,56 +14,66 @@ constants: {} predicates: - - name: in - args: - - name: obj - expr_type: obj_type - - name: receptacle - expr_type: art_receptacle_entity_type - set_state: - obj_states: - obj: receptacle - - name: holding args: - name: obj expr_type: movable_entity_type - - name: robot_id + - name: robot expr_type: robot_entity_type - set_state: - robot_states: - robot_id: - holding: obj + is_valid_fn: + # Function that checks if the predicate is satisfied. See + # `pddl_defined_predicates.py` for the pre-defined instances of these + # functions. + _target_: habitat.tasks.rearrange.multi_task.pddl_defined_predicates.is_robot_hold_match + hold_state: True + set_state_fn: + # Funtion that sets simulator state based on predicate arguments. + _target_: habitat.tasks.rearrange.multi_task.pddl_defined_predicates.set_robot_holding + hold_state: True + + - name: in + args: + - name: obj + expr_type: obj_type + - name: recep + expr_type: art_receptacle_entity_type + is_valid_fn: + _target_: habitat.tasks.rearrange.multi_task.pddl_defined_predicates.is_inside - name: not_holding args: - - name: robot_id + - name: robot expr_type: robot_entity_type - set_state: - robot_states: - robot_id: - should_drop: True - + is_valid_fn: + _target_: habitat.tasks.rearrange.multi_task.pddl_defined_predicates.is_robot_hold_match + hold_state: False + set_state_fn: + _target_: habitat.tasks.rearrange.multi_task.pddl_defined_predicates.set_robot_holding + hold_state: False - name: robot_at args: - - name: Y + - name: at_entity expr_type: static_obj_type - - name: robot_id + - name: robot expr_type: robot_entity_type - set_state: - robot_states: - robot_id: - pos: Y - - - name: at + is_valid_fn: + _target_: habitat.tasks.rearrange.multi_task.pddl_defined_predicates.is_robot_at_position + dist_thresh: 2.0 + set_state_fn: + _target_: habitat.tasks.rearrange.multi_task.pddl_defined_predicates.set_robot_position + dist_thresh: 2.0 + + - name: object_at args: - name: obj - expr_type: obj_type + expr_type: movable_entity_type - name: at_entity expr_type: static_obj_type - set_state: - obj_states: - obj: at_entity + is_valid_fn: + _target_: habitat.tasks.rearrange.multi_task.pddl_defined_predicates.is_object_at + dist_thresh: 0.3 + set_state_fn: + _target_: habitat.tasks.rearrange.multi_task.pddl_defined_predicates.set_object_at actions: - name: nav_to_goal @@ -197,7 +207,7 @@ actions: - robot_at(obj, robot) postcondition: - not_holding(robot) - - at(place_obj, obj) + - object_at(place_obj, obj) task_info: task: RearrangePlaceTask-v0 task_def: "place" diff --git a/habitat-lab/habitat/tasks/rearrange/multi_task/domain_configs/replica_cad.yaml b/habitat-lab/habitat/tasks/rearrange/multi_task/domain_configs/replica_cad.yaml index 847941f8e5..69edab0cab 100644 --- a/habitat-lab/habitat/tasks/rearrange/multi_task/domain_configs/replica_cad.yaml +++ b/habitat-lab/habitat/tasks/rearrange/multi_task/domain_configs/replica_cad.yaml @@ -9,7 +9,6 @@ types: - cab_type - fridge_type - constants: - name: cab_push_point_7 expr_type: cab_type @@ -23,100 +22,116 @@ constants: expr_type: fridge_type predicates: - - name: in - args: - - name: obj - expr_type: obj_type - - name: receptacle - expr_type: art_receptacle_entity_type - set_state: - obj_states: - obj: receptacle - - name: holding args: - name: obj expr_type: movable_entity_type - - name: robot_id + - name: robot expr_type: robot_entity_type - set_state: - robot_states: - robot_id: - holding: obj + is_valid_fn: + # Function that checks if the predicate is satisfied. See + # `pddl_defined_predicates.py` for the pre-defined instances of these + # functions. + _target_: habitat.tasks.rearrange.multi_task.pddl_defined_predicates.is_robot_hold_match + hold_state: True + set_state_fn: + # Funtion that sets simulator state based on predicate arguments. + _target_: habitat.tasks.rearrange.multi_task.pddl_defined_predicates.set_robot_holding + hold_state: True + + - name: in + args: + - name: obj + expr_type: obj_type + - name: recep + expr_type: art_receptacle_entity_type + is_valid_fn: + _target_: habitat.tasks.rearrange.multi_task.pddl_defined_predicates.is_inside - name: not_holding args: - - name: robot_id + - name: robot expr_type: robot_entity_type - set_state: - robot_states: - robot_id: - should_drop: True + is_valid_fn: + _target_: habitat.tasks.rearrange.multi_task.pddl_defined_predicates.is_robot_hold_match + hold_state: False + set_state_fn: + _target_: habitat.tasks.rearrange.multi_task.pddl_defined_predicates.set_robot_holding + hold_state: False - name: opened_cab args: - - name: cab_id + - name: art_obj expr_type: cab_type - set_state: - art_states: - cab_id: - value: 0.45 - cmp: 'greater' - override_thresh: 0.1 + is_valid_fn: + _target_: habitat.tasks.rearrange.multi_task.pddl_defined_predicates.is_articulated_object_at_state + target_val: 0.45 + cmp: 'greater' + joint_dist_thresh: 0.1 + set_state_fn: + _target_: habitat.tasks.rearrange.multi_task.pddl_defined_predicates.set_articulated_object_at_state + target_val: 0.45 - name: closed_cab args: - - name: cab_id + - name: art_obj expr_type: cab_type - set_state: - arg_spec: - name_match: "cab" - art_states: - cab_id: - value: 0.0 - cmp: 'close' - + is_valid_fn: + _target_: habitat.tasks.rearrange.multi_task.pddl_defined_predicates.is_articulated_object_at_state + target_val: 0.0 + cmp: 'close' + joint_dist_thresh: 0.15 + set_state_fn: + _target_: habitat.tasks.rearrange.multi_task.pddl_defined_predicates.set_articulated_object_at_state + target_val: 0.0 - name: opened_fridge args: - - name: fridge_id + - name: art_obj expr_type: fridge_type - set_state: - art_states: - fridge_id: - value: 1.22 - cmp: 'greater' + is_valid_fn: + _target_: habitat.tasks.rearrange.multi_task.pddl_defined_predicates.is_articulated_object_at_state + target_val: 1.22 + cmp: 'greater' + joint_dist_thresh: 0.15 + set_state_fn: + _target_: habitat.tasks.rearrange.multi_task.pddl_defined_predicates.set_articulated_object_at_state + target_val: 1.22 - name: closed_fridge args: - - name: fridge_id + - name: art_obj expr_type: fridge_type - set_state: - art_states: - fridge_id: - value: 0.0 - cmp: 'close' + is_valid_fn: + _target_: habitat.tasks.rearrange.multi_task.pddl_defined_predicates.is_articulated_object_at_state + target_val: 0.0 + cmp: 'close' + joint_dist_thresh: 0.15 - name: robot_at args: - - name: Y + - name: at_entity expr_type: static_obj_type - - name: robot_id + - name: robot expr_type: robot_entity_type - set_state: - robot_states: - robot_id: - pos: Y + is_valid_fn: + _target_: habitat.tasks.rearrange.multi_task.pddl_defined_predicates.is_robot_at_position + dist_thresh: 2.0 + set_state_fn: + _target_: habitat.tasks.rearrange.multi_task.pddl_defined_predicates.set_robot_position + dist_thresh: 2.0 - - name: at + - name: object_at args: - name: obj expr_type: movable_entity_type - name: at_entity expr_type: static_obj_type - set_state: - obj_states: - obj: at_entity + is_valid_fn: + _target_: habitat.tasks.rearrange.multi_task.pddl_defined_predicates.is_object_at + dist_thresh: 0.3 + set_state_fn: + _target_: habitat.tasks.rearrange.multi_task.pddl_defined_predicates.set_object_at actions: - name: nav @@ -180,7 +195,7 @@ actions: - robot_at(obj, robot) postcondition: - not_holding(robot) - - at(place_obj, obj) + - object_at(place_obj, obj) - name: open_fridge parameters: diff --git a/habitat-lab/habitat/tasks/rearrange/multi_task/pddl_defined_predicates.py b/habitat-lab/habitat/tasks/rearrange/multi_task/pddl_defined_predicates.py new file mode 100644 index 0000000000..42e104e54a --- /dev/null +++ b/habitat-lab/habitat/tasks/rearrange/multi_task/pddl_defined_predicates.py @@ -0,0 +1,391 @@ +from typing import Optional, cast + +import magnum as mn +import numpy as np + +import habitat_sim +from habitat.sims.habitat_simulator.sim_utilities import get_ao_global_bb +from habitat.tasks.rearrange.marker_info import MarkerInfo +from habitat.tasks.rearrange.multi_task.rearrange_pddl import ( + PddlEntity, + PddlSimInfo, + SimulatorObjectType, +) +from habitat.tasks.rearrange.utils import ( + place_agent_at_dist_from_pos, + rearrange_logger, +) + +# TODO: Deprecate these and instead represent them as articulated object entity type. +CAB_TYPE = "cab_type" +FRIDGE_TYPE = "fridge_type" + + +def is_robot_hold_match( + robot: PddlEntity, + sim_info: PddlSimInfo, + hold_state: bool, + obj: Optional[PddlEntity] = None, +) -> bool: + """ + Check if the robot is holding the desired object in the desired hold state. + :param hold_state: True if the robot should be holding the object. + """ + + robot_id = cast( + int, + sim_info.search_for_entity(robot), + ) + grasp_mgr = sim_info.sim.get_agent_data(robot_id).grasp_mgr + + if hold_state: + if obj is not None: + # Robot must hold specific object. + obj_idx = cast(int, sim_info.search_for_entity(obj)) + abs_obj_id = sim_info.sim.scene_obj_ids[obj_idx] + return grasp_mgr.snap_idx == abs_obj_id + else: + # Robot can hold any object. + return grasp_mgr.snap_idx != None + else: + # Robot must hold no object. + return grasp_mgr.snap_idx == None + + +def set_robot_holding( + robot: PddlEntity, + sim_info: PddlSimInfo, + hold_state: bool, + obj: Optional[PddlEntity] = None, +) -> None: + robot_id = cast( + int, + sim_info.search_for_entity(robot), + ) + sim = sim_info.sim + agent_data = sim.get_agent_data(robot_id) + # Set the snapped object information + if not hold_state and agent_data.grasp_mgr.is_grasped: + agent_data.grasp_mgr.desnap(True) + elif hold_state: + if obj is None: + raise ValueError( + f"If setting hold state {hold_state=}, must set object" + ) + # Swap objects to the desired object. + obj_idx = cast(int, sim_info.search_for_entity(obj)) + agent_data.grasp_mgr.desnap(True) + sim.internal_step(-1) + agent_data.grasp_mgr.snap_to_obj(sim.scene_obj_ids[obj_idx]) + sim.internal_step(-1) + + +def is_inside( + obj: PddlEntity, recep: PddlEntity, sim_info: PddlSimInfo +) -> bool: + """ + Check if an entity is inside the receptacle. + """ + + assert sim_info.check_type_matches( + recep, SimulatorObjectType.ARTICULATED_RECEPTACLE_ENTITY.value + ), f"Bad type {recep=}" + + entity_pos = sim_info.get_entity_pos(obj) + check_marker = cast( + MarkerInfo, + sim_info.search_for_entity(recep), + ) + # Hack to see if an object is inside the fridge. + if sim_info.check_type_matches(recep, FRIDGE_TYPE): + global_bb = get_ao_global_bb(check_marker.ao_parent) + else: + bb = check_marker.link_node.cumulative_bb + global_bb = habitat_sim.geo.get_transformed_bb( + bb, check_marker.link_node.transformation + ) + + return global_bb.contains(entity_pos) + + +def is_robot_at_position( + at_entity, + sim_info, + dist_thresh: float, + robot=None, + angle_thresh: Optional[float] = None, +): + if robot is None: + robot = sim_info.sim.get_agent_data(None).articulated_agent + else: + robot_id = cast( + int, + sim_info.search_for_entity(robot), + ) + robot = sim_info.sim.get_agent_data(robot_id).articulated_agent + targ_pos = sim_info.get_entity_pos(at_entity) + + # Get the base transformation + T = robot.base_transformation + # Do transformation + pos = T.inverted().transform_point(targ_pos) + # Project to 2D plane (x,y,z=0) + pos[2] = 0.0 + + # Compute distance + dist = np.linalg.norm(pos) + + # Unit vector of the pos + pos = pos.normalized() + # Define the coordinate of the robot + pos_robot = np.array([1.0, 0.0, 0.0]) + # Get the angle + angle = np.arccos(np.dot(pos, pos_robot)) + + # Check the distance threshold. + if dist > dist_thresh: + return False + + # Check for the angle threshold + if angle_thresh is not None and np.abs(angle) > angle_thresh: + return False + + return True + + +def set_robot_position( + at_entity: PddlEntity, + sim_info: PddlSimInfo, + dist_thresh: float, + robot: Optional[PddlEntity] = None, + filter_colliding_states: bool = True, + angle_noise: float = 0.0, + num_spawn_attempts: int = 200, +): + """ + Set the robot transformation to be within `dist_thresh` of `at_entity`. + """ + + sim = sim_info.sim + if robot is None: + agent_data = sim.get_agent_data(None) + else: + robot_id = cast( + int, + sim_info.search_for_entity(robot), + ) + agent_data = sim.get_agent_data(robot_id) + targ_pos = sim_info.get_entity_pos(at_entity) + + # Place some distance away from the object. + start_pos, start_rot, was_fail = place_agent_at_dist_from_pos( + target_position=targ_pos, + rotation_perturbation_noise=angle_noise, + distance_threshold=dist_thresh, + sim=sim, + num_spawn_attempts=num_spawn_attempts, + filter_colliding_states=filter_colliding_states, + agent=agent_data.articulated_agent, + ) + agent_data.articulated_agent.base_pos = start_pos + agent_data.articulated_agent.base_rot = start_rot + if was_fail: + rearrange_logger.error("Failed to place the robot.") + + # We teleported the agent. We also need to teleport the object the agent was holding. + agent_data.grasp_mgr.update_object_to_grasp() + + +def is_object_at( + obj: PddlEntity, + at_entity: PddlEntity, + sim_info: PddlSimInfo, + dist_thresh: float, +) -> bool: + """ + Checks if an object entity is logically at another entity. At an object + means within a threshold of that object. At a receptacle means on the + receptacle. At a articulated receptacle means inside of it. + """ + + entity_pos = sim_info.get_entity_pos(obj) + + if sim_info.check_type_matches( + at_entity, SimulatorObjectType.ARTICULATED_RECEPTACLE_ENTITY.value + ): + # Object is rigid and target is receptacle, we are checking if + # an object is inside of a receptacle. + return is_inside(obj, at_entity, sim_info) + elif sim_info.check_type_matches( + at_entity, SimulatorObjectType.GOAL_ENTITY.value + ) or sim_info.check_type_matches( + at_entity, SimulatorObjectType.MOVABLE_ENTITY.value + ): + # Is the target `at_entity` a movable or goal entity? + targ_idx = cast( + int, + sim_info.search_for_entity(at_entity), + ) + idxs, pos_targs = sim_info.sim.get_targets() + targ_pos = pos_targs[list(idxs).index(targ_idx)] + + dist = float(np.linalg.norm(entity_pos - targ_pos)) + return dist < dist_thresh + elif sim_info.check_type_matches( + at_entity, SimulatorObjectType.STATIC_RECEPTACLE_ENTITY.value + ): + # TODO: Fix this logic to be using + # habitat/sims/habitat_simulator/sim_utilities.py + recep = cast(mn.Range3D, sim_info.search_for_entity(at_entity)) + return recep.contains(entity_pos) + else: + raise ValueError( + f"Got unexpected combination of {obj} and {at_entity}" + ) + + +def set_object_at( + obj: PddlEntity, + at_entity: PddlEntity, + sim_info: PddlSimInfo, + recep_place_shrink_factor: float = 0.8, +) -> None: + """ + Sets a movable PDDL entity to match the transformation of a desired + `at_entity` which can be a receptacle or goal. + + :param recep_place_shrink_factor: How much to shrink the size of the + receptacle by when placing the entity on a receptacle. + """ + + sim = sim_info.sim + + # The source object must be movable. + if not sim_info.check_type_matches( + obj, SimulatorObjectType.MOVABLE_ENTITY.value + ): + raise ValueError(f"Got unexpected obj {obj}") + + if sim_info.check_type_matches( + at_entity, SimulatorObjectType.GOAL_ENTITY.value + ): + targ_idx = cast( + int, + sim_info.search_for_entity(at_entity), + ) + all_targ_idxs, pos_targs = sim.get_targets() + targ_pos = pos_targs[list(all_targ_idxs).index(targ_idx)] + set_T = mn.Matrix4.translation(targ_pos) + elif sim_info.check_type_matches( + at_entity, SimulatorObjectType.STATIC_RECEPTACLE_ENTITY.value + ): + # Place object on top of receptacle. + recep = cast(mn.Range3D, sim_info.search_for_entity(at_entity)) + + # Divide by 2 because the `from_center` creates from the half size. + shrunk_recep = mn.Range3D.from_center( + recep.center(), + (recep.size() / 2.0) * recep_place_shrink_factor, + ) + pos = np.random.uniform(shrunk_recep.min, shrunk_recep.max) + set_T = mn.Matrix4.translation(pos) + else: + raise ValueError(f"Got unexpected at_entity {at_entity}") + + obj_idx = cast(int, sim_info.search_for_entity(obj)) + abs_obj_id = sim.scene_obj_ids[obj_idx] + + # Get the object id corresponding to this name + rom = sim.get_rigid_object_manager() + set_obj = rom.get_object_by_id(abs_obj_id) + set_obj.transformation = set_T + set_obj.angular_velocity = mn.Vector3.zero_init() + set_obj.linear_velocity = mn.Vector3.zero_init() + sim.internal_step(-1) + set_obj.angular_velocity = mn.Vector3.zero_init() + set_obj.linear_velocity = mn.Vector3.zero_init() + + +def is_articulated_object_at_state( + art_obj: PddlEntity, + sim_info: PddlSimInfo, + target_val: float, + cmp: str, + joint_dist_thresh: float = 0.1, +) -> bool: + """ + Checks if an articulated object matches a joint state condition. + + :param cmp: The comparison to use. Can be "greater", "lesser", or "close". + """ + + if not sim_info.check_type_matches( + art_obj, + SimulatorObjectType.ARTICULATED_RECEPTACLE_ENTITY.value, + ): + raise ValueError(f"Got unexpected entity {art_obj}") + marker = cast( + MarkerInfo, + sim_info.search_for_entity( + art_obj, + ), + ) + cur_value = marker.get_targ_js() + if cmp == "greater": + return cur_value > target_val - joint_dist_thresh + elif cmp == "lesser": + return cur_value < target_val + joint_dist_thresh + elif cmp == "close": + return abs(cur_value - target_val) < joint_dist_thresh + else: + raise ValueError(f"Unrecognized comparison {cmp}") + + +def set_articulated_object_at_state( + art_obj: PddlEntity, sim_info: PddlSimInfo, target_val: float +) -> None: + """ + Sets an articulated object joint state to `target_val`. + """ + + sim = sim_info.sim + rom = sim.get_rigid_object_manager() + + in_pred = sim_info.get_predicate("in") + poss_entities = [ + e + for e in sim_info.all_entities.values() + if e.expr_type.is_subtype_of( + sim_info.expr_types[SimulatorObjectType.MOVABLE_ENTITY.value] + ) + ] + + move_objs = [] + for poss_entity in poss_entities: + bound_in_pred = in_pred.clone() + bound_in_pred.set_param_values([poss_entity, art_obj]) + if not bound_in_pred.is_true(sim_info): + continue + obj_idx = cast( + int, + sim_info.search_for_entity(poss_entity), + ) + abs_obj_id = sim.scene_obj_ids[obj_idx] + set_obj = rom.get_object_by_id(abs_obj_id) + move_objs.append(set_obj) + + marker = cast( + MarkerInfo, + sim_info.search_for_entity( + art_obj, + ), + ) + pre_link_pos = marker.link_node.transformation.translation + marker.set_targ_js(target_val) + post_link_pos = marker.link_node.transformation.translation + + if art_obj.expr_type.is_subtype_of(sim_info.expr_types[CAB_TYPE]): + # Also move all objects that were in the drawer + diff_pos = post_link_pos - pre_link_pos + for move_obj in move_objs: + move_obj.translation += diff_pos diff --git a/habitat-lab/habitat/tasks/rearrange/multi_task/pddl_domain.py b/habitat-lab/habitat/tasks/rearrange/multi_task/pddl_domain.py index 36ccad8987..ccfdd173a7 100644 --- a/habitat-lab/habitat/tasks/rearrange/multi_task/pddl_domain.py +++ b/habitat-lab/habitat/tasks/rearrange/multi_task/pddl_domain.py @@ -4,9 +4,11 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import importlib import itertools import os.path as osp import time +from functools import partial from typing import ( TYPE_CHECKING, Any, @@ -22,8 +24,6 @@ import yaml # type: ignore[import] from habitat.config.default import get_full_habitat_config_path -from habitat.core.dataset import Episode -from habitat.datasets.rearrange.rearrange_dataset import RearrangeDatasetV0 from habitat.tasks.rearrange.multi_task.pddl_action import PddlAction from habitat.tasks.rearrange.multi_task.pddl_logical_expr import ( LogicalExpr, @@ -31,11 +31,6 @@ LogicalQuantifierType, ) from habitat.tasks.rearrange.multi_task.pddl_predicate import Predicate -from habitat.tasks.rearrange.multi_task.pddl_sim_state import ( - ArtSampler, - PddlRobotState, - PddlSimState, -) from habitat.tasks.rearrange.multi_task.rearrange_pddl import ( ExprType, PddlEntity, @@ -71,19 +66,6 @@ def __init__( self._config = cur_task_config self._orig_actions: Dict[str, PddlAction] = {} - if read_config: - # Setup config properties - self._obj_succ_thresh = self._config.obj_succ_thresh - self._art_succ_thresh = self._config.art_succ_thresh - self._robot_at_thresh = self._config.robot_at_thresh - self._num_spawn_attempts = self._config.num_spawn_attempts - self._filter_colliding_states = ( - self._config.filter_colliding_states - ) - self._recep_place_shrink_factor = ( - self._config.recep_place_shrink_factor - ) - if not osp.isabs(domain_file_path): parent_dir = osp.dirname(__file__) domain_file_path = osp.join( @@ -155,41 +137,23 @@ def _parse_predicates(self, domain_def) -> None: PddlEntity(arg["name"], self.expr_types[arg["expr_type"]]) for arg in pred_d["args"] ] - pred_entities = {e.name: e for e in arg_entities} - art_states = pred_d["set_state"].get("art_states", {}) - obj_states = pred_d["set_state"].get("obj_states", {}) - robot_states = pred_d["set_state"].get("robot_states", {}) - - all_entities = {**self.all_entities, **pred_entities} - - art_states = { - all_entities[k]: ArtSampler(**v) for k, v in art_states.items() - } - obj_states = { - all_entities[k]: all_entities[v] for k, v in obj_states.items() - } - - use_robot_states = {} - - def fetch_entity(s): - # Fetches the corresponding entity if the argument is a string - # referring to an entity. - if isinstance(s, str): - return all_entities.get(s, s) - else: - return s - - for k, v in robot_states.items(): - use_k = all_entities[k] - # Sub in any referred entities. - v = {sub_k: fetch_entity(sub_v) for sub_k, sub_v in v.items()} - - use_robot_states[use_k] = PddlRobotState(**v) + if "set_state_fn" not in pred_d: + set_state_fn = None + else: + set_state_fn = _parse_callable(pred_d["set_state_fn"]) - set_state = PddlSimState(art_states, obj_states, use_robot_states) + if "is_valid_fn" not in pred_d: + is_valid_fn = None + else: + is_valid_fn = _parse_callable(pred_d["is_valid_fn"]) - pred = Predicate(pred_d["name"], set_state, arg_entities) + pred = Predicate( + pred_d["name"], + is_valid_fn, + set_state_fn, + arg_entities, + ) self.predicates[pred.name] = pred def _parse_constants(self, domain_def) -> None: @@ -215,6 +179,8 @@ def register_type(self, expr_type: ExprType): def register_episode_entity(self, pddl_entity: PddlEntity) -> None: """ Add an entity to appear in `self.all_entities`. Clears every episode. + Note that `pddl_entity.name` should be unique. Otherwise, it will + overide the existing object with that name. """ self._added_entities[pddl_entity.name] = pddl_entity @@ -224,14 +190,23 @@ def _parse_expr_types(self, domain_def): """ # Always add the default `expr_types` from the simulator. + base_entity = ExprType(SimulatorObjectType.BASE_ENTITY.value, None) self._expr_types: Dict[str, ExprType] = { - obj_type.value: ExprType(obj_type.value, None) - for obj_type in SimulatorObjectType + SimulatorObjectType.BASE_ENTITY.value: base_entity } + self._expr_types.update( + { + obj_type.value: ExprType(obj_type.value, base_entity) + for obj_type in SimulatorObjectType + if obj_type.value != SimulatorObjectType.BASE_ENTITY.value + } + ) for parent_type, sub_types in domain_def["types"].items(): if parent_type not in self._expr_types: - self._expr_types[parent_type] = ExprType(parent_type, None) + self._expr_types[parent_type] = ExprType( + parent_type, base_entity + ) for sub_type in sub_types: if sub_type in self._expr_types: self._expr_types[sub_type].parent = self._expr_types[ @@ -250,13 +225,18 @@ def expr_types(self) -> Dict[str, ExprType]: return {**self._expr_types, **self._added_expr_types} def parse_predicate( - self, pred_str: str, existing_entities: Dict[str, PddlEntity] + self, + pred_str: str, + existing_entities: Optional[Dict[str, PddlEntity]] = None, ) -> Predicate: """ Instantiates a predicate from call in string such as "in(X,Y)". :param pred_str: The string to parse such as "in(X,Y)". - :param existing_entities: The valid entities for arguments in the predicate. + :param existing_entities: The valid entities for arguments in the + predicate. If not specified, uses all defined entities. """ + if existing_entities is None: + existing_entities = {} func_name, func_args = parse_func(pred_str) pred = self.predicates[func_name].clone() @@ -336,9 +316,7 @@ def _parse_expr( def bind_to_instance( self, sim: RearrangeSim, - dataset: RearrangeDatasetV0, env: RearrangeTask, - episode: Episode, ) -> None: """ Attach the domain to the simulator. This does not bind any entity @@ -355,12 +333,7 @@ def bind_to_instance( self._sim_info = PddlSimInfo( sim=sim, - dataset=dataset, env=env, - episode=episode, - obj_thresh=self._obj_succ_thresh, - art_thresh=self._art_succ_thresh, - robot_at_thresh=self._robot_at_thresh, expr_types=self.expr_types, obj_ids=sim.handle_to_object_id, target_ids={ @@ -375,10 +348,7 @@ def bind_to_instance( }, all_entities=self.all_entities, predicates=self.predicates, - num_spawn_attempts=self._num_spawn_attempts, - filter_colliding_states=self._filter_colliding_states, receptacles=sim.receptacles, - recep_place_shrink_factor=self._recep_place_shrink_factor, ) # Ensure that all objects are accounted for. for entity in self.all_entities.values(): @@ -465,8 +435,7 @@ def get_possible_predicates(self) -> List[Predicate]: use_pred = pred.clone() use_pred.set_param_values(entity_input) - if use_pred.are_types_compatible(self.expr_types): - poss_preds.append(use_pred) + poss_preds.append(use_pred) return sorted(poss_preds, key=lambda pred: pred.compact_str) def get_possible_actions( @@ -582,9 +551,11 @@ def expand_quantifiers( """ expr.sub_exprs = [ - self.expand_quantifiers(subexpr)[0] - if isinstance(subexpr, LogicalExpr) - else subexpr + ( + self.expand_quantifiers(subexpr)[0] + if isinstance(subexpr, LogicalExpr) + else subexpr + ) for subexpr in expr.sub_exprs ] @@ -706,3 +677,12 @@ def solution(self): @property def all_entities(self) -> Dict[str, PddlEntity]: return {**self._objects, **super().all_entities} + + +def _parse_callable(callable_d): + full_fn_name = callable_d.pop("_target_") + module_name, _, function_name = full_fn_name.rpartition(".") + module = importlib.import_module(module_name) + fn = getattr(module, function_name) + + return partial(fn, **callable_d) diff --git a/habitat-lab/habitat/tasks/rearrange/multi_task/pddl_logical_expr.py b/habitat-lab/habitat/tasks/rearrange/multi_task/pddl_logical_expr.py index f865328171..b58d67a91d 100644 --- a/habitat-lab/habitat/tasks/rearrange/multi_task/pddl_logical_expr.py +++ b/habitat-lab/habitat/tasks/rearrange/multi_task/pddl_logical_expr.py @@ -93,7 +93,10 @@ def _is_true(self, is_true_fn) -> bool: result = True for i, sub_expr in enumerate(self._sub_exprs): truth_val = is_true_fn(sub_expr) - assert isinstance(truth_val, bool) + if not isinstance(truth_val, bool): + raise ValueError( + f"Predicate returned non truth value: {sub_expr=}, {truth_val=}" + ) self._truth_vals[i] = truth_val result = result and truth_val if not result: diff --git a/habitat-lab/habitat/tasks/rearrange/multi_task/pddl_predicate.py b/habitat-lab/habitat/tasks/rearrange/multi_task/pddl_predicate.py index 7f96367ecd..e81de559c4 100644 --- a/habitat-lab/habitat/tasks/rearrange/multi_task/pddl_predicate.py +++ b/habitat-lab/habitat/tasks/rearrange/multi_task/pddl_predicate.py @@ -3,11 +3,9 @@ # LICENSE file in the root directory of this source tree. -from typing import Dict, List, Optional +from typing import Callable, Dict, List, Optional -from habitat.tasks.rearrange.multi_task.pddl_sim_state import PddlSimState from habitat.tasks.rearrange.multi_task.rearrange_pddl import ( - ExprType, PddlEntity, PddlSimInfo, do_entity_lists_match, @@ -21,21 +19,31 @@ class Predicate: def __init__( self, name: str, - pddl_sim_state: Optional[PddlSimState], + is_valid_fn: Optional[Callable], + set_state_fn: Optional[Callable], args: List[PddlEntity], ): """ :param name: Predicate identifier. Does not need to be unique because predicates have the same name but different arguments. - :param pddl_sim_state: Optionally specifies conditions that must be - true in the simulator for the predicate to be true. If None is - specified, no simulator state will force the Predicate to be true. + :param is_valid_fn: Function that returns if the predicate is true in + the current state. This function must return a bool and + take as input the predicate parameters specified by `args`. If + None, then this always returns True. + :param set_state_fn: Function that sets the state to satisfy the + predicate. This function must return nothing and take as input the + values set in the predicate parameters specified by `args`. If + None, then no simulator state is set. + :param args: The names of the arguments to the predicate. Note that + these are only placeholders. Actual entities are substituted in later + via `self.set_param_values`. """ self._name = name - self._pddl_sim_state = pddl_sim_state self._args = args self._arg_values = None + self._is_valid_fn = is_valid_fn + self._set_state_fn = set_state_fn def are_args_compatible(self, arg_values: List[PddlEntity]): """ @@ -45,15 +53,6 @@ def are_args_compatible(self, arg_values: List[PddlEntity]): return do_entity_lists_match(self._args, arg_values) - def are_types_compatible(self, expr_types: Dict[str, ExprType]) -> bool: - """ - Returns if the argument types match the underlying simulator state. - """ - if self._pddl_sim_state is None: - return True - - return self._pddl_sim_state.is_compatible(expr_types) - def set_param_values(self, arg_values: List[PddlEntity]) -> None: arg_values = list(arg_values) if self._arg_values is not None: @@ -62,7 +61,6 @@ def set_param_values(self, arg_values: List[PddlEntity]) -> None: ) ensure_entity_lists_match(self._args, arg_values) self._arg_values = arg_values - self._pddl_sim_state.sub_in(dict(zip(self._args, self._arg_values))) @property def n_args(self): @@ -77,13 +75,13 @@ def sub_in(self, sub_dict: Dict[PddlEntity, PddlEntity]) -> "Predicate": sub_dict.get(entity, entity) for entity in self._arg_values ] ensure_entity_lists_match(self._args, self._arg_values) - self._pddl_sim_state.sub_in(sub_dict) return self def sub_in_clone(self, sub_dict: Dict[PddlEntity, PddlEntity]): p = Predicate( self._name, - self._pddl_sim_state.sub_in_clone(sub_dict), + self._is_valid_fn, + self._set_state_fn, self._args, ) if self._arg_values is not None: @@ -107,7 +105,12 @@ def is_true(self, sim_info: PddlSimInfo) -> bool: return sim_info.pred_truth_cache[self_repr] # Recompute and potentially cache the result. - result = self._pddl_sim_state.is_true(sim_info) + if self._is_valid_fn is None: + result = True + else: + result = self._is_valid_fn( + sim_info=sim_info, **self._create_kwargs() + ) if sim_info.pred_truth_cache is not None: sim_info.pred_truth_cache[self_repr] = result return result @@ -116,10 +119,18 @@ def set_state(self, sim_info: PddlSimInfo) -> None: """ Sets the simulator state to satisfy the predicate. """ - return self._pddl_sim_state.set_state(sim_info) + if self._set_state_fn is not None: + self._set_state_fn(sim_info=sim_info, **self._create_kwargs()) + + def _create_kwargs(self): + return { + arg.name: val for arg, val in zip(self._args, self._arg_values) + } def clone(self): - p = Predicate(self._name, self._pddl_sim_state.clone(), self._args) + p = Predicate( + self._name, self._is_valid_fn, self._set_state_fn, self._args + ) if self._arg_values is not None: p.set_param_values(self._arg_values) return p diff --git a/habitat-lab/habitat/tasks/rearrange/multi_task/pddl_sim_state.py b/habitat-lab/habitat/tasks/rearrange/multi_task/pddl_sim_state.py deleted file mode 100644 index 04f5b38039..0000000000 --- a/habitat-lab/habitat/tasks/rearrange/multi_task/pddl_sim_state.py +++ /dev/null @@ -1,608 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) Meta Platforms, Inc. and its affiliates. -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -from dataclasses import dataclass, replace -from typing import Any, Dict, Optional, cast - -import magnum as mn -import numpy as np - -import habitat_sim -from habitat.sims.habitat_simulator.sim_utilities import get_ao_global_bb -from habitat.tasks.rearrange.marker_info import MarkerInfo -from habitat.tasks.rearrange.multi_task.rearrange_pddl import ( - PddlEntity, - PddlSimInfo, - SimulatorObjectType, -) -from habitat.tasks.rearrange.utils import ( - place_agent_at_dist_from_pos, - rearrange_logger, -) - -# TODO: Deprecate these and instead represent them as articulated object entity type. -CAB_TYPE = "cab_type" -FRIDGE_TYPE = "fridge_type" - - -class ArtSampler: - """ - Desired simulator state for a articulated object. Expresses a range of - allowable joint values. - """ - - def __init__( - self, value: float, cmp: str, override_thresh: Optional[float] = None - ): - self.value = value - self.cmp = cmp - self.override_thresh = override_thresh - - def is_satisfied(self, cur_value: float, thresh: float) -> bool: - if self.override_thresh is not None: - thresh = self.override_thresh - - if self.cmp == "greater": - return cur_value > self.value - thresh - elif self.cmp == "less": - return cur_value < self.value + thresh - elif self.cmp == "close": - return abs(cur_value - self.value) < thresh - else: - raise ValueError(f"Unrecognized comparison {self.cmp}") - - def sample(self) -> float: - return self.value - - -@dataclass -class PddlRobotState: - """ - Specifies the configuration of the robot. - - :property place_at_pos_dist: If -1.0, this will place the robot as close - as possible to the entity. Otherwise, it will place the robot within X - meters of the entity. If unset, sets to task default. - :property base_angle_noise: How much noise to add to the robot base angle - when setting the robot base position. If not set, sets to task default. - :property place_at_angle_thresh: The required maximum angle to the target - entity in the robot's local frame. Specified in radains. If not specified, - no angle is considered. - :property filter_colliding_states: Whether or not to filter colliding states when placing the - robot. If not set, sets to task default. - """ - - holding: Optional[PddlEntity] = None - should_drop: bool = False - pos: Optional[Any] = None - place_at_pos_dist: Optional[float] = None - place_at_angle_thresh: Optional[float] = None - base_angle_noise: Optional[float] = None - filter_colliding_states: Optional[bool] = None - - def get_place_at_pos_dist(self, sim_info) -> float: - if self.place_at_pos_dist is None: - return sim_info.robot_at_thresh - else: - return self.place_at_pos_dist - - def get_base_angle_noise(self, sim_info) -> float: - if self.base_angle_noise is None: - return 0.0 - return self.base_angle_noise - - def get_filter_colliding_states(self, sim_info) -> Optional[bool]: - if self.filter_colliding_states is None: - return sim_info.filter_colliding_states - else: - return self.filter_colliding_states - - def sub_in( - self, sub_dict: Dict[PddlEntity, PddlEntity] - ) -> "PddlRobotState": - self.holding = sub_dict.get(self.holding, self.holding) - self.pos = sub_dict.get(self.pos, self.pos) - return self - - def sub_in_clone( - self, sub_dict: Dict[PddlEntity, PddlEntity] - ) -> "PddlRobotState": - other = replace(self) - other.holding = sub_dict.get(self.holding, self.holding) - other.pos = sub_dict.get(self.pos, self.pos) - return other - - def clone(self) -> "PddlRobotState": - """ - Returns a shallow copy - """ - return replace(self) - - def is_true(self, sim_info: PddlSimInfo, robot_entity: PddlEntity) -> bool: - """ - Returns if the desired robot state is currently true in the simulator state. - """ - robot_id = cast( - int, - sim_info.search_for_entity(robot_entity), - ) - grasp_mgr = sim_info.sim.get_agent_data(robot_id).grasp_mgr - - assert not (self.holding is not None and self.should_drop) - - if self.holding is not None: - # Robot must be holding desired object. - obj_idx = cast(int, sim_info.search_for_entity(self.holding)) - abs_obj_id = sim_info.sim.scene_obj_ids[obj_idx] - if grasp_mgr.snap_idx != abs_obj_id: - return False - elif self.should_drop and grasp_mgr.snap_idx != None: - return False - - if isinstance(self.pos, PddlEntity): - targ_pos = sim_info.get_entity_pos(self.pos) - robot = sim_info.sim.get_agent_data(robot_id).articulated_agent - - # Get the base transformation - T = robot.base_transformation - # Do transformation - pos = T.inverted().transform_point(targ_pos) - # Project to 2D plane (x,y,z=0) - pos[2] = 0.0 - - # Compute distance - dist = np.linalg.norm(pos) - - # Unit vector of the pos - pos = pos.normalized() - # Define the coordinate of the robot - pos_robot = np.array([1.0, 0.0, 0.0]) - # Get the angle - angle = np.arccos(np.dot(pos, pos_robot)) - - # Check the distance threshold. - if dist > self.get_place_at_pos_dist(sim_info): - return False - - # Check for the angle threshold - if ( - self.place_at_angle_thresh is not None - and np.abs(angle) > self.place_at_angle_thresh - ): - return False - - return True - - def set_state( - self, sim_info: PddlSimInfo, robot_entity: PddlEntity - ) -> None: - """ - Sets the robot state in the simulator. - """ - robot_id = cast( - int, - sim_info.search_for_entity(robot_entity), - ) - sim = sim_info.sim - agent_data = sim.get_agent_data(robot_id) - # Set the snapped object information - if self.should_drop and agent_data.grasp_mgr.is_grasped: - agent_data.grasp_mgr.desnap(True) - elif self.holding is not None: - # Swap objects to the desired object. - obj_idx = cast(int, sim_info.search_for_entity(self.holding)) - agent_data.grasp_mgr.desnap(True) - sim.internal_step(-1) - agent_data.grasp_mgr.snap_to_obj(sim.scene_obj_ids[obj_idx]) - sim.internal_step(-1) - - # Set the robot starting position - if isinstance(self.pos, PddlEntity): - targ_pos = sim_info.get_entity_pos(self.pos) - - # Place some distance away from the object. - start_pos, start_rot, was_fail = place_agent_at_dist_from_pos( - target_position=targ_pos, - rotation_perturbation_noise=self.get_base_angle_noise( - sim_info - ), - distance_threshold=self.get_place_at_pos_dist(sim_info), - sim=sim, - num_spawn_attempts=sim_info.num_spawn_attempts, - filter_colliding_states=self.get_filter_colliding_states( - sim_info - ), - agent=agent_data.articulated_agent, - ) - agent_data.articulated_agent.base_pos = start_pos - agent_data.articulated_agent.base_rot = start_rot - if was_fail: - rearrange_logger.error("Failed to place the robot.") - - # We teleported the agent. We also need to teleport the object the agent was holding. - agent_data.grasp_mgr.update_object_to_grasp() - - elif self.pos is not None: - raise ValueError(f"Unrecongized set position {self.pos}") - - -class PddlSimState: - """ - The "building block" for predicates. This checks if a particular simulator state is satisfied. - """ - - def __init__( - self, - art_states: Dict[PddlEntity, ArtSampler], - obj_states: Dict[PddlEntity, PddlEntity], - robot_states: Dict[PddlEntity, PddlRobotState], - ): - for k, v in obj_states.items(): - if not isinstance(k, PddlEntity) or not isinstance(v, PddlEntity): - raise TypeError(f"Unexpected types {obj_states}") - - for k, v in art_states.items(): - if not isinstance(k, PddlEntity) or not isinstance(v, ArtSampler): - raise TypeError(f"Unexpected types {art_states}") - - for k, v in robot_states.items(): - if not isinstance(k, PddlEntity) or not isinstance( - v, PddlRobotState - ): - raise TypeError(f"Unexpected types {robot_states}") - - self._art_states = art_states - self._obj_states = obj_states - self._robot_states = robot_states - - def __repr__(self): - return f"{self._art_states}, {self._obj_states}, {self._robot_states}" - - def clone(self) -> "PddlSimState": - return PddlSimState( - self._art_states, - self._obj_states, - {k: v.clone() for k, v in self._robot_states.items()}, - ) - - def sub_in_clone( - self, sub_dict: Dict[PddlEntity, PddlEntity] - ) -> "PddlSimState": - return PddlSimState( - {sub_dict.get(k, k): v for k, v in self._art_states.items()}, - { - sub_dict.get(k, k): sub_dict.get(v, v) - for k, v in self._obj_states.items() - }, - { - sub_dict.get(k, k): robot_state.sub_in_clone(sub_dict) - for k, robot_state in self._robot_states.items() - }, - ) - - def sub_in(self, sub_dict: Dict[PddlEntity, PddlEntity]) -> "PddlSimState": - self._robot_states = { - sub_dict.get(k, k): robot_state.sub_in(sub_dict) - for k, robot_state in self._robot_states.items() - } - self._art_states = { - sub_dict.get(k, k): v for k, v in self._art_states.items() - } - self._obj_states = { - sub_dict.get(k, k): sub_dict.get(v, v) - for k, v in self._obj_states.items() - } - return self - - def is_compatible(self, expr_types) -> bool: - def type_matches(entity, match_names): - return any( - entity.expr_type.is_subtype_of(expr_types[match_name]) - for match_name in match_names - ) - - for entity, target in self._obj_states.items(): - # We have to be able to move the source object. - if not type_matches( - entity, [SimulatorObjectType.MOVABLE_ENTITY.value] - ): - return False - - # All targets must refer to 1 of the predefined types. - if not ( - type_matches( - target, - [ - SimulatorObjectType.ARTICULATED_RECEPTACLE_ENTITY.value, - SimulatorObjectType.GOAL_ENTITY.value, - SimulatorObjectType.MOVABLE_ENTITY.value, - SimulatorObjectType.STATIC_RECEPTACLE_ENTITY.value, - ], - ) - ): - return False - - if entity.expr_type.name == target.expr_type.name: - return False - - # All the receptacle state entities must refer to receptacles. - return all( - type_matches( - art_entity, - [SimulatorObjectType.ARTICULATED_RECEPTACLE_ENTITY.value], - ) - for art_entity in self._art_states - ) - - def is_true( - self, - sim_info: PddlSimInfo, - ) -> bool: - """ - Returns True if the grounded state is present in the current simulator state. - Throws exception if the arguments are not compatible. - """ - - # Check object states are true. - if not all( - _is_obj_state_true(entity, target, sim_info) - for entity, target in self._obj_states.items() - ): - return False - - # Check articulated object states are true. - if not all( - _is_art_state_true(art_entity, set_art, sim_info) - for art_entity, set_art in self._art_states.items() - ): - return False - - # Check robot states are true. - if not all( - robot_state.is_true(sim_info, robot_entity) - for robot_entity, robot_state in self._robot_states.items() - ): - return False - return True - - def set_state(self, sim_info: PddlSimInfo) -> None: - """ - Set this state in the simulator. Warning, this steps the simulator. - """ - # Set all desired object states. - for entity, target in self._obj_states.items(): - _set_obj_state(entity, target, sim_info) - - # Set all desired articulated object states. - for art_entity, set_art in self._art_states.items(): - sim = sim_info.sim - rom = sim.get_rigid_object_manager() - - in_pred = sim_info.get_predicate("in") - poss_entities = [ - e - for e in sim_info.all_entities.values() - if e.expr_type.is_subtype_of( - sim_info.expr_types[ - SimulatorObjectType.MOVABLE_ENTITY.value - ] - ) - ] - - move_objs = [] - for poss_entity in poss_entities: - bound_in_pred = in_pred.clone() - bound_in_pred.set_param_values([poss_entity, art_entity]) - if not bound_in_pred.is_true(sim_info): - continue - obj_idx = cast( - int, - sim_info.search_for_entity(poss_entity), - ) - abs_obj_id = sim.scene_obj_ids[obj_idx] - set_obj = rom.get_object_by_id(abs_obj_id) - move_objs.append(set_obj) - - marker = cast( - MarkerInfo, - sim_info.search_for_entity( - art_entity, - ), - ) - pre_link_pos = marker.link_node.transformation.translation - marker.set_targ_js(set_art.sample()) - post_link_pos = marker.link_node.transformation.translation - - if art_entity.expr_type.is_subtype_of( - sim_info.expr_types[CAB_TYPE] - ): - # Also move all objects that were in the drawer - diff_pos = post_link_pos - pre_link_pos - for move_obj in move_objs: - move_obj.translation += diff_pos - - # Set all desired robot states. - for robot_entity, robot_state in self._robot_states.items(): - robot_state.set_state(sim_info, robot_entity) - - -def _is_object_inside( - entity: PddlEntity, target: PddlEntity, sim_info: PddlSimInfo -) -> bool: - """ - Returns if `entity` is inside of `target` in the CURRENT simulator state, NOT at the start of the episode. - """ - entity_pos = sim_info.get_entity_pos(entity) - check_marker = cast( - MarkerInfo, - sim_info.search_for_entity(target), - ) - if sim_info.check_type_matches(target, FRIDGE_TYPE): - global_bb = get_ao_global_bb(check_marker.ao_parent) - else: - bb = check_marker.link_node.cumulative_bb - global_bb = habitat_sim.geo.get_transformed_bb( - bb, check_marker.link_node.transformation - ) - - return global_bb.contains(entity_pos) - - -def _is_obj_state_true( - entity: PddlEntity, target: PddlEntity, sim_info: PddlSimInfo -) -> bool: - entity_pos = sim_info.get_entity_pos(entity) - - if sim_info.check_type_matches( - target, SimulatorObjectType.ARTICULATED_RECEPTACLE_ENTITY.value - ): - # object is rigid and target is receptacle, we are checking if - # an object is inside of a receptacle. - if not _is_object_inside(entity, target, sim_info): - return False - elif sim_info.check_type_matches( - target, SimulatorObjectType.GOAL_ENTITY.value - ): - targ_idx = cast( - int, - sim_info.search_for_entity(target), - ) - idxs, pos_targs = sim_info.sim.get_targets() - targ_pos = pos_targs[list(idxs).index(targ_idx)] - - dist = np.linalg.norm(entity_pos - targ_pos) - if dist >= sim_info.obj_thresh: - return False - elif sim_info.check_type_matches( - target, SimulatorObjectType.STATIC_RECEPTACLE_ENTITY.value - ): - recep = cast(mn.Range3D, sim_info.search_for_entity(target)) - return recep.contains(entity_pos) - elif sim_info.check_type_matches( - target, SimulatorObjectType.MOVABLE_ENTITY.value - ): - raise NotImplementedError() - else: - raise ValueError( - f"Got unexpected combination of {entity} and {target}" - ) - return True - - -def _is_art_state_true( - art_entity: PddlEntity, set_art: ArtSampler, sim_info: PddlSimInfo -) -> bool: - """ - Checks if an articulated object entity matches a condition specified by - `set_art`. - """ - - if not sim_info.check_type_matches( - art_entity, - SimulatorObjectType.ARTICULATED_RECEPTACLE_ENTITY.value, - ): - raise ValueError(f"Got unexpected entity {set_art}") - - marker = cast( - MarkerInfo, - sim_info.search_for_entity( - art_entity, - ), - ) - prev_art_pos = marker.get_targ_js() - if not set_art.is_satisfied(prev_art_pos, sim_info.art_thresh): - return False - return True - - -def _place_obj_on_goal( - target: PddlEntity, sim_info: PddlSimInfo -) -> mn.Matrix4: - """ - Place an object at a goal position. - """ - - sim = sim_info.sim - targ_idx = cast( - int, - sim_info.search_for_entity(target), - ) - all_targ_idxs, pos_targs = sim.get_targets() - targ_pos = pos_targs[list(all_targ_idxs).index(targ_idx)] - return mn.Matrix4.translation(targ_pos) - - -def _place_obj_on_obj( - entity: PddlEntity, target: PddlEntity, sim_info: PddlSimInfo -) -> mn.Matrix4: - """ - This is intended to implement placing an object on top of another object. - """ - - raise NotImplementedError() - - -def _place_obj_on_recep(target: PddlEntity, sim_info) -> mn.Matrix4: - # Place object on top of receptacle. - recep = cast(mn.Range3D, sim_info.search_for_entity(target)) - - # Divide by 2 because the `from_center` creates from the half size. - shrunk_recep = mn.Range3D.from_center( - recep.center(), - (recep.size() / 2.0) * sim_info.recep_place_shrink_factor, - ) - pos = np.random.uniform(shrunk_recep.min, shrunk_recep.max) - return mn.Matrix4.translation(pos) - - -def _set_obj_state( - entity: PddlEntity, target: PddlEntity, sim_info: PddlSimInfo -) -> None: - """ - Sets an object state to match the state specified by `target`. The context - of this will vary on the type of the source and target entity (like if we - are placing an object on a receptacle). - """ - - sim = sim_info.sim - - # The source object must be movable. - if not sim_info.check_type_matches( - entity, SimulatorObjectType.MOVABLE_ENTITY.value - ): - raise ValueError(f"Got unexpected entity {entity}") - - if sim_info.check_type_matches( - target, SimulatorObjectType.ARTICULATED_RECEPTACLE_ENTITY.value - ): - raise NotImplementedError() - elif sim_info.check_type_matches( - target, SimulatorObjectType.GOAL_ENTITY.value - ): - set_T = _place_obj_on_goal(target, sim_info) - elif sim_info.check_type_matches( - target, SimulatorObjectType.STATIC_RECEPTACLE_ENTITY.value - ): - set_T = _place_obj_on_recep(target, sim_info) - elif sim_info.check_type_matches( - target, SimulatorObjectType.MOVABLE_ENTITY.value - ): - set_T = _place_obj_on_obj(entity, target, sim_info) - else: - raise ValueError(f"Got unexpected target {target}") - - obj_idx = cast(int, sim_info.search_for_entity(entity)) - abs_obj_id = sim.scene_obj_ids[obj_idx] - - # Get the object id corresponding to this name - rom = sim.get_rigid_object_manager() - set_obj = rom.get_object_by_id(abs_obj_id) - set_obj.transformation = set_T - set_obj.angular_velocity = mn.Vector3.zero_init() - set_obj.linear_velocity = mn.Vector3.zero_init() - sim.internal_step(-1) - set_obj.angular_velocity = mn.Vector3.zero_init() - set_obj.linear_velocity = mn.Vector3.zero_init() diff --git a/habitat-lab/habitat/tasks/rearrange/multi_task/pddl_task.py b/habitat-lab/habitat/tasks/rearrange/multi_task/pddl_task.py index fd74f37df5..616c9a7afc 100644 --- a/habitat-lab/habitat/tasks/rearrange/multi_task/pddl_task.py +++ b/habitat-lab/habitat/tasks/rearrange/multi_task/pddl_task.py @@ -5,11 +5,9 @@ # LICENSE file in the root directory of this source tree. import os.path as osp -from typing import cast from habitat.core.dataset import Episode from habitat.core.registry import registry -from habitat.datasets.rearrange.rearrange_dataset import RearrangeDatasetV0 from habitat.tasks.rearrange.multi_task.pddl_domain import PddlProblem from habitat.tasks.rearrange.rearrange_task import RearrangeTask @@ -36,8 +34,6 @@ def __init__(self, *args, config, **kwargs): def reset(self, episode: Episode): super().reset(episode, fetch_observations=False) - self.pddl_problem.bind_to_instance( - self._sim, cast(RearrangeDatasetV0, self._dataset), self, episode - ) + self.pddl_problem.bind_to_instance(self._sim, self) self._sim.maybe_update_articulated_agent() return self._get_observations(episode) diff --git a/habitat-lab/habitat/tasks/rearrange/multi_task/rearrange_pddl.py b/habitat-lab/habitat/tasks/rearrange/multi_task/rearrange_pddl.py index 542f1382f7..77de396b7f 100644 --- a/habitat-lab/habitat/tasks/rearrange/multi_task/rearrange_pddl.py +++ b/habitat-lab/habitat/tasks/rearrange/multi_task/rearrange_pddl.py @@ -8,11 +8,9 @@ from enum import Enum from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union -import magnum as mn import numpy as np -from habitat.core.dataset import Episode -from habitat.datasets.rearrange.rearrange_dataset import RearrangeDatasetV0 +from habitat.datasets.rearrange.samplers.receptacle import Receptacle from habitat.tasks.rearrange.marker_info import MarkerInfo from habitat.tasks.rearrange.rearrange_sim import RearrangeSim from habitat.tasks.rearrange.rearrange_task import RearrangeTask @@ -27,6 +25,7 @@ class SimulatorObjectType(Enum): Predefined entity types for which default predicate behavior is defined. """ + BASE_ENTITY = "entity_type" MOVABLE_ENTITY = "movable_entity_type" STATIC_RECEPTACLE_ENTITY = "static_receptacle_entity_type" ARTICULATED_RECEPTACLE_ENTITY = "art_receptacle_entity_type" @@ -169,20 +168,11 @@ class PddlSimInfo: robot_ids: Dict[str, int] sim: RearrangeSim - dataset: RearrangeDatasetV0 env: RearrangeTask - episode: Episode - obj_thresh: float - art_thresh: float - robot_at_thresh: float expr_types: Dict[str, ExprType] predicates: Dict[str, Any] all_entities: Dict[str, PddlEntity] - receptacles: Dict[str, mn.Range3D] - - num_spawn_attempts: int - filter_colliding_states: bool - recep_place_shrink_factor: float + receptacles: Dict[str, Receptacle] pred_truth_cache: Optional[Dict[str, bool]] = None @@ -231,7 +221,7 @@ def get_entity_pos(self, entity: PddlEntity) -> np.ndarray: entity, SimulatorObjectType.STATIC_RECEPTACLE_ENTITY.value ): recep = self.receptacles[ename] - return np.array(recep.center()) + return np.array(recep.get_global_transform(self.sim).translation) if self.check_type_matches( entity, SimulatorObjectType.MOVABLE_ENTITY.value ): @@ -246,7 +236,7 @@ def get_entity_pos(self, entity: PddlEntity) -> np.ndarray: def search_for_entity( self, entity: PddlEntity - ) -> Union[int, str, MarkerInfo, mn.Range3D]: + ) -> Union[int, str, MarkerInfo, Receptacle]: """ Returns underlying simulator information associated with a PDDL entity. Helper to match the PDDL entity to something from the simulator. @@ -273,7 +263,6 @@ def search_for_entity( elif self.check_type_matches( entity, SimulatorObjectType.STATIC_RECEPTACLE_ENTITY.value ): - asset_name = ename.split("_:")[0] - return self.receptacles[asset_name] + return self.receptacles[ename] else: raise ValueError(f"No type match for {entity}") diff --git a/habitat-lab/habitat/tasks/rearrange/rearrange_sim.py b/habitat-lab/habitat/tasks/rearrange/rearrange_sim.py index 442326bd58..56f339c901 100644 --- a/habitat-lab/habitat/tasks/rearrange/rearrange_sim.py +++ b/habitat-lab/habitat/tasks/rearrange/rearrange_sim.py @@ -17,7 +17,6 @@ Optional, Tuple, Union, - cast, ) import magnum as mn @@ -34,7 +33,7 @@ from habitat.datasets.rearrange.navmesh_utils import get_largest_island_index from habitat.datasets.rearrange.rearrange_dataset import RearrangeEpisode from habitat.datasets.rearrange.samplers.receptacle import ( - AABBReceptacle, + Receptacle, find_receptacles, ) from habitat.sims.habitat_simulator.habitat_simulator import HabitatSim @@ -102,9 +101,9 @@ def __init__(self, config: "DictConfig"): self._prev_obj_names: Optional[List[str]] = None self._scene_obj_ids: List[int] = [] # The receptacle information cached between all scenes. - self._receptacles_cache: Dict[str, Dict[str, mn.Range3D]] = {} + self._receptacles_cache: Dict[str, Dict[str, Receptacle]] = {} # The per episode receptacle information. - self._receptacles: Dict[str, mn.Range3D] = {} + self._receptacles: Dict[str, Receptacle] = {} # Used to get data from the RL environment class to sensors. self._goal_pos = None self.viz_ids: Dict[Any, Any] = defaultdict(lambda: None) @@ -155,7 +154,7 @@ def enable_perf_logging(self): self._perf_logging_enabled = True @property - def receptacles(self) -> Dict[str, AABBReceptacle]: + def receptacles(self) -> Dict[str, Receptacle]: return self._receptacles @property @@ -631,7 +630,6 @@ def _add_objs( ) if self._kinematic_mode: ro.motion_type = habitat_sim.physics.MotionType.KINEMATIC - ro.collidable = False if should_add_objects: self._scene_obj_ids.append(ro.object_id) @@ -662,30 +660,15 @@ def _add_objs( def _create_recep_info( self, scene_id: str, ignore_handles: List[str] - ) -> Dict[str, mn.Range3D]: + ) -> Dict[str, Receptacle]: if scene_id not in self._receptacles_cache: - receps = {} all_receps = find_receptacles( self, ignore_handles=ignore_handles, ) - for recep in all_receps: - recep = cast(AABBReceptacle, recep) - local_bounds = recep.bounds - global_T = recep.get_global_transform(self) - # Some coordinates may be flipped by the global transformation, - # mixing the minimum and maximum bound coordinates. - bounds = np.stack( - [ - global_T.transform_point(local_bounds.min), - global_T.transform_point(local_bounds.max), - ], - axis=0, - ) - receps[recep.unique_name] = mn.Range3D( - np.min(bounds, axis=0), np.max(bounds, axis=0) - ) - self._receptacles_cache[scene_id] = receps + self._receptacles_cache[scene_id] = { + recep.unique_name: recep for recep in all_receps + } return self._receptacles_cache[scene_id] def _create_obj_viz(self): diff --git a/habitat-lab/habitat/tasks/rearrange/social_nav/social_nav_task.py b/habitat-lab/habitat/tasks/rearrange/social_nav/social_nav_task.py index 2be4659aee..d1efc4f823 100644 --- a/habitat-lab/habitat/tasks/rearrange/social_nav/social_nav_task.py +++ b/habitat-lab/habitat/tasks/rearrange/social_nav/social_nav_task.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import random -from typing import Optional, cast +from typing import Optional import numpy as np @@ -14,7 +14,6 @@ ) from habitat.core.dataset import Episode from habitat.core.registry import registry -from habitat.datasets.rearrange.rearrange_dataset import RearrangeDatasetV0 from habitat.tasks.rearrange.multi_task.pddl_task import PddlTask from habitat.tasks.rearrange.sub_tasks.nav_to_obj_task import NavToInfo @@ -121,9 +120,7 @@ def reset(self, episode: Episode): super().reset(episode) - self.pddl_problem.bind_to_instance( - self._sim, cast(RearrangeDatasetV0, self._dataset), self, episode - ) + self.pddl_problem.bind_to_instance(self._sim, self) if self._sim.habitat_config.debug_render: # Visualize the position the agent is navigating to. diff --git a/test/test_object_state_machine.py b/test/test_object_state_machine.py new file mode 100644 index 0000000000..55975ccb30 --- /dev/null +++ b/test/test_object_state_machine.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Union + +import magnum as mn + +from habitat.sims.habitat_simulator.object_state_machine import ( + BooleanObjectState, + ObjectStateMachine, + get_state_of_obj, + set_state_of_obj, +) +from habitat_sim import Simulator +from habitat_sim.physics import ManagedArticulatedObject, ManagedRigidObject +from habitat_sim.utils.settings import default_sim_settings, make_cfg + + +def test_state_getter_setter(): + """ + Test getting and setting state metadata to the object user_defined fields. + """ + + sim_settings = default_sim_settings.copy() + hab_cfg = make_cfg(sim_settings) + + with Simulator(hab_cfg) as sim: + obj_template_mngr = sim.get_object_template_manager() + cube_obj_template = ( + obj_template_mngr.get_first_matching_template_by_handle("cube") + ) + rom = sim.get_rigid_object_manager() + new_obj = rom.add_object_by_template_handle(cube_obj_template.handle) + + test_state_values = ["string", 99, mn.Vector3(1.0, 2.0, 3.0)] + + assert get_state_of_obj(new_obj, "test_state") is None + for test_state_val in test_state_values: + set_state_of_obj(new_obj, "test_state", test_state_val) + assert get_state_of_obj(new_obj, "test_state") == test_state_val + + +class TestObjectState(BooleanObjectState): + def __init__(self): + super().__init__() + self.name = "TestState" + # NOTE: This is contrived + self.accepted_semantic_classes = ["test_class"] + + def update_state( + self, + sim: Simulator, + obj: Union[ManagedArticulatedObject, ManagedRigidObject], + dt: float, + ) -> None: + """ + Overwrite the update for a contrived unit test. + Caches the time the object has been alive and when that time exceeds 1 second, sets the state to false. + """ + time_alive = get_state_of_obj(obj, "time_alive") + if time_alive is None: + time_alive = 0 + time_alive += dt + set_state_of_obj(obj, "time_alive", time_alive) + if time_alive > 1.0: + set_state_of_obj(obj, self.name, False) + + +def test_object_state_machine(): + """ + Test initializing and assigning a state to the state machine. + Test contrived mechanics to proive and example of using the API. + """ + + # use an empty scene + sim_settings = default_sim_settings.copy() + hab_cfg = make_cfg(sim_settings) + + with Simulator(hab_cfg) as sim: + obj_template_mngr = sim.get_object_template_manager() + cube_obj_template = ( + obj_template_mngr.get_first_matching_template_by_handle("cube") + ) + rom = sim.get_rigid_object_manager() + new_obj = rom.add_object_by_template_handle(cube_obj_template.handle) + + # TODO: this is currently a contrived location to cache semantic state for category-based affordance logic. + set_state_of_obj(new_obj, "semantic_class", "test_class") + assert get_state_of_obj(new_obj, "semantic_class") == "test_class" + + # initialize the ObjectStateMachine + osm = ObjectStateMachine(active_states=[TestObjectState()]) + osm.initialize_object_state_map(sim) + + # now the cube should be registered for TestObjectState because it has the correct semantic_class + assert isinstance(osm.active_states[0], TestObjectState) + assert new_obj.handle in osm.objects_with_states + assert isinstance( + osm.objects_with_states[new_obj.handle][0], TestObjectState + ) + + state_report_dict = osm.get_snapshot_dict(sim) + assert "TestState" in state_report_dict + assert new_obj.handle in state_report_dict["TestState"] + assert ( + state_report_dict["TestState"][new_obj.handle] + == TestObjectState().default_value() + ) + + # update the object state machine over time + dt = 0.1 + while sim.get_world_time() < 2.0: + sim.step_world(dt) + osm.update_states(sim, dt) + state_report_dict = osm.get_snapshot_dict(sim) + if sim.get_world_time() < 1.0: + assert state_report_dict["TestState"][new_obj.handle] == True + else: + assert state_report_dict["TestState"][new_obj.handle] == False diff --git a/test/test_rearrange_task.py b/test/test_rearrange_task.py index 4fcac2b736..5556a010e1 100644 --- a/test/test_rearrange_task.py +++ b/test/test_rearrange_task.py @@ -6,6 +6,7 @@ import json import os.path as osp +import random import time from glob import glob @@ -97,7 +98,11 @@ def test_rearrange_dataset(): check_binary_serialization(dataset) -def test_pddl(): +def _get_test_pddl(): + """ + Helper to get a test PDDL instance. + """ + config = get_config( "habitat-lab/habitat/config/benchmark/rearrange/multi_task/rearrange_easy.yaml", [ @@ -111,7 +116,29 @@ def test_pddl(): env_class=env_class, config=config ) env.reset() - pddl = env.env.env._env.task.pddl_problem # type: ignore + return env.env.env._env.task.pddl_problem # type: ignore + + +def test_pddl_actions(): + """ + Checks we can execute all PDDL actions. + """ + + pddl = _get_test_pddl() + sim_info = pddl.sim_info + + poss_actions = pddl.get_possible_actions() + for action in poss_actions: + action.apply_if_true(sim_info) + + +def test_pddl_action_postconds(): + """ + Tests the PDDL system action post conditions have the expected outcome in + the simulator. + """ + + pddl = _get_test_pddl() sim_info = pddl.sim_info # Check that the predicates are registering that the robot is not holding @@ -142,7 +169,8 @@ def test_pddl(): # Check the object registers at the goal now. true_preds = pddl.get_true_predicates() assert any( - x.compact_str == "at(goal0|0,TARGET_goal0|0)" for x in true_preds + x.compact_str == "object_at(goal0|0,TARGET_goal0|0)" + for x in true_preds ) @@ -312,6 +340,24 @@ def randomize_obj_state(): # parse the metadata into Receptacle objects test_receptacles = hab_receptacle.find_receptacles(sim) + # test receptacle filtering in find + random_receptacle = random.choice(test_receptacles) + exclude_unique_name = random_receptacle.unique_name + test_filtered_receptacles = hab_receptacle.find_receptacles( + sim, exclude_filter_strings=[exclude_unique_name] + ) + assert len(test_filtered_receptacles) == len(test_receptacles) - 1 + assert ( + len( + [ + rec + for rec in test_filtered_receptacles + if rec.unique_name == random_receptacle.unique_name + ] + ) + == 0 + ) + # test the Receptacle instances num_test_samples = 10 for receptacle in test_receptacles: diff --git a/test/test_sim_utils.py b/test/test_sim_utils.py index a1264c0319..248afe1627 100644 --- a/test/test_sim_utils.py +++ b/test/test_sim_utils.py @@ -315,7 +315,7 @@ def test_keypoint_cast_prepositions(): canister_within = sutils.within(sim, canister_object) assert len(canister_within) == 1 assert basket_object.object_id in canister_within - # now make the check more strict, requring 6 keypoints + # now make the check more strict, requiring 6 keypoints canister_within = sutils.within( sim, canister_object, keypoint_vote_threshold=6 ) @@ -323,7 +323,7 @@ def test_keypoint_cast_prepositions(): # further lower the canister such that the center is contained canister_object.translation = mn.Vector3(-2.01639, 1.2, 0.0410867) - # when center ensures contaiment this state is "within" + # when center ensures containment this state is "within" canister_within = sutils.within( sim, canister_object, keypoint_vote_threshold=6 ) @@ -522,6 +522,56 @@ def test_ao_open_close_queries(): ) sutils.close_link(obj, link_id) # debug reset state + ################################ + # test default link functionality + + # test computing the default link + default_link = sutils.get_ao_default_link(fridge) + assert default_link is None + default_link = sutils.get_ao_default_link( + fridge, compute_if_not_found=True + ) + assert default_link == 1 + assert fridge.user_attributes.get("default_link") == 1 + default_link = sutils.get_ao_default_link( + kitchen_counter, compute_if_not_found=True + ) + assert default_link == 6 + + # NOTE: sim bug here doesn't break the feature + # test setting the default link in template metadata + fridge_template = fridge.creation_attributes + assert fridge_template.get_user_config().get("default_link") is None + fridge_template.get_user_config().set("default_link", 0) + assert fridge_template.get_user_config().get("default_link") == 0 + sim.metadata_mediator.ao_template_manager.register_template( + fridge_template, "new_fridge_template" + ) + new_fridge = sim.get_articulated_object_manager().add_articulated_object_by_template_handle( + "new_fridge_template" + ) + assert new_fridge is not None + default_link = sutils.get_ao_default_link( + fridge, compute_if_not_found=True + ) + assert default_link == 1 + new_default_link = sutils.get_ao_default_link( + new_fridge, compute_if_not_found=True + ) + print( + f" new_default_link (== {new_default_link}) should be 0, waiting on sim bug fix." + ) + # TODO: habitat-sim bug. "default_link" does not get copied over after instantiation if set in the template programmatically. + # assert new_default_link == 0 + + # test setting the default link in instance metadata + fridge.user_attributes.set("default_link", 0) + assert fridge.user_attributes.get("default_link") == 0 + default_link = sutils.get_ao_default_link( + fridge, compute_if_not_found=True + ) + assert fridge.user_attributes.get("default_link") == 0 + @pytest.mark.skipif( not built_with_bullet, @@ -642,7 +692,7 @@ def test_ontop_util(): not osp.exists("data/replica_cad/"), reason="Requires ReplicaCAD dataset.", ) -def test_on_floor_util(): +def test_on_floor_and_next_to(): sim_settings = default_sim_settings.copy() sim_settings[ "scene_dataset_config_file" @@ -701,7 +751,7 @@ def test_on_floor_util(): for obj_handle in objects_in_table: obj = sutils.get_obj_from_handle(sim, obj_handle) l2_dist = (obj.translation - table_object.translation).length() - reg_dist = sutils.size_regularized_distance( + reg_dist = sutils.size_regularized_object_distance( sim, table_object.object_id, obj.object_id, @@ -720,7 +770,7 @@ def test_on_floor_util(): for obj_handle in objects_on_table: obj = sutils.get_obj_from_handle(sim, obj_handle) l2_dist = (obj.translation - table_object.translation).length() - reg_dist = sutils.size_regularized_distance( + reg_dist = sutils.size_regularized_object_distance( sim, table_object.object_id, obj.object_id, @@ -739,7 +789,7 @@ def test_on_floor_util(): shelf = sutils.get_obj_from_handle( sim, "frl_apartment_wall_cabinet_01_:0000" ) - reg_dist = sutils.size_regularized_distance( + reg_dist = sutils.size_regularized_object_distance( sim, sofa.object_id, shelf.object_id, ao_link_map, ao_aabbs ) assert ( @@ -765,3 +815,45 @@ def test_on_floor_util(): sim, sofa.object_id, vec, ao_link_map, ao_aabbs ) assert axis_size_along == sofa_bb.size()[axis] / 2.0 + + # test next_to logics + + # NOTE: using ids because they can represent links also, providing handles for readability + next_to_object_pairs = [ + (3, 4), # neighboring trashcans + (102, 103), # lamps on the table + (145, 50), # table and cabinet furniture + (40, 38), # books on the same shelf + (22, 23), # two neighboring lounge chairs + (11, 13), # two neighboring Sofa pillows + (51, 52), # two neighboring objects on the table + (141, 142), # two neighboring drawers in the chest of drawers + (131, 132), # two neighboring cabinet doors + (77, 78), # two neighboring spice jars + (77, 79), # two skip neighboring spice jars + (77, 80), # two double-skip neighboring spice jars + ] + not_next_to_object_pairs = [ + (36, 38), # books on different shelves + (11, 14), # sofa pillows on opposite sides + (51, 53), # two objects on different table shelves + (141, 140), # two non-neighboring drawers in the chest of drawers + (129, 132), # two non-neighboring cabinet doors + (17, 20), # potted plant and coffee table + ] + for ix, (obj_a_id, obj_b_id) in enumerate(next_to_object_pairs): + assert sutils.obj_next_to( + sim, + obj_a_id, + obj_b_id, + ao_aabbs=ao_aabbs, + ao_link_map=ao_link_map, + ), f"Objects with ids {obj_a_id} and {obj_b_id} at test pair index {ix} should be 'next to' one another." + for ix, (obj_a_id, obj_b_id) in enumerate(not_next_to_object_pairs): + assert not sutils.obj_next_to( + sim, + obj_a_id, + obj_b_id, + ao_aabbs=ao_aabbs, + ao_link_map=ao_link_map, + ), f"Objects with ids {obj_a_id} and {obj_b_id} at test pair index {ix} should not be 'next to' one another."