Skip to content

Commit

Permalink
Feature/vec env (#239)
Browse files Browse the repository at this point in the history
* added vector env environment

* various additions and fixes to vector env

* added duplicate env, tested

* added vector env similarity test

* fixed test bug

* fixed test

* fixed type problem

* fixed linting

* fixed FPS

* fixed vector env problem

* fixed asserts
  • Loading branch information
benblack769 committed Mar 24, 2021
1 parent edaa476 commit 5ee29ea
Show file tree
Hide file tree
Showing 10 changed files with 482 additions and 86 deletions.
7 changes: 6 additions & 1 deletion all/environments/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from ._environment import Environment
from._multiagent_environment import MultiagentEnvironment
from ._multiagent_environment import MultiagentEnvironment
from ._vector_environment import VectorEnvironment
from .gym import GymEnvironment
from .atari import AtariEnvironment
from .multiagent_atari import MultiagentAtariEnv
from .multiagent_pettingzoo import MultiagentPettingZooEnv
from .duplicate_env import DuplicateEnvironment
from .vector_env import GymVectorEnvironment
from .pybullet import PybulletEnvironment

__all__ = [
Expand All @@ -13,5 +16,7 @@
"AtariEnvironment",
"MultiagentAtariEnv",
"MultiagentPettingZooEnv",
"GymVectorEnvironment",
"DuplicateEnvironment",
"PybulletEnvironment",
]
116 changes: 116 additions & 0 deletions all/environments/_vector_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from abc import ABC, abstractmethod


class VectorEnvironment(ABC):
"""
A reinforcement learning vector Environment.
Similar to a regular RL environment except many environments are stacked together
in the observations, rewards, and dones, and the vector environment expects
an action to be given for each environment in step.
Also, since sub-environments are done at different times, you do not need to
manually reset the environments when they are done, rather the vector environment
automatically resets environments when they are complete.
"""

@property
@abstractmethod
def name(self):
"""
The name of the environment.
"""

@abstractmethod
def reset(self):
"""
Reset the environment and return a new initial state.
Returns
-------
State
The initial state for the next episode.
"""

@abstractmethod
def step(self, action):
"""
Apply an action and get the next state.
Parameters
----------
action : Action
The action to apply at the current time step.
Returns
-------
all.environments.State
The State of the environment after the action is applied.
This State object includes both the done flag and any additional "info"
float
The reward achieved by the previous action
"""

@abstractmethod
def close(self):
"""
Clean up any extraneous environment objects.
"""

@property
@abstractmethod
def state_array(self):
"""
A StateArray of the Environments at the current timestep.
"""

@property
@abstractmethod
def state_space(self):
"""
The Space representing the range of observable states for each environment.
Returns
-------
Space
An object of type Space that represents possible states the agent may observe
"""

@property
def observation_space(self):
"""
Alias for Environment.state_space.
Returns
-------
Space
An object of type Space that represents possible states the agent may observe
"""
return self.state_space

@property
@abstractmethod
def action_space(self):
"""
The Space representing the range of possible actions for each environment.
Returns
-------
Space
An object of type Space that represents possible actions the agent may take
"""

@property
@abstractmethod
def device(self):
"""
The torch device the environment lives on.
"""

@property
@abstractmethod
def num_envs(self):
"""
Number of environments in vector. This is the number of actions step() expects as input
and the number of observations, dones, etc returned by the environment.
"""
5 changes: 3 additions & 2 deletions all/environments/atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
LifeLostEnv,
)
from all.core import State
from .duplicate_env import DuplicateEnvironment


class AtariEnvironment(GymEnvironment):
Expand Down Expand Up @@ -38,6 +39,6 @@ def reset(self):
return self._state

def duplicate(self, n):
return [
return DuplicateEnvironment([
AtariEnvironment(self._name, *self._args, **self._kwargs) for _ in range(n)
]
])
73 changes: 73 additions & 0 deletions all/environments/duplicate_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import gym
import torch
from all.core import State
from ._vector_environment import VectorEnvironment
import numpy as np


class DuplicateEnvironment(VectorEnvironment):
'''
Turns a list of ALL Environment objects into a VectorEnvironment object
This wrapper just takes the list of States the environments generate and outputs
a StateArray object containing all of the environment states. Like all vector
environments, the sub environments are automatically reset when done.
Args:
envs: A list of ALL environments
device (optional): the device on which tensors will be stored
'''

def __init__(self, envs, device=torch.device('cpu')):
self._name = envs[0].name
self._envs = envs
self._state = None
self._action = None
self._reward = None
self._done = True
self._info = None
self._device = device

@property
def name(self):
return self._name

def reset(self):
self._state = State.array([sub_env.reset() for sub_env in self._envs])
return self._state

def step(self, actions):
states = []
actions = actions.cpu().detach().numpy()
for sub_env, action in zip(self._envs, actions):
state = sub_env.reset() if sub_env.state.done else sub_env.step(action)
states.append(state)
self._state = State.array(states)
return self._state

def close(self):
return self._env.close()

def seed(self, seed):
for i, env in enumerate(self._envs):
env.seed(seed + i)

@property
def state_space(self):
return self._envs[0].observation_space

@property
def action_space(self):
return self._envs[0].action_space

