Skip to content

Commit

Permalink
Feature/save agents (#185)
Browse files Browse the repository at this point in the history
* initial implementation for dqn

* update watch code

* simply usage of dqn builder

* add a2c preset

* make approximation optimizer optional

* update a2c

* add preset unit test

* add c51 atari preset

* add train_steps parameter

* add train_step to a2c

* add train_steps to c51

* change render command line usage

* add hyperparameter parser

* add DDQNAtariPreset

* update ppo

* rainbow preset

* add vac atari preset

* vpg preset

* add vqn and vsarsa presets

* update integration tests

* make parallel env experiment test with single env agent

* tweak function signature for Preset

* try to get docstrings working

* get documentation working properly

* re-add model constructor to a2c preset

* separate keyword args

* update all docstrings and re-add model constructors

* start converting cc presets

* update c51 cc preset

* update ddqn classic control preset

* update dqn cc preset

* add classic control preset test

* ppo cc preset

* add rainbow cc preset

* add VAC cc preset

* add vpg cc preset

* add vqn cc preset

* add vsarsa cc preset

* export presets

* add ddpg preset

* ppo

* add sac

* fix continuous preset integration tests

* update classic control integration tests

* fix single env experiment test

* fix policy tests

* run autopep8

* deep copy everything

* run autopep on integration tests

* fix linting

* update watch scripts
  • Loading branch information
cpnota committed Nov 23, 2020
1 parent 79d5ee6 commit 8f65a70
Show file tree
Hide file tree
Showing 78 changed files with 2,488 additions and 1,431 deletions.
36 changes: 24 additions & 12 deletions all/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,41 @@
from ._agent import Agent
from .a2c import A2C
from .c51 import C51
from .ddpg import DDPG
from .ddqn import DDQN
from .dqn import DQN
from .ppo import PPO
from .rainbow import Rainbow
from .sac import SAC
from .vac import VAC
from .vpg import VPG
from .vqn import VQN
from .vsarsa import VSarsa
from .a2c import A2C, A2CTestAgent
from .c51 import C51, C51TestAgent
from .ddpg import DDPG, DDPGTestAgent
from .ddqn import DDQN, DDQNTestAgent
from .dqn import DQN, DQNTestAgent
from .ppo import PPO, PPOTestAgent
from .rainbow import Rainbow, RainbowTestAgent
from .sac import SAC, SACTestAgent
from .vac import VAC, VACTestAgent
from .vpg import VPG, VPGTestAgent
from .vqn import VQN, VQNTestAgent
from .vsarsa import VSarsa, VSarsaTestAgent

__all__ = [
"Agent",
"A2C",
"A2CTestAgent",
"C51",
"C51TestAgent",
"DDPG",
"DDPGTestAgent",
"DDQN",
"DDQNTestAgent",
"DQN",
"DQNTestAgent",
"PPO",
"PPOTestAgent",
"Rainbow",
"RainbowTestAgent",
"SAC",
"SACTestAgent",
"VAC",
"VACTestAgent",
"VPG",
"VPGTestAgent",
"VQN",
"VQNTestAgent",
"VSarsa",
"VSarsaTestAgent"
]
17 changes: 0 additions & 17 deletions all/agents/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,3 @@ def act(self, state):
Returns:
torch.Tensor: The action to take at the current timestep.
"""

@abstractmethod
def eval(self, state):
"""
Select an action for the current timestep in evaluation mode.
Unlike act, this method should NOT update the internal parameters of the agent.
Most of the time, this method should return the greedy action according to the current policy.
This method is useful when using evaluation methodologies that distinguish between the performance
of the agent during training and the performance of the resulting policy.
Args:
state (all.environment.State): The environment state at the current timestep.
Returns:
torch.Tensor: The action to take at the current timestep.
"""
13 changes: 10 additions & 3 deletions all/agents/a2c.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
from torch.nn.functional import mse_loss
from all.logging import DummyWriter
from all.memory import NStepAdvantageBuffer
Expand Down Expand Up @@ -61,9 +62,6 @@ def act(self, states):
self._actions = self.policy.no_grad(self.features.no_grad(states)).sample()
return self._actions

def eval(self, states):
return self.policy.eval(self.features.eval(states))

def _train(self, next_states):
if len(self._buffer) >= self._batch_size:
# load trajectories from buffer
Expand Down Expand Up @@ -100,3 +98,12 @@ def _make_buffer(self):
self.n_envs,
discount_factor=self.discount_factor
)


class A2CTestAgent(Agent):
def __init__(self, features, policy):
self.features = features
self.policy = policy

def act(self, state):
return self.policy.eval(self.features.eval(state)).sample()
13 changes: 13 additions & 0 deletions all/agents/c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,16 @@ def _kl(self, dist, target_dist):
log_dist = torch.log(torch.clamp(dist, min=self.eps))
log_target_dist = torch.log(torch.clamp(target_dist, min=self.eps))
return (target_dist * (log_target_dist - log_dist)).sum(dim=-1)


class C51TestAgent(Agent):
def __init__(self, q_dist, n_actions, exploration=0.):
self.q_dist = q_dist
self.n_actions = n_actions
self.exploration = exploration

def act(self, state):
if np.random.rand() < self.exploration:
return np.random.randint(0, self.n_actions)
q_values = (self.q_dist(state) * self.q_dist.atoms).sum(dim=-1)
return torch.argmax(q_values, dim=-1)
8 changes: 8 additions & 0 deletions all/agents/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,11 @@ def _train(self):
def _should_train(self):
self._frames_seen += 1
return self._frames_seen > self.replay_start_size and self._frames_seen % self.update_frequency == 0


class DDPGTestAgent(Agent):
def __init__(self, policy):
self.policy = policy

def act(self, state):
return self.policy.eval(state)
4 changes: 4 additions & 0 deletions all/agents/ddqn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from all.nn import weighted_mse_loss
from ._agent import Agent
from .dqn import DQNTestAgent


class DDQN(Agent):
Expand Down Expand Up @@ -80,3 +81,6 @@ def _train(self):
def _should_train(self):
self._frames_seen += 1
return self._frames_seen > self.replay_start_size and self._frames_seen % self.update_frequency == 0


DDQNTestAgent = DQNTestAgent
13 changes: 13 additions & 0 deletions all/agents/dqn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import torch
from torch.nn.functional import mse_loss
from ._agent import Agent
Expand Down Expand Up @@ -77,3 +78,15 @@ def _train(self):
def _should_train(self):
self._frames_seen += 1
return (self._frames_seen > self.replay_start_size and self._frames_seen % self.update_frequency == 0)


class DQNTestAgent(Agent):
def __init__(self, q, n_actions, exploration=0.):
self.q = q
self.n_actions = n_actions
self.exploration = exploration

def act(self, state):
if np.random.rand() < self.exploration:
return np.random.randint(0, self.n_actions)
return torch.argmax(self.q.eval(state)).item()
4 changes: 4 additions & 0 deletions all/agents/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from all.logging import DummyWriter
from all.memory import GeneralizedAdvantageBuffer
from ._agent import Agent
from .a2c import A2CTestAgent


class PPO(Agent):
Expand Down Expand Up @@ -139,3 +140,6 @@ def _make_buffer(self):
discount_factor=self.discount_factor,
lam=self.lam
)


PPOTestAgent = A2CTestAgent
5 changes: 4 additions & 1 deletion all/agents/rainbow.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .c51 import C51
from .c51 import C51, C51TestAgent


class Rainbow(C51):
Expand Down Expand Up @@ -29,3 +29,6 @@ class Rainbow(C51):
when training begins.
update_frequency (int): Number of timesteps per training update.
"""


RainbowTestAgent = C51TestAgent
11 changes: 8 additions & 3 deletions all/agents/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,6 @@ def act(self, state):
self._action = self.policy.no_grad(state)[0]
return self._action

def eval(self, state):
return self.policy.eval(state)

def _train(self):
if self._should_train():
# sample from replay buffer
Expand Down Expand Up @@ -113,3 +110,11 @@ def _train(self):
def _should_train(self):
self._frames_seen += 1
return self._frames_seen > self.replay_start_size and self._frames_seen % self.update_frequency == 0


class SACTestAgent(Agent):
def __init__(self, policy):
self.policy = policy

def act(self, state):
return self.policy.eval(state)
4 changes: 4 additions & 0 deletions all/agents/vac.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from torch.nn.functional import mse_loss
from ._agent import Agent
from .a2c import A2CTestAgent


class VAC(Agent):
Expand Down Expand Up @@ -56,3 +57,6 @@ def _train(self, state, reward):
self.v.reinforce(value_loss)
self.policy.reinforce(policy_loss)
self.features.reinforce()


VACTestAgent = A2CTestAgent
4 changes: 4 additions & 0 deletions all/agents/vpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from torch.nn.functional import mse_loss
from all.core import State
from ._agent import Agent
from .a2c import A2CTestAgent


class VPG(Agent):
Expand Down Expand Up @@ -130,3 +131,6 @@ def _compute_discounted_returns(self, rewards):
returns[t] = discounted_return
t -= 1
return returns


VPGTestAgent = A2CTestAgent
4 changes: 4 additions & 0 deletions all/agents/vqn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from torch.nn.functional import mse_loss
from ._agent import Agent
from .dqn import DQNTestAgent


class VQN(Agent):
Expand Down Expand Up @@ -46,3 +47,6 @@ def _train(self, reward, next_state):
loss = mse_loss(value, target)
# backward pass
self.q.reinforce(loss)


VQNTestAgent = DQNTestAgent
4 changes: 4 additions & 0 deletions all/agents/vsarsa.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from torch.nn.functional import mse_loss
from ._agent import Agent
from .dqn import DQNTestAgent


class VSarsa(Agent):
Expand Down Expand Up @@ -43,3 +44,6 @@ def _train(self, reward, next_state, next_action):
loss = mse_loss(value, target)
# backward pass
self.q.reinforce(loss)


VSarsaTestAgent = DQNTestAgent
6 changes: 3 additions & 3 deletions all/approximation/approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch.nn import utils
from all.logging import DummyWriter
from .target import TrivialTarget
from .checkpointer import PeriodicCheckpointer
from .checkpointer import DummyCheckpointer

DEFAULT_CHECKPOINT_FREQUENCY = 200

Expand Down Expand Up @@ -51,7 +51,7 @@ class Approximation():
def __init__(
self,
model,
optimizer,
optimizer=None,
checkpointer=None,
clip_grad=0,
loss_scaling=1,
Expand All @@ -74,7 +74,7 @@ def __init__(
self._name = name

if checkpointer is None:
checkpointer = PeriodicCheckpointer(DEFAULT_CHECKPOINT_FREQUENCY)
checkpointer = DummyCheckpointer()
self._checkpointer = checkpointer
self._checkpointer.init(
self.model,
Expand Down
2 changes: 1 addition & 1 deletion all/approximation/q_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class QNetwork(Approximation):
def __init__(
self,
model,
optimizer,
optimizer=None,
name='q',
**kwargs
):
Expand Down
3 changes: 1 addition & 2 deletions all/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
from .writer import ExperimentWriter
from .plots import plot_returns_100
from .slurm import SlurmExperiment
from .watch import GreedyAgent, watch, load_and_watch
from .watch import watch, load_and_watch

__all__ = [
"run_experiment",
"Experiment",
"SingleEnvExperiment",
"ParallelEnvExperiment",
"SlurmExperiment",
"GreedyAgent",
"ExperimentWriter",
"watch",
"load_and_watch",
Expand Down
7 changes: 7 additions & 0 deletions all/experiments/experiment.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from abc import ABC, abstractmethod
import numpy as np
from scipy import stats
import torch


class Experiment(ABC):
Expand Down Expand Up @@ -73,4 +75,9 @@ def _log_test_episode(self, episode, returns):
print('test episode: {}, returns: {}'.format(episode, returns))

def _log_test(self, returns):
if not self._quiet:
print('test returns (mean ± sem): {} ± {}'.format(np.mean(returns), stats.sem(returns)))
self._writer.add_summary('returns-test', np.mean(returns), np.std(returns))

def save(self):
return self._preset.save('{}/preset.pt'.format(self._writer.log_dir))

0 comments on commit 8f65a70

Please sign in to comment.