Skip to content

Commit

Permalink
Feature/multi agent atari (#201)
Browse files Browse the repository at this point in the history
* add initial multiagent_atari implementation

* add multiagent atari initial code

* improve package structure

* add multiagent atari

* start updating to new preset method

* add independent

* render and replay buffer size

* add watch script

* remove starter code for parameter sharing dqn

* update multiagent atari env unittest

* update tests and doc for MultiagentAtari environment

* add abstract multiagent environment

* add multiagent env documentation

* update documentation and make MultiagentAtari implement abstract methods

* add multiagent env test

* add test mode

* add ma-atari to extras

* upgrade gym version

* add autorom

* add integration tests

* run formatter

* install unrar on travis

* fix mock writer

* make unit test not write preset to disk

* add multiagent atari preset unittest

* formatting

Co-authored-by: Ben Black <weepingwillowben@gmail.com>
  • Loading branch information
cpnota and benblack769 committed Jan 12, 2021
1 parent 7a8860d commit 97e64d5
Show file tree
Hide file tree
Showing 29 changed files with 1,044 additions and 22 deletions.
2 changes: 2 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ branches:
- master
- develop
before_install:
- sudo apt-get install unrar
- sudo apt-get install swig
install:
- pip install torch==1.5.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
- pip install -e .["dev"]
- AutoROM -v
script:
- make lint
- make test
24 changes: 23 additions & 1 deletion all/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,26 @@
import all.agents
import all.approximation
import all.core
import all.environments
import all.logging
import all.memory
import all.nn
import all.optim
import all.policies
import all.presets
from all.core import State, StateArray

__all__ = ['nn', 'State', 'StateArray']
__all__ = [
'agents',
'approximation',
'core',
'environments',
'logging',
'memory',
'nn',
'optim',
'policies',
'presets',
'State',
'StateArray'
]
8 changes: 7 additions & 1 deletion all/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .multi import Multiagent, IndependentMultiagent
from ._agent import Agent
from .a2c import A2C, A2CTestAgent
from .c51 import C51, C51TestAgent
Expand All @@ -12,7 +13,9 @@
from .vqn import VQN, VQNTestAgent
from .vsarsa import VSarsa, VSarsaTestAgent


__all__ = [
# single agents
"Agent",
"A2C",
"A2CTestAgent",
Expand All @@ -37,5 +40,8 @@
"VQN",
"VQNTestAgent",
"VSarsa",
"VSarsaTestAgent"
"VSarsaTestAgent",
# multiagents
"Multiagent",
"IndependentMultiagent",
]
2 changes: 1 addition & 1 deletion all/agents/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class DQNTestAgent(Agent):
def __init__(self, q, n_actions, exploration=0.):
self.q = q
self.n_actions = n_actions
self.exploration = exploration
self.exploration = 0.001

def act(self, state):
if np.random.rand() < self.exploration:
Expand Down
7 changes: 7 additions & 0 deletions all/agents/multi/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from ._multiagent import Multiagent
from .independent import IndependentMultiagent

__all__ = [
"Multiagent",
"IndependentMultiagent"
]
34 changes: 34 additions & 0 deletions all/agents/multi/_multiagent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from abc import ABC, abstractmethod
from all.optim import Schedulable


class Multiagent(ABC, Schedulable):
"""
A reinforcement learning agent.
In reinforcement learning, an Agent learns by interacting with an Environment.
Usually, an agent tries to maximize a reward signal.
It does this by observing environment "states", taking "actions", receiving "rewards",
and in doing so, learning which state-action pairs correlate with high rewards.
An Agent implementation should encapsulate some particular reinforcement learning algorihthm.
"""

@abstractmethod
def act(self, state):
"""
Select an action for the current timestep and update internal parameters.
In general, a reinforcement learning agent does several things during a timestep:
1. Choose an action,
2. Compute the TD error from the previous time step
3. Update the value function and/or policy
The order of these steps differs depending on the agent.
This method allows the agent to do whatever is necessary for itself on a given timestep.
However, the agent must ultimately return an action.
Args:
state (all.core.MultiAgentState): The environment state at the current timestep.
Returns:
torch.Tensor: The action to take at the current timestep.
"""
9 changes: 9 additions & 0 deletions all/agents/multi/independent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from ._multiagent import Multiagent


class IndependentMultiagent(Multiagent):
def __init__(self, agents):
self.agents = agents

def act(self, state):
return self.agents[state['agent']].act(state)
4 changes: 2 additions & 2 deletions all/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .state import State, StateArray
from .state import State, StateArray, MultiAgentState

__all__ = ['State', 'StateArray']
__all__ = ['State', 'StateArray', 'MultiAgentState']
64 changes: 64 additions & 0 deletions all/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,67 @@ def shape(self):

def __len__(self):
return self.shape[0]


class MultiAgentState(State):
def __init__(self, x, device='cpu', **kwargs):
if 'agent' not in x:
raise Exception('MultiAgentState must contain an agent ID')
super().__init__(x, device=device, **kwargs)

@property
def agent(self):
return self['agent']

@classmethod
def from_zoo(cls, agent, state, device='cpu', dtype=np.float32):
"""
Constructs a State object given the return value of an OpenAI gym reset()/step(action) call.
Args:
state (tuple): The return value of an OpenAI gym reset()/step(action) call
device (string): The device on which to store resulting tensors.
dtype: The type of the observation.
Returns:
A State object.
"""
if not isinstance(state, tuple):
return MultiAgentState({
'agent': agent,
'observation': torch.from_numpy(
np.array(
state,
dtype=dtype
),
).to(device)
}, device=device)

observation, reward, done, info = state
observation = torch.from_numpy(
np.array(
observation,
dtype=dtype
),
).to(device)
x = {
'agent': agent,
'observation': observation,
'reward': float(reward),
'done': done,
}
info = info if info else {}
for key in info:
x[key] = info[key]
return MultiAgentState(x, device=device)

def to(self, device):
if device == self.device:
return self
x = {}
for key, value in self.items():
if torch.is_tensor(value):
x[key] = value.to(device)
else:
x[key] = value
return type(self)(x, device=device, shape=self._shape)
12 changes: 10 additions & 2 deletions all/environments/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from .abstract import Environment
from ._environment import Environment
from._multiagent_environment import MultiagentEnvironment
from .gym import GymEnvironment
from .atari import AtariEnvironment
from .multiagent_atari import MultiagentAtariEnv

__all__ = ["Environment", "GymEnvironment", "AtariEnvironment"]
__all__ = [
"Environment",
"MultiagentEnvironment",
"GymEnvironment",
"AtariEnvironment",
"MultiagentAtariEnv",
]
File renamed without changes.
104 changes: 104 additions & 0 deletions all/environments/_multiagent_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from abc import ABC, abstractmethod


class MultiagentEnvironment(ABC):
'''
A multiagent reinforcement learning Environment.
The Multiagent variant of the Environment object.
An Environment defines the dynamics of a particular problem:
the states, the actions, the transitions between states, and the rewards given to the agent.
Environments are often used to benchmark reinforcement learning agents,
or to define real problems that the user hopes to solve using reinforcement learning.
'''

@abstractmethod
def reset(self):
'''
Reset the environment and return a new intial state for the first agent.
Returns
all.core.MultiagentState: The initial state for the next episode.
'''

@abstractmethod
def step(self, action):
'''
Apply an action for the current agent and get the multiagent state for the next agent.
Parameters:
action: The Action for the current agent and timestep.
Returns:
all.core.MultiagentState: The state for the next agent.
'''

@abstractmethod
def render(self, **kwargs):
'''Render the current environment state.'''

@abstractmethod
def close(self):
'''Clean up any extraneaous environment objects.'''

@abstractmethod
def agent_iter(self):
'''
Create an iterable which that the next element is always the name of the agent whose turn it is to act.
Returns:
An Iterable over Agent strings.
'''

@abstractmethod
def last(self):
'''
Get the MultiagentState object for the current agent.
Returns:
The all.core.MultiagentState object for the current agent.
'''

@abstractmethod
def is_done(self, agent):
'''
Determine whether a given agent is done.
Args:
agent (str): The name of the agent.
Returns:
A boolean representing whether the given agent is done.
'''

@property
def state(self):
'''The State for the current agent.'''
return self.last()

@property
@abstractmethod
def name(self):
'''str: The name of the environment.'''

@property
@abstractmethod
def state_spaces(self):
'''A dictionary of state spaces for each agent.'''

@property
def observation_spaces(self):
'''Alias for MultiagentEnvironment.state_spaces.'''
return self.state_space

@property
@abstractmethod
def action_spaces(self):
'''A dictionary of action spaces for each agent.'''

@property
@abstractmethod
def device(self):
'''
The torch device the environment lives on.
'''
2 changes: 1 addition & 1 deletion all/environments/gym.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import gym
import torch
from all.core import State
from .abstract import Environment
from ._environment import Environment
import cloudpickle
gym.logger.set_level(40)

Expand Down

0 comments on commit 97e64d5

Please sign in to comment.