In [None]:
from collections import OrderedDict, namedtuple
import time

import matplotlib.pyplot as plt
%matplotlib inline

import numpy as np
import theano as th
import theano.tensor as T

%load_ext autoreload
%autoreload 2

In [None]:
from utils.cartpole import CartPole
from utils import VariableStore, Linear, SGD, momentum

In [None]:
seed = int(time.time())
print seed
rng = T.shared_randomstreams.RandomStreams(seed)

In [None]:
STATE_DIM = 4
ACTION_DIM = 1
EXPLORE_RANGE = 0.5

Critic = namedtuple("Critic", ["pred", "targets", "cost", "updates"])


class DPGModel(object):
    
    def __init__(self, state_dim, action_dim, explore_range=0.5, track=True,
                 _parent=None, name="dpg"):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.explore_range = explore_range
        self.name = name
        
        self.parent = _parent
        if _parent is None:
            self._vs_actor = VariableStore("%s/vs_a" % name)
            self._vs_critic = VariableStore("%s/vs_c" % name)
            self._vs_prefix = self.name
            self._make_vars()
        else:
            self._vs_actor = VariableStore.snapshot(_parent._vs_actor)
            self._vs_critic = VariableStore.snapshot(_parent._vs_critic)
            self._vs_prefix = _parent.name
            self._pull_vars(_parent)
        
        self._make_graph()
        self._make_updates()
        self._make_functions()
        
        if track:
            self.track = DPGModel(state_dim, action_dim, explore_range,
                                  track=False, _parent=self,
                                  name="%s_track" % name)
        
    def _make_vars(self):
        self.X = T.matrix("X")

        # Optionally directly provide actions predicted
        self.actions = T.matrix("actions")
        # Q target values
        self.q_targets = T.vector("q_targets")
        # Learning rate
        self.lr = T.scalar("lr")
        
    def _pull_vars(self, parent):
        self.X = parent.X
        self.actions = parent.actions
        self.q_targets = parent.q_targets
        self.lr = parent.lr
        
        # Target network: tracking coefficient
        self.tau = T.scalar("tau")
        
    def _make_graph(self):
        # Deterministic policy: linear map
        self.a_pred = Linear(self.X, self.state_dim, self.action_dim,
                             self._vs_actor, name="%s/a" % self._vs_prefix)

        # Exploration policy: add noise
        self.a_explore = self.a_pred + rng.uniform(self.a_pred.shape,
                                                   -self.explore_range,
                                                   self.explore_range, ndim=2)

        # Create a few different Critic instances (Q-functions). These all
        # share parameters; they only differ in the sources of their inputs.
        #
        # Critic 1: actions given
        self.critic_given = self._make_critic(self.actions, self.q_targets)
        # Critic 2: with deterministic policy
        self.critic_det = self._make_critic(self.a_pred, self.q_targets)
        # Critic 3: with noised / exploration policy
        self.critic_exp = self._make_critic(self.a_explore, self.q_targets)
        
    def _make_critic(self, actions, targets):
        # Q-function is a linear map on state+action pair.
        q_pred = Linear(T.concatenate([self.X, actions], axis=1),
                        self.state_dim + self.action_dim, 1, self._vs_critic,
                        "%s/q" % self._vs_prefix)
        q_pred = q_pred.reshape((-1,))
        
        # MSE loss on TD backup targets.
        q_cost = ((targets - q_pred) ** 2).mean()
        q_updates = momentum(q_cost, self._vs_critic.vars.values(), self.lr)
        return Critic(q_pred, targets, q_cost, q_updates)
    
    def _make_updates(self):
        # Actor-critic learning w/ critic 3
        # NB, need to flatten all timesteps into a single batch
        self.updates = OrderedDict(self.critic_exp.updates)
        # Add policy gradient updates
        self.updates.update(momentum(-self.critic_exp.pred.mean(),
                                     self._vs_actor.vars.values(),
                                     self.lr))
        
        # Target network: update w.r.t. parent
        if self.parent is not None:
            self.target_updates = OrderedDict()
            for vs, parent_vs in [(self._vs_actor, self.parent._vs_actor),
                                  (self._vs_critic, self.parent._vs_critic)]:
                for param_name, param_var in vs.vars.iteritems():
                    self.target_updates[param_var] = (
                        self.tau * vs.vars[param_name]
                        + (1 - self.tau) * parent_vs.vars[param_name])
        
    def _make_functions(self):
        # On-policy action prediction function
        self.f_action_on = th.function([self.X], self.a_pred)
        # Off-policy action prediction function
        self.f_action_off = th.function([self.X], self.a_explore)

        # Q-function
        self.f_q = th.function([self.X, self.actions], self.critic_given.pred)

        # Actor-critic update
        self.f_update = th.function([self.X, self.q_targets, self.lr],
                                    (self.critic_exp.cost, self.critic_exp.pred),
                                    updates=self.updates)
        
        # Target networks only: update w.r.t. parent
        if self.parent is not None:
            self.f_track_update = th.function([self.tau], updates=self.target_updates)
        
        
