# Reinforcement Learning for a Cart-Pole System

In [None]:
from __future__ import division, print_function

import numpy as np
import tensorflow as tf
from scipy.linalg import block_diag
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

import time
%matplotlib inline

import safe_learning
import plotting
from utilities import (sample_box, sample_box_boundary, sample_ellipsoid, constrained_batch_sampler,
                       CartPole, compute_closedloop_response, get_max_parameter_change)

try:
    from tqdm import tqdm
except ImportError:
    tqdm = lambda x: x
    
np_dtype = safe_learning.config.np_dtype
tf_dtype = safe_learning.config.dtype

tf.reset_default_graph()
try:
    session.close()
except NameError:
    pass
session = tf.InteractiveSession()

initialized = False


# TODO testing ****************************************#

train_policy = True

train_value_function = True

clip_states = False

saturate = True

#******************************************************#


In [None]:
def plot_value_function(tf_learned_values, tf_true_values, n_points, fixed_state, colors=['r','b'], show=True):
    fig = plt.figure(figsize=(12, 5), dpi=200)
    fig.subplots_adjust(wspace=0.1, hspace=0.2)
    xx, yy = np.mgrid[-1:1:np.complex(0, n_points[0]), -1:1:np.complex(0, n_points[1])]
    
    x_fix, theta_fix, x_dot_fix, theta_dot_fix = fixed_state

    # Fix x_dot and theta_dot, plot value function over x and theta
    grid = np.column_stack((xx.ravel(), yy.ravel(), 
                            x_dot_fix*np.ones_like(xx.ravel()), theta_dot_fix*np.ones_like(yy.ravel())))
    learned_values = session.run(tf_learned_values, feed_dict={states: grid}).reshape(n_points)
    true_values = session.run(tf_true_values, feed_dict={states: grid}).reshape(n_points)
    ax = fig.add_subplot(1, 2, 1, projection='3d')
    ax.plot_surface(xx, yy, learned_values, color=colors[0], alpha=0.75)
    ax.plot_surface(xx, yy, true_values, color=colors[1], alpha=0.5)
    ax.set_title(r'Value function, $\dot{x} = %.3g,\ \dot{\theta} = %.3g$' % (x_dot_fix, theta_dot_fix), fontsize=16)
    ax.set_xlabel(r'$x$', fontsize=14)
    ax.set_ylabel(r'$\theta$', fontsize=14)
    ax.set_zlabel(r'$V(s)$', fontsize=14)

    # Fix x and theta, plot value function over x_dot and theta_dot
    grid = np.column_stack((x_fix*np.ones_like(xx.ravel()), theta_fix*np.ones_like(yy.ravel()), 
                            xx.ravel(), yy.ravel()))
    learned_values = session.run(tf_learned_values, feed_dict={states: grid}).reshape(n_points)
    true_values = session.run(tf_true_values, feed_dict={states: grid}).reshape(n_points)
    ax = fig.add_subplot(1, 2, 2, projection='3d')
    ax.plot_surface(xx, yy, learned_values, color=colors[0], alpha=0.75)
    ax.plot_surface(xx, yy, true_values, color=colors[1], alpha=0.5)
    ax.set_title(r'Value function, $x = %.3g,\ \theta = %.3g$' % (x_fix, theta_fix), fontsize=16)
    ax.set_xlabel(r'$\dot{x}$', fontsize=14)
    ax.set_ylabel(r'$\dot{\theta}$', fontsize=14)
    ax.set_zlabel(r'$V(s)$', fontsize=14)

    if show:
        plt.show()

        
