Skip to content

Commit

Permalink
Feature/release preparation (#234)
Browse files Browse the repository at this point in the history
* update getting started

* update some of basic concepts

* MultiAgent -> Multiagent

* update benchmark scripts

* more doc

* fix docs build issue

* update GPU recomendation

* make experiments write cleaner names for presets

* update directory names in getting started guide

* update basic concepts

* update benchmark performance docs page and fix some typos
  • Loading branch information
cpnota committed Mar 22, 2021
1 parent b7d6494 commit e41fbec
Show file tree
Hide file tree
Showing 18 changed files with 123 additions and 111 deletions.
2 changes: 1 addition & 1 deletion all/agents/_multiagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def act(self, multiagent_state):
However, the agent must ultimately return an action.
Args:
multiagent_state (all.core.MultiAgentState): The environment state at the current timestep.
multiagent_state (all.core.MultiagentState): The environment state at the current timestep.
Returns:
torch.Tensor: The action for the current agent to take at the current timestep.
Expand Down
4 changes: 2 additions & 2 deletions all/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .state import State, StateArray, MultiAgentState
from .state import State, StateArray, MultiagentState

__all__ = ['State', 'StateArray', 'MultiAgentState']
__all__ = ['State', 'StateArray', 'MultiagentState']
8 changes: 4 additions & 4 deletions all/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,10 +388,10 @@ def __len__(self):
return self.shape[0]


class MultiAgentState(State):
class MultiagentState(State):
def __init__(self, x, device='cpu', **kwargs):
if 'agent' not in x:
raise Exception('MultiAgentState must contain an agent ID')
raise Exception('MultiagentState must contain an agent ID')
super().__init__(x, device=device, **kwargs)

@property
Expand All @@ -412,7 +412,7 @@ def from_zoo(cls, agent, state, device='cpu', dtype=np.float32):
A State object.
"""
if not isinstance(state, tuple):
return MultiAgentState({
return MultiagentState({
'agent': agent,
'observation': torch.from_numpy(
np.array(
Expand All @@ -438,7 +438,7 @@ def from_zoo(cls, agent, state, device='cpu', dtype=np.float32):
info = info if info else {}
for key in info:
x[key] = info[key]
return MultiAgentState(x, device=device)
return MultiagentState(x, device=device)

def to(self, device):
if device == self.device:
Expand Down
2 changes: 1 addition & 1 deletion all/environments/multiagent_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
import torch
import gym
from all.core import MultiAgentState
from all.core import MultiagentState
from ._multiagent_environment import MultiagentEnvironment
from .multiagent_pettingzoo import MultiagentPettingZooEnv

Expand Down
4 changes: 2 additions & 2 deletions all/environments/multiagent_pettingzoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import cloudpickle
import gym
from all.core import MultiAgentState
from all.core import MultiagentState
from ._multiagent_environment import MultiagentEnvironment


Expand Down Expand Up @@ -80,7 +80,7 @@ def duplicate(self, n):
def last(self):
observation, reward, done, info = self._env.last()
selected_obs_space = self._env.observation_spaces[self._env.agent_selection]
return MultiAgentState.from_zoo(self._env.agent_selection, (observation, reward, done, info), device=self._device, dtype=selected_obs_space.dtype)
return MultiagentState.from_zoo(self._env.agent_selection, (observation, reward, done, info), device=self._device, dtype=selected_obs_space.dtype)

@property
def name(self):
Expand Down
2 changes: 1 addition & 1 deletion all/experiments/multiagent_env_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
write_loss=True,
writer="tensorboard"
):
self._name = name if name is not None else preset.__class__.__name__
self._name = name if name is not None else preset.name
self._writer = self._make_writer(logdir, self._name, env.name, write_loss, writer)
self._agent = preset.agent(writer=self._writer, train_steps=train_steps)
self._env = env
Expand Down
2 changes: 1 addition & 1 deletion all/experiments/multiagent_env_experiment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def setUp(self):

def test_adds_default_name(self):
experiment = MockExperiment(self.make_preset(), self.env, quiet=True, save_freq=float('inf'))
self.assertEqual(experiment._writer.label, "IndependentMultiagentPreset_space_invaders_v1")
self.assertEqual(experiment._writer.label, "independent_space_invaders_v1")

def test_adds_custom_name(self):
experiment = MockExperiment(self.make_preset(), self.env, name='custom', quiet=True, save_freq=float('inf'))
Expand Down
2 changes: 1 addition & 1 deletion all/experiments/parallel_env_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(
write_loss=True,
writer="tensorboard"
):
self._name = name if name is not None else preset.__class__.__name__
self._name = name if name is not None else preset.name
super().__init__(self._make_writer(logdir, self._name, env.name, write_loss, writer), quiet)
self._n_envs = preset.n_envs
self._envs = env.duplicate(self._n_envs)
Expand Down
2 changes: 1 addition & 1 deletion all/experiments/parallel_env_experiment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def setUp(self):
env.seed(i)

def test_adds_default_label(self):
self.assertEqual(self.experiment._writer.label, "A2CClassicControlPreset_CartPole-v0")
self.assertEqual(self.experiment._writer.label, "a2c_CartPole-v0")

def test_adds_custom_label(self):
env = GymEnvironment('CartPole-v0')
Expand Down
2 changes: 1 addition & 1 deletion all/experiments/single_env_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(
write_loss=True,
writer="tensorboard"
):
self._name = name if name is not None else preset.__class__.__name__
self._name = name if name is not None else preset.name
super().__init__(self._make_writer(logdir, self._name, env.name, write_loss, writer), quiet)
self._logdir = logdir
self._preset = preset
Expand Down
2 changes: 1 addition & 1 deletion all/experiments/single_env_experiment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def setUp(self):

def test_adds_default_name(self):
experiment = MockExperiment(self.make_preset(), self.env, quiet=True)
self.assertEqual(experiment._writer.label, "DQNClassicControlPreset_CartPole-v0")
self.assertEqual(experiment._writer.label, "dqn_CartPole-v0")

def test_adds_custom_name(self):
experiment = MockExperiment(self.make_preset(), self.env, name='dqn', quiet=True)
Expand Down
13 changes: 4 additions & 9 deletions all/experiments/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,10 @@ class ExperimentWriter(SummaryWriter, Writer):

def __init__(self, experiment, agent_name, env_name, loss=True, logdir='runs'):
self.env_name = env_name
current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S %f')
os.makedirs(
os.path.join(
logdir, ("%s %s %s" % (agent_name, COMMIT_HASH, current_time)), env_name
)
)
self.log_dir = os.path.join(
logdir, ("%s %s %s" % (agent_name, COMMIT_HASH, current_time))
)
current_time = datetime.now().strftime('%Y-%m-%d_%H:%M:%S_%f')
dir_name = "%s_%s_%s" % (agent_name, COMMIT_HASH, current_time)
os.makedirs(os.path.join(logdir, dir_name, env_name))
self.log_dir = os.path.join(logdir, dir_name)
self._experiment = experiment
self._loss = loss
super().__init__(log_dir=self.log_dir)
Expand Down
4 changes: 4 additions & 0 deletions all/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
from all.core import State


""""A Pytorch Module"""
Module = nn.Module


class RLNetwork(nn.Module):
"""
Wraps a network such that States can be given as input.
Expand Down
13 changes: 6 additions & 7 deletions benchmarks/atari40.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@


def main():
device = 'cuda'
agents = [
atari.a2c(device=device),
atari.c51(device=device),
atari.dqn(device=device),
atari.ddqn(device=device),
atari.ppo(device=device),
atari.rainbow(device=device),
atari.a2c,
atari.c51,
atari.dqn,
atari.ddqn,
atari.ppo,
atari.rainbow,
]
envs = [AtariEnvironment(env, device=device) for env in ['BeamRider', 'Breakout', 'Pong', 'Qbert', 'SpaceInvaders']]
SlurmExperiment(agents, envs, 10e6, sbatch_args={
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/pybullet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ def main():
frames = int(1e7)

agents = [
ddpg(device=device),
ppo(device=device),
sac(device=device)
ddpg,
ppo,
sac
]

envs = [GymEnvironment(env, device) for env in [
Expand Down

0 comments on commit e41fbec

Please sign in to comment.