Skip to content

Commit

Permalink
moved independent wrapper (#207)
Browse files Browse the repository at this point in the history
* moved independent wrapper

* fixed integration test

* fixed linting issue
  • Loading branch information
benblack769 committed Jan 20, 2021
1 parent 1778f84 commit e244c37
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 18 deletions.
6 changes: 3 additions & 3 deletions all/experiments/multiagent_env_experiment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
import torch
from all.presets.atari import dqn
from all.presets.multiagent_atari import IndependentMultiagentAtariPreset
from all.presets import IndependentMultiagentPreset
from all.environments import MultiagentAtariEnv
from all.experiments import MultiagentEnvExperiment
from all.logging import Writer
Expand Down Expand Up @@ -58,7 +58,7 @@ def setUp(self):

def test_adds_default_name(self):
experiment = MockExperiment(self.make_preset(), self.env, quiet=True, save_freq=float('inf'))
self.assertEqual(experiment._writer.label, "IndependentMultiagentAtariPreset_pong_v1")
self.assertEqual(experiment._writer.label, "IndependentMultiagentPreset_pong_v1")

def test_adds_custom_name(self):
experiment = MockExperiment(self.make_preset(), self.env, name='custom', quiet=True, save_freq=float('inf'))
Expand Down Expand Up @@ -96,7 +96,7 @@ def test_writes_loss(self):
self.assertFalse(experiment._writer.write_loss)

def make_preset(self):
return IndependentMultiagentAtariPreset({
return IndependentMultiagentPreset({
agent: dqn().device('cpu').env(env).build()
for agent, env in self.env.subenvs.items()
})
Expand Down
10 changes: 8 additions & 2 deletions all/presets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@

from all.presets import atari
from all.presets import classic_control
from all.presets import continuous
from .preset import Preset
from .independent_multiagent import IndependentMultiagentPreset

__all__ = ["Preset", "atari", "classic_control", "continuous"]
__all__ = [
"Preset",
"atari",
"classic_control",
"continuous",
"IndependentMultiagentPreset"
]
2 changes: 1 addition & 1 deletion all/presets/continuous/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class PPOContinuousPreset(Preset):
eps (float): Stability parameters for the Adam optimizer.
entropy_loss_scaling (float): Coefficient for the entropy term in the total loss.
value_loss_scaling (float): Coefficient for the value function loss.
clip_grad (float): Clips the gradient during training so that its L2 norm (calculated over all parameters)
clip_grad (float): Clips the gradient during training so that its L2 norm (calculated over all parameters)
# is no greater than this bound. Set to 0 to disable.
clip_initial (float): Value for epsilon in the clipped PPO objective function at the beginning of training.
clip_final (float): Value for epsilon in the clipped PPO objective function at the end of training.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from ..builder import preset_builder
from ..preset import Preset
from .builder import preset_builder
from .preset import Preset
from all.agents.multi.independent import IndependentMultiagent
from all.logging import DummyWriter


class IndependentMultiagentAtariPreset(Preset):
class IndependentMultiagentPreset(Preset):
def __init__(self, presets):
self.presets = presets

Expand Down
4 changes: 0 additions & 4 deletions all/presets/multiagent_atari/__init__.py

This file was deleted.

4 changes: 2 additions & 2 deletions all/presets/multiagent_atari_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from all.environments import MultiagentAtariEnv
from all.logging import DummyWriter
from all.presets.atari import dqn
from all.presets.multiagent_atari import IndependentMultiagentAtariPreset
from all.presets import IndependentMultiagentPreset


class TestAtariPresets(unittest.TestCase):
Expand All @@ -22,7 +22,7 @@ def test_independent(self):
agent_id: dqn().device('cpu').env(env.subenvs[agent_id]).build()
for agent_id in env.agents
}
self.validate_preset(IndependentMultiagentAtariPreset(presets), env)
self.validate_preset(IndependentMultiagentPreset(presets), env)

def validate_preset(self, preset, env):
# normal agent
Expand Down
6 changes: 3 additions & 3 deletions integration/multiagent_atari_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
import torch
from all.environments import MultiagentAtariEnv
from all.presets.multiagent_atari import IndependentMultiagentAtariPreset
from all.presets import IndependentMultiagentPreset
from all.presets.atari import dqn
from validate_agent import validate_multiagent

Expand All @@ -25,15 +25,15 @@ def test_independent(self):
agent_id : dqn().device(CPU).env(env.subenvs[agent_id]).build()
for agent_id in env.agents
}
validate_multiagent(IndependentMultiagentAtariPreset(presets), env)
validate_multiagent(IndependentMultiagentPreset(presets), env)

def test_independent_cuda(self):
env = MultiagentAtariEnv('pong_v1', device=CUDA)
presets = {
agent_id : dqn().device(CUDA).env(env.subenvs[agent_id]).build()
for agent_id in env.agents
}
validate_multiagent(IndependentMultiagentAtariPreset(presets), env)
validate_multiagent(IndependentMultiagentPreset(presets), env)


if __name__ == "__main__":
Expand Down

0 comments on commit e244c37

Please sign in to comment.