def plot_policy(tf_actions, tf_true_actions, n_points, fixed_state, colors=['r','b'], show=True):
    fig = plt.figure(figsize=(12, 5), dpi=200)
    fig.subplots_adjust(wspace=0.1, hspace=0.2)
    xx, yy = np.mgrid[-1:1:np.complex(0, n_points[0]), -1:1:np.complex(0, n_points[1])]
    
    x_fix, theta_fix, x_dot_fix, theta_dot_fix = fixed_state

    # Fix x_dot and theta_dot, plot value function over x and theta
    grid = np.column_stack((xx.ravel(), yy.ravel(), 
                            x_dot_fix*np.ones_like(xx.ravel()), theta_dot_fix*np.ones_like(yy.ravel())))
    learned_control = session.run(tf_actions, feed_dict={states: grid}).reshape(n_points)
    true_control = session.run(tf_true_actions, feed_dict={states: grid}).reshape(n_points)
    ax = fig.add_subplot(1, 2, 1, projection='3d')
    ax.plot_surface(xx, yy, learned_control, color=colors[0], alpha=0.75)
    ax.plot_surface(xx, yy, true_control, color=colors[1], alpha=0.5)
    ax.set_title(r'Control, $\dot{x} = %.3g,\ \dot{\theta} = %.3g$' % (x_dot_fix, theta_dot_fix), fontsize=16)
    ax.set_xlabel(r'$x$', fontsize=14)
    ax.set_ylabel(r'$\theta$', fontsize=14)
    ax.set_zlabel(r'$u$', fontsize=14)
    ax.view_init(elev=20., azim=15.)

    # Fix x and theta, plot value function over x_dot and theta_dot
    grid = np.column_stack((x_fix*np.ones_like(xx.ravel()), theta_fix*np.ones_like(yy.ravel()), 
                            xx.ravel(), yy.ravel()))
    learned_control = session.run(tf_actions, feed_dict={states: grid}).reshape(n_points)
    true_control = session.run(tf_true_actions, feed_dict={states: grid}).reshape(n_points)
    ax = fig.add_subplot(1, 2, 2, projection='3d')
    ax.plot_surface(xx, yy, learned_control, color=colors[0], alpha=0.75)
    ax.plot_surface(xx, yy, true_control, color=colors[1], alpha=0.5)
    ax.set_title(r'Control, $x = %.3g,\ \theta = %.3g$' % (x_fix, theta_fix), fontsize=16)
    ax.set_xlabel(r'$\dot{x}$', fontsize=14)
    ax.set_ylabel(r'$\dot{\theta}$', fontsize=14)
    ax.set_zlabel(r'$u$', fontsize=14)
    ax.view_init(elev=20., azim=100.)

    if show:
        plt.show()

## True Dynamics

In [None]:
# System parameters
m = 0.1     # pendulum mass
M = 1.5     # cart mass
L = 0.2     # pole length
b = 0.0     # rotational friction

# Constants
dt = 0.01   # sampling time
g = 9.81    # gravity

# State and action normalizers
x_max = 0.25
theta_max = np.deg2rad(15)
x_dot_max = 0.5
theta_dot_max = np.deg2rad(15)
u_max = (m + M)*x_dot_max / (5*dt)

state_norm = (x_max, theta_max, x_dot_max, theta_dot_max)
action_norm = (u_max,)

# Define system and dynamics
cart_pole = CartPole(m, M, L, b, dt, [state_norm, action_norm])
state_dim = 4
action_dim = 1

state_limits = np.array([[-1., 1.]]*cart_pole.state_dim)
action_limits = np.array([[-1., 1.]]*cart_pole.action_dim)

A, B = cart_pole.linearize()   
dynamics = safe_learning.functions.LinearSystem((A, B), name='dynamics')


