Skip to content

Commit

Permalink
fix support for preconstructed gym environments (#170)
Browse files Browse the repository at this point in the history
* fix support for preconstructed gym environments

* add documentation to GymEnvironment

* whitespace
  • Loading branch information
cpnota committed Sep 29, 2020
1 parent 72a5ffc commit c68570f
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
19 changes: 18 additions & 1 deletion all/environments/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,27 @@
gym.logger.set_level(40)

class GymEnvironment(Environment):
'''
A wrapper for OpenAI Gym environments (see: https://gym.openai.com).
This wrapper converts the output of the gym environment to PyTorch tensors,
and wraps them in a State object that can be passed to an Agent.
This constructor supports either a string, which will be passed to the
gym.make(name) function, or a preconstructed gym environment. Note that
in the latter case, the name property is set to be the whatever the name
of the outermost wrapper on the environment is.
Args:
env: Either a string or an OpenAI gym environment
device (optional): the device on which tensors will be stored
'''
def __init__(self, env, device=torch.device('cpu')):
self._name = env
if isinstance(env, str):
self._name = env
env = gym.make(env)
else:
self._name = env.__class__.__name__

self._env = env
self._state = None
self._action = None
Expand Down
17 changes: 17 additions & 0 deletions all/environments/gym_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
import unittest
import gym
from all.environments import GymEnvironment


class GymEnvironmentTest(unittest.TestCase):
def test_env_name(self):
env = GymEnvironment('CartPole-v0')
self.assertEqual(env.name, 'CartPole-v0')

def test_preconstructed_env_name(self):
env = GymEnvironment(gym.make('Blackjack-v0'))
self.assertEqual(env.name, 'BlackjackEnv')

def test_reset(self):
env = GymEnvironment('CartPole-v0')
state = env.reset()
Expand All @@ -11,6 +20,14 @@ def test_reset(self):
self.assertFalse(state.done)
self.assertEqual(state.mask, 1)

def test_reset_preconstructed_env(self):
env = GymEnvironment(gym.make('CartPole-v0'))
state = env.reset()
self.assertEqual(state.observation.shape, (4,))
self.assertEqual(state.reward, 0)
self.assertFalse(state.done)
self.assertEqual(state.mask, 1)

def test_step(self):
env = GymEnvironment('CartPole-v0')
env.reset()
Expand Down

0 comments on commit c68570f

Please sign in to comment.