dpg = DPGModel(STATE_DIM, ACTION_DIM, EXPLORE_RANGE, track=True)

In [None]:
def run_episode(f_onpolicy, f_offpolicy, f_q):
    """
    Simulate a trajectory and record states and rewards.
    Return a batch of (s, a, r).
    """
    cp = CartPole()
    trace = cp.single_episode(policy=lambda *args: f_offpolicy(np.array(args).reshape((1, -1))))
    
    states, actions, rewards = [], [], []
    for state_t, action_t, reward_t, _, _ in trace:
        states.append(state_t)
        actions.append(0 if action_t < 0 else 1)
        rewards.append(reward_t)

    states, actions, rewards = np.array(states), np.array(actions).astype(np.int32), np.array(rewards)    
    return len(trace), states, actions, rewards

In [None]:
# Keep a replay buffer of states, actions, rewards, targets
R_states = np.empty((0, STATE_DIM), dtype=th.config.floatX)
R_actions = np.empty((0,), dtype=np.int32)
R_rewards = np.empty((0,), dtype=np.int32)

steps = []
q_costs = []

BATCH_SIZE = 50
LR = 0.01
GAMMA = 0.99
TAU = 0.75

for t in xrange(1000):
    steps_t, states_t, actions_t, rewards_t = run_episode(dpg.f_action_on, dpg.f_action_off, dpg.f_q)
    
    R_states = np.append(R_states, states_t, axis=0)
    R_actions = np.append(R_actions, actions_t)
    R_rewards = np.append(R_rewards, rewards_t)
    
    if len(R_states) - 1 < BATCH_SIZE:
        # Not enough data. Keep collecting trajectories.
        continue

    # Sample a training minibatch.
    idxs = np.random.choice(len(R_states) - 1, size=BATCH_SIZE, replace=False)
    b_states, b_actions, b_rewards = R_states[idxs], R_actions[idxs], R_rewards[idxs]
    
    # Compute targets (TD error backups) given current Q function.
    next_states = R_states[idxs + 1] # may bork at the end of each trajectory, but I don't care
    next_actions = dpg.track.f_action_on(next_states)
    b_targets = b_rewards + GAMMA * dpg.track.f_q(next_states, next_actions).reshape((-1,))
    
    # SGD update.
    cost_t, _ = dpg.f_update(b_states, b_targets, LR)
    # Update tracking model.
    dpg.track.f_track_update(TAU)
    
    steps.append(steps_t)
    q_costs.append(cost_t)
    print "%i\t% 4i\t%10f" % (t, steps_t, cost_t)

In [None]:
plt.figure(0)
plt.plot(steps)
plt.xlabel("Iteration")
plt.ylabel("Steps until failure")

plt.figure(1)
plt.plot(q_costs, "r")
plt.xlabel("Iteration")
plt.ylabel("Q cost")