def plot_closedloop_response(dynamics, policy1, policy2, steps, dt, reference='zero', const=1.0, ic=None,
                             labels=['Policy 1','Policy 2'], colors=['r','b'], denormalize=False, show=True):
    
    state_dim = 4
    state_traj1, action_traj1, t, _ = compute_closedloop_response(dynamics, policy1, state_dim,
                                                                  steps, dt, reference, const, ic)
    state_traj2, action_traj2, _, _ = compute_closedloop_response(dynamics, policy2, state_dim,
                                                                  steps, dt, reference, const, ic)
    
    fig = plt.figure(figsize=(10, 5), dpi=200)
    fig.subplots_adjust(wspace=0.5, hspace=0.5)
    
    if reference=='zero':
        title_string = r'Zero-Input Response'
    elif reference=='impulse':
        title_string = r'Impulse Response'
    elif reference=='step':
        title_string = r'Step Response, $r = %.1gu_{max} = %.1g$ N' % (const, const*u_max)
    
    if ic is not None:
        ic_tuple = (ic[0], ic[1], ic[2], ic[3])
    else:
        ic_tuple = (0, 0, 0, 0)
    title_string = title_string + r', $s_0 = (%.1g, %.1g, %.1g, %.1g)$' % ic_tuple
    fig.suptitle(title_string, fontsize=18)
    
    if denormalize:
        state_names = [r'$x$ [m]', r'$\theta$ [deg]', r'$\dot{x}$ [m/s]', r'$\dot{\theta}$ [deg/s]']
        state_traj1, action_traj1 = session.run(cart_pole.denormalize(state_traj1, action_traj1))
        state_traj2, action_traj2 = session.run(cart_pole.denormalize(state_traj2, action_traj2))
        for col in [1, 3]:
            state_traj1[:, col] = np.rad2deg(state_traj1[:, col])
            state_traj2[:, col] = np.rad2deg(state_traj2[:, col])
    else:
        state_names = [r'$x$', r'$\theta$', r'$\dot{x}$', r'$\dot{\theta}$']
    
    plot_idx = (1, 2, 4, 5)               
    for i in range(cart_pole.state_dim):
        ax = fig.add_subplot(2, 3, plot_idx[i])
        ax.plot(t, state_traj1[:, i], colors[0])
        ax.plot(t, state_traj2[:, i], colors[1])
        ax.set_xlabel(r'$t$ [s]', fontsize=14)
        ax.set_ylabel(state_names[i], fontsize=14)
    ax = fig.add_subplot(2, 3, 3)
    plot1 = ax.plot(t, action_traj1, color=colors[0])
    plot2 = ax.plot(t, action_traj2, color=colors[1])
    ax.set_xlabel(r'$t$ [s]', fontsize=14)
    if denormalize:
        ax.set_ylabel(r'$u$ [N]', fontsize=14)
    else:
        ax.set_ylabel(r'$u$', fontsize=14)
    fig.legend((plot1[0], plot2[0]), (labels[0], labels[1]), loc=(0.75, 0.2), fontsize=14)

    if show:
        plt.show()

## Reward Function

In [None]:
# State cost matrix
Q = np.diag([0.1, 0.1, 0.1, 0.1])

# Action cost matrix
R = 0.1*np.identity(action_dim)

# Quadratic reward (-cost) function
reward_function = safe_learning.QuadraticFunction(block_diag(-Q, -R), name='reward_function')

## Exact Optimal Policy and Value Function for the True Dynamics

In [None]:
# Solve Lyapunov equation for the exact value function matrix P and optimal feedback law u = -K.dot(s)
K, P = safe_learning.utilities.dlqr(A, B, Q, R)
print('LQR gain:\n{}\n'.format(-K))
print('Induced 2-norm of LQR gain:\n{}\n'.format(np.linalg.norm(K, 2)))
print('LQR cost matrix:\n{}\n'.format(P))

# LQR policy
lqr_policy = safe_learning.functions.LinearSystem((-K,), name='LQR_policy')
if saturate:
    lqr_policy = safe_learning.Saturation(lqr_policy, -1, 1)

# Optimal value function
lqr_value_function = safe_learning.functions.QuadraticFunction(-P, name='LQR_value_function')

# Approximate maximum ellipsoidal level set contained inside [-1, 1]**d via sampling
samples = sample_box_boundary(state_limits, 1e6)
test_values = lqr_value_function(tf.constant(samples, tf_dtype)).eval()
c_min = np.amin(np.abs(test_values))
c_max = np.amax(np.abs(test_values))
print('Minimum LQR cost (approx.):\n{}\n'.format(c_min))
print('Maximum LQR cost (approx.):\n{}\n'.format(c_max))

# LQR response
T = 1000
ic = np.array([1., 1., 0., 0.]).reshape(1, -1)
x, u, t, r = compute_closedloop_response(dynamics, lqr_policy, state_dim, T, dt, 'zero', ic=ic)

names = [r'$x$', r'$\theta$', r'$\dot{x}$', r'$\dot{\theta}$', r'$u$']
plt.plot(t, np.concatenate((x, u), axis=1))
plt.legend(names)
plt.show()

## Function Approximators

In [None]:
# Scaling
max_state = np.ones(state_dim).reshape((1, -1))
max_action = np.ones(action_dim).reshape((1, -1))
r_max = np.linalg.multi_dot((max_state, Q, max_state.T)) + np.linalg.multi_dot((max_action, R, max_action.T))
gamma = tf.placeholder(tf_dtype, shape=[], name='discount_factor')
scaling = (1 - gamma) / r_max

