Skip to content

Commit

Permalink
Refactor/pybullet import (#236)
Browse files Browse the repository at this point in the history
* add PybulletEnvironment class

* refactor the way pybullet imports work

* PybulletEnvironment accept kwargs

* loosen constraint on pybullet env test
  • Loading branch information
cpnota committed Mar 22, 2021
1 parent 230d0fb commit 111cd83
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 39 deletions.
2 changes: 2 additions & 0 deletions all/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .atari import AtariEnvironment
from .multiagent_atari import MultiagentAtariEnv
from .multiagent_pettingzoo import MultiagentPettingZooEnv
from .pybullet import PybulletEnvironment

__all__ = [
"Environment",
Expand All @@ -12,4 +13,5 @@
"AtariEnvironment",
"MultiagentAtariEnv",
"MultiagentPettingZooEnv",
"PybulletEnvironment",
]
8 changes: 5 additions & 3 deletions all/environments/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,18 @@ class GymEnvironment(Environment):
Args:
env: Either a string or an OpenAI gym environment
device (optional): the device on which tensors will be stored
name (str, optional): the name of the environment
device (str, optional): the device on which tensors will be stored
'''

def __init__(self, env, device=torch.device('cpu')):
def __init__(self, env, name=None, device=torch.device('cpu')):
if isinstance(env, str):
self._name = env
env = gym.make(env)
else:
self._name = env.__class__.__name__

if name:
self._name = name
self._env = env
self._state = None
self._action = None
Expand Down
17 changes: 17 additions & 0 deletions all/environments/pybullet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from .gym import GymEnvironment


class PybulletEnvironment(GymEnvironment):
short_names = {
"ant": "AntBulletEnv-v0",
"cheetah": "HalfCheetahBulletEnv-v0",
"humanoid": "HumanoidBulletEnv-v0",
"hopper": "HopperBulletEnv-v0",
"walker": "Walker2DBulletEnv-v0"
}

def __init__(self, name, **kwargs):
import pybullet_envs
if name in self.short_names:
name = self.short_names[name]
super().__init__(name, **kwargs)
33 changes: 33 additions & 0 deletions all/environments/pybullet_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import unittest
from all.environments import PybulletEnvironment, GymEnvironment


class PybulletEnvironmentTest(unittest.TestCase):
def test_env_short_name(self):
for short_name, long_name in PybulletEnvironment.short_names.items():
env = PybulletEnvironment(short_name)
self.assertEqual(env.name, long_name)

def test_env_full_name(self):
env = PybulletEnvironment('HalfCheetahBulletEnv-v0')
self.assertEqual(env.name, 'HalfCheetahBulletEnv-v0')

def test_reset(self):
env = PybulletEnvironment('cheetah')
state = env.reset()
self.assertEqual(state.observation.shape, (26,))
self.assertEqual(state.reward, 0.)
self.assertFalse(state.done)
self.assertEqual(state.mask, 1)

def test_step(self):
env = PybulletEnvironment('cheetah')
env.seed(0)
state = env.reset()
state = env.step(env.action_space.sample())
self.assertEqual(state.observation.shape, (26,))
self.assertGreater(state.reward, -1.)
self.assertLess(state.reward, 1)
self.assertNotEqual(state.reward, 0.)
self.assertFalse(state.done)
self.assertEqual(state.mask, 1)
12 changes: 2 additions & 10 deletions benchmarks/pybullet.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import pybullet
import pybullet_envs
from all.experiments import SlurmExperiment
from all.presets.continuous import ddpg, ppo, sac
from all.environments import GymEnvironment
from all.environments import PybulletEnvironment


def main():
Expand All @@ -16,13 +14,7 @@ def main():
sac
]

envs = [GymEnvironment(env, device) for env in [
'AntBulletEnv-v0',
"HalfCheetahBulletEnv-v0",
'HumanoidBulletEnv-v0',
'HopperBulletEnv-v0',
'Walker2DBulletEnv-v0'
]]
envs = [PybulletEnvironment(env, device) for env in PybulletEnvironment.short_names]

SlurmExperiment(agents, envs, frames, sbatch_args={
'partition': '1080ti-long'
Expand Down
5 changes: 2 additions & 3 deletions docs/source/guide/basic_concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ Below, we show how several different types of environments can be created:

.. code-block:: python
from all.environments import AtariEnvironment, GymEnvironment
from all.environments import AtariEnvironment, GymEnvironment, PybulletEnvironment
# create an Atari environment on the gpu
env = AtariEnvironment('Breakout', device='cuda')
Expand All @@ -190,8 +190,7 @@ Below, we show how several different types of environments can be created:
env = GymEnvironment('CartPole-v0')
# create a PyBullet environment on the cpu
import pybullet_envs
env = GymEnvironment('HalfCheetahBulletEnv-v0')
env = PybulletEnvironment('cheetah')
Now we can write our first control loop:

Expand Down
25 changes: 8 additions & 17 deletions scripts/continuous.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,20 @@
# pylint: disable=unused-import
import argparse
import pybullet
import pybullet_envs
from all.environments import GymEnvironment
from all.environments import GymEnvironment, PybulletEnvironment
from all.experiments import run_experiment
from all.presets import continuous

# some example envs
# can also enter ID directly

# see also: PybulletEnvironment.short_names
ENVS = {
# classic continuous environments
"mountaincar": "MountainCarContinuous-v0",
"lander": "LunarLanderContinuous-v2",
# Bullet robotics environments
"ant": "AntBulletEnv-v0",
"cheetah": "HalfCheetahBulletEnv-v0",
"humanoid": "HumanoidBulletEnv-v0",
"hopper": "HopperBulletEnv-v0",
"walker": "Walker2DBulletEnv-v0"
}


def main():
parser = argparse.ArgumentParser(description="Run a continuous actions benchmark.")
parser.add_argument("env", help="Name of the env (see envs)")
parser.add_argument("env", help="Name of the env (e.g. 'lander', 'cheetah')")
parser.add_argument(
"agent", help="Name of the agent (e.g. ddpg). See presets for available agents."
)
Expand Down Expand Up @@ -51,11 +42,11 @@ def main():
args = parser.parse_args()

if args.env in ENVS:
env_id = ENVS[args.env]
env = GymEnvironment(args.env, device=args.device)
elif 'BulletEnv' in args.env or args.env in PybulletEnvironment.short_names:
env = PybulletEnvironment(args.env, device=args.device)
else:
env_id = args.env

env = GymEnvironment(env_id, device=args.device)
env = GymEnvironment(args.env, device=args.device)

agent_name = args.agent
agent = getattr(continuous, agent_name)
Expand Down
11 changes: 5 additions & 6 deletions scripts/watch_continuous.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# pylint: disable=unused-import
import argparse
import pybullet
import pybullet_envs
from all.bodies import TimeFeature
from all.environments import GymEnvironment
from all.environments import GymEnvironment, PybulletEnvironment
from all.experiments import load_and_watch
from .continuous import ENVS

Expand All @@ -25,11 +23,12 @@ def main():
args = parser.parse_args()

if args.env in ENVS:
env_id = ENVS[args.env]
env = GymEnvironment(args.env, device=args.device)
elif 'BulletEnv' in args.env or args.env in PybulletEnvironment.short_names:
env = PybulletEnvironment(args.env, device=args.device)
else:
env_id = args.env
env = GymEnvironment(args.env, device=args.device)

env = GymEnvironment(env_id, device=args.device)
load_and_watch(args.filename, env, fps=args.fps)


Expand Down

0 comments on commit 111cd83

Please sign in to comment.