Skip to content

Commit

Permalink
Feature/parallel test agent (#240)
Browse files Browse the repository at this point in the history
* add parallel_test_agent to ParallelPreset

* switch parallel experiment to run tests in parallel

* update atari presets

* update classic control presets

* update continuous preset

* autoformat
  • Loading branch information
cpnota committed Mar 31, 2021
1 parent 5ee29ea commit df96ce6
Show file tree
Hide file tree
Showing 26 changed files with 211 additions and 62 deletions.
2 changes: 1 addition & 1 deletion all/agents/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _make_buffer(self):
)


class A2CTestAgent(Agent):
class A2CTestAgent(Agent, ParallelAgent):
def __init__(self, features, policy):
self.features = features
self.policy = policy
Expand Down
10 changes: 3 additions & 7 deletions all/agents/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,8 @@ def _should_train(self):


class DQNTestAgent(Agent):
def __init__(self, q, n_actions, exploration=0.):
self.q = q
self.n_actions = n_actions
self.exploration = 0.001
def __init__(self, policy):
self.policy = policy

def act(self, state):
if np.random.rand() < self.exploration:
return np.random.randint(0, self.n_actions)
return torch.argmax(self.q.eval(state)).item()
return self.policy.eval(state)
7 changes: 6 additions & 1 deletion all/agents/vqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,9 @@ def _train(self, reward, next_state):
self.q.reinforce(loss)


VQNTestAgent = DQNTestAgent
class VQNTestAgent(Agent, ParallelAgent):
def __init__(self, policy):
self.policy = policy

def act(self, state):
return self.policy.eval(state)
5 changes: 2 additions & 3 deletions all/agents/vsarsa.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from torch.nn.functional import mse_loss
from ._agent import Agent
from ._parallel_agent import ParallelAgent
from .dqn import DQNTestAgent
from .vqn import VQNTestAgent


class VSarsa(ParallelAgent):
Expand Down Expand Up @@ -47,4 +46,4 @@ def _train(self, reward, next_state, next_action):
self.q.reinforce(loss)


VSarsaTestAgent = DQNTestAgent
VSarsaTestAgent = VQNTestAgent
57 changes: 32 additions & 25 deletions all/experiments/parallel_env_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,31 +88,38 @@ def train(self, frames=np.inf, episodes=np.inf):
self._episode += episodes_completed

def test(self, episodes=100):
test_agent = self._preset.test_agent()
returns = 0
first_state = self._env.reset()[0]
eps_returns = []
while len(eps_returns) < episodes:
first_action = test_agent.act(first_state)
if isinstance(self._env.action_space, gym.spaces.Discrete):
action = torch.tensor([first_action] * self._env.num_envs)
else:
action = torch.tensor(first_action).reshape(1, -1).repeat(self._env.num_envs, 1)
state_array = self._env.step(action)
dones = state_array.done.cpu().detach().numpy()
rews = state_array.reward.cpu().detach().numpy()
first_state = state_array[0]
returns += rews[0]
for i in range(1):
if dones[i]:
episode_return = returns
esp_index = len(eps_returns)
eps_returns.append(episode_return)
returns = 0
self._log_test_episode(esp_index, episode_return)

self._log_test(eps_returns)
return eps_returns
test_agent = self._preset.parallel_test_agent()

# Note that we need to record the first N episodes that are STARTED,
# not the first N that are completed, or we introduce bias.
test_returns = []
episodes_started = self._n_envs
should_record = [True] * self._n_envs

# initialize state
states = self._env.reset()
returns = states.reward.clone()

while len(test_returns) < episodes:
# step the agent and environments
actions = test_agent.act(states)
states = self._env.step(actions)
returns += states.reward

# record any episodes that have finished
for i, done in enumerate(states.done):
if done:
if should_record[i] and len(test_returns) < episodes:
episode_return = returns[i].item()
test_returns.append(episode_return)
self._log_test_episode(len(test_returns), episode_return)
returns[i] = 0.
episodes_started += 1
if episodes_started > episodes:
should_record[i] = False

self._log_test(test_returns)
return test_returns

def _done(self, frames, episodes):
return self._frame > frames or self._episode > episodes
Expand Down
3 changes: 3 additions & 0 deletions all/presets/atari/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,5 +112,8 @@ def test_agent(self):
policy = SoftmaxPolicy(copy.deepcopy(self.policy_model))
return DeepmindAtariBody(A2CTestAgent(features, policy))