@property
def state_array(self):
return self._state

@property
def device(self):
return self._device

@property
def num_envs(self):
return len(self._envs)
54 changes: 54 additions & 0 deletions all/environments/duplicate_env_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import unittest
import gym
import torch
from all.environments import DuplicateEnvironment, GymEnvironment


def make_vec_env(num_envs=3):
env = [GymEnvironment('CartPole-v0') for i in range(num_envs)]
return env


class DuplicateEnvironmentTest(unittest.TestCase):
def test_env_name(self):
env = DuplicateEnvironment(make_vec_env())
self.assertEqual(env.name, 'CartPole-v0')

def test_num_envs(self):
num_envs = 5
env = DuplicateEnvironment(make_vec_env(num_envs))
self.assertEqual(env.num_envs, num_envs)
self.assertEqual((num_envs,), env.reset().shape)

def test_reset(self):
num_envs = 5
env = DuplicateEnvironment(make_vec_env(num_envs))
state = env.reset()
self.assertEqual(state.observation.shape, (num_envs, 4))
self.assertTrue((state.reward == torch.zeros(num_envs, )).all())
self.assertTrue((state.done == torch.zeros(num_envs, )).all())
self.assertTrue((state.mask == torch.ones(num_envs, )).all())

def test_step(self):
num_envs = 5
env = DuplicateEnvironment(make_vec_env(num_envs))
env.reset()
state = env.step(torch.ones(num_envs, dtype=torch.int32))
self.assertEqual(state.observation.shape, (num_envs, 4))
self.assertTrue((state.reward == torch.ones(num_envs, )).all())
self.assertTrue((state.done == torch.zeros(num_envs, )).all())
self.assertTrue((state.mask == torch.ones(num_envs, )).all())

def test_step_until_done(self):
num_envs = 3
env = DuplicateEnvironment(make_vec_env(num_envs))
env.seed(5)
env.reset()
for _ in range(100):
state = env.step(torch.ones(num_envs, dtype=torch.int32))
if state.done[0]:
break
self.assertEqual(state[0].observation.shape, (4,))
self.assertEqual(state[0].reward, 1.)
self.assertTrue(state[0].done)
self.assertEqual(state[0].mask, 0)
3 changes: 2 additions & 1 deletion all/environments/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
from all.core import State
from ._environment import Environment
from .duplicate_env import DuplicateEnvironment
import cloudpickle
gym.logger.set_level(40)

Expand Down Expand Up @@ -66,7 +67,7 @@ def seed(self, seed):
self._env.seed(seed)

def duplicate(self, n):
return [GymEnvironment(cloudpickle.loads(cloudpickle.dumps(self._env)), device=self.device) for _ in range(n)]
return DuplicateEnvironment([GymEnvironment(cloudpickle.loads(cloudpickle.dumps(self._env)), device=self.device) for _ in range(n)])

@property
def state_space(self):
Expand Down
85 changes: 85 additions & 0 deletions all/environments/vector_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import gym
import torch
from all.core import StateArray
from ._vector_environment import VectorEnvironment
import cloudpickle
import numpy as np


class GymVectorEnvironment(VectorEnvironment):
'''
A wrapper for Gym's vector environments
(see: https://github.com/openai/gym/blob/master/gym/vector/vector_env.py).
This wrapper converts the output of the vector environment to PyTorch tensors,
and wraps them in a StateArray object that can be passed to a Parallel Agent.
This constructor accepts a preconstructed gym vetor environment. Note that
in the latter case, the name property is set to be the whatever the name
of the outermost wrapper on the environment is.
Args:
vec_env: An OpenAI gym vector environment
device (optional): the device on which tensors will be stored
'''

def __init__(self, vec_env, name, device=torch.device('cpu')):
self._name = name
self._env = vec_env
self._state = None
self._action = None
self._reward = None
self._done = True
self._info = None
self._device = device

@property
def name(self):
return self._name

def reset(self):
state_tuple = self._env.reset(), np.zeros(self._env.num_envs), np.zeros(self._env.num_envs), None
self._state = self._to_state(*state_tuple)
return self._state

def _to_state(self, obs, rew, done, info):
obs = obs.astype(self.observation_space.dtype)
rew = rew.astype("float32")
done = done.astype("bool")
mask = (1 - done).astype("float32")
return StateArray({
"observation": torch.tensor(obs, device=self._device),
"reward": torch.tensor(rew, device=self._device),
"done": torch.tensor(done, device=self._device),
"mask": torch.tensor(mask, device=self._device)
}, shape=(self._env.num_envs,))

def step(self, action):
state_tuple = self._env.step(action.cpu().detach().numpy())
self._state = self._to_state(*state_tuple)
return self._state

def close(self):
return self._env.close()

def seed(self, seed):
self._env.seed(seed)

@property
def state_space(self):
return getattr(self._env, "single_observation_space", getattr(self._env, "observation_space"))

@property
def action_space(self):
return getattr(self._env, "single_action_space", getattr(self._env, "action_space"))

@property
def state_array(self):
return self._state

@property
def device(self):
return self._device

@property
def num_envs(self):
return self._env.num_envs

0 comments on commit 5ee29ea

Please sign in to comment.