From 7fdec15f7e842bce4c17f4f3328d9d6fdc79d7fc Mon Sep 17 00:00:00 2001 From: Shagun Sodhani Date: Thu, 11 Feb 2021 14:57:00 -0800 Subject: [PATCH] Initial commit --- .circleci/config.yml | 116 ++++ .github/CODE_OF_CONDUCT.md | 45 ++ .github/CONTRIBUTING.md | 109 ++++ .github/ISSUE_TEMPLATE.md | 30 + .github/PULL_REQUEST_TEMPLATE.md | 13 + .gitignore | 140 +++++ .pre-commit-config.yaml | 21 + .readthedocs.yaml | 20 + LICENSE | 21 + MANIFEST.in | 1 + README.md | 128 +++++ docs_src/Makefile | 20 + docs_src/make.bat | 35 ++ docs_src/source/conf.py | 68 +++ docs_src/source/index.rst | 52 ++ docs_src/source/pages/api/modules.rst | 7 + .../source/pages/api/mtenv.envs.control.rst | 37 ++ .../source/pages/api/mtenv.envs.hipbmdp.rst | 45 ++ .../pages/api/mtenv.envs.hipbmdp.wrappers.rst | 37 ++ .../source/pages/api/mtenv.envs.metaworld.rst | 37 ++ .../api/mtenv.envs.metaworld.wrappers.rst | 21 + docs_src/source/pages/api/mtenv.envs.mpte.rst | 29 + docs_src/source/pages/api/mtenv.envs.rst | 34 ++ .../source/pages/api/mtenv.envs.shared.rst | 18 + .../pages/api/mtenv.envs.shared.wrappers.rst | 21 + .../pages/api/mtenv.envs.tabular_mdp.rst | 29 + docs_src/source/pages/api/mtenv.rst | 31 ++ docs_src/source/pages/api/mtenv.utils.rst | 37 ++ docs_src/source/pages/api/mtenv.wrappers.rst | 53 ++ docs_src/source/pages/bib/refs.bib | 27 + docs_src/source/pages/envs/create.rst | 22 + docs_src/source/pages/envs/supported.rst | 75 +++ docs_src/source/pages/readme.rst | 140 +++++ examples/bandit.py | 56 ++ examples/finite_mtenv_bandit.py | 109 ++++ examples/mtenv_bandit.py | 70 +++ examples/wrapped_bandit.py | 61 +++ local_dm_control_suite/README.md | 56 ++ local_dm_control_suite/__init__.py | 167 ++++++ local_dm_control_suite/acrobot.py | 131 +++++ local_dm_control_suite/acrobot.xml | 43 ++ local_dm_control_suite/ball_in_cup.py | 104 ++++ local_dm_control_suite/ball_in_cup.xml | 54 ++ local_dm_control_suite/base.py | 112 ++++ local_dm_control_suite/cartpole.py | 252 +++++++++ local_dm_control_suite/cartpole.xml | 37 ++ .../cartpole_cart_mass_1.xml | 37 ++ .../cartpole_cart_mass_10.xml | 37 ++ .../cartpole_cart_mass_2.xml | 37 ++ .../cartpole_cart_mass_3.xml | 37 ++ .../cartpole_cart_mass_4.xml | 37 ++ .../cartpole_cart_mass_5.xml | 37 ++ .../cartpole_cart_mass_6.xml | 37 ++ .../cartpole_cart_mass_7.xml | 37 ++ .../cartpole_cart_mass_8.xml | 37 ++ .../cartpole_cart_mass_9.xml | 37 ++ .../cartpole_pole_mass_1.xml | 37 ++ .../cartpole_pole_mass_10.xml | 37 ++ .../cartpole_pole_mass_2.xml | 37 ++ .../cartpole_pole_mass_3.xml | 37 ++ .../cartpole_pole_mass_4.xml | 37 ++ .../cartpole_pole_mass_5.xml | 37 ++ .../cartpole_pole_mass_6.xml | 37 ++ .../cartpole_pole_mass_7.xml | 37 ++ .../cartpole_pole_mass_8.xml | 37 ++ .../cartpole_pole_mass_9.xml | 37 ++ local_dm_control_suite/cheetah.py | 105 ++++ local_dm_control_suite/cheetah.xml | 73 +++ .../cheetah_bfoot_len_1.xml | 73 +++ .../cheetah_bfoot_len_10.xml | 73 +++ .../cheetah_bfoot_len_2.xml | 73 +++ .../cheetah_bfoot_len_3.xml | 73 +++ .../cheetah_bfoot_len_4.xml | 73 +++ .../cheetah_bfoot_len_5.xml | 73 +++ .../cheetah_bfoot_len_6.xml | 73 +++ .../cheetah_bfoot_len_7.xml | 73 +++ .../cheetah_bfoot_len_8.xml | 73 +++ .../cheetah_bfoot_len_9.xml | 73 +++ local_dm_control_suite/cheetah_foot_pos_1.xml | 73 +++ .../cheetah_foot_pos_10.xml | 73 +++ local_dm_control_suite/cheetah_foot_pos_2.xml | 73 +++ local_dm_control_suite/cheetah_foot_pos_3.xml | 73 +++ local_dm_control_suite/cheetah_foot_pos_4.xml | 73 +++ local_dm_control_suite/cheetah_foot_pos_5.xml | 73 +++ local_dm_control_suite/cheetah_foot_pos_6.xml | 73 +++ local_dm_control_suite/cheetah_foot_pos_7.xml | 73 +++ local_dm_control_suite/cheetah_foot_pos_8.xml | 73 +++ local_dm_control_suite/cheetah_foot_pos_9.xml | 73 +++ .../cheetah_foot_size_1.xml | 73 +++ .../cheetah_foot_size_10.xml | 73 +++ .../cheetah_foot_size_2.xml | 73 +++ .../cheetah_foot_size_3.xml | 73 +++ .../cheetah_foot_size_4.xml | 73 +++ .../cheetah_foot_size_5.xml | 73 +++ .../cheetah_foot_size_6.xml | 73 +++ .../cheetah_foot_size_7.xml | 73 +++ .../cheetah_foot_size_8.xml | 73 +++ .../cheetah_foot_size_9.xml | 73 +++ .../cheetah_torso_length_1.xml | 73 +++ .../cheetah_torso_length_10.xml | 73 +++ .../cheetah_torso_length_2.xml | 73 +++ .../cheetah_torso_length_3.xml | 73 +++ .../cheetah_torso_length_4.xml | 73 +++ .../cheetah_torso_length_5.xml | 73 +++ .../cheetah_torso_length_6.xml | 73 +++ .../cheetah_torso_length_7.xml | 73 +++ .../cheetah_torso_length_8.xml | 73 +++ .../cheetah_torso_length_9.xml | 73 +++ local_dm_control_suite/common/__init__.py | 41 ++ local_dm_control_suite/common/materials.xml | 23 + .../common/materials_white_floor.xml | 23 + local_dm_control_suite/common/skybox.xml | 6 + local_dm_control_suite/common/visual.xml | 7 + local_dm_control_suite/demos/mocap_demo.py | 89 +++ local_dm_control_suite/demos/zeros.amc | 213 ++++++++ local_dm_control_suite/explore.py | 95 ++++ local_dm_control_suite/finger.py | 242 +++++++++ local_dm_control_suite/finger.xml | 71 +++ local_dm_control_suite/finger_size_1.xml | 71 +++ local_dm_control_suite/finger_size_10.xml | 71 +++ local_dm_control_suite/finger_size_2.xml | 71 +++ local_dm_control_suite/finger_size_3.xml | 71 +++ local_dm_control_suite/finger_size_4.xml | 71 +++ local_dm_control_suite/finger_size_5.xml | 71 +++ local_dm_control_suite/finger_size_6.xml | 71 +++ local_dm_control_suite/finger_size_7.xml | 71 +++ local_dm_control_suite/finger_size_8.xml | 71 +++ local_dm_control_suite/finger_size_9.xml | 71 +++ local_dm_control_suite/fish.py | 188 +++++++ local_dm_control_suite/fish.xml | 85 +++ local_dm_control_suite/hopper.py | 147 +++++ local_dm_control_suite/hopper.xml | 66 +++ local_dm_control_suite/humanoid.py | 237 ++++++++ local_dm_control_suite/humanoid.xml | 202 +++++++ local_dm_control_suite/humanoid_CMU.py | 195 +++++++ local_dm_control_suite/humanoid_CMU.xml | 289 ++++++++++ local_dm_control_suite/lqr.py | 271 +++++++++ local_dm_control_suite/lqr.xml | 26 + local_dm_control_suite/lqr_solver.py | 146 +++++ local_dm_control_suite/manipulator.py | 329 +++++++++++ local_dm_control_suite/manipulator.xml | 211 +++++++ local_dm_control_suite/pendulum.py | 114 ++++ local_dm_control_suite/pendulum.xml | 26 + local_dm_control_suite/point_mass.py | 134 +++++ local_dm_control_suite/point_mass.xml | 49 ++ local_dm_control_suite/quadruped.py | 514 ++++++++++++++++++ local_dm_control_suite/quadruped.xml | 329 +++++++++++ local_dm_control_suite/reacher.py | 120 ++++ local_dm_control_suite/reacher.xml | 47 ++ local_dm_control_suite/stacker.py | 224 ++++++++ local_dm_control_suite/stacker.xml | 193 +++++++ local_dm_control_suite/swimmer.py | 225 ++++++++ local_dm_control_suite/swimmer.xml | 57 ++ local_dm_control_suite/tests/domains_test.py | 319 +++++++++++ local_dm_control_suite/tests/loader_test.py | 51 ++ local_dm_control_suite/tests/lqr_test.py | 87 +++ local_dm_control_suite/utils/__init__.py | 16 + local_dm_control_suite/utils/parse_amc.py | 301 ++++++++++ .../utils/parse_amc_test.py | 68 +++ local_dm_control_suite/utils/randomizers.py | 90 +++ .../utils/randomizers_test.py | 177 ++++++ local_dm_control_suite/walker.py | 190 +++++++ local_dm_control_suite/walker.xml | 70 +++ local_dm_control_suite/walker_friction_1.xml | 70 +++ local_dm_control_suite/walker_friction_10.xml | 70 +++ local_dm_control_suite/walker_friction_2.xml | 70 +++ local_dm_control_suite/walker_friction_3.xml | 70 +++ local_dm_control_suite/walker_friction_4.xml | 70 +++ local_dm_control_suite/walker_friction_5.xml | 70 +++ local_dm_control_suite/walker_friction_6.xml | 70 +++ local_dm_control_suite/walker_friction_7.xml | 70 +++ local_dm_control_suite/walker_friction_8.xml | 70 +++ local_dm_control_suite/walker_friction_9.xml | 70 +++ local_dm_control_suite/walker_len_1.xml | 70 +++ local_dm_control_suite/walker_len_10.xml | 70 +++ local_dm_control_suite/walker_len_2.xml | 70 +++ local_dm_control_suite/walker_len_3.xml | 70 +++ local_dm_control_suite/walker_len_4.xml | 70 +++ local_dm_control_suite/walker_len_5.xml | 70 +++ local_dm_control_suite/walker_len_6.xml | 70 +++ local_dm_control_suite/walker_len_7.xml | 70 +++ local_dm_control_suite/walker_len_8.xml | 70 +++ local_dm_control_suite/walker_len_9.xml | 70 +++ local_dm_control_suite/wrappers/__init__.py | 16 + .../wrappers/action_noise.py | 77 +++ .../wrappers/action_noise_test.py | 143 +++++ local_dm_control_suite/wrappers/pixels.py | 123 +++++ .../wrappers/pixels_test.py | 135 +++++ mtenv/__init__.py | 7 + mtenv/core.py | 212 ++++++++ mtenv/envs/__init__.py | 124 +++++ mtenv/envs/control/README.md | 1 + mtenv/envs/control/__init__.py | 2 + mtenv/envs/control/acrobot.py | 330 +++++++++++ mtenv/envs/control/cartpole.py | 202 +++++++ mtenv/envs/control/requirements.txt | 0 mtenv/envs/control/setup.py | 28 + mtenv/envs/hipbmdp/README.md | 0 mtenv/envs/hipbmdp/__init__.py | 0 mtenv/envs/hipbmdp/dmc_env.py | 115 ++++ mtenv/envs/hipbmdp/env.py | 81 +++ mtenv/envs/hipbmdp/requirements.txt | 1 + mtenv/envs/hipbmdp/setup.py | 28 + mtenv/envs/hipbmdp/wrappers/__init__.py | 0 mtenv/envs/hipbmdp/wrappers/dmc_wrapper.py | 80 +++ mtenv/envs/hipbmdp/wrappers/framestack.py | 47 ++ .../hipbmdp/wrappers/sticky_observation.py | 56 ++ mtenv/envs/metaworld/README.md | 0 mtenv/envs/metaworld/__init__.py | 0 mtenv/envs/metaworld/env.py | 197 +++++++ mtenv/envs/metaworld/requirements.txt | 1 + mtenv/envs/metaworld/setup.py | 28 + mtenv/envs/metaworld/wrappers/__init__.py | 0 .../envs/metaworld/wrappers/normalized_env.py | 169 ++++++ mtenv/envs/mpte/README.md | 0 mtenv/envs/mpte/__init__.py | 0 mtenv/envs/mpte/requirements.txt | 1 + mtenv/envs/mpte/setup.py | 27 + mtenv/envs/mpte/two_goal_maze_env.py | 343 ++++++++++++ mtenv/envs/registration.py | 86 +++ mtenv/envs/shared/__init__.py | 0 mtenv/envs/shared/wrappers/__init__.py | 0 mtenv/envs/shared/wrappers/multienv.py | 98 ++++ mtenv/envs/tabular_mdp/__init__.py | 0 mtenv/envs/tabular_mdp/requirements.txt | 1 + mtenv/envs/tabular_mdp/setup.py | 26 + mtenv/envs/tabular_mdp/tmdp.py | 121 +++++ mtenv/utils/__init__.py | 1 + mtenv/utils/seeding.py | 19 + mtenv/utils/setup_utils.py | 30 + mtenv/utils/types.py | 15 + mtenv/wrappers/__init__.py | 4 + mtenv/wrappers/env_to_mtenv.py | 109 ++++ mtenv/wrappers/multitask.py | 69 +++ mtenv/wrappers/ntasks.py | 58 ++ mtenv/wrappers/ntasks_id.py | 67 +++ mtenv/wrappers/sample_random_task.py | 22 + news/.gitignore | 2 + news/_template.rst | 19 + noxfile.py | 175 ++++++ requirements/base.txt | 2 + requirements/dev.txt | 21 + requirements/docs.txt | 8 + setup.cfg | 56 ++ setup.py | 86 +++ tests/__init__.py | 1 + tests/envs/__init__.py | 1 + tests/envs/registered_env_test.py | 74 +++ tests/examples/__init__.py | 1 + tests/examples/bandit_test.py | 32 ++ tests/examples/finite_mtenv_bandit_test.py | 28 + tests/examples/mtenv_bandit_test.py | 28 + tests/examples/wrapped_bandit_test.py | 38 ++ tests/utils/utils.py | 69 +++ tests/wrappers/__init__.py | 1 + tests/wrappers/ntasks_id_test.py | 33 ++ tests/wrappers/ntasks_test.py | 33 ++ towncrier.toml | 41 ++ 258 files changed, 19371 insertions(+) create mode 100644 .circleci/config.yml create mode 100644 .github/CODE_OF_CONDUCT.md create mode 100644 .github/CONTRIBUTING.md create mode 100644 .github/ISSUE_TEMPLATE.md create mode 100644 .github/PULL_REQUEST_TEMPLATE.md create mode 100644 .gitignore create mode 100644 .pre-commit-config.yaml create mode 100644 .readthedocs.yaml create mode 100644 LICENSE create mode 100644 MANIFEST.in create mode 100644 README.md create mode 100644 docs_src/Makefile create mode 100644 docs_src/make.bat create mode 100644 docs_src/source/conf.py create mode 100644 docs_src/source/index.rst create mode 100644 docs_src/source/pages/api/modules.rst create mode 100644 docs_src/source/pages/api/mtenv.envs.control.rst create mode 100644 docs_src/source/pages/api/mtenv.envs.hipbmdp.rst create mode 100644 docs_src/source/pages/api/mtenv.envs.hipbmdp.wrappers.rst create mode 100644 docs_src/source/pages/api/mtenv.envs.metaworld.rst create mode 100644 docs_src/source/pages/api/mtenv.envs.metaworld.wrappers.rst create mode 100644 docs_src/source/pages/api/mtenv.envs.mpte.rst create mode 100644 docs_src/source/pages/api/mtenv.envs.rst create mode 100644 docs_src/source/pages/api/mtenv.envs.shared.rst create mode 100644 docs_src/source/pages/api/mtenv.envs.shared.wrappers.rst create mode 100644 docs_src/source/pages/api/mtenv.envs.tabular_mdp.rst create mode 100644 docs_src/source/pages/api/mtenv.rst create mode 100644 docs_src/source/pages/api/mtenv.utils.rst create mode 100644 docs_src/source/pages/api/mtenv.wrappers.rst create mode 100644 docs_src/source/pages/bib/refs.bib create mode 100644 docs_src/source/pages/envs/create.rst create mode 100644 docs_src/source/pages/envs/supported.rst create mode 100644 docs_src/source/pages/readme.rst create mode 100644 examples/bandit.py create mode 100644 examples/finite_mtenv_bandit.py create mode 100644 examples/mtenv_bandit.py create mode 100644 examples/wrapped_bandit.py create mode 100755 local_dm_control_suite/README.md create mode 100755 local_dm_control_suite/__init__.py create mode 100755 local_dm_control_suite/acrobot.py create mode 100755 local_dm_control_suite/acrobot.xml create mode 100755 local_dm_control_suite/ball_in_cup.py create mode 100755 local_dm_control_suite/ball_in_cup.xml create mode 100755 local_dm_control_suite/base.py create mode 100755 local_dm_control_suite/cartpole.py create mode 100755 local_dm_control_suite/cartpole.xml create mode 100755 local_dm_control_suite/cartpole_cart_mass_1.xml create mode 100755 local_dm_control_suite/cartpole_cart_mass_10.xml create mode 100755 local_dm_control_suite/cartpole_cart_mass_2.xml create mode 100755 local_dm_control_suite/cartpole_cart_mass_3.xml create mode 100755 local_dm_control_suite/cartpole_cart_mass_4.xml create mode 100755 local_dm_control_suite/cartpole_cart_mass_5.xml create mode 100755 local_dm_control_suite/cartpole_cart_mass_6.xml create mode 100755 local_dm_control_suite/cartpole_cart_mass_7.xml create mode 100755 local_dm_control_suite/cartpole_cart_mass_8.xml create mode 100755 local_dm_control_suite/cartpole_cart_mass_9.xml create mode 100755 local_dm_control_suite/cartpole_pole_mass_1.xml create mode 100755 local_dm_control_suite/cartpole_pole_mass_10.xml create mode 100755 local_dm_control_suite/cartpole_pole_mass_2.xml create mode 100755 local_dm_control_suite/cartpole_pole_mass_3.xml create mode 100755 local_dm_control_suite/cartpole_pole_mass_4.xml create mode 100755 local_dm_control_suite/cartpole_pole_mass_5.xml create mode 100755 local_dm_control_suite/cartpole_pole_mass_6.xml create mode 100755 local_dm_control_suite/cartpole_pole_mass_7.xml create mode 100755 local_dm_control_suite/cartpole_pole_mass_8.xml create mode 100755 local_dm_control_suite/cartpole_pole_mass_9.xml create mode 100755 local_dm_control_suite/cheetah.py create mode 100755 local_dm_control_suite/cheetah.xml create mode 100755 local_dm_control_suite/cheetah_bfoot_len_1.xml create mode 100755 local_dm_control_suite/cheetah_bfoot_len_10.xml create mode 100755 local_dm_control_suite/cheetah_bfoot_len_2.xml create mode 100755 local_dm_control_suite/cheetah_bfoot_len_3.xml create mode 100755 local_dm_control_suite/cheetah_bfoot_len_4.xml create mode 100755 local_dm_control_suite/cheetah_bfoot_len_5.xml create mode 100755 local_dm_control_suite/cheetah_bfoot_len_6.xml create mode 100755 local_dm_control_suite/cheetah_bfoot_len_7.xml create mode 100755 local_dm_control_suite/cheetah_bfoot_len_8.xml create mode 100755 local_dm_control_suite/cheetah_bfoot_len_9.xml create mode 100755 local_dm_control_suite/cheetah_foot_pos_1.xml create mode 100755 local_dm_control_suite/cheetah_foot_pos_10.xml create mode 100755 local_dm_control_suite/cheetah_foot_pos_2.xml create mode 100755 local_dm_control_suite/cheetah_foot_pos_3.xml create mode 100755 local_dm_control_suite/cheetah_foot_pos_4.xml create mode 100755 local_dm_control_suite/cheetah_foot_pos_5.xml create mode 100755 local_dm_control_suite/cheetah_foot_pos_6.xml create mode 100755 local_dm_control_suite/cheetah_foot_pos_7.xml create mode 100755 local_dm_control_suite/cheetah_foot_pos_8.xml create mode 100755 local_dm_control_suite/cheetah_foot_pos_9.xml create mode 100755 local_dm_control_suite/cheetah_foot_size_1.xml create mode 100755 local_dm_control_suite/cheetah_foot_size_10.xml create mode 100755 local_dm_control_suite/cheetah_foot_size_2.xml create mode 100755 local_dm_control_suite/cheetah_foot_size_3.xml create mode 100755 local_dm_control_suite/cheetah_foot_size_4.xml create mode 100755 local_dm_control_suite/cheetah_foot_size_5.xml create mode 100755 local_dm_control_suite/cheetah_foot_size_6.xml create mode 100755 local_dm_control_suite/cheetah_foot_size_7.xml create mode 100755 local_dm_control_suite/cheetah_foot_size_8.xml create mode 100755 local_dm_control_suite/cheetah_foot_size_9.xml create mode 100755 local_dm_control_suite/cheetah_torso_length_1.xml create mode 100755 local_dm_control_suite/cheetah_torso_length_10.xml create mode 100755 local_dm_control_suite/cheetah_torso_length_2.xml create mode 100755 local_dm_control_suite/cheetah_torso_length_3.xml create mode 100755 local_dm_control_suite/cheetah_torso_length_4.xml create mode 100755 local_dm_control_suite/cheetah_torso_length_5.xml create mode 100755 local_dm_control_suite/cheetah_torso_length_6.xml create mode 100755 local_dm_control_suite/cheetah_torso_length_7.xml create mode 100755 local_dm_control_suite/cheetah_torso_length_8.xml create mode 100755 local_dm_control_suite/cheetah_torso_length_9.xml create mode 100755 local_dm_control_suite/common/__init__.py create mode 100755 local_dm_control_suite/common/materials.xml create mode 100755 local_dm_control_suite/common/materials_white_floor.xml create mode 100755 local_dm_control_suite/common/skybox.xml create mode 100755 local_dm_control_suite/common/visual.xml create mode 100755 local_dm_control_suite/demos/mocap_demo.py create mode 100755 local_dm_control_suite/demos/zeros.amc create mode 100755 local_dm_control_suite/explore.py create mode 100755 local_dm_control_suite/finger.py create mode 100755 local_dm_control_suite/finger.xml create mode 100755 local_dm_control_suite/finger_size_1.xml create mode 100755 local_dm_control_suite/finger_size_10.xml create mode 100755 local_dm_control_suite/finger_size_2.xml create mode 100755 local_dm_control_suite/finger_size_3.xml create mode 100755 local_dm_control_suite/finger_size_4.xml create mode 100755 local_dm_control_suite/finger_size_5.xml create mode 100755 local_dm_control_suite/finger_size_6.xml create mode 100755 local_dm_control_suite/finger_size_7.xml create mode 100755 local_dm_control_suite/finger_size_8.xml create mode 100755 local_dm_control_suite/finger_size_9.xml create mode 100755 local_dm_control_suite/fish.py create mode 100755 local_dm_control_suite/fish.xml create mode 100755 local_dm_control_suite/hopper.py create mode 100755 local_dm_control_suite/hopper.xml create mode 100755 local_dm_control_suite/humanoid.py create mode 100755 local_dm_control_suite/humanoid.xml create mode 100755 local_dm_control_suite/humanoid_CMU.py create mode 100755 local_dm_control_suite/humanoid_CMU.xml create mode 100755 local_dm_control_suite/lqr.py create mode 100755 local_dm_control_suite/lqr.xml create mode 100755 local_dm_control_suite/lqr_solver.py create mode 100755 local_dm_control_suite/manipulator.py create mode 100755 local_dm_control_suite/manipulator.xml create mode 100755 local_dm_control_suite/pendulum.py create mode 100755 local_dm_control_suite/pendulum.xml create mode 100755 local_dm_control_suite/point_mass.py create mode 100755 local_dm_control_suite/point_mass.xml create mode 100755 local_dm_control_suite/quadruped.py create mode 100755 local_dm_control_suite/quadruped.xml create mode 100755 local_dm_control_suite/reacher.py create mode 100755 local_dm_control_suite/reacher.xml create mode 100755 local_dm_control_suite/stacker.py create mode 100755 local_dm_control_suite/stacker.xml create mode 100755 local_dm_control_suite/swimmer.py create mode 100755 local_dm_control_suite/swimmer.xml create mode 100755 local_dm_control_suite/tests/domains_test.py create mode 100755 local_dm_control_suite/tests/loader_test.py create mode 100755 local_dm_control_suite/tests/lqr_test.py create mode 100755 local_dm_control_suite/utils/__init__.py create mode 100755 local_dm_control_suite/utils/parse_amc.py create mode 100755 local_dm_control_suite/utils/parse_amc_test.py create mode 100755 local_dm_control_suite/utils/randomizers.py create mode 100755 local_dm_control_suite/utils/randomizers_test.py create mode 100755 local_dm_control_suite/walker.py create mode 100755 local_dm_control_suite/walker.xml create mode 100755 local_dm_control_suite/walker_friction_1.xml create mode 100755 local_dm_control_suite/walker_friction_10.xml create mode 100755 local_dm_control_suite/walker_friction_2.xml create mode 100755 local_dm_control_suite/walker_friction_3.xml create mode 100755 local_dm_control_suite/walker_friction_4.xml create mode 100755 local_dm_control_suite/walker_friction_5.xml create mode 100755 local_dm_control_suite/walker_friction_6.xml create mode 100755 local_dm_control_suite/walker_friction_7.xml create mode 100755 local_dm_control_suite/walker_friction_8.xml create mode 100755 local_dm_control_suite/walker_friction_9.xml create mode 100755 local_dm_control_suite/walker_len_1.xml create mode 100755 local_dm_control_suite/walker_len_10.xml create mode 100755 local_dm_control_suite/walker_len_2.xml create mode 100755 local_dm_control_suite/walker_len_3.xml create mode 100755 local_dm_control_suite/walker_len_4.xml create mode 100755 local_dm_control_suite/walker_len_5.xml create mode 100755 local_dm_control_suite/walker_len_6.xml create mode 100755 local_dm_control_suite/walker_len_7.xml create mode 100755 local_dm_control_suite/walker_len_8.xml create mode 100755 local_dm_control_suite/walker_len_9.xml create mode 100755 local_dm_control_suite/wrappers/__init__.py create mode 100755 local_dm_control_suite/wrappers/action_noise.py create mode 100755 local_dm_control_suite/wrappers/action_noise_test.py create mode 100755 local_dm_control_suite/wrappers/pixels.py create mode 100755 local_dm_control_suite/wrappers/pixels_test.py create mode 100644 mtenv/__init__.py create mode 100644 mtenv/core.py create mode 100644 mtenv/envs/__init__.py create mode 100644 mtenv/envs/control/README.md create mode 100644 mtenv/envs/control/__init__.py create mode 100644 mtenv/envs/control/acrobot.py create mode 100644 mtenv/envs/control/cartpole.py create mode 100644 mtenv/envs/control/requirements.txt create mode 100644 mtenv/envs/control/setup.py create mode 100644 mtenv/envs/hipbmdp/README.md create mode 100644 mtenv/envs/hipbmdp/__init__.py create mode 100644 mtenv/envs/hipbmdp/dmc_env.py create mode 100644 mtenv/envs/hipbmdp/env.py create mode 100644 mtenv/envs/hipbmdp/requirements.txt create mode 100644 mtenv/envs/hipbmdp/setup.py create mode 100644 mtenv/envs/hipbmdp/wrappers/__init__.py create mode 100644 mtenv/envs/hipbmdp/wrappers/dmc_wrapper.py create mode 100644 mtenv/envs/hipbmdp/wrappers/framestack.py create mode 100644 mtenv/envs/hipbmdp/wrappers/sticky_observation.py create mode 100644 mtenv/envs/metaworld/README.md create mode 100644 mtenv/envs/metaworld/__init__.py create mode 100644 mtenv/envs/metaworld/env.py create mode 100644 mtenv/envs/metaworld/requirements.txt create mode 100644 mtenv/envs/metaworld/setup.py create mode 100644 mtenv/envs/metaworld/wrappers/__init__.py create mode 100644 mtenv/envs/metaworld/wrappers/normalized_env.py create mode 100644 mtenv/envs/mpte/README.md create mode 100644 mtenv/envs/mpte/__init__.py create mode 100644 mtenv/envs/mpte/requirements.txt create mode 100644 mtenv/envs/mpte/setup.py create mode 100644 mtenv/envs/mpte/two_goal_maze_env.py create mode 100644 mtenv/envs/registration.py create mode 100644 mtenv/envs/shared/__init__.py create mode 100644 mtenv/envs/shared/wrappers/__init__.py create mode 100644 mtenv/envs/shared/wrappers/multienv.py create mode 100644 mtenv/envs/tabular_mdp/__init__.py create mode 100644 mtenv/envs/tabular_mdp/requirements.txt create mode 100644 mtenv/envs/tabular_mdp/setup.py create mode 100644 mtenv/envs/tabular_mdp/tmdp.py create mode 100644 mtenv/utils/__init__.py create mode 100644 mtenv/utils/seeding.py create mode 100644 mtenv/utils/setup_utils.py create mode 100644 mtenv/utils/types.py create mode 100644 mtenv/wrappers/__init__.py create mode 100644 mtenv/wrappers/env_to_mtenv.py create mode 100644 mtenv/wrappers/multitask.py create mode 100644 mtenv/wrappers/ntasks.py create mode 100644 mtenv/wrappers/ntasks_id.py create mode 100644 mtenv/wrappers/sample_random_task.py create mode 100644 news/.gitignore create mode 100644 news/_template.rst create mode 100644 noxfile.py create mode 100644 requirements/base.txt create mode 100644 requirements/dev.txt create mode 100644 requirements/docs.txt create mode 100644 setup.cfg create mode 100644 setup.py create mode 100644 tests/__init__.py create mode 100644 tests/envs/__init__.py create mode 100644 tests/envs/registered_env_test.py create mode 100644 tests/examples/__init__.py create mode 100644 tests/examples/bandit_test.py create mode 100644 tests/examples/finite_mtenv_bandit_test.py create mode 100644 tests/examples/mtenv_bandit_test.py create mode 100644 tests/examples/wrapped_bandit_test.py create mode 100644 tests/utils/utils.py create mode 100644 tests/wrappers/__init__.py create mode 100644 tests/wrappers/ntasks_id_test.py create mode 100644 tests/wrappers/ntasks_test.py create mode 100644 towncrier.toml diff --git a/.circleci/config.yml b/.circleci/config.yml new file mode 100644 index 0000000..8fec392 --- /dev/null +++ b/.circleci/config.yml @@ -0,0 +1,116 @@ +# Python CircleCI 2.0 configuration file +# +# Check https://circleci.com/docs/2.0/language-python/ for more details +# +version: 2.1 +jobs: + # Linux + py36_linux: + docker: + - image: circleci/python:3.6 + steps: + - checkout + - run: + name: "Mujoco setup" + command: | + wget https://www.roboti.us/download/mujoco200_linux.zip + unzip mujoco200_linux.zip -d ~/.mujoco + cp -r ~/.mujoco/mujoco200_linux ~/.mujoco/mujoco200 + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/circleci/.mujoco/mujoco200_linux/bin + sudo apt install -y libosmesa6-dev libgl1-mesa-glx libglfw3 libglew-dev libglfw3-dev patchelf + - run: + name: "Preparing environment" + command: | + sudo apt-get install -y expect + sudo pip install nox + - run: + name: "Testing mtenv" + command: | + export NOX_PYTHON_VERSIONS=3.6 + pip install nox + python3 -m nox + + py37_linux: + docker: + - image: circleci/python:3.7 + steps: + - checkout + - run: + name: "Mujoco setup" + command: | + wget https://www.roboti.us/download/mujoco200_linux.zip + unzip mujoco200_linux.zip -d ~/.mujoco + cp -r ~/.mujoco/mujoco200_linux ~/.mujoco/mujoco200 + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/circleci/.mujoco/mujoco200_linux/bin + sudo apt install -y libosmesa6-dev libgl1-mesa-glx libglfw3 libglew-dev libglfw3-dev patchelf + - run: + name: "Preparing environment" + command: | + sudo apt-get install -y expect + sudo pip install nox + - run: + name: "Testing mtenv" + command: | + export NOX_PYTHON_VERSIONS=3.7 + pip install nox + python3 -m nox + + py38_linux: + docker: + - image: circleci/python:3.8 + steps: + - checkout + - run: + name: "Mujoco setup" + command: | + wget https://www.roboti.us/download/mujoco200_linux.zip + unzip mujoco200_linux.zip -d ~/.mujoco + cp -r ~/.mujoco/mujoco200_linux ~/.mujoco/mujoco200 + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/circleci/.mujoco/mujoco200_linux/bin + sudo apt install -y libosmesa6-dev libgl1-mesa-glx libglfw3 libglew-dev libglfw3-dev patchelf + - run: + name: "Preparing environment" + command: | + sudo apt-get install -y expect + sudo pip install nox + - run: + name: "Testing mtenv" + command: | + export NOX_PYTHON_VERSIONS=3.8 + pip install nox + python3 -m nox + + py39_linux: + docker: + - image: circleci/python:3.9 + steps: + - checkout + - run: + name: "Mujoco setup" + command: | + wget https://www.roboti.us/download/mujoco200_linux.zip + unzip mujoco200_linux.zip -d ~/.mujoco + cp -r ~/.mujoco/mujoco200_linux ~/.mujoco/mujoco200 + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/circleci/.mujoco/mujoco200_linux/bin + sudo apt install -y libosmesa6-dev libgl1-mesa-glx libglfw3 libglew-dev libglfw3-dev patchelf + - run: + name: "Preparing environment" + command: | + sudo apt-get install -y expect + sudo pip install nox + - run: + name: "Testing mtenv" + command: | + export NOX_PYTHON_VERSIONS=3.9 + pip install nox + python3 -m nox + +workflows: + version: 2.0 + build: + jobs: + - py36_linux + - py37_linux + - py38_linux + - py39_linux + diff --git a/.github/CODE_OF_CONDUCT.md b/.github/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..c7540fd --- /dev/null +++ b/.github/CODE_OF_CONDUCT.md @@ -0,0 +1,45 @@ +# Open Source Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to make participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment include: + +Using welcoming and inclusive language +Being respectful of differing viewpoints and experiences +Gracefully accepting constructive criticism +Focusing on what is best for the community +Showing empathy towards other community members +Examples of unacceptable behavior by participants include: + +The use of sexualized language or imagery and unwelcome sexual attention or advances +Trolling, insulting/derogatory comments, and personal or political attacks +Public or private harassment +Publishing others’ private information, such as a physical or electronic address, without explicit permission +Other conduct which could reasonably be considered inappropriate in a professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when an individual is representing the project or its community in public spaces. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at opensource-conduct@fb.com. All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project’s leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org \ No newline at end of file diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md new file mode 100644 index 0000000..a4e89a8 --- /dev/null +++ b/.github/CONTRIBUTING.md @@ -0,0 +1,109 @@ +# Contributing to MTEnv + +We are glad that you want to contribute to MTEnv. + +## Local Setup + +Follow these instructions to setup MTEnv locally: + +* Clone locally - `git clone git@github.com:facebookresearch/mtenv.git`. +* *cd* into the directory - `cd mtenv`. +* Install MTEnv in the dev mode - `pip install -e ".[dev]"` +* Tests can be run locally using `nox`. The code is linted using: + * `black` + * `flake8` + * `mypy` +* Install pre-commit hooks - `pre-commit install`. It will execute some +of the tests when you commit the code. You can disable it by adding the +"-n" flag to git command. For example, `git commit -m -n`. + + +### Documentation + +We use [Sphinx](https://www.sphinx-doc.org/en/master/) to build the documentation. +Follow the steps to build/update the documentation locally. + +* rm -rf docs/* +* rm -rf docs_src/source/pages/api +* rm -rf docs_src/build +* sphinx-apidoc -o docs_src/source/pages/api mtenv +* cd docs_src +* make html +* cd .. +* cp -r docs_src/build/html/* docs/ + +Or run all the commands at once: `rm -rf docs/* && rm -rf docs_src/source/pages/api && rm -rf docs_src/build && sphinx-apidoc -o docs_src/source/pages/api mtenv && cd docs_src && make html && cd .. && cp -r docs_src/build/html/* docs/` + +## Pull Requests + +We actively welcome your pull requests. + +1. Fork the repo and create your branch from `main`. +2. Set up the code using instructions from above. +3. If you are adding a new environment, checkout the guide on [how to contribute new environments](#How-To-Contribute-New-Environments). +4. If you've added code that should be tested, add tests. +5. If you've changed APIs, update the documentation. +6. Ensure the test suite passes. This is tested via CI when you make a PR. +7. Add a news entry as described [here](#News-Entry). +8. If you haven't already, complete the Contributor License Agreement ("CLA"). + +#### How To Contribute New Environments + +1. We recommend that you first open an issue to discuss the feasibility of +adding a new environment. This will eliminate the possibility of duplication +of work. +2. Checkout the guide on [how to create new environments](https://mtenv.readthedocs.io/en/latest/pages/envs/create.html). +3. Create a new folder in `mtenv/envs`. +4. Add the following files, along with the implementation of the environment. +You can refer to existing environments. + * `__init__.py` + * `setup.py` + * `requirements.txt` + * `README.md` +5. Register your environment in `/mtenv/envs/__init__.py`. + * `test_kwargs` are optional but if you can specify some values (both + valid and invalid configurations) for automated testing. +6. We run some basic tests on the environment (to make sure it can be +instantiated). You should add more tests to `tests/envs` +7. Add your environment to the list of supported environments at +`docs_src/source/pages/envs/supported.rst` + +#### News Entry + +* Add an issue describing the issue that the PR fixes. + +* Create a file, with the name `issue_number.xxx`, in `news` folder using +the issue number from the previous step. + +* The extension (ie `xxx` part) can be one of the following: + + * api_change: API Changes + * bugfix: Bug Fixes + * doc: Documentation Changes + * environment: Environment Chages (addition or removal) + * feature: Features + * misc: Miscellaneous Changes + +* Add a crisp one line summary of the change. The summary should complete +the sentence "This change will ...". + +## Contributor License Agreement ("CLA") + +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Facebook's open source projects. + +Complete your CLA here: + +## Issues + +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe +disclosure of security bugs. In those cases, please go through the process +outlined on that page and do not file a public issue. + +## License + +By contributing to MTEnv, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md new file mode 100644 index 0000000..c78a0d4 --- /dev/null +++ b/.github/ISSUE_TEMPLATE.md @@ -0,0 +1,30 @@ +--- +name: "Issue" +about: +title: "Issue" +labels: +assignees: '' + +--- + +# Description + +What issue are you facing? + +## How to reproduce + +**Add a minimal example to reproduce the issue.** + +**Stack trace / error message** + +Paste the stack trace/error message to [Gist](https://gist.github.com) +and paste the link here. + +## System information +- **MTEnv Version** : +- **MTEnv environment Name** : +- **Python version** : + +## Any other information + +Add any other information here. diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..3ba3867 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,13 @@ +Thank you for contributing to MTEnv. + +## Proposed Change + +What are you proposing? What is the motivation for the change. + +### Have you read the [Contributing Guidelines](https://github.com/facebookresearch/mtenv/blob/main/.github/CONTRIBUTING.md)? + +Yes/No + +## Related Issues/PRs + +Any issues or PRs that are related to this? \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5f03ff0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,140 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +docs/* diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..068f441 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,21 @@ +repos: + + - repo: https://github.com/psf/black + rev: stable + hooks: + - id: black + language_version: python3.6 + + - repo: https://gitlab.com/pycqa/flake8 + rev: 3.7.9 + hooks: + - id: flake8 + additional_dependencies: [-e, "git+git://github.com/pycqa/pyflakes.git@1911c20#egg=pyflakes"] + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.761 + hooks: + - id: mypy + args: [--strict] + exclude: noxfile.py + exclude: setup.py diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 0000000..8a32ff0 --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,20 @@ +# .readthedocs.yml +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +# Build documentation in the docs/ directory with Sphinx +sphinx: + configuration: docs_src/source/conf.py + +# Optionally build your docs in additional formats such as PDF +formats: + - pdf + +# Optionally set the version of Python and requirements required to build your docs +python: + version: 3.8 + install: + - requirements: requirements/docs.txt \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..87cbf53 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) Facebook, Inc. and its affiliates. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..74531ed --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +recursive-include local_dm_control_suite * diff --git a/README.md b/README.md new file mode 100644 index 0000000..5490970 --- /dev/null +++ b/README.md @@ -0,0 +1,128 @@ +[![CircleCI](https://circleci.com/gh/facebookresearch/mtenv.svg?style=svg&circle-token=d507c3b95e80c67d6d4daf2d43785df654af36d1)](https://circleci.com/gh/facebookresearch/mtenv) +![PyPI - License](https://img.shields.io/pypi/l/mtenv) +![PyPI - Python Version](https://img.shields.io/pypi/pyversions/mtenv) +[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) +[![Zulip Chat](https://img.shields.io/badge/zulip-join_chat-brightgreen.svg)](https://mtenv.zulipchat.com) + + +# MTEnv +MultiTask Environments for Reinforcement Learning. + +## Contents + +1. [Introduction](#Introduction) + +2. [Installation](#Installation) + +3. [Usage](#Usage) + +4. [Documentation](#Documentation) + +5. [Contributing to MTEnv](#Contributing-to-MTEnv) + +6. [Community](#Community) + +7. [Acknowledgements](#Acknowledgements) + +## Introduction + +MTEnv is a library to interface with environments for multi-task reinforcement learning. It has two main components: + +* A core API/interface that extends the [gym interface](https://gym.openai.com/) and adds first-class support for multi-task RL. + +* A [collection of environments](https://mtenv.readthedocs.io/en/latest/pages/envs/supported.html) that implement the API. + +Together, these two components should provide a standard interface for multi-task RL environments and make it easier to reuse components and tools across environments. + +You can read more about the difference between `MTEnv` and single-task environments [here.](https://mtenv.readthedocs.io/en/latest/pages/readme.html#multitask-observation) + +### List of publications & submissions using MTEnv (please create a pull request to add the missing entries): + +* [Learning Adaptive Exploration Strategies in Dynamic Environments Through Informed Policy Regularization](https://arxiv.org/abs/2005.02934) + +* [Learning Robust State Abstractions for Hidden-Parameter Block MDPs](https://arxiv.org/abs/2007.07206) + +### License + +* MTEnv uses [MIT License](https://github.com/facebookresearch/mtenv/blob/main/LICENSE). + +* [Terms of Use](https://opensource.facebook.com/legal/terms) + +* [Privacy Policy](https://opensource.facebook.com/legal/privacy) + +### Citing MTEnv + +If you use MTEnv in your research, please use the following BibTeX entry: +``` +@Misc{Sodhani2021MTEnv, + author = {Shagun Sodhani and Ludovic Denoyer and Pierre-Alexandre Kamienny and Olivier Delalleau}, + title = {MTEnv - Environment interface for mulit-task reinforcement learning}, + howpublished = {Github}, + year = {2021}, + url = {https://github.com/facebookresearch/mtenv} +} +``` + +## Installation + +MTEnv has two components - a core API and environments that implement the API. + +The **Core API** can be installed via `pip install mtenv` or `pip install git+https://github.com/facebookresearch/mtenv.git@main#egg=mtenv` + +The **list of environments**, that implement the API, is available [here](https://mtenv.readthedocs.io/en/latest/pages/envs/supported.html). Any of these environments can be installed via `pip install git+https://github.com/facebookresearch/mtenv.git@main#egg="mtenv[env_name]"`. For example, the `MetaWorld` environment can be installed via `pip install git+https://github.com/facebookresearch/mtenv.git@main#egg="mtenv[metaworld]"`. + +All the environments can be installed at once using `pip install git+https://github.com/facebookresearch/mtenv.git@main#egg="mtenv[all]"`. However, note that some environments may have incompatible dependencies. + +MTEnv can also be installed from the source by first cloning the repo (`git clone git@github.com:facebookresearch/mtenv.git`), *cding* into the directory `cd mtenv`, and then using the pip commands as described above. For example, `pip install mtenv` to install the core API, and `pip install "mtenv[env_name]"` to install a particular environment. + +## Usage + +MTEnv provides an interface very similar to the standard gym environments. One key difference between multi-task environments (that implement the MTEnv interface) and single-task environments is in terms of observation that they return. + +### MultiTask Observation + +The multi-task environments return a dictionary as the observation. This dictionary has two keys: (i) `env_obs`, which maps to the observation from the environment (i.e., the observation that a single task environments return), and (ii) `task_obs`, which maps to the task-specific information from the environment. In the simplest case, `task_obs` can be an integer denoting the task index. In other cases, `task_obs` can provide richer information. + +``` +from mtenv import make +env = make("MT-MetaWorld-MT10-v0") +obs = env.reset() +print(obs) +# {'env_obs': array([-0.03265039, 0.51487777, 0.2368754 , -0.06968209, 0.6235982 , +# 0.01492813, 0. , 0. , 0. , 0.03933976, +# 0.89743189, 0.01492813]), 'task_obs': 1} +action = env.action_space.sample() +print(action) +# array([-0.76422 , -0.15384133, 0.74575615, -0.11724994], dtype=float32) +obs, reward, done, info = env.step(action) +print(obs) +# {'env_obs': array([-0.02583682, 0.54065546, 0.22773503, -0.06968209, 0.6235982 , +# 0.01494118, 0. , 0. , 0. , 0.03933976, +# 0.89743189, 0.01492813]), 'task_obs': 1} +``` + +## Documentation + +[https://mtenv.readthedocs.io](https://mtenv.readthedocs.io) + +## Contributing to MTEnv + +There are several ways to contribute to MTEnv. + +1. Use MTEnv in your research. + +2. Contribute a new environment. We support [many environments](https://mtenv.readthedocs.io/en/latest/pages/envs/supported.html) via MTEnv and are looking forward to adding more environments. Contributors will be added as authors of the library. You can learn more about the workflow of adding an environment [here.](https://mtenv.readthedocs.io/en/latest/pages/envs/create.html) + +3. Check out the [good-first-issues](https://github.com/facebookresearch/mtenv/pulls?q=is%3Apr+is%3Aopen+label%3A%22good+first+issue%22) on GitHub and contribute to fixing those issues. + +4. Check out additional details [here](https://github.com/facebookresearch/mtenv/blob/main/.github/CONTRIBUTING.md). + +## Community + +Ask questions in the chat or github issues: +* [Chat](https://mtenv.zulipchat.com) +* [Issues](https://github.com/facebookresearch/mtenv/issues) + +## Acknowledgements + +* Project file pre-commit, mypy config, towncrier config, circleci etc are based on same files from [Hydra](https://github.com/facebookresearch/hydra). \ No newline at end of file diff --git a/docs_src/Makefile b/docs_src/Makefile new file mode 100644 index 0000000..d0c3cbf --- /dev/null +++ b/docs_src/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs_src/make.bat b/docs_src/make.bat new file mode 100644 index 0000000..6247f7e --- /dev/null +++ b/docs_src/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs_src/source/conf.py b/docs_src/source/conf.py new file mode 100644 index 0000000..af037f6 --- /dev/null +++ b/docs_src/source/conf.py @@ -0,0 +1,68 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +import os +import sys + +sys.path.insert(0, os.path.abspath("../..")) + +# -- Project information ----------------------------------------------------- +import mtenv + +project = "mtenv" +copyright = "2021, Facebook AI" +author = "Shagun Sodhani, Ludovic Denoyer, Pierre-Alexandre Kamienny, Olivier Delalleau" + +# The full version, including alpha/beta/rc tags +release = mtenv.__version__ + + +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.autosectionlabel", + "sphinx.ext.napoleon", + "sphinx.ext.viewcode", + "sphinx_copybutton", + "sphinxcontrib.bibtex", +] + +bibtex_bibfiles = ["pages/bib/refs.bib"] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ["_templates"] + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] + + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = "sphinx_rtd_theme" + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ["_static"] + +# https://github.com/sphinx-doc/sphinx/issues/2374 +autoclass_content = "both" diff --git a/docs_src/source/index.rst b/docs_src/source/index.rst new file mode 100644 index 0000000..6395b44 --- /dev/null +++ b/docs_src/source/index.rst @@ -0,0 +1,52 @@ +.. mtenv documentation master file, created by + sphinx-quickstart on Tue Jan 5 11:07:41 2021. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +MTEnv: MultiTask Environments for Reinforcement Learning +======================================================== + +|CircleCI| |PyPI - License| |PyPI - Python Version| |Code style: black| |Zulip Chat| + +.. |CircleCI| image:: https://circleci.com/gh/facebookresearch/mtenv.svg?style=svg&circle-token=d507c3b95e80c67d6d4daf2d43785df654af36d1 + :target: https://circleci.com/gh/facebookresearch/mtenv +.. |PyPI - License| image:: https://img.shields.io/pypi/l/mtenv +.. |PyPI - Python Version| image:: https://img.shields.io/pypi/pyversions/mtenv +.. |Code style: black| image:: https://img.shields.io/badge/code%20style-black-000000.svg + :target: https://github.com/psf/black +.. |Zulip Chat| image:: https://img.shields.io/badge/zulip-join_chat-brightgreen.svg + :target: https://mtenv.zulipchat.com + +.. toctree:: + :maxdepth: 2 + :caption: Getting Started + + pages/readme + +.. toctree:: + :maxdepth: 3 + :caption: Environments + + pages/envs/supported + pages/envs/create + +.. toctree:: + :maxdepth: 3 + :caption: API + + pages/api/modules + + +Community +========= + +Ask questions in the `chat `_ or `GitHub issues `_. + +To contribute, `open a Pull Request (PR) `__ + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` \ No newline at end of file diff --git a/docs_src/source/pages/api/modules.rst b/docs_src/source/pages/api/modules.rst new file mode 100644 index 0000000..c0cf5e0 --- /dev/null +++ b/docs_src/source/pages/api/modules.rst @@ -0,0 +1,7 @@ +mtenv +===== + +.. toctree:: + :maxdepth: 4 + + mtenv diff --git a/docs_src/source/pages/api/mtenv.envs.control.rst b/docs_src/source/pages/api/mtenv.envs.control.rst new file mode 100644 index 0000000..c72c3ae --- /dev/null +++ b/docs_src/source/pages/api/mtenv.envs.control.rst @@ -0,0 +1,37 @@ +mtenv.envs.control package +========================== + +Submodules +---------- + +mtenv.envs.control.acrobot module +--------------------------------- + +.. automodule:: mtenv.envs.control.acrobot + :members: + :undoc-members: + :show-inheritance: + +mtenv.envs.control.cartpole module +---------------------------------- + +.. automodule:: mtenv.envs.control.cartpole + :members: + :undoc-members: + :show-inheritance: + +mtenv.envs.control.setup module +------------------------------- + +.. automodule:: mtenv.envs.control.setup + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mtenv.envs.control + :members: + :undoc-members: + :show-inheritance: diff --git a/docs_src/source/pages/api/mtenv.envs.hipbmdp.rst b/docs_src/source/pages/api/mtenv.envs.hipbmdp.rst new file mode 100644 index 0000000..029d20f --- /dev/null +++ b/docs_src/source/pages/api/mtenv.envs.hipbmdp.rst @@ -0,0 +1,45 @@ +mtenv.envs.hipbmdp package +========================== + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + mtenv.envs.hipbmdp.wrappers + +Submodules +---------- + +mtenv.envs.hipbmdp.dmc\_env module +---------------------------------- + +.. automodule:: mtenv.envs.hipbmdp.dmc_env + :members: + :undoc-members: + :show-inheritance: + +mtenv.envs.hipbmdp.env module +----------------------------- + +.. automodule:: mtenv.envs.hipbmdp.env + :members: + :undoc-members: + :show-inheritance: + +mtenv.envs.hipbmdp.setup module +------------------------------- + +.. automodule:: mtenv.envs.hipbmdp.setup + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mtenv.envs.hipbmdp + :members: + :undoc-members: + :show-inheritance: diff --git a/docs_src/source/pages/api/mtenv.envs.hipbmdp.wrappers.rst b/docs_src/source/pages/api/mtenv.envs.hipbmdp.wrappers.rst new file mode 100644 index 0000000..65a5319 --- /dev/null +++ b/docs_src/source/pages/api/mtenv.envs.hipbmdp.wrappers.rst @@ -0,0 +1,37 @@ +mtenv.envs.hipbmdp.wrappers package +=================================== + +Submodules +---------- + +mtenv.envs.hipbmdp.wrappers.dmc\_wrapper module +----------------------------------------------- + +.. automodule:: mtenv.envs.hipbmdp.wrappers.dmc_wrapper + :members: + :undoc-members: + :show-inheritance: + +mtenv.envs.hipbmdp.wrappers.framestack module +--------------------------------------------- + +.. automodule:: mtenv.envs.hipbmdp.wrappers.framestack + :members: + :undoc-members: + :show-inheritance: + +mtenv.envs.hipbmdp.wrappers.sticky\_observation module +------------------------------------------------------ + +.. automodule:: mtenv.envs.hipbmdp.wrappers.sticky_observation + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mtenv.envs.hipbmdp.wrappers + :members: + :undoc-members: + :show-inheritance: diff --git a/docs_src/source/pages/api/mtenv.envs.metaworld.rst b/docs_src/source/pages/api/mtenv.envs.metaworld.rst new file mode 100644 index 0000000..4f84b7a --- /dev/null +++ b/docs_src/source/pages/api/mtenv.envs.metaworld.rst @@ -0,0 +1,37 @@ +mtenv.envs.metaworld package +============================ + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + mtenv.envs.metaworld.wrappers + +Submodules +---------- + +mtenv.envs.metaworld.env module +------------------------------- + +.. automodule:: mtenv.envs.metaworld.env + :members: + :undoc-members: + :show-inheritance: + +mtenv.envs.metaworld.setup module +--------------------------------- + +.. automodule:: mtenv.envs.metaworld.setup + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mtenv.envs.metaworld + :members: + :undoc-members: + :show-inheritance: diff --git a/docs_src/source/pages/api/mtenv.envs.metaworld.wrappers.rst b/docs_src/source/pages/api/mtenv.envs.metaworld.wrappers.rst new file mode 100644 index 0000000..395fa64 --- /dev/null +++ b/docs_src/source/pages/api/mtenv.envs.metaworld.wrappers.rst @@ -0,0 +1,21 @@ +mtenv.envs.metaworld.wrappers package +===================================== + +Submodules +---------- + +mtenv.envs.metaworld.wrappers.normalized\_env module +---------------------------------------------------- + +.. automodule:: mtenv.envs.metaworld.wrappers.normalized_env + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mtenv.envs.metaworld.wrappers + :members: + :undoc-members: + :show-inheritance: diff --git a/docs_src/source/pages/api/mtenv.envs.mpte.rst b/docs_src/source/pages/api/mtenv.envs.mpte.rst new file mode 100644 index 0000000..0fb00b5 --- /dev/null +++ b/docs_src/source/pages/api/mtenv.envs.mpte.rst @@ -0,0 +1,29 @@ +mtenv.envs.mpte package +======================= + +Submodules +---------- + +mtenv.envs.mpte.setup module +---------------------------- + +.. automodule:: mtenv.envs.mpte.setup + :members: + :undoc-members: + :show-inheritance: + +mtenv.envs.mpte.two\_goal\_maze\_env module +------------------------------------------- + +.. automodule:: mtenv.envs.mpte.two_goal_maze_env + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mtenv.envs.mpte + :members: + :undoc-members: + :show-inheritance: diff --git a/docs_src/source/pages/api/mtenv.envs.rst b/docs_src/source/pages/api/mtenv.envs.rst new file mode 100644 index 0000000..570acdf --- /dev/null +++ b/docs_src/source/pages/api/mtenv.envs.rst @@ -0,0 +1,34 @@ +mtenv.envs package +================== + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + mtenv.envs.control + mtenv.envs.hipbmdp + mtenv.envs.metaworld + mtenv.envs.mpte + mtenv.envs.shared + mtenv.envs.tabular_mdp + +Submodules +---------- + +mtenv.envs.registration module +------------------------------ + +.. automodule:: mtenv.envs.registration + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mtenv.envs + :members: + :undoc-members: + :show-inheritance: diff --git a/docs_src/source/pages/api/mtenv.envs.shared.rst b/docs_src/source/pages/api/mtenv.envs.shared.rst new file mode 100644 index 0000000..db12324 --- /dev/null +++ b/docs_src/source/pages/api/mtenv.envs.shared.rst @@ -0,0 +1,18 @@ +mtenv.envs.shared package +========================= + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + mtenv.envs.shared.wrappers + +Module contents +--------------- + +.. automodule:: mtenv.envs.shared + :members: + :undoc-members: + :show-inheritance: diff --git a/docs_src/source/pages/api/mtenv.envs.shared.wrappers.rst b/docs_src/source/pages/api/mtenv.envs.shared.wrappers.rst new file mode 100644 index 0000000..78f38c8 --- /dev/null +++ b/docs_src/source/pages/api/mtenv.envs.shared.wrappers.rst @@ -0,0 +1,21 @@ +mtenv.envs.shared.wrappers package +================================== + +Submodules +---------- + +mtenv.envs.shared.wrappers.multienv module +------------------------------------------ + +.. automodule:: mtenv.envs.shared.wrappers.multienv + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mtenv.envs.shared.wrappers + :members: + :undoc-members: + :show-inheritance: diff --git a/docs_src/source/pages/api/mtenv.envs.tabular_mdp.rst b/docs_src/source/pages/api/mtenv.envs.tabular_mdp.rst new file mode 100644 index 0000000..019c9c8 --- /dev/null +++ b/docs_src/source/pages/api/mtenv.envs.tabular_mdp.rst @@ -0,0 +1,29 @@ +mtenv.envs.tabular\_mdp package +=============================== + +Submodules +---------- + +mtenv.envs.tabular\_mdp.setup module +------------------------------------ + +.. automodule:: mtenv.envs.tabular_mdp.setup + :members: + :undoc-members: + :show-inheritance: + +mtenv.envs.tabular\_mdp.tmdp module +----------------------------------- + +.. automodule:: mtenv.envs.tabular_mdp.tmdp + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mtenv.envs.tabular_mdp + :members: + :undoc-members: + :show-inheritance: diff --git a/docs_src/source/pages/api/mtenv.rst b/docs_src/source/pages/api/mtenv.rst new file mode 100644 index 0000000..e1e809c --- /dev/null +++ b/docs_src/source/pages/api/mtenv.rst @@ -0,0 +1,31 @@ +mtenv package +============= + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + mtenv.envs + mtenv.utils + mtenv.wrappers + +Submodules +---------- + +mtenv.core module +----------------- + +.. automodule:: mtenv.core + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mtenv + :members: + :undoc-members: + :show-inheritance: diff --git a/docs_src/source/pages/api/mtenv.utils.rst b/docs_src/source/pages/api/mtenv.utils.rst new file mode 100644 index 0000000..5dabad5 --- /dev/null +++ b/docs_src/source/pages/api/mtenv.utils.rst @@ -0,0 +1,37 @@ +mtenv.utils package +=================== + +Submodules +---------- + +mtenv.utils.seeding module +-------------------------- + +.. automodule:: mtenv.utils.seeding + :members: + :undoc-members: + :show-inheritance: + +mtenv.utils.setup\_utils module +------------------------------- + +.. automodule:: mtenv.utils.setup_utils + :members: + :undoc-members: + :show-inheritance: + +mtenv.utils.types module +------------------------ + +.. automodule:: mtenv.utils.types + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mtenv.utils + :members: + :undoc-members: + :show-inheritance: diff --git a/docs_src/source/pages/api/mtenv.wrappers.rst b/docs_src/source/pages/api/mtenv.wrappers.rst new file mode 100644 index 0000000..601ff26 --- /dev/null +++ b/docs_src/source/pages/api/mtenv.wrappers.rst @@ -0,0 +1,53 @@ +mtenv.wrappers package +====================== + +Submodules +---------- + +mtenv.wrappers.env\_to\_mtenv module +------------------------------------ + +.. automodule:: mtenv.wrappers.env_to_mtenv + :members: + :undoc-members: + :show-inheritance: + +mtenv.wrappers.multitask module +------------------------------- + +.. automodule:: mtenv.wrappers.multitask + :members: + :undoc-members: + :show-inheritance: + +mtenv.wrappers.ntasks module +---------------------------- + +.. automodule:: mtenv.wrappers.ntasks + :members: + :undoc-members: + :show-inheritance: + +mtenv.wrappers.ntasks\_id module +-------------------------------- + +.. automodule:: mtenv.wrappers.ntasks_id + :members: + :undoc-members: + :show-inheritance: + +mtenv.wrappers.sample\_random\_task module +------------------------------------------ + +.. automodule:: mtenv.wrappers.sample_random_task + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: mtenv.wrappers + :members: + :undoc-members: + :show-inheritance: diff --git a/docs_src/source/pages/bib/refs.bib b/docs_src/source/pages/bib/refs.bib new file mode 100644 index 0000000..66657e5 --- /dev/null +++ b/docs_src/source/pages/bib/refs.bib @@ -0,0 +1,27 @@ +Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +@article{mtrl_as_a_hidden_block_mdp, + title = {Multi-Task Reinforcement Learning as a Hidden-Parameter Block MDP}, + author = {Zhang, Amy and Sodhani, Shagun and Khetarpal, Khimya and Pineau, Joelle}, + journal = {arXiv preprint arXiv:2007.07206}, + year = {2020} +} + +@misc{tassa2020dmcontrol, + title = {dm_control: Software and Tasks for Continuous Control}, + author = {Yuval Tassa and Saran Tunyasuvunakool and Alistair Muldal and + Yotam Doron and Siqi Liu and Steven Bohez and Josh Merel and + Tom Erez and Timothy Lillicrap and Nicolas Heess}, + year = {2020}, + eprint = {2006.12983}, + archiveprefix = {arXiv}, + primaryclass = {cs.RO} +} + +@inproceedings{yu2020meta, + title = {Meta-world: A benchmark and evaluation for multi-task and meta reinforcement learning}, + author = {Yu, Tianhe and Quillen, Deirdre and He, Zhanpeng and Julian, Ryan and Hausman, Karol and Finn, Chelsea and Levine, Sergey}, + booktitle = {Conference on Robot Learning}, + pages = {1094--1100}, + year = {2020}, + organization = {PMLR} +} diff --git a/docs_src/source/pages/envs/create.rst b/docs_src/source/pages/envs/create.rst new file mode 100644 index 0000000..1d21fff --- /dev/null +++ b/docs_src/source/pages/envs/create.rst @@ -0,0 +1,22 @@ + +How to create new environments +--------------------------- + +There are two workflows: + + +#. + You have a standard gym environment, which you want to convert into a multitask environment. For example, ``examples/bandit.py`` implements ``BanditEnv`` which is a standard multi-arm bandit, without an explicit notion of task. The user has the following options: + + + * + Write a new subclass, say ``MTBanditEnv`` (which subclasses ``MTEnv``\ ) as shown in ``examples/mtenv_bandit.py``. + + * + Use the ``EnvToMTEnv`` wrapper and wrap the existing single task environment. In some cases, the wrapper may have to be extended, as is done in ``examples/wrapped_bandit.py``. + +#. + If you do not have a single-task gym environment to start with, it is recommended that you directly extend the ``MTEnv`` class. Implementations in ``mtenv/envs`` can be seen as a reference. + +If you want to contribute an environment to the repo, checkout the `Contribution Guide `_. + diff --git a/docs_src/source/pages/envs/supported.rst b/docs_src/source/pages/envs/supported.rst new file mode 100644 index 0000000..8656e21 --- /dev/null +++ b/docs_src/source/pages/envs/supported.rst @@ -0,0 +1,75 @@ + +Supported Environments +====================== + +The following environments are supported: + +Control +------- + +**Installation** + +.. code-block:: bash + + pip install git+https://github.com/facebookresearch/mtenv.git@main#egg="mtenv[control]" + +HiPBMDP +------- + +:cite:`mtrl_as_a_hidden_block_mdp` create a family of MDPs using the +existing environment-task pairs from DeepMind Control Suite :cite:`tassa2020dmcontrol` +and change one environment parameter to sample different MDPs. For more details, +refer :cite:`mtrl_as_a_hidden_block_mdp`. + + +**Installation** + +.. code-block:: bash + + pip install git+https://github.com/facebookresearch/mtenv.git@main#egg="mtenv[hipbmdp]" + +**Usage** + +.. code-block:: python + + from mtenv import make + env = make("MT-HiPBMDP-Finger-Spin-vary-size-v0") + env.reset() + + +MetaWorld +--------- + +:cite:`yu2020meta` proposed an open-source simulated benchmark for +meta-reinforcement learning and multi-task learning consisting of 50 distinct +robotic manipulation tasks. For more details, refer :cite:`yu2020meta`. +MTEnv provides a wrapper for the multi-task learning environments. + +**Installation** + +.. code-block:: bash + + pip install git+https://github.com/facebookresearch/mtenv.git@main#egg="mtenv[metaworld]" + +**Usage** + +.. code-block:: python + + from mtenv import make + env = make("MT-MetaWorld-MT10-v0") # or MT-MetaWorld-MT50-v0 or MT-MetaWorld-MT1-v0 + env.reset() + +MPTE +---- + +**Installation** + +.. code-block:: bash + + pip install git+https://github.com/facebookresearch/mtenv.git@main#egg="mtenv[mpte]" + + +References +------------- + +.. bibliography:: \ No newline at end of file diff --git a/docs_src/source/pages/readme.rst b/docs_src/source/pages/readme.rst new file mode 100644 index 0000000..d92cbcc --- /dev/null +++ b/docs_src/source/pages/readme.rst @@ -0,0 +1,140 @@ +MTEnv +===== + +MultiTask Environments for Reinforcement Learning. + +Introduction +------------ + +MTEnv is a library to interface with environments for multi-task reinforcement learning. It has two main components: + + +* A core API/interface that extends the `gym interface `_ and adds first-class support for multi-task RL. + +* A `collection of environments `_ that implement the API. + +Together, these two components should provide a standard interface for multi-task RL environments and make it easier to reuse components and tools across environments. + +You can read more about the difference between ``MTEnv`` and single-task environments `here. `_ + +List of publications & submissions using MTEnv (please create a pull request to add the missing entries): +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + +* `Learning Adaptive Exploration Strategies in Dynamic Environments Through Informed Policy Regularization `_ + +* `Learning Robust State Abstractions for Hidden-Parameter Block MDPs `_ + +License +^^^^^^^ + +* MTEnv uses `MIT License `_. + +* `Terms of Use `_ + +* `Privacy Policy `_ + +Citing MTEnv +^^^^^^^^^^^^ + +If you use MTEnv in your research, please use the following BibTeX entry: + +.. code-block:: + + @Misc{Sodhani2021MTEnv, + author = {Shagun Sodhani and Ludovic Denoyer and Pierre-Alexandre Kamienny and Olivier Delalleau}, + title = {MTEnv - Environment interface for mulit-task reinforcement learning}, + howpublished = {Github}, + year = {2021}, + url = {https://github.com/facebookresearch/mtenv} + } + +Installation +------------ + +MTEnv has two components - a core API and environments that implement the API. + +The **Core API** can be installed via ``pip install mtenv`` or ``pip install git+https://github.com/facebookresearch/mtenv.git@main#egg=mtenv`` + +The **list of environments**\ , that implement the API, is available `here `_. Any of these environments can be installed via ``pip install git+https://github.com/facebookresearch/mtenv.git@main#egg="mtenv[env_name]"``. For example, the ``MetaWorld`` environment can be installed via ``pip install git+https://github.com/facebookresearch/mtenv.git@main#egg="mtenv[metaworld]"``. + +All the environments can be installed at once using ``pip install git+https://github.com/facebookresearch/mtenv.git@main#egg="mtenv[all]"``. However, note that some environments may have incompatible dependencies. + +MTEnv can also be installed from the source by first cloning the repo (\ ``git clone git@github.com:facebookresearch/mtenv.git``\ ), *cding* into the directory ``cd mtenv``\ , and then using the pip commands as described above. For example, ``pip install mtenv`` to install the core API, and ``pip install "mtenv[env_name]"`` to install a particular environment. + +Usage +----- + +MTEnv provides an interface very similar to the standard gym environments. +One key difference between multitask environments (that implement the MTEnv +interface and single tasks environments is in terms of observation that +they return. + +.. _multitask_observation: + +MultiTask Observation +^^^^^^^^^^^^^^^^^^^^^ + +The multitask environments returns a dictionary as the observation. This +dictionary has two keys: (i) `env_obs` which maps to the observation from +the environment (i.e. the observation that a single task environments return) +and (ii) `task_obs` which maps to the task-specific information from the +environment. In the simplest case, `task_obs` can be an integer denoting +the task index. In other cases, `task_obs` can provide richer information. + +.. code-block:: python + + from mtenv import make + env = make("MT-MetaWorld-MT10-v0") + obs = env.reset() + print(obs) + # {'env_obs': array([-0.03265039, 0.51487777, 0.2368754 , -0.06968209, 0.6235982 , + # 0.01492813, 0. , 0. , 0. , 0.03933976, + # 0.89743189, 0.01492813]), 'task_obs': 1} + action = env.action_space.sample() + print(action) + # array([-0.76422 , -0.15384133, 0.74575615, -0.11724994], dtype=float32) + obs, reward, done, info = env.step(action) + print(obs) + # {'env_obs': array([-0.02583682, 0.54065546, 0.22773503, -0.06968209, 0.6235982 , + # 0.01494118, 0. , 0. , 0. , 0.03933976, + # 0.89743189, 0.01492813]), 'task_obs': 1} + +Documentation +------------- + +`https://mtenv.readthedocs.io `_ + +Contributing to MTEnv +--------------------- + +There are several ways to contribute to MTEnv. + + +#. Use MTEnv in your research. + +#. Contribute a new environment. We support `many environments `_ via MTEnv and are looking forward to adding more environments. Contributors will be added as authors of the library. You can learn more about the workflow of adding an environment `here. `_ + +#. Check out the `good-first-issues `_ on GitHub and contribute to fixing those issues. + +#. Check out additional details `here `_. + +Community +--------- + +Ask questions in the chat or github issues: + + +* `Chat `_ +* `Issues `_ + +Glossary +-------- + +.. _task_state: + +Task State +^^^^^^^^^^ + +Task State contains all the information that the environment needs to +switch to any other task. diff --git a/examples/bandit.py b/examples/bandit.py new file mode 100644 index 0000000..88641ec --- /dev/null +++ b/examples/bandit.py @@ -0,0 +1,56 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from typing import List, Optional, Tuple + +import numpy as np +from gym import spaces +from gym.core import Env + +from mtenv.utils import seeding +from mtenv.utils.types import ActionType, DoneType, EnvObsType, InfoType, RewardType + +StepReturnType = Tuple[EnvObsType, RewardType, DoneType, InfoType] + + +class BanditEnv(Env): # type: ignore[misc] + # Class cannot subclass 'Env' (has type 'Any') + + def __init__(self, n_arms: int): + self.n_arms = n_arms + self.action_space = spaces.Discrete(n_arms) + self.observation_space = spaces.Box( + low=0.0, high=1.0, shape=(1,), dtype=np.float32 + ) + self.reward_probability = spaces.Box( + low=0.0, high=1.0, shape=(self.n_arms,) + ).sample() + + def seed(self, seed: Optional[int] = None) -> List[int]: + self.np_random_env, seed = seeding.np_random(seed) + assert isinstance(seed, int) + return [seed] + + def reset(self) -> EnvObsType: + return np.asarray([0.0]) + + def step(self, action: ActionType) -> StepReturnType: + sample = self.np_random_env.rand() + reward = 0.0 + if sample < self.reward_probability[action]: + reward = 1.0 + + return np.asarray([0.0]), reward, False, {} + + +def run() -> None: + env = BanditEnv(5) + env.seed(seed=5) + for episode in range(3): + print("=== episode " + str(episode)) + print(env.reset()) + for _ in range(5): + action = env.action_space.sample() + print(env.step(action)) + + +if __name__ == "__main__": + run() diff --git a/examples/finite_mtenv_bandit.py b/examples/finite_mtenv_bandit.py new file mode 100644 index 0000000..c08d8a3 --- /dev/null +++ b/examples/finite_mtenv_bandit.py @@ -0,0 +1,109 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from typing import Any, Dict, List, Optional + +import numpy as np +from gym import spaces + +from mtenv import MTEnv +from mtenv.utils import seeding +from mtenv.utils.types import ActionType, ObsType, StepReturnType + +TaskStateType = int + + +class FiniteMTBanditEnv(MTEnv): + """Multitask Bandit Env where the task_state is sampled from a finite list of states""" + + def __init__(self, n_tasks: int, n_arms: int): + super().__init__( + action_space=spaces.Discrete(n_arms), + env_observation_space=spaces.Box( + low=0.0, high=1.0, shape=(1,), dtype=np.float32 + ), + task_observation_space=spaces.Box(low=0.0, high=1.0, shape=(n_arms,)), + ) + self.n_arms = n_arms + self.n_tasks = n_tasks + self.observation_space["task_obs"].seed(0) + self.possible_task_observations = np.asarray( + [self.observation_space["task_obs"].sample() for _ in range(self.n_tasks)] + ) + # possible_task_observations is assumed to be part of the environment definition ie + # everytime we instantiate the env, we get the same `possible_task_observations`. + self._should_reset_env = True + + def reset(self, **kwargs: Dict[str, Any]) -> ObsType: + self.assert_env_seed_is_set() + self._should_reset_env = False + return {"env_obs": [0.0], "task_obs": self.task_obs} + + def sample_task_state(self) -> TaskStateType: + """Sample a `task_state` that contains all the information needed to revert to any + other task. For examples, refer to TBD""" + self.assert_task_seed_is_set() + # The assert statement (at the start of the function) ensures that self.np_random_task + # is not None. Mypy is raising the warning incorrectly. + + return self.np_random_task.randint(0, self.n_tasks) # type: ignore[no-any-return, union-attr] + + def set_task_state(self, task_state: TaskStateType) -> None: + self.task_state = task_state + self.task_obs = self.possible_task_observations[task_state] + + def step(self, action: ActionType) -> StepReturnType: + if self._should_reset_env: + raise RuntimeError("Call `env.reset()` before calling `env.step()`") + # The assert statement (at the start of the function) ensures that self.np_random_task + # is not None. Mypy is raising the warning incorrectly. + sample = self.np_random_env.rand() # type: ignore[union-attr] + reward = 0.0 + if sample < self.task_obs[action]: # type: ignore[index] + reward = 1.0 + + return ( + {"env_obs": [0.0], "task_obs": self.task_obs}, + reward, + False, + {}, + ) + + def seed_task(self, seed: Optional[int] = None) -> List[int]: + """Set the seed for task information""" + self.np_random_task, seed = seeding.np_random(seed) + # in this function, we do not need the self.np_random_task + return [seed] + + def get_task_state(self) -> TaskStateType: + """Return all the information needed to execute the current task again. + For examples, refer to TBD""" + return self.task_state + + +def run() -> None: + env = FiniteMTBanditEnv(n_tasks=10, n_arms=5) + env.seed(seed=1) + env.seed_task(seed=2) + + for task in range(3): + print("=== Task " + str(task % 2)) + env.set_task_state(task % 2) + print(env.reset()) + for _ in range(5): + action = env.action_space.sample() + print(env.step(action)) + + new_env = FiniteMTBanditEnv(n_tasks=10, n_arms=5) + new_env.seed(seed=1) + new_env.seed_task(seed=2) + + print("=== Executing the current task (from old env) in new env ") + + new_env.set_task_state(task_state=env.get_task_state()) + print(new_env.reset()) + for _ in range(5): + action = new_env.action_space.sample() + print(new_env.step(action)) + + +if __name__ == "__main__": + run() diff --git a/examples/mtenv_bandit.py b/examples/mtenv_bandit.py new file mode 100644 index 0000000..d1ccc7f --- /dev/null +++ b/examples/mtenv_bandit.py @@ -0,0 +1,70 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import numpy as np +from gym import spaces + +from mtenv import MTEnv +from mtenv.utils.types import ActionType, ObsType, StepReturnType, TaskStateType + + +class MTBanditEnv(MTEnv): + def __init__(self, n_arms: int): + super().__init__( + action_space=spaces.Discrete(n_arms), + env_observation_space=spaces.Box( + low=0.0, high=1.0, shape=(1,), dtype=np.float32 + ), + task_observation_space=spaces.Box(low=0.0, high=1.0, shape=(n_arms,)), + ) + self.n_arms = n_arms + self._should_reset_env = True + + def reset(self) -> ObsType: + self.assert_env_seed_is_set() + self._should_reset_env = False + return {"env_obs": [0.0], "task_obs": self.task_observation} + + def sample_task_state(self) -> TaskStateType: + self.assert_task_seed_is_set() + return self.observation_space["task_obs"].sample() + + def get_task_state(self) -> TaskStateType: + return self.task_observation + + def set_task_state(self, task_state: TaskStateType) -> None: + self.task_observation = task_state + + def step(self, action: ActionType) -> StepReturnType: + if self._should_reset_env: + raise RuntimeError("Call `env.reset()` before calling `env.step()`") + + # The assert statement (at the start of the function) ensures that self.np_random_task + # is not None. Mypy is raising the warning incorrectly. + sample = self.np_random_env.rand() # type: ignore[union-attr] + reward = 0.0 + if sample < self.task_observation[action]: + reward = 1.0 + + return ( + {"env_obs": [0.0], "task_obs": self.task_observation}, + reward, + False, + {}, + ) + + +def run() -> None: + env = MTBanditEnv(5) + env.seed(seed=1) + env.seed_task(seed=2) + + for task in range(3): + print("=== Task " + str(task)) + env.reset_task_state() + print(env.reset()) + for _ in range(5): + action = env.action_space.sample() + print(env.step(action)) + + +if __name__ == "__main__": + run() diff --git a/examples/wrapped_bandit.py b/examples/wrapped_bandit.py new file mode 100644 index 0000000..67be693 --- /dev/null +++ b/examples/wrapped_bandit.py @@ -0,0 +1,61 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from typing import List, Optional + +from gym import spaces + +from examples.bandit import BanditEnv # type: ignore[import] +from mtenv.utils import seeding +from mtenv.utils.types import TaskObsType, TaskStateType +from mtenv.wrappers.env_to_mtenv import EnvToMTEnv + + +class MTBanditWrapper(EnvToMTEnv): + def set_task_observation(self, task_obs: TaskObsType) -> None: + self._task_obs = task_obs + self.env.reward_probability = self._task_obs + self._is_task_seed_set = False + + def get_task_state(self) -> TaskStateType: + return self._task_obs + + def set_task_state(self, task_state: TaskStateType) -> None: + self._task_obs = task_state + self.env.reward_probability = self._task_obs + + def sample_task_state(self) -> TaskStateType: + """Sample a `task_state` that contains all the information needed to revert to any + other task. For examples, refer to TBD""" + return self.observation_space["task_obs"].sample() + + def seed_task(self, seed: Optional[int] = None) -> List[int]: + """Set the seed for task information""" + self._is_task_seed_set = True + _, seed = seeding.np_random(seed) + self.observation_space["task_obs"].seed(seed) + return [seed] + + def assert_task_seed_is_set(self) -> None: + """Check that the task seed is set.""" + assert self._is_task_seed_set, "please call `seed_task()` first" + + +def run() -> None: + n_arms = 5 + env = MTBanditWrapper( + env=BanditEnv(n_arms), + task_observation_space=spaces.Box(low=0.0, high=1.0, shape=(n_arms,)), + ) + env.seed(1) + env.seed_task(seed=2) + for task in range(3): + print("=== task " + str(task)) + env.reset_task_state() + print(env.reset()) + for _ in range(5): + action = env.action_space.sample() + print(env.step(action)) + print(f"reward_probability: {env.unwrapped.reward_probability}") + + +if __name__ == "__main__": + run() diff --git a/local_dm_control_suite/README.md b/local_dm_control_suite/README.md new file mode 100755 index 0000000..135ab42 --- /dev/null +++ b/local_dm_control_suite/README.md @@ -0,0 +1,56 @@ +# DeepMind Control Suite. + +This submodule contains the domains and tasks described in the +[DeepMind Control Suite tech report](https://arxiv.org/abs/1801.00690). + +## Quickstart + +```python +from dm_control import suite +import numpy as np + +# Load one task: +env = suite.load(domain_name="cartpole", task_name="swingup") + +# Iterate over a task set: +for domain_name, task_name in suite.BENCHMARKING: + env = suite.load(domain_name, task_name) + +# Step through an episode and print out reward, discount and observation. +action_spec = env.action_spec() +time_step = env.reset() +while not time_step.last(): + action = np.random.uniform(action_spec.minimum, + action_spec.maximum, + size=action_spec.shape) + time_step = env.step(action) + print(time_step.reward, time_step.discount, time_step.observation) +``` + +## Illustration video + +Below is a video montage of solved Control Suite tasks, with reward +visualisation enabled. + +[![Video montage](https://img.youtube.com/vi/rAai4QzcYbs/0.jpg)](https://www.youtube.com/watch?v=rAai4QzcYbs) + + +### Quadruped domain [April 2019] + +Roughly based on the 'ant' model introduced by [Schulman et al. 2015](https://arxiv.org/abs/1506.02438). Main modifications to the body are: + +- 4 DoFs per leg, 1 constraining tendon. +- 3 actuators per leg: 'yaw', 'lift', 'extend'. +- Filtered position actuators with timescale of 100ms. +- Sensors include an IMU, force/torque sensors, and rangefinders. + +Four tasks: + +- `walk` and `run`: self-right the body then move forward at a desired speed. +- `escape`: escape a bowl-shaped random terrain (uses rangefinders). +- `fetch`, go to a moving ball and bring it to a target. + +All behaviors in the video below were trained with [Abdolmaleki et al's +MPO](https://arxiv.org/abs/1806.06920). + +[![Video montage](https://img.youtube.com/vi/RhRLjbb7pBE/0.jpg)](https://www.youtube.com/watch?v=RhRLjbb7pBE) diff --git a/local_dm_control_suite/__init__.py b/local_dm_control_suite/__init__.py new file mode 100755 index 0000000..1e45dc2 --- /dev/null +++ b/local_dm_control_suite/__init__.py @@ -0,0 +1,167 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""A collection of MuJoCo-based Reinforcement Learning environments.""" + +from __future__ import absolute_import, division, print_function + +import collections +import inspect +import itertools +import sys + +from dm_control.rl import control + +from . import ( + acrobot, + ball_in_cup, + cartpole, + cheetah, + finger, + fish, + hopper, + humanoid, + humanoid_CMU, + lqr, + manipulator, + pendulum, + point_mass, + quadruped, + reacher, + stacker, + swimmer, + walker, +) + +# Find all domains imported. +_DOMAINS = { + name: module + for name, module in locals().items() + if inspect.ismodule(module) and hasattr(module, "SUITE") +} + + +def _get_tasks(tag): + """Returns a sequence of (domain name, task name) pairs for the given tag.""" + result = [] + + for domain_name in sorted(_DOMAINS.keys()): + domain = _DOMAINS[domain_name] + + if tag is None: + tasks_in_domain = domain.SUITE + else: + tasks_in_domain = domain.SUITE.tagged(tag) + + for task_name in tasks_in_domain.keys(): + result.append((domain_name, task_name)) + + return tuple(result) + + +def _get_tasks_by_domain(tasks): + """Returns a dict mapping from task name to a tuple of domain names.""" + result = collections.defaultdict(list) + + for domain_name, task_name in tasks: + result[domain_name].append(task_name) + + return {k: tuple(v) for k, v in result.items()} + + +# A sequence containing all (domain name, task name) pairs. +ALL_TASKS = _get_tasks(tag=None) + +# Subsets of ALL_TASKS, generated via the tag mechanism. +BENCHMARKING = _get_tasks("benchmarking") +EASY = _get_tasks("easy") +HARD = _get_tasks("hard") +EXTRA = tuple(sorted(set(ALL_TASKS) - set(BENCHMARKING))) + +# A mapping from each domain name to a sequence of its task names. +TASKS_BY_DOMAIN = _get_tasks_by_domain(ALL_TASKS) + + +def load( + domain_name, + task_name, + task_kwargs=None, + environment_kwargs=None, + visualize_reward=False, +): + """Returns an environment from a domain name, task name and optional settings. + + ```python + env = suite.load('cartpole', 'balance') + ``` + + Args: + domain_name: A string containing the name of a domain. + task_name: A string containing the name of a task. + task_kwargs: Optional `dict` of keyword arguments for the task. + environment_kwargs: Optional `dict` specifying keyword arguments for the + environment. + visualize_reward: Optional `bool`. If `True`, object colours in rendered + frames are set to indicate the reward at each step. Default `False`. + + Returns: + The requested environment. + """ + return build_environment( + domain_name, task_name, task_kwargs, environment_kwargs, visualize_reward, + ) + + +def build_environment( + domain_name, + task_name, + task_kwargs=None, + environment_kwargs=None, + visualize_reward=False, +): + """Returns an environment from the suite given a domain name and a task name. + + Args: + domain_name: A string containing the name of a domain. + task_name: A string containing the name of a task. + task_kwargs: Optional `dict` specifying keyword arguments for the task. + environment_kwargs: Optional `dict` specifying keyword arguments for the + environment. + visualize_reward: Optional `bool`. If `True`, object colours in rendered + frames are set to indicate the reward at each step. Default `False`. + + Raises: + ValueError: If the domain or task doesn't exist. + + Returns: + An instance of the requested environment. + """ + if domain_name not in _DOMAINS: + raise ValueError("Domain {!r} does not exist.".format(domain_name)) + + domain = _DOMAINS[domain_name] + + if task_name not in domain.SUITE: + raise ValueError( + "Level {!r} does not exist in domain {!r}.".format(task_name, domain_name) + ) + + task_kwargs = task_kwargs or {} + if environment_kwargs is not None: + task_kwargs = task_kwargs.copy() + task_kwargs["environment_kwargs"] = environment_kwargs + env = domain.SUITE[task_name](**task_kwargs) + env.task.visualize_reward = visualize_reward + return env diff --git a/local_dm_control_suite/acrobot.py b/local_dm_control_suite/acrobot.py new file mode 100755 index 0000000..0adfc1c --- /dev/null +++ b/local_dm_control_suite/acrobot.py @@ -0,0 +1,131 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Acrobot domain.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from dm_control import mujoco +from dm_control.rl import control +from . import base +from . import common +from dm_control.utils import containers +from dm_control.utils import rewards +import numpy as np + +_DEFAULT_TIME_LIMIT = 10 +SUITE = containers.TaggedTasks() + + +def get_model_and_assets(): + """Returns a tuple containing the model XML string and a dict of assets.""" + return common.read_model("acrobot.xml"), common.ASSETS + + +@SUITE.add("benchmarking") +def swingup(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): + """Returns Acrobot balance task.""" + physics = Physics.from_xml_string(*get_model_and_assets()) + task = Balance(sparse=False, random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, task, time_limit=time_limit, **environment_kwargs + ) + + +@SUITE.add("benchmarking") +def swingup_sparse( + time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None +): + """Returns Acrobot sparse balance.""" + physics = Physics.from_xml_string(*get_model_and_assets()) + task = Balance(sparse=True, random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, task, time_limit=time_limit, **environment_kwargs + ) + + +class Physics(mujoco.Physics): + """Physics simulation with additional features for the Acrobot domain.""" + + def horizontal(self): + """Returns horizontal (x) component of body frame z-axes.""" + return self.named.data.xmat[["upper_arm", "lower_arm"], "xz"] + + def vertical(self): + """Returns vertical (z) component of body frame z-axes.""" + return self.named.data.xmat[["upper_arm", "lower_arm"], "zz"] + + def to_target(self): + """Returns the distance from the tip to the target.""" + tip_to_target = ( + self.named.data.site_xpos["target"] - self.named.data.site_xpos["tip"] + ) + return np.linalg.norm(tip_to_target) + + def orientations(self): + """Returns the sines and cosines of the pole angles.""" + return np.concatenate((self.horizontal(), self.vertical())) + + +class Balance(base.Task): + """An Acrobot `Task` to swing up and balance the pole.""" + + def __init__(self, sparse, random=None): + """Initializes an instance of `Balance`. + + Args: + sparse: A `bool` specifying whether to use a sparse (indicator) reward. + random: Optional, either a `numpy.random.RandomState` instance, an + integer seed for creating a new `RandomState`, or None to select a seed + automatically (default). + """ + self._sparse = sparse + super(Balance, self).__init__(random=random) + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode. + + Shoulder and elbow are set to a random position between [-pi, pi). + + Args: + physics: An instance of `Physics`. + """ + physics.named.data.qpos[["shoulder", "elbow"]] = self.random.uniform( + -np.pi, np.pi, 2 + ) + super(Balance, self).initialize_episode(physics) + + def get_observation(self, physics): + """Returns an observation of pole orientation and angular velocities.""" + obs = collections.OrderedDict() + obs["orientations"] = physics.orientations() + obs["velocity"] = physics.velocity() + return obs + + def _get_reward(self, physics, sparse): + target_radius = physics.named.model.site_size["target", 0] + return rewards.tolerance( + physics.to_target(), bounds=(0, target_radius), margin=0 if sparse else 1 + ) + + def get_reward(self, physics): + """Returns a sparse or a smooth reward, as specified in the constructor.""" + return self._get_reward(physics, sparse=self._sparse) diff --git a/local_dm_control_suite/acrobot.xml b/local_dm_control_suite/acrobot.xml new file mode 100755 index 0000000..79b76d9 --- /dev/null +++ b/local_dm_control_suite/acrobot.xml @@ -0,0 +1,43 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/ball_in_cup.py b/local_dm_control_suite/ball_in_cup.py new file mode 100755 index 0000000..daf479c --- /dev/null +++ b/local_dm_control_suite/ball_in_cup.py @@ -0,0 +1,104 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Ball-in-Cup Domain.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from dm_control import mujoco +from dm_control.rl import control +from . import base +from . import common +from dm_control.utils import containers + +_DEFAULT_TIME_LIMIT = 20 # (seconds) +_CONTROL_TIMESTEP = 0.02 # (seconds) + + +SUITE = containers.TaggedTasks() + + +def get_model_and_assets(): + """Returns a tuple containing the model XML string and a dict of assets.""" + return common.read_model("ball_in_cup.xml"), common.ASSETS + + +@SUITE.add("benchmarking", "easy") +def catch(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): + """Returns the Ball-in-Cup task.""" + physics = Physics.from_xml_string(*get_model_and_assets()) + task = BallInCup(random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, + task, + time_limit=time_limit, + control_timestep=_CONTROL_TIMESTEP, + **environment_kwargs + ) + + +class Physics(mujoco.Physics): + """Physics with additional features for the Ball-in-Cup domain.""" + + def ball_to_target(self): + """Returns the vector from the ball to the target.""" + target = self.named.data.site_xpos["target", ["x", "z"]] + ball = self.named.data.xpos["ball", ["x", "z"]] + return target - ball + + def in_target(self): + """Returns 1 if the ball is in the target, 0 otherwise.""" + ball_to_target = abs(self.ball_to_target()) + target_size = self.named.model.site_size["target", [0, 2]] + ball_size = self.named.model.geom_size["ball", 0] + return float(all(ball_to_target < target_size - ball_size)) + + +class BallInCup(base.Task): + """The Ball-in-Cup task. Put the ball in the cup.""" + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode. + + Args: + physics: An instance of `Physics`. + + """ + # Find a collision-free random initial position of the ball. + penetrating = True + while penetrating: + # Assign a random ball position. + physics.named.data.qpos["ball_x"] = self.random.uniform(-0.2, 0.2) + physics.named.data.qpos["ball_z"] = self.random.uniform(0.2, 0.5) + # Check for collisions. + physics.after_reset() + penetrating = physics.data.ncon > 0 + super(BallInCup, self).initialize_episode(physics) + + def get_observation(self, physics): + """Returns an observation of the state.""" + obs = collections.OrderedDict() + obs["position"] = physics.position() + obs["velocity"] = physics.velocity() + return obs + + def get_reward(self, physics): + """Returns a sparse reward.""" + return physics.in_target() diff --git a/local_dm_control_suite/ball_in_cup.xml b/local_dm_control_suite/ball_in_cup.xml new file mode 100755 index 0000000..792073f --- /dev/null +++ b/local_dm_control_suite/ball_in_cup.xml @@ -0,0 +1,54 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/base.py b/local_dm_control_suite/base.py new file mode 100755 index 0000000..07aaf95 --- /dev/null +++ b/local_dm_control_suite/base.py @@ -0,0 +1,112 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Base class for tasks in the Control Suite.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from dm_control import mujoco +from dm_control.rl import control + +import numpy as np + + +class Task(control.Task): + """Base class for tasks in the Control Suite. + + Actions are mapped directly to the states of MuJoCo actuators: each element of + the action array is used to set the control input for a single actuator. The + ordering of the actuators is the same as in the corresponding MJCF XML file. + + Attributes: + random: A `numpy.random.RandomState` instance. This should be used to + generate all random variables associated with the task, such as random + starting states, observation noise* etc. + + *If sensor noise is enabled in the MuJoCo model then this will be generated + using MuJoCo's internal RNG, which has its own independent state. + """ + + def __init__(self, random=None): + """Initializes a new continuous control task. + + Args: + random: Optional, either a `numpy.random.RandomState` instance, an integer + seed for creating a new `RandomState`, or None to select a seed + automatically (default). + """ + if not isinstance(random, np.random.RandomState): + random = np.random.RandomState(random) + self._random = random + self._visualize_reward = False + + @property + def random(self): + """Task-specific `numpy.random.RandomState` instance.""" + return self._random + + def action_spec(self, physics): + """Returns a `BoundedArraySpec` matching the `physics` actuators.""" + return mujoco.action_spec(physics) + + def initialize_episode(self, physics): + """Resets geom colors to their defaults after starting a new episode. + + Subclasses of `base.Task` must delegate to this method after performing + their own initialization. + + Args: + physics: An instance of `mujoco.Physics`. + """ + self.after_step(physics) + + def before_step(self, action, physics): + """Sets the control signal for the actuators to values in `action`.""" + # Support legacy internal code. + action = getattr(action, "continuous_actions", action) + physics.set_control(action) + + def after_step(self, physics): + """Modifies colors according to the reward.""" + if self._visualize_reward: + reward = np.clip(self.get_reward(physics), 0.0, 1.0) + _set_reward_colors(physics, reward) + + @property + def visualize_reward(self): + return self._visualize_reward + + @visualize_reward.setter + def visualize_reward(self, value): + if not isinstance(value, bool): + raise ValueError("Expected a boolean, got {}.".format(type(value))) + self._visualize_reward = value + + +_MATERIALS = ["self", "effector", "target"] +_DEFAULT = [name + "_default" for name in _MATERIALS] +_HIGHLIGHT = [name + "_highlight" for name in _MATERIALS] + + +def _set_reward_colors(physics, reward): + """Sets the highlight, effector and target colors according to the reward.""" + assert 0.0 <= reward <= 1.0 + colors = physics.named.model.mat_rgba + default = colors[_DEFAULT] + highlight = colors[_HIGHLIGHT] + blend_coef = reward ** 4 # Better color distinction near high rewards. + colors[_MATERIALS] = blend_coef * highlight + (1.0 - blend_coef) * default diff --git a/local_dm_control_suite/cartpole.py b/local_dm_control_suite/cartpole.py new file mode 100755 index 0000000..778bbd7 --- /dev/null +++ b/local_dm_control_suite/cartpole.py @@ -0,0 +1,252 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Cartpole domain.""" + +from __future__ import absolute_import, division, print_function + +import collections + +import numpy as np +from dm_control import mujoco +from dm_control.rl import control +from dm_control.utils import containers, rewards +from . import base, common +from lxml import etree +from six.moves import range + +_DEFAULT_TIME_LIMIT = 10 +SUITE = containers.TaggedTasks() + + +def get_model_and_assets(num_poles=1, xml_file_id=None): + """Returns a tuple containing the model XML string and a dict of assets.""" + return _make_model(num_poles, xml_file_id), common.ASSETS + + +@SUITE.add("benchmarking") +def balance( + time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None, +): + """Returns the Cartpole Balance task.""" + physics = Physics.from_xml_string(*get_model_and_assets()) + task = Balance(swing_up=False, sparse=False, random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, task, time_limit=time_limit, **environment_kwargs + ) + + +@SUITE.add("benchmarking") +def balance_sparse( + time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None +): + """Returns the sparse reward variant of the Cartpole Balance task.""" + physics = Physics.from_xml_string(*get_model_and_assets()) + task = Balance(swing_up=False, sparse=True, random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, task, time_limit=time_limit, **environment_kwargs + ) + + +@SUITE.add("benchmarking") +def swingup( + time_limit=_DEFAULT_TIME_LIMIT, + xml_file_id=None, + random=None, + environment_kwargs=None, +): + """Returns the Cartpole Swing-Up task.""" + physics = Physics.from_xml_string(*get_model_and_assets(xml_file_id=xml_file_id)) + task = Balance(swing_up=True, sparse=False, random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, task, time_limit=time_limit, **environment_kwargs + ) + + +@SUITE.add("benchmarking") +def swingup_sparse( + time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None +): + """Returns the sparse reward variant of teh Cartpole Swing-Up task.""" + physics = Physics.from_xml_string(*get_model_and_assets()) + task = Balance(swing_up=True, sparse=True, random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, task, time_limit=time_limit, **environment_kwargs + ) + + +@SUITE.add() +def two_poles(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): + """Returns the Cartpole Balance task with two poles.""" + physics = Physics.from_xml_string(*get_model_and_assets(num_poles=2)) + task = Balance(swing_up=True, sparse=False, random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, task, time_limit=time_limit, **environment_kwargs + ) + + +@SUITE.add() +def three_poles( + time_limit=_DEFAULT_TIME_LIMIT, + random=None, + num_poles=3, + sparse=False, + environment_kwargs=None, +): + """Returns the Cartpole Balance task with three or more poles.""" + physics = Physics.from_xml_string(*get_model_and_assets(num_poles=num_poles)) + task = Balance(swing_up=True, sparse=sparse, random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, task, time_limit=time_limit, **environment_kwargs + ) + + +def _make_model(n_poles, xml_file_id=None): + """Generates an xml string defining a cart with `n_poles` bodies.""" + if xml_file_id is not None: + filename = f"cartpole_{xml_file_id}.xml" + print(filename) + else: + filename = f"cartpole.xml" + xml_string = common.read_model(filename) + if n_poles == 1: + return xml_string + mjcf = etree.fromstring(xml_string) + parent = mjcf.find("./worldbody/body/body") # Find first pole. + # Make chain of poles. + for pole_index in range(2, n_poles + 1): + child = etree.Element( + "body", name="pole_{}".format(pole_index), pos="0 0 1", childclass="pole" + ) + etree.SubElement(child, "joint", name="hinge_{}".format(pole_index)) + etree.SubElement(child, "geom", name="pole_{}".format(pole_index)) + parent.append(child) + parent = child + # Move plane down. + floor = mjcf.find("./worldbody/geom") + floor.set("pos", "0 0 {}".format(1 - n_poles - 0.05)) + # Move cameras back. + cameras = mjcf.findall("./worldbody/camera") + cameras[0].set("pos", "0 {} 1".format(-1 - 2 * n_poles)) + cameras[1].set("pos", "0 {} 2".format(-2 * n_poles)) + return etree.tostring(mjcf, pretty_print=True) + + +class Physics(mujoco.Physics): + """Physics simulation with additional features for the Cartpole domain.""" + + def cart_position(self): + """Returns the position of the cart.""" + return self.named.data.qpos["slider"][0] + + def angular_vel(self): + """Returns the angular velocity of the pole.""" + return self.data.qvel[1:] + + def pole_angle_cosine(self): + """Returns the cosine of the pole angle.""" + return self.named.data.xmat[2:, "zz"] + + def bounded_position(self): + """Returns the state, with pole angle split into sin/cos.""" + return np.hstack( + (self.cart_position(), self.named.data.xmat[2:, ["zz", "xz"]].ravel()) + ) + + +class Balance(base.Task): + """A Cartpole `Task` to balance the pole. + + State is initialized either close to the target configuration or at a random + configuration. + """ + + _CART_RANGE = (-0.25, 0.25) + _ANGLE_COSINE_RANGE = (0.995, 1) + + def __init__(self, swing_up, sparse, random=None): + """Initializes an instance of `Balance`. + + Args: + swing_up: A `bool`, which if `True` sets the cart to the middle of the + slider and the pole pointing towards the ground. Otherwise, sets the + cart to a random position on the slider and the pole to a random + near-vertical position. + sparse: A `bool`, whether to return a sparse or a smooth reward. + random: Optional, either a `numpy.random.RandomState` instance, an + integer seed for creating a new `RandomState`, or None to select a seed + automatically (default). + """ + self._sparse = sparse + self._swing_up = swing_up + super(Balance, self).__init__(random=random) + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode. + + Initializes the cart and pole according to `swing_up`, and in both cases + adds a small random initial velocity to break symmetry. + + Args: + physics: An instance of `Physics`. + """ + nv = physics.model.nv + if self._swing_up: + physics.named.data.qpos["slider"] = 0.01 * self.random.randn() + physics.named.data.qpos["hinge_1"] = np.pi + 0.01 * self.random.randn() + physics.named.data.qpos[2:] = 0.1 * self.random.randn(nv - 2) + else: + physics.named.data.qpos["slider"] = self.random.uniform(-0.1, 0.1) + physics.named.data.qpos[1:] = self.random.uniform(-0.034, 0.034, nv - 1) + physics.named.data.qvel[:] = 0.01 * self.random.randn(physics.model.nv) + super(Balance, self).initialize_episode(physics) + + def get_observation(self, physics): + """Returns an observation of the (bounded) physics state.""" + obs = collections.OrderedDict() + obs["position"] = physics.bounded_position() + obs["velocity"] = physics.velocity() + return obs + + def _get_reward(self, physics, sparse): + if sparse: + cart_in_bounds = rewards.tolerance( + physics.cart_position(), self._CART_RANGE + ) + angle_in_bounds = rewards.tolerance( + physics.pole_angle_cosine(), self._ANGLE_COSINE_RANGE + ).prod() + return cart_in_bounds * angle_in_bounds + else: + upright = (physics.pole_angle_cosine() + 1) / 2 + centered = rewards.tolerance(physics.cart_position(), margin=2) + centered = (1 + centered) / 2 + small_control = rewards.tolerance( + physics.control(), margin=1, value_at_margin=0, sigmoid="quadratic" + )[0] + small_control = (4 + small_control) / 5 + small_velocity = rewards.tolerance(physics.angular_vel(), margin=5).min() + small_velocity = (1 + small_velocity) / 2 + return upright.mean() * small_control * small_velocity * centered + + def get_reward(self, physics): + """Returns a sparse or a smooth reward, as specified in the constructor.""" + return self._get_reward(physics, sparse=self._sparse) diff --git a/local_dm_control_suite/cartpole.xml b/local_dm_control_suite/cartpole.xml new file mode 100755 index 0000000..e01869d --- /dev/null +++ b/local_dm_control_suite/cartpole.xml @@ -0,0 +1,37 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/cartpole_cart_mass_1.xml b/local_dm_control_suite/cartpole_cart_mass_1.xml new file mode 100755 index 0000000..e01869d --- /dev/null +++ b/local_dm_control_suite/cartpole_cart_mass_1.xml @@ -0,0 +1,37 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/cartpole_cart_mass_10.xml b/local_dm_control_suite/cartpole_cart_mass_10.xml new file mode 100755 index 0000000..1ffa772 --- /dev/null +++ b/local_dm_control_suite/cartpole_cart_mass_10.xml @@ -0,0 +1,37 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/cartpole_cart_mass_2.xml b/local_dm_control_suite/cartpole_cart_mass_2.xml new file mode 100755 index 0000000..dd61503 --- /dev/null +++ b/local_dm_control_suite/cartpole_cart_mass_2.xml @@ -0,0 +1,37 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/cartpole_cart_mass_3.xml b/local_dm_control_suite/cartpole_cart_mass_3.xml new file mode 100755 index 0000000..0da06e1 --- /dev/null +++ b/local_dm_control_suite/cartpole_cart_mass_3.xml @@ -0,0 +1,37 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/cartpole_cart_mass_4.xml b/local_dm_control_suite/cartpole_cart_mass_4.xml new file mode 100755 index 0000000..cd8ca56 --- /dev/null +++ b/local_dm_control_suite/cartpole_cart_mass_4.xml @@ -0,0 +1,37 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/cartpole_cart_mass_5.xml b/local_dm_control_suite/cartpole_cart_mass_5.xml new file mode 100755 index 0000000..4b93083 --- /dev/null +++ b/local_dm_control_suite/cartpole_cart_mass_5.xml @@ -0,0 +1,37 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/cartpole_cart_mass_6.xml b/local_dm_control_suite/cartpole_cart_mass_6.xml new file mode 100755 index 0000000..2fb0060 --- /dev/null +++ b/local_dm_control_suite/cartpole_cart_mass_6.xml @@ -0,0 +1,37 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/cartpole_cart_mass_7.xml b/local_dm_control_suite/cartpole_cart_mass_7.xml new file mode 100755 index 0000000..5df5129 --- /dev/null +++ b/local_dm_control_suite/cartpole_cart_mass_7.xml @@ -0,0 +1,37 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/cartpole_cart_mass_8.xml b/local_dm_control_suite/cartpole_cart_mass_8.xml new file mode 100755 index 0000000..001546a --- /dev/null +++ b/local_dm_control_suite/cartpole_cart_mass_8.xml @@ -0,0 +1,37 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/cartpole_cart_mass_9.xml b/local_dm_control_suite/cartpole_cart_mass_9.xml new file mode 100755 index 0000000..0e01ff0 --- /dev/null +++ b/local_dm_control_suite/cartpole_cart_mass_9.xml @@ -0,0 +1,37 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/cartpole_pole_mass_1.xml b/local_dm_control_suite/cartpole_pole_mass_1.xml new file mode 100755 index 0000000..e01869d --- /dev/null +++ b/local_dm_control_suite/cartpole_pole_mass_1.xml @@ -0,0 +1,37 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/cartpole_pole_mass_10.xml b/local_dm_control_suite/cartpole_pole_mass_10.xml new file mode 100755 index 0000000..61bfeab --- /dev/null +++ b/local_dm_control_suite/cartpole_pole_mass_10.xml @@ -0,0 +1,37 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/cartpole_pole_mass_2.xml b/local_dm_control_suite/cartpole_pole_mass_2.xml new file mode 100755 index 0000000..156f090 --- /dev/null +++ b/local_dm_control_suite/cartpole_pole_mass_2.xml @@ -0,0 +1,37 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/cartpole_pole_mass_3.xml b/local_dm_control_suite/cartpole_pole_mass_3.xml new file mode 100755 index 0000000..3f2128c --- /dev/null +++ b/local_dm_control_suite/cartpole_pole_mass_3.xml @@ -0,0 +1,37 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/cartpole_pole_mass_4.xml b/local_dm_control_suite/cartpole_pole_mass_4.xml new file mode 100755 index 0000000..8968b0b --- /dev/null +++ b/local_dm_control_suite/cartpole_pole_mass_4.xml @@ -0,0 +1,37 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/cartpole_pole_mass_5.xml b/local_dm_control_suite/cartpole_pole_mass_5.xml new file mode 100755 index 0000000..d448817 --- /dev/null +++ b/local_dm_control_suite/cartpole_pole_mass_5.xml @@ -0,0 +1,37 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/cartpole_pole_mass_6.xml b/local_dm_control_suite/cartpole_pole_mass_6.xml new file mode 100755 index 0000000..0b8e8b1 --- /dev/null +++ b/local_dm_control_suite/cartpole_pole_mass_6.xml @@ -0,0 +1,37 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/cartpole_pole_mass_7.xml b/local_dm_control_suite/cartpole_pole_mass_7.xml new file mode 100755 index 0000000..cfb0f85 --- /dev/null +++ b/local_dm_control_suite/cartpole_pole_mass_7.xml @@ -0,0 +1,37 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/cartpole_pole_mass_8.xml b/local_dm_control_suite/cartpole_pole_mass_8.xml new file mode 100755 index 0000000..c92cb62 --- /dev/null +++ b/local_dm_control_suite/cartpole_pole_mass_8.xml @@ -0,0 +1,37 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/cartpole_pole_mass_9.xml b/local_dm_control_suite/cartpole_pole_mass_9.xml new file mode 100755 index 0000000..86b4414 --- /dev/null +++ b/local_dm_control_suite/cartpole_pole_mass_9.xml @@ -0,0 +1,37 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/cheetah.py b/local_dm_control_suite/cheetah.py new file mode 100755 index 0000000..78bd562 --- /dev/null +++ b/local_dm_control_suite/cheetah.py @@ -0,0 +1,105 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Cheetah Domain.""" + +from __future__ import absolute_import, division, print_function + +import collections + +from dm_control import mujoco +from dm_control.rl import control +from dm_control.utils import containers, rewards +from . import base, common + +# How long the simulation will run, in seconds. +_DEFAULT_TIME_LIMIT = 10 + +# Running speed above which reward is 1. +_RUN_SPEED = 10 + +SUITE = containers.TaggedTasks() + + +def get_model_and_assets(xml_file_id): + """Returns a tuple containing the model XML string and a dict of assets.""" + if xml_file_id is not None: + filename = f"cheetah_{xml_file_id}.xml" + print(filename) + else: + filename = f"cheetah.xml" + return common.read_model(filename), common.ASSETS + + +@SUITE.add("benchmarking") +def run( + time_limit=_DEFAULT_TIME_LIMIT, + xml_file_id=None, + random=None, + environment_kwargs=None, +): + """Returns the run task.""" + physics = Physics.from_xml_string(*get_model_and_assets(xml_file_id)) + task = Cheetah(random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, task, time_limit=time_limit, **environment_kwargs + ) + + +class Physics(mujoco.Physics): + """Physics simulation with additional features for the Cheetah domain.""" + + def speed(self): + """Returns the horizontal speed of the Cheetah.""" + return self.named.data.sensordata["torso_subtreelinvel"][0] + + +class Cheetah(base.Task): + """A `Task` to train a running Cheetah.""" + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode.""" + # The indexing below assumes that all joints have a single DOF. + assert physics.model.nq == physics.model.njnt + is_limited = physics.model.jnt_limited == 1 + lower, upper = physics.model.jnt_range[is_limited].T + physics.data.qpos[is_limited] = self.random.uniform(lower, upper) + + # Stabilize the model before the actual simulation. + for _ in range(200): + physics.step() + + physics.data.time = 0 + self._timeout_progress = 0 + super(Cheetah, self).initialize_episode(physics) + + def get_observation(self, physics): + """Returns an observation of the state, ignoring horizontal position.""" + obs = collections.OrderedDict() + # Ignores horizontal position to maintain translational invariance. + obs["position"] = physics.data.qpos[1:].copy() + obs["velocity"] = physics.velocity() + return obs + + def get_reward(self, physics): + """Returns a reward to the agent.""" + return rewards.tolerance( + physics.speed(), + bounds=(_RUN_SPEED, float("inf")), + margin=_RUN_SPEED, + value_at_margin=0, + sigmoid="linear", + ) diff --git a/local_dm_control_suite/cheetah.xml b/local_dm_control_suite/cheetah.xml new file mode 100755 index 0000000..26076af --- /dev/null +++ b/local_dm_control_suite/cheetah.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_bfoot_len_1.xml b/local_dm_control_suite/cheetah_bfoot_len_1.xml new file mode 100755 index 0000000..26076af --- /dev/null +++ b/local_dm_control_suite/cheetah_bfoot_len_1.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_bfoot_len_10.xml b/local_dm_control_suite/cheetah_bfoot_len_10.xml new file mode 100755 index 0000000..2f8ea14 --- /dev/null +++ b/local_dm_control_suite/cheetah_bfoot_len_10.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_bfoot_len_2.xml b/local_dm_control_suite/cheetah_bfoot_len_2.xml new file mode 100755 index 0000000..f67ea67 --- /dev/null +++ b/local_dm_control_suite/cheetah_bfoot_len_2.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_bfoot_len_3.xml b/local_dm_control_suite/cheetah_bfoot_len_3.xml new file mode 100755 index 0000000..5842e28 --- /dev/null +++ b/local_dm_control_suite/cheetah_bfoot_len_3.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_bfoot_len_4.xml b/local_dm_control_suite/cheetah_bfoot_len_4.xml new file mode 100755 index 0000000..bbb280a --- /dev/null +++ b/local_dm_control_suite/cheetah_bfoot_len_4.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_bfoot_len_5.xml b/local_dm_control_suite/cheetah_bfoot_len_5.xml new file mode 100755 index 0000000..89c2bfa --- /dev/null +++ b/local_dm_control_suite/cheetah_bfoot_len_5.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_bfoot_len_6.xml b/local_dm_control_suite/cheetah_bfoot_len_6.xml new file mode 100755 index 0000000..c3b6683 --- /dev/null +++ b/local_dm_control_suite/cheetah_bfoot_len_6.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_bfoot_len_7.xml b/local_dm_control_suite/cheetah_bfoot_len_7.xml new file mode 100755 index 0000000..cc2c0b0 --- /dev/null +++ b/local_dm_control_suite/cheetah_bfoot_len_7.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_bfoot_len_8.xml b/local_dm_control_suite/cheetah_bfoot_len_8.xml new file mode 100755 index 0000000..02d684b --- /dev/null +++ b/local_dm_control_suite/cheetah_bfoot_len_8.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_bfoot_len_9.xml b/local_dm_control_suite/cheetah_bfoot_len_9.xml new file mode 100755 index 0000000..2470117 --- /dev/null +++ b/local_dm_control_suite/cheetah_bfoot_len_9.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_foot_pos_1.xml b/local_dm_control_suite/cheetah_foot_pos_1.xml new file mode 100755 index 0000000..26076af --- /dev/null +++ b/local_dm_control_suite/cheetah_foot_pos_1.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_foot_pos_10.xml b/local_dm_control_suite/cheetah_foot_pos_10.xml new file mode 100755 index 0000000..30dc057 --- /dev/null +++ b/local_dm_control_suite/cheetah_foot_pos_10.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_foot_pos_2.xml b/local_dm_control_suite/cheetah_foot_pos_2.xml new file mode 100755 index 0000000..0a52c9e --- /dev/null +++ b/local_dm_control_suite/cheetah_foot_pos_2.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_foot_pos_3.xml b/local_dm_control_suite/cheetah_foot_pos_3.xml new file mode 100755 index 0000000..42145bb --- /dev/null +++ b/local_dm_control_suite/cheetah_foot_pos_3.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_foot_pos_4.xml b/local_dm_control_suite/cheetah_foot_pos_4.xml new file mode 100755 index 0000000..907d91a --- /dev/null +++ b/local_dm_control_suite/cheetah_foot_pos_4.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_foot_pos_5.xml b/local_dm_control_suite/cheetah_foot_pos_5.xml new file mode 100755 index 0000000..25ca76d --- /dev/null +++ b/local_dm_control_suite/cheetah_foot_pos_5.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_foot_pos_6.xml b/local_dm_control_suite/cheetah_foot_pos_6.xml new file mode 100755 index 0000000..fabf5a1 --- /dev/null +++ b/local_dm_control_suite/cheetah_foot_pos_6.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_foot_pos_7.xml b/local_dm_control_suite/cheetah_foot_pos_7.xml new file mode 100755 index 0000000..85e27fc --- /dev/null +++ b/local_dm_control_suite/cheetah_foot_pos_7.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_foot_pos_8.xml b/local_dm_control_suite/cheetah_foot_pos_8.xml new file mode 100755 index 0000000..6961f09 --- /dev/null +++ b/local_dm_control_suite/cheetah_foot_pos_8.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_foot_pos_9.xml b/local_dm_control_suite/cheetah_foot_pos_9.xml new file mode 100755 index 0000000..398b5a6 --- /dev/null +++ b/local_dm_control_suite/cheetah_foot_pos_9.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_foot_size_1.xml b/local_dm_control_suite/cheetah_foot_size_1.xml new file mode 100755 index 0000000..26076af --- /dev/null +++ b/local_dm_control_suite/cheetah_foot_size_1.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_foot_size_10.xml b/local_dm_control_suite/cheetah_foot_size_10.xml new file mode 100755 index 0000000..5eb0309 --- /dev/null +++ b/local_dm_control_suite/cheetah_foot_size_10.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_foot_size_2.xml b/local_dm_control_suite/cheetah_foot_size_2.xml new file mode 100755 index 0000000..32b6c30 --- /dev/null +++ b/local_dm_control_suite/cheetah_foot_size_2.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_foot_size_3.xml b/local_dm_control_suite/cheetah_foot_size_3.xml new file mode 100755 index 0000000..debeae0 --- /dev/null +++ b/local_dm_control_suite/cheetah_foot_size_3.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_foot_size_4.xml b/local_dm_control_suite/cheetah_foot_size_4.xml new file mode 100755 index 0000000..0d0571a --- /dev/null +++ b/local_dm_control_suite/cheetah_foot_size_4.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_foot_size_5.xml b/local_dm_control_suite/cheetah_foot_size_5.xml new file mode 100755 index 0000000..8e90ad4 --- /dev/null +++ b/local_dm_control_suite/cheetah_foot_size_5.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_foot_size_6.xml b/local_dm_control_suite/cheetah_foot_size_6.xml new file mode 100755 index 0000000..8cf2c48 --- /dev/null +++ b/local_dm_control_suite/cheetah_foot_size_6.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_foot_size_7.xml b/local_dm_control_suite/cheetah_foot_size_7.xml new file mode 100755 index 0000000..e408b00 --- /dev/null +++ b/local_dm_control_suite/cheetah_foot_size_7.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_foot_size_8.xml b/local_dm_control_suite/cheetah_foot_size_8.xml new file mode 100755 index 0000000..d419233 --- /dev/null +++ b/local_dm_control_suite/cheetah_foot_size_8.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_foot_size_9.xml b/local_dm_control_suite/cheetah_foot_size_9.xml new file mode 100755 index 0000000..d383a94 --- /dev/null +++ b/local_dm_control_suite/cheetah_foot_size_9.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_torso_length_1.xml b/local_dm_control_suite/cheetah_torso_length_1.xml new file mode 100755 index 0000000..26076af --- /dev/null +++ b/local_dm_control_suite/cheetah_torso_length_1.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_torso_length_10.xml b/local_dm_control_suite/cheetah_torso_length_10.xml new file mode 100755 index 0000000..19cd5e9 --- /dev/null +++ b/local_dm_control_suite/cheetah_torso_length_10.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_torso_length_2.xml b/local_dm_control_suite/cheetah_torso_length_2.xml new file mode 100755 index 0000000..524aeb8 --- /dev/null +++ b/local_dm_control_suite/cheetah_torso_length_2.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_torso_length_3.xml b/local_dm_control_suite/cheetah_torso_length_3.xml new file mode 100755 index 0000000..0030d46 --- /dev/null +++ b/local_dm_control_suite/cheetah_torso_length_3.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_torso_length_4.xml b/local_dm_control_suite/cheetah_torso_length_4.xml new file mode 100755 index 0000000..d1acb5e --- /dev/null +++ b/local_dm_control_suite/cheetah_torso_length_4.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_torso_length_5.xml b/local_dm_control_suite/cheetah_torso_length_5.xml new file mode 100755 index 0000000..b9a569c --- /dev/null +++ b/local_dm_control_suite/cheetah_torso_length_5.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_torso_length_6.xml b/local_dm_control_suite/cheetah_torso_length_6.xml new file mode 100755 index 0000000..ad1c29a --- /dev/null +++ b/local_dm_control_suite/cheetah_torso_length_6.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_torso_length_7.xml b/local_dm_control_suite/cheetah_torso_length_7.xml new file mode 100755 index 0000000..c43cbb6 --- /dev/null +++ b/local_dm_control_suite/cheetah_torso_length_7.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_torso_length_8.xml b/local_dm_control_suite/cheetah_torso_length_8.xml new file mode 100755 index 0000000..166ec46 --- /dev/null +++ b/local_dm_control_suite/cheetah_torso_length_8.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/cheetah_torso_length_9.xml b/local_dm_control_suite/cheetah_torso_length_9.xml new file mode 100755 index 0000000..b8149fe --- /dev/null +++ b/local_dm_control_suite/cheetah_torso_length_9.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/local_dm_control_suite/common/__init__.py b/local_dm_control_suite/common/__init__.py new file mode 100755 index 0000000..a997524 --- /dev/null +++ b/local_dm_control_suite/common/__init__.py @@ -0,0 +1,41 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Functions to manage the common assets for domains.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +from dm_control.utils import io as resources + +_SUITE_DIR = os.path.dirname(os.path.dirname(__file__)) +_FILENAMES = [ + "./common/materials.xml", + "./common/materials_white_floor.xml", + "./common/skybox.xml", + "./common/visual.xml", +] + +ASSETS = { + filename: resources.GetResource(os.path.join(_SUITE_DIR, filename)) + for filename in _FILENAMES +} + + +def read_model(model_filename): + """Reads a model XML file and returns its contents as a string.""" + return resources.GetResource(os.path.join(_SUITE_DIR, model_filename)) diff --git a/local_dm_control_suite/common/materials.xml b/local_dm_control_suite/common/materials.xml new file mode 100755 index 0000000..5a3b169 --- /dev/null +++ b/local_dm_control_suite/common/materials.xml @@ -0,0 +1,23 @@ + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/common/materials_white_floor.xml b/local_dm_control_suite/common/materials_white_floor.xml new file mode 100755 index 0000000..a1e35c2 --- /dev/null +++ b/local_dm_control_suite/common/materials_white_floor.xml @@ -0,0 +1,23 @@ + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/common/skybox.xml b/local_dm_control_suite/common/skybox.xml new file mode 100755 index 0000000..b888692 --- /dev/null +++ b/local_dm_control_suite/common/skybox.xml @@ -0,0 +1,6 @@ + + + + + diff --git a/local_dm_control_suite/common/visual.xml b/local_dm_control_suite/common/visual.xml new file mode 100755 index 0000000..ede15ad --- /dev/null +++ b/local_dm_control_suite/common/visual.xml @@ -0,0 +1,7 @@ + + + + + + + diff --git a/local_dm_control_suite/demos/mocap_demo.py b/local_dm_control_suite/demos/mocap_demo.py new file mode 100755 index 0000000..535492a --- /dev/null +++ b/local_dm_control_suite/demos/mocap_demo.py @@ -0,0 +1,89 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Demonstration of amc parsing for CMU mocap database. + +To run the demo, supply a path to a `.amc` file: + + python mocap_demo --filename='path/to/mocap.amc' + +CMU motion capture clips are available at mocap.cs.cmu.edu +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +# Internal dependencies. + +from absl import app +from absl import flags + +from . import humanoid_CMU +from dm_control.suite.utils import parse_amc + +import matplotlib.pyplot as plt +import numpy as np + +FLAGS = flags.FLAGS +flags.DEFINE_string("filename", None, "amc file to be converted.") +flags.DEFINE_integer( + "max_num_frames", 90, "Maximum number of frames for plotting/playback" +) + + +def main(unused_argv): + env = humanoid_CMU.stand() + + # Parse and convert specified clip. + converted = parse_amc.convert(FLAGS.filename, env.physics, env.control_timestep()) + + max_frame = min(FLAGS.max_num_frames, converted.qpos.shape[1] - 1) + + width = 480 + height = 480 + video = np.zeros((max_frame, height, 2 * width, 3), dtype=np.uint8) + + for i in range(max_frame): + p_i = converted.qpos[:, i] + with env.physics.reset_context(): + env.physics.data.qpos[:] = p_i + video[i] = np.hstack( + [ + env.physics.render(height, width, camera_id=0), + env.physics.render(height, width, camera_id=1), + ] + ) + + tic = time.time() + for i in range(max_frame): + if i == 0: + img = plt.imshow(video[i]) + else: + img.set_data(video[i]) + toc = time.time() + clock_dt = toc - tic + tic = time.time() + # Real-time playback not always possible as clock_dt > .03 + plt.pause(max(0.01, 0.03 - clock_dt)) # Need min display time > 0.0. + plt.draw() + plt.waitforbuttonpress() + + +if __name__ == "__main__": + flags.mark_flag_as_required("filename") + app.run(main) diff --git a/local_dm_control_suite/demos/zeros.amc b/local_dm_control_suite/demos/zeros.amc new file mode 100755 index 0000000..b4590a4 --- /dev/null +++ b/local_dm_control_suite/demos/zeros.amc @@ -0,0 +1,213 @@ +#DUMMY AMC for testing +:FULLY-SPECIFIED +:DEGREES +1 +root 0 0 0 0 0 0 +lowerback 0 0 0 +upperback 0 0 0 +thorax 0 0 0 +lowerneck 0 0 0 +upperneck 0 0 0 +head 0 0 0 +rclavicle 0 0 +rhumerus 0 0 0 +rradius 0 +rwrist 0 +rhand 0 0 +rfingers 0 +rthumb 0 0 +lclavicle 0 0 +lhumerus 0 0 0 +lradius 0 +lwrist 0 +lhand 0 0 +lfingers 0 +lthumb 0 0 +rfemur 0 0 0 +rtibia 0 +rfoot 0 0 +rtoes 0 +lfemur 0 0 0 +ltibia 0 +lfoot 0 0 +ltoes 0 +2 +root 0 0 0 0 0 0 +lowerback 0 0 0 +upperback 0 0 0 +thorax 0 0 0 +lowerneck 0 0 0 +upperneck 0 0 0 +head 0 0 0 +rclavicle 0 0 +rhumerus 0 0 0 +rradius 0 +rwrist 0 +rhand 0 0 +rfingers 0 +rthumb 0 0 +lclavicle 0 0 +lhumerus 0 0 0 +lradius 0 +lwrist 0 +lhand 0 0 +lfingers 0 +lthumb 0 0 +rfemur 0 0 0 +rtibia 0 +rfoot 0 0 +rtoes 0 +lfemur 0 0 0 +ltibia 0 +lfoot 0 0 +ltoes 0 +3 +root 0 0 0 0 0 0 +lowerback 0 0 0 +upperback 0 0 0 +thorax 0 0 0 +lowerneck 0 0 0 +upperneck 0 0 0 +head 0 0 0 +rclavicle 0 0 +rhumerus 0 0 0 +rradius 0 +rwrist 0 +rhand 0 0 +rfingers 0 +rthumb 0 0 +lclavicle 0 0 +lhumerus 0 0 0 +lradius 0 +lwrist 0 +lhand 0 0 +lfingers 0 +lthumb 0 0 +rfemur 0 0 0 +rtibia 0 +rfoot 0 0 +rtoes 0 +lfemur 0 0 0 +ltibia 0 +lfoot 0 0 +ltoes 0 +4 +root 0 0 0 0 0 0 +lowerback 0 0 0 +upperback 0 0 0 +thorax 0 0 0 +lowerneck 0 0 0 +upperneck 0 0 0 +head 0 0 0 +rclavicle 0 0 +rhumerus 0 0 0 +rradius 0 +rwrist 0 +rhand 0 0 +rfingers 0 +rthumb 0 0 +lclavicle 0 0 +lhumerus 0 0 0 +lradius 0 +lwrist 0 +lhand 0 0 +lfingers 0 +lthumb 0 0 +rfemur 0 0 0 +rtibia 0 +rfoot 0 0 +rtoes 0 +lfemur 0 0 0 +ltibia 0 +lfoot 0 0 +ltoes 0 +5 +root 0 0 0 0 0 0 +lowerback 0 0 0 +upperback 0 0 0 +thorax 0 0 0 +lowerneck 0 0 0 +upperneck 0 0 0 +head 0 0 0 +rclavicle 0 0 +rhumerus 0 0 0 +rradius 0 +rwrist 0 +rhand 0 0 +rfingers 0 +rthumb 0 0 +lclavicle 0 0 +lhumerus 0 0 0 +lradius 0 +lwrist 0 +lhand 0 0 +lfingers 0 +lthumb 0 0 +rfemur 0 0 0 +rtibia 0 +rfoot 0 0 +rtoes 0 +lfemur 0 0 0 +ltibia 0 +lfoot 0 0 +ltoes 0 +6 +root 0 0 0 0 0 0 +lowerback 0 0 0 +upperback 0 0 0 +thorax 0 0 0 +lowerneck 0 0 0 +upperneck 0 0 0 +head 0 0 0 +rclavicle 0 0 +rhumerus 0 0 0 +rradius 0 +rwrist 0 +rhand 0 0 +rfingers 0 +rthumb 0 0 +lclavicle 0 0 +lhumerus 0 0 0 +lradius 0 +lwrist 0 +lhand 0 0 +lfingers 0 +lthumb 0 0 +rfemur 0 0 0 +rtibia 0 +rfoot 0 0 +rtoes 0 +lfemur 0 0 0 +ltibia 0 +lfoot 0 0 +ltoes 0 +7 +root 0 0 0 0 0 0 +lowerback 0 0 0 +upperback 0 0 0 +thorax 0 0 0 +lowerneck 0 0 0 +upperneck 0 0 0 +head 0 0 0 +rclavicle 0 0 +rhumerus 0 0 0 +rradius 0 +rwrist 0 +rhand 0 0 +rfingers 0 +rthumb 0 0 +lclavicle 0 0 +lhumerus 0 0 0 +lradius 0 +lwrist 0 +lhand 0 0 +lfingers 0 +lthumb 0 0 +rfemur 0 0 0 +rtibia 0 +rfoot 0 0 +rtoes 0 +lfemur 0 0 0 +ltibia 0 +lfoot 0 0 +ltoes 0 diff --git a/local_dm_control_suite/explore.py b/local_dm_control_suite/explore.py new file mode 100755 index 0000000..c552e97 --- /dev/null +++ b/local_dm_control_suite/explore.py @@ -0,0 +1,95 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Control suite environments explorer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl import app +from absl import flags +from dm_control import suite +from dm_control.suite.wrappers import action_noise +from six.moves import input + +from dm_control import viewer + + +_ALL_NAMES = [".".join(domain_task) for domain_task in suite.ALL_TASKS] + +flags.DEFINE_enum( + "environment_name", + None, + _ALL_NAMES, + "Optional 'domain_name.task_name' pair specifying the " + "environment to load. If unspecified a prompt will appear to " + "select one.", +) +flags.DEFINE_bool("timeout", True, "Whether episodes should have a time limit.") +flags.DEFINE_bool( + "visualize_reward", + True, + "Whether to vary the colors of geoms according to the " "current reward value.", +) +flags.DEFINE_float( + "action_noise", + 0.0, + "Standard deviation of Gaussian noise to apply to actions, " + "expressed as a fraction of the max-min range for each " + "action dimension. Defaults to 0, i.e. no noise.", +) +FLAGS = flags.FLAGS + + +def prompt_environment_name(prompt, values): + environment_name = None + while not environment_name: + environment_name = input(prompt) + if not environment_name or values.index(environment_name) < 0: + print('"%s" is not a valid environment name.' % environment_name) + environment_name = None + return environment_name + + +def main(argv): + del argv + environment_name = FLAGS.environment_name + if environment_name is None: + print("\n ".join(["Available environments:"] + _ALL_NAMES)) + environment_name = prompt_environment_name( + "Please select an environment name: ", _ALL_NAMES + ) + + index = _ALL_NAMES.index(environment_name) + domain_name, task_name = suite.ALL_TASKS[index] + + task_kwargs = {} + if not FLAGS.timeout: + task_kwargs["time_limit"] = float("inf") + + def loader(): + env = suite.load( + domain_name=domain_name, task_name=task_name, task_kwargs=task_kwargs + ) + env.task.visualize_reward = FLAGS.visualize_reward + if FLAGS.action_noise > 0: + env = action_noise.Wrapper(env, scale=FLAGS.action_noise) + return env + + viewer.launch(loader) + + +if __name__ == "__main__": + app.run(main) diff --git a/local_dm_control_suite/finger.py b/local_dm_control_suite/finger.py new file mode 100755 index 0000000..fa35971 --- /dev/null +++ b/local_dm_control_suite/finger.py @@ -0,0 +1,242 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Finger Domain.""" + +from __future__ import absolute_import, division, print_function + +import collections + +import numpy as np +from dm_control import mujoco +from dm_control.rl import control +from dm_control.suite.utils import randomizers +from dm_control.utils import containers +from six.moves import range + +from . import base, common + +_DEFAULT_TIME_LIMIT = 20 # (seconds) +_CONTROL_TIMESTEP = 0.02 # (seconds) +# For TURN tasks, the 'tip' geom needs to enter a spherical target of sizes: +_EASY_TARGET_SIZE = 0.07 +_HARD_TARGET_SIZE = 0.03 +# Initial spin velocity for the Stop task. +_INITIAL_SPIN_VELOCITY = 100 +# Spinning slower than this value (radian/second) is considered stopped. +_STOP_VELOCITY = 1e-6 +# Spinning faster than this value (radian/second) is considered spinning. +_SPIN_VELOCITY = 15.0 + + +SUITE = containers.TaggedTasks() + + +def get_model_and_assets(xml_file_id): + """Returns a tuple containing the model XML string and a dict of assets.""" + if xml_file_id is not None: + filename = f"finger_{xml_file_id}.xml" + else: + filename = f"finger.xml" + return common.read_model(filename), common.ASSETS + + +@SUITE.add("benchmarking") +def spin( + time_limit=_DEFAULT_TIME_LIMIT, + xml_file_id=None, + random=None, + environment_kwargs=None, +): + """Returns the Spin task.""" + physics = Physics.from_xml_string(*get_model_and_assets(xml_file_id)) + task = Spin(random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, + task, + time_limit=time_limit, + control_timestep=_CONTROL_TIMESTEP, + **environment_kwargs, + ) + + +@SUITE.add("benchmarking") +def turn_easy(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): + """Returns the easy Turn task.""" + physics = Physics.from_xml_string(*get_model_and_assets()) + task = Turn(target_radius=_EASY_TARGET_SIZE, random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, + task, + time_limit=time_limit, + control_timestep=_CONTROL_TIMESTEP, + **environment_kwargs, + ) + + +@SUITE.add("benchmarking") +def turn_hard(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): + """Returns the hard Turn task.""" + physics = Physics.from_xml_string(*get_model_and_assets()) + task = Turn(target_radius=_HARD_TARGET_SIZE, random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, + task, + time_limit=time_limit, + control_timestep=_CONTROL_TIMESTEP, + **environment_kwargs, + ) + + +class Physics(mujoco.Physics): + """Physics simulation with additional features for the Finger domain.""" + + def touch(self): + """Returns logarithmically scaled signals from the two touch sensors.""" + return np.log1p(self.named.data.sensordata[["touchtop", "touchbottom"]]) + + def hinge_velocity(self): + """Returns the velocity of the hinge joint.""" + return self.named.data.sensordata["hinge_velocity"] + + def tip_position(self): + """Returns the (x,z) position of the tip relative to the hinge.""" + return ( + self.named.data.sensordata["tip"][[0, 2]] + - self.named.data.sensordata["spinner"][[0, 2]] + ) + + def bounded_position(self): + """Returns the positions, with the hinge angle replaced by tip position.""" + return np.hstack( + (self.named.data.sensordata[["proximal", "distal"]], self.tip_position()) + ) + + def velocity(self): + """Returns the velocities (extracted from sensordata).""" + return self.named.data.sensordata[ + ["proximal_velocity", "distal_velocity", "hinge_velocity"] + ] + + def target_position(self): + """Returns the (x,z) position of the target relative to the hinge.""" + return ( + self.named.data.sensordata["target"][[0, 2]] + - self.named.data.sensordata["spinner"][[0, 2]] + ) + + def to_target(self): + """Returns the vector from the tip to the target.""" + return self.target_position() - self.tip_position() + + def dist_to_target(self): + """Returns the signed distance to the target surface, negative is inside.""" + return ( + np.linalg.norm(self.to_target()) - self.named.model.site_size["target", 0] + ) + + +class Spin(base.Task): + """A Finger `Task` to spin the stopped body.""" + + def __init__(self, random=None): + """Initializes a new `Spin` instance. + + Args: + random: Optional, either a `numpy.random.RandomState` instance, an + integer seed for creating a new `RandomState`, or None to select a seed + automatically (default). + """ + super(Spin, self).__init__(random=random) + + def initialize_episode(self, physics): + physics.named.model.site_rgba["target", 3] = 0 + physics.named.model.site_rgba["tip", 3] = 0 + physics.named.model.dof_damping["hinge"] = 0.03 + _set_random_joint_angles(physics, self.random) + super(Spin, self).initialize_episode(physics) + + def get_observation(self, physics): + """Returns state and touch sensors, and target info.""" + obs = collections.OrderedDict() + obs["position"] = physics.bounded_position() + obs["velocity"] = physics.velocity() + obs["touch"] = physics.touch() + return obs + + def get_reward(self, physics): + """Returns a sparse reward.""" + return float(physics.hinge_velocity() <= -_SPIN_VELOCITY) + + +class Turn(base.Task): + """A Finger `Task` to turn the body to a target angle.""" + + def __init__(self, target_radius, random=None): + """Initializes a new `Turn` instance. + + Args: + target_radius: Radius of the target site, which specifies the goal angle. + random: Optional, either a `numpy.random.RandomState` instance, an + integer seed for creating a new `RandomState`, or None to select a seed + automatically (default). + """ + self._target_radius = target_radius + super(Turn, self).__init__(random=random) + + def initialize_episode(self, physics): + target_angle = self.random.uniform(-np.pi, np.pi) + hinge_x, hinge_z = physics.named.data.xanchor["hinge", ["x", "z"]] + radius = physics.named.model.geom_size["cap1"].sum() + target_x = hinge_x + radius * np.sin(target_angle) + target_z = hinge_z + radius * np.cos(target_angle) + physics.named.model.site_pos["target", ["x", "z"]] = target_x, target_z + physics.named.model.site_size["target", 0] = self._target_radius + + _set_random_joint_angles(physics, self.random) + + super(Turn, self).initialize_episode(physics) + + def get_observation(self, physics): + """Returns state, touch sensors, and target info.""" + obs = collections.OrderedDict() + obs["position"] = physics.bounded_position() + obs["velocity"] = physics.velocity() + obs["touch"] = physics.touch() + obs["target_position"] = physics.target_position() + obs["dist_to_target"] = physics.dist_to_target() + return obs + + def get_reward(self, physics): + return float(physics.dist_to_target() <= 0) + + +def _set_random_joint_angles(physics, random, max_attempts=1000): + """Sets the joints to a random collision-free state.""" + + for _ in range(max_attempts): + randomizers.randomize_limited_and_rotational_joints(physics, random) + # Check for collisions. + physics.after_reset() + if physics.data.ncon == 0: + break + else: + raise RuntimeError( + "Could not find a collision-free state " + "after {} attempts".format(max_attempts) + ) diff --git a/local_dm_control_suite/finger.xml b/local_dm_control_suite/finger.xml new file mode 100755 index 0000000..37a3914 --- /dev/null +++ b/local_dm_control_suite/finger.xml @@ -0,0 +1,71 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/finger_size_1.xml b/local_dm_control_suite/finger_size_1.xml new file mode 100755 index 0000000..37a3914 --- /dev/null +++ b/local_dm_control_suite/finger_size_1.xml @@ -0,0 +1,71 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/finger_size_10.xml b/local_dm_control_suite/finger_size_10.xml new file mode 100755 index 0000000..c1490ea --- /dev/null +++ b/local_dm_control_suite/finger_size_10.xml @@ -0,0 +1,71 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/finger_size_2.xml b/local_dm_control_suite/finger_size_2.xml new file mode 100755 index 0000000..b488b6b --- /dev/null +++ b/local_dm_control_suite/finger_size_2.xml @@ -0,0 +1,71 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/finger_size_3.xml b/local_dm_control_suite/finger_size_3.xml new file mode 100755 index 0000000..cad2772 --- /dev/null +++ b/local_dm_control_suite/finger_size_3.xml @@ -0,0 +1,71 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/finger_size_4.xml b/local_dm_control_suite/finger_size_4.xml new file mode 100755 index 0000000..fa73efe --- /dev/null +++ b/local_dm_control_suite/finger_size_4.xml @@ -0,0 +1,71 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/finger_size_5.xml b/local_dm_control_suite/finger_size_5.xml new file mode 100755 index 0000000..447fe13 --- /dev/null +++ b/local_dm_control_suite/finger_size_5.xml @@ -0,0 +1,71 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/finger_size_6.xml b/local_dm_control_suite/finger_size_6.xml new file mode 100755 index 0000000..12e466f --- /dev/null +++ b/local_dm_control_suite/finger_size_6.xml @@ -0,0 +1,71 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/finger_size_7.xml b/local_dm_control_suite/finger_size_7.xml new file mode 100755 index 0000000..710f8c6 --- /dev/null +++ b/local_dm_control_suite/finger_size_7.xml @@ -0,0 +1,71 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/finger_size_8.xml b/local_dm_control_suite/finger_size_8.xml new file mode 100755 index 0000000..5407ae5 --- /dev/null +++ b/local_dm_control_suite/finger_size_8.xml @@ -0,0 +1,71 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/finger_size_9.xml b/local_dm_control_suite/finger_size_9.xml new file mode 100755 index 0000000..197d95f --- /dev/null +++ b/local_dm_control_suite/finger_size_9.xml @@ -0,0 +1,71 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/fish.py b/local_dm_control_suite/fish.py new file mode 100755 index 0000000..859ee38 --- /dev/null +++ b/local_dm_control_suite/fish.py @@ -0,0 +1,188 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Fish Domain.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from dm_control import mujoco +from dm_control.rl import control +from . import base +from . import common +from dm_control.utils import containers +from dm_control.utils import rewards +import numpy as np + + +_DEFAULT_TIME_LIMIT = 40 +_CONTROL_TIMESTEP = 0.04 +_JOINTS = [ + "tail1", + "tail_twist", + "tail2", + "finright_roll", + "finright_pitch", + "finleft_roll", + "finleft_pitch", +] +SUITE = containers.TaggedTasks() + + +def get_model_and_assets(): + """Returns a tuple containing the model XML string and a dict of assets.""" + return common.read_model("fish.xml"), common.ASSETS + + +@SUITE.add("benchmarking") +def upright(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): + """Returns the Fish Upright task.""" + physics = Physics.from_xml_string(*get_model_and_assets()) + task = Upright(random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, + task, + control_timestep=_CONTROL_TIMESTEP, + time_limit=time_limit, + **environment_kwargs + ) + + +@SUITE.add("benchmarking") +def swim(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): + """Returns the Fish Swim task.""" + physics = Physics.from_xml_string(*get_model_and_assets()) + task = Swim(random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, + task, + control_timestep=_CONTROL_TIMESTEP, + time_limit=time_limit, + **environment_kwargs + ) + + +class Physics(mujoco.Physics): + """Physics simulation with additional features for the Fish domain.""" + + def upright(self): + """Returns projection from z-axes of torso to the z-axes of worldbody.""" + return self.named.data.xmat["torso", "zz"] + + def torso_velocity(self): + """Returns velocities and angular velocities of the torso.""" + return self.data.sensordata + + def joint_velocities(self): + """Returns the joint velocities.""" + return self.named.data.qvel[_JOINTS] + + def joint_angles(self): + """Returns the joint positions.""" + return self.named.data.qpos[_JOINTS] + + def mouth_to_target(self): + """Returns a vector, from mouth to target in local coordinate of mouth.""" + data = self.named.data + mouth_to_target_global = data.geom_xpos["target"] - data.geom_xpos["mouth"] + return mouth_to_target_global.dot(data.geom_xmat["mouth"].reshape(3, 3)) + + +class Upright(base.Task): + """A Fish `Task` for getting the torso upright with smooth reward.""" + + def __init__(self, random=None): + """Initializes an instance of `Upright`. + + Args: + random: Either an existing `numpy.random.RandomState` instance, an + integer seed for creating a new `RandomState`, or None to select a seed + automatically. + """ + super(Upright, self).__init__(random=random) + + def initialize_episode(self, physics): + """Randomizes the tail and fin angles and the orientation of the Fish.""" + quat = self.random.randn(4) + physics.named.data.qpos["root"][3:7] = quat / np.linalg.norm(quat) + for joint in _JOINTS: + physics.named.data.qpos[joint] = self.random.uniform(-0.2, 0.2) + # Hide the target. It's irrelevant for this task. + physics.named.model.geom_rgba["target", 3] = 0 + super(Upright, self).initialize_episode(physics) + + def get_observation(self, physics): + """Returns an observation of joint angles, velocities and uprightness.""" + obs = collections.OrderedDict() + obs["joint_angles"] = physics.joint_angles() + obs["upright"] = physics.upright() + obs["velocity"] = physics.velocity() + return obs + + def get_reward(self, physics): + """Returns a smooth reward.""" + return rewards.tolerance(physics.upright(), bounds=(1, 1), margin=1) + + +class Swim(base.Task): + """A Fish `Task` for swimming with smooth reward.""" + + def __init__(self, random=None): + """Initializes an instance of `Swim`. + + Args: + random: Optional, either a `numpy.random.RandomState` instance, an + integer seed for creating a new `RandomState`, or None to select a seed + automatically (default). + """ + super(Swim, self).__init__(random=random) + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode.""" + + quat = self.random.randn(4) + physics.named.data.qpos["root"][3:7] = quat / np.linalg.norm(quat) + for joint in _JOINTS: + physics.named.data.qpos[joint] = self.random.uniform(-0.2, 0.2) + # Randomize target position. + physics.named.model.geom_pos["target", "x"] = self.random.uniform(-0.4, 0.4) + physics.named.model.geom_pos["target", "y"] = self.random.uniform(-0.4, 0.4) + physics.named.model.geom_pos["target", "z"] = self.random.uniform(0.1, 0.3) + super(Swim, self).initialize_episode(physics) + + def get_observation(self, physics): + """Returns an observation of joints, target direction and velocities.""" + obs = collections.OrderedDict() + obs["joint_angles"] = physics.joint_angles() + obs["upright"] = physics.upright() + obs["target"] = physics.mouth_to_target() + obs["velocity"] = physics.velocity() + return obs + + def get_reward(self, physics): + """Returns a smooth reward.""" + radii = physics.named.model.geom_size[["mouth", "target"], 0].sum() + in_target = rewards.tolerance( + np.linalg.norm(physics.mouth_to_target()), + bounds=(0, radii), + margin=2 * radii, + ) + is_upright = 0.5 * (physics.upright() + 1) + return (7 * in_target + is_upright) / 8 diff --git a/local_dm_control_suite/fish.xml b/local_dm_control_suite/fish.xml new file mode 100755 index 0000000..43de56d --- /dev/null +++ b/local_dm_control_suite/fish.xml @@ -0,0 +1,85 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/hopper.py b/local_dm_control_suite/hopper.py new file mode 100755 index 0000000..fe253ac --- /dev/null +++ b/local_dm_control_suite/hopper.py @@ -0,0 +1,147 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Hopper domain.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from dm_control import mujoco +from dm_control.rl import control +from . import base +from . import common +from dm_control.suite.utils import randomizers +from dm_control.utils import containers +from dm_control.utils import rewards +import numpy as np + + +SUITE = containers.TaggedTasks() + +_CONTROL_TIMESTEP = 0.02 # (Seconds) + +# Default duration of an episode, in seconds. +_DEFAULT_TIME_LIMIT = 20 + +# Minimal height of torso over foot above which stand reward is 1. +_STAND_HEIGHT = 0.6 + +# Hopping speed above which hop reward is 1. +_HOP_SPEED = 2 + + +def get_model_and_assets(): + """Returns a tuple containing the model XML string and a dict of assets.""" + return common.read_model("hopper.xml"), common.ASSETS + + +@SUITE.add("benchmarking") +def stand(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): + """Returns a Hopper that strives to stand upright, balancing its pose.""" + physics = Physics.from_xml_string(*get_model_and_assets()) + task = Hopper(hopping=False, random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, + task, + time_limit=time_limit, + control_timestep=_CONTROL_TIMESTEP, + **environment_kwargs + ) + + +@SUITE.add("benchmarking") +def hop(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): + """Returns a Hopper that strives to hop forward.""" + physics = Physics.from_xml_string(*get_model_and_assets()) + task = Hopper(hopping=True, random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, + task, + time_limit=time_limit, + control_timestep=_CONTROL_TIMESTEP, + **environment_kwargs + ) + + +class Physics(mujoco.Physics): + """Physics simulation with additional features for the Hopper domain.""" + + def height(self): + """Returns height of torso with respect to foot.""" + return self.named.data.xipos["torso", "z"] - self.named.data.xipos["foot", "z"] + + def speed(self): + """Returns horizontal speed of the Hopper.""" + return self.named.data.sensordata["torso_subtreelinvel"][0] + + def touch(self): + """Returns the signals from two foot touch sensors.""" + return np.log1p(self.named.data.sensordata[["touch_toe", "touch_heel"]]) + + +class Hopper(base.Task): + """A Hopper's `Task` to train a standing and a jumping Hopper.""" + + def __init__(self, hopping, random=None): + """Initialize an instance of `Hopper`. + + Args: + hopping: Boolean, if True the task is to hop forwards, otherwise it is to + balance upright. + random: Optional, either a `numpy.random.RandomState` instance, an + integer seed for creating a new `RandomState`, or None to select a seed + automatically (default). + """ + self._hopping = hopping + super(Hopper, self).__init__(random=random) + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode.""" + randomizers.randomize_limited_and_rotational_joints(physics, self.random) + self._timeout_progress = 0 + super(Hopper, self).initialize_episode(physics) + + def get_observation(self, physics): + """Returns an observation of positions, velocities and touch sensors.""" + obs = collections.OrderedDict() + # Ignores horizontal position to maintain translational invariance: + obs["position"] = physics.data.qpos[1:].copy() + obs["velocity"] = physics.velocity() + obs["touch"] = physics.touch() + return obs + + def get_reward(self, physics): + """Returns a reward applicable to the performed task.""" + standing = rewards.tolerance(physics.height(), (_STAND_HEIGHT, 2)) + if self._hopping: + hopping = rewards.tolerance( + physics.speed(), + bounds=(_HOP_SPEED, float("inf")), + margin=_HOP_SPEED / 2, + value_at_margin=0.5, + sigmoid="linear", + ) + return standing * hopping + else: + small_control = rewards.tolerance( + physics.control(), margin=1, value_at_margin=0, sigmoid="quadratic" + ).mean() + small_control = (small_control + 4) / 5 + return standing * small_control diff --git a/local_dm_control_suite/hopper.xml b/local_dm_control_suite/hopper.xml new file mode 100755 index 0000000..0c8ec28 --- /dev/null +++ b/local_dm_control_suite/hopper.xml @@ -0,0 +1,66 @@ + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/humanoid.py b/local_dm_control_suite/humanoid.py new file mode 100755 index 0000000..75101c7 --- /dev/null +++ b/local_dm_control_suite/humanoid.py @@ -0,0 +1,237 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Humanoid Domain.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from dm_control import mujoco +from dm_control.rl import control +from . import base +from . import common +from dm_control.suite.utils import randomizers +from dm_control.utils import containers +from dm_control.utils import rewards +import numpy as np + +_DEFAULT_TIME_LIMIT = 25 +_CONTROL_TIMESTEP = 0.025 + +# Height of head above which stand reward is 1. +_STAND_HEIGHT = 1.4 + +# Horizontal speeds above which move reward is 1. +_WALK_SPEED = 1 +_RUN_SPEED = 10 + + +SUITE = containers.TaggedTasks() + + +def get_model_and_assets(): + """Returns a tuple containing the model XML string and a dict of assets.""" + return common.read_model("humanoid.xml"), common.ASSETS + + +@SUITE.add("benchmarking") +def stand(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): + """Returns the Stand task.""" + physics = Physics.from_xml_string(*get_model_and_assets()) + task = Humanoid(move_speed=0, pure_state=False, random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, + task, + time_limit=time_limit, + control_timestep=_CONTROL_TIMESTEP, + **environment_kwargs + ) + + +@SUITE.add("benchmarking") +def walk(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): + """Returns the Walk task.""" + physics = Physics.from_xml_string(*get_model_and_assets()) + task = Humanoid(move_speed=_WALK_SPEED, pure_state=False, random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, + task, + time_limit=time_limit, + control_timestep=_CONTROL_TIMESTEP, + **environment_kwargs + ) + + +@SUITE.add("benchmarking") +def run(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): + """Returns the Run task.""" + physics = Physics.from_xml_string(*get_model_and_assets()) + task = Humanoid(move_speed=_RUN_SPEED, pure_state=False, random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, + task, + time_limit=time_limit, + control_timestep=_CONTROL_TIMESTEP, + **environment_kwargs + ) + + +@SUITE.add() +def run_pure_state( + time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None +): + """Returns the Run task.""" + physics = Physics.from_xml_string(*get_model_and_assets()) + task = Humanoid(move_speed=_RUN_SPEED, pure_state=True, random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, + task, + time_limit=time_limit, + control_timestep=_CONTROL_TIMESTEP, + **environment_kwargs + ) + + +class Physics(mujoco.Physics): + """Physics simulation with additional features for the Walker domain.""" + + def torso_upright(self): + """Returns projection from z-axes of torso to the z-axes of world.""" + return self.named.data.xmat["torso", "zz"] + + def head_height(self): + """Returns the height of the torso.""" + return self.named.data.xpos["head", "z"] + + def center_of_mass_position(self): + """Returns position of the center-of-mass.""" + return self.named.data.subtree_com["torso"].copy() + + def center_of_mass_velocity(self): + """Returns the velocity of the center-of-mass.""" + return self.named.data.sensordata["torso_subtreelinvel"].copy() + + def torso_vertical_orientation(self): + """Returns the z-projection of the torso orientation matrix.""" + return self.named.data.xmat["torso", ["zx", "zy", "zz"]] + + def joint_angles(self): + """Returns the state without global orientation or position.""" + return self.data.qpos[7:].copy() # Skip the 7 DoFs of the free root joint. + + def extremities(self): + """Returns end effector positions in egocentric frame.""" + torso_frame = self.named.data.xmat["torso"].reshape(3, 3) + torso_pos = self.named.data.xpos["torso"] + positions = [] + for side in ("left_", "right_"): + for limb in ("hand", "foot"): + torso_to_limb = self.named.data.xpos[side + limb] - torso_pos + positions.append(torso_to_limb.dot(torso_frame)) + return np.hstack(positions) + + +class Humanoid(base.Task): + """A humanoid task.""" + + def __init__(self, move_speed, pure_state, random=None): + """Initializes an instance of `Humanoid`. + + Args: + move_speed: A float. If this value is zero, reward is given simply for + standing up. Otherwise this specifies a target horizontal velocity for + the walking task. + pure_state: A bool. Whether the observations consist of the pure MuJoCo + state or includes some useful features thereof. + random: Optional, either a `numpy.random.RandomState` instance, an + integer seed for creating a new `RandomState`, or None to select a seed + automatically (default). + """ + self._move_speed = move_speed + self._pure_state = pure_state + super(Humanoid, self).__init__(random=random) + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode. + + Args: + physics: An instance of `Physics`. + + """ + # Find a collision-free random initial configuration. + penetrating = True + while penetrating: + randomizers.randomize_limited_and_rotational_joints(physics, self.random) + # Check for collisions. + physics.after_reset() + penetrating = physics.data.ncon > 0 + super(Humanoid, self).initialize_episode(physics) + + def get_observation(self, physics): + """Returns either the pure state or a set of egocentric features.""" + obs = collections.OrderedDict() + if self._pure_state: + obs["position"] = physics.position() + obs["velocity"] = physics.velocity() + else: + obs["joint_angles"] = physics.joint_angles() + obs["head_height"] = physics.head_height() + obs["extremities"] = physics.extremities() + obs["torso_vertical"] = physics.torso_vertical_orientation() + obs["com_velocity"] = physics.center_of_mass_velocity() + obs["velocity"] = physics.velocity() + return obs + + def get_reward(self, physics): + """Returns a reward to the agent.""" + standing = rewards.tolerance( + physics.head_height(), + bounds=(_STAND_HEIGHT, float("inf")), + margin=_STAND_HEIGHT / 4, + ) + upright = rewards.tolerance( + physics.torso_upright(), + bounds=(0.9, float("inf")), + sigmoid="linear", + margin=1.9, + value_at_margin=0, + ) + stand_reward = standing * upright + small_control = rewards.tolerance( + physics.control(), margin=1, value_at_margin=0, sigmoid="quadratic" + ).mean() + small_control = (4 + small_control) / 5 + if self._move_speed == 0: + horizontal_velocity = physics.center_of_mass_velocity()[[0, 1]] + dont_move = rewards.tolerance(horizontal_velocity, margin=2).mean() + return small_control * stand_reward * dont_move + else: + com_velocity = np.linalg.norm(physics.center_of_mass_velocity()[[0, 1]]) + move = rewards.tolerance( + com_velocity, + bounds=(self._move_speed, float("inf")), + margin=self._move_speed, + value_at_margin=0, + sigmoid="linear", + ) + move = (5 * move + 1) / 6 + return small_control * stand_reward * move diff --git a/local_dm_control_suite/humanoid.xml b/local_dm_control_suite/humanoid.xml new file mode 100755 index 0000000..32b84c5 --- /dev/null +++ b/local_dm_control_suite/humanoid.xml @@ -0,0 +1,202 @@ + + + + + + + + + diff --git a/local_dm_control_suite/humanoid_CMU.py b/local_dm_control_suite/humanoid_CMU.py new file mode 100755 index 0000000..d4663bb --- /dev/null +++ b/local_dm_control_suite/humanoid_CMU.py @@ -0,0 +1,195 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Humanoid_CMU Domain.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from dm_control import mujoco +from dm_control.rl import control +from . import base +from . import common +from dm_control.suite.utils import randomizers +from dm_control.utils import containers +from dm_control.utils import rewards +import numpy as np + +_DEFAULT_TIME_LIMIT = 20 +_CONTROL_TIMESTEP = 0.02 + +# Height of head above which stand reward is 1. +_STAND_HEIGHT = 1.4 + +# Horizontal speeds above which move reward is 1. +_WALK_SPEED = 1 +_RUN_SPEED = 10 + +SUITE = containers.TaggedTasks() + + +def get_model_and_assets(): + """Returns a tuple containing the model XML string and a dict of assets.""" + return common.read_model("humanoid_CMU.xml"), common.ASSETS + + +@SUITE.add() +def stand(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): + """Returns the Stand task.""" + physics = Physics.from_xml_string(*get_model_and_assets()) + task = HumanoidCMU(move_speed=0, random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, + task, + time_limit=time_limit, + control_timestep=_CONTROL_TIMESTEP, + **environment_kwargs + ) + + +@SUITE.add() +def run(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): + """Returns the Run task.""" + physics = Physics.from_xml_string(*get_model_and_assets()) + task = HumanoidCMU(move_speed=_RUN_SPEED, random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, + task, + time_limit=time_limit, + control_timestep=_CONTROL_TIMESTEP, + **environment_kwargs + ) + + +class Physics(mujoco.Physics): + """Physics simulation with additional features for the humanoid_CMU domain.""" + + def thorax_upright(self): + """Returns projection from y-axes of thorax to the z-axes of world.""" + return self.named.data.xmat["thorax", "zy"] + + def head_height(self): + """Returns the height of the head.""" + return self.named.data.xpos["head", "z"] + + def center_of_mass_position(self): + """Returns position of the center-of-mass.""" + return self.named.data.subtree_com["thorax"] + + def center_of_mass_velocity(self): + """Returns the velocity of the center-of-mass.""" + return self.named.data.sensordata["thorax_subtreelinvel"].copy() + + def torso_vertical_orientation(self): + """Returns the z-projection of the thorax orientation matrix.""" + return self.named.data.xmat["thorax", ["zx", "zy", "zz"]] + + def joint_angles(self): + """Returns the state without global orientation or position.""" + return self.data.qpos[7:].copy() # Skip the 7 DoFs of the free root joint. + + def extremities(self): + """Returns end effector positions in egocentric frame.""" + torso_frame = self.named.data.xmat["thorax"].reshape(3, 3) + torso_pos = self.named.data.xpos["thorax"] + positions = [] + for side in ("l", "r"): + for limb in ("hand", "foot"): + torso_to_limb = self.named.data.xpos[side + limb] - torso_pos + positions.append(torso_to_limb.dot(torso_frame)) + return np.hstack(positions) + + +class HumanoidCMU(base.Task): + """A task for the CMU Humanoid.""" + + def __init__(self, move_speed, random=None): + """Initializes an instance of `Humanoid_CMU`. + + Args: + move_speed: A float. If this value is zero, reward is given simply for + standing up. Otherwise this specifies a target horizontal velocity for + the walking task. + random: Optional, either a `numpy.random.RandomState` instance, an + integer seed for creating a new `RandomState`, or None to select a seed + automatically (default). + """ + self._move_speed = move_speed + super(HumanoidCMU, self).__init__(random=random) + + def initialize_episode(self, physics): + """Sets a random collision-free configuration at the start of each episode. + + Args: + physics: An instance of `Physics`. + """ + penetrating = True + while penetrating: + randomizers.randomize_limited_and_rotational_joints(physics, self.random) + # Check for collisions. + physics.after_reset() + penetrating = physics.data.ncon > 0 + super(HumanoidCMU, self).initialize_episode(physics) + + def get_observation(self, physics): + """Returns a set of egocentric features.""" + obs = collections.OrderedDict() + obs["joint_angles"] = physics.joint_angles() + obs["head_height"] = physics.head_height() + obs["extremities"] = physics.extremities() + obs["torso_vertical"] = physics.torso_vertical_orientation() + obs["com_velocity"] = physics.center_of_mass_velocity() + obs["velocity"] = physics.velocity() + return obs + + def get_reward(self, physics): + """Returns a reward to the agent.""" + standing = rewards.tolerance( + physics.head_height(), + bounds=(_STAND_HEIGHT, float("inf")), + margin=_STAND_HEIGHT / 4, + ) + upright = rewards.tolerance( + physics.thorax_upright(), + bounds=(0.9, float("inf")), + sigmoid="linear", + margin=1.9, + value_at_margin=0, + ) + stand_reward = standing * upright + small_control = rewards.tolerance( + physics.control(), margin=1, value_at_margin=0, sigmoid="quadratic" + ).mean() + small_control = (4 + small_control) / 5 + if self._move_speed == 0: + horizontal_velocity = physics.center_of_mass_velocity()[[0, 1]] + dont_move = rewards.tolerance(horizontal_velocity, margin=2).mean() + return small_control * stand_reward * dont_move + else: + com_velocity = np.linalg.norm(physics.center_of_mass_velocity()[[0, 1]]) + move = rewards.tolerance( + com_velocity, + bounds=(self._move_speed, float("inf")), + margin=self._move_speed, + value_at_margin=0, + sigmoid="linear", + ) + move = (5 * move + 1) / 6 + return small_control * stand_reward * move diff --git a/local_dm_control_suite/humanoid_CMU.xml b/local_dm_control_suite/humanoid_CMU.xml new file mode 100755 index 0000000..9a41a16 --- /dev/null +++ b/local_dm_control_suite/humanoid_CMU.xml @@ -0,0 +1,289 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/lqr.py b/local_dm_control_suite/lqr.py new file mode 100755 index 0000000..92e065d --- /dev/null +++ b/local_dm_control_suite/lqr.py @@ -0,0 +1,271 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Procedurally generated LQR domain.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import os + +from dm_control import mujoco +from dm_control.rl import control +from . import base +from . import common +from dm_control.utils import containers +from dm_control.utils import xml_tools +from lxml import etree +import numpy as np +from six.moves import range + +from dm_control.utils import io as resources + +_DEFAULT_TIME_LIMIT = float("inf") +_CONTROL_COST_COEF = 0.1 +SUITE = containers.TaggedTasks() + + +def get_model_and_assets(n_bodies, n_actuators, random): + """Returns the model description as an XML string and a dict of assets. + + Args: + n_bodies: An int, number of bodies of the LQR. + n_actuators: An int, number of actuated bodies of the LQR. `n_actuators` + should be less or equal than `n_bodies`. + random: A `numpy.random.RandomState` instance. + + Returns: + A tuple `(model_xml_string, assets)`, where `assets` is a dict consisting of + `{filename: contents_string}` pairs. + """ + return _make_model(n_bodies, n_actuators, random), common.ASSETS + + +@SUITE.add() +def lqr_2_1(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): + """Returns an LQR environment with 2 bodies of which the first is actuated.""" + return _make_lqr( + n_bodies=2, + n_actuators=1, + control_cost_coef=_CONTROL_COST_COEF, + time_limit=time_limit, + random=random, + environment_kwargs=environment_kwargs, + ) + + +@SUITE.add() +def lqr_6_2(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): + """Returns an LQR environment with 6 bodies of which first 2 are actuated.""" + return _make_lqr( + n_bodies=6, + n_actuators=2, + control_cost_coef=_CONTROL_COST_COEF, + time_limit=time_limit, + random=random, + environment_kwargs=environment_kwargs, + ) + + +def _make_lqr( + n_bodies, n_actuators, control_cost_coef, time_limit, random, environment_kwargs +): + """Returns a LQR environment. + + Args: + n_bodies: An int, number of bodies of the LQR. + n_actuators: An int, number of actuated bodies of the LQR. `n_actuators` + should be less or equal than `n_bodies`. + control_cost_coef: A number, the coefficient of the control cost. + time_limit: An int, maximum time for each episode in seconds. + random: Either an existing `numpy.random.RandomState` instance, an + integer seed for creating a new `RandomState`, or None to select a seed + automatically. + environment_kwargs: A `dict` specifying keyword arguments for the + environment, or None. + + Returns: + A LQR environment with `n_bodies` bodies of which first `n_actuators` are + actuated. + """ + + if not isinstance(random, np.random.RandomState): + random = np.random.RandomState(random) + + model_string, assets = get_model_and_assets(n_bodies, n_actuators, random=random) + physics = Physics.from_xml_string(model_string, assets=assets) + task = LQRLevel(control_cost_coef, random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, task, time_limit=time_limit, **environment_kwargs + ) + + +def _make_body(body_id, stiffness_range, damping_range, random): + """Returns an `etree.Element` defining a body. + + Args: + body_id: Id of the created body. + stiffness_range: A tuple of (stiffness_lower_bound, stiffness_uppder_bound). + The stiffness of the joint is drawn uniformly from this range. + damping_range: A tuple of (damping_lower_bound, damping_upper_bound). The + damping of the joint is drawn uniformly from this range. + random: A `numpy.random.RandomState` instance. + + Returns: + A new instance of `etree.Element`. A body element with two children: joint + and geom. + """ + body_name = "body_{}".format(body_id) + joint_name = "joint_{}".format(body_id) + geom_name = "geom_{}".format(body_id) + + body = etree.Element("body", name=body_name) + body.set("pos", ".25 0 0") + joint = etree.SubElement(body, "joint", name=joint_name) + body.append(etree.Element("geom", name=geom_name)) + joint.set("stiffness", str(random.uniform(stiffness_range[0], stiffness_range[1]))) + joint.set("damping", str(random.uniform(damping_range[0], damping_range[1]))) + return body + + +def _make_model( + n_bodies, n_actuators, random, stiffness_range=(15, 25), damping_range=(0, 0) +): + """Returns an MJCF XML string defining a model of springs and dampers. + + Args: + n_bodies: An integer, the number of bodies (DoFs) in the system. + n_actuators: An integer, the number of actuated bodies. + random: A `numpy.random.RandomState` instance. + stiffness_range: A tuple containing minimum and maximum stiffness. Each + joint's stiffness is sampled uniformly from this interval. + damping_range: A tuple containing minimum and maximum damping. Each joint's + damping is sampled uniformly from this interval. + + Returns: + An MJCF string describing the linear system. + + Raises: + ValueError: If the number of bodies or actuators is erronous. + """ + if n_bodies < 1 or n_actuators < 1: + raise ValueError("At least 1 body and 1 actuator required.") + if n_actuators > n_bodies: + raise ValueError("At most 1 actuator per body.") + + file_path = os.path.join(os.path.dirname(__file__), "lqr.xml") + with resources.GetResourceAsFile(file_path) as xml_file: + mjcf = xml_tools.parse(xml_file) + parent = mjcf.find("./worldbody") + actuator = etree.SubElement(mjcf.getroot(), "actuator") + tendon = etree.SubElement(mjcf.getroot(), "tendon") + + for body in range(n_bodies): + # Inserting body. + child = _make_body(body, stiffness_range, damping_range, random) + site_name = "site_{}".format(body) + child.append(etree.Element("site", name=site_name)) + + if body == 0: + child.set("pos", ".25 0 .1") + # Add actuators to the first n_actuators bodies. + if body < n_actuators: + # Adding actuator. + joint_name = "joint_{}".format(body) + motor_name = "motor_{}".format(body) + child.find("joint").set("name", joint_name) + actuator.append(etree.Element("motor", name=motor_name, joint=joint_name)) + + # Add a tendon between consecutive bodies (for visualisation purposes only). + if body < n_bodies - 1: + child_site_name = "site_{}".format(body + 1) + tendon_name = "tendon_{}".format(body) + spatial = etree.SubElement(tendon, "spatial", name=tendon_name) + spatial.append(etree.Element("site", site=site_name)) + spatial.append(etree.Element("site", site=child_site_name)) + parent.append(child) + parent = child + + return etree.tostring(mjcf, pretty_print=True) + + +class Physics(mujoco.Physics): + """Physics simulation with additional features for the LQR domain.""" + + def state_norm(self): + """Returns the norm of the physics state.""" + return np.linalg.norm(self.state()) + + +class LQRLevel(base.Task): + """A Linear Quadratic Regulator `Task`.""" + + _TERMINAL_TOL = 1e-6 + + def __init__(self, control_cost_coef, random=None): + """Initializes an LQR level with cost = sum(states^2) + c*sum(controls^2). + + Args: + control_cost_coef: The coefficient of the control cost. + random: Optional, either a `numpy.random.RandomState` instance, an + integer seed for creating a new `RandomState`, or None to select a seed + automatically (default). + + Raises: + ValueError: If the control cost coefficient is not positive. + """ + if control_cost_coef <= 0: + raise ValueError("control_cost_coef must be positive.") + + self._control_cost_coef = control_cost_coef + super(LQRLevel, self).__init__(random=random) + + @property + def control_cost_coef(self): + return self._control_cost_coef + + def initialize_episode(self, physics): + """Random state sampled from a unit sphere.""" + ndof = physics.model.nq + unit = self.random.randn(ndof) + physics.data.qpos[:] = np.sqrt(2) * unit / np.linalg.norm(unit) + super(LQRLevel, self).initialize_episode(physics) + + def get_observation(self, physics): + """Returns an observation of the state.""" + obs = collections.OrderedDict() + obs["position"] = physics.position() + obs["velocity"] = physics.velocity() + return obs + + def get_reward(self, physics): + """Returns a quadratic state and control reward.""" + position = physics.position() + state_cost = 0.5 * np.dot(position, position) + control_signal = physics.control() + control_l2_norm = 0.5 * np.dot(control_signal, control_signal) + return 1 - (state_cost + control_l2_norm * self._control_cost_coef) + + def get_evaluation(self, physics): + """Returns a sparse evaluation reward that is not used for learning.""" + return float(physics.state_norm() <= 0.01) + + def get_termination(self, physics): + """Terminates when the state norm is smaller than epsilon.""" + if physics.state_norm() < self._TERMINAL_TOL: + return 0.0 diff --git a/local_dm_control_suite/lqr.xml b/local_dm_control_suite/lqr.xml new file mode 100755 index 0000000..d403532 --- /dev/null +++ b/local_dm_control_suite/lqr.xml @@ -0,0 +1,26 @@ + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/lqr_solver.py b/local_dm_control_suite/lqr_solver.py new file mode 100755 index 0000000..9376eca --- /dev/null +++ b/local_dm_control_suite/lqr_solver.py @@ -0,0 +1,146 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +r"""Optimal policy for LQR levels. + +LQR control problem is described in +https://en.wikipedia.org/wiki/Linear-quadratic_regulator#Infinite-horizon.2C_discrete-time_LQR +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl import logging +from dm_control.mujoco import wrapper +import numpy as np +from six.moves import range + +try: + import scipy.linalg as sp # pylint: disable=g-import-not-at-top +except ImportError: + sp = None + + +def _solve_dare(a, b, q, r): + """Solves the Discrete-time Algebraic Riccati Equation (DARE) by iteration. + + Algebraic Riccati Equation: + ```none + P_{t-1} = Q + A' * P_{t} * A - + A' * P_{t} * B * (R + B' * P_{t} * B)^{-1} * B' * P_{t} * A + ``` + + Args: + a: A 2 dimensional numpy array, transition matrix A. + b: A 2 dimensional numpy array, control matrix B. + q: A 2 dimensional numpy array, symmetric positive definite cost matrix. + r: A 2 dimensional numpy array, symmetric positive definite cost matrix + + Returns: + A numpy array, a real symmetric matrix P which is the solution to DARE. + + Raises: + RuntimeError: If the computed P matrix is not symmetric and + positive-definite. + """ + p = np.eye(len(a)) + for _ in range(1000000): + a_p = a.T.dot(p) # A' * P_t + a_p_b = np.dot(a_p, b) # A' * P_t * B + # Algebraic Riccati Equation. + p_next = ( + q + + np.dot(a_p, a) + - a_p_b.dot(np.linalg.solve(b.T.dot(p.dot(b)) + r, a_p_b.T)) + ) + p_next += p_next.T + p_next *= 0.5 + if np.abs(p - p_next).max() < 1e-12: + break + p = p_next + else: + logging.warning("DARE solver did not converge") + try: + # Check that the result is symmetric and positive-definite. + np.linalg.cholesky(p_next) + except np.linalg.LinAlgError: + raise RuntimeError( + "ARE solver failed: P matrix is not symmetric and " "positive-definite." + ) + return p_next + + +def solve(env): + """Returns the optimal value and policy for LQR problem. + + Args: + env: An instance of `control.EnvironmentV2` with LQR level. + + Returns: + p: A numpy array, the Hessian of the optimal total cost-to-go (value + function at state x) is V(x) = .5 * x' * p * x. + k: A numpy array which gives the optimal linear policy u = k * x. + beta: The maximum eigenvalue of (a + b * k). Under optimal policy, at + timestep n the state tends to 0 like beta^n. + + Raises: + RuntimeError: If the controlled system is unstable. + """ + n = env.physics.model.nq # number of DoFs + m = env.physics.model.nu # number of controls + + # Compute the mass matrix. + mass = np.zeros((n, n)) + wrapper.mjbindings.mjlib.mj_fullM(env.physics.model.ptr, mass, env.physics.data.qM) + + # Compute input matrices a, b, q and r to the DARE solvers. + # State transition matrix a. + stiffness = np.diag(env.physics.model.jnt_stiffness.ravel()) + damping = np.diag(env.physics.model.dof_damping.ravel()) + dt = env.physics.model.opt.timestep + + j = np.linalg.solve(-mass, np.hstack((stiffness, damping))) + a = np.eye(2 * n) + dt * np.vstack( + (dt * j + np.hstack((np.zeros((n, n)), np.eye(n))), j) + ) + + # Control transition matrix b. + b = env.physics.data.actuator_moment.T + bc = np.linalg.solve(mass, b) + b = dt * np.vstack((dt * bc, bc)) + + # State cost Hessian q. + q = np.diag(np.hstack([np.ones(n), np.zeros(n)])) + + # Control cost Hessian r. + r = env.task.control_cost_coef * np.eye(m) + + if sp: + # Use scipy's faster DARE solver if available. + solve_dare = sp.solve_discrete_are + else: + # Otherwise fall back on a slower internal implementation. + solve_dare = _solve_dare + + # Solve the discrete algebraic Riccati equation. + p = solve_dare(a, b, q, r) + k = -np.linalg.solve(b.T.dot(p.dot(b)) + r, b.T.dot(p.dot(a))) + + # Under optimal policy, state tends to 0 like beta^n_timesteps + beta = np.abs(np.linalg.eigvals(a + b.dot(k))).max() + if beta >= 1.0: + raise RuntimeError("Controlled system is unstable.") + return p, k, beta diff --git a/local_dm_control_suite/manipulator.py b/local_dm_control_suite/manipulator.py new file mode 100755 index 0000000..5a9e28e --- /dev/null +++ b/local_dm_control_suite/manipulator.py @@ -0,0 +1,329 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Planar Manipulator domain.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from dm_control import mujoco +from dm_control.rl import control +from . import base +from . import common +from dm_control.utils import containers +from dm_control.utils import rewards +from dm_control.utils import xml_tools + +from lxml import etree +import numpy as np + +_CLOSE = 0.01 # (Meters) Distance below which a thing is considered close. +_CONTROL_TIMESTEP = 0.01 # (Seconds) +_TIME_LIMIT = 10 # (Seconds) +_P_IN_HAND = 0.1 # Probabillity of object-in-hand initial state +_P_IN_TARGET = 0.1 # Probabillity of object-in-target initial state +_ARM_JOINTS = [ + "arm_root", + "arm_shoulder", + "arm_elbow", + "arm_wrist", + "finger", + "fingertip", + "thumb", + "thumbtip", +] +_ALL_PROPS = frozenset(["ball", "target_ball", "cup", "peg", "target_peg", "slot"]) + +SUITE = containers.TaggedTasks() + + +def make_model(use_peg, insert): + """Returns a tuple containing the model XML string and a dict of assets.""" + xml_string = common.read_model("manipulator.xml") + parser = etree.XMLParser(remove_blank_text=True) + mjcf = etree.XML(xml_string, parser) + + # Select the desired prop. + if use_peg: + required_props = ["peg", "target_peg"] + if insert: + required_props += ["slot"] + else: + required_props = ["ball", "target_ball"] + if insert: + required_props += ["cup"] + + # Remove unused props + for unused_prop in _ALL_PROPS.difference(required_props): + prop = xml_tools.find_element(mjcf, "body", unused_prop) + prop.getparent().remove(prop) + + return etree.tostring(mjcf, pretty_print=True), common.ASSETS + + +@SUITE.add("benchmarking", "hard") +def bring_ball( + fully_observable=True, time_limit=_TIME_LIMIT, random=None, environment_kwargs=None +): + """Returns manipulator bring task with the ball prop.""" + use_peg = False + insert = False + physics = Physics.from_xml_string(*make_model(use_peg, insert)) + task = Bring( + use_peg=use_peg, insert=insert, fully_observable=fully_observable, random=random + ) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, + task, + control_timestep=_CONTROL_TIMESTEP, + time_limit=time_limit, + **environment_kwargs + ) + + +@SUITE.add("hard") +def bring_peg( + fully_observable=True, time_limit=_TIME_LIMIT, random=None, environment_kwargs=None +): + """Returns manipulator bring task with the peg prop.""" + use_peg = True + insert = False + physics = Physics.from_xml_string(*make_model(use_peg, insert)) + task = Bring( + use_peg=use_peg, insert=insert, fully_observable=fully_observable, random=random + ) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, + task, + control_timestep=_CONTROL_TIMESTEP, + time_limit=time_limit, + **environment_kwargs + ) + + +@SUITE.add("hard") +def insert_ball( + fully_observable=True, time_limit=_TIME_LIMIT, random=None, environment_kwargs=None +): + """Returns manipulator insert task with the ball prop.""" + use_peg = False + insert = True + physics = Physics.from_xml_string(*make_model(use_peg, insert)) + task = Bring( + use_peg=use_peg, insert=insert, fully_observable=fully_observable, random=random + ) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, + task, + control_timestep=_CONTROL_TIMESTEP, + time_limit=time_limit, + **environment_kwargs + ) + + +@SUITE.add("hard") +def insert_peg( + fully_observable=True, time_limit=_TIME_LIMIT, random=None, environment_kwargs=None +): + """Returns manipulator insert task with the peg prop.""" + use_peg = True + insert = True + physics = Physics.from_xml_string(*make_model(use_peg, insert)) + task = Bring( + use_peg=use_peg, insert=insert, fully_observable=fully_observable, random=random + ) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, + task, + control_timestep=_CONTROL_TIMESTEP, + time_limit=time_limit, + **environment_kwargs + ) + + +class Physics(mujoco.Physics): + """Physics with additional features for the Planar Manipulator domain.""" + + def bounded_joint_pos(self, joint_names): + """Returns joint positions as (sin, cos) values.""" + joint_pos = self.named.data.qpos[joint_names] + return np.vstack([np.sin(joint_pos), np.cos(joint_pos)]).T + + def joint_vel(self, joint_names): + """Returns joint velocities.""" + return self.named.data.qvel[joint_names] + + def body_2d_pose(self, body_names, orientation=True): + """Returns positions and/or orientations of bodies.""" + if not isinstance(body_names, str): + body_names = np.array(body_names).reshape(-1, 1) # Broadcast indices. + pos = self.named.data.xpos[body_names, ["x", "z"]] + if orientation: + ori = self.named.data.xquat[body_names, ["qw", "qy"]] + return np.hstack([pos, ori]) + else: + return pos + + def touch(self): + return np.log1p(self.data.sensordata) + + def site_distance(self, site1, site2): + site1_to_site2 = np.diff(self.named.data.site_xpos[[site2, site1]], axis=0) + return np.linalg.norm(site1_to_site2) + + +class Bring(base.Task): + """A Bring `Task`: bring the prop to the target.""" + + def __init__(self, use_peg, insert, fully_observable, random=None): + """Initialize an instance of the `Bring` task. + + Args: + use_peg: A `bool`, whether to replace the ball prop with the peg prop. + insert: A `bool`, whether to insert the prop in a receptacle. + fully_observable: A `bool`, whether the observation should contain the + position and velocity of the object being manipulated and the target + location. + random: Optional, either a `numpy.random.RandomState` instance, an + integer seed for creating a new `RandomState`, or None to select a seed + automatically (default). + """ + self._use_peg = use_peg + self._target = "target_peg" if use_peg else "target_ball" + self._object = "peg" if self._use_peg else "ball" + self._object_joints = ["_".join([self._object, dim]) for dim in "xzy"] + self._receptacle = "slot" if self._use_peg else "cup" + self._insert = insert + self._fully_observable = fully_observable + super(Bring, self).__init__(random=random) + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode.""" + # Local aliases + choice = self.random.choice + uniform = self.random.uniform + model = physics.named.model + data = physics.named.data + + # Find a collision-free random initial configuration. + penetrating = True + while penetrating: + + # Randomise angles of arm joints. + is_limited = model.jnt_limited[_ARM_JOINTS].astype(np.bool) + joint_range = model.jnt_range[_ARM_JOINTS] + lower_limits = np.where(is_limited, joint_range[:, 0], -np.pi) + upper_limits = np.where(is_limited, joint_range[:, 1], np.pi) + angles = uniform(lower_limits, upper_limits) + data.qpos[_ARM_JOINTS] = angles + + # Symmetrize hand. + data.qpos["finger"] = data.qpos["thumb"] + + # Randomise target location. + target_x = uniform(-0.4, 0.4) + target_z = uniform(0.1, 0.4) + if self._insert: + target_angle = uniform(-np.pi / 3, np.pi / 3) + model.body_pos[self._receptacle, ["x", "z"]] = target_x, target_z + model.body_quat[self._receptacle, ["qw", "qy"]] = [ + np.cos(target_angle / 2), + np.sin(target_angle / 2), + ] + else: + target_angle = uniform(-np.pi, np.pi) + + model.body_pos[self._target, ["x", "z"]] = target_x, target_z + model.body_quat[self._target, ["qw", "qy"]] = [ + np.cos(target_angle / 2), + np.sin(target_angle / 2), + ] + + # Randomise object location. + object_init_probs = [ + _P_IN_HAND, + _P_IN_TARGET, + 1 - _P_IN_HAND - _P_IN_TARGET, + ] + init_type = choice(["in_hand", "in_target", "uniform"], p=object_init_probs) + if init_type == "in_target": + object_x = target_x + object_z = target_z + object_angle = target_angle + elif init_type == "in_hand": + physics.after_reset() + object_x = data.site_xpos["grasp", "x"] + object_z = data.site_xpos["grasp", "z"] + grasp_direction = data.site_xmat["grasp", ["xx", "zx"]] + object_angle = np.pi - np.arctan2( + grasp_direction[1], grasp_direction[0] + ) + else: + object_x = uniform(-0.5, 0.5) + object_z = uniform(0, 0.7) + object_angle = uniform(0, 2 * np.pi) + data.qvel[self._object + "_x"] = uniform(-5, 5) + + data.qpos[self._object_joints] = object_x, object_z, object_angle + + # Check for collisions. + physics.after_reset() + penetrating = physics.data.ncon > 0 + + super(Bring, self).initialize_episode(physics) + + def get_observation(self, physics): + """Returns either features or only sensors (to be used with pixels).""" + obs = collections.OrderedDict() + obs["arm_pos"] = physics.bounded_joint_pos(_ARM_JOINTS) + obs["arm_vel"] = physics.joint_vel(_ARM_JOINTS) + obs["touch"] = physics.touch() + if self._fully_observable: + obs["hand_pos"] = physics.body_2d_pose("hand") + obs["object_pos"] = physics.body_2d_pose(self._object) + obs["object_vel"] = physics.joint_vel(self._object_joints) + obs["target_pos"] = physics.body_2d_pose(self._target) + return obs + + def _is_close(self, distance): + return rewards.tolerance(distance, (0, _CLOSE), _CLOSE * 2) + + def _peg_reward(self, physics): + """Returns a reward for bringing the peg prop to the target.""" + grasp = self._is_close(physics.site_distance("peg_grasp", "grasp")) + pinch = self._is_close(physics.site_distance("peg_pinch", "pinch")) + grasping = (grasp + pinch) / 2 + bring = self._is_close(physics.site_distance("peg", "target_peg")) + bring_tip = self._is_close(physics.site_distance("target_peg_tip", "peg_tip")) + bringing = (bring + bring_tip) / 2 + return max(bringing, grasping / 3) + + def _ball_reward(self, physics): + """Returns a reward for bringing the ball prop to the target.""" + return self._is_close(physics.site_distance("ball", "target_ball")) + + def get_reward(self, physics): + """Returns a reward to the agent.""" + if self._use_peg: + return self._peg_reward(physics) + else: + return self._ball_reward(physics) diff --git a/local_dm_control_suite/manipulator.xml b/local_dm_control_suite/manipulator.xml new file mode 100755 index 0000000..d6d1767 --- /dev/null +++ b/local_dm_control_suite/manipulator.xml @@ -0,0 +1,211 @@ + + + + + + + + + + + + + + > + + diff --git a/local_dm_control_suite/pendulum.py b/local_dm_control_suite/pendulum.py new file mode 100755 index 0000000..07806ae --- /dev/null +++ b/local_dm_control_suite/pendulum.py @@ -0,0 +1,114 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Pendulum domain.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from dm_control import mujoco +from dm_control.rl import control +from . import base +from . import common +from dm_control.utils import containers +from dm_control.utils import rewards +import numpy as np + + +_DEFAULT_TIME_LIMIT = 20 +_ANGLE_BOUND = 8 +_COSINE_BOUND = np.cos(np.deg2rad(_ANGLE_BOUND)) +SUITE = containers.TaggedTasks() + + +def get_model_and_assets(): + """Returns a tuple containing the model XML string and a dict of assets.""" + return common.read_model("pendulum.xml"), common.ASSETS + + +@SUITE.add("benchmarking") +def swingup(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): + """Returns pendulum swingup task .""" + physics = Physics.from_xml_string(*get_model_and_assets()) + task = SwingUp(random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, task, time_limit=time_limit, **environment_kwargs + ) + + +class Physics(mujoco.Physics): + """Physics simulation with additional features for the Pendulum domain.""" + + def pole_vertical(self): + """Returns vertical (z) component of pole frame.""" + return self.named.data.xmat["pole", "zz"] + + def angular_velocity(self): + """Returns the angular velocity of the pole.""" + return self.named.data.qvel["hinge"].copy() + + def pole_orientation(self): + """Returns both horizontal and vertical components of pole frame.""" + return self.named.data.xmat["pole", ["zz", "xz"]] + + +class SwingUp(base.Task): + """A Pendulum `Task` to swing up and balance the pole.""" + + def __init__(self, random=None): + """Initialize an instance of `Pendulum`. + + Args: + random: Optional, either a `numpy.random.RandomState` instance, an + integer seed for creating a new `RandomState`, or None to select a seed + automatically (default). + """ + super(SwingUp, self).__init__(random=random) + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode. + + Pole is set to a random angle between [-pi, pi). + + Args: + physics: An instance of `Physics`. + + """ + physics.named.data.qpos["hinge"] = self.random.uniform(-np.pi, np.pi) + super(SwingUp, self).initialize_episode(physics) + + def get_observation(self, physics): + """Returns an observation. + + Observations are states concatenating pole orientation and angular velocity + and pixels from fixed camera. + + Args: + physics: An instance of `physics`, Pendulum physics. + + Returns: + A `dict` of observation. + """ + obs = collections.OrderedDict() + obs["orientation"] = physics.pole_orientation() + obs["velocity"] = physics.angular_velocity() + return obs + + def get_reward(self, physics): + return rewards.tolerance(physics.pole_vertical(), (_COSINE_BOUND, 1)) diff --git a/local_dm_control_suite/pendulum.xml b/local_dm_control_suite/pendulum.xml new file mode 100755 index 0000000..14377ae --- /dev/null +++ b/local_dm_control_suite/pendulum.xml @@ -0,0 +1,26 @@ + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/point_mass.py b/local_dm_control_suite/point_mass.py new file mode 100755 index 0000000..bbd4078 --- /dev/null +++ b/local_dm_control_suite/point_mass.py @@ -0,0 +1,134 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Point-mass domain.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from dm_control import mujoco +from dm_control.rl import control +from . import base +from . import common +from dm_control.suite.utils import randomizers +from dm_control.utils import containers +from dm_control.utils import rewards +import numpy as np + +_DEFAULT_TIME_LIMIT = 20 +SUITE = containers.TaggedTasks() + + +def get_model_and_assets(): + """Returns a tuple containing the model XML string and a dict of assets.""" + return common.read_model("point_mass.xml"), common.ASSETS + + +@SUITE.add("benchmarking", "easy") +def easy(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): + """Returns the easy point_mass task.""" + physics = Physics.from_xml_string(*get_model_and_assets()) + task = PointMass(randomize_gains=False, random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, task, time_limit=time_limit, **environment_kwargs + ) + + +@SUITE.add() +def hard(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): + """Returns the hard point_mass task.""" + physics = Physics.from_xml_string(*get_model_and_assets()) + task = PointMass(randomize_gains=True, random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, task, time_limit=time_limit, **environment_kwargs + ) + + +class Physics(mujoco.Physics): + """physics for the point_mass domain.""" + + def mass_to_target(self): + """Returns the vector from mass to target in global coordinate.""" + return ( + self.named.data.geom_xpos["target"] - self.named.data.geom_xpos["pointmass"] + ) + + def mass_to_target_dist(self): + """Returns the distance from mass to the target.""" + return np.linalg.norm(self.mass_to_target()) + + +class PointMass(base.Task): + """A point_mass `Task` to reach target with smooth reward.""" + + def __init__(self, randomize_gains, random=None): + """Initialize an instance of `PointMass`. + + Args: + randomize_gains: A `bool`, whether to randomize the actuator gains. + random: Optional, either a `numpy.random.RandomState` instance, an + integer seed for creating a new `RandomState`, or None to select a seed + automatically (default). + """ + self._randomize_gains = randomize_gains + super(PointMass, self).__init__(random=random) + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode. + + If _randomize_gains is True, the relationship between the controls and + the joints is randomized, so that each control actuates a random linear + combination of joints. + + Args: + physics: An instance of `mujoco.Physics`. + """ + randomizers.randomize_limited_and_rotational_joints(physics, self.random) + if self._randomize_gains: + dir1 = self.random.randn(2) + dir1 /= np.linalg.norm(dir1) + # Find another actuation direction that is not 'too parallel' to dir1. + parallel = True + while parallel: + dir2 = self.random.randn(2) + dir2 /= np.linalg.norm(dir2) + parallel = abs(np.dot(dir1, dir2)) > 0.9 + physics.model.wrap_prm[[0, 1]] = dir1 + physics.model.wrap_prm[[2, 3]] = dir2 + super(PointMass, self).initialize_episode(physics) + + def get_observation(self, physics): + """Returns an observation of the state.""" + obs = collections.OrderedDict() + obs["position"] = physics.position() + obs["velocity"] = physics.velocity() + return obs + + def get_reward(self, physics): + """Returns a reward to the agent.""" + target_size = physics.named.model.geom_size["target", 0] + near_target = rewards.tolerance( + physics.mass_to_target_dist(), bounds=(0, target_size), margin=target_size + ) + control_reward = rewards.tolerance( + physics.control(), margin=1, value_at_margin=0, sigmoid="quadratic" + ).mean() + small_control = (control_reward + 4) / 5 + return near_target * small_control diff --git a/local_dm_control_suite/point_mass.xml b/local_dm_control_suite/point_mass.xml new file mode 100755 index 0000000..c447cf6 --- /dev/null +++ b/local_dm_control_suite/point_mass.xml @@ -0,0 +1,49 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/quadruped.py b/local_dm_control_suite/quadruped.py new file mode 100755 index 0000000..422d70d --- /dev/null +++ b/local_dm_control_suite/quadruped.py @@ -0,0 +1,514 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Quadruped Domain.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from dm_control import mujoco +from dm_control.mujoco.wrapper import mjbindings +from dm_control.rl import control +from . import base +from . import common +from dm_control.utils import containers +from dm_control.utils import rewards +from dm_control.utils import xml_tools + +from lxml import etree +import numpy as np +from scipy import ndimage + +enums = mjbindings.enums +mjlib = mjbindings.mjlib + + +_DEFAULT_TIME_LIMIT = 20 +_CONTROL_TIMESTEP = 0.02 + +# Horizontal speeds above which the move reward is 1. +_RUN_SPEED = 5 +_WALK_SPEED = 0.5 + +# Constants related to terrain generation. +_HEIGHTFIELD_ID = 0 +_TERRAIN_SMOOTHNESS = 0.15 # 0.0: maximally bumpy; 1.0: completely smooth. +_TERRAIN_BUMP_SCALE = 2 # Spatial scale of terrain bumps (in meters). + +# Named model elements. +_TOES = ["toe_front_left", "toe_back_left", "toe_back_right", "toe_front_right"] +_WALLS = ["wall_px", "wall_py", "wall_nx", "wall_ny"] + +SUITE = containers.TaggedTasks() + + +def make_model( + floor_size=None, terrain=False, rangefinders=False, walls_and_ball=False +): + """Returns the model XML string.""" + xml_string = common.read_model("quadruped.xml") + parser = etree.XMLParser(remove_blank_text=True) + mjcf = etree.XML(xml_string, parser) + + # Set floor size. + if floor_size is not None: + floor_geom = mjcf.find(".//geom[@name={!r}]".format("floor")) + floor_geom.attrib["size"] = "{} {} .5".format(floor_size, floor_size) + + # Remove walls, ball and target. + if not walls_and_ball: + for wall in _WALLS: + wall_geom = xml_tools.find_element(mjcf, "geom", wall) + wall_geom.getparent().remove(wall_geom) + + # Remove ball. + ball_body = xml_tools.find_element(mjcf, "body", "ball") + ball_body.getparent().remove(ball_body) + + # Remove target. + target_site = xml_tools.find_element(mjcf, "site", "target") + target_site.getparent().remove(target_site) + + # Remove terrain. + if not terrain: + terrain_geom = xml_tools.find_element(mjcf, "geom", "terrain") + terrain_geom.getparent().remove(terrain_geom) + + # Remove rangefinders if they're not used, as range computations can be + # expensive, especially in a scene with heightfields. + if not rangefinders: + rangefinder_sensors = mjcf.findall(".//rangefinder") + for rf in rangefinder_sensors: + rf.getparent().remove(rf) + + return etree.tostring(mjcf, pretty_print=True) + + +@SUITE.add() +def walk(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): + """Returns the Walk task.""" + xml_string = make_model(floor_size=_DEFAULT_TIME_LIMIT * _WALK_SPEED) + physics = Physics.from_xml_string(xml_string, common.ASSETS) + task = Move(desired_speed=_WALK_SPEED, random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, + task, + time_limit=time_limit, + control_timestep=_CONTROL_TIMESTEP, + **environment_kwargs + ) + + +@SUITE.add() +def run(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): + """Returns the Run task.""" + xml_string = make_model(floor_size=_DEFAULT_TIME_LIMIT * _RUN_SPEED) + physics = Physics.from_xml_string(xml_string, common.ASSETS) + task = Move(desired_speed=_RUN_SPEED, random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, + task, + time_limit=time_limit, + control_timestep=_CONTROL_TIMESTEP, + **environment_kwargs + ) + + +@SUITE.add() +def escape(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): + """Returns the Escape task.""" + xml_string = make_model(floor_size=40, terrain=True, rangefinders=True) + physics = Physics.from_xml_string(xml_string, common.ASSETS) + task = Escape(random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, + task, + time_limit=time_limit, + control_timestep=_CONTROL_TIMESTEP, + **environment_kwargs + ) + + +@SUITE.add() +def fetch(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): + """Returns the Fetch task.""" + xml_string = make_model(walls_and_ball=True) + physics = Physics.from_xml_string(xml_string, common.ASSETS) + task = Fetch(random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, + task, + time_limit=time_limit, + control_timestep=_CONTROL_TIMESTEP, + **environment_kwargs + ) + + +class Physics(mujoco.Physics): + """Physics simulation with additional features for the Quadruped domain.""" + + def _reload_from_data(self, data): + super(Physics, self)._reload_from_data(data) + # Clear cached sensor names when the physics is reloaded. + self._sensor_types_to_names = {} + self._hinge_names = [] + + def _get_sensor_names(self, *sensor_types): + try: + sensor_names = self._sensor_types_to_names[sensor_types] + except KeyError: + [sensor_ids] = np.where(np.in1d(self.model.sensor_type, sensor_types)) + sensor_names = [self.model.id2name(s_id, "sensor") for s_id in sensor_ids] + self._sensor_types_to_names[sensor_types] = sensor_names + return sensor_names + + def torso_upright(self): + """Returns the dot-product of the torso z-axis and the global z-axis.""" + return np.asarray(self.named.data.xmat["torso", "zz"]) + + def torso_velocity(self): + """Returns the velocity of the torso, in the local frame.""" + return self.named.data.sensordata["velocimeter"].copy() + + def egocentric_state(self): + """Returns the state without global orientation or position.""" + if not self._hinge_names: + [hinge_ids] = np.nonzero(self.model.jnt_type == enums.mjtJoint.mjJNT_HINGE) + self._hinge_names = [ + self.model.id2name(j_id, "joint") for j_id in hinge_ids + ] + return np.hstack( + ( + self.named.data.qpos[self._hinge_names], + self.named.data.qvel[self._hinge_names], + self.data.act, + ) + ) + + def toe_positions(self): + """Returns toe positions in egocentric frame.""" + torso_frame = self.named.data.xmat["torso"].reshape(3, 3) + torso_pos = self.named.data.xpos["torso"] + torso_to_toe = self.named.data.xpos[_TOES] - torso_pos + return torso_to_toe.dot(torso_frame) + + def force_torque(self): + """Returns scaled force/torque sensor readings at the toes.""" + force_torque_sensors = self._get_sensor_names( + enums.mjtSensor.mjSENS_FORCE, enums.mjtSensor.mjSENS_TORQUE + ) + return np.arcsinh(self.named.data.sensordata[force_torque_sensors]) + + def imu(self): + """Returns IMU-like sensor readings.""" + imu_sensors = self._get_sensor_names( + enums.mjtSensor.mjSENS_GYRO, enums.mjtSensor.mjSENS_ACCELEROMETER + ) + return self.named.data.sensordata[imu_sensors] + + def rangefinder(self): + """Returns scaled rangefinder sensor readings.""" + rf_sensors = self._get_sensor_names(enums.mjtSensor.mjSENS_RANGEFINDER) + rf_readings = self.named.data.sensordata[rf_sensors] + no_intersection = -1.0 + return np.where(rf_readings == no_intersection, 1.0, np.tanh(rf_readings)) + + def origin_distance(self): + """Returns the distance from the origin to the workspace.""" + return np.asarray(np.linalg.norm(self.named.data.site_xpos["workspace"])) + + def origin(self): + """Returns origin position in the torso frame.""" + torso_frame = self.named.data.xmat["torso"].reshape(3, 3) + torso_pos = self.named.data.xpos["torso"] + return -torso_pos.dot(torso_frame) + + def ball_state(self): + """Returns ball position and velocity relative to the torso frame.""" + data = self.named.data + torso_frame = data.xmat["torso"].reshape(3, 3) + ball_rel_pos = data.xpos["ball"] - data.xpos["torso"] + ball_rel_vel = data.qvel["ball_root"][:3] - data.qvel["root"][:3] + ball_rot_vel = data.qvel["ball_root"][3:] + ball_state = np.vstack((ball_rel_pos, ball_rel_vel, ball_rot_vel)) + return ball_state.dot(torso_frame).ravel() + + def target_position(self): + """Returns target position in torso frame.""" + torso_frame = self.named.data.xmat["torso"].reshape(3, 3) + torso_pos = self.named.data.xpos["torso"] + torso_to_target = self.named.data.site_xpos["target"] - torso_pos + return torso_to_target.dot(torso_frame) + + def ball_to_target_distance(self): + """Returns horizontal distance from the ball to the target.""" + ball_to_target = ( + self.named.data.site_xpos["target"] - self.named.data.xpos["ball"] + ) + return np.linalg.norm(ball_to_target[:2]) + + def self_to_ball_distance(self): + """Returns horizontal distance from the quadruped workspace to the ball.""" + self_to_ball = ( + self.named.data.site_xpos["workspace"] - self.named.data.xpos["ball"] + ) + return np.linalg.norm(self_to_ball[:2]) + + +def _find_non_contacting_height(physics, orientation, x_pos=0.0, y_pos=0.0): + """Find a height with no contacts given a body orientation. + + Args: + physics: An instance of `Physics`. + orientation: A quaternion. + x_pos: A float. Position along global x-axis. + y_pos: A float. Position along global y-axis. + Raises: + RuntimeError: If a non-contacting configuration has not been found after + 10,000 attempts. + """ + z_pos = 0.0 # Start embedded in the floor. + num_contacts = 1 + num_attempts = 0 + # Move up in 1cm increments until no contacts. + while num_contacts > 0: + try: + with physics.reset_context(): + physics.named.data.qpos["root"][:3] = x_pos, y_pos, z_pos + physics.named.data.qpos["root"][3:] = orientation + except control.PhysicsError: + # We may encounter a PhysicsError here due to filling the contact + # buffer, in which case we simply increment the height and continue. + pass + num_contacts = physics.data.ncon + z_pos += 0.01 + num_attempts += 1 + if num_attempts > 10000: + raise RuntimeError("Failed to find a non-contacting configuration.") + + +def _common_observations(physics): + """Returns the observations common to all tasks.""" + obs = collections.OrderedDict() + obs["egocentric_state"] = physics.egocentric_state() + obs["torso_velocity"] = physics.torso_velocity() + obs["torso_upright"] = physics.torso_upright() + obs["imu"] = physics.imu() + obs["force_torque"] = physics.force_torque() + return obs + + +def _upright_reward(physics, deviation_angle=0): + """Returns a reward proportional to how upright the torso is. + + Args: + physics: an instance of `Physics`. + deviation_angle: A float, in degrees. The reward is 0 when the torso is + exactly upside-down and 1 when the torso's z-axis is less than + `deviation_angle` away from the global z-axis. + """ + deviation = np.cos(np.deg2rad(deviation_angle)) + return rewards.tolerance( + physics.torso_upright(), + bounds=(deviation, float("inf")), + sigmoid="linear", + margin=1 + deviation, + value_at_margin=0, + ) + + +class Move(base.Task): + """A quadruped task solved by moving forward at a designated speed.""" + + def __init__(self, desired_speed, random=None): + """Initializes an instance of `Move`. + + Args: + desired_speed: A float. If this value is zero, reward is given simply + for standing upright. Otherwise this specifies the horizontal velocity + at which the velocity-dependent reward component is maximized. + random: Optional, either a `numpy.random.RandomState` instance, an + integer seed for creating a new `RandomState`, or None to select a seed + automatically (default). + """ + self._desired_speed = desired_speed + super(Move, self).__init__(random=random) + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode. + + Args: + physics: An instance of `Physics`. + + """ + # Initial configuration. + orientation = self.random.randn(4) + orientation /= np.linalg.norm(orientation) + _find_non_contacting_height(physics, orientation) + super(Move, self).initialize_episode(physics) + + def get_observation(self, physics): + """Returns an observation to the agent.""" + return _common_observations(physics) + + def get_reward(self, physics): + """Returns a reward to the agent.""" + + # Move reward term. + move_reward = rewards.tolerance( + physics.torso_velocity()[0], + bounds=(self._desired_speed, float("inf")), + margin=self._desired_speed, + value_at_margin=0.5, + sigmoid="linear", + ) + + return _upright_reward(physics) * move_reward + + +class Escape(base.Task): + """A quadruped task solved by escaping a bowl-shaped terrain.""" + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode. + + Args: + physics: An instance of `Physics`. + + """ + # Get heightfield resolution, assert that it is square. + res = physics.model.hfield_nrow[_HEIGHTFIELD_ID] + assert res == physics.model.hfield_ncol[_HEIGHTFIELD_ID] + # Sinusoidal bowl shape. + row_grid, col_grid = np.ogrid[-1 : 1 : res * 1j, -1 : 1 : res * 1j] + radius = np.clip(np.sqrt(col_grid ** 2 + row_grid ** 2), 0.04, 1) + bowl_shape = 0.5 - np.cos(2 * np.pi * radius) / 2 + # Random smooth bumps. + terrain_size = 2 * physics.model.hfield_size[_HEIGHTFIELD_ID, 0] + bump_res = int(terrain_size / _TERRAIN_BUMP_SCALE) + bumps = self.random.uniform(_TERRAIN_SMOOTHNESS, 1, (bump_res, bump_res)) + smooth_bumps = ndimage.zoom(bumps, res / float(bump_res)) + # Terrain is elementwise product. + terrain = bowl_shape * smooth_bumps + start_idx = physics.model.hfield_adr[_HEIGHTFIELD_ID] + physics.model.hfield_data[start_idx : start_idx + res ** 2] = terrain.ravel() + super(Escape, self).initialize_episode(physics) + + # If we have a rendering context, we need to re-upload the modified + # heightfield data. + if physics.contexts: + with physics.contexts.gl.make_current() as ctx: + ctx.call( + mjlib.mjr_uploadHField, + physics.model.ptr, + physics.contexts.mujoco.ptr, + _HEIGHTFIELD_ID, + ) + + # Initial configuration. + orientation = self.random.randn(4) + orientation /= np.linalg.norm(orientation) + _find_non_contacting_height(physics, orientation) + + def get_observation(self, physics): + """Returns an observation to the agent.""" + obs = _common_observations(physics) + obs["origin"] = physics.origin() + obs["rangefinder"] = physics.rangefinder() + return obs + + def get_reward(self, physics): + """Returns a reward to the agent.""" + + # Escape reward term. + terrain_size = physics.model.hfield_size[_HEIGHTFIELD_ID, 0] + escape_reward = rewards.tolerance( + physics.origin_distance(), + bounds=(terrain_size, float("inf")), + margin=terrain_size, + value_at_margin=0, + sigmoid="linear", + ) + + return _upright_reward(physics, deviation_angle=20) * escape_reward + + +class Fetch(base.Task): + """A quadruped task solved by bringing a ball to the origin.""" + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode. + + Args: + physics: An instance of `Physics`. + + """ + # Initial configuration, random azimuth and horizontal position. + azimuth = self.random.uniform(0, 2 * np.pi) + orientation = np.array((np.cos(azimuth / 2), 0, 0, np.sin(azimuth / 2))) + spawn_radius = 0.9 * physics.named.model.geom_size["floor", 0] + x_pos, y_pos = self.random.uniform(-spawn_radius, spawn_radius, size=(2,)) + _find_non_contacting_height(physics, orientation, x_pos, y_pos) + + # Initial ball state. + physics.named.data.qpos["ball_root"][:2] = self.random.uniform( + -spawn_radius, spawn_radius, size=(2,) + ) + physics.named.data.qpos["ball_root"][2] = 2 + physics.named.data.qvel["ball_root"][:2] = 5 * self.random.randn(2) + super(Fetch, self).initialize_episode(physics) + + def get_observation(self, physics): + """Returns an observation to the agent.""" + obs = _common_observations(physics) + obs["ball_state"] = physics.ball_state() + obs["target_position"] = physics.target_position() + return obs + + def get_reward(self, physics): + """Returns a reward to the agent.""" + + # Reward for moving close to the ball. + arena_radius = physics.named.model.geom_size["floor", 0] * np.sqrt(2) + workspace_radius = physics.named.model.site_size["workspace", 0] + ball_radius = physics.named.model.geom_size["ball", 0] + reach_reward = rewards.tolerance( + physics.self_to_ball_distance(), + bounds=(0, workspace_radius + ball_radius), + sigmoid="linear", + margin=arena_radius, + value_at_margin=0, + ) + + # Reward for bringing the ball to the target. + target_radius = physics.named.model.site_size["target", 0] + fetch_reward = rewards.tolerance( + physics.ball_to_target_distance(), + bounds=(0, target_radius), + sigmoid="linear", + margin=arena_radius, + value_at_margin=0, + ) + + reach_then_fetch = reach_reward * (0.5 + 0.5 * fetch_reward) + + return _upright_reward(physics) * reach_then_fetch diff --git a/local_dm_control_suite/quadruped.xml b/local_dm_control_suite/quadruped.xml new file mode 100755 index 0000000..958d2c0 --- /dev/null +++ b/local_dm_control_suite/quadruped.xml @@ -0,0 +1,329 @@ + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/reacher.py b/local_dm_control_suite/reacher.py new file mode 100755 index 0000000..f18c792 --- /dev/null +++ b/local_dm_control_suite/reacher.py @@ -0,0 +1,120 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Reacher domain.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from dm_control import mujoco +from dm_control.rl import control +from . import base +from . import common +from dm_control.suite.utils import randomizers +from dm_control.utils import containers +from dm_control.utils import rewards +import numpy as np + +SUITE = containers.TaggedTasks() +_DEFAULT_TIME_LIMIT = 20 +_BIG_TARGET = 0.05 +_SMALL_TARGET = 0.015 + + +def get_model_and_assets(): + """Returns a tuple containing the model XML string and a dict of assets.""" + return common.read_model("reacher.xml"), common.ASSETS + + +@SUITE.add("benchmarking", "easy") +def easy(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): + """Returns reacher with sparse reward with 5e-2 tol and randomized target.""" + physics = Physics.from_xml_string(*get_model_and_assets()) + task = Reacher(target_size=_BIG_TARGET, random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, task, time_limit=time_limit, **environment_kwargs + ) + + +@SUITE.add("benchmarking") +def hard(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): + """Returns reacher with sparse reward with 1e-2 tol and randomized target.""" + physics = Physics.from_xml_string(*get_model_and_assets()) + task = Reacher(target_size=_SMALL_TARGET, random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, task, time_limit=time_limit, **environment_kwargs + ) + + +class Physics(mujoco.Physics): + """Physics simulation with additional features for the Reacher domain.""" + + def finger_to_target(self): + """Returns the vector from target to finger in global coordinates.""" + return ( + self.named.data.geom_xpos["target", :2] + - self.named.data.geom_xpos["finger", :2] + ) + + def finger_to_target_dist(self): + """Returns the signed distance between the finger and target surface.""" + return np.linalg.norm(self.finger_to_target()) + + +class Reacher(base.Task): + """A reacher `Task` to reach the target.""" + + def __init__(self, target_size, random=None): + """Initialize an instance of `Reacher`. + + Args: + target_size: A `float`, tolerance to determine whether finger reached the + target. + random: Optional, either a `numpy.random.RandomState` instance, an + integer seed for creating a new `RandomState`, or None to select a seed + automatically (default). + """ + self._target_size = target_size + super(Reacher, self).__init__(random=random) + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode.""" + physics.named.model.geom_size["target", 0] = self._target_size + randomizers.randomize_limited_and_rotational_joints(physics, self.random) + + # Randomize target position + angle = self.random.uniform(0, 2 * np.pi) + radius = self.random.uniform(0.05, 0.20) + physics.named.model.geom_pos["target", "x"] = radius * np.sin(angle) + physics.named.model.geom_pos["target", "y"] = radius * np.cos(angle) + + super(Reacher, self).initialize_episode(physics) + + def get_observation(self, physics): + """Returns an observation of the state and the target position.""" + obs = collections.OrderedDict() + obs["position"] = physics.position() + obs["to_target"] = physics.finger_to_target() + obs["velocity"] = physics.velocity() + return obs + + def get_reward(self, physics): + radii = physics.named.model.geom_size[["target", "finger"], 0].sum() + return rewards.tolerance(physics.finger_to_target_dist(), (0, radii)) diff --git a/local_dm_control_suite/reacher.xml b/local_dm_control_suite/reacher.xml new file mode 100755 index 0000000..343f799 --- /dev/null +++ b/local_dm_control_suite/reacher.xml @@ -0,0 +1,47 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/stacker.py b/local_dm_control_suite/stacker.py new file mode 100755 index 0000000..a25609b --- /dev/null +++ b/local_dm_control_suite/stacker.py @@ -0,0 +1,224 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Planar Stacker domain.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from dm_control import mujoco +from dm_control.rl import control +from . import base +from . import common +from dm_control.utils import containers +from dm_control.utils import rewards +from dm_control.utils import xml_tools + +from lxml import etree +import numpy as np + + +_CLOSE = 0.01 # (Meters) Distance below which a thing is considered close. +_CONTROL_TIMESTEP = 0.01 # (Seconds) +_TIME_LIMIT = 10 # (Seconds) +_ARM_JOINTS = [ + "arm_root", + "arm_shoulder", + "arm_elbow", + "arm_wrist", + "finger", + "fingertip", + "thumb", + "thumbtip", +] + +SUITE = containers.TaggedTasks() + + +def make_model(n_boxes): + """Returns a tuple containing the model XML string and a dict of assets.""" + xml_string = common.read_model("stacker.xml") + parser = etree.XMLParser(remove_blank_text=True) + mjcf = etree.XML(xml_string, parser) + + # Remove unused boxes + for b in range(n_boxes, 4): + box = xml_tools.find_element(mjcf, "body", "box" + str(b)) + box.getparent().remove(box) + + return etree.tostring(mjcf, pretty_print=True), common.ASSETS + + +@SUITE.add("hard") +def stack_2( + fully_observable=True, time_limit=_TIME_LIMIT, random=None, environment_kwargs=None +): + """Returns stacker task with 2 boxes.""" + n_boxes = 2 + physics = Physics.from_xml_string(*make_model(n_boxes=n_boxes)) + task = Stack(n_boxes=n_boxes, fully_observable=fully_observable, random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, + task, + control_timestep=_CONTROL_TIMESTEP, + time_limit=time_limit, + **environment_kwargs + ) + + +@SUITE.add("hard") +def stack_4( + fully_observable=True, time_limit=_TIME_LIMIT, random=None, environment_kwargs=None +): + """Returns stacker task with 4 boxes.""" + n_boxes = 4 + physics = Physics.from_xml_string(*make_model(n_boxes=n_boxes)) + task = Stack(n_boxes=n_boxes, fully_observable=fully_observable, random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, + task, + control_timestep=_CONTROL_TIMESTEP, + time_limit=time_limit, + **environment_kwargs + ) + + +class Physics(mujoco.Physics): + """Physics with additional features for the Planar Manipulator domain.""" + + def bounded_joint_pos(self, joint_names): + """Returns joint positions as (sin, cos) values.""" + joint_pos = self.named.data.qpos[joint_names] + return np.vstack([np.sin(joint_pos), np.cos(joint_pos)]).T + + def joint_vel(self, joint_names): + """Returns joint velocities.""" + return self.named.data.qvel[joint_names] + + def body_2d_pose(self, body_names, orientation=True): + """Returns positions and/or orientations of bodies.""" + if not isinstance(body_names, str): + body_names = np.array(body_names).reshape(-1, 1) # Broadcast indices. + pos = self.named.data.xpos[body_names, ["x", "z"]] + if orientation: + ori = self.named.data.xquat[body_names, ["qw", "qy"]] + return np.hstack([pos, ori]) + else: + return pos + + def touch(self): + return np.log1p(self.data.sensordata) + + def site_distance(self, site1, site2): + site1_to_site2 = np.diff(self.named.data.site_xpos[[site2, site1]], axis=0) + return np.linalg.norm(site1_to_site2) + + +class Stack(base.Task): + """A Stack `Task`: stack the boxes.""" + + def __init__(self, n_boxes, fully_observable, random=None): + """Initialize an instance of the `Stack` task. + + Args: + n_boxes: An `int`, number of boxes to stack. + fully_observable: A `bool`, whether the observation should contain the + positions and velocities of the boxes and the location of the target. + random: Optional, either a `numpy.random.RandomState` instance, an + integer seed for creating a new `RandomState`, or None to select a seed + automatically (default). + """ + self._n_boxes = n_boxes + self._box_names = ["box" + str(b) for b in range(n_boxes)] + self._box_joint_names = [] + for name in self._box_names: + for dim in "xyz": + self._box_joint_names.append("_".join([name, dim])) + self._fully_observable = fully_observable + super(Stack, self).__init__(random=random) + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode.""" + # Local aliases + randint = self.random.randint + uniform = self.random.uniform + model = physics.named.model + data = physics.named.data + + # Find a collision-free random initial configuration. + penetrating = True + while penetrating: + + # Randomise angles of arm joints. + is_limited = model.jnt_limited[_ARM_JOINTS].astype(np.bool) + joint_range = model.jnt_range[_ARM_JOINTS] + lower_limits = np.where(is_limited, joint_range[:, 0], -np.pi) + upper_limits = np.where(is_limited, joint_range[:, 1], np.pi) + angles = uniform(lower_limits, upper_limits) + data.qpos[_ARM_JOINTS] = angles + + # Symmetrize hand. + data.qpos["finger"] = data.qpos["thumb"] + + # Randomise target location. + target_height = 2 * randint(self._n_boxes) + 1 + box_size = model.geom_size["target", 0] + model.body_pos["target", "z"] = box_size * target_height + model.body_pos["target", "x"] = uniform(-0.37, 0.37) + + # Randomise box locations. + for name in self._box_names: + data.qpos[name + "_x"] = uniform(0.1, 0.3) + data.qpos[name + "_z"] = uniform(0, 0.7) + data.qpos[name + "_y"] = uniform(0, 2 * np.pi) + + # Check for collisions. + physics.after_reset() + penetrating = physics.data.ncon > 0 + + super(Stack, self).initialize_episode(physics) + + def get_observation(self, physics): + """Returns either features or only sensors (to be used with pixels).""" + obs = collections.OrderedDict() + obs["arm_pos"] = physics.bounded_joint_pos(_ARM_JOINTS) + obs["arm_vel"] = physics.joint_vel(_ARM_JOINTS) + obs["touch"] = physics.touch() + if self._fully_observable: + obs["hand_pos"] = physics.body_2d_pose("hand") + obs["box_pos"] = physics.body_2d_pose(self._box_names) + obs["box_vel"] = physics.joint_vel(self._box_joint_names) + obs["target_pos"] = physics.body_2d_pose("target", orientation=False) + return obs + + def get_reward(self, physics): + """Returns a reward to the agent.""" + box_size = physics.named.model.geom_size["target", 0] + min_box_to_target_distance = min( + physics.site_distance(name, "target") for name in self._box_names + ) + box_is_close = rewards.tolerance( + min_box_to_target_distance, margin=2 * box_size + ) + hand_to_target_distance = physics.site_distance("grasp", "target") + hand_is_far = rewards.tolerance( + hand_to_target_distance, bounds=(0.1, float("inf")), margin=_CLOSE + ) + return box_is_close * hand_is_far diff --git a/local_dm_control_suite/stacker.xml b/local_dm_control_suite/stacker.xml new file mode 100755 index 0000000..7af4877 --- /dev/null +++ b/local_dm_control_suite/stacker.xml @@ -0,0 +1,193 @@ + + + + + + + + + + + + + + > + + diff --git a/local_dm_control_suite/swimmer.py b/local_dm_control_suite/swimmer.py new file mode 100755 index 0000000..b915510 --- /dev/null +++ b/local_dm_control_suite/swimmer.py @@ -0,0 +1,225 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Procedurally generated Swimmer domain.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from dm_control import mujoco +from dm_control.rl import control +from . import base +from . import common +from dm_control.suite.utils import randomizers +from dm_control.utils import containers +from dm_control.utils import rewards +from lxml import etree +import numpy as np +from six.moves import range + +_DEFAULT_TIME_LIMIT = 30 +_CONTROL_TIMESTEP = 0.03 # (Seconds) + +SUITE = containers.TaggedTasks() + + +def get_model_and_assets(n_joints): + """Returns a tuple containing the model XML string and a dict of assets. + + Args: + n_joints: An integer specifying the number of joints in the swimmer. + + Returns: + A tuple `(model_xml_string, assets)`, where `assets` is a dict consisting of + `{filename: contents_string}` pairs. + """ + return _make_model(n_joints), common.ASSETS + + +@SUITE.add("benchmarking") +def swimmer6(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): + """Returns a 6-link swimmer.""" + return _make_swimmer( + 6, time_limit, random=random, environment_kwargs=environment_kwargs + ) + + +@SUITE.add("benchmarking") +def swimmer15(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None): + """Returns a 15-link swimmer.""" + return _make_swimmer( + 15, time_limit, random=random, environment_kwargs=environment_kwargs + ) + + +def swimmer( + n_links=3, time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None +): + """Returns a swimmer with n links.""" + return _make_swimmer( + n_links, time_limit, random=random, environment_kwargs=environment_kwargs + ) + + +def _make_swimmer( + n_joints, time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None +): + """Returns a swimmer control environment.""" + model_string, assets = get_model_and_assets(n_joints) + physics = Physics.from_xml_string(model_string, assets=assets) + task = Swimmer(random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, + task, + time_limit=time_limit, + control_timestep=_CONTROL_TIMESTEP, + **environment_kwargs + ) + + +def _make_model(n_bodies): + """Generates an xml string defining a swimmer with `n_bodies` bodies.""" + if n_bodies < 3: + raise ValueError("At least 3 bodies required. Received {}".format(n_bodies)) + mjcf = etree.fromstring(common.read_model("swimmer.xml")) + head_body = mjcf.find("./worldbody/body") + actuator = etree.SubElement(mjcf, "actuator") + sensor = etree.SubElement(mjcf, "sensor") + + parent = head_body + for body_index in range(n_bodies - 1): + site_name = "site_{}".format(body_index) + child = _make_body(body_index=body_index) + child.append(etree.Element("site", name=site_name)) + joint_name = "joint_{}".format(body_index) + joint_limit = 360.0 / n_bodies + joint_range = "{} {}".format(-joint_limit, joint_limit) + child.append(etree.Element("joint", {"name": joint_name, "range": joint_range})) + motor_name = "motor_{}".format(body_index) + actuator.append(etree.Element("motor", name=motor_name, joint=joint_name)) + velocimeter_name = "velocimeter_{}".format(body_index) + sensor.append( + etree.Element("velocimeter", name=velocimeter_name, site=site_name) + ) + gyro_name = "gyro_{}".format(body_index) + sensor.append(etree.Element("gyro", name=gyro_name, site=site_name)) + parent.append(child) + parent = child + + # Move tracking cameras further away from the swimmer according to its length. + cameras = mjcf.findall("./worldbody/body/camera") + scale = n_bodies / 6.0 + for cam in cameras: + if cam.get("mode") == "trackcom": + old_pos = cam.get("pos").split(" ") + new_pos = " ".join([str(float(dim) * scale) for dim in old_pos]) + cam.set("pos", new_pos) + + return etree.tostring(mjcf, pretty_print=True) + + +def _make_body(body_index): + """Generates an xml string defining a single physical body.""" + body_name = "segment_{}".format(body_index) + visual_name = "visual_{}".format(body_index) + inertial_name = "inertial_{}".format(body_index) + body = etree.Element("body", name=body_name) + body.set("pos", "0 .1 0") + etree.SubElement(body, "geom", {"class": "visual", "name": visual_name}) + etree.SubElement(body, "geom", {"class": "inertial", "name": inertial_name}) + return body + + +class Physics(mujoco.Physics): + """Physics simulation with additional features for the swimmer domain.""" + + def nose_to_target(self): + """Returns a vector from nose to target in local coordinate of the head.""" + nose_to_target = ( + self.named.data.geom_xpos["target"] - self.named.data.geom_xpos["nose"] + ) + head_orientation = self.named.data.xmat["head"].reshape(3, 3) + return nose_to_target.dot(head_orientation)[:2] + + def nose_to_target_dist(self): + """Returns the distance from the nose to the target.""" + return np.linalg.norm(self.nose_to_target()) + + def body_velocities(self): + """Returns local body velocities: x,y linear, z rotational.""" + xvel_local = self.data.sensordata[12:].reshape((-1, 6)) + vx_vy_wz = [0, 1, 5] # Indices for linear x,y vels and rotational z vel. + return xvel_local[:, vx_vy_wz].ravel() + + def joints(self): + """Returns all internal joint angles (excluding root joints).""" + return self.data.qpos[3:].copy() + + +class Swimmer(base.Task): + """A swimmer `Task` to reach the target or just swim.""" + + def __init__(self, random=None): + """Initializes an instance of `Swimmer`. + + Args: + random: Optional, either a `numpy.random.RandomState` instance, an + integer seed for creating a new `RandomState`, or None to select a seed + automatically (default). + """ + super(Swimmer, self).__init__(random=random) + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode. + + Initializes the swimmer orientation to [-pi, pi) and the relative joint + angle of each joint uniformly within its range. + + Args: + physics: An instance of `Physics`. + """ + # Random joint angles: + randomizers.randomize_limited_and_rotational_joints(physics, self.random) + # Random target position. + close_target = self.random.rand() < 0.2 # Probability of a close target. + target_box = 0.3 if close_target else 2 + xpos, ypos = self.random.uniform(-target_box, target_box, size=2) + physics.named.model.geom_pos["target", "x"] = xpos + physics.named.model.geom_pos["target", "y"] = ypos + physics.named.model.light_pos["target_light", "x"] = xpos + physics.named.model.light_pos["target_light", "y"] = ypos + + super(Swimmer, self).initialize_episode(physics) + + def get_observation(self, physics): + """Returns an observation of joint angles, body velocities and target.""" + obs = collections.OrderedDict() + obs["joints"] = physics.joints() + obs["to_target"] = physics.nose_to_target() + obs["body_velocities"] = physics.body_velocities() + return obs + + def get_reward(self, physics): + """Returns a smooth reward.""" + target_size = physics.named.model.geom_size["target", 0] + return rewards.tolerance( + physics.nose_to_target_dist(), + bounds=(0, target_size), + margin=5 * target_size, + sigmoid="long_tail", + ) diff --git a/local_dm_control_suite/swimmer.xml b/local_dm_control_suite/swimmer.xml new file mode 100755 index 0000000..29c7bc8 --- /dev/null +++ b/local_dm_control_suite/swimmer.xml @@ -0,0 +1,57 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/local_dm_control_suite/tests/domains_test.py b/local_dm_control_suite/tests/domains_test.py new file mode 100755 index 0000000..615401a --- /dev/null +++ b/local_dm_control_suite/tests/domains_test.py @@ -0,0 +1,319 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for dm_control.suite domains.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Internal dependencies. +from absl.testing import absltest +from absl.testing import parameterized +from dm_control import suite +from dm_control.rl import control +import mock +import numpy as np +import six +from six.moves import range +from six.moves import zip + + +def uniform_random_policy(action_spec, random=None): + lower_bounds = action_spec.minimum + upper_bounds = action_spec.maximum + # Draw values between -1 and 1 for unbounded actions. + lower_bounds = np.where(np.isinf(lower_bounds), -1.0, lower_bounds) + upper_bounds = np.where(np.isinf(upper_bounds), 1.0, upper_bounds) + random_state = np.random.RandomState(random) + + def policy(time_step): + del time_step # Unused. + return random_state.uniform(lower_bounds, upper_bounds) + + return policy + + +def step_environment(env, policy, num_episodes=5, max_steps_per_episode=10): + for _ in range(num_episodes): + step_count = 0 + time_step = env.reset() + yield time_step + while not time_step.last(): + action = policy(time_step) + time_step = env.step(action) + step_count += 1 + yield time_step + if step_count >= max_steps_per_episode: + break + + +def make_trajectory(domain, task, seed, **trajectory_kwargs): + env = suite.load(domain, task, task_kwargs={"random": seed}) + policy = uniform_random_policy(env.action_spec(), random=seed) + return step_environment(env, policy, **trajectory_kwargs) + + +class DomainTest(parameterized.TestCase): + """Tests run on all the tasks registered.""" + + def test_constants(self): + num_tasks = sum(len(tasks) for tasks in six.itervalues(suite.TASKS_BY_DOMAIN)) + + self.assertLen(suite.ALL_TASKS, num_tasks) + + def _validate_observation(self, observation_dict, observation_spec): + obs = observation_dict.copy() + for name, spec in six.iteritems(observation_spec): + arr = obs.pop(name) + self.assertEqual(arr.shape, spec.shape) + self.assertEqual(arr.dtype, spec.dtype) + self.assertTrue( + np.all(np.isfinite(arr)), + msg="{!r} has non-finite value(s): {!r}".format(name, arr), + ) + self.assertEmpty( + obs, + msg="Observation contains arrays(s) that are not in the spec: {!r}".format( + obs + ), + ) + + def _validate_reward_range(self, time_step): + if time_step.first(): + self.assertIsNone(time_step.reward) + else: + self.assertIsInstance(time_step.reward, float) + self.assertBetween(time_step.reward, 0, 1) + + def _validate_discount(self, time_step): + if time_step.first(): + self.assertIsNone(time_step.discount) + else: + self.assertIsInstance(time_step.discount, float) + self.assertBetween(time_step.discount, 0, 1) + + def _validate_control_range(self, lower_bounds, upper_bounds): + for b in lower_bounds: + self.assertEqual(b, -1.0) + for b in upper_bounds: + self.assertEqual(b, 1.0) + + @parameterized.parameters(*suite.ALL_TASKS) + def test_components_have_names(self, domain, task): + env = suite.load(domain, task) + model = env.physics.model + + object_types_and_size_fields = [ + ("body", "nbody"), + ("joint", "njnt"), + ("geom", "ngeom"), + ("site", "nsite"), + ("camera", "ncam"), + ("light", "nlight"), + ("mesh", "nmesh"), + ("hfield", "nhfield"), + ("texture", "ntex"), + ("material", "nmat"), + ("equality", "neq"), + ("tendon", "ntendon"), + ("actuator", "nu"), + ("sensor", "nsensor"), + ("numeric", "nnumeric"), + ("text", "ntext"), + ("tuple", "ntuple"), + ] + for object_type, size_field in object_types_and_size_fields: + for idx in range(getattr(model, size_field)): + object_name = model.id2name(idx, object_type) + self.assertNotEqual( + object_name, + "", + msg="Model {!r} contains unnamed {!r} with ID {}.".format( + model.name, object_type, idx + ), + ) + + @parameterized.parameters(*suite.ALL_TASKS) + def test_model_has_at_least_2_cameras(self, domain, task): + env = suite.load(domain, task) + model = env.physics.model + self.assertGreaterEqual( + model.ncam, + 2, + "Model {!r} should have at least 2 cameras, has {}.".format( + model.name, model.ncam + ), + ) + + @parameterized.parameters(*suite.ALL_TASKS) + def test_task_conforms_to_spec(self, domain, task): + """Tests that the environment timesteps conform to specifications.""" + is_benchmark = (domain, task) in suite.BENCHMARKING + env = suite.load(domain, task) + observation_spec = env.observation_spec() + action_spec = env.action_spec() + + # Check action bounds. + if is_benchmark: + self._validate_control_range(action_spec.minimum, action_spec.maximum) + + # Step through the environment, applying random actions sampled within the + # valid range and check the observations, rewards, and discounts. + policy = uniform_random_policy(action_spec) + for time_step in step_environment(env, policy): + self._validate_observation(time_step.observation, observation_spec) + self._validate_discount(time_step) + if is_benchmark: + self._validate_reward_range(time_step) + + @parameterized.parameters(*suite.ALL_TASKS) + def test_environment_is_deterministic(self, domain, task): + """Tests that identical seeds and actions produce identical trajectories.""" + seed = 0 + # Iterate over two trajectories generated using identical sequences of + # random actions, and with identical task random states. Check that the + # observations, rewards, discounts and step types are identical. + trajectory1 = make_trajectory(domain=domain, task=task, seed=seed) + trajectory2 = make_trajectory(domain=domain, task=task, seed=seed) + for time_step1, time_step2 in zip(trajectory1, trajectory2): + self.assertEqual(time_step1.step_type, time_step2.step_type) + self.assertEqual(time_step1.reward, time_step2.reward) + self.assertEqual(time_step1.discount, time_step2.discount) + for key in six.iterkeys(time_step1.observation): + np.testing.assert_array_equal( + time_step1.observation[key], + time_step2.observation[key], + err_msg="Observation {!r} is not equal.".format(key), + ) + + def assertCorrectColors(self, physics, reward): + colors = physics.named.model.mat_rgba + for material_name in ("self", "effector", "target"): + highlight = colors[material_name + "_highlight"] + default = colors[material_name + "_default"] + blend_coef = reward ** 4 + expected = blend_coef * highlight + (1.0 - blend_coef) * default + actual = colors[material_name] + err_msg = ( + "Material {!r} has unexpected color.\nExpected: {!r}\n" + "Actual: {!r}".format(material_name, expected, actual) + ) + np.testing.assert_array_almost_equal(expected, actual, err_msg=err_msg) + + @parameterized.parameters(*suite.ALL_TASKS) + def test_visualize_reward(self, domain, task): + env = suite.load(domain, task) + env.task.visualize_reward = True + action = np.zeros(env.action_spec().shape) + + with mock.patch.object(env.task, "get_reward") as mock_get_reward: + mock_get_reward.return_value = -3.0 # Rewards < 0 should be clipped. + env.reset() + mock_get_reward.assert_called_with(env.physics) + self.assertCorrectColors(env.physics, reward=0.0) + + mock_get_reward.reset_mock() + mock_get_reward.return_value = 0.5 + env.step(action) + mock_get_reward.assert_called_with(env.physics) + self.assertCorrectColors(env.physics, reward=mock_get_reward.return_value) + + mock_get_reward.reset_mock() + mock_get_reward.return_value = 2.0 # Rewards > 1 should be clipped. + env.step(action) + mock_get_reward.assert_called_with(env.physics) + self.assertCorrectColors(env.physics, reward=1.0) + + mock_get_reward.reset_mock() + mock_get_reward.return_value = 0.25 + env.reset() + mock_get_reward.assert_called_with(env.physics) + self.assertCorrectColors(env.physics, reward=mock_get_reward.return_value) + + @parameterized.parameters(*suite.ALL_TASKS) + def test_task_supports_environment_kwargs(self, domain, task): + env = suite.load(domain, task, environment_kwargs=dict(flat_observation=True)) + # Check that the kwargs are actually passed through to the environment. + self.assertSetEqual(set(env.observation_spec()), {control.FLAT_OBSERVATION_KEY}) + + @parameterized.parameters(*suite.ALL_TASKS) + def test_observation_arrays_dont_share_memory(self, domain, task): + env = suite.load(domain, task) + first_timestep = env.reset() + action = np.zeros(env.action_spec().shape) + second_timestep = env.step(action) + for name, first_array in six.iteritems(first_timestep.observation): + second_array = second_timestep.observation[name] + self.assertFalse( + np.may_share_memory(first_array, second_array), + msg="Consecutive observations of {!r} may share memory.".format(name), + ) + + @parameterized.parameters(*suite.ALL_TASKS) + def test_observations_dont_contain_constant_elements(self, domain, task): + env = suite.load(domain, task) + trajectory = make_trajectory( + domain=domain, task=task, seed=0, num_episodes=2, max_steps_per_episode=1000 + ) + observations = {name: [] for name in env.observation_spec()} + for time_step in trajectory: + for name, array in six.iteritems(time_step.observation): + observations[name].append(array) + + failures = [] + + for name, array_list in six.iteritems(observations): + # Sampling random uniform actions generally isn't sufficient to trigger + # these touch sensors. + if ( + domain in ("manipulator", "stacker") + and name == "touch" + or domain == "quadruped" + and name == "force_torque" + ): + continue + stacked_arrays = np.array(array_list) + is_constant = np.all(stacked_arrays == stacked_arrays[0], axis=0) + has_constant_elements = ( + is_constant if np.isscalar(is_constant) else np.any(is_constant) + ) + if has_constant_elements: + failures.append((name, is_constant)) + + self.assertEmpty( + failures, + msg="The following observation(s) contain constant elements:\n{}".format( + "\n".join( + ":\t".join([name, str(is_constant)]) + for (name, is_constant) in failures + ) + ), + ) + + @parameterized.parameters(*suite.ALL_TASKS) + def test_initial_state_is_randomized(self, domain, task): + env = suite.load(domain, task, task_kwargs={"random": 42}) + obs1 = env.reset().observation + obs2 = env.reset().observation + self.assertFalse( + all(np.all(obs1[k] == obs2[k]) for k in obs1), + "Two consecutive initial states have identical observations.\n" + "First: {}\nSecond: {}".format(obs1, obs2), + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/local_dm_control_suite/tests/loader_test.py b/local_dm_control_suite/tests/loader_test.py new file mode 100755 index 0000000..8570bf9 --- /dev/null +++ b/local_dm_control_suite/tests/loader_test.py @@ -0,0 +1,51 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for the dm_control.suite loader.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Internal dependencies. + +from absl.testing import absltest + +from dm_control import suite +from dm_control.rl import control + + +class LoaderTest(absltest.TestCase): + def test_load_without_kwargs(self): + env = suite.load("cartpole", "swingup") + self.assertIsInstance(env, control.Environment) + + def test_load_with_kwargs(self): + env = suite.load( + "cartpole", "swingup", task_kwargs={"time_limit": 40, "random": 99} + ) + self.assertIsInstance(env, control.Environment) + + +class LoaderConstantsTest(absltest.TestCase): + def testSuiteConstants(self): + self.assertNotEmpty(suite.BENCHMARKING) + self.assertNotEmpty(suite.EASY) + self.assertNotEmpty(suite.HARD) + self.assertNotEmpty(suite.EXTRA) + + +if __name__ == "__main__": + absltest.main() diff --git a/local_dm_control_suite/tests/lqr_test.py b/local_dm_control_suite/tests/lqr_test.py new file mode 100755 index 0000000..7d168c4 --- /dev/null +++ b/local_dm_control_suite/tests/lqr_test.py @@ -0,0 +1,87 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests specific to the LQR domain.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import unittest + +# Internal dependencies. +from absl import logging + +from absl.testing import absltest +from absl.testing import parameterized + +from . import lqr +from . import lqr_solver + +import numpy as np +from six.moves import range + + +class LqrTest(parameterized.TestCase): + @parameterized.named_parameters(("lqr_2_1", lqr.lqr_2_1), ("lqr_6_2", lqr.lqr_6_2)) + def test_lqr_optimal_policy(self, make_env): + env = make_env() + p, k, beta = lqr_solver.solve(env) + self.assertPolicyisOptimal(env, p, k, beta) + + @parameterized.named_parameters(("lqr_2_1", lqr.lqr_2_1), ("lqr_6_2", lqr.lqr_6_2)) + @unittest.skipUnless( + condition=lqr_solver.sp, + reason="scipy is not available, so non-scipy DARE solver is the default.", + ) + def test_lqr_optimal_policy_no_scipy(self, make_env): + env = make_env() + old_sp = lqr_solver.sp + try: + lqr_solver.sp = None # Force the solver to use the non-scipy code path. + p, k, beta = lqr_solver.solve(env) + finally: + lqr_solver.sp = old_sp + self.assertPolicyisOptimal(env, p, k, beta) + + def assertPolicyisOptimal(self, env, p, k, beta): + tolerance = 1e-3 + n_steps = int(math.ceil(math.log10(tolerance) / math.log10(beta))) + logging.info("%d timesteps for %g convergence.", n_steps, tolerance) + total_loss = 0.0 + + timestep = env.reset() + initial_state = np.hstack( + (timestep.observation["position"], timestep.observation["velocity"]) + ) + logging.info("Measuring total cost over %d steps.", n_steps) + for _ in range(n_steps): + x = np.hstack( + (timestep.observation["position"], timestep.observation["velocity"]) + ) + # u = k*x is the optimal policy + u = k.dot(x) + total_loss += 1 - (timestep.reward or 0.0) + timestep = env.step(u) + + logging.info("Analytical expected total cost is .5*x^T*p*x.") + expected_loss = 0.5 * initial_state.T.dot(p).dot(initial_state) + logging.info("Comparing measured and predicted costs.") + np.testing.assert_allclose(expected_loss, total_loss, rtol=tolerance) + + +if __name__ == "__main__": + absltest.main() diff --git a/local_dm_control_suite/utils/__init__.py b/local_dm_control_suite/utils/__init__.py new file mode 100755 index 0000000..2ea19cf --- /dev/null +++ b/local_dm_control_suite/utils/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Utility functions used in the control suite.""" diff --git a/local_dm_control_suite/utils/parse_amc.py b/local_dm_control_suite/utils/parse_amc.py new file mode 100755 index 0000000..51c314e --- /dev/null +++ b/local_dm_control_suite/utils/parse_amc.py @@ -0,0 +1,301 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Parse and convert amc motion capture data.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from dm_control.mujoco.wrapper import mjbindings +import numpy as np +from scipy import interpolate +from six.moves import range + +mjlib = mjbindings.mjlib + +MOCAP_DT = 1.0 / 120.0 +CONVERSION_LENGTH = 0.056444 + +_CMU_MOCAP_JOINT_ORDER = ( + "root0", + "root1", + "root2", + "root3", + "root4", + "root5", + "lowerbackrx", + "lowerbackry", + "lowerbackrz", + "upperbackrx", + "upperbackry", + "upperbackrz", + "thoraxrx", + "thoraxry", + "thoraxrz", + "lowerneckrx", + "lowerneckry", + "lowerneckrz", + "upperneckrx", + "upperneckry", + "upperneckrz", + "headrx", + "headry", + "headrz", + "rclaviclery", + "rclaviclerz", + "rhumerusrx", + "rhumerusry", + "rhumerusrz", + "rradiusrx", + "rwristry", + "rhandrx", + "rhandrz", + "rfingersrx", + "rthumbrx", + "rthumbrz", + "lclaviclery", + "lclaviclerz", + "lhumerusrx", + "lhumerusry", + "lhumerusrz", + "lradiusrx", + "lwristry", + "lhandrx", + "lhandrz", + "lfingersrx", + "lthumbrx", + "lthumbrz", + "rfemurrx", + "rfemurry", + "rfemurrz", + "rtibiarx", + "rfootrx", + "rfootrz", + "rtoesrx", + "lfemurrx", + "lfemurry", + "lfemurrz", + "ltibiarx", + "lfootrx", + "lfootrz", + "ltoesrx", +) + +Converted = collections.namedtuple("Converted", ["qpos", "qvel", "time"]) + + +def convert(file_name, physics, timestep): + """Converts the parsed .amc values into qpos and qvel values and resamples. + + Args: + file_name: The .amc file to be parsed and converted. + physics: The corresponding physics instance. + timestep: Desired output interval between resampled frames. + + Returns: + A namedtuple with fields: + `qpos`, a numpy array containing converted positional variables. + `qvel`, a numpy array containing converted velocity variables. + `time`, a numpy array containing the corresponding times. + """ + frame_values = parse(file_name) + joint2index = {} + for name in physics.named.data.qpos.axes.row.names: + joint2index[name] = physics.named.data.qpos.axes.row.convert_key_item(name) + index2joint = {} + for joint, index in joint2index.items(): + if isinstance(index, slice): + indices = range(index.start, index.stop) + else: + indices = [index] + for ii in indices: + index2joint[ii] = joint + + # Convert frame_values to qpos + amcvals2qpos_transformer = Amcvals2qpos(index2joint, _CMU_MOCAP_JOINT_ORDER) + qpos_values = [] + for frame_value in frame_values: + qpos_values.append(amcvals2qpos_transformer(frame_value)) + qpos_values = np.stack(qpos_values) # Time by nq + + # Interpolate/resample. + # Note: interpolate quaternions rather than euler angles (slerp). + # see https://en.wikipedia.org/wiki/Slerp + qpos_values_resampled = [] + time_vals = np.arange(0, len(frame_values) * MOCAP_DT - 1e-8, MOCAP_DT) + time_vals_new = np.arange(0, len(frame_values) * MOCAP_DT, timestep) + while time_vals_new[-1] > time_vals[-1]: + time_vals_new = time_vals_new[:-1] + + for i in range(qpos_values.shape[1]): + f = interpolate.splrep(time_vals, qpos_values[:, i]) + qpos_values_resampled.append(interpolate.splev(time_vals_new, f)) + + qpos_values_resampled = np.stack(qpos_values_resampled) # nq by ntime + + qvel_list = [] + for t in range(qpos_values_resampled.shape[1] - 1): + p_tp1 = qpos_values_resampled[:, t + 1] + p_t = qpos_values_resampled[:, t] + qvel = [ + (p_tp1[:3] - p_t[:3]) / timestep, + mj_quat2vel(mj_quatdiff(p_t[3:7], p_tp1[3:7]), timestep), + (p_tp1[7:] - p_t[7:]) / timestep, + ] + qvel_list.append(np.concatenate(qvel)) + + qvel_values_resampled = np.vstack(qvel_list).T + + return Converted(qpos_values_resampled, qvel_values_resampled, time_vals_new) + + +def parse(file_name): + """Parses the amc file format.""" + values = [] + fid = open(file_name, "r") + line = fid.readline().strip() + frame_ind = 1 + first_frame = True + while True: + # Parse first frame. + if first_frame and line[0] == str(frame_ind): + first_frame = False + frame_ind += 1 + frame_vals = [] + while True: + line = fid.readline().strip() + if not line or line == str(frame_ind): + values.append(np.array(frame_vals, dtype=np.float)) + break + tokens = line.split() + frame_vals.extend(tokens[1:]) + # Parse other frames. + elif line == str(frame_ind): + frame_ind += 1 + frame_vals = [] + while True: + line = fid.readline().strip() + if not line or line == str(frame_ind): + values.append(np.array(frame_vals, dtype=np.float)) + break + tokens = line.split() + frame_vals.extend(tokens[1:]) + else: + line = fid.readline().strip() + if not line: + break + return values + + +class Amcvals2qpos(object): + """Callable that converts .amc values for a frame and to MuJoCo qpos format.""" + + def __init__(self, index2joint, joint_order): + """Initializes a new Amcvals2qpos instance. + + Args: + index2joint: List of joint angles in .amc file. + joint_order: List of joint names in MuJoco MJCF. + """ + # Root is x,y,z, then quat. + # need to get indices of qpos that order for amc default order + self.qpos_root_xyz_ind = [0, 1, 2] + self.root_xyz_ransform = ( + np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]]) * CONVERSION_LENGTH + ) + self.qpos_root_quat_ind = [3, 4, 5, 6] + amc2qpos_transform = np.zeros((len(index2joint), len(joint_order))) + for i in range(len(index2joint)): + for j in range(len(joint_order)): + if index2joint[i] == joint_order[j]: + if "rx" in index2joint[i]: + amc2qpos_transform[i][j] = 1 + elif "ry" in index2joint[i]: + amc2qpos_transform[i][j] = 1 + elif "rz" in index2joint[i]: + amc2qpos_transform[i][j] = 1 + self.amc2qpos_transform = amc2qpos_transform + + def __call__(self, amc_val): + """Converts a `.amc` frame to MuJoCo qpos format.""" + amc_val_rad = np.deg2rad(amc_val) + qpos = np.dot(self.amc2qpos_transform, amc_val_rad) + + # Root. + qpos[:3] = np.dot(self.root_xyz_ransform, amc_val[:3]) + qpos_quat = euler2quat(amc_val[3], amc_val[4], amc_val[5]) + qpos_quat = mj_quatprod(euler2quat(90, 0, 0), qpos_quat) + + for i, ind in enumerate(self.qpos_root_quat_ind): + qpos[ind] = qpos_quat[i] + + return qpos + + +def euler2quat(ax, ay, az): + """Converts euler angles to a quaternion. + + Note: rotation order is zyx + + Args: + ax: Roll angle (deg) + ay: Pitch angle (deg). + az: Yaw angle (deg). + + Returns: + A numpy array representing the rotation as a quaternion. + """ + r1 = az + r2 = ay + r3 = ax + + c1 = np.cos(np.deg2rad(r1 / 2)) + s1 = np.sin(np.deg2rad(r1 / 2)) + c2 = np.cos(np.deg2rad(r2 / 2)) + s2 = np.sin(np.deg2rad(r2 / 2)) + c3 = np.cos(np.deg2rad(r3 / 2)) + s3 = np.sin(np.deg2rad(r3 / 2)) + + q0 = c1 * c2 * c3 + s1 * s2 * s3 + q1 = c1 * c2 * s3 - s1 * s2 * c3 + q2 = c1 * s2 * c3 + s1 * c2 * s3 + q3 = s1 * c2 * c3 - c1 * s2 * s3 + + return np.array([q0, q1, q2, q3]) + + +def mj_quatprod(q, r): + quaternion = np.zeros(4) + mjlib.mju_mulQuat(quaternion, np.ascontiguousarray(q), np.ascontiguousarray(r)) + return quaternion + + +def mj_quat2vel(q, dt): + vel = np.zeros(3) + mjlib.mju_quat2Vel(vel, np.ascontiguousarray(q), dt) + return vel + + +def mj_quatneg(q): + quaternion = np.zeros(4) + mjlib.mju_negQuat(quaternion, np.ascontiguousarray(q)) + return quaternion + + +def mj_quatdiff(source, target): + return mj_quatprod(mj_quatneg(source), np.ascontiguousarray(target)) diff --git a/local_dm_control_suite/utils/parse_amc_test.py b/local_dm_control_suite/utils/parse_amc_test.py new file mode 100755 index 0000000..4d3e6c8 --- /dev/null +++ b/local_dm_control_suite/utils/parse_amc_test.py @@ -0,0 +1,68 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for parse_amc utility.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +# Internal dependencies. + +from absl.testing import absltest +from . import humanoid_CMU +from dm_control.suite.utils import parse_amc + +from dm_control.utils import io as resources + +_TEST_AMC_PATH = resources.GetResourceFilename( + os.path.join(os.path.dirname(__file__), "../demos/zeros.amc") +) + + +class ParseAMCTest(absltest.TestCase): + def test_sizes_of_parsed_data(self): + + # Instantiate the humanoid environment. + env = humanoid_CMU.stand() + + # Parse and convert specified clip. + converted = parse_amc.convert( + _TEST_AMC_PATH, env.physics, env.control_timestep() + ) + + self.assertEqual(converted.qpos.shape[0], 63) + self.assertEqual(converted.qvel.shape[0], 62) + self.assertEqual(converted.time.shape[0], converted.qpos.shape[1]) + self.assertEqual(converted.qpos.shape[1], converted.qvel.shape[1] + 1) + + # Parse and convert specified clip -- WITH SMALLER TIMESTEP + converted2 = parse_amc.convert( + _TEST_AMC_PATH, env.physics, 0.5 * env.control_timestep() + ) + + self.assertEqual(converted2.qpos.shape[0], 63) + self.assertEqual(converted2.qvel.shape[0], 62) + self.assertEqual(converted2.time.shape[0], converted2.qpos.shape[1]) + self.assertEqual(converted.qpos.shape[1], converted.qvel.shape[1] + 1) + + # Compare sizes of parsed objects for different timesteps + self.assertEqual(converted.qpos.shape[1] * 2, converted2.qpos.shape[1]) + + +if __name__ == "__main__": + absltest.main() diff --git a/local_dm_control_suite/utils/randomizers.py b/local_dm_control_suite/utils/randomizers.py new file mode 100755 index 0000000..2e472d2 --- /dev/null +++ b/local_dm_control_suite/utils/randomizers.py @@ -0,0 +1,90 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Randomization functions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from dm_control.mujoco.wrapper import mjbindings +import numpy as np +from six.moves import range + + +def random_limited_quaternion(random, limit): + """Generates a random quaternion limited to the specified rotations.""" + axis = random.randn(3) + axis /= np.linalg.norm(axis) + angle = random.rand() * limit + + quaternion = np.zeros(4) + mjbindings.mjlib.mju_axisAngle2Quat(quaternion, axis, angle) + + return quaternion + + +def randomize_limited_and_rotational_joints(physics, random=None): + """Randomizes the positions of joints defined in the physics body. + + The following randomization rules apply: + - Bounded joints (hinges or sliders) are sampled uniformly in the bounds. + - Unbounded hinges are samples uniformly in [-pi, pi] + - Quaternions for unlimited free joints and ball joints are sampled + uniformly on the unit 3-sphere. + - Quaternions for limited ball joints are sampled uniformly on a sector + of the unit 3-sphere. + - The linear degrees of freedom of free joints are not randomized. + + Args: + physics: Instance of 'Physics' class that holds a loaded model. + random: Optional instance of 'np.random.RandomState'. Defaults to the global + NumPy random state. + """ + random = random or np.random + + hinge = mjbindings.enums.mjtJoint.mjJNT_HINGE + slide = mjbindings.enums.mjtJoint.mjJNT_SLIDE + ball = mjbindings.enums.mjtJoint.mjJNT_BALL + free = mjbindings.enums.mjtJoint.mjJNT_FREE + + qpos = physics.named.data.qpos + + for joint_id in range(physics.model.njnt): + joint_name = physics.model.id2name(joint_id, "joint") + joint_type = physics.model.jnt_type[joint_id] + is_limited = physics.model.jnt_limited[joint_id] + range_min, range_max = physics.model.jnt_range[joint_id] + + if is_limited: + if joint_type == hinge or joint_type == slide: + qpos[joint_name] = random.uniform(range_min, range_max) + + elif joint_type == ball: + qpos[joint_name] = random_limited_quaternion(random, range_max) + + else: + if joint_type == hinge: + qpos[joint_name] = random.uniform(-np.pi, np.pi) + + elif joint_type == ball: + quat = random.randn(4) + quat /= np.linalg.norm(quat) + qpos[joint_name] = quat + + elif joint_type == free: + quat = random.rand(4) + quat /= np.linalg.norm(quat) + qpos[joint_name][3:] = quat diff --git a/local_dm_control_suite/utils/randomizers_test.py b/local_dm_control_suite/utils/randomizers_test.py new file mode 100755 index 0000000..fcfb0dd --- /dev/null +++ b/local_dm_control_suite/utils/randomizers_test.py @@ -0,0 +1,177 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for randomizers.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Internal dependencies. +from absl.testing import absltest +from absl.testing import parameterized +from dm_control import mujoco +from dm_control.mujoco.wrapper import mjbindings +from dm_control.suite.utils import randomizers +import numpy as np +from six.moves import range + +mjlib = mjbindings.mjlib + + +class RandomizeUnlimitedJointsTest(parameterized.TestCase): + def setUp(self): + self.rand = np.random.RandomState(100) + + def test_single_joint_of_each_type(self): + physics = mujoco.Physics.from_xml_string( + """ + + + + + + + + + + + + + + + + + + + + + + + + + """ + ) + + randomizers.randomize_limited_and_rotational_joints(physics, self.rand) + self.assertNotEqual(0.0, physics.named.data.qpos["hinge"]) + self.assertNotEqual(0.0, physics.named.data.qpos["limited_hinge"]) + self.assertNotEqual(0.0, physics.named.data.qpos["limited_slide"]) + + self.assertNotEqual(0.0, np.sum(physics.named.data.qpos["ball"])) + self.assertNotEqual(0.0, np.sum(physics.named.data.qpos["limited_ball"])) + + self.assertNotEqual(0.0, np.sum(physics.named.data.qpos["free"][3:])) + + # Unlimited slide and the positional part of the free joint remains + # uninitialized. + self.assertEqual(0.0, physics.named.data.qpos["slide"]) + self.assertEqual(0.0, np.sum(physics.named.data.qpos["free"][:3])) + + def test_multiple_joints_of_same_type(self): + physics = mujoco.Physics.from_xml_string( + """ + + + + + + + + + """ + ) + + randomizers.randomize_limited_and_rotational_joints(physics, self.rand) + self.assertNotEqual(0.0, physics.named.data.qpos["hinge_1"]) + self.assertNotEqual(0.0, physics.named.data.qpos["hinge_2"]) + self.assertNotEqual(0.0, physics.named.data.qpos["hinge_3"]) + + self.assertNotEqual( + physics.named.data.qpos["hinge_1"], physics.named.data.qpos["hinge_2"] + ) + + self.assertNotEqual( + physics.named.data.qpos["hinge_2"], physics.named.data.qpos["hinge_3"] + ) + + self.assertNotEqual( + physics.named.data.qpos["hinge_1"], physics.named.data.qpos["hinge_3"] + ) + + def test_unlimited_hinge_randomization_range(self): + physics = mujoco.Physics.from_xml_string( + """ + + + + + + + """ + ) + + for _ in range(10): + randomizers.randomize_limited_and_rotational_joints(physics, self.rand) + self.assertBetween(physics.named.data.qpos["hinge"], -np.pi, np.pi) + + def test_limited_1d_joint_limits_are_respected(self): + physics = mujoco.Physics.from_xml_string( + """ + + + + + + + + + + + """ + ) + + for _ in range(10): + randomizers.randomize_limited_and_rotational_joints(physics, self.rand) + self.assertBetween( + physics.named.data.qpos["hinge"], np.deg2rad(0), np.deg2rad(10) + ) + self.assertBetween(physics.named.data.qpos["slide"], 30, 50) + + def test_limited_ball_joint_are_respected(self): + physics = mujoco.Physics.from_xml_string( + """ + + + + + + + """ + ) + + body_axis = np.array([1.0, 0.0, 0.0]) + joint_axis = np.zeros(3) + for _ in range(10): + randomizers.randomize_limited_and_rotational_joints(physics, self.rand) + + quat = physics.named.data.qpos["ball"] + mjlib.mju_rotVecQuat(joint_axis, body_axis, quat) + angle_cos = np.dot(body_axis, joint_axis) + self.assertGreater(angle_cos, 0.5) # cos(60) = 0.5 + + +if __name__ == "__main__": + absltest.main() diff --git a/local_dm_control_suite/walker.py b/local_dm_control_suite/walker.py new file mode 100755 index 0000000..b9f3199 --- /dev/null +++ b/local_dm_control_suite/walker.py @@ -0,0 +1,190 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Planar Walker Domain.""" + +from __future__ import absolute_import, division, print_function + +import collections + +from dm_control import mujoco +from dm_control.rl import control +from dm_control.suite.utils import randomizers +from dm_control.utils import containers, rewards + +from . import base, common + +_DEFAULT_TIME_LIMIT = 25 +_CONTROL_TIMESTEP = 0.025 + +# Minimal height of torso over foot above which stand reward is 1. +_STAND_HEIGHT = 1.2 + +# Horizontal speeds (meters/second) above which move reward is 1. +_WALK_SPEED = 1 +_RUN_SPEED = 8 + + +SUITE = containers.TaggedTasks() + + +def get_model_and_assets(xml_file_id): + """Returns a tuple containing the model XML string and a dict of assets.""" + if xml_file_id is not None: + filename = f"walker_{xml_file_id}.xml" + print(filename) + else: + filename = f"walker.xml" + return common.read_model(filename), common.ASSETS + + +@SUITE.add("benchmarking") +def stand( + time_limit=_DEFAULT_TIME_LIMIT, + xml_file_id=None, + random=None, + environment_kwargs=None, +): + """Returns the Stand task.""" + physics = Physics.from_xml_string(*get_model_and_assets(xml_file_id)) + task = PlanarWalker(move_speed=0, random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, + task, + time_limit=time_limit, + control_timestep=_CONTROL_TIMESTEP, + **environment_kwargs, + ) + + +@SUITE.add("benchmarking") +def walk( + time_limit=_DEFAULT_TIME_LIMIT, + xml_file_id=None, + random=None, + environment_kwargs=None, +): + """Returns the Walk task.""" + physics = Physics.from_xml_string(*get_model_and_assets(xml_file_id)) + task = PlanarWalker(move_speed=_WALK_SPEED, random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, + task, + time_limit=time_limit, + control_timestep=_CONTROL_TIMESTEP, + **environment_kwargs, + ) + + +@SUITE.add("benchmarking") +def run( + time_limit=_DEFAULT_TIME_LIMIT, + xml_file_id=None, + random=None, + environment_kwargs=None, +): + """Returns the Run task.""" + physics = Physics.from_xml_string(*get_model_and_assets(xml_file_id)) + task = PlanarWalker(move_speed=_RUN_SPEED, random=random) + environment_kwargs = environment_kwargs or {} + return control.Environment( + physics, + task, + time_limit=time_limit, + control_timestep=_CONTROL_TIMESTEP, + **environment_kwargs, + ) + + +class Physics(mujoco.Physics): + """Physics simulation with additional features for the Walker domain.""" + + def torso_upright(self): + """Returns projection from z-axes of torso to the z-axes of world.""" + return self.named.data.xmat["torso", "zz"] + + def torso_height(self): + """Returns the height of the torso.""" + return self.named.data.xpos["torso", "z"] + + def horizontal_velocity(self): + """Returns the horizontal velocity of the center-of-mass.""" + return self.named.data.sensordata["torso_subtreelinvel"][0] + + def orientations(self): + """Returns planar orientations of all bodies.""" + return self.named.data.xmat[1:, ["xx", "xz"]].ravel() + + +class PlanarWalker(base.Task): + """A planar walker task.""" + + def __init__(self, move_speed, random=None): + """Initializes an instance of `PlanarWalker`. + + Args: + move_speed: A float. If this value is zero, reward is given simply for + standing up. Otherwise this specifies a target horizontal velocity for + the walking task. + random: Optional, either a `numpy.random.RandomState` instance, an + integer seed for creating a new `RandomState`, or None to select a seed + automatically (default). + """ + self._move_speed = move_speed + super(PlanarWalker, self).__init__(random=random) + + def initialize_episode(self, physics): + """Sets the state of the environment at the start of each episode. + + In 'standing' mode, use initial orientation and small velocities. + In 'random' mode, randomize joint angles and let fall to the floor. + + Args: + physics: An instance of `Physics`. + + """ + randomizers.randomize_limited_and_rotational_joints(physics, self.random) + super(PlanarWalker, self).initialize_episode(physics) + + def get_observation(self, physics): + """Returns an observation of body orientations, height and velocites.""" + obs = collections.OrderedDict() + obs["orientations"] = physics.orientations() + obs["height"] = physics.torso_height() + obs["velocity"] = physics.velocity() + return obs + + def get_reward(self, physics): + """Returns a reward to the agent.""" + standing = rewards.tolerance( + physics.torso_height(), + bounds=(_STAND_HEIGHT, float("inf")), + margin=_STAND_HEIGHT / 2, + ) + upright = (1 + physics.torso_upright()) / 2 + stand_reward = (3 * standing + upright) / 4 + if self._move_speed == 0: + return stand_reward + else: + move_reward = rewards.tolerance( + physics.horizontal_velocity(), + bounds=(self._move_speed, float("inf")), + margin=self._move_speed / 2, + value_at_margin=0.5, + sigmoid="linear", + ) + return stand_reward * (5 * move_reward + 1) / 6 diff --git a/local_dm_control_suite/walker.xml b/local_dm_control_suite/walker.xml new file mode 100755 index 0000000..9509893 --- /dev/null +++ b/local_dm_control_suite/walker.xml @@ -0,0 +1,70 @@ + + + + + + diff --git a/local_dm_control_suite/walker_friction_1.xml b/local_dm_control_suite/walker_friction_1.xml new file mode 100755 index 0000000..9509893 --- /dev/null +++ b/local_dm_control_suite/walker_friction_1.xml @@ -0,0 +1,70 @@ + + + + + + diff --git a/local_dm_control_suite/walker_friction_10.xml b/local_dm_control_suite/walker_friction_10.xml new file mode 100755 index 0000000..659d8b7 --- /dev/null +++ b/local_dm_control_suite/walker_friction_10.xml @@ -0,0 +1,70 @@ + + + + + + diff --git a/local_dm_control_suite/walker_friction_2.xml b/local_dm_control_suite/walker_friction_2.xml new file mode 100755 index 0000000..51dedbb --- /dev/null +++ b/local_dm_control_suite/walker_friction_2.xml @@ -0,0 +1,70 @@ + + + + + + diff --git a/local_dm_control_suite/walker_friction_3.xml b/local_dm_control_suite/walker_friction_3.xml new file mode 100755 index 0000000..c6b32b4 --- /dev/null +++ b/local_dm_control_suite/walker_friction_3.xml @@ -0,0 +1,70 @@ + + + + + + diff --git a/local_dm_control_suite/walker_friction_4.xml b/local_dm_control_suite/walker_friction_4.xml new file mode 100755 index 0000000..bd3be18 --- /dev/null +++ b/local_dm_control_suite/walker_friction_4.xml @@ -0,0 +1,70 @@ + + + + + + diff --git a/local_dm_control_suite/walker_friction_5.xml b/local_dm_control_suite/walker_friction_5.xml new file mode 100755 index 0000000..ba799b9 --- /dev/null +++ b/local_dm_control_suite/walker_friction_5.xml @@ -0,0 +1,70 @@ + + + + + + diff --git a/local_dm_control_suite/walker_friction_6.xml b/local_dm_control_suite/walker_friction_6.xml new file mode 100755 index 0000000..4b4e739 --- /dev/null +++ b/local_dm_control_suite/walker_friction_6.xml @@ -0,0 +1,70 @@ + + + + + + diff --git a/local_dm_control_suite/walker_friction_7.xml b/local_dm_control_suite/walker_friction_7.xml new file mode 100755 index 0000000..5108ca8 --- /dev/null +++ b/local_dm_control_suite/walker_friction_7.xml @@ -0,0 +1,70 @@ + + + + + + diff --git a/local_dm_control_suite/walker_friction_8.xml b/local_dm_control_suite/walker_friction_8.xml new file mode 100755 index 0000000..86b171d --- /dev/null +++ b/local_dm_control_suite/walker_friction_8.xml @@ -0,0 +1,70 @@ + + + + + + diff --git a/local_dm_control_suite/walker_friction_9.xml b/local_dm_control_suite/walker_friction_9.xml new file mode 100755 index 0000000..94f1651 --- /dev/null +++ b/local_dm_control_suite/walker_friction_9.xml @@ -0,0 +1,70 @@ + + + + + + diff --git a/local_dm_control_suite/walker_len_1.xml b/local_dm_control_suite/walker_len_1.xml new file mode 100755 index 0000000..7704f34 --- /dev/null +++ b/local_dm_control_suite/walker_len_1.xml @@ -0,0 +1,70 @@ + + + + + + diff --git a/local_dm_control_suite/walker_len_10.xml b/local_dm_control_suite/walker_len_10.xml new file mode 100755 index 0000000..1c97362 --- /dev/null +++ b/local_dm_control_suite/walker_len_10.xml @@ -0,0 +1,70 @@ + + + + + + diff --git a/local_dm_control_suite/walker_len_2.xml b/local_dm_control_suite/walker_len_2.xml new file mode 100755 index 0000000..8ad7359 --- /dev/null +++ b/local_dm_control_suite/walker_len_2.xml @@ -0,0 +1,70 @@ + + + + + + diff --git a/local_dm_control_suite/walker_len_3.xml b/local_dm_control_suite/walker_len_3.xml new file mode 100755 index 0000000..9509893 --- /dev/null +++ b/local_dm_control_suite/walker_len_3.xml @@ -0,0 +1,70 @@ + + + + + + diff --git a/local_dm_control_suite/walker_len_4.xml b/local_dm_control_suite/walker_len_4.xml new file mode 100755 index 0000000..660a30a --- /dev/null +++ b/local_dm_control_suite/walker_len_4.xml @@ -0,0 +1,70 @@ + + + + + + diff --git a/local_dm_control_suite/walker_len_5.xml b/local_dm_control_suite/walker_len_5.xml new file mode 100755 index 0000000..faca117 --- /dev/null +++ b/local_dm_control_suite/walker_len_5.xml @@ -0,0 +1,70 @@ + + + + + + diff --git a/local_dm_control_suite/walker_len_6.xml b/local_dm_control_suite/walker_len_6.xml new file mode 100755 index 0000000..273892d --- /dev/null +++ b/local_dm_control_suite/walker_len_6.xml @@ -0,0 +1,70 @@ + + + + + + diff --git a/local_dm_control_suite/walker_len_7.xml b/local_dm_control_suite/walker_len_7.xml new file mode 100755 index 0000000..cd98182 --- /dev/null +++ b/local_dm_control_suite/walker_len_7.xml @@ -0,0 +1,70 @@ + + + + + + diff --git a/local_dm_control_suite/walker_len_8.xml b/local_dm_control_suite/walker_len_8.xml new file mode 100755 index 0000000..44e82bd --- /dev/null +++ b/local_dm_control_suite/walker_len_8.xml @@ -0,0 +1,70 @@ + + + + + + diff --git a/local_dm_control_suite/walker_len_9.xml b/local_dm_control_suite/walker_len_9.xml new file mode 100755 index 0000000..3de95b7 --- /dev/null +++ b/local_dm_control_suite/walker_len_9.xml @@ -0,0 +1,70 @@ + + + + + + diff --git a/local_dm_control_suite/wrappers/__init__.py b/local_dm_control_suite/wrappers/__init__.py new file mode 100755 index 0000000..f7e4a68 --- /dev/null +++ b/local_dm_control_suite/wrappers/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Environment wrappers used to extend or modify environment behaviour.""" diff --git a/local_dm_control_suite/wrappers/action_noise.py b/local_dm_control_suite/wrappers/action_noise.py new file mode 100755 index 0000000..8799bae --- /dev/null +++ b/local_dm_control_suite/wrappers/action_noise.py @@ -0,0 +1,77 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Wrapper control suite environments that adds Gaussian noise to actions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import dm_env +import numpy as np + + +_BOUNDS_MUST_BE_FINITE = ( + "All bounds in `env.action_spec()` must be finite, got: {action_spec}" +) + + +class Wrapper(dm_env.Environment): + """Wraps a control environment and adds Gaussian noise to actions.""" + + def __init__(self, env, scale=0.01): + """Initializes a new action noise Wrapper. + + Args: + env: The control suite environment to wrap. + scale: The standard deviation of the noise, expressed as a fraction + of the max-min range for each action dimension. + + Raises: + ValueError: If any of the action dimensions of the wrapped environment are + unbounded. + """ + action_spec = env.action_spec() + if not ( + np.all(np.isfinite(action_spec.minimum)) + and np.all(np.isfinite(action_spec.maximum)) + ): + raise ValueError(_BOUNDS_MUST_BE_FINITE.format(action_spec=action_spec)) + self._minimum = action_spec.minimum + self._maximum = action_spec.maximum + self._noise_std = scale * (action_spec.maximum - action_spec.minimum) + self._env = env + + def step(self, action): + noisy_action = action + self._env.task.random.normal(scale=self._noise_std) + # Clip the noisy actions in place so that they fall within the bounds + # specified by the `action_spec`. Note that MuJoCo implicitly clips out-of- + # bounds control inputs, but we also clip here in case the actions do not + # correspond directly to MuJoCo actuators, or if there are other wrapper + # layers that expect the actions to be within bounds. + np.clip(noisy_action, self._minimum, self._maximum, out=noisy_action) + return self._env.step(noisy_action) + + def reset(self): + return self._env.reset() + + def observation_spec(self): + return self._env.observation_spec() + + def action_spec(self): + return self._env.action_spec() + + def __getattr__(self, name): + return getattr(self._env, name) diff --git a/local_dm_control_suite/wrappers/action_noise_test.py b/local_dm_control_suite/wrappers/action_noise_test.py new file mode 100755 index 0000000..f2de983 --- /dev/null +++ b/local_dm_control_suite/wrappers/action_noise_test.py @@ -0,0 +1,143 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for the action noise wrapper.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Internal dependencies. +from absl.testing import absltest +from absl.testing import parameterized +from dm_control.rl import control +from dm_control.suite.wrappers import action_noise +from dm_env import specs +import mock +import numpy as np + + +class ActionNoiseTest(parameterized.TestCase): + def make_action_spec(self, lower=(-1.0,), upper=(1.0,)): + lower, upper = np.broadcast_arrays(lower, upper) + return specs.BoundedArray( + shape=lower.shape, dtype=float, minimum=lower, maximum=upper + ) + + def make_mock_env(self, action_spec=None): + action_spec = action_spec or self.make_action_spec() + env = mock.Mock(spec=control.Environment) + env.action_spec.return_value = action_spec + return env + + def assertStepCalledOnceWithCorrectAction(self, env, expected_action): + # NB: `assert_called_once_with()` doesn't support numpy arrays. + env.step.assert_called_once() + actual_action = env.step.call_args_list[0][0][0] + np.testing.assert_array_equal(expected_action, actual_action) + + @parameterized.parameters( + [ + dict(lower=np.r_[-1.0, 0.0], upper=np.r_[1.0, 2.0], scale=0.05), + dict(lower=np.r_[-1.0, 0.0], upper=np.r_[1.0, 2.0], scale=0.0), + dict(lower=np.r_[-1.0, 0.0], upper=np.r_[-1.0, 0.0], scale=0.05), + ] + ) + def test_step(self, lower, upper, scale): + seed = 0 + std = scale * (upper - lower) + expected_noise = np.random.RandomState(seed).normal(scale=std) + action = np.random.RandomState(seed).uniform(lower, upper) + expected_noisy_action = np.clip(action + expected_noise, lower, upper) + task = mock.Mock(spec=control.Task) + task.random = np.random.RandomState(seed) + action_spec = self.make_action_spec(lower=lower, upper=upper) + env = self.make_mock_env(action_spec=action_spec) + env.task = task + wrapped_env = action_noise.Wrapper(env, scale=scale) + time_step = wrapped_env.step(action) + self.assertStepCalledOnceWithCorrectAction(env, expected_noisy_action) + self.assertIs(time_step, env.step(expected_noisy_action)) + + @parameterized.named_parameters( + [ + dict(testcase_name="within_bounds", action=np.r_[-1.0], noise=np.r_[0.1]), + dict(testcase_name="below_lower", action=np.r_[-1.0], noise=np.r_[-0.1]), + dict(testcase_name="above_upper", action=np.r_[1.0], noise=np.r_[0.1]), + ] + ) + def test_action_clipping(self, action, noise): + lower = -1.0 + upper = 1.0 + expected_noisy_action = np.clip(action + noise, lower, upper) + task = mock.Mock(spec=control.Task) + task.random = mock.Mock(spec=np.random.RandomState) + task.random.normal.return_value = noise + action_spec = self.make_action_spec(lower=lower, upper=upper) + env = self.make_mock_env(action_spec=action_spec) + env.task = task + wrapped_env = action_noise.Wrapper(env) + time_step = wrapped_env.step(action) + self.assertStepCalledOnceWithCorrectAction(env, expected_noisy_action) + self.assertIs(time_step, env.step(expected_noisy_action)) + + @parameterized.parameters( + [ + dict(lower=np.r_[-1.0, 0.0], upper=np.r_[1.0, np.inf]), + dict(lower=np.r_[np.nan, 0.0], upper=np.r_[1.0, 2.0]), + ] + ) + def test_error_if_action_bounds_non_finite(self, lower, upper): + action_spec = self.make_action_spec(lower=lower, upper=upper) + env = self.make_mock_env(action_spec=action_spec) + with self.assertRaisesWithLiteralMatch( + ValueError, + action_noise._BOUNDS_MUST_BE_FINITE.format(action_spec=action_spec), + ): + _ = action_noise.Wrapper(env) + + def test_reset(self): + env = self.make_mock_env() + wrapped_env = action_noise.Wrapper(env) + time_step = wrapped_env.reset() + env.reset.assert_called_once_with() + self.assertIs(time_step, env.reset()) + + def test_observation_spec(self): + env = self.make_mock_env() + wrapped_env = action_noise.Wrapper(env) + observation_spec = wrapped_env.observation_spec() + env.observation_spec.assert_called_once_with() + self.assertIs(observation_spec, env.observation_spec()) + + def test_action_spec(self): + env = self.make_mock_env() + wrapped_env = action_noise.Wrapper(env) + # `env.action_spec()` is called in `Wrapper.__init__()` + env.action_spec.reset_mock() + action_spec = wrapped_env.action_spec() + env.action_spec.assert_called_once_with() + self.assertIs(action_spec, env.action_spec()) + + @parameterized.parameters(["task", "physics", "control_timestep"]) + def test_getattr(self, attribute_name): + env = self.make_mock_env() + wrapped_env = action_noise.Wrapper(env) + attr = getattr(wrapped_env, attribute_name) + self.assertIs(attr, getattr(env, attribute_name)) + + +if __name__ == "__main__": + absltest.main() diff --git a/local_dm_control_suite/wrappers/pixels.py b/local_dm_control_suite/wrappers/pixels.py new file mode 100755 index 0000000..6261958 --- /dev/null +++ b/local_dm_control_suite/wrappers/pixels.py @@ -0,0 +1,123 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Wrapper that adds pixel observations to a control environment.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +import dm_env +from dm_env import specs + +STATE_KEY = "state" + + +class Wrapper(dm_env.Environment): + """Wraps a control environment and adds a rendered pixel observation.""" + + def __init__( + self, env, pixels_only=True, render_kwargs=None, observation_key="pixels" + ): + """Initializes a new pixel Wrapper. + + Args: + env: The environment to wrap. + pixels_only: If True (default), the original set of 'state' observations + returned by the wrapped environment will be discarded, and the + `OrderedDict` of observations will only contain pixels. If False, the + `OrderedDict` will contain the original observations as well as the + pixel observations. + render_kwargs: Optional `dict` containing keyword arguments passed to the + `mujoco.Physics.render` method. + observation_key: Optional custom string specifying the pixel observation's + key in the `OrderedDict` of observations. Defaults to 'pixels'. + + Raises: + ValueError: If `env`'s observation spec is not compatible with the + wrapper. Supported formats are a single array, or a dict of arrays. + ValueError: If `env`'s observation already contains the specified + `observation_key`. + """ + if render_kwargs is None: + render_kwargs = {} + + wrapped_observation_spec = env.observation_spec() + + if isinstance(wrapped_observation_spec, specs.Array): + self._observation_is_dict = False + invalid_keys = set([STATE_KEY]) + elif isinstance(wrapped_observation_spec, collections.MutableMapping): + self._observation_is_dict = True + invalid_keys = set(wrapped_observation_spec.keys()) + else: + raise ValueError("Unsupported observation spec structure.") + + if not pixels_only and observation_key in invalid_keys: + raise ValueError( + "Duplicate or reserved observation key {!r}.".format(observation_key) + ) + + if pixels_only: + self._observation_spec = collections.OrderedDict() + elif self._observation_is_dict: + self._observation_spec = wrapped_observation_spec.copy() + else: + self._observation_spec = collections.OrderedDict() + self._observation_spec[STATE_KEY] = wrapped_observation_spec + + # Extend observation spec. + pixels = env.physics.render(**render_kwargs) + pixels_spec = specs.Array( + shape=pixels.shape, dtype=pixels.dtype, name=observation_key + ) + self._observation_spec[observation_key] = pixels_spec + + self._env = env + self._pixels_only = pixels_only + self._render_kwargs = render_kwargs + self._observation_key = observation_key + + def reset(self): + time_step = self._env.reset() + return self._add_pixel_observation(time_step) + + def step(self, action): + time_step = self._env.step(action) + return self._add_pixel_observation(time_step) + + def observation_spec(self): + return self._observation_spec + + def action_spec(self): + return self._env.action_spec() + + def _add_pixel_observation(self, time_step): + if self._pixels_only: + observation = collections.OrderedDict() + elif self._observation_is_dict: + observation = type(time_step.observation)(time_step.observation) + else: + observation = collections.OrderedDict() + observation[STATE_KEY] = time_step.observation + + pixels = self._env.physics.render(**self._render_kwargs) + observation[self._observation_key] = pixels + return time_step._replace(observation=observation) + + def __getattr__(self, name): + return getattr(self._env, name) diff --git a/local_dm_control_suite/wrappers/pixels_test.py b/local_dm_control_suite/wrappers/pixels_test.py new file mode 100755 index 0000000..473ad81 --- /dev/null +++ b/local_dm_control_suite/wrappers/pixels_test.py @@ -0,0 +1,135 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for the pixel wrapper.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +# Internal dependencies. +from absl.testing import absltest +from absl.testing import parameterized +from . import cartpole +from dm_control.suite.wrappers import pixels +import dm_env +from dm_env import specs + +import numpy as np + + +class FakePhysics(object): + def render(self, *args, **kwargs): + del args + del kwargs + return np.zeros((4, 5, 3), dtype=np.uint8) + + +class FakeArrayObservationEnvironment(dm_env.Environment): + def __init__(self): + self.physics = FakePhysics() + + def reset(self): + return dm_env.restart(np.zeros((2,))) + + def step(self, action): + del action + return dm_env.transition(0.0, np.zeros((2,))) + + def action_spec(self): + pass + + def observation_spec(self): + return specs.Array(shape=(2,), dtype=np.float) + + +class PixelsTest(parameterized.TestCase): + @parameterized.parameters(True, False) + def test_dict_observation(self, pixels_only): + pixel_key = "rgb" + + env = cartpole.swingup() + + # Make sure we are testing the right environment for the test. + observation_spec = env.observation_spec() + self.assertIsInstance(observation_spec, collections.OrderedDict) + + width = 320 + height = 240 + + # The wrapper should only add one observation. + wrapped = pixels.Wrapper( + env, + observation_key=pixel_key, + pixels_only=pixels_only, + render_kwargs={"width": width, "height": height}, + ) + + wrapped_observation_spec = wrapped.observation_spec() + self.assertIsInstance(wrapped_observation_spec, collections.OrderedDict) + + if pixels_only: + self.assertLen(wrapped_observation_spec, 1) + self.assertEqual([pixel_key], list(wrapped_observation_spec.keys())) + else: + expected_length = len(observation_spec) + 1 + self.assertLen(wrapped_observation_spec, expected_length) + expected_keys = list(observation_spec.keys()) + [pixel_key] + self.assertEqual(expected_keys, list(wrapped_observation_spec.keys())) + + # Check that the added spec item is consistent with the added observation. + time_step = wrapped.reset() + rgb_observation = time_step.observation[pixel_key] + wrapped_observation_spec[pixel_key].validate(rgb_observation) + + self.assertEqual(rgb_observation.shape, (height, width, 3)) + self.assertEqual(rgb_observation.dtype, np.uint8) + + @parameterized.parameters(True, False) + def test_single_array_observation(self, pixels_only): + pixel_key = "depth" + + env = FakeArrayObservationEnvironment() + observation_spec = env.observation_spec() + self.assertIsInstance(observation_spec, specs.Array) + + wrapped = pixels.Wrapper( + env, observation_key=pixel_key, pixels_only=pixels_only + ) + wrapped_observation_spec = wrapped.observation_spec() + self.assertIsInstance(wrapped_observation_spec, collections.OrderedDict) + + if pixels_only: + self.assertLen(wrapped_observation_spec, 1) + self.assertEqual([pixel_key], list(wrapped_observation_spec.keys())) + else: + self.assertLen(wrapped_observation_spec, 2) + self.assertEqual( + [pixels.STATE_KEY, pixel_key], list(wrapped_observation_spec.keys()) + ) + + time_step = wrapped.reset() + + depth_observation = time_step.observation[pixel_key] + wrapped_observation_spec[pixel_key].validate(depth_observation) + + self.assertEqual(depth_observation.shape, (4, 5, 3)) + self.assertEqual(depth_observation.dtype, np.uint8) + + +if __name__ == "__main__": + absltest.main() diff --git a/mtenv/__init__.py b/mtenv/__init__.py new file mode 100644 index 0000000..7a6a7e8 --- /dev/null +++ b/mtenv/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +__version__ = "1.0" + +from mtenv.core import MTEnv # noqa: F401 +from mtenv.envs.registration import make # noqa: F401 + +__all__ = ["MTEnv", "make"] diff --git a/mtenv/core.py b/mtenv/core.py new file mode 100644 index 0000000..b17285f --- /dev/null +++ b/mtenv/core.py @@ -0,0 +1,212 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +"""Core API of MultiTask Environments for Reinforcement Learning.""" +from abc import ABC, abstractmethod +from typing import List, Optional + +from gym.core import Env +from gym.spaces.dict import Dict as DictSpace +from gym.spaces.space import Space +from numpy.random import RandomState + +from mtenv.utils import seeding +from mtenv.utils.types import ( + ActionType, + ObsType, + StepReturnType, + TaskObsType, + TaskStateType, +) + + +class MTEnv(Env, ABC): # type: ignore[misc] + def __init__( + self, + action_space: Space, + env_observation_space: Space, + task_observation_space: Space, + ) -> None: + """Main class for multitask RL Environments. + + This abstract class extends the OpenAI Gym environment and adds + support for return the task-specific information from the environment. + The observation returned from the single task environments is + encoded as `env_obs` (environment observation) while the task + specific observation is encoded as the `task_obs` (task observation). + The observation returned by `mtenv` is a dictionary of `env_obs` and + `task_obs`. Since this class extends the OpenAI gym, the `mtenv` + API looks similar to the gym API. + + .. code-block:: python + + import mtenv + env = mtenv.make('xxx') + env.reset() + + Any multitask RL environment class should extend/implement this class. + + Args: + action_space (Space) + env_observation_space (Space) + task_observation_space (Space) + """ + + self.action_space = action_space + self.observation_space: DictSpace = DictSpace( + spaces={ + "env_obs": env_observation_space, + "task_obs": task_observation_space, + } + ) + + self.np_random_env: Optional[RandomState] = None + self.np_random_task: Optional[RandomState] = None + + self._task_obs: TaskObsType + + @abstractmethod + def step(self, action: ActionType) -> StepReturnType: + """Execute the action in the environment. + + Args: + action (ActionType) + + Returns: + StepReturnType: Tuple of `multitask observation`, `reward`, + `done`, and `info`. For more information on `multitask observation` + returned by the environment, refer :ref:`multitask_observation`. + """ + pass + + def get_task_obs(self) -> TaskObsType: + """Get the current value of task observation. + + Environment returns task observation everytime we call `step` or + `reset`. This function is useful when the user wants to access the + task observation without acting in (or resetting) the environment. + + Returns: + TaskObsType: + """ + return self._task_obs + + @abstractmethod + def get_task_state(self) -> TaskStateType: + """Return all the information needed to execute the current task + again. + + This function is useful when we want to set the environment to a + previous task. + + Returns: + TaskStateType: For more information on `task_state`, refer :ref:`task_state`. + """ + pass + + @abstractmethod + def set_task_state(self, task_state: TaskStateType) -> None: + """Reset the environment to a particular task. + + `task_state` contains all the information that the environment + needs to switch to any other task. + + Args: + task_state (TaskStateType): For more information on `task_state`, + refer :ref:`task_state`. + """ + pass + + def assert_env_seed_is_set(self) -> None: + """Check that seed (for the environment) is set. + + `reset` function should invoke this function before resetting the + environment (for reproducibility). + + """ + assert self.np_random_env is not None, "please call `seed()` first" + + def assert_task_seed_is_set(self) -> None: + """Check that seed (for the task) is set. + + `sample_task_state` function should invoke this function before + sampling a new task state (for reproducibility). + + """ + assert self.np_random_task is not None, "please call `seed_task()` first" + + @abstractmethod + def reset(self) -> ObsType: + """Reset the environment to some initial state and return the + observation in the new state. + + The subclasses, extending this class, should ensure that the + environment seed is set (by calling `seed(int)`) before invoking this + method (for reproducibility). It can be done by invoking + `self.assert_env_seed_is_set()`. + + Returns: + ObsType: For more information on `multitask observation` + returned by the environment, refer :ref:`multitask_observation`. + """ + pass + + @abstractmethod + def sample_task_state(self) -> TaskStateType: + """Sample a `task_state`. + + `task_state` contains all the information that the environment + needs to switch to any other task. + + The subclasses, extending this class, should ensure that the task + seed is set (by calling `seed(int)`) before invoking this + method (for reproducibility). It can be done by invoking + `self.assert_task_seed_is_set()`. + + Returns: + TaskStateType: For more information on `task_state`, + refer :ref:`task_state`. + """ + pass + + def reset_task_state(self) -> None: + """Sample a new task_state and set the environment to that `task_state`. + + For more information on `task_state`, refer :ref:`task_state`. + """ + self.set_task_state(task_state=self.sample_task_state()) + + def seed(self, seed: Optional[int] = None) -> List[int]: + """Set the seed for the environment's random number generator. + + Invoke `seed_task` to set the seed for the task's + random number generator. + + Args: + seed (Optional[int], optional): Defaults to None. + + Returns: + List[int]: Returns the list of seeds used in the environment's + random number generator. The first value in the list should be + the seed that should be passed to this method for reproducibility. + """ + self.np_random_env, seed = seeding.np_random(seed) + assert isinstance(seed, int) + return [seed] + + def seed_task(self, seed: Optional[int] = None) -> List[int]: + """Set the seed for the task's random number generator. + + Invoke `seed` to set the seed for the environment's + random number generator. + + Args: + seed (Optional[int], optional): Defaults to None. + + Returns: + List[int]: Returns the list of seeds used in the task's + random number generator. The first value in the list should be + the seed that should be passed to this method for reproducibility. + """ + self.np_random_task, seed = seeding.np_random(seed) + assert isinstance(seed, int) + self.observation_space["task_obs"].seed(seed) + return [seed] diff --git a/mtenv/envs/__init__.py b/mtenv/envs/__init__.py new file mode 100644 index 0000000..c49c584 --- /dev/null +++ b/mtenv/envs/__init__.py @@ -0,0 +1,124 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +from copy import deepcopy + +from mtenv.envs.registration import register + +# Control Task +# ---------------------------------------- + +register( + id="MT-CartPole-v0", + entry_point="mtenv.envs.control.cartpole:MTCartPole", + test_kwargs={ + # "valid_env_kwargs": [], + "invalid_env_kwargs": [], + }, +) + + +register( + id="MT-TabularMDP-v0", + entry_point="mtenv.envs.tabular_mdp.tmdp:UniformTMDP", + kwargs={"n_states": 4, "n_actions": 5}, + test_kwargs={ + "valid_env_kwargs": [{"n_states": 3, "n_actions": 2}], + "invalid_env_kwargs": [], + }, +) + +register( + id="MT-Acrobat-v0", + entry_point="mtenv.envs.control.acrobot:MTAcrobot", + test_kwargs={ + # "valid_env_kwargs": [], + "invalid_env_kwargs": [], + }, +) + +register( + id="MT-TwoGoalMaze-v0", + entry_point="mtenv.envs.mpte.two_goal_maze_env:build_two_goal_maze_env", + kwargs={"size_x": 3, "size_y": 3, "task_seed": 169, "n_tasks": 100}, + test_kwargs={ + # "valid_env_kwargs": [], + "invalid_env_kwargs": [], + }, +) + + +# remove it before making the repo public. +default_kwargs = { + "seed": 1, + "visualize_reward": False, + "from_pixels": True, + "height": 84, + "width": 84, + "frame_skip": 2, + "frame_stack": 3, + "sticky_observation_cfg": {}, + "initial_task_state": 1, +} + +for domain_name, task_name, prefix in [ + ("finger", "spin", "size"), + ("cheetah", "run", "torso_length"), + ("walker", "walk", "friction"), + ("walker", "walk", "len"), +]: + file_ids = list(range(1, 11)) + kwargs = deepcopy(default_kwargs) + kwargs["domain_name"] = domain_name + kwargs["task_name"] = task_name + kwargs["xml_file_ids"] = [f"{prefix}_{i}" for i in file_ids] + register( + id=f"MT-HiPBMDP-{domain_name.capitalize()}-{task_name.capitalize()}-vary-{prefix.replace('_', '-')}-v0", + entry_point="mtenv.envs.hipbmdp.env:build", + kwargs=kwargs, + test_kwargs={ + # "valid_env_kwargs": [], + # "invalid_env_kwargs": [], + }, + ) + + +default_kwargs = { + "benchmark": None, + "benchmark_name": "MT10", + "env_id_to_task_map": None, + "should_perform_reward_normalization": True, + "num_copies_per_env": 1, + "initial_task_state": 1, +} + +for benchmark_name in [("MT10"), ("MT50")]: + kwargs = deepcopy(default_kwargs) + kwargs["benchmark_name"] = benchmark_name + register( + id=f"MT-MetaWorld-{benchmark_name}-v0", + entry_point="mtenv.envs.metaworld.env:build", + kwargs=kwargs, + test_kwargs={ + # "valid_env_kwargs": [], + # "invalid_env_kwargs": [], + }, + ) + +kwargs = { + "benchmark": None, + "benchmark_name": "MT1", + "env_id_to_task_map": None, + "should_perform_reward_normalization": True, + "task_name": "pick-place-v1", + "num_copies_per_env": 1, + "initial_task_state": 0, +} +register( + id=f'MT-MetaWorld-{kwargs["benchmark_name"]}-v0', + entry_point="mtenv.envs.metaworld.env:build", + kwargs=kwargs, + test_kwargs={ + # "valid_env_kwargs": [], + # "invalid_env_kwargs": [], + }, +) diff --git a/mtenv/envs/control/README.md b/mtenv/envs/control/README.md new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/mtenv/envs/control/README.md @@ -0,0 +1 @@ + diff --git a/mtenv/envs/control/__init__.py b/mtenv/envs/control/__init__.py new file mode 100644 index 0000000..7f01fad --- /dev/null +++ b/mtenv/envs/control/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from mtenv.envs.control.cartpole import CartPole, MTCartPole # noqa: F401 diff --git a/mtenv/envs/control/acrobot.py b/mtenv/envs/control/acrobot.py new file mode 100644 index 0000000..7ddcda2 --- /dev/null +++ b/mtenv/envs/control/acrobot.py @@ -0,0 +1,330 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree. + + +import numpy as np +from gym import spaces +from numpy import cos, pi, sin + +from mtenv import MTEnv +from mtenv.utils import seeding + +__copyright__ = "Copyright 2013, RLPy http://acl.mit.edu/RLPy" +__credits__ = [ + "Alborz Geramifard", + "Robert H. Klein", + "Christoph Dann", + "William Dabney", + "Jonathan P. How", +] +__license__ = "BSD 3-Clause" +__author__ = "Christoph Dann " + +# SOURCE: +# https://github.com/rlpy/rlpy/blob/master/rlpy/Domains/Acrobot.py + + +class MTAcrobot(MTEnv): + """A acrobot environment with varying characteristics + The task descriptor is composed of values between -1 and +1 and mapped to acrobot physical characcteristics in the + self._mu_to_vars function. + + + """ + + metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": 15} + + dt = 0.2 + + def _mu_to_vars(self, mu): + self.LINK_LENGTH_1 = 1.0 + mu[0] * 0.5 + self.LINK_LENGTH_2 = 1.0 + mu[1] * 0.5 + self.LINK_MASS_1 = 1.0 + mu[2] * 0.5 + self.LINK_MASS_2 = 1.0 + mu[3] * 0.5 + self.LINK_COM_POS_1 = 0.5 + self.LINK_COM_POS_2 = 0.5 + if mu[6] > 0: + self.AVAIL_TORQUE = [-1.0, 0.0, 1.0] + else: + self.AVAIL_TORQUE = [1.0, 0.0, -1.0] + self.LINK_MOI = 1.0 + + torque_noise_max = 0.0 + MAX_VEL_1 = 4 * pi + pi + MAX_VEL_2 = 9 * pi + 2 * pi + + #: use dynamics equations from the nips paper or the book + book_or_nips = "book" + action_arrow = None + domain_fig = None + actions_num = 3 + + def __init__(self): + self.viewer = None + self.action_space = spaces.Discrete(3) + self.state = None + high = np.array( + [1.5, 1.5, 1.5, 1.5, self.MAX_VEL_1, self.MAX_VEL_2], dtype=np.float32 + ) + low = -high + observation_space = spaces.Box(low=low, high=high, dtype=np.float32) + action_space = spaces.Discrete(3) + high = np.array([1.0 for k in range(5)]) + task_space = spaces.Box(-high, high, dtype=np.float32) + super().__init__( + action_space=action_space, + env_observation_space=observation_space, + task_observation_space=task_space, + ) + + def step(self, a): + self.t += 1 + self._mu_to_vars(self.task_state) + s = self.state + torque = self.AVAIL_TORQUE[a] + + # Add noise to the force action + if self.torque_noise_max > 0: + torque += self.np_random_env.uniform( + -self.torque_noise_max, self.torque_noise_max + ) + + # Now, augment the state with our force action so it can be passed to + # _dsdt + s_augmented = np.append(s, torque) + + ns = rk4(self._dsdt, s_augmented, [0, self.dt]) + # only care about final timestep of integration returned by integrator + ns = ns[-1] + ns = ns[:4] # omit action + # ODEINT IS TOO SLOW! + # ns_continuous = integrate.odeint(self._dsdt, self.s_continuous, [0, self.dt]) + # self.s_continuous = ns_continuous[-1] # We only care about the state + # at the ''final timestep'', self.dt + + ns[0] = wrap(ns[0], -pi, pi) + ns[1] = wrap(ns[1], -pi, pi) + ns[2] = bound(ns[2], -self.MAX_VEL_1, self.MAX_VEL_1) + ns[3] = bound(ns[3], -self.MAX_VEL_2, self.MAX_VEL_2) + self.state = ns + terminal = self._terminal() + reward = -1.0 if not terminal else 0.0 + return ( + {"env_obs": self._get_obs(), "task_obs": self.get_task_obs()}, + reward, + terminal, + {}, + ) + + def reset(self): + self._mu_to_vars(self.task_state) + self.state = self.np_random_env.uniform(low=-0.1, high=0.1, size=(4,)) + self.t = 0 + return {"env_obs": self._get_obs(), "task_obs": self.get_task_obs()} + + def get_task_obs(self): + return self.task_state + + def get_task_state(self): + return self.task_state + + def set_task_state(self, task_state): + self.task_state = task_state + + def _get_obs(self): + s = self.state + return [cos(s[0]), sin(s[0]), cos(s[1]), sin(s[1]), s[2], s[3]] + + def _terminal(self): + s = self.state + return bool(-cos(s[0]) - cos(s[1] + s[0]) > 1.0) + + def _dsdt(self, s_augmented, t): + m1 = self.LINK_MASS_1 + m2 = self.LINK_MASS_2 + l1 = self.LINK_LENGTH_1 + lc1 = self.LINK_COM_POS_1 + lc2 = self.LINK_COM_POS_2 + I1 = self.LINK_MOI + I2 = self.LINK_MOI + g = 9.8 + a = s_augmented[-1] + s = s_augmented[:-1] + theta1 = s[0] + theta2 = s[1] + dtheta1 = s[2] + dtheta2 = s[3] + d1 = ( + m1 * lc1 ** 2 + + m2 * (l1 ** 2 + lc2 ** 2 + 2 * l1 * lc2 * cos(theta2)) + + I1 + + I2 + ) + d2 = m2 * (lc2 ** 2 + l1 * lc2 * cos(theta2)) + I2 + phi2 = m2 * lc2 * g * cos(theta1 + theta2 - pi / 2.0) + phi1 = ( + -m2 * l1 * lc2 * dtheta2 ** 2 * sin(theta2) + - 2 * m2 * l1 * lc2 * dtheta2 * dtheta1 * sin(theta2) + + (m1 * lc1 + m2 * l1) * g * cos(theta1 - pi / 2) + + phi2 + ) + if self.book_or_nips == "nips": + # the following line is consistent with the description in the + # paper + ddtheta2 = (a + d2 / d1 * phi1 - phi2) / (m2 * lc2 ** 2 + I2 - d2 ** 2 / d1) + else: + # the following line is consistent with the java implementation and the + # book + ddtheta2 = ( + a + d2 / d1 * phi1 - m2 * l1 * lc2 * dtheta1 ** 2 * sin(theta2) - phi2 + ) / (m2 * lc2 ** 2 + I2 - d2 ** 2 / d1) + ddtheta1 = -(d2 * ddtheta2 + phi1) / d1 + return (dtheta1, dtheta2, ddtheta1, ddtheta2, 0.0) + + def seed(self, env_seed): + self.np_random_env, seed = seeding.np_random(env_seed) + return [seed] + + def seed_task(self, task_seed): + self.np_random_task, seed = seeding.np_random(task_seed) + return [seed] + + def sample_task_state(self): + self.assert_task_seed_is_set() + super().sample_task_state() + new_task_state = [ + self.np_random_task.uniform(-1, 1), + self.np_random_task.uniform(-1, 1), + self.np_random_task.uniform(-1, 1), + self.np_random_task.uniform(-1, 1), + self.np_random_task.uniform(-1, 1), + self.np_random_task.uniform(-1, 1), + self.np_random_task.uniform(-1, 1), + ] + return new_task_state + + +def wrap(x, m, M): + """ + :param x: a scalar + :param m: minimum possible value in range + :param M: maximum possible value in range + Wraps ``x`` so m <= x <= M; but unlike ``bound()`` which + truncates, ``wrap()`` wraps x around the coordinate system defined by m,M.\n + For example, m = -180, M = 180 (degrees), x = 360 --> returns 0. + """ + diff = M - m + while x > M: + x = x - diff + while x < m: + x = x + diff + return x + + +def bound(x, m, M=None): + """ + :param x: scalar + Either have m as scalar, so bound(x,m,M) which returns m <= x <= M *OR* + have m as length 2 vector, bound(x,m, ) returns m[0] <= x <= m[1]. + """ + if M is None: + M = m[1] + m = m[0] + # bound x between min (m) and Max (M) + return min(max(x, m), M) + + +def rk4(derivs, y0, t, *args, **kwargs): + """ + Integrate 1D or ND system of ODEs using 4-th order Runge-Kutta. + This is a toy implementation which may be useful if you find + yourself stranded on a system w/o scipy. Otherwise use + :func:`scipy.integrate`. + *y0* + initial state vector + *t* + sample times + *derivs* + returns the derivative of the system and has the + signature ``dy = derivs(yi, ti)`` + *args* + additional arguments passed to the derivative function + *kwargs* + additional keyword arguments passed to the derivative function + Example 1 :: + ## 2D system + def derivs6(x,t): + d1 = x[0] + 2*x[1] + d2 = -3*x[0] + 4*x[1] + return (d1, d2) + dt = 0.0005 + t = arange(0.0, 2.0, dt) + y0 = (1,2) + yout = rk4(derivs6, y0, t) + Example 2:: + ## 1D system + alpha = 2 + def derivs(x,t): + return -alpha*x + exp(-t) + y0 = 1 + yout = rk4(derivs, y0, t) + If you have access to scipy, you should probably be using the + scipy.integrate tools rather than this function. + """ + + try: + Ny = len(y0) + except TypeError: + yout = np.zeros((len(t),), np.float_) + else: + yout = np.zeros((len(t), Ny), np.float_) + + yout[0] = y0 + + for i in np.arange(len(t) - 1): + + thist = t[i] + dt = t[i + 1] - thist + dt2 = dt / 2.0 + y0 = yout[i] + + k1 = np.asarray(derivs(y0, thist, *args, **kwargs)) + k2 = np.asarray(derivs(y0 + dt2 * k1, thist + dt2, *args, **kwargs)) + k3 = np.asarray(derivs(y0 + dt2 * k2, thist + dt2, *args, **kwargs)) + k4 = np.asarray(derivs(y0 + dt * k3, thist + dt, *args, **kwargs)) + yout[i + 1] = y0 + dt / 6.0 * (k1 + 2 * k2 + 2 * k3 + k4) + return yout + + +class Acrobot(MTAcrobot): + """The original acrobot environment in the MTEnv fashion""" + + def __init__(self): + super().__init__() + + def sample_task_state(self): + self.assert_task_seed_is_set() + super().sample_task_state() + new_task_state = [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ] + return new_task_state + + +if __name__ == "__main__": + env = MTAcrobot() + env.seed(5) + env.seed_task(15) + env.reset_task_state() + obs = env.reset() + print(obs) + done = False + while not done: + obs, rew, done, _ = env.step(np.random.randint(env.action_space.n)) + print(obs) diff --git a/mtenv/envs/control/cartpole.py b/mtenv/envs/control/cartpole.py new file mode 100644 index 0000000..de0dcab --- /dev/null +++ b/mtenv/envs/control/cartpole.py @@ -0,0 +1,202 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +import math + +import numpy as np +from gym import logger, spaces + +from mtenv import MTEnv +from mtenv.utils import seeding + +""" +Classic cart-pole system implemented based on Rich Sutton et al. +Copied from http://incompleteideas.net/sutton/book/code/pole.c +permalink: https://perma.cc/C9ZM-652R +""" + + +class MTCartPole(MTEnv): + """A cartpole environment with varying physical values + (see the self._mu_to_vars function) + """ + + metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": 50} + + def _mu_to_vars(self, mu): + self.gravity = 9.8 + mu[0] * 5 + self.masscart = 1.0 + mu[1] * 0.5 + self.masspole = 0.1 + mu[2] * 0.09 + self.total_mass = self.masspole + self.masscart + self.length = 0.5 + mu[3] * 0.3 + self.polemass_length = self.masspole * self.length + self.force_mag = 10 * mu[4] + if mu[4] == 0: + self.force_mag = 10 + + def __init__(self): + # Angle limit set to 2 * theta_threshold_radians so failing observation is still within bounds + self.x_threshold = 2.4 + self.theta_threshold_radians = 12 * 2 * math.pi / 360 + + high = np.array( + [ + self.x_threshold * 2, + np.finfo(np.float32).max, + self.theta_threshold_radians * 2, + np.finfo(np.float32).max, + ] + ) + observation_space = spaces.Box(-high, high, dtype=np.float32) + action_space = spaces.Discrete(2) + high = np.array([1.0 for k in range(5)]) + task_space = spaces.Box(-high, high, dtype=np.float32) + super().__init__( + action_space=action_space, + env_observation_space=observation_space, + task_observation_space=task_space, + ) + + self.gravity = 9.8 + self.masscart = 1.0 + self.masspole = 0.1 + self.total_mass = self.masspole + self.masscart + self.length = 0.5 # actually half the pole's length + self.polemass_length = self.masspole * self.length + self.force_mag = 10.0 + self.tau = 0.02 # seconds between state updates + self.kinematics_integrator = "euler" + # Angle at which to fail the episode + + self.state = None + self.steps_beyond_done = None + + self.task_state = None + + def step(self, action): + self.t += 1 + self._mu_to_vars(self.task_state) + + assert self.action_space.contains(action), "%r (%s) invalid" % ( + action, + type(action), + ) + state = self.state + x, x_dot, theta, theta_dot = state + force = self.force_mag if action == 1 else -self.force_mag + costheta = math.cos(theta) + sintheta = math.sin(theta) + temp = ( + force + self.polemass_length * theta_dot * theta_dot * sintheta + ) / self.total_mass + thetaacc = (self.gravity * sintheta - costheta * temp) / ( + self.length + * (4.0 / 3.0 - self.masspole * costheta * costheta / self.total_mass) + ) + xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass + if self.kinematics_integrator == "euler": + x = x + self.tau * x_dot + x_dot = x_dot + self.tau * xacc + theta = theta + self.tau * theta_dot + theta_dot = theta_dot + self.tau * thetaacc + else: # semi-implicit euler + x_dot = x_dot + self.tau * xacc + x = x + self.tau * x_dot + theta_dot = theta_dot + self.tau * thetaacc + theta = theta + self.tau * theta_dot + self.state = [x, x_dot, theta, theta_dot] + done = ( + x < -self.x_threshold + or x > self.x_threshold + or theta < -self.theta_threshold_radians + or theta > self.theta_threshold_radians + ) + done = bool(done) + + reward = 0 + + if not done: + reward = 1.0 + elif self.steps_beyond_done is None: + # Pole just fell! + self.steps_beyond_done = 0 + reward = 1.0 + else: + if self.steps_beyond_done == 0: + logger.warn( + "You are calling 'step()' even though this environment has already returned done = True. You should always call 'reset()' once you receive 'done = True' -- any further steps are undefined behavior." + ) + print( + "You are calling 'step()' even though this environment has already returned done = True. You should always call 'reset()' once you receive 'done = True' -- any further steps are undefined behavior." + ) + self.steps_beyond_done += 1 + reward = 0.0 + + return ( + {"env_obs": self.state, "task_obs": self.get_task_obs()}, + reward, + done, + {}, + ) + + def reset(self, **args): + self.assert_env_seed_is_set() + assert self.task_state is not None + + self._mu_to_vars(self.task_state) + self.state = self.np_random_env.uniform(low=-0.05, high=0.05, size=(4,)) + self.steps_beyond_done = None + self.t = 0 + return {"env_obs": self.state, "task_obs": self.get_task_obs()} + + def get_task_obs(self): + return self.task_state + + def get_task_state(self): + return self.task_state + + def set_task_state(self, task_state): + self.task_state = task_state + + def sample_task_state(self): + self.assert_task_seed_is_set() + super().sample_task_state() + new_task_state = [ + self.np_random_task.uniform(-1, 1), + self.np_random_task.uniform(-1, 1), + self.np_random_task.uniform(-1, 1), + self.np_random_task.uniform(-1, 1), + self.np_random_task.uniform(-1, 1), + ] + return new_task_state + + def seed(self, env_seed): + self.np_random_env, seed = seeding.np_random(env_seed) + return [seed] + + def seed_task(self, task_seed): + self.np_random_task, seed = seeding.np_random(task_seed) + return [seed] + + +class CartPole(MTCartPole): + """The original cartpole environment in the MTEnv fashion""" + + def __init__(self): + super().__init__() + + def sample_task_state(self): + new_task_state = [0.0, 0.0, 0.0, 0.0, 0.0] + return new_task_state + + +if __name__ == "__main__": + env = MTCartPole() + env.seed(5) + env.seed_task(15) + env.reset_task_state() + obs = env.reset() + print(obs) + done = False + while not done: + obs, rew, done, _ = env.step(np.random.randint(env.action_space.n)) + print(obs) diff --git a/mtenv/envs/control/requirements.txt b/mtenv/envs/control/requirements.txt new file mode 100644 index 0000000..e69de29 diff --git a/mtenv/envs/control/setup.py b/mtenv/envs/control/setup.py new file mode 100644 index 0000000..0041f21 --- /dev/null +++ b/mtenv/envs/control/setup.py @@ -0,0 +1,28 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from pathlib import Path + +import setuptools + +from mtenv.utils.setup_utils import parse_dependency + +env_name = "control" +path = Path(__file__).parent / "requirements.txt" +requirements = parse_dependency(path) + +with (Path(__file__).parent / "README.md").open() as fh: + long_description = fh.read() + +setuptools.setup( + name=env_name, + version="0.0.1", + install_requires=requirements, + classifiers=[ + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + ], + python_requires=">=3.6", +) diff --git a/mtenv/envs/hipbmdp/README.md b/mtenv/envs/hipbmdp/README.md new file mode 100644 index 0000000..e69de29 diff --git a/mtenv/envs/hipbmdp/__init__.py b/mtenv/envs/hipbmdp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mtenv/envs/hipbmdp/dmc_env.py b/mtenv/envs/hipbmdp/dmc_env.py new file mode 100644 index 0000000..b4b1f07 --- /dev/null +++ b/mtenv/envs/hipbmdp/dmc_env.py @@ -0,0 +1,115 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from typing import Any, Dict + +import gym +from gym.core import Env +from gym.envs.registration import register + +from mtenv.envs.hipbmdp.wrappers import framestack, sticky_observation + + +def _build_env( + domain_name: str, + task_name: str, + seed: int = 1, + xml_file_id: str = "none", + visualize_reward: bool = True, + from_pixels: bool = False, + height: int = 84, + width: int = 84, + camera_id: int = 0, + frame_skip: int = 1, + environment_kwargs: Any = None, + episode_length: int = 1000, +) -> Env: + if xml_file_id is None: + env_id = "dmc_%s_%s_%s-v1" % (domain_name, task_name, seed) + else: + env_id = "dmc_%s_%s_%s_%s-v1" % (domain_name, task_name, xml_file_id, seed) + + if from_pixels: + assert ( + not visualize_reward + ), "cannot use visualize reward when learning from pixels" + + # shorten episode length + max_episode_steps = (episode_length + frame_skip - 1) // frame_skip + + if env_id not in gym.envs.registry.env_specs: + register( + id=env_id, + entry_point="mtenv.envs.hipbmdp.wrappers.dmc_wrapper:DMCWrapper", + kwargs={ + "domain_name": domain_name, + "task_name": task_name, + "task_kwargs": {"random": seed, "xml_file_id": xml_file_id}, + "environment_kwargs": environment_kwargs, + "visualize_reward": visualize_reward, + "from_pixels": from_pixels, + "height": height, + "width": width, + "camera_id": camera_id, + "frame_skip": frame_skip, + }, + max_episode_steps=max_episode_steps, + ) + return gym.make(env_id) + + +def build_dmc_env( + domain_name: str, + task_name: str, + seed: int, + xml_file_id: str, + visualize_reward: bool, + from_pixels: bool, + height: int, + width: int, + frame_skip: int, + frame_stack: int, + sticky_observation_cfg: Dict[str, Any], +) -> Env: + """Build a single DMC environment as described in + :cite:`tassa2020dmcontrol`. + + Args: + domain_name (str): name of the domain. + task_name (str): name of the task. + seed (int): environment seed (for reproducibility). + xml_file_id (str): id of the xml file to use. + visualize_reward (bool): should visualize reward ? + from_pixels (bool): return pixel observations? + height (int): height of pixel frames. + width (int): width of pixel frames. + frame_skip (int): should skip frames? + frame_stack (int): should stack frames together? + sticky_observation_cfg (Dict[str, Any]): Configuration for using + sticky observations. It should be a dictionary with three + keys, `should_use` which specifies if the config should be + used, `sticky_probability` which specifies the probability of + choosing a previous task and `last_k` which specifies the + number of previous frames to choose from. + + Returns: + Env: + """ + env = _build_env( + domain_name=domain_name, + task_name=task_name, + seed=seed, + visualize_reward=visualize_reward, + from_pixels=from_pixels, + height=height, + width=width, + frame_skip=frame_skip, + xml_file_id=xml_file_id, + ) + if from_pixels: + env = framestack.FrameStack(env, k=frame_stack) + if sticky_observation_cfg and sticky_observation_cfg["should_use"]: + env = sticky_observation.StickyObservation( # type: ignore[attr-defined] + env=env, + sticky_probability=sticky_observation_cfg["sticky_probability"], + last_k=sticky_observation_cfg["last_k"], + ) + return env diff --git a/mtenv/envs/hipbmdp/env.py b/mtenv/envs/hipbmdp/env.py new file mode 100644 index 0000000..f8af155 --- /dev/null +++ b/mtenv/envs/hipbmdp/env.py @@ -0,0 +1,81 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from typing import Any, Callable, Dict, List + +from gym.core import Env + +from mtenv import MTEnv +from mtenv.envs.hipbmdp import dmc_env +from mtenv.envs.shared.wrappers.multienv import MultiEnvWrapper + +EnvBuilderType = Callable[[], Env] +TaskStateType = int +TaskObsType = int + + +def build( + domain_name: str, + task_name: str, + seed: int, + xml_file_ids: List[str], + visualize_reward: bool, + from_pixels: bool, + height: int, + width: int, + frame_skip: int, + frame_stack: int, + sticky_observation_cfg: Dict[str, Any], + initial_task_state: int = 1, +) -> MTEnv: + """Build multitask environment as described in HiPBMDP paper. See + :cite:`mtrl_as_a_hidden_block_mdp` for more details. + + Args: + domain_name (str): name of the domain. + task_name (str): name of the task. + seed (int): environment seed (for reproducibility). + xml_file_ids (List[str]): ids of xml files. + visualize_reward (bool): should visualize reward ? + from_pixels (bool): return pixel observations? + height (int): height of pixel frames. + width (int): width of pixel frames. + frame_skip (int): should skip frames? + frame_stack (int): should stack frames together? + sticky_observation_cfg (Dict[str, Any]): Configuration for using + sticky observations. It should be a dictionary with three + keys, `should_use` which specifies if the config should be + used, `sticky_probability` which specifies the probability of + choosing a previous task and `last_k` which specifies the + number of previous frames to choose from. + initial_task_state (int, optional): intial task/environment + to select. Defaults to 1. + + Returns: + MTEnv: + """ + + def get_func_to_make_envs(xml_file_id: str) -> EnvBuilderType: + def _func() -> Env: + return dmc_env.build_dmc_env( + domain_name=domain_name, + task_name=task_name, + seed=seed, + xml_file_id=xml_file_id, + visualize_reward=visualize_reward, + from_pixels=from_pixels, + height=height, + width=width, + frame_skip=frame_skip, + frame_stack=frame_stack, + sticky_observation_cfg=sticky_observation_cfg, + ) + + return _func + + funcs_to_make_envs = [ + get_func_to_make_envs(xml_file_id=file_id) for file_id in xml_file_ids + ] + + mtenv = MultiEnvWrapper( + funcs_to_make_envs=funcs_to_make_envs, initial_task_state=initial_task_state + ) + return mtenv diff --git a/mtenv/envs/hipbmdp/requirements.txt b/mtenv/envs/hipbmdp/requirements.txt new file mode 100644 index 0000000..67cf797 --- /dev/null +++ b/mtenv/envs/hipbmdp/requirements.txt @@ -0,0 +1 @@ +git+git://github.com/denisyarats/dmc2gym.git@62ca3d886eb59a1927720d036be210bedc6d9f48#egg=dmc2gym \ No newline at end of file diff --git a/mtenv/envs/hipbmdp/setup.py b/mtenv/envs/hipbmdp/setup.py new file mode 100644 index 0000000..1a1cdb6 --- /dev/null +++ b/mtenv/envs/hipbmdp/setup.py @@ -0,0 +1,28 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from pathlib import Path + +import setuptools + +from mtenv.utils.setup_utils import parse_dependency + +env_name = "hipbmdp" +path = Path(__file__).parent / "requirements.txt" +requirements = parse_dependency(path) + + +with (Path(__file__).parent / "README.md").open() as fh: + long_description = fh.read() + +setuptools.setup( + name=env_name, + version="0.0.1", + install_requires=requirements, + classifiers=[ + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + ], + python_requires=">=3.6", +) diff --git a/mtenv/envs/hipbmdp/wrappers/__init__.py b/mtenv/envs/hipbmdp/wrappers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mtenv/envs/hipbmdp/wrappers/dmc_wrapper.py b/mtenv/envs/hipbmdp/wrappers/dmc_wrapper.py new file mode 100644 index 0000000..7329aec --- /dev/null +++ b/mtenv/envs/hipbmdp/wrappers/dmc_wrapper.py @@ -0,0 +1,80 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +from typing import Any, Dict, Optional + +import dmc2gym +import numpy as np +from dmc2gym.wrappers import DMCWrapper as BaseDMCWrapper +from gym import spaces + +import local_dm_control_suite as local_dmc_suite + + +class DMCWrapper(BaseDMCWrapper): + def __init__( + self, + domain_name: str, + task_name: str, + task_kwargs: Any = None, + visualize_reward: Optional[Dict[str, Any]] = None, + from_pixels: bool = False, + height=84, + width: int = 84, + camera_id: int = 0, + frame_skip: int = 1, + environment_kwargs: Any = None, + channels_first: bool = True, + ): + """This wrapper is based on implementation from + https://github.com/denisyarats/dmc2gym/blob/master/dmc2gym/wrappers.py#L37 + + We extend the wrapper so that we can use the modified version of + `dm_control_suite`. + """ + assert ( + "random" in task_kwargs # type: ignore [operator] + ), "please specify a seed, for deterministic behaviour" + self._from_pixels = from_pixels + self._height = height + self._width = width + self._camera_id = camera_id + self._frame_skip = frame_skip + self._channels_first = channels_first + if visualize_reward is None: + visualize_reward = {} + # create task + self._env = local_dmc_suite.load( + domain_name=domain_name, + task_name=task_name, + task_kwargs=task_kwargs, + visualize_reward=visualize_reward, + environment_kwargs=environment_kwargs, + ) + + # true and normalized action spaces + self._true_action_space = dmc2gym.wrappers._spec_to_box( + [self._env.action_spec()] + ) + self._norm_action_space = spaces.Box( + low=-1.0, high=1.0, shape=self._true_action_space.shape, dtype=np.float32 + ) + + # create observation space + if from_pixels: + shape = [3, height, width] if channels_first else [height, width, 3] + self._observation_space = spaces.Box( + low=0, high=255, shape=shape, dtype=np.uint8 + ) + else: + self._observation_space = dmc2gym.wrappers._spec_to_box( + self._env.observation_spec().values() + ) + + self._state_space = dmc2gym.wrappers._spec_to_box( + self._env.observation_spec().values() + ) + + self.current_state = None + + # set seed + self.seed(seed=task_kwargs["random"]) # type: ignore [index] diff --git a/mtenv/envs/hipbmdp/wrappers/framestack.py b/mtenv/envs/hipbmdp/wrappers/framestack.py new file mode 100644 index 0000000..36e0326 --- /dev/null +++ b/mtenv/envs/hipbmdp/wrappers/framestack.py @@ -0,0 +1,47 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +"""Wrapper to stack observations for single task environments.""" + +from collections import deque + +import gym +import numpy as np + +from mtenv.utils.types import ActionType, EnvStepReturnType + + +class FrameStack(gym.Wrapper): # type: ignore[misc] + # Mypy error: Class cannot subclass 'Wrapper' (has type 'Any') [misc] + + def __init__(self, env: gym.core.Env, k: int): + """Wrapper to stack observations for single task environments. + + Args: + env (gym.core.Env): Single Task Environment + k (int): number of frames to stack. + """ + gym.Wrapper.__init__(self, env) + self._k = k + self._frames: deque = deque([], maxlen=k) + shp = env.observation_space.shape + self.observation_space = gym.spaces.Box( + low=0, + high=1, + shape=((shp[0] * k,) + shp[1:]), + dtype=env.observation_space.dtype, + ) + self._max_episode_steps = env._max_episode_steps + + def reset(self) -> np.ndarray: + obs = self.env.reset() + for _ in range(self._k): + self._frames.append(obs) + return self._get_obs() + + def step(self, action: ActionType) -> EnvStepReturnType: + obs, reward, done, info = self.env.step(action) + self._frames.append(obs) + return self._get_obs(), reward, done, info + + def _get_obs(self) -> np.ndarray: + assert len(self._frames) == self._k + return np.concatenate(list(self._frames), axis=0) diff --git a/mtenv/envs/hipbmdp/wrappers/sticky_observation.py b/mtenv/envs/hipbmdp/wrappers/sticky_observation.py new file mode 100644 index 0000000..aac117b --- /dev/null +++ b/mtenv/envs/hipbmdp/wrappers/sticky_observation.py @@ -0,0 +1,56 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +"""Wrapper to enable sitcky observations for single task environments.""" +# type: ignore +import random +from collections import deque + +import gym + + +class StickyObservation(gym.Wrapper): + def __init__(self, env: gym.Env, sticky_probability: float, last_k: int): + """Env wrapper that returns a previous observation with probability + `p` and the current observation with a probability `1-p`. `last_k` + previous observations are stored. + + Args: + env (gym.Env): Single task environment. + sticky_probability (float): Probability `p` for returning a + previous observation. + last_k (int): Number of previous observations to store. + + Raises: + ValueError: Raise a ValueError if `sticky_probability` is + not in range `[0, 1]`. + """ + super().__init__(self, env) + if 1 >= sticky_probability >= 0: + self._sticky_probability = sticky_probability + else: + raise ValueError( + f"sticky_probability = {sticky_probability} is not in the interval [0, 1]." + ) + self._last_k = last_k + 1 + self._observations: deque = deque([], maxlen=self._last_k) + self.observation_space = env.observation_space + self._max_episode_steps = env._max_episode_steps + + def reset(self): + obs = self.env.reset() + for _ in range(self._last_k): + self._observations.append(obs) + return self._get_obs() + + def step(self, action): + obs, reward, done, info = self.env.step(action) + self._observations.append(obs) + return self._get_obs(), reward, done, info + + def _get_obs(self): + assert len(self._observations) == self._last_k + should_choose_old_observation = random.random() < self._sticky_probability + if should_choose_old_observation: + index = random.randint(0, self._last_k - 2) + return self._observations[index] + else: + return self._observations[-1] diff --git a/mtenv/envs/metaworld/README.md b/mtenv/envs/metaworld/README.md new file mode 100644 index 0000000..e69de29 diff --git a/mtenv/envs/metaworld/__init__.py b/mtenv/envs/metaworld/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mtenv/envs/metaworld/env.py b/mtenv/envs/metaworld/env.py new file mode 100644 index 0000000..85a52d8 --- /dev/null +++ b/mtenv/envs/metaworld/env.py @@ -0,0 +1,197 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import random +from typing import Any, Callable, Dict, List, Optional, Tuple + +import metaworld +from gym import Env + +from mtenv import MTEnv +from mtenv.envs.metaworld.wrappers.normalized_env import ( # type: ignore[attr-defined] + NormalizedEnvWrapper, +) +from mtenv.envs.shared.wrappers.multienv import MultiEnvWrapper + +EnvBuilderType = Callable[[], Env] +TaskStateType = int +TaskObsType = int +EnvIdToTaskMapType = Dict[str, metaworld.Task] + + +class MetaWorldMTWrapper(MultiEnvWrapper): + def __init__( + self, + funcs_to_make_envs: List[EnvBuilderType], + initial_task_state: TaskStateType, + env_id_to_task_map: EnvIdToTaskMapType, + ) -> None: + """Wrapper to make MetaWorld environment compatible with Multitask + Environment API. See :cite:`yu2020meta` for more details about + MetaWorld. + + Args: + funcs_to_make_envs (List[EnvBuilderType]): list of constructor + functions to make the environments. + initial_task_state (TaskStateType): initial task/environment + to select. + env_id_to_task_map (EnvIdToTaskMapType): In MetaWorld, each + environment can be associated with multiple tasks. This + dict persists the mapping between environment ids and tasks. + """ + super().__init__( + funcs_to_make_envs=funcs_to_make_envs, + initial_task_state=initial_task_state, + ) + self.env_id_to_task_map = env_id_to_task_map + + +def get_list_of_func_to_make_envs( + benchmark: Optional[metaworld.Benchmark], + benchmark_name: str, + env_id_to_task_map: Optional[EnvIdToTaskMapType], + should_perform_reward_normalization: bool = True, + task_name: str = "pick-place-v1", + num_copies_per_env: int = 1, +) -> Tuple[List[Any], Dict[str, Any]]: + """Return a list of functions to construct the MetaWorld environments + and a mapping of environment ids to tasks. + + Args: + benchmark (Optional[metaworld.Benchmark]): `benchmark` to create + tasks from. + benchmark_name (str): name of the `benchmark`. This is used only + when the `benchmark` is None. + env_id_to_task_map (Optional[EnvIdToTaskMapType]): In MetaWorld, + each environment can be associated with multiple tasks. This + dict persists the mapping between environment ids and tasks. + should_perform_reward_normalization (bool, optional): Defaults to + True. + task_name (str, optional): In case of MT1, only . Defaults to + "pick-place-v1". + num_copies_per_env (int, optional): Number of copies to create for + each environment. Defaults to 1. + + Raises: + ValueError: if `benchmark` is None and `benchmark_name` is not + MT1, MT10, or MT50. + + Returns: + Tuple[List[Any], Dict[str, Any]]: A tuple of two elements. The + first element is a list of functions to construct the MetaWorld + environments and the second is a mapping of environment ids + to tasks. + + """ + if not benchmark: + if benchmark_name == "MT1": + benchmark = metaworld.ML1(task_name) + elif benchmark_name == "MT10": + benchmark = metaworld.MT10() + elif benchmark_name == "MT50": + benchmark = metaworld.MT50() + else: + raise ValueError(f"benchmark_name={benchmark_name} is not valid.") + + env_id_list = list(benchmark.train_classes.keys()) + + def _get_class_items(current_benchmark): + return current_benchmark.train_classes.items() + + def _get_tasks(current_benchmark): + return current_benchmark.train_tasks + + def _get_env_id_to_task_map() -> EnvIdToTaskMapType: + env_id_to_task_map: EnvIdToTaskMapType = {} + current_benchmark = benchmark + for env_id in env_id_list: + for name, _ in _get_class_items(current_benchmark): + if name == env_id: + task = random.choice( + [ + task + for task in _get_tasks(current_benchmark) + if task.env_name == name + ] + ) + env_id_to_task_map[env_id] = task + return env_id_to_task_map + + if env_id_to_task_map is None: + env_id_to_task_map: EnvIdToTaskMapType = _get_env_id_to_task_map() # type: ignore[no-redef] + assert env_id_to_task_map is not None + + def get_func_to_make_envs(env_id: str): + current_benchmark = benchmark + + def _make_env(): + for name, env_cls in _get_class_items(current_benchmark): + if name == env_id: + env = env_cls() + task = env_id_to_task_map[env_id] + env.set_task(task) + if should_perform_reward_normalization: + env = NormalizedEnvWrapper(env, normalize_reward=True) + return env + + return _make_env + + if num_copies_per_env > 1: + env_id_list = [ + [env_id for _ in range(num_copies_per_env)] for env_id in env_id_list + ] + env_id_list = [ + env_id for env_id_sublist in env_id_list for env_id in env_id_sublist + ] + + funcs_to_make_envs = [get_func_to_make_envs(env_id) for env_id in env_id_list] + + return funcs_to_make_envs, env_id_to_task_map + + +def build( + benchmark: Optional[metaworld.Benchmark], + benchmark_name: str, + env_id_to_task_map: Optional[EnvIdToTaskMapType], + should_perform_reward_normalization: bool = True, + task_name: str = "pick-place-v1", + num_copies_per_env: int = 1, + initial_task_state: int = 1, +) -> MTEnv: + """Build a MTEnv comptaible variant of MetaWorld. + + Args: + benchmark (Optional[metaworld.Benchmark]): `benchmark` to create + tasks from. + benchmark_name (str): name of the `benchmark`. This is used only + when the `benchmark` is None. + env_id_to_task_map (Optional[EnvIdToTaskMapType]): In MetaWorld, + each environment can be associated with multiple tasks. This + dict persists the mapping between environment ids and tasks. + should_perform_reward_normalization (bool, optional): Defaults to + True. + task_name (str, optional): In case of MT1, only . Defaults to + "pick-place-v1". + num_copies_per_env (int, optional): Number of copies to create for + each environment. Defaults to 1. + initial_task_state (int, optional): initial task/environment to + select. Defaults to 1. + + Returns: + MTEnv: + """ + funcs_to_make_envs, env_id_to_task_map = get_list_of_func_to_make_envs( + benchmark=benchmark, + benchmark_name=benchmark_name, + env_id_to_task_map=env_id_to_task_map, + should_perform_reward_normalization=should_perform_reward_normalization, + task_name=task_name, + num_copies_per_env=num_copies_per_env, + ) + + assert env_id_to_task_map is not None + + mtenv = MetaWorldMTWrapper( + funcs_to_make_envs=funcs_to_make_envs, + initial_task_state=initial_task_state, + env_id_to_task_map=env_id_to_task_map, + ) + return mtenv diff --git a/mtenv/envs/metaworld/requirements.txt b/mtenv/envs/metaworld/requirements.txt new file mode 100644 index 0000000..cbbc8bc --- /dev/null +++ b/mtenv/envs/metaworld/requirements.txt @@ -0,0 +1 @@ +git+https://github.com/rlworkgroup/metaworld.git@af8417bfc82a3e249b4b02156518d775f29eb289#egg=metaworld \ No newline at end of file diff --git a/mtenv/envs/metaworld/setup.py b/mtenv/envs/metaworld/setup.py new file mode 100644 index 0000000..406705d --- /dev/null +++ b/mtenv/envs/metaworld/setup.py @@ -0,0 +1,28 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from pathlib import Path + +import setuptools + +from mtenv.utils.setup_utils import parse_dependency + +env_name = "metaworld" +path = Path(__file__).parent / "requirements.txt" +requirements = parse_dependency(path) + + +with (Path(__file__).parent / "README.md").open() as fh: + long_description = fh.read() + +setuptools.setup( + name=env_name, + version="0.0.1", + install_requires=requirements, + classifiers=[ + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + ], + python_requires=">=3.6", +) diff --git a/mtenv/envs/metaworld/wrappers/__init__.py b/mtenv/envs/metaworld/wrappers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mtenv/envs/metaworld/wrappers/normalized_env.py b/mtenv/envs/metaworld/wrappers/normalized_env.py new file mode 100644 index 0000000..2569b39 --- /dev/null +++ b/mtenv/envs/metaworld/wrappers/normalized_env.py @@ -0,0 +1,169 @@ +# This code is taken from: https://raw.githubusercontent.com/rlworkgroup/garage/af57bf9c6b10cd733cb0fa9bfe3abd0ba239fd6e/src/garage/envs/normalized_env.py +# +# """"An environment wrapper that normalizes action, observation and reward.""" +# type: ignore +import gym +import gym.spaces +import gym.spaces.utils +import numpy as np + + +class NormalizedEnvWrapper(gym.Wrapper): + """An environment wrapper for normalization. + + This wrapper normalizes action, and optionally observation and reward. + + Args: + env (garage.envs.GarageEnv): An environment instance. + scale_reward (float): Scale of environment reward. + normalize_obs (bool): If True, normalize observation. + normalize_reward (bool): If True, normalize reward. scale_reward is + applied after normalization. + expected_action_scale (float): Assuming action falls in the range of + [-expected_action_scale, expected_action_scale] when normalize it. + flatten_obs (bool): Flatten observation if True. + obs_alpha (float): Update rate of moving average when estimating the + mean and variance of observations. + reward_alpha (float): Update rate of moving average when estimating the + mean and variance of rewards. + + """ + + def __init__( + self, + env, + scale_reward=1.0, + normalize_obs=False, + normalize_reward=False, + expected_action_scale=1.0, + flatten_obs=True, + obs_alpha=0.001, + reward_alpha=0.001, + ): + super().__init__(env) + + self._scale_reward = scale_reward + self._normalize_obs = normalize_obs + self._normalize_reward = normalize_reward + self._expected_action_scale = expected_action_scale + self._flatten_obs = flatten_obs + + self._obs_alpha = obs_alpha + flat_obs_dim = gym.spaces.utils.flatdim(env.observation_space) + self._obs_mean = np.zeros(flat_obs_dim) + self._obs_var = np.ones(flat_obs_dim) + + self._reward_alpha = reward_alpha + self._reward_mean = 0.0 + self._reward_var = 1.0 + + def _update_obs_estimate(self, obs): + flat_obs = gym.spaces.utils.flatten(self.env.observation_space, obs) + self._obs_mean = ( + 1 - self._obs_alpha + ) * self._obs_mean + self._obs_alpha * flat_obs + self._obs_var = ( + 1 - self._obs_alpha + ) * self._obs_var + self._obs_alpha * np.square(flat_obs - self._obs_mean) + + def _update_reward_estimate(self, reward): + self._reward_mean = ( + 1 - self._reward_alpha + ) * self._reward_mean + self._reward_alpha * reward + self._reward_var = ( + 1 - self._reward_alpha + ) * self._reward_var + self._reward_alpha * np.square( + reward - self._reward_mean + ) + + def _apply_normalize_obs(self, obs): + """Compute normalized observation. + + Args: + obs (np.ndarray): Observation. + + Returns: + np.ndarray: Normalized observation. + + """ + self._update_obs_estimate(obs) + flat_obs = gym.spaces.utils.flatten(self.env.observation_space, obs) + normalized_obs = (flat_obs - self._obs_mean) / (np.sqrt(self._obs_var) + 1e-8) + if not self._flatten_obs: + normalized_obs = gym.spaces.utils.unflatten( + self.env.observation_space, normalized_obs + ) + return normalized_obs + + def _apply_normalize_reward(self, reward): + """Compute normalized reward. + + Args: + reward (float): Reward. + + Returns: + float: Normalized reward. + + """ + self._update_reward_estimate(reward) + return reward / (np.sqrt(self._reward_var) + 1e-8) + + def reset(self, **kwargs): + """Reset environment. + + Args: + **kwargs: Additional parameters for reset. + + Returns: + tuple: + * observation (np.ndarray): The observation of the environment. + * reward (float): The reward acquired at this time step. + * done (boolean): Whether the environment was completed at this + time step. + * infos (dict): Environment-dependent additional information. + + """ + ret = self.env.reset(**kwargs) + if self._normalize_obs: + return self._apply_normalize_obs(ret) + else: + return ret + + def step(self, action): + """Feed environment with one step of action and get result. + + Args: + action (np.ndarray): An action fed to the environment. + + Returns: + tuple: + * observation (np.ndarray): The observation of the environment. + * reward (float): The reward acquired at this time step. + * done (boolean): Whether the environment was completed at this + time step. + * infos (dict): Environment-dependent additional information. + + """ + if isinstance(self.action_space, gym.spaces.Box): + # rescale the action when the bounds are not inf + lb, ub = self.action_space.low, self.action_space.high + if np.all(lb != -np.inf) and np.all(ub != -np.inf): + scaled_action = lb + (action + self._expected_action_scale) * ( + 0.5 * (ub - lb) / self._expected_action_scale + ) + scaled_action = np.clip(scaled_action, lb, ub) + else: + scaled_action = action + else: + scaled_action = action + try: + next_obs, reward, done, info = self.env.step(scaled_action) + except Exception as e: + print(e) + + if self._normalize_obs: + next_obs = self._apply_normalize_obs(next_obs) + if self._normalize_reward: + reward = self._apply_normalize_reward(reward) + + return next_obs, reward * self._scale_reward, done, info diff --git a/mtenv/envs/mpte/README.md b/mtenv/envs/mpte/README.md new file mode 100644 index 0000000..e69de29 diff --git a/mtenv/envs/mpte/__init__.py b/mtenv/envs/mpte/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mtenv/envs/mpte/requirements.txt b/mtenv/envs/mpte/requirements.txt new file mode 100644 index 0000000..3781218 --- /dev/null +++ b/mtenv/envs/mpte/requirements.txt @@ -0,0 +1 @@ +gym-miniworld>=2020.1.9 \ No newline at end of file diff --git a/mtenv/envs/mpte/setup.py b/mtenv/envs/mpte/setup.py new file mode 100644 index 0000000..0a04b6a --- /dev/null +++ b/mtenv/envs/mpte/setup.py @@ -0,0 +1,27 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from pathlib import Path + +import setuptools + +from mtenv.utils.setup_utils import parse_dependency + +env_name = "mpte" +path = Path(__file__).parent / "requirements.txt" +requirements = parse_dependency(path) + +with (Path(__file__).parent / "README.md").open() as fh: + long_description = fh.read() + +setuptools.setup( + name=env_name, + version="1.0.0", + install_requires=requirements, + classifiers=[ + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "License :: MIT", + "Operating System :: OS Independent", + ], + python_requires=">=3.6", +) diff --git a/mtenv/envs/mpte/two_goal_maze_env.py b/mtenv/envs/mpte/two_goal_maze_env.py new file mode 100644 index 0000000..d0777df --- /dev/null +++ b/mtenv/envs/mpte/two_goal_maze_env.py @@ -0,0 +1,343 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree. + +import copy +import math +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +from gym import spaces +from gym.spaces.box import Box as BoxSpace +from gym.spaces.dict import Dict as DictSpace +from gym.spaces.discrete import Discrete as DiscreteSpace +from gym_miniworld.entity import Box +from gym_miniworld.miniworld import Agent, MiniWorldEnv +from numpy.random import RandomState + +from mtenv.utils import seeding +from mtenv.utils.types import DoneType, InfoType, RewardType, TaskObsType +from mtenv.wrappers.env_to_mtenv import EnvToMTEnv + +TaskStateType = List[int] + +ActionType = int + +EnvObsType = Dict[str, Union[int, List[int], List[float]]] +ObsType = Dict[str, Union[EnvObsType, TaskObsType]] +StepReturnType = Tuple[ObsType, RewardType, DoneType, InfoType] + + +class MTMiniWorldEnv(EnvToMTEnv): + def make_observation(self, env_obs: EnvObsType) -> ObsType: + raise NotImplementedError + + def get_task_obs(self) -> TaskObsType: + return self.env.get_task_obs() + + def get_task_state(self) -> TaskStateType: + return self.env.task_state + + def set_task_state(self, task_state: TaskStateType) -> None: + self.env.set_task_state(task_state) + + def sample_task_state(self) -> TaskStateType: + return self.env.sample_task_state() + + def reset(self, **kwargs: Dict[str, Any]) -> ObsType: # type: ignore[override] + # signature is incompatible with supertype. + self.assert_env_seed_is_set() + return self.env.reset(**kwargs) + + def step(self, action: ActionType) -> StepReturnType: # type: ignore + return self.env.step(action) + + def assert_env_seed_is_set(self) -> None: + assert self.env.np_random_env is not None, "please call `seed()` first" + + def assert_task_seed_is_set(self) -> None: + assert self.env.np_random_task is not None, "please call `seed_task()` first" + + def seed(self, seed: Optional[int] = None) -> List[int]: + """Set the seed for environment observations""" + return self.env.seed(seed=seed) + + +class TwoGoalMazeEnv(MiniWorldEnv): + + metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": 30} + + def __init__( + self, + size_x=5, + size_y=5, + obs_type="xy", + task_seed=0, + n_tasks=10, + p_change=0.0, + empty_mu=False, + ): + assert p_change == 0.0 + self.empty_mu = empty_mu + self.obs_type = obs_type + self.seed_task(seed=task_seed) + self.np_random_env: Optional[RandomState] = None + self.size_x, self.size_y = size_x, size_y + self.task_state = [] + + super().__init__() + # Allow only movement actions (left/right/forward) + self.action_space = spaces.Discrete(self.actions.move_forward + 1) + if self.obs_type == "xy": + _obs_space = BoxSpace( + low=-np.inf, high=np.inf, shape=(4,), dtype=np.float32 + ) + + else: + _obs_space = BoxSpace( + low=-1.0, + high=1.0, + shape=(64, 64), + dtype=np.float32, + ) + self.observation_space = DictSpace( + { + "obs": _obs_space, + "total_reward": BoxSpace( + low=-np.inf, high=np.inf, shape=(1,), dtype=np.float32 + ), + } + ) + + def assert_env_seed_is_set(self) -> None: + """Check that the env seed is set.""" + assert self.np_random_env is not None, "please call `seed()` first" + + def assert_task_seed_is_set(self) -> None: + """Check that the task seed is set.""" + assert self.np_random_task is not None, "please call `seed_task()` first" + + def seed_task(self, seed: Optional[int] = None) -> List[int]: + """Set the seed for task information""" + self.np_random_task, seed = seeding.np_random(seed) + assert isinstance(seed, int) + return [seed] + + def sample_task_state(self) -> TaskStateType: + self.assert_task_seed_is_set() + return [self.np_random_task.randint(2)] + + def set_task_state(self, task_state: TaskStateType) -> None: + self.task_state = task_state + + def _gen_world(self): + self.reset_task_state() + room1 = self.add_rect_room( + min_x=-self.size_x, + max_x=self.size_x, + min_z=-self.size_y, + max_z=self.size_y, + wall_tex="brick_wall", + ) + self.room1 = room1 + room2 = self.add_rect_room( + min_x=-self.size_x, + max_x=self.size_x, + min_z=self.size_y, + max_z=self.size_y + 1, + wall_tex="cardboard", + ) + self.connect_rooms(room1, room2, min_x=-self.size_x, max_x=self.size_x) + + room3 = self.add_rect_room( + min_x=-self.size_x, + max_x=self.size_x, + min_z=-self.size_y - 1, + max_z=-self.size_y, + wall_tex="lava", + ) + self.connect_rooms(room1, room3, min_x=-self.size_x, max_x=self.size_x) + + room4 = None + if self.task_state[0] == 0: + room4 = self.add_rect_room( + min_x=-self.size_x - 1, + max_x=-self.size_x, + min_z=-self.size_y, + max_z=self.size_y, + wall_tex="wood_planks", + ) + else: + room4 = self.add_rect_room( + min_x=-self.size_x - 1, + max_x=-self.size_x, + min_z=-self.size_y, + max_z=self.size_y, + wall_tex="slime", + ) + + self.connect_rooms(room1, room4, min_z=-self.size_y, max_z=self.size_y) + + room5 = self.add_rect_room( + min_x=self.size_x, + max_x=self.size_x + 1, + min_z=-self.size_y, + max_z=self.size_y, + wall_tex="metal_grill", + ) + + self.connect_rooms(room1, room5, min_z=-self.size_y, max_z=self.size_y) + + self.boxes = [] + self.boxes.append(Box(color="blue")) + self.boxes.append(Box(color="red")) + self.place_entity(self.boxes[0], room=room1) + self.place_entity(self.boxes[1], room=room1) + + # Choose a random room and position to spawn at + _dir = self.np_random_env.randint(8) * (math.pi / 4) - math.pi + self.place_agent( + dir=_dir, + room=room1, + ) + while self._dist() < 2 or self._ndist() < 2: + self.place_agent( + dir=_dir, + room=room1, + ) + + def _dist(self): + bp = self.boxes[int(self.task_state[0])].pos + pos = self.agent.pos + distance = math.sqrt((bp[0] - pos[0]) ** 2 + (bp[2] - pos[2]) ** 2) + + return distance + + def _ndist(self): + bp = self.boxes[1 - int(self.task_state[0])].pos + pos = self.agent.pos + distance = math.sqrt((bp[0] - pos[0]) ** 2 + (bp[2] - pos[2]) ** 2) + + return distance + + def reset(self) -> ObsType: + self.assert_env_seed_is_set() + self.max_episode_steps = 200 + self.treward = 0.0 + self.step_count = 0 + self.agent = Agent() + self.entities: List[Any] = [] + self.rooms: List[Any] = [] + self.wall_segs: List[Any] = [] + self._gen_world() + self.blocked = False + rand = self.rand if self.domain_rand else None + self.params.sample_many( + rand, self, ["sky_color", "light_pos", "light_color", "light_ambient"] + ) + + for ent in self.entities: + ent.randomize(self.params, rand) + + # Compute the min and max x, z extents of the whole floorplan + self.min_x = min(r.min_x for r in self.rooms) + self.max_x = max(r.max_x for r in self.rooms) + self.min_z = min(r.min_z for r in self.rooms) + self.max_z = max(r.max_z for r in self.rooms) + + # Generate static data + if len(self.wall_segs) == 0: + self._gen_static_data() + + # Pre-compile static parts of the environment into a display list + self._render_static() + _pos = [ + (self.agent.pos[0] / self.size_x) * 2.1 - 1.0, + (self.agent.pos[2] / self.size_y) * 2.1 - 1.0, + ] + _dir = [self.agent.dir_vec[0], self.agent.dir_vec[2]] + + if self.obs_type == "xy": + _mu = [0.0] + at = math.atan2(_dir[0], _dir[1]) + o = copy.deepcopy(_pos + [at] + _mu) + else: + o = (self.render_obs() / 255.0) * 2.0 - 1.0 + + return self.make_obs(env_obs=o, total_reward=[0.0]) + + def get_task_obs(self) -> TaskObsType: + mmu = copy.deepcopy(self.task_state) + if self.empty_mu: + mmu = [0.0] + return mmu + + def get_task_state(self) -> TaskStateType: + return self.task_state + + def reset_task_state(self) -> None: + """Sample a new task_state and set that as the new task_state""" + self.set_task_state(task_state=self.sample_task_state()) + + def make_obs(self, env_obs: Any, total_reward: List[float]) -> ObsType: + + return { + "env_obs": {"obs": env_obs, "total_reward": total_reward}, + "task_obs": self.get_task_obs(), + } + + def seed(self, seed: Optional[int] = None) -> List[int]: + """Set the seed for environment observations""" + self.np_random_env, seed = seeding.np_random(seed) + return [seed] + super().seed(seed=seed) + + def step(self, action: ActionType) -> StepReturnType: + self.step_count += 1 + if not self.blocked: + if action == 2: + self.move_agent(0.51, 0.0) # fwd_step, fwd_drift) + elif action == 0: + self.turn_agent(45) + elif action == 1: + self.turn_agent(-45) + reward = 0.0 + done = False + + distance = self._dist() + if distance < 2: + reward = +1.0 + done = True + distance = self._ndist() + if distance < 2: + reward = -1.0 + done = True + _pos = [ + (self.agent.pos[0] / self.size_x) * 2.1 - 1.0, + (self.agent.pos[2] / self.size_y) * 2.1 - 1.0, + ] + _dir = [self.agent.dir_vec[0], self.agent.dir_vec[2]] + + if self.obs_type == "xy": + at = math.atan2(_dir[0], _dir[1]) + _mu = [0.0] + if (at < -1.5 and at > -1.7) and not self.empty_mu: + _mu = [1.0] + if self.task_state[0] == 0: + _mu = [-1.0] + + o = copy.deepcopy(_pos + [at] + _mu) + else: + o = (self.render_obs() / 255.0) * 2.0 - 1.0 + + self.treward += reward + + return self.make_obs(env_obs=o, total_reward=[self.treward]), reward, done, {} + + +def build_two_goal_maze_env(size_x: int, size_y: int, task_seed: int, n_tasks: int): + env = MTMiniWorldEnv( + TwoGoalMazeEnv( + size_x=size_x, size_y=size_y, task_seed=task_seed, n_tasks=n_tasks + ), + task_observation_space=DiscreteSpace(n=1), + ) + return env diff --git a/mtenv/envs/registration.py b/mtenv/envs/registration.py new file mode 100644 index 0000000..3992a8f --- /dev/null +++ b/mtenv/envs/registration.py @@ -0,0 +1,86 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. + +from typing import Any, Dict, Optional + +from gym import error +from gym.core import Env +from gym.envs.registration import EnvRegistry, EnvSpec + + +class MultitaskEnvSpec(EnvSpec): # type: ignore[misc] + def __init__( + self, + id: str, + entry_point: Optional[str] = None, + reward_threshold: Optional[int] = None, + kwargs: Optional[Dict[str, Any]] = None, + nondeterministic: bool = False, + max_episode_steps: Optional[int] = None, + test_kwargs: Optional[Dict[str, Any]] = None, + ): + """A specification for a particular instance of the environment. + Used to register the parameters for official evaluations. + + Args: + id (str): The official environment ID + entry_point (Optional[str]): The Python entrypoint of the + environment class (e.g. module.name:Class) + reward_threshold (Optional[int]): The reward threshold before + the task is considered solved + kwargs (dict): The kwargs to pass to the environment class + nondeterministic (bool): Whether this environment is + non-deterministic even after seeding + max_episode_steps (Optional[int]): The maximum number of steps + that an episode can consist of + test_kwargs (Optional[Dict[str, Any]], optional): Dictionary + to specify parameters for automated testing. Defaults to + None. + + """ + super().__init__( + id=id, + entry_point=entry_point, + reward_threshold=reward_threshold, + nondeterministic=nondeterministic, + max_episode_steps=max_episode_steps, + kwargs=kwargs, + ) + self.test_kwargs = test_kwargs + + def __repr__(self) -> str: + return f"MultitaskEnvSpec({self.id})" + + @property + def kwargs(self) -> Dict[str, Any]: + return self._kwargs # type: ignore[no-any-return] + + +class MultiEnvRegistry(EnvRegistry): # type: ignore[misc] + def __init__(self) -> None: + super().__init__() + + def register(self, id: str, **kwargs: Any) -> None: + if id in self.env_specs: + raise error.Error("Cannot re-register id: {}".format(id)) + self.env_specs[id] = MultitaskEnvSpec(id, **kwargs) + + +# Have a global registry +mtenv_registry = MultiEnvRegistry() + + +def register(id: str, **kwargs: Any) -> None: + return mtenv_registry.register(id, **kwargs) + + +def make(id: str, **kwargs: Any) -> Env: + env = mtenv_registry.make(id, **kwargs) + assert isinstance(env, Env) + return env + + +def spec(id: str) -> MultitaskEnvSpec: + spec = mtenv_registry.spec(id) + assert isinstance(spec, MultitaskEnvSpec) + return spec diff --git a/mtenv/envs/shared/__init__.py b/mtenv/envs/shared/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mtenv/envs/shared/wrappers/__init__.py b/mtenv/envs/shared/wrappers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mtenv/envs/shared/wrappers/multienv.py b/mtenv/envs/shared/wrappers/multienv.py new file mode 100644 index 0000000..617c7f4 --- /dev/null +++ b/mtenv/envs/shared/wrappers/multienv.py @@ -0,0 +1,98 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +"""Wrapper to (lazily) construct a multitask environment from a list of + constructors (list of functions to construct the environments).""" + +from typing import Callable, List, Optional + +from gym.core import Env +from gym.spaces.discrete import Discrete as DiscreteSpace + +from mtenv import MTEnv +from mtenv.utils import seeding +from mtenv.utils.types import ActionType, EnvObsType, ObsType, StepReturnType + +EnvBuilderType = Callable[[], Env] +TaskStateType = int +TaskObsType = int + + +class MultiEnvWrapper(MTEnv): + def __init__( + self, + funcs_to_make_envs: List[EnvBuilderType], + initial_task_state: TaskStateType, + ) -> None: + """Wrapper to (lazily) construct a multitask environment from a + list of constructors (list of functions to construct the + environments). + + The wrapper enables activating/slecting any environment (from the + list of environments that can be created) and that environment is + treated as the current task. The environments are created lazily. + + Note that this wrapper is experimental and may change in the future. + + Args: + funcs_to_make_envs (List[EnvBuilderType]): list of constructor + functions to make the environments. + initial_task_state (TaskStateType): intial task/environment + to select. + """ + self._num_tasks = len(funcs_to_make_envs) + self._funcs_to_make_envs = funcs_to_make_envs + self._envs = [None for _ in range(self._num_tasks)] + self._envs[initial_task_state] = funcs_to_make_envs[initial_task_state]() + self.env: Env = self._envs[initial_task_state] + super().__init__( + action_space=self.env.action_space, + env_observation_space=self.env.observation_space, + task_observation_space=DiscreteSpace(n=self._num_tasks), + ) + self.task_obs: TaskObsType = initial_task_state + + def _make_observation(self, env_obs: EnvObsType) -> ObsType: + return { + "env_obs": env_obs, + "task_obs": self.task_obs, + } + + def step(self, action: ActionType) -> StepReturnType: + env_obs, reward, done, info = self.env.step(action) + return self._make_observation(env_obs=env_obs), reward, done, info + + def get_task_state(self) -> TaskStateType: + return self.task_obs + + def set_task_state(self, task_state: TaskStateType) -> None: + self.task_obs = task_state + if self._envs[task_state] is None: + self._envs[task_state] = self._funcs_to_make_envs[task_state]() + self.env = self._envs[task_state] + + def assert_env_seed_is_set(self) -> None: + """The seed is set during the call to the constructor of self.env""" + pass + + def assert_task_seed_is_set(self) -> None: + assert self.np_random_task is not None, "please call `seed_task()` first" + + def reset(self) -> ObsType: + return self._make_observation(env_obs=self.env.reset()) + + def sample_task_state(self) -> TaskStateType: + self.assert_task_seed_is_set() + task_state = self.np_random_task.randint(self._num_tasks) # type: ignore[union-attr] + # The assert statement (at the start of the function) ensures that self.np_random_task + # is not None. Mypy is raising the warning incorrectly. + assert isinstance(task_state, int) + return task_state + + def reset_task_state(self) -> None: + self.set_task_state(task_state=self.sample_task_state()) + + def seed(self, seed: Optional[int] = None) -> List[int]: + self.np_random_env, seed = seeding.np_random(seed) + env_seeds = self.env.seed(seed) + if isinstance(env_seeds, list): + return [seed] + env_seeds + return [seed] diff --git a/mtenv/envs/tabular_mdp/__init__.py b/mtenv/envs/tabular_mdp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mtenv/envs/tabular_mdp/requirements.txt b/mtenv/envs/tabular_mdp/requirements.txt new file mode 100644 index 0000000..08037e7 --- /dev/null +++ b/mtenv/envs/tabular_mdp/requirements.txt @@ -0,0 +1 @@ +scipy>=1.0.0 \ No newline at end of file diff --git a/mtenv/envs/tabular_mdp/setup.py b/mtenv/envs/tabular_mdp/setup.py new file mode 100644 index 0000000..a85c37f --- /dev/null +++ b/mtenv/envs/tabular_mdp/setup.py @@ -0,0 +1,26 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from pathlib import Path + +import setuptools + +from mtenv.utils.setup_utils import parse_dependency + +env_name = "tabular_mdp" +path = Path(__file__).parent / "requirements.txt" +requirements = parse_dependency(path) + + +setuptools.setup( + name=env_name, + version="1.0.0", + install_requires=requirements, + classifiers=[ + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + ], + python_requires=">=3.6", +) diff --git a/mtenv/envs/tabular_mdp/tmdp.py b/mtenv/envs/tabular_mdp/tmdp.py new file mode 100644 index 0000000..9d2dd41 --- /dev/null +++ b/mtenv/envs/tabular_mdp/tmdp.py @@ -0,0 +1,121 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import scipy.special +from gym import spaces +from gym.utils import seeding + +from mtenv import MTEnv + + +class TMDP(MTEnv): + """Defines a Tabuular MDP where task_state is the reward matrix,transition matrix + reward_matrix is n_states*n_actions and gies the probability of having a reward = +1 when choosing action a in state s (matrix[s,a]) + transition_matrix is n_states*n_actions*n_states and gives the probability of moving to state s' when choosing action a in state s (matrix[s,a,s']) + Args: + MTEnv ([type]): [description] + """ + + def __init__(self, n_states, n_actions): + self.n_states = n_states + self.n_actions = n_actions + + ohigh = np.array([1.0 for n in range(n_states + 1)]) + olow = np.array([0.0 for n in range(n_states + 1)]) + observation_space = spaces.Box(olow, ohigh, dtype=np.float32) + action_space = spaces.Discrete(n_actions) + self.task_state = ( + np.zeros((n_states, n_actions)), + np.zeros((n_states, n_actions, n_states)), + ) + o = self.get_task_obs() + thigh = np.ones((len(o),)) + tlow = np.zeros((len(o),)) + task_space = spaces.Box(tlow, thigh, dtype=np.float32) + super().__init__( + action_space=action_space, + env_observation_space=observation_space, + task_observation_space=task_space, + ) + + # task state is the reward matrix and transition matrix + + def get_task_obs(self): + obs = list(self.task_state[0].flatten()) + list(self.task_state[1].flatten()) + return obs + + def get_task_state(self): + return self.task_state + + def set_task_state(self, task_state): + self.task_state = task_state + + def sample_task_state(self): + raise NotImplementedError + + def seed(self, env_seed): + self.np_random_env, seed = seeding.np_random(env_seed) + return [seed] + + def seed_task(self, task_seed): + self.np_random_task, seed = seeding.np_random(task_seed) + return [seed] + + def step(self, action): + t_reward, t_matrix = self.task_state + reward = 0.0 + + if self.np_random_env.rand() < t_reward[self.state][action]: + reward = 1.0 + self.state = self.np_random_env.multinomial( + 1, t_matrix[self.state][action] + ).argmax() + + obs = np.zeros(self.n_states + 1) + obs[self.state] = 1.0 + obs[-1] = reward + return ( + {"env_obs": list(obs), "task_obs": self.get_task_obs()}, + reward, + False, + {}, + ) + + def reset(self): + self.state = self.np_random_env.randint(self.n_states) + obs = np.zeros(self.n_states + 1) + obs[self.state] = 1.0 + return {"env_obs": list(obs), "task_obs": self.get_task_obs()} + + +class UniformTMDP(TMDP): + def __init__(self, n_states, n_actions): + super().__init__(n_states, n_actions) + + def sample_task_state(self): + self.assert_task_seed_is_set() + t_reward = self.np_random_task.rand(self.n_states, self.n_actions) + t_transitions = self.np_random_task.randn( + self.n_states, self.n_actions, self.n_states + ) + t_transitions = scipy.special.softmax(t_transitions, axis=2) + + new_task_state = t_reward, t_transitions + return new_task_state + + +if __name__ == "__main__": + env = UniformTMDP(3, 2) + env.seed(5) + env.seed_task(14) + env.reset_task_state() + obs = env.reset() + done = False + while not done: + action = np.random.randint(env.action_space.n) + obs, rew, done, _ = env.step(action) + print(obs["env_obs"]) diff --git a/mtenv/utils/__init__.py b/mtenv/utils/__init__.py new file mode 100644 index 0000000..168f997 --- /dev/null +++ b/mtenv/utils/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved diff --git a/mtenv/utils/seeding.py b/mtenv/utils/seeding.py new file mode 100644 index 0000000..90d7c6e --- /dev/null +++ b/mtenv/utils/seeding.py @@ -0,0 +1,19 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from typing import Optional, Tuple + +from gym.utils import seeding +from numpy.random import RandomState + + +def np_random(seed: Optional[int]) -> Tuple[RandomState, int]: + """Set the seed for numpy's random generator. + + Args: + seed (Optional[int]): + + Returns: + Tuple[RandomState, int]: Returns a tuple of random state and seed. + """ + rng, seed = seeding.np_random(seed) + assert isinstance(seed, int) + return rng, seed diff --git a/mtenv/utils/setup_utils.py b/mtenv/utils/setup_utils.py new file mode 100644 index 0000000..a5ad70e --- /dev/null +++ b/mtenv/utils/setup_utils.py @@ -0,0 +1,30 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from pathlib import Path +from typing import List + + +def parse_dependency(filepath: Path) -> List[str]: + """Parse python dependencies from a file. + + The list of dependencies is used by `setup.py` files. Lines starting + with "#" are ingored (useful for writing comments). In case the + dependnecy is host using git, the url is parsed and modified to make + suitable for `setup.py` files. + + + Args: + filepath (Path): + + Returns: + List[str]: List of dependencies + """ + dep_list = [] + for dep in open(filepath).read().splitlines(): + if dep.startswith("#"): + continue + key = "#egg=" + if key in dep: + git_link, egg_name = dep.split(key) + dep = f"{egg_name} @ {git_link}" + dep_list.append(dep) + return dep_list diff --git a/mtenv/utils/types.py b/mtenv/utils/types.py new file mode 100644 index 0000000..a615be6 --- /dev/null +++ b/mtenv/utils/types.py @@ -0,0 +1,15 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from typing import Any, Dict, Tuple, Union + +import numpy as np + +TaskObsType = Union[str, int, float, np.ndarray] +ActionType = Union[str, int, float, np.ndarray] +EnvObsType = Union[np.ndarray] +ObsType = Dict[str, Union[EnvObsType, TaskObsType]] +RewardType = float +DoneType = bool +InfoType = Dict[str, Any] +StepReturnType = Tuple[ObsType, RewardType, DoneType, InfoType] +EnvStepReturnType = Tuple[EnvObsType, RewardType, DoneType, InfoType] +TaskStateType = Any diff --git a/mtenv/wrappers/__init__.py b/mtenv/wrappers/__init__.py new file mode 100644 index 0000000..bd4566a --- /dev/null +++ b/mtenv/wrappers/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from mtenv.wrappers.ntasks import NTasks # noqa: F401 +from mtenv.wrappers.ntasks_id import NTasksId # noqa: F401 +from mtenv.wrappers.sample_random_task import SampleRandomTask # noqa: F401 diff --git a/mtenv/wrappers/env_to_mtenv.py b/mtenv/wrappers/env_to_mtenv.py new file mode 100644 index 0000000..e0d3363 --- /dev/null +++ b/mtenv/wrappers/env_to_mtenv.py @@ -0,0 +1,109 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +"""Wrapper to convert an environment into multitask environment.""" +from typing import Any, Dict, List, Optional + +from gym.core import Env +from gym.spaces.space import Space + +from mtenv import MTEnv +from mtenv.utils import seeding +from mtenv.utils.types import ( + ActionType, + EnvObsType, + ObsType, + StepReturnType, + TaskObsType, + TaskStateType, +) + + +class EnvToMTEnv(MTEnv): + def __init__(self, env: Env, task_observation_space: Space) -> None: + """Wrapper to convert an environment into a multitak environment. + + Args: + env (Env): Environment to wrap over. + task_observation_space (Space): Task observation space for the + resulting multitask environment. + """ + + super().__init__( + action_space=env.action_space, + env_observation_space=env.observation_space, + task_observation_space=task_observation_space, + ) + + self.env = env + self.reward_range = self.env.reward_range + self.metadata = self.env.metadata + + @property + def spec(self) -> Any: + return self.env.spec + + @classmethod + def class_name(cls) -> str: + return cls.__name__ + + def _make_observation(self, env_obs: EnvObsType) -> ObsType: + return {"env_obs": env_obs, "task_obs": self.get_task_obs()} + + def get_task_obs(self) -> TaskObsType: + return self._task_obs + + def get_task_state(self) -> TaskStateType: + raise NotImplementedError + + def set_task_state(self, task_state: TaskStateType) -> None: + raise NotImplementedError + + def sample_task_state(self) -> TaskStateType: + raise NotImplementedError + + def reset(self, **kwargs: Dict[str, Any]) -> ObsType: + self.assert_env_seed_is_set() + env_obs = self.env.reset(**kwargs) + return self._make_observation(env_obs=env_obs) + + def reset_task_state(self) -> None: + self.set_task_state(task_state=self.sample_task_state()) + + def step(self, action: ActionType) -> StepReturnType: + env_obs, reward, done, info = self.env.step(action) + return ( + self._make_observation(env_obs=env_obs), + reward, + done, + info, + ) + + def seed(self, seed: Optional[int] = None) -> List[int]: + self.np_random_env, seed = seeding.np_random(seed) + env_seeds = self.env.seed(seed) + if isinstance(env_seeds, list): + return [seed] + env_seeds + return [seed] + + def render(self, mode: str = "human", **kwargs: Dict[str, Any]) -> Any: + """Renders the environment.""" + return self.env.render(mode, **kwargs) + + def close(self) -> Any: + return self.env.close() + + def __str__(self) -> str: + return f"{type(self).__name__}{self.env}" + + def __repr__(self) -> str: + return str(self) + + @property + def unwrapped(self) -> Env: + return self.env.unwrapped + + def __getattr__(self, name: str) -> Any: + if name.startswith("_"): + raise AttributeError( + "attempted to get missing private attribute '{}'".format(name) + ) + return getattr(self.env, name) diff --git a/mtenv/wrappers/multitask.py b/mtenv/wrappers/multitask.py new file mode 100644 index 0000000..1bed14c --- /dev/null +++ b/mtenv/wrappers/multitask.py @@ -0,0 +1,69 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +"""Wrapper to change the behaviour of an existing multitask environment.""" + +from typing import List, Optional + +from numpy.random import RandomState + +from mtenv import MTEnv +from mtenv.utils import seeding +from mtenv.utils.types import ( + ActionType, + ObsType, + StepReturnType, + TaskObsType, + TaskStateType, +) + + +class MultiTask(MTEnv): + def __init__(self, env: MTEnv): + """Wrapper to change the behaviour of an existing multitask environment + + Args: + env (MTEnv): Multitask environment to wrap over. + """ + self.env = env + self.observation_space = self.env.observation_space + self.action_space = self.env.action_space + self.np_random_env: Optional[RandomState] = None + self.np_random_task: Optional[RandomState] = None + + def step(self, action: ActionType) -> StepReturnType: + return self.env.step(action) + + def get_task_obs(self) -> TaskObsType: + return self.env.get_task_obs() + + def get_task_state(self) -> TaskStateType: + return self.env.get_task_state() + + def set_task_state(self, task_state: TaskStateType) -> None: + self.env.set_task_state(task_state) + + def assert_env_seed_is_set(self) -> None: + """Check that the env seed is set.""" + assert self.np_random_env is not None, "please call `seed()` first" + self.env.assert_env_seed_is_set() + + def assert_task_seed_is_set(self) -> None: + """Check that the task seed is set.""" + assert self.np_random_task is not None, "please call `seed_task()` first" + self.env.assert_task_seed_is_set() + + def reset(self) -> ObsType: + return self.env.reset() + + def sample_task_state(self) -> TaskStateType: + return self.env.sample_task_state() + + def reset_task_state(self) -> None: + self.env.reset_task_state() + + def seed(self, seed: Optional[int] = None) -> List[int]: + self.np_random_env, seed = seeding.np_random(seed) + return [seed] + self.env.seed(seed) + + def seed_task(self, seed: Optional[int] = None) -> List[int]: + self.np_random_task, seed = seeding.np_random(seed) + return [seed] + self.env.seed_task(seed) diff --git a/mtenv/wrappers/ntasks.py b/mtenv/wrappers/ntasks.py new file mode 100644 index 0000000..a017041 --- /dev/null +++ b/mtenv/wrappers/ntasks.py @@ -0,0 +1,58 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +"""Wrapper to fix the number of tasks in an existing multitask environment.""" + +from typing import List + +from mtenv import MTEnv +from mtenv.utils.types import TaskStateType +from mtenv.wrappers.multitask import MultiTask + + +class NTasks(MultiTask): + def __init__(self, env: MTEnv, n_tasks: int): + """Wrapper to fix the number of tasks in an existing multitask + environment to `n_tasks`. + + Each task is sampled in this fixed set of `n_tasks`. + + Args: + env (MTEnv): Multitask environment to wrap over. + n_tasks (int): Number of tasks to sample. + """ + super().__init__(env=env) + self.n_tasks = n_tasks + self.tasks: List[TaskStateType] + self._are_tasks_set = False + + def sample_task_state(self) -> TaskStateType: + """Sample a `task_state` from the set of `n_tasks` tasks. + + `task_state` contains all the information that the environment + needs to switch to any other task. + + The subclasses, extending this class, should ensure that the task + seed is set (by calling `seed(int)`) before invoking this + method (for reproducibility). It can be done by invoking + `self.assert_task_seed_is_set()`. + + Returns: + TaskStateType: For more information on `task_state`, + refer :ref:`task_state`. + """ + self.assert_task_seed_is_set() + if not self._are_tasks_set: + self.tasks = [self.env.sample_task_state() for _ in range(self.n_tasks)] + self._are_tasks_set = True + + # The assert statement (at the start of the function) ensures that self.np_random_task + # is not None. Mypy is raising the warning incorrectly. + id_task = self.np_random_task.randint(self.n_tasks) # type: ignore[union-attr] + return self.tasks[id_task] + + def reset_task_state(self) -> None: + """Sample a new task_state from the set of `n_tasks` tasks and + set the environment to that `task_state`. + + For more information on `task_state`, refer :ref:`task_state`. + """ + self.set_task_state(task_state=self.sample_task_state()) diff --git a/mtenv/wrappers/ntasks_id.py b/mtenv/wrappers/ntasks_id.py new file mode 100644 index 0000000..91f16ee --- /dev/null +++ b/mtenv/wrappers/ntasks_id.py @@ -0,0 +1,67 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +"""Wrapper to fix the number of tasks in an existing multitask environment +and return the id of the task as part of the observation.""" + +from gym.spaces import Dict as DictSpace +from gym.spaces import Discrete + +from mtenv import MTEnv +from mtenv.utils.types import ActionType, ObsType, StepReturnType, TaskStateType +from mtenv.wrappers.ntasks import NTasks + + +class NTasksId(NTasks): + def __init__(self, env: MTEnv, n_tasks: int): + """Wrapper to fix the number of tasks in an existing multitask + environment to `n_tasks`. + + Each task is sampled in this fixed set of `n_tasks`. The agent + observes the id of the task. + + Args: + env (MTEnv): Multitask environment to wrap over. + n_tasks (int): Number of tasks to sample. + """ + self.env = env + + super().__init__(n_tasks=n_tasks, env=env) + self.task_state: TaskStateType + self.observation_space: DictSpace = DictSpace( + spaces={ + "env_obs": self.observation_space["env_obs"], + "task_obs": Discrete(n_tasks), + } + ) + + def _update_obs(self, obs: ObsType) -> ObsType: + obs["task_obs"] = self.get_task_obs() + return obs + + def step(self, action: ActionType) -> StepReturnType: + obs, reward, done, info = self.env.step(action) + return self._update_obs(obs), reward, done, info + + def get_task_obs(self) -> TaskStateType: + return self.task_state + + def get_task_state(self) -> TaskStateType: + return self.task_state + + def set_task_state(self, task_state: TaskStateType) -> None: + self.env.set_task_state(self.tasks[task_state]) + self.task_state = task_state + + def reset(self) -> ObsType: + obs = self.env.reset() + return self._update_obs(obs) + + def sample_task_state(self) -> TaskStateType: + self.assert_task_seed_is_set() + if not self._are_tasks_set: + self.tasks = [self.env.sample_task_state() for _ in range(self.n_tasks)] + self._are_tasks_set = True + + # The assert statement (at the start of the function) ensures that self.np_random_task + # is not None. Mypy is raising the warning incorrectly. + id_task = self.np_random_task.randint(self.n_tasks) # type: ignore[union-attr] + return id_task diff --git a/mtenv/wrappers/sample_random_task.py b/mtenv/wrappers/sample_random_task.py new file mode 100644 index 0000000..2ea3ed4 --- /dev/null +++ b/mtenv/wrappers/sample_random_task.py @@ -0,0 +1,22 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +"""Wrapper that samples a new task everytime the environment is reset.""" + +from mtenv import MTEnv +from mtenv.utils.types import ObsType +from mtenv.wrappers.multitask import MultiTask + + +class SampleRandomTask(MultiTask): + def __init__(self, env: MTEnv): + """Wrapper that samples a new task everytime the environment is + reset. + + Args: + env (MTEnv): Multitask environment to wrap over. + """ + + super().__init__(env=env) + + def reset(self) -> ObsType: + self.env.reset_task_state() + return self.env.reset() diff --git a/news/.gitignore b/news/.gitignore new file mode 100644 index 0000000..4397c3a --- /dev/null +++ b/news/.gitignore @@ -0,0 +1,2 @@ +!.gitignore + diff --git a/news/_template.rst b/news/_template.rst new file mode 100644 index 0000000..27b2a2c --- /dev/null +++ b/news/_template.rst @@ -0,0 +1,19 @@ +{% for section in sections %} {% if section %} {{section}} + +{% endif %} {% if sections[section] %} {% for category, val in definitions.items() if category in sections[section] and category != 'trivial' %} + +### {{ definitions[category]['name'] }} + +{% if definitions[category]['showcontent'] %} {% for text, values in sections[section][category]|dictsort(by='value') %} - {{ text }}{% if category != 'plugin' and category != 'process' %} ({{ values|sort|join(', ') }}){% endif %} + +{% endfor %} {% else %} - {{ sections[section][category]['']|sort|join(', ') }} + +{% endif %} {% if sections[section][category]|length == 0 %} + +No significant changes. + +{% else %} {% endif %} {% endfor %} {% else %} + +No significant changes. + +{% endif %} {% endfor %} \ No newline at end of file diff --git a/noxfile.py b/noxfile.py new file mode 100644 index 0000000..c0186e0 --- /dev/null +++ b/noxfile.py @@ -0,0 +1,175 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# type: ignore +import base64 +import os +from pathlib import Path +from typing import List, Set + +import nox +from nox.sessions import Session + +DEFAULT_PYTHON_VERSIONS = ["3.6", "3.7", "3.8", "3.9"] + +PYTHON_VERSIONS = os.environ.get( + "NOX_PYTHON_VERSIONS", ",".join(DEFAULT_PYTHON_VERSIONS) +).split(",") + + +def setup_env(session: Session, name: str) -> None: + env = {} + if name in ["metaworld"]: + key = "CIRCLECI_MJKEY" + if key in os.environ: + # job is running in CI + env[ + "LD_LIBRARY_PATH" + ] = "$LD_LIBRARY_PATH:/home/circleci/.mujoco/mujoco200/bin" + session.install(f".[{name}]", env=env) + + +def setup_mtenv(session: Session) -> None: + key = "CIRCLECI_MJKEY" + if key in os.environ: + # job is running in CI + mjkey = base64.b64decode(os.environ[key]).decode("utf-8") + mjkey_path = "/home/circleci/.mujoco/mjkey.txt" + with open(mjkey_path, "w") as f: + # even if the mjkey exists, we can safely overwrite it. + for line in mjkey: + f.write(line) + session.install("--upgrade", "setuptools", "pip") + session.install(".[dev]") + + +def get_core_paths(root: str) -> List[str]: + """Return all the files/directories that are part of core package. + + In practice, it just excludes the directories in env module""" + paths = [] + for _path in Path(root).iterdir(): + if _path.stem == "envs": + for _env_path in _path.iterdir(): + if _env_path.is_file(): + paths.append(str(_env_path)) + else: + paths.append(str(_path)) + return paths + + +class EnvSetup: + def __init__( + self, name: str, setup_path: Path, supported_python_versions: Set[str] + ) -> None: + self.name = name + self.setup_path = str(setup_path) + self.path = str(setup_path.parent) + self.supported_python_versions = supported_python_versions + + +def parse_setup_file(session: Session, setup_path: Path) -> EnvSetup: + command = ["python", str(setup_path), "--name", "--classifiers"] + classifiers = session.run(*command, silent=True).splitlines() + name = classifiers[0] + python_version_string = "Programming Language :: Python :: " + supported_python_versions = { + stmt.replace(python_version_string, "") + for stmt in classifiers[1:] + if python_version_string in stmt + } + return EnvSetup( + name=name, + setup_path=setup_path, + supported_python_versions=supported_python_versions, + ) + + +def get_all_envsetups(session: Session) -> List[EnvSetup]: + return [ + parse_setup_file(session=session, setup_path=setup_path) + for setup_path in Path("mtenv/envs").glob("**/setup.py") + ] + + +def get_all_env_setup_paths_as_nox_params(): + return [ + nox.param(setup_path, id=setup_path.parent.stem) + for setup_path in Path("mtenv/envs").glob("**/setup.py") + ] + + +def get_supported_envsetups(session: Session) -> List[EnvSetup]: + """Get the list of EnvSetups that can run in a given session.""" + return [ + env_setup + for env_setup in get_all_envsetups(session=session) + if session.python in env_setup.supported_python_versions + ] + + +def get_supported_env_paths(session: Session) -> List[str]: + """Get the list of env_paths that can run in a given session.""" + return [env_setup.path for env_setup in get_supported_envsetups(session=session)] + + +@nox.session(python=PYTHON_VERSIONS) +def lint(session: Session) -> None: + setup_mtenv(session=session) + for _path in ( + get_core_paths(root="mtenv") + + get_core_paths(root="tests") + + get_supported_env_paths(session=session) + ): + session.run("black", "--check", _path) + session.run("flake8", _path) + + +@nox.session(python=PYTHON_VERSIONS) +def mypy(session: Session) -> None: + setup_mtenv(session=session) + for _path in get_core_paths(root="mtenv"): + session.run("mypy", "--strict", _path) + for envsetup in get_supported_envsetups(session=session): + setup_env(session=session, name=envsetup.name) + session.run("mypy", envsetup.path) + + +@nox.session(python=PYTHON_VERSIONS) +def test_wrappers(session) -> None: + setup_mtenv(session=session) + session.run("pytest", "tests/wrappers") + + +@nox.session(python=PYTHON_VERSIONS) +def test_examples(session) -> None: + setup_mtenv(session=session) + session.run("pytest", "tests/examples") + + +@nox.session(python=PYTHON_VERSIONS) +@nox.parametrize("env_setup_path", get_all_env_setup_paths_as_nox_params()) +def test_envs(session, env_setup_path) -> None: + setup_mtenv(session=session) + + envsetup = parse_setup_file(session=session, setup_path=env_setup_path) + + if session.python not in envsetup.supported_python_versions: + print(f"Python {session.python} is not supported by {envsetup.name}") + return + setup_env(session=session, name=envsetup.name) + env = {"NOX_MTENV_ENV_PATH": envsetup.path} + command_for_headless_rendering = [ + "xvfb-run", + "-a", + "-s", + "-screen 0 1024x768x24 -ac +extension GLX +render -noreset", + ] + commands = [] + key = "CIRCLECI_MJKEY" + if key in os.environ and envsetup.name in ["metaworld"]: + env["LD_LIBRARY_PATH"] = "$LD_LIBRARY_PATH:/home/circleci/.mujoco/mujoco200/bin" + if envsetup.name.startswith("MT-HiPBMDP"): + env["PYTHONPATH"] = "mtenv/envs/hipbmdp/local_dm_control_suite" + if envsetup.name in ["hipbmdp", "mpte"]: + commands = commands + command_for_headless_rendering + commands = commands + ["pytest", "tests/envs"] + session.run(*commands, env=env) diff --git a/requirements/base.txt b/requirements/base.txt new file mode 100644 index 0000000..84618eb --- /dev/null +++ b/requirements/base.txt @@ -0,0 +1,2 @@ +gym>=0.16.0 +numpy>=1.10.4<1.20 \ No newline at end of file diff --git a/requirements/dev.txt b/requirements/dev.txt new file mode 100644 index 0000000..910fa5c --- /dev/null +++ b/requirements/dev.txt @@ -0,0 +1,21 @@ +black==20.8b1 +flake8-bugbear==20.11.1 +flake8-comprehensions==3.3.1 +flake8-docstrings==1.5.0 +flake8==3.8.4 +isort==5.7.0 +mypy==0.790 +nox==2020.8.22 +pre-commit==2.9.3 +pytest==6.2.1 +pytest-xdist==2.2.0 +setuptools==51.1.1 +sphinx-autodoc-annotation==1.0-1 +sphinx-copybutton==0.3.1 +sphinx-rtd-theme==0.5.0 +sphinx==3.4.1 +sphinxcontrib-bibtex==2.1.3 +sphinxcontrib-napoleon==0.7 +twine==3.3.0 +towncrier==19.2.0 +typing-extensions==3.7.4.3 \ No newline at end of file diff --git a/requirements/docs.txt b/requirements/docs.txt new file mode 100644 index 0000000..71b3df2 --- /dev/null +++ b/requirements/docs.txt @@ -0,0 +1,8 @@ +gym>=0.16.0 +numpy>=1.10.4<1.20 +sphinx-autodoc-annotation==1.0-1 +sphinx-copybutton==0.3.1 +sphinx-rtd-theme==0.5.0 +sphinx==3.4.1 +sphinxcontrib-bibtex==2.1.3 +sphinxcontrib-napoleon==0.7 \ No newline at end of file diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..44cb154 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,56 @@ +[flake8] +exclude = .git,.nox +max-line-length = 119 +select = B,C,E,F,W +ignore=B009,E203,E501,W503 + +[isort] +profile=black + +[examples.*] +ignore_missing_imports = True + +[tests.*] +ignore_missing_imports = True + +[mypy] +python_version = 3.8 +warn_unused_configs = True +mypy_path=. +disallow_untyped_calls = False +show_error_codes = True + +[mypy-dm_control.*] +ignore_missing_imports = True + +[mypy-dmc2gym.*] +ignore_missing_imports = True + +[mypy-gym.*] +ignore_missing_imports = True +disallow_subclassing_any = False + +[mypy-gym_miniworld.*] +ignore_missing_imports = True + +[mypy-lxml.*] +ignore_missing_imports = True + +[mypy-metaworld.*] +ignore_missing_imports = True + +[mypy-numpy.*] +ignore_missing_imports = True + +[mypy-pytest.*] +ignore_missing_imports = True + +[tool:pytest] +filterwarnings = + ignore:.*Box bound precision lowered by casting to float32.*:UserWarning + +[mypy-scipy.*] +ignore_missing_imports = True + +[mypy-setuptools.*] +ignore_missing_imports = True \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..483d8d7 --- /dev/null +++ b/setup.py @@ -0,0 +1,86 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# type: ignore +import codecs +import os.path +import subprocess +from pathlib import Path + +import setuptools + + +def read(rel_path): + here = os.path.abspath(os.path.dirname(__file__)) + with codecs.open(os.path.join(here, rel_path), "r") as fp: + return fp.read() + + +def get_version(rel_path): + for line in read(rel_path).splitlines(): + if line.startswith("__version__"): + delim = '"' if '"' in line else "'" + return line.split(delim)[1] + raise RuntimeError("Unable to find version string.") + + +def parse_dependency(filepath): + dep_list = [] + for dep in open(filepath).read().splitlines(): + if dep.startswith("#"): + continue + key = "#egg=" + if key in dep: + git_link, egg_name = dep.split(key) + dep = f"{egg_name} @ {git_link}" + dep_list.append(dep) + return dep_list + + +base_requirements = parse_dependency("requirements/base.txt") +dev_requirements = base_requirements + parse_dependency("requirements/dev.txt") + + +extras_require = {} + +for setup_path in Path("mtenv/envs").glob("**/setup.py"): + env_path = setup_path.parent + env_name = ( + subprocess.run(["python", setup_path, "--name"], stdout=subprocess.PIPE) + .stdout.decode() + .strip() + ) + extras_require[env_name] = base_requirements + parse_dependency( + f"{str(env_path)}/requirements.txt" + ) + +extras_require["all"] = list( + set([dep for requirements in extras_require.values() for dep in requirements]) +) +extras_require["dev"] = dev_requirements + +with open("README.md", "r") as fh: + long_description = fh.read() + +setuptools.setup( + name="mtenv", + version=get_version("mtenv/__init__.py"), + author="Shagun Sodhani, Ludovic Denoyer, Pierre-Alexandre Kamienny, Olivier Delalleau", + author_email="sshagunsodhani@gmail.com, denoyer@fb.com, pakamienny@fb.com, odelalleau@fb.com", + description="MTEnv: MultiTask Environments for Reinforcement Learning", + long_description=open("README.md").read(), + long_description_content_type="text/markdown", + install_requires=base_requirements, + url="https://github.com/facbookresearch/mtenv", + packages=setuptools.find_packages( + exclude=["*.tests", "*.tests.*", "tests.*", "tests", "docs", "docsrc"] + ), + classifiers=[ + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + ], + python_requires=">=3.6", + extras_require=extras_require, +) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..168f997 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved diff --git a/tests/envs/__init__.py b/tests/envs/__init__.py new file mode 100644 index 0000000..168f997 --- /dev/null +++ b/tests/envs/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved diff --git a/tests/envs/registered_env_test.py b/tests/envs/registered_env_test.py new file mode 100644 index 0000000..7c1d513 --- /dev/null +++ b/tests/envs/registered_env_test.py @@ -0,0 +1,74 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +import os +from copy import deepcopy +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import pytest + +from mtenv import make +from mtenv.envs.registration import MultitaskEnvSpec, mtenv_registry +from tests.utils.utils import validate_mtenv + +ConfigType = Dict[str, Any] + + +def get_env_spec() -> List[Dict[str, MultitaskEnvSpec]]: + mtenv_env_path = os.environ.get("NOX_MTENV_ENV_PATH", "") + if mtenv_env_path == "": + # test all envs + return mtenv_registry.env_specs.items() + else: + # test only those environments which are on NOX_MTENV_ENV_PATH + + mtenv_env_path = str(Path(mtenv_env_path).resolve()) + env_specs = deepcopy(mtenv_registry.env_specs) + for key in list(env_specs.keys()): + entry_point = env_specs[key].entry_point.split(":")[0].replace(".", "/") + if mtenv_env_path not in str(Path(entry_point).resolve()): + env_specs.pop(key) + return env_specs.items() + + +def get_test_kwargs_from_spec(spec: MultitaskEnvSpec, key: str) -> List[Dict[str, Any]]: + if spec.test_kwargs and key in spec.test_kwargs: + return spec.test_kwargs[key] + else: + return [] + + +def get_configs(get_valid_env_args: bool) -> Tuple[ConfigType, ConfigType]: + configs = [] + key = "valid_env_kwargs" if get_valid_env_args else "invalid_env_kwargs" + for env_name, spec in get_env_spec(): + test_config = deepcopy(spec.test_kwargs) + for key_to_pop in ["valid_env_kwargs", "invalid_env_kwargs"]: + if key_to_pop in test_config: + test_config.pop(key_to_pop) + for params in get_test_kwargs_from_spec(spec, key): + env_config = deepcopy(params) + env_config["id"] = env_name + configs.append((env_config, deepcopy(test_config))) + if get_valid_env_args: + env_config = deepcopy(spec.kwargs) + env_config["id"] = env_name + configs.append((env_config, deepcopy(test_config))) + return configs + + +@pytest.mark.parametrize( + "env_config, test_config", get_configs(get_valid_env_args=True) +) +def test_registered_env_with_valid_input(env_config, test_config): + env = make(**env_config) + validate_mtenv(env=env, **test_config) + + +@pytest.mark.parametrize( + "env_config, test_config", get_configs(get_valid_env_args=False) +) +def test_registered_env_with_invalid_input(env_config, test_config): + with pytest.raises(Exception): + env = make(**env_config) + validate_mtenv(env=env, **test_config) diff --git a/tests/examples/__init__.py b/tests/examples/__init__.py new file mode 100644 index 0000000..168f997 --- /dev/null +++ b/tests/examples/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved diff --git a/tests/examples/bandit_test.py b/tests/examples/bandit_test.py new file mode 100644 index 0000000..614f236 --- /dev/null +++ b/tests/examples/bandit_test.py @@ -0,0 +1,32 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + + +from typing import List + +import pytest + +from examples.bandit import BanditEnv # noqa: E402 +from tests.utils.utils import validate_single_task_env + + +def get_valid_n_arms() -> List[int]: + return [1, 10, 100] + + +def get_invalid_n_arms() -> List[int]: + return [-1, 0] + + +@pytest.mark.parametrize("n_arms", get_valid_n_arms()) +def test_n_arm_bandit_with_valid_input(n_arms): + env = BanditEnv(n_arms=n_arms) + env.seed(seed=5) + validate_single_task_env(env) + + +@pytest.mark.parametrize("n_arms", get_invalid_n_arms()) +def test_n_arm_bandit_with_invalid_input(n_arms): + with pytest.raises(Exception): + env = BanditEnv(n_arms=n_arms) + env.seed(seed=5) + validate_single_task_env(env) diff --git a/tests/examples/finite_mtenv_bandit_test.py b/tests/examples/finite_mtenv_bandit_test.py new file mode 100644 index 0000000..d6bbe12 --- /dev/null +++ b/tests/examples/finite_mtenv_bandit_test.py @@ -0,0 +1,28 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from typing import List + +import pytest + +from examples.finite_mtenv_bandit import FiniteMTBanditEnv # noqa: E402 +from tests.utils.utils import validate_mtenv + + +def get_valid_n_tasks_and_arms() -> List[int]: + return [(1, 2), (10, 20), (100, 200)] + + +def get_invalid_n_tasks_and_arms() -> List[int]: + return [(-1, 2), (0, 3), (1, -2), (3, 0)] + + +@pytest.mark.parametrize("n_tasks, n_arms", get_valid_n_tasks_and_arms()) +def test_mtenv_bandit_with_valid_input(n_tasks, n_arms): + env = FiniteMTBanditEnv(n_tasks=n_tasks, n_arms=n_arms) + validate_mtenv(env=env) + + +@pytest.mark.parametrize("n_tasks, n_arms", get_invalid_n_tasks_and_arms()) +def test_mtenv_bandit_with_invalid_input(n_tasks, n_arms): + with pytest.raises(Exception): + env = FiniteMTBanditEnv(n_tasks=n_tasks, n_arms=n_arms) + validate_mtenv(env=env) diff --git a/tests/examples/mtenv_bandit_test.py b/tests/examples/mtenv_bandit_test.py new file mode 100644 index 0000000..e596a07 --- /dev/null +++ b/tests/examples/mtenv_bandit_test.py @@ -0,0 +1,28 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from typing import List + +import pytest + +from examples.mtenv_bandit import MTBanditEnv # noqa: E402 +from tests.utils.utils import validate_mtenv + + +def get_valid_n_arms() -> List[int]: + return [1, 10, 100] + + +def get_invalid_n_arms() -> List[int]: + return [-1, 0] + + +@pytest.mark.parametrize("n_arms", get_valid_n_arms()) +def test_ntasks_id_wrapper_with_valid_input(n_arms): + env = MTBanditEnv(n_arms=n_arms) + validate_mtenv(env=env) + + +@pytest.mark.parametrize("n_arms", get_invalid_n_arms()) +def test_ntasks_id_wrapper_with_invalid_input(n_arms): + with pytest.raises(Exception): + env = MTBanditEnv(n_arms=n_arms) + validate_mtenv(env=env) diff --git a/tests/examples/wrapped_bandit_test.py b/tests/examples/wrapped_bandit_test.py new file mode 100644 index 0000000..795043f --- /dev/null +++ b/tests/examples/wrapped_bandit_test.py @@ -0,0 +1,38 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from typing import List + +import pytest +from gym import spaces + +from examples.bandit import BanditEnv # noqa: E402 +from examples.wrapped_bandit import MTBanditWrapper # noqa: E402 +from tests.utils.utils import validate_mtenv + + +def get_valid_n_arms() -> List[int]: + return [1, 10, 100] + + +def get_invalid_n_arms() -> List[int]: + return [-1, 0] + + +@pytest.mark.parametrize("n_arms", get_valid_n_arms()) +def test_ntasks_id_wrapper_with_valid_input(n_arms): + + env = MTBanditWrapper( + env=BanditEnv(n_arms), + task_observation_space=spaces.Box(low=0.0, high=1.0, shape=(n_arms,)), + ) + + validate_mtenv(env=env) + + +@pytest.mark.parametrize("n_arms", get_invalid_n_arms()) +def test_ntasks_id_wrapper_with_invalid_input(n_arms): + with pytest.raises(Exception): + env = MTBanditWrapper( + env=BanditEnv(n_arms), + task_observation_space=spaces.Box(low=0.0, high=1.0, shape=(n_arms,)), + ) + validate_mtenv(env=env) diff --git a/tests/utils/utils.py b/tests/utils/utils.py new file mode 100644 index 0000000..54cb2d2 --- /dev/null +++ b/tests/utils/utils.py @@ -0,0 +1,69 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +from typing import Tuple + +import gym +import numpy as np + +from mtenv import MTEnv +from mtenv.utils.types import ( + DoneType, + EnvObsType, + InfoType, + ObsType, + RewardType, + StepReturnType, +) + +StepReturnTypeSingleEnv = Tuple[EnvObsType, RewardType, DoneType, InfoType] + + +def validate_obs_type(obs: ObsType): + assert isinstance(obs, dict) + assert "env_obs" in obs + assert "task_obs" in obs + + +def validate_step_return_type(step_return: StepReturnType): + obs, reward, done, info = step_return + validate_obs_type(obs) + assert isinstance(reward, (float, int)) + assert isinstance(done, bool) + assert isinstance(info, dict) + + +def valiate_obs_type_single_env(obs: EnvObsType): + assert isinstance(obs, np.ndarray) + + +def validate_step_return_type_single_env(step_return: StepReturnType): + obs, reward, done, info = step_return + valiate_obs_type_single_env(obs) + assert isinstance(reward, float) + assert isinstance(done, bool) + assert isinstance(info, dict) + + +def validate_mtenv(env: MTEnv) -> None: + env.seed(5) + env.assert_env_seed_is_set() + env.seed_task(15) + env.assert_task_seed_is_set() + for _env_index in range(10): + env.reset_task_state() + obs = env.reset() + validate_obs_type(obs) + for _step_index in range(3): + action = env.action_space.sample() + step_return = env.step(action) + validate_step_return_type(step_return) + + +def validate_single_task_env(env: gym.Env) -> None: + for _episode in range(10): + obs = env.reset() + valiate_obs_type_single_env(obs) + for _ in range(3): + action = env.action_space.sample() + step_return = env.step(action) + validate_step_return_type_single_env(step_return) diff --git a/tests/wrappers/__init__.py b/tests/wrappers/__init__.py new file mode 100644 index 0000000..168f997 --- /dev/null +++ b/tests/wrappers/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved diff --git a/tests/wrappers/ntasks_id_test.py b/tests/wrappers/ntasks_id_test.py new file mode 100644 index 0000000..b8e00ac --- /dev/null +++ b/tests/wrappers/ntasks_id_test.py @@ -0,0 +1,33 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + + +from typing import List + +import pytest + +from mtenv.envs.control.cartpole import MTCartPole +from mtenv.wrappers.ntasks_id import NTasksId as NTasksIdWrapper +from tests.utils.utils import validate_mtenv + + +def get_valid_num_tasks() -> List[int]: + return [1, 10, 100] + + +def get_invalid_num_tasks() -> List[int]: + return [-1, 0] + + +@pytest.mark.parametrize("n_tasks", get_valid_num_tasks()) +def test_ntasks_id_wrapper_with_valid_input(n_tasks): + env = MTCartPole() + env = NTasksIdWrapper(env, n_tasks=n_tasks) + validate_mtenv(env=env) + + +@pytest.mark.parametrize("n_tasks", get_invalid_num_tasks()) +def test_ntasks_id_wrapper_with_invalid_input(n_tasks): + with pytest.raises(Exception): + env = MTCartPole() + env = NTasksIdWrapper(env, n_tasks=n_tasks) + validate_mtenv(env=env) diff --git a/tests/wrappers/ntasks_test.py b/tests/wrappers/ntasks_test.py new file mode 100644 index 0000000..e073b3f --- /dev/null +++ b/tests/wrappers/ntasks_test.py @@ -0,0 +1,33 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + + +from typing import List + +import pytest + +from mtenv.envs.control.cartpole import MTCartPole +from mtenv.wrappers.ntasks import NTasks as NTasksWrapper +from tests.utils.utils import validate_mtenv + + +def get_valid_num_tasks() -> List[int]: + return [1, 10, 100] + + +def get_invalid_num_tasks() -> List[int]: + return [-1, 0] + + +@pytest.mark.parametrize("n_tasks", get_valid_num_tasks()) +def test_ntasks_wrapper_with_valid_input(n_tasks): + env = MTCartPole() + env = NTasksWrapper(env, n_tasks=n_tasks) + validate_mtenv(env=env) + + +@pytest.mark.parametrize("n_tasks", get_invalid_num_tasks()) +def test_ntasks_wrapper_with_invalid_input(n_tasks): + with pytest.raises(Exception): + env = MTCartPole() + env = NTasksWrapper(env, n_tasks=n_tasks) + validate_mtenv(env=env) diff --git a/towncrier.toml b/towncrier.toml new file mode 100644 index 0000000..1a3f671 --- /dev/null +++ b/towncrier.toml @@ -0,0 +1,41 @@ +[tool.towncrier] +package = "mtenv" +package_dir = "" +filename = "NEWS.md" +directory = "news/" +title_format = "{version} ({project_date})" +issue_format = "[#{issue}](https://github.com/facebookresearch/mtenv/issues/{issue})" +template = "news/_template.rst" +start_string = "\n" + + +[[tool.towncrier.type]] +directory = "api_change" +name = "API Changes" +showcontent = true + +[[tool.towncrier.type]] +directory = "bugfix" +name = "Bug Fixes" +showcontent = true + +[[tool.towncrier.type]] +directory = "doc" +name = "Documentation Changes" +showcontent = true + +[[tool.towncrier.type]] +directory = "environment" +name = "Environment Chages (addition or removal)" +showcontent = true + +[[tool.towncrier.type]] +directory = "feature" +name = "Features" +showcontent = true + + +[[tool.towncrier.type]] +directory = "misc" +name = "Miscellaneous Changes" +showcontent = true \ No newline at end of file