diff --git a/README.md b/README.md index f20636e..70079a6 100644 --- a/README.md +++ b/README.md @@ -26,10 +26,12 @@ this list): | Algorithm | Publication | Status | | :-------- | :---------: | :----- | | Deep Q-Network (DQN) | Mnih et al. 2015 ([PDF][paper-dqn]) | Working consistently. | +| Double Deep Q-Network (DDQN) | Hasselt, Guez, Silver. 2015 ([PDF][paper-ddqn]) | Working consistently. | | Asynchronous Advantage Actor-Critic (A3C) | Mnih et al. 2016 ([PDF][paper-a3c]) | Partly working. | | Reinforce | Williams 1992 ([PDF][paper-reinforce]) | Currently being tested. | [paper-dqn]: https://storage.googleapis.com/deepmind-data/assets/papers/DeepMindNature14236Paper.pdf +[paper-ddqn]: https://arxiv.org/pdf/1509.06461v3.pdf [paper-a3c]: https://arxiv.org/pdf/1602.01783v2.pdf [paper-reinforce]: http://www-anw.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf diff --git a/definition/cartpole.yaml b/definition/cartpole.yaml deleted file mode 100644 index 2c19ae7..0000000 --- a/definition/cartpole.yaml +++ /dev/null @@ -1,52 +0,0 @@ -epochs: 20 -test_steps: 1000 -repeats: 1 -envs: - - CartPole-v0 -algorithms: - - - name: Reinforce - type: Reinforce - train_steps: 20000 - config: - heads: 32 - discount: 0.99 - update_every: 10000 - preprocess: default - preprocess_config: {subsample: 1, frame_skip: 1, history: 4, noop_max: 5} - initial_learning_rate: 1e-2 - network: control - - - name: DQN - type: DQN - train_steps: 10000 - config: - discount: 0.95 - epsilon: {from_: 1.0, offset: 10000, over: 50000, test: 0.0, to: 0.05} - preprocess: default - preprocess_config: {subsample: 1, frame_skip: 1, history: 4, noop_max: 5} - initial_learning_rate: 1e-3 - network: control - replay_capacity: 1000 - start_learning: 1000 - sync_target: 50 - batch_size: 128 - - - name: A3C - type: A3C - train_steps: 10000 - config: - learners: 16 - discount: 0.95 # 0.5 - preprocess: default - preprocess_config: {subsample: 1, frame_skip: 1, history: 4, noop_max: 5} - initial_learning_rate: 1e-3 - heads: 10 - update_every: 10000 - batch_size: 50 - preprocess: default - preprocess_config: {subsample: 0, frame_skip: 1, history: 4, noop_max: 5} - approximation: {scale_critic_loss: 0.5, regularize: 0.01} - gradient_clipping: 1 - initial_learning_rate: 1e-4 - network: control diff --git a/definition/control.yaml b/definition/control.yaml new file mode 100644 index 0000000..53603d5 --- /dev/null +++ b/definition/control.yaml @@ -0,0 +1,68 @@ +epochs: 20 +test_steps: 1000 +repeats: 1 +envs: + - CartPole-v0 + - Acrobot-v1 +algorithms: + - + name: DDQN + type: DDQN + train_steps: 10000 + config: + discount: 0.95 + epsilon: {from_: 1.0, offset: 10000, over: 50000, test: 0.0, to: 0.05} + preprocess: default + preprocess_config: {subsample: 1, frame_skip: 1, history: 4, noop_max: 5} + initial_learning_rate: 1e-3 + network: control + replay_capacity: 1000 + start_learning: 1000 + sync_target: 50 + batch_size: 128 + - + name: DQN + type: DQN + train_steps: 10000 + config: + discount: 0.95 + epsilon: {from_: 1.0, offset: 10000, over: 50000, test: 0.0, to: 0.05} + preprocess: default + preprocess_config: {subsample: 1, frame_skip: 1, history: 4, noop_max: 5} + initial_learning_rate: 1e-3 + network: control + replay_capacity: 1000 + start_learning: 1000 + sync_target: 50 + batch_size: 128 +# - +# name: Reinforce +# type: Reinforce +# train_steps: 20000 +# config: +# heads: 32 +# discount: 0.99 +# update_every: 10000 +# preprocess: default +# preprocess_config: {subsample: 1, frame_skip: 1, history: 4, noop_max: 5} +# initial_learning_rate: 1e-2 +# network: control +# - +# name: A3C +# type: A3C +# train_steps: 10000 +# config: +# learners: 16 +# discount: 0.95 # 0.5 +# preprocess: default +# preprocess_config: {subsample: 1, frame_skip: 1, history: 4, noop_max: 5} +# initial_learning_rate: 1e-3 +# heads: 10 +# update_every: 10000 +# batch_size: 50 +# preprocess: default +# preprocess_config: {subsample: 0, frame_skip: 1, history: 4, noop_max: 5} +# approximation: {scale_critic_loss: 0.5, regularize: 0.01} +# gradient_clipping: 1 +# initial_learning_rate: 1e-4 +# network: control diff --git a/definition/reinforce.yaml b/definition/reinforce.yaml index f75dec4..7036fc5 100644 --- a/definition/reinforce.yaml +++ b/definition/reinforce.yaml @@ -11,10 +11,14 @@ algorithms: config: heads: 32 preprocess: default - preprocess_config: {subsample: 1, frame_skip: 1, history: 2, noop_max: 3} - approximation: 'advantage_policy_gradient' - approximation_config: {scale_critic_loss: 0.5, regularize: 0.01} - discount: 0.95 + preprocess_config: {subsample: 1, frame_skip: 1, history: 3, noop_max: 5} + approximation: advantage_policy_gradient + approximation_config: + actor_weight: 1.0 + critic_weight: 1.0 + entropy_weight: 1.0 + # TODO: Add eligibility parameter. + discount: 1 # 0.95 network: control update_every: 10000 batch_size: 32 diff --git a/mindpark/algorithm/__init__.py b/mindpark/algorithm/__init__.py index a80b31a..972fb6a 100644 --- a/mindpark/algorithm/__init__.py +++ b/mindpark/algorithm/__init__.py @@ -1,5 +1,6 @@ from .random import Random from .keyboard import KeyboardDoom from .dqn import DQN +from .ddqn import DDQN from .a3c import A3C from .reinforce import Reinforce diff --git a/mindpark/algorithm/a3c.py b/mindpark/algorithm/a3c.py index 6713184..d83951d 100644 --- a/mindpark/algorithm/a3c.py +++ b/mindpark/algorithm/a3c.py @@ -19,7 +19,8 @@ def defaults(cls): preprocess_config = dict() network = 'a3c_lstm' learners = 16 - approximation = dict(scale_critic_loss=0.5, regularize=0.01) + approximation_config = dict( + actor_weight=1.0, critic_weight=0.5, entropy_weight=0.01) apply_gradient = 5 initial_learning_rate = 7e-4 optimizer = tf.train.RMSPropOptimizer @@ -75,7 +76,7 @@ def _create_network(self, model): observs = self._preprocess.above_task.observs.shape actions = self._preprocess.above_task.actions.n mp.part.approximation.advantage_policy_gradient( - model, network, observs, actions, self.config.approximation) + model, network, observs, actions, self.config.approximation_config) def _create_preprocess(self): policy = mp.Sequential(self.task) diff --git a/mindpark/algorithm/ddqn.py b/mindpark/algorithm/ddqn.py new file mode 100644 index 0000000..b86180c --- /dev/null +++ b/mindpark/algorithm/ddqn.py @@ -0,0 +1,138 @@ +import numpy as np +import tensorflow as tf +import mindpark as mp +import mindpark.part.preprocess +import mindpark.part.approximation +import mindpark.part.network +import mindpark.part.replay + + +class DDQN(mp.Algorithm, mp.step.Experience): + + """ + Algorithm: Double Deep Q-Network (DDQN) + Paper: Deep Reinforcement Learning with Double Q-learning + Authors: Hasselt, Guez, Silver. 2015 + PDF: https://arxiv.org/pdf/1509.06461v3.pdf + """ + + @classmethod + def defaults(cls): + preprocess = 'dqn_2015' + preprocess_config = dict(frame_skip=4) + network = 'dqn_2015' + replay_capacity = 1e5 # 1e6 + start_learning = 5e4 + epsilon = dict( + from_=1.0, to=0.1, test=0.05, over=1e6, offset=start_learning) + batch_size = 32 + sync_target = 2500 + initial_learning_rate = 2.5e-4 + optimizer = tf.train.RMSPropOptimizer + optimizer_config = dict(decay=0.95, epsilon=0.1) + return mp.utility.merge_dicts(super().defaults(), locals()) + + def __init__(self, task, config): + mp.Algorithm.__init__(self, task, config) + self._parse_config() + self._preprocess = self._create_preprocess() + mp.step.Experience.__init__(self, self._preprocess.above_task) + self._model = mp.model.Model(self._create_network) + self._target = mp.model.Model(self._create_network) + self._target.weights = self._model.weights + self._sync_target = mp.utility.Every( + self.config.sync_target, self.config.start_learning) + print(str(self._model)) + self._learning_rate = mp.utility.Decay( + self.config.initial_learning_rate, 0, self.task.steps) + self._cost_metric = mp.Metric(self.task, 'dqn/cost', 1) + self._sync_target_metric = mp.Metric(self.task, 'dqn/sync_target', 1) + self._learning_rate_metric = mp.Metric( + self.task, 'dqn/learning_rate', 1) + self._memory = self._create_memory() + + def end_epoch(self): + super().end_epoch() + if self.task.directory: + self._model.save(self.task.directory, 'model') + + def perform(self, observ): + return self._model.compute('qvalues', state=observ) + + def experience(self, observ, action, reward, successor): + action = action.argmax() + self._memory.push(observ, action, reward, successor) + if self.task.step < self.config.start_learning: + return + self._train_network() + + @property + def policy(self): + # TODO: Why doesn't self.task work here? + policy = mp.Sequential(self._preprocess.task) + policy.add(self._preprocess) + policy.add(self) + return policy + + def _train_network(self): + self._model.set_option( + 'learning_rate', self._learning_rate(self.task.step)) + self._learning_rate_metric(self._model.get_option('learning_rate')) + observ, action, reward, successor = \ + self._memory.batch(self.config.batch_size) + return_ = self._estimated_return(reward, successor) + cost = self._model.train( + 'cost', state=observ, action=action, return_=return_) + self._cost_metric(cost) + if self._sync_target(self.task.step): + self._target.weights = self._model.weights + self._sync_target_metric(True) + else: + self._sync_target_metric(False) + + def _estimated_return(self, reward, successor): + terminal = np.isnan(successor.reshape((len(successor), -1))).any(1) + successor = np.nan_to_num(successor) + assert np.isfinite(successor).all() + # NOTE: Swapping the models below seems to work similarly well. + future = self._target.compute('qvalues', state=successor) + choice = self._model.compute('choice', state=successor) + future = choice.choose(future.T) + future[terminal] = 0 + return_ = reward + self.config.discount * future + assert np.isfinite(return_).all() + return return_ + + def _create_memory(self): + observ_shape = self._preprocess.above_task.observs.shape + shapes = observ_shape, tuple(), tuple(), observ_shape + memory = mp.part.replay.Random(self.config.replay_capacity, shapes) + memory.log_memory_size() + return memory + + def _create_preprocess(self): + policy = mp.Sequential(self.task) + preprocess = getattr(mp.part.preprocess, self.config.preprocess) + policy.add(preprocess, self.config.preprocess_config) + policy.add(mp.step.EpsilonGreedy, **self.config.epsilon) + return policy + + def _create_network(self, model): + learning_rate = model.add_option( + 'learning_rate', self.config.initial_learning_rate) + model.set_optimizer(self.config.optimizer( + learning_rate=learning_rate, + **self.config.optimizer_config)) + network = getattr(mp.part.network, self.config.network) + observs = self._preprocess.above_task.observs.shape + actions = self._preprocess.above_task.actions.shape[0] + mp.part.approximation.q_function(model, network, observs, actions) + + def _parse_config(self): + if self.config.start_learning > self.config.replay_capacity: + raise KeyError('Why not start learning after the buffer is full?') + if self.config.start_learning < self.config.batch_size: + raise KeyError('Must collect at least one batch before learning.') + self.config.start_learning *= self.config.preprocess_config.frame_skip + self.config.sync_target *= self.config.preprocess_config.frame_skip + self.config.epsilon.over *= self.config.preprocess_config.frame_skip diff --git a/mindpark/algorithm/dqn.py b/mindpark/algorithm/dqn.py index c724a63..2516eb3 100644 --- a/mindpark/algorithm/dqn.py +++ b/mindpark/algorithm/dqn.py @@ -89,7 +89,7 @@ def _estimated_return(self, reward, successor): terminal = np.isnan(successor.reshape((len(successor), -1))).any(1) successor = np.nan_to_num(successor) assert np.isfinite(successor).all() - future = self._target.compute('value', state=successor) + future = self._target.compute('qvalue', state=successor) future[terminal] = 0 return_ = reward + self.config.discount * future assert np.isfinite(return_).all() diff --git a/mindpark/algorithm/reinforce.py b/mindpark/algorithm/reinforce.py index ea8c65c..1c5e92d 100644 --- a/mindpark/algorithm/reinforce.py +++ b/mindpark/algorithm/reinforce.py @@ -27,11 +27,12 @@ def defaults(cls): heads = 16 discount = 0.999 initial_learning_rate = 2.5e-4 - optimizer = tf.train.AdamOptimizer - gradient_clipping = 10 # 1e-2 + optimizer = 'AdamOptimizer' + gradient_clipping = 10 optimizer_config = dict() approximation = 'advantage_policy_gradient' - approximation_config = dict(scale_critic_loss=0.5, regularize=0.01) + approximation_config = dict( + actor_weight=1.0, critic_weight=1.0, entropy_weight=1.0) return mp.utility.merge_dicts(super().defaults(), locals()) def __init__(self, task, config): diff --git a/mindpark/model/model.py b/mindpark/model/model.py index 5e58c38..c08f1a4 100644 --- a/mindpark/model/model.py +++ b/mindpark/model/model.py @@ -102,6 +102,7 @@ def has_cost(self, name): def train(self, cost, batch=None, epochs=1, **data): costs = [] for batch in self._chunks(data, batch, epochs): + # TODO: See if training directly is more efficient. delta, cost = self.delta(cost, **data) self.apply(delta) costs.append(cost) diff --git a/mindpark/part/approximation.py b/mindpark/part/approximation.py index 8890ffb..8ffcdf9 100644 --- a/mindpark/part/approximation.py +++ b/mindpark/part/approximation.py @@ -11,19 +11,19 @@ def q_function(model, network, observs, actions, config=None): hidden = network(model, state) qvalues = dense(hidden, actions, tf.identity) qvalues = model.add_output('qvalues', qvalues) + model.add_output('choice', tf.argmax(qvalues, 1)) with tf.variable_scope('learning'): action = model.add_input('action', type_=tf.int32) action = tf.one_hot(action, actions) return_ = model.add_input('return_') - model.add_output('value', tf.reduce_max(qvalues, 1)) + model.add_output('qvalue', tf.reduce_max(qvalues, 1)) model.add_cost( 'cost', (tf.reduce_sum(action * qvalues, 1) - return_) ** 2) def policy_gradient(model, network, observs, actions, config): """ - Policy gradient of the advantage function. Estimates the advantage from a - learned value function and experiences returns. + Policy gradient of the return. """ with tf.variable_scope('behavior'): state = model.add_input('state', observs) @@ -31,15 +31,16 @@ def policy_gradient(model, network, observs, actions, config): value = model.add_output( 'value', tf.squeeze(dense(hidden, 1, tf.identity), [1])) policy = dense(value, actions, tf.nn.softmax) - model.add_output( - 'choice', tf.squeeze(tf.multinomial(tf.log(policy), 1), [1])) + model.add_output('choice', tf.squeeze(tf.multinomial(policy, 1), [1])) with tf.variable_scope('learning'): action = model.add_input('action', type_=tf.int32) action = tf.one_hot(action, actions) return_ = model.add_input('return_') logprob = tf.log(tf.reduce_sum(policy * action, 1) + 1e-13) entropy = -tf.reduce_sum(tf.log(policy + 1e-13) * policy) - model.add_cost('cost', return_ * logprob + config.regularize * entropy) + actor = config.actor_weight * return_ * logprob + entropy = config.entropy_weight * entropy + model.add_cost('cost', -actor + -entropy) def advantage_policy_gradient(model, network, observs, actions, config): @@ -53,18 +54,18 @@ def advantage_policy_gradient(model, network, observs, actions, config): value = model.add_output( 'value', tf.squeeze(dense(hidden, 1, tf.identity), [1])) policy = dense(value, actions, tf.nn.softmax) - model.add_output( - 'choice', tf.squeeze(tf.multinomial(tf.log(policy), 1), [1])) + model.add_output('choice', tf.squeeze(tf.multinomial(policy, 1), [1])) with tf.variable_scope('learning'): action = model.add_input('action', type_=tf.int32) action = tf.one_hot(action, actions) return_ = model.add_input('return_') + advantage = tf.stop_gradient(return_ - value) logprob = tf.log(tf.reduce_sum(policy * action, 1) + 1e-13) entropy = -tf.reduce_sum(tf.log(policy + 1e-13) * policy) - advantage = tf.stop_gradient(return_ - value) - actor = advantage * logprob + config.regularize * entropy - critic = config.scale_critic_loss * (return_ - value) ** 2 / 2 - model.add_cost('cost', critic - actor) + actor = config.actor_weight * advantage * logprob + critic = config.critic_weight * (return_ - value) ** 2 / 2 + entropy = config.entropy_weight * entropy + model.add_cost('cost', critic - actor - entropy) def approx_advantage_policy_gradient(model, network, observs, actions, config): @@ -77,25 +78,20 @@ def approx_advantage_policy_gradient(model, network, observs, actions, config): hidden = network(model, state) value = model.add_output( 'value', tf.squeeze(dense(hidden, 1, tf.identity), [1])) - advantages = model.add_output( - 'advantages', dense(hidden, actions, tf.identity)) + qvalues = model.add_output( + 'qvalues', tf.squeeze(dense(hidden, actions, tf.identity), [1])) policy = dense(value, actions, tf.nn.softmax) - model.add_output( - 'choice', tf.squeeze(tf.multinomial(tf.log(policy), 1), [1])) + model.add_output('choice', tf.squeeze(tf.multinomial(policy, 1), [1])) with tf.variable_scope('learning'): action = model.add_input('action', type_=tf.int32) - return_ = model.add_input('return_') action = tf.one_hot(action, actions) - with tf.variable_scope('value'): - critic_v = (return_ - value) ** 2 / 2 - with tf.variable_scope('advantage'): - advantage = tf.reduce_max(action * advantages, [1]) - qvalue = value + advantage - critic_q = (return_ - qvalue) ** 2 / 2 - with tf.variable_scope('policy'): - advantage = tf.stop_gradient(advantage) - logprob = tf.log(tf.reduce_sum(policy * action, 1) + 1e-13) - entropy = -tf.reduce_sum(tf.log(policy + 1e-13) * policy) - actor = advantage * logprob + config.regularize * entropy - critic = config.scale_critic_loss * (critic_v + critic_q) - model.add_cost('cost', critic - actor) + return_ = model.add_input('return_') + qvalue = qvalues * action + advantage = tf.stop_gradient(qvalue - value) + logprob = tf.log(tf.reduce_sum(policy * action, 1) + 1e-13) + entropy = -tf.reduce_sum(tf.log(policy + 1e-13) * policy) + actor = config.actor_weight * advantage * logprob + critic = config.critic_weight * (return_ - value) ** 2 / 2 + qcritic = config.critic_weight * (return_ - qvalue) ** 2 / 2 + entropy = config.entropy_weight * entropy + model.add_cost('cost', critic - actor - entropy) diff --git a/mindpark/part/network.py b/mindpark/part/network.py index 561097e..53747c7 100644 --- a/mindpark/part/network.py +++ b/mindpark/part/network.py @@ -75,6 +75,8 @@ def test(model, x): def control(model, x): - x = dense(x, 100, tf.nn.relu) - x = dense(x, 50, tf.nn.relu) + # x = dense(x, 100, tf.nn.relu) + # x = dense(x, 50, tf.nn.relu) + x = dense(x, 32, tf.nn.relu) + x = dense(x, 32, tf.nn.relu) return x diff --git a/test/part/test_replay.py b/test/part/test_replay.py index a2d8daf..fec9eb3 100644 --- a/test/part/test_replay.py +++ b/test/part/test_replay.py @@ -78,6 +78,16 @@ def test_shuffle(self): assert (np.sort(batch) == list(range(10, 20))).all() assert not (batch == list(range(10, 20))).all() + def test_shuffle_underfull(self): + random = np.random.RandomState(0) + memory = mp.part.replay.Sequential(20, [[]], random) + for number in range(10): + memory.push(number) + memory.shuffle() + batch = memory.batch(10)[0] + assert (np.sort(batch) == list(range(10))).all() + assert not (batch == list(range(10))).all() + class TestRandom: