Skip to content

Commit

Permalink
Refactor/agent types (#221)
Browse files Browse the repository at this point in the history
* add new abstract classes

* flatten all.agents directory structure

* add uncommitted files

* refactor PresetBuilder

* update classic control presets

* update atari presets

* fix continuous presets

* update independent multiagent

* fix unit tests

* update integration tests

* run formatter

* update atari script

* update run scripts

* update docstring for agent base classes

* make paramaters optional

* remove unnecessary print statement

* fix multiagent atari script
  • Loading branch information
cpnota committed Jan 24, 2021
1 parent fad9f4a commit 15411aa
Show file tree
Hide file tree
Showing 58 changed files with 524 additions and 408 deletions.
11 changes: 7 additions & 4 deletions all/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from .multi import Multiagent, IndependentMultiagent
from ._agent import Agent
from ._multiagent import Multiagent
from ._parallel_agent import ParallelAgent
from .a2c import A2C, A2CTestAgent
from .c51 import C51, C51TestAgent
from .ddpg import DDPG, DDPGTestAgent
from .ddqn import DDQN, DDQNTestAgent
from .dqn import DQN, DQNTestAgent
from .independent import IndependentMultiagent
from .ppo import PPO, PPOTestAgent
from .rainbow import Rainbow, RainbowTestAgent
from .sac import SAC, SACTestAgent
Expand All @@ -15,8 +17,11 @@


__all__ = [
# single agents
# Agent interfaces
"Agent",
"Multiagent",
"ParallelAgent",
# Agent implementations
"A2C",
"A2CTestAgent",
"C51",
Expand All @@ -41,7 +46,5 @@
"VQNTestAgent",
"VSarsa",
"VSarsaTestAgent",
# multiagents
"Multiagent",
"IndependentMultiagent",
]
6 changes: 3 additions & 3 deletions all/agents/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ class Agent(ABC, Schedulable):
A reinforcement learning agent.
In reinforcement learning, an Agent learns by interacting with an Environment.
Usually, an agent tries to maximize a reward signal.
Usually, an Agent tries to maximize a reward signal.
It does this by observing environment "states", taking "actions", receiving "rewards",
and in doing so, learning which state-action pairs correlate with high rewards.
An Agent implementation should encapsulate some particular reinforcement learning algorihthm.
and learning which state-action pairs correlate with high rewards.
An Agent implementation should encapsulate some particular reinforcement learning algorithm.
"""

@abstractmethod
Expand Down
12 changes: 6 additions & 6 deletions all/agents/multi/_multiagent.py → all/agents/_multiagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@

class Multiagent(ABC, Schedulable):
"""
A reinforcement learning agent.
A multiagent RL agent. Differs from standard agents in that it accepts a multiagent state.
In reinforcement learning, an Agent learns by interacting with an Environment.
Usually, an agent tries to maximize a reward signal.
It does this by observing environment "states", taking "actions", receiving "rewards",
and in doing so, learning which state-action pairs correlate with high rewards.
An Agent implementation should encapsulate some particular reinforcement learning algorihthm.
and learning which state-action pairs correlate with high rewards.
An Agent implementation should encapsulate some particular reinforcement learning algorithm.
"""

@abstractmethod
def act(self, state):
def act(self, multiagent_state):
"""
Select an action for the current timestep and update internal parameters.
Expand All @@ -27,8 +27,8 @@ def act(self, state):
However, the agent must ultimately return an action.
Args:
state (all.core.MultiAgentState): The environment state at the current timestep.
multiagent_state (all.core.MultiAgentState): The environment state at the current timestep.
Returns:
torch.Tensor: The action to take at the current timestep.
torch.Tensor: The action for the current agent to take at the current timestep.
"""
36 changes: 36 additions & 0 deletions all/agents/_parallel_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from abc import ABC, abstractmethod
from all.optim import Schedulable


class ParallelAgent(ABC, Schedulable):
"""
A reinforcement learning agent that chooses actions for multiple states simultaneously.
Differs from SingleAgent in that it accepts a StateArray instead of a State to process
input from multiple environments in parallel.
In reinforcement learning, an Agent learns by interacting with an Environment.
Usually, an Agent tries to maximize a reward signal.
It does this by observing environment "states", taking "actions", receiving "rewards",
and learning which state-action pairs correlate with high rewards.
An Agent implementation should encapsulate some particular reinforcement learning algorithm.
"""

@abstractmethod
def act(self, state_array):
"""
Select an action for the current timestep and update internal parameters.
In general, a reinforcement learning agent does several things during a timestep:
1. Choose an action,
2. Compute the TD error from the previous time step
3. Update the value function and/or policy
The order of these steps differs depending on the agent.
This method allows the agent to do whatever is necessary for itself on a given timestep.
However, the agent must ultimately return an action.
Args:
state_array (all.environment.StateArray): An array of states for each parallel environment.
Returns:
torch.Tensor: The actions to take for each parallel environmets.
"""
3 changes: 2 additions & 1 deletion all/agents/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from all.logging import DummyWriter
from all.memory import NStepAdvantageBuffer
from ._agent import Agent
from ._parallel_agent import ParallelAgent


class A2C(Agent):
class A2C(ParallelAgent):
"""
Advantage Actor-Critic (A2C).
A2C is policy gradient method in the actor-critic family.
Expand Down
File renamed without changes.
7 changes: 0 additions & 7 deletions all/agents/multi/__init__.py

This file was deleted.

3 changes: 2 additions & 1 deletion all/agents/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from all.logging import DummyWriter
from all.memory import GeneralizedAdvantageBuffer
from ._agent import Agent
from ._parallel_agent import ParallelAgent
from .a2c import A2CTestAgent


class PPO(Agent):
class PPO(ParallelAgent):
"""
Proximal Policy Optimization (PPO).
PPO is an actor-critic style policy gradient algorithm that allows for the reuse of samples
Expand Down
3 changes: 2 additions & 1 deletion all/agents/vac.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from torch.nn.functional import mse_loss
from ._agent import Agent
from ._parallel_agent import ParallelAgent
from .a2c import A2CTestAgent


class VAC(Agent):
class VAC(ParallelAgent):
'''
Vanilla Actor-Critic (VAC).
VAC is an implementation of the actor-critic alogorithm found in the Sutton and Barto (2018) textbook.
Expand Down
3 changes: 2 additions & 1 deletion all/agents/vqn.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import torch
from torch.nn.functional import mse_loss
from ._agent import Agent
from ._parallel_agent import ParallelAgent
from .dqn import DQNTestAgent


class VQN(Agent):
class VQN(ParallelAgent):
'''
Vanilla Q-Network (VQN).
VQN is an implementation of the Q-learning algorithm found in the Sutton and Barto (2018) textbook.
Expand Down
3 changes: 2 additions & 1 deletion all/agents/vsarsa.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from torch.nn.functional import mse_loss
from ._agent import Agent
from ._parallel_agent import ParallelAgent
from .dqn import DQNTestAgent


class VSarsa(Agent):
class VSarsa(ParallelAgent):
'''
Vanilla SARSA (VSarsa).
SARSA (State-Action-Reward-State-Action) is an on-policy alternative to Q-learning. Unlike Q-learning,
Expand Down
1 change: 0 additions & 1 deletion all/environments/multiagent_pettingzoo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def test_action_spaces(self):

def test_list_agents(self):
env = self._make_env()
print(env.agents)
self.assertEqual(env.agents, ['leadadversary_0', 'adversary_0', 'adversary_1', 'adversary_2', 'agent_0', 'agent_1'])

def test_is_done(self):
Expand Down
4 changes: 2 additions & 2 deletions all/experiments/multiagent_env_experiment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def test_writes_loss(self):
self.assertFalse(experiment._writer.write_loss)

def make_preset(self):
return IndependentMultiagentPreset({
agent: dqn().device('cpu').env(env).build()
return IndependentMultiagentPreset('independent', 'cpu', {
agent: dqn.device('cpu').env(env).build()
for agent, env in self.env.subenvs.items()
})

Expand Down
2 changes: 1 addition & 1 deletion all/experiments/parallel_env_experiment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_writes_loss(self):
self.assertFalse(experiment._writer.write_loss)

def make_agent(self):
return a2c().device('cpu').env(self.env).build()
return a2c.device('cpu').env(self.env).build()


if __name__ == "__main__":
Expand Down
7 changes: 2 additions & 5 deletions all/experiments/run_experiment.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .single_env_experiment import SingleEnvExperiment
from .parallel_env_experiment import ParallelEnvExperiment
from all.presets import ParallelPreset


def run_experiment(
Expand Down Expand Up @@ -40,10 +41,6 @@ def run_experiment(


def get_experiment_type(preset):
if preset.is_parallel():
if isinstance(preset, ParallelPreset):
return ParallelEnvExperiment
return SingleEnvExperiment


def is_parallel_env_agent(agent):
return isinstance(agent, tuple)
2 changes: 1 addition & 1 deletion all/experiments/single_env_experiment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def test_writes_loss(self):
self.assertFalse(experiment._writer.write_loss)

def make_preset(self):
return dqn().device('cpu').env(self.env).build()
return dqn.device('cpu').env(self.env).build()


if __name__ == "__main__":
Expand Down
6 changes: 5 additions & 1 deletion all/presets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from all.presets import atari
from all.presets import classic_control
from all.presets import continuous
from .preset import Preset
from .preset import Preset, ParallelPreset
from .builder import PresetBuilder, ParallelPresetBuilder
from .independent_multiagent import IndependentMultiagentPreset

__all__ = [
"Preset",
"ParallelPreset",
"PresetBuilder",
"ParallelPresetBuilder",
"atari",
"classic_control",
"continuous",
Expand Down
20 changes: 9 additions & 11 deletions all/presets/atari/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from all.approximation import VNetwork, FeatureNetwork
from all.logging import DummyWriter
from all.policies import SoftmaxPolicy
from .models import nature_features, nature_value_head, nature_policy_head
from ..builder import preset_builder
from ..preset import Preset
from all.presets.builder import ParallelPresetBuilder
from all.presets.preset import ParallelPreset
from all.presets.atari.models import nature_features, nature_value_head, nature_policy_head


default_hyperparameters = {
Expand All @@ -32,13 +32,14 @@
}


class A2CAtariPreset(Preset):
class A2CAtariPreset(ParallelPreset):
"""
Advantage Actor-Critic (A2C) Atari preset.
Args:
env (all.environments.AtariEnvironment): The environment for which to construct the agent.
device (torch.device, optional): the device on which to load the agent
name (str): A human-readable name for the preset.
device (torch.device): The device on which to load the agent.
Keyword Args:
discount_factor (float): Discount factor for future rewards.
Expand All @@ -55,14 +56,11 @@ class A2CAtariPreset(Preset):
policy_model_constructor (function): The function used to construct the neural policy model.
"""

def __init__(self, env, device="cuda", **hyperparameters):
hyperparameters = {**default_hyperparameters, **hyperparameters}
super().__init__(n_envs=hyperparameters['n_envs'])
def __init__(self, env, name, device, **hyperparameters):
super().__init__(name, device, hyperparameters)
self.value_model = hyperparameters['value_model_constructor']().to(device)
self.policy_model = hyperparameters['policy_model_constructor'](env).to(device)
self.feature_model = hyperparameters['feature_model_constructor']().to(device)
self.hyperparameters = hyperparameters
self.device = device

def agent(self, writer=DummyWriter(), train_steps=float('inf')):
n_updates = train_steps / (self.hyperparameters['n_steps'] * self.hyperparameters['n_envs'])
Expand Down Expand Up @@ -115,4 +113,4 @@ def test_agent(self):
return DeepmindAtariBody(A2CTestAgent(features, policy))


a2c = preset_builder('a2c', default_hyperparameters, A2CAtariPreset)
a2c = ParallelPresetBuilder('a2c', default_hyperparameters, A2CAtariPreset)
18 changes: 8 additions & 10 deletions all/presets/atari/c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from all.logging import DummyWriter
from all.memory import ExperienceReplayBuffer
from all.optim import LinearScheduler
from .models import nature_c51
from ..builder import preset_builder
from ..preset import Preset
from all.presets.builder import PresetBuilder
from all.presets.preset import Preset
from all.presets.atari.models import nature_c51


default_hyperparameters = {
Expand Down Expand Up @@ -44,7 +44,8 @@ class C51AtariPreset(Preset):
Args:
env (all.environments.AtariEnvironment): The environment for which to construct the agent.
device (torch.device, optional): the device on which to load the agent
name (str): A human-readable name for the preset.
device (torch.device): The device on which to load the agent.
Keyword Args:
discount_factor (float): Discount factor for future rewards.
Expand All @@ -67,13 +68,10 @@ class C51AtariPreset(Preset):
model_constructor (function): The function used to construct the neural model.
"""

def __init__(self, env, device="cuda", **hyperparameters):
hyperparameters = {**default_hyperparameters, **hyperparameters}
super().__init__()
def __init__(self, env, name, device, **hyperparameters):
super().__init__(name, device, hyperparameters)
self.model = hyperparameters['model_constructor'](env, atoms=hyperparameters['atoms']).to(device)
self.hyperparameters = hyperparameters
self.n_actions = env.action_space.n
self.device = device

def agent(self, writer=DummyWriter(), train_steps=float('inf')):
n_updates = (train_steps - self.hyperparameters['replay_start_size']) / self.hyperparameters['update_frequency']
Expand Down Expand Up @@ -134,4 +132,4 @@ def test_agent(self):
return DeepmindAtariBody(C51TestAgent(q_dist, self.n_actions, self.hyperparameters["test_exploration"]))


c51 = preset_builder('c51', default_hyperparameters, C51AtariPreset)
c51 = PresetBuilder('c51', default_hyperparameters, C51AtariPreset)

0 comments on commit 15411aa

Please sign in to comment.