# Value function
if train_value_function:
    layer_dims = [64, 64, 1]
    activations = [tf.nn.relu, tf.nn.relu, None]
    value_function = safe_learning.functions.NeuralNetwork(layer_dims, activations, name='value_function')
else:
    value_function = lqr_value_function
    
# Policy
if train_policy:
    layer_dims = [64, 64, action_dim]
    activations = [tf.nn.relu, tf.nn.relu, None]
    if saturate:
        activations[-1] = tf.nn.tanh
    policy = safe_learning.functions.NeuralNetwork(layer_dims, activations, scaling=1., name='policy')
else:
    policy = lqr_policy

## TensorFlow Graph

In [None]:
states = tf.placeholder(tf_dtype, shape=[None, state_dim], name='states')
actions = policy(states)
rewards = reward_function(states, actions)
future_states = dynamics(states, actions)

values = value_function(states)
if clip_states:
    future_values = value_function(tf.clip_by_value(future_states, -1, 1))
else:
    future_values = value_function(future_states)

true_values = lqr_value_function(states)
true_actions = lqr_policy(states)
true_rewards = reward_function(states, true_actions)
true_future_states = dynamics(states, true_actions)

## Optimization Objectives

In [None]:
# Bellman error objective for value update
with tf.name_scope('value_optimization'):
    value_learning_rate = tf.placeholder(tf_dtype, shape=[], name='learning_rate')
    target = tf.stop_gradient(rewards + gamma*future_values, name='target')
    value_obj = scaling*tf.reduce_mean(tf.abs(values - target), name='objective')
    optimizer = tf.train.GradientDescentOptimizer(value_learning_rate)
    if train_value_function:
        value_update = optimizer.minimize(value_obj, var_list=value_function.parameters)

# Pseudo-integration objective for policy update
with tf.name_scope('policy_optimization'):
    policy_learning_rate = tf.placeholder(tf_dtype, shape=[], name='learning_rate')
    lagrange_multiplier = tf.placeholder(tf_dtype, shape=[], name='lagrange_multiplier')
    regularizer = -policy.lipschitz()
    policy_obj = -scaling*(tf.reduce_mean(rewards + gamma*future_values, name='objective') 
                           + lagrange_multiplier*regularizer)    
    optimizer = tf.train.GradientDescentOptimizer(policy_learning_rate)
    if train_policy:
        policy_update = optimizer.minimize(policy_obj, var_list=policy.parameters)

## Training: Policy Iteration

In [None]:
# Allow this cell to be run repeatedly to continue training if desired
if not initialized:
    session.run(tf.global_variables_initializer())
    initialized = True
if 'value_obj_history' not in locals():
    value_obj_history = np.zeros(0)
    value_param_history = np.zeros(0)
    policy_obj_history = np.zeros(0)
    policy_param_history = np.zeros(0)
    
# Uniformly distributed test set
test_size = 1e3
grid_length = np.power(test_size, 1 / state_dim)
grid_length = int(2*np.floor(grid_length / 2) + 1)
test_set = safe_learning.GridWorld(state_limits, [grid_length,]*state_dim).all_points

# Training hyperparameters
max_iters = 100
min_iters = 100
batch_size = 1e3
batch = constrained_batch_sampler(dynamics, policy, state_dim, batch_size, action_limit=None, zero_pad=0)

value_iters = 50
policy_iters = 10

value_tol = 1e-1
policy_tol = 1e-3

feed_dict = {
    states:               test_set,
    gamma:                0.99,
    value_learning_rate:  0.2,
    policy_learning_rate: 0.7,
    lagrange_multiplier:  0.0,
}



# Record objective values over time
value_obj_eval = np.zeros(max_iters + 1)
policy_obj_eval = np.zeros(max_iters + 1)
value_obj_eval[0] = value_obj.eval(feed_dict)
policy_obj_eval[0] = policy_obj.eval(feed_dict)

# For convergence, check the parameter values
converged = False
iter_memory = 5
value_param_changes = np.zeros(max_iters)
policy_param_changes = np.zeros(max_iters)
old_value_params = session.run(value_function.parameters)
old_policy_params = session.run(policy.parameters)