def parallel_test_agent(self):
return self.test_agent()


a2c = ParallelPresetBuilder('a2c', default_hyperparameters, A2CAtariPreset)
7 changes: 5 additions & 2 deletions all/presets/atari/ddqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,12 @@ def agent(self, writer=DummyWriter(), train_steps=float('inf')):

def test_agent(self):
q = QNetwork(copy.deepcopy(self.model))
return DeepmindAtariBody(
DDQNTestAgent(q, self.n_actions, exploration=self.hyperparameters['test_exploration'])
policy = GreedyPolicy(
q,
self.n_actions,
epsilon=self.hyperparameters['test_exploration']
)
return DeepmindAtariBody(DDQNTestAgent(policy))


ddqn = PresetBuilder('ddqn', default_hyperparameters, DDQNAtariPreset)
7 changes: 5 additions & 2 deletions all/presets/atari/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,12 @@ def agent(self, writer=DummyWriter(), train_steps=float('inf')):

def test_agent(self):
q = QNetwork(copy.deepcopy(self.model))
return DeepmindAtariBody(
DQNTestAgent(q, self.n_actions, exploration=self.hyperparameters['test_exploration'])
policy = GreedyPolicy(
q,
self.n_actions,
epsilon=self.hyperparameters['test_exploration']
)
return DeepmindAtariBody(DQNTestAgent(policy))


dqn = PresetBuilder('dqn', default_hyperparameters, DQNAtariPreset)
3 changes: 3 additions & 0 deletions all/presets/atari/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,5 +134,8 @@ def test_agent(self):
policy = SoftmaxPolicy(copy.deepcopy(self.policy_model))
return DeepmindAtariBody(PPOTestAgent(features, policy))

def parallel_test_agent(self):
return self.test_agent()


ppo = ParallelPresetBuilder('ppo', default_hyperparameters, PPOAtariPreset)
3 changes: 3 additions & 0 deletions all/presets/atari/vac.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,5 +100,8 @@ def test_agent(self):
policy = SoftmaxPolicy(copy.deepcopy(self.policy_model))
return DeepmindAtariBody(VACTestAgent(features, policy))

def parallel_test_agent(self):
return self.test_agent()


vac = ParallelPresetBuilder('vac', default_hyperparameters, VACAtariPreset)
3 changes: 3 additions & 0 deletions all/presets/atari/vpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,5 +99,8 @@ def test_agent(self):
policy = SoftmaxPolicy(copy.deepcopy(self.policy_model))
return DeepmindAtariBody(VPGTestAgent(features, policy))

def parallel_test_agent(self):
return self.test_agent()


vpg = PresetBuilder('vpg', default_hyperparameters, VPGAtariPreset)
12 changes: 8 additions & 4 deletions all/presets/atari/vqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from all.bodies import DeepmindAtariBody
from all.logging import DummyWriter
from all.optim import LinearScheduler
from all.policies import ParallelGreedyPolicy
from all.policies import GreedyPolicy, ParallelGreedyPolicy
from all.presets.builder import ParallelPresetBuilder
from all.presets.preset import ParallelPreset
from all.presets.atari.models import nature_ddqn
Expand Down Expand Up @@ -92,9 +92,13 @@ def agent(self, writer=DummyWriter(), train_steps=float('inf')):

def test_agent(self):
q = QNetwork(copy.deepcopy(self.model))
return DeepmindAtariBody(
VQNTestAgent(q, self.n_actions, exploration=self.hyperparameters['test_exploration'])
)
policy = GreedyPolicy(q, self.n_actions, epsilon=self.hyperparameters['test_exploration'])
return DeepmindAtariBody(VQNTestAgent(policy))

def parallel_test_agent(self):
q = QNetwork(copy.deepcopy(self.model))
policy = ParallelGreedyPolicy(q, self.n_actions, epsilon=self.hyperparameters['test_exploration'])
return DeepmindAtariBody(VQNTestAgent(policy))


vqn = ParallelPresetBuilder('vqn', default_hyperparameters, VQNAtariPreset)
12 changes: 8 additions & 4 deletions all/presets/atari/vsarsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from all.bodies import DeepmindAtariBody
from all.logging import DummyWriter
from all.optim import LinearScheduler
from all.policies import ParallelGreedyPolicy
from all.policies import GreedyPolicy, ParallelGreedyPolicy
from all.presets.builder import ParallelPresetBuilder
from all.presets.preset import ParallelPreset
from all.presets.atari.models import nature_ddqn
Expand Down Expand Up @@ -92,9 +92,13 @@ def agent(self, writer=DummyWriter(), train_steps=float('inf')):

def test_agent(self):
q = QNetwork(copy.deepcopy(self.model))
return DeepmindAtariBody(
VSarsaTestAgent(q, self.n_actions, exploration=self.hyperparameters['test_exploration'])
)
policy = GreedyPolicy(q, self.n_actions, epsilon=self.hyperparameters['test_exploration'])
return DeepmindAtariBody(VSarsaTestAgent(policy))

def parallel_test_agent(self):
q = QNetwork(copy.deepcopy(self.model))
policy = ParallelGreedyPolicy(q, self.n_actions, epsilon=self.hyperparameters['test_exploration'])
return DeepmindAtariBody(VSarsaTestAgent(policy))


vsarsa = ParallelPresetBuilder('vsarsa', default_hyperparameters, VSarsaAtariPreset)
28 changes: 26 additions & 2 deletions all/presets/atari_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
import unittest
import torch
from all.environments import AtariEnvironment
from all.environments import AtariEnvironment, DuplicateEnvironment
from all.logging import DummyWriter
from all.presets import Preset, ParallelPreset
from all.presets.atari import (
a2c,
c51,
Expand All @@ -21,6 +22,8 @@ class TestAtariPresets(unittest.TestCase):
def setUp(self):
self.env = AtariEnvironment('Breakout')
self.env.reset()
self.parallel_env = DuplicateEnvironment([AtariEnvironment('Breakout'), AtariEnvironment('Breakout')])
self.parallel_env.reset()

def tearDown(self):
if os.path.exists('test_preset.pt'):
Expand Down Expand Up @@ -58,7 +61,12 @@ def test_vqn(self):

def validate_preset(self, builder):
preset = builder.device('cpu').env(self.env).build()
# normal agent
if isinstance(preset, ParallelPreset):
return self.validate_parallel_preset(preset)
return self.validate_standard_preset(preset)

def validate_standard_preset(self, preset):
# train agent
agent = preset.agent(writer=DummyWriter(), train_steps=100000)
agent.act(self.env.state)
# test agent
Expand All @@ -70,6 +78,22 @@ def validate_preset(self, builder):
test_agent = preset.test_agent()
test_agent.act(self.env.state)

def validate_parallel_preset(self, preset):
# train agent
agent = preset.agent(writer=DummyWriter(), train_steps=100000)
agent.act(self.parallel_env.state_array)
# test agent
test_agent = preset.test_agent()
test_agent.act(self.env.state)
# parallel test_agent
parallel_test_agent = preset.test_agent()
parallel_test_agent.act(self.parallel_env.state_array)
# test save/load
preset.save('test_preset.pt')
preset = torch.load('test_preset.pt')
test_agent = preset.test_agent()
test_agent.act(self.env.state)


if __name__ == "__main__":
unittest.main()
3 changes: 3 additions & 0 deletions all/presets/classic_control/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,5 +99,8 @@ def test_agent(self):
policy = SoftmaxPolicy(copy.deepcopy(self.policy_model))
return A2CTestAgent(features, policy)

def parallel_test_agent(self):
return self.test_agent()


a2c = ParallelPresetBuilder('a2c', default_hyperparameters, A2CClassicControlPreset)
3 changes: 2 additions & 1 deletion all/presets/classic_control/ddqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def agent(self, writer=DummyWriter(), train_steps=float('inf')):

def test_agent(self):
q = QNetwork(copy.deepcopy(self.model))
return DDQNTestAgent(q, self.n_actions, exploration=self.hyperparameters['test_exploration'])
policy = GreedyPolicy(q, self.n_actions, epsilon=self.hyperparameters['test_exploration'])
return DDQNTestAgent(policy)


ddqn = PresetBuilder('ddqn', default_hyperparameters, DDQNClassicControlPreset)
3 changes: 2 additions & 1 deletion all/presets/classic_control/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ def agent(self, writer=DummyWriter(), train_steps=float('inf')):

def test_agent(self):
q = QNetwork(copy.deepcopy(self.model))
return DQNTestAgent(q, self.n_actions, exploration=self.hyperparameters['test_exploration'])
policy = GreedyPolicy(q, self.n_actions, epsilon=self.hyperparameters['test_exploration'])
return DQNTestAgent(policy)


dqn = PresetBuilder('dqn', default_hyperparameters, DQNClassicControlPreset)
3 changes: 3 additions & 0 deletions all/presets/classic_control/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,5 +129,8 @@ def test_agent(self):
policy = SoftmaxPolicy(copy.deepcopy(self.policy_model))
return PPOTestAgent(features, policy)

def parallel_test_agent(self):
return self.test_agent()


ppo = ParallelPresetBuilder('ppo', default_hyperparameters, PPOClassicControlPreset)
3 changes: 3 additions & 0 deletions all/presets/classic_control/vac.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,5 +93,8 @@ def test_agent(self):
policy = SoftmaxPolicy(copy.deepcopy(self.policy_model))
return VACTestAgent(features, policy)

def parallel_test_agent(self):
return self.test_agent()


vac = ParallelPresetBuilder('vac', default_hyperparameters, VACClassicControlPreset)
3 changes: 3 additions & 0 deletions all/presets/classic_control/vpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,5 +92,8 @@ def test_agent(self):
policy = SoftmaxPolicy(copy.deepcopy(self.policy_model))
return VPGTestAgent(features, policy)

def parallel_test_agent(self):
return self.test_agent()


vpg = PresetBuilder('vpg', default_hyperparameters, VPGClassicControlPreset)
10 changes: 8 additions & 2 deletions all/presets/classic_control/vqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from all.bodies import DeepmindAtariBody
from all.logging import DummyWriter
from all.optim import LinearScheduler
from all.policies import ParallelGreedyPolicy
from all.policies import GreedyPolicy, ParallelGreedyPolicy
from all.presets.builder import ParallelPresetBuilder
from all.presets.preset import ParallelPreset
from all.presets.classic_control.models import dueling_fc_relu_q
Expand Down Expand Up @@ -90,7 +90,13 @@ def agent(self, writer=DummyWriter(), train_steps=float('inf')):

def test_agent(self):
q = QNetwork(copy.deepcopy(self.model))
return VQNTestAgent(q, self.n_actions, exploration=self.hyperparameters['test_exploration'])
policy = GreedyPolicy(q, self.n_actions, epsilon=self.hyperparameters["test_exploration"])
return VQNTestAgent(policy)

def parallel_test_agent(self):
q = QNetwork(copy.deepcopy(self.model))
policy = ParallelGreedyPolicy(q, self.n_actions, epsilon=self.hyperparameters["test_exploration"])
return VQNTestAgent(policy)


vqn = ParallelPresetBuilder('vqn', default_hyperparameters, VQNClassicControlPreset)
10 changes: 8 additions & 2 deletions all/presets/classic_control/vsarsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from all.bodies import DeepmindAtariBody
from all.logging import DummyWriter
from all.optim import LinearScheduler
from all.policies import ParallelGreedyPolicy
from all.policies import GreedyPolicy, ParallelGreedyPolicy
from all.presets.builder import ParallelPresetBuilder
from all.presets.preset import ParallelPreset
from all.presets.classic_control.models import dueling_fc_relu_q
Expand Down Expand Up @@ -90,7 +90,13 @@ def agent(self, writer=DummyWriter(), train_steps=float('inf')):

def test_agent(self):
q = QNetwork(copy.deepcopy(self.model))
return VSarsaTestAgent(q, self.n_actions, exploration=self.hyperparameters['test_exploration'])
policy = GreedyPolicy(q, self.n_actions, epsilon=self.hyperparameters["test_exploration"])
return VSarsaTestAgent(policy)

def parallel_test_agent(self):
q = QNetwork(copy.deepcopy(self.model))
policy = ParallelGreedyPolicy(q, self.n_actions, epsilon=self.hyperparameters["test_exploration"])
return VSarsaTestAgent(policy)


vsarsa = ParallelPresetBuilder('vsarsa', default_hyperparameters, VSarsaClassicControlPreset)

0 comments on commit df96ce6

Please sign in to comment.