for i in tqdm(range(max_iters)):
    
    # Policy evaluation (value update)
    if train_value_function:
        for _ in range(value_iters):
            feed_dict[states] = session.run(batch)
            session.run(value_update, feed_dict)
        new_value_params = session.run(value_function.parameters)
        value_param_changes[i] = get_max_parameter_change(old_value_params, new_value_params)
        old_value_params = new_value_params

    # Policy improvement (policy update)
    if train_policy:
        for _ in range(policy_iters):
            feed_dict[states] = session.run(batch)
            session.run(policy_update, feed_dict)
        new_policy_params = session.run(policy.parameters)
        policy_param_changes[i] = get_max_parameter_change(old_policy_params, new_policy_params)
        old_policy_params = new_policy_params
    
    feed_dict[states] = test_set
    value_obj_eval[i+1] = value_obj.eval(feed_dict)
    policy_obj_eval[i+1] = policy_obj.eval(feed_dict)
    
    # TODO debugging    
    if np.isnan(value_obj_eval[i+1]) or np.isnan(policy_obj_eval[i+1]):
        raise ValueError('Encountered NAN value after {} iterations!'.format(i+1))
    
    # TODO Break if converged
#     if i >= iter_memory and i >= min_iters:
#         value_params_converged = np.all(value_param_changes[i-iter_memory+1:i+1] <= value_tol)
#         policy_params_converged = np.all(policy_param_changes[i-iter_memory+1:i+1] <= policy_tol)
#         if value_params_converged and policy_params_converged:
#             converged = True
#             break

final_iter = i+1
if converged:
    print('Converged after {} iterations.'.format(final_iter))
else:
    print('Did not converge!')

value_obj_history = np.concatenate((value_obj_history, value_obj_eval[:final_iter+1]))
value_param_history = np.concatenate((value_param_history, value_param_changes[:final_iter+1]))
policy_obj_history = np.concatenate((policy_obj_history, policy_obj_eval[:final_iter+1]))
policy_param_history = np.concatenate((policy_param_history, policy_param_changes[:final_iter+1]))

## Training: Results

In [None]:
fig = plt.figure(figsize=(10, 5), dpi=200)
fig.subplots_adjust(wspace=0.5, hspace=0.5)

cap = 50
start = 0
end = -1

ax = fig.add_subplot(221)
ax.plot(np.clip(value_obj_history[start:end], None, cap), '.-r')
ax.set_xlabel('Iteration')
ax.set_ylabel('Value Function Objective')

ax = fig.add_subplot(223)
ax.plot(np.clip(value_param_history[start:end], None, cap), '.-r')
ax.set_xlabel('Iteration')
ax.set_ylabel('Max. Value Param. Change')


ax = fig.add_subplot(222)
ax.plot(np.clip(policy_obj_history[start:end], None, cap), '.-r')
ax.set_xlabel('Iteration')
ax.set_ylabel('Policy Objective')

ax = fig.add_subplot(224)
ax.plot(np.clip(policy_param_history[start:end], None, cap), '.-r')
ax.set_xlabel('Iteration')
ax.set_ylabel('Max. Policy Param. Change')

plt.show()

## Zero-Input and Step Responses

In [None]:
ic = np.array([0.1, 0.2, 0.1, 0.1])
const = 1 / u_max
steps = 1000

print(session.run(actions, {states: np.zeros([1, state_dim])}))
print(session.run(values, {states: np.zeros([1, state_dim])}))

# Zero-input response, zero initial condition
plot_closedloop_response(dynamics, policy, lqr_policy, steps, dt, 'zero', ic=None, 
                         labels=['Learned','True'], denormalize=True)

# Zero-input response, non-zero initial condition
plot_closedloop_response(dynamics, policy, lqr_policy, steps, dt, 'zero', ic=ic, 
                         labels=['Learned','True'], denormalize=False)

# Step response
plot_closedloop_response(dynamics, policy, lqr_policy, steps, dt, 'step', const=const, 
                         labels=['Learned','True'], denormalize=False)

# Impulse response
# plot_closedloop_response(dynamics, policy, lqr_policy, steps, dt, 'impulse', labels=['Learned','True'])

## Value Function and Policy Comparison

In [None]:
if not initialized:
    session.run(tf.global_variables_initializer())
    initialized = True
    time.sleep(1.5)
    print('Initialized!')

fixed_state = [0.0, 0.0, 0., 0.]
n_points = [25, 25]

plot_value_function(values, true_values, n_points, fixed_state)
plot_policy(actions, true_actions, n_points, fixed_state)

## TensorBoard Graph

In [None]:
# plotting.show_graph(tf.get_default_graph())