# Reinforcement Learning for a Cart-Pole

In [None]:
from __future__ import division, print_function

import numpy as np
import tensorflow as tf
from scipy.linalg import block_diag
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.colors import ListedColormap
import os
from matplotlib.font_manager import FontProperties

import time
%matplotlib inline

import safe_learning
import plotting
from utilities import constrained_batch_sampler, CartPole, compute_closedloop_response, get_parameter_change

try:
    from tqdm import tqdm
except ImportError:
    tqdm = lambda x: x
    
np_dtype = safe_learning.config.np_dtype
tf_dtype = safe_learning.config.dtype

# TODO testing ****************************************#

class Options(object):
    def __init__(self, **kwargs):
        super(Options, self).__init__()
        self.__dict__.update(kwargs)

OPTIONS = Options(np_dtype              = safe_learning.config.np_dtype,
                  tf_dtype              = safe_learning.config.dtype,
                  train_policy          = True,
                  train_value_function  = True,
                  use_linear_dynamics   = False,
                  saturate              = True,
                  eps                   = 1e-8,
                  dpi                   = 100,
                  fontproperties        = FontProperties(size=10),
                  save_figs             = True,
                  fig_path              = 'figures/cartpole_rl/')

HEAT_MAP = plt.get_cmap('inferno', lut=None)
HEAT_MAP.set_over('white')
HEAT_MAP.set_under('black')

LEVEL_MAP = plt.get_cmap('viridis', lut=21)
LEVEL_MAP.set_over('gold')
LEVEL_MAP.set_under('white')

BINARY_MAP = ListedColormap([(1., 1., 1., 0.), (0., 1., 0., 0.65)])

def binary_cmap(color='red', alpha=1.):
    if color=='red':
        color_code = (1., 0., 0., alpha)
    elif color=='green':
        color_code = (0., 1., 0., alpha)
    elif color=='blue':
        color_code = (0., 0., 1., alpha)
    else:
        color_code = (0., 0., 0., alpha)
    transparent_code = (1., 1., 1., 0.)
    return ListedColormap([transparent_code, color_code])

#******************************************************#


## TensorFlow Session

In [None]:
MAX_CPU_COUNT = os.cpu_count()
NUM_CORES = 8
NUM_SOCKETS = 2

os.environ["KMP_BLOCKTIME"]    = str(0)
os.environ["KMP_SETTINGS"]     = str(1)
os.environ["KMP_AFFINITY"]     = 'granularity=fine,noverbose,compact,1,0'
os.environ["OMP_NUM_THREADS"]  = str(NUM_CORES)

config = tf.ConfigProto(intra_op_parallelism_threads  = NUM_CORES,
                        inter_op_parallelism_threads  = NUM_SOCKETS,
                        allow_soft_placement          = False,
#                         log_device_placement          = True,
                        device_count                  = {'CPU': MAX_CPU_COUNT})

# TODO manually for CPU-only?
config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1

try:
    session.close()
except NameError:
    pass
session = tf.InteractiveSession(config=config)

# print('Found MAX_CPU_COUNT =', MAX_CPU_COUNT)
# for dev in session.list_devices():
#     print(dev)

initialized = False

In [None]:
def plot_policy(tf_actions, tf_true_actions, n_points, fixed_state, colors=['r','b'], show=True):
    fig = plt.figure(figsize=(12, 5), dpi=OPTIONS.dpi)
    fig.subplots_adjust(wspace=0.1, hspace=0.2)
    xx, yy = np.mgrid[-1:1:np.complex(0, n_points[0]), -1:1:np.complex(0, n_points[1])]
    
    x_fix, theta_fix, x_dot_fix, theta_dot_fix = fixed_state

    # Fix x_dot and theta_dot, plot value function over x and theta
    grid = np.column_stack((xx.ravel(), yy.ravel(), 
                            x_dot_fix*np.ones_like(xx.ravel()), theta_dot_fix*np.ones_like(yy.ravel())))
    learned_control = session.run(tf_actions, feed_dict={states: grid}).reshape(n_points)
    true_control = session.run(tf_true_actions, feed_dict={states: grid}).reshape(n_points)
    ax = fig.add_subplot(1, 2, 1, projection='3d')
    ax.plot_surface(xx, yy, learned_control, color=colors[0], alpha=0.75)
    ax.plot_surface(xx, yy, true_control, color=colors[1], alpha=0.5)
    ax.set_title(r'Control, $\dot{x} = %.3g,\ \dot{\theta} = %.3g$' % (x_dot_fix, theta_dot_fix), fontsize=16)
    ax.set_xlabel(r'$x$', fontsize=14)
    ax.set_ylabel(r'$\theta$', fontsize=14)
    ax.set_zlabel(r'$u$', fontsize=14)
    ax.view_init(elev=20., azim=15.)

    # Fix x and theta, plot value function over x_dot and theta_dot
    grid = np.column_stack((x_fix*np.ones_like(xx.ravel()), theta_fix*np.ones_like(yy.ravel()), 
                            xx.ravel(), yy.ravel()))
    learned_control = session.run(tf_actions, feed_dict={states: grid}).reshape(n_points)
    true_control = session.run(tf_true_actions, feed_dict={states: grid}).reshape(n_points)
    ax = fig.add_subplot(1, 2, 2, projection='3d')
    ax.plot_surface(xx, yy, learned_control, color=colors[0], alpha=0.75)
    ax.plot_surface(xx, yy, true_control, color=colors[1], alpha=0.5)
    ax.set_title(r'Control, $x = %.3g,\ \theta = %.3g$' % (x_fix, theta_fix), fontsize=16)
    ax.set_xlabel(r'$\dot{x}$', fontsize=14)
    ax.set_ylabel(r'$\dot{\theta}$', fontsize=14)
    ax.set_zlabel(r'$u$', fontsize=14)
    ax.view_init(elev=20., azim=100.)

    if show:
        plt.show()


def plot_closedloop_response(dynamics, policy1, policy2, steps, dt, reference='zero', const=1.0, ic=None,
                             labels=['Policy 1','Policy 2'], colors=['r','b'], denormalize=False, show=True):
    
    state_dim = 4
    state_traj1, action_traj1, t, _ = compute_closedloop_response(dynamics, policy1, state_dim,
                                                                  steps, dt, reference, const, ic)
    state_traj2, action_traj2, _, _ = compute_closedloop_response(dynamics, policy2, state_dim,
                                                                  steps, dt, reference, const, ic)
    
    fig = plt.figure(figsize=(10, 5), dpi=OPTIONS.dpi)
    fig.subplots_adjust(wspace=0.5, hspace=0.5)
    
    if reference=='zero':
        title_string = r'Zero-Input Response'
    elif reference=='impulse':
        title_string = r'Impulse Response'
    elif reference=='step':
        title_string = r'Step Response, $r = %.1gu_{max} = %.1g$ N' % (const, const*u_max)
    
    if ic is not None:
        ic_tuple = (ic[0], ic[1], ic[2], ic[3])
    else:
        ic_tuple = (0, 0, 0, 0)
    title_string = title_string + r', $s_0 = (%.1g, %.1g, %.1g, %.1g)$' % ic_tuple
    fig.suptitle(title_string, fontsize=18)
    
    if denormalize:
        state_names = [r'$x$ [m]', r'$\theta$ [deg]', r'$\dot{x}$ [m/s]', r'$\dot{\theta}$ [deg/s]']
        state_traj1, action_traj1 = session.run(cart_pole.denormalize(state_traj1, action_traj1))
        state_traj2, action_traj2 = session.run(cart_pole.denormalize(state_traj2, action_traj2))
        for col in [1, 3]:
            state_traj1[:, col] = np.rad2deg(state_traj1[:, col])
            state_traj2[:, col] = np.rad2deg(state_traj2[:, col])
    else:
        state_names = [r'$x$', r'$\theta$', r'$\dot{x}$', r'$\dot{\theta}$']
    
    plot_idx = (1, 2, 4, 5)               
    for i in range(state_dim):
        ax = fig.add_subplot(2, 3, plot_idx[i])
        ax.plot(t, state_traj1[:, i], colors[0])
        ax.plot(t, state_traj2[:, i], colors[1])
        ax.set_xlabel(r'$t$ [s]', fontsize=14)
        ax.set_ylabel(state_names[i], fontsize=14)
    ax = fig.add_subplot(2, 3, 3)
    plot1 = ax.plot(t, action_traj1, color=colors[0])
    plot2 = ax.plot(t, action_traj2, color=colors[1])
    ax.set_xlabel(r'$t$ [s]', fontsize=14)
    if denormalize:
        ax.set_ylabel(r'$u$ [N]', fontsize=14)
    else:
        ax.set_ylabel(r'$u$', fontsize=14)
    fig.legend((plot1[0], plot2[0]), (labels[0], labels[1]), loc=(0.75, 0.2), fontsize=14)

    if show:
        plt.show()
        

def gridify(norms, maxes=None, num_points=25):    
    norms = np.asarray(norms).ravel()
    if maxes is None:
        maxes = norms
    else:
        maxes = np.asarray(maxes).ravel()
    limits = np.column_stack((- maxes / norms, maxes / norms))
    
    if isinstance(num_points, int):
        num_points = [num_points, ] * len(norms)
    grid = safe_learning.GridWorld(limits, num_points)
    return grid


def compute_roa(grid, closed_loop_dynamics, horizon=250, tol=1e-3, equilibrium=None, no_traj=True):
    if isinstance(grid, np.ndarray):
        all_points = grid
        nindex = grid.shape[0]
        ndim = grid.shape[1]
    else:
        all_points = grid.all_points
        nindex = grid.nindex
        ndim = grid.ndim
    
    # Forward-simulate all trajectories from initial points in the discretization
    if no_traj:
        end_states = all_points
        for t in range(1, horizon):
            end_states = closed_loop_dynamics(end_states)
    else:
        trajectories = np.empty((nindex, ndim, horizon))
        trajectories[:, :, 0] = all_points
        for t in range(1, horizon):
            trajectories[:, :, t] = closed_loop_dynamics(trajectories[:, :, t - 1])
        end_states = trajectories[:, :, -1]
            
    if equilibrium is None:
        equilibrium = np.zeros((1, ndim))
    
    # Compute an approximate ROA as all states that end up "close" to 0
    dists = np.linalg.norm(end_states - equilibrium, ord=2, axis=1, keepdims=True).ravel()
    roa = (dists <= tol)
    if no_traj:
        return roa, dists
    else:
        return roa, dists, trajectories


def estimate_cost(grid, closed_loop_dynamics, cost_function, discount, horizon=250, tol=1e-3):
    # Estimate true cost function using a finite-horizon rollout
    converged = False
    if isinstance(grid, safe_learning.GridWorld):
        rollout = np.zeros(grid.nindex)
        current_states = grid.all_points
    else:
        rollout = np.zeros(grid.shape[0])
        current_states = grid
        
    for t in range(horizon):
        temp = (discount ** t) * cost_function(current_states).ravel()
        rollout += temp
        if np.max(np.abs(temp)) < tol:
            converged = True
            break
        current_states = closed_loop_dynamics(current_states)
    if converged:
        print('Cost converged after {} steps!'.format(t + 1))
    else:
        print('Cost did not converge!')
            
    return rollout


def find_nearest(array, value, sorted_1d=True):
    if not sorted_1d:
        array = np.sort(array)
    idx = np.searchsorted(array, value, side='left')
    if idx > 0 and (idx == len(array) or np.abs(value - array[idx - 1]) < np.abs(value - array[idx])):
        idx -= 1
    return idx, array[idx]


def plot_cost(cost, discretization, norms, masks, roa=None):
    planes = [[0, 2], [1, 3]]
    limits = np.asarray(norms).reshape((-1, 1)) * discretization.limits
    
    plt.rc('font', size=10)
    fig = plt.figure(figsize=(6, 12), dpi=OPTIONS.dpi)
#     fig.subplots_adjust(wspace=0.4, hspace=0.2)
    
    for i, p in enumerate(planes):
        if isinstance(cost, list):
            z = cost[i].ravel()
        else:
            z = cost(discretization.all_points[masks[i]]).eval()
        z = z.reshape(discretization.num_points[p])
        
        scaled_discrete_points = [norm * points for norm, points in zip(norms, grid.discrete_points)]
        xx, yy = np.meshgrid(*[scaled_discrete_points[p[0]], scaled_discrete_points[p[1]]])
        
        ax = fig.add_subplot(211 + i, projection='3d')        
        surf = ax.plot_surface(xx, yy, z, color='blue', alpha=0.65) #, label=r'$J_{\bf \theta}({\bf x})$')
        surf._facecolors2d = surf._facecolors3d
        surf._edgecolors2d = surf._edgecolors3d
        
        if roa is not None:
            z = roa[i].reshape(discretization.num_points[p])
            ax.contourf(xx, yy, z, cmap=BINARY_MAP, zdir='z', offset=0)
    
        if i == 0:
            ax.set_xlabel(r'$x$ [m]')
            ax.set_ylabel(r'$\dot{x}$ [m/s]')  
        else:
            ax.set_xlabel(r'$\phi$ [deg]')
            ax.set_ylabel(r'$\dot{\phi}$ [deg/s]')
        ax.view_init(None, -45)
 
    fig.tight_layout()
    plt.show()

## True Dynamics

In [None]:
# System parameters
# m = 0.1     # pendulum mass
# M = 1.5     # cart mass
# L = 0.2     # pole length
# b = 0.0     # rotational friction
m = 0.175    # pendulum mass
M = 1.732    # cart mass
L = 0.28     # pole length
b = 0.01      # rotational friction

# Constants
dt = 0.01   # sampling time
g = 9.81    # gravity

# State and action normalizers
x_max         = 0.5
theta_max     = np.deg2rad(30)
x_dot_max     = 2
theta_dot_max = np.deg2rad(30)
# theta_dot_max = np.deg2rad(30)
# u_max         = (m + M) * x_dot_max / (5 * dt)
# u_max         = (M + m) * g * np.tan(theta_max)
u_max         = (m + M) * (x_dot_max ** 2) / x_max

state_norm = (x_max, theta_max, x_dot_max, theta_dot_max)
action_norm = (u_max,)

# Define system and dynamics
cart_pole = CartPole(m, M, L, b, dt, [state_norm, action_norm])
state_dim = 4
action_dim = 1

state_limits = np.array([[-1., 1.]]*cart_pole.state_dim)
action_limits = np.array([[-1., 1.]]*cart_pole.action_dim)

A, B = cart_pole.linearize()   
dynamics = safe_learning.functions.LinearSystem((A, B), name='dynamics')

print(state_norm)
print(action_norm)

## Reward Function

In [None]:
# State cost matrix
Q = np.diag([0.1, 0.1, 0.1, 0.1])

# Action cost matrix
R = 0.1 * np.identity(action_dim)

# Quadratic reward (-cost) function
reward_function = safe_learning.QuadraticFunction(block_diag(-Q, -R), name='reward_function')

## Parametric Policy and Value Function

In [None]:
# Policy
layer_dims = [64, 64, action_dim]
activations = [tf.nn.relu, tf.nn.relu, None]
if OPTIONS.saturate:
    activations[-1] = tf.nn.tanh
policy = safe_learning.functions.NeuralNetwork(layer_dims, activations, name='policy', use_bias=False)

# Value function
layer_dims = [64, 64, 1]
activations = [tf.nn.relu, tf.nn.relu, None]
value_function = safe_learning.functions.NeuralNetwork(layer_dims, activations, name='value_function', use_bias=False)

## LQR Policy

In [None]:
K, P = safe_learning.utilities.dlqr(A, B, Q, R)
policy_lqr = safe_learning.functions.LinearSystem((-K, ), name='policy_lqr')
if OPTIONS.saturate:
    policy_lqr = safe_learning.Saturation(policy_lqr, -1, 1)

## TensorFlow Graph and Initialization

In [None]:
# Use parametric policy and value function
states = tf.placeholder(OPTIONS.tf_dtype, shape=[None, state_dim], name='states')
actions = policy(states)
rewards = reward_function(states, actions)
values = value_function(states)
future_states = dynamics(states, actions)
future_values = value_function(future_states)

# Compare with LQR solution, possibly with saturation constraints
actions_lqr = policy_lqr(states)
rewards_lqr = reward_function(states, actions_lqr)
future_states_lqr = dynamics(states, actions_lqr)

# Discount factor and scaling
max_state = np.ones((1, state_dim))
max_action = np.ones((1, action_dim))
r_max = np.linalg.multi_dot((max_state, Q, max_state.T)) + np.linalg.multi_dot((max_action, R, max_action.T))
gamma = tf.placeholder(OPTIONS.tf_dtype, shape=[], name='discount_factor')

val_scaling = 1 / r_max.ravel()
pol_scaling = (1 - gamma) / r_max.ravel()

# Policy evaluation
with tf.name_scope('value_optimization'):
    value_learning_rate = tf.placeholder(OPTIONS.tf_dtype, shape=[], name='learning_rate')
    target = tf.stop_gradient(rewards + gamma * future_values, name='target')
    value_objective = pol_scaling * tf.reduce_mean(tf.abs(values - target), name='objective')
    optimizer = tf.train.GradientDescentOptimizer(value_learning_rate)
    value_update = optimizer.minimize(value_objective, var_list=value_function.parameters)

# Policy improvement
with tf.name_scope('policy_optimization'):
    lagrange_multiplier = tf.placeholder(tf_dtype, shape=[], name='lagrange_multiplier')
    policy_learning_rate = tf.placeholder(OPTIONS.tf_dtype, shape=[], name='learning_rate')
    policy_objective = - pol_scaling * tf.reduce_mean(rewards + gamma * future_values, name='objective') # + lagrange_multiplier * policy.lipschitz()
    optimizer = tf.train.GradientDescentOptimizer(policy_learning_rate)
    policy_update = optimizer.minimize(policy_objective, var_list=policy.parameters)
    
# Sampling    
with tf.name_scope('state_sampler'):
    batch_size = tf.placeholder(tf.int32, shape=[], name='batch_size')
    batch = tf.random_uniform([batch_size, state_dim], -1, 1, dtype=OPTIONS.tf_dtype, name='batch')

## Approximate Policy Iteration

### Initialization

In [None]:
session.run(tf.global_variables_initializer())
value_obj_history = np.zeros(0)
value_param_history = np.zeros(0)
policy_obj_history = np.zeros(0)
policy_param_history = np.zeros(0)

# Uniformly distributed test set
test_size = 1e3
grid_length = np.power(test_size, 1 / state_dim)
grid_length = int(2*np.floor(grid_length / 2) + 1)
test_set = safe_learning.GridWorld(state_limits, [grid_length,]*state_dim).all_points

### Training

In [None]:
# Training hyperparameters
max_iters    = 200
value_iters  = 100
policy_iters = 10

feed_dict = {
    states:                test_set,
    gamma:                 0.99,
    value_learning_rate:   0.2,
    policy_learning_rate:  0.5,
    batch_size:            1e2,
    lagrange_multiplier:   0.
}
# batch = constrained_batch_sampler(dynamics, policy, state_dim, feed_dict[batch_size], action_limit=0.999, zero_pad=0)

# Record objective values over time
value_obj_eval = np.zeros(max_iters + 1)
policy_obj_eval = np.zeros(max_iters + 1)
value_obj_eval[0] = value_objective.eval(feed_dict)
policy_obj_eval[0] = policy_objective.eval(feed_dict)

# For convergence, check the parameter values
converged = False
value_param_changes = np.zeros(max_iters)
policy_param_changes = np.zeros(max_iters)
old_value_params = session.run(value_function.parameters)
old_policy_params = session.run(policy.parameters)


for i in tqdm(range(max_iters)):
    # Policy evaluation (value update)
    for _ in range(value_iters):
        feed_dict[states] = batch.eval(feed_dict)
        session.run(value_update, feed_dict)
    new_value_params = session.run(value_function.parameters)
    value_param_changes[i] = get_parameter_change(old_value_params, new_value_params)
    old_value_params = new_value_params

    # Policy improvement (policy update)
    for _ in range(policy_iters):
        feed_dict[states] = batch.eval(feed_dict)
        session.run(policy_update, feed_dict)
    new_policy_params = session.run(policy.parameters)
    policy_param_changes[i] = get_parameter_change(old_policy_params, new_policy_params)
    old_policy_params = new_policy_params
    
    # Record objectives
    feed_dict[states] = test_set
    value_obj_eval[i+1] = value_objective.eval(feed_dict)
    policy_obj_eval[i+1] = policy_objective.eval(feed_dict)
    
    # TODO debugging    
    if np.isnan(value_obj_eval[i+1]) or np.isnan(policy_obj_eval[i+1]):
        raise ValueError('Encountered NAN value after {} iterations!'.format(i+1))

final_iter           = i + 1
value_obj_history    = np.concatenate((value_obj_history, value_obj_eval[:final_iter+1]))
value_param_history  = np.concatenate((value_param_history, value_param_changes[:final_iter+1]))
policy_obj_history   = np.concatenate((policy_obj_history, policy_obj_eval[:final_iter+1]))
policy_param_history = np.concatenate((policy_param_history, policy_param_changes[:final_iter+1]))

### Results

In [None]:
plt.rc('font', size=12)
fig = plt.figure(figsize=(6, 5), dpi=200)
fig.subplots_adjust(hspace=0.4)

ax = fig.add_subplot(211)
ax.plot(value_obj_history, '.-r')
ax.set_xlabel(r'Policy iteration $k$')
ax.set_ylabel(r'$G(\mathcal{X}_v, {\bf \theta}_k)$')
ax.set_ylim([0, 0.06])

ax = fig.add_subplot(212)
ax.plot(value_param_history, '.-r')
ax.set_xlabel(r'Policy iteration $k$')
ax.set_ylabel(r'$||{\bf \theta}_k - {\bf \theta}_{k-1}||_\infty$')
ax.set_ylim([0, 0.8])

if OPTIONS.save_figs:
    gamma_string = str(feed_dict[gamma])[2:]
    fig.savefig(OPTIONS.fig_path + 'cartpole_policyiter_costfunc_training_gamma' + gamma_string + '.pdf', bbox_inches='tight')


#
fig = plt.figure(figsize=(6, 5), dpi=200)
fig.subplots_adjust(hspace=0.4)

ax = fig.add_subplot(211)
ax.plot(policy_obj_history, '.-r')
ax.set_xlabel(r'Policy iteration $k$')
ax.set_ylabel(r'$H(\mathcal{X}_v, {\bf \delta}_k)$')
ax.set_ylim([0, 3])

ax = fig.add_subplot(212)
ax.plot(policy_param_history, '.-r')
ax.set_xlabel(r'Policy iteration $k$')
ax.set_ylabel(r'$||{\bf \delta}_k - {\bf \delta}_{k-1}||_\infty$')
ax.set_ylim([0, 0.1])

plt.show()

if OPTIONS.save_figs:
    gamma_string = str(feed_dict[gamma])[2:]
    fig.savefig(OPTIONS.fig_path + 'cartpole_policyiter_policy_training_gamma' + gamma_string + '.pdf', bbox_inches='tight')

## Estimated Value Functions and ROAs

In [None]:
# Grid for plotting
N = 51
norms = np.asarray([x_max, np.rad2deg(theta_max), x_dot_max, np.rad2deg(theta_dot_max)])
maxes = np.copy(norms)
grid = gridify(norms, maxes, N)

# Estimate value functions and ROAs with rollout
roa_horizon  = 5000
cost_horizon = 500
roa_tol      = 0.1
cost_tol     = 0.01
discount     = feed_dict[gamma]
fixed_state  = [0., 0., 0., 0.]

# Snap fixed_state to the closest grid point
fixed_state = np.asarray(fixed_state, dtype=OPTIONS.np_dtype)
fixed_index = np.zeros_like(fixed_state, dtype=int)
for d in range(grid.ndim):
    fixed_index[d], fixed_state[d] = find_nearest(grid.discrete_points[d], fixed_state[d])

# Get 2d-planes of the discretization (x vs. v, theta vs. omega) according to fixed_state
planes = [[1, 3], [0, 2]]
grid_slices = []
for p in planes:
    grid_slices.append(np.logical_and(grid.all_points[:, p[0]] == fixed_state[p[0]], 
                                      grid.all_points[:, p[1]] == fixed_state[p[1]]).ravel())

# LQR solution
closed_loop_dynamics = lambda x: future_states_lqr.eval({states: x})
cost_function        = lambda x: - rewards_lqr.eval({states: x})
true_costs           = [estimate_cost(grid.all_points[mask], closed_loop_dynamics, cost_function, discount, cost_horizon, cost_tol) 
                        for mask in grid_slices]
# true_costs           = [c / c.max() for c in true_costs]
true_roas            = [compute_roa(grid.all_points[mask], closed_loop_dynamics, roa_horizon, roa_tol)[0]
                        for mask in grid_slices]

# Parametric policy's value function
closed_loop_dynamics = lambda x: future_states.eval({states: x})
cost_function        = lambda x: - rewards.eval({states: x})
est_costs            = [estimate_cost(grid.all_points[mask], closed_loop_dynamics, cost_function, discount, cost_horizon, cost_tol) 
                        for mask in grid_slices]
# est_costs            = [c / c.max() for c in est_costs]
est_roas             = [compute_roa(grid.all_points[mask], closed_loop_dynamics, roa_horizon, roa_tol)[0]
                        for mask in grid_slices]

# Parametric value function
par_costs = [- values.eval({states: grid.all_points[mask]}) for mask in grid_slices]
# par_costs = [c / c.max() for c in par_costs]


In [None]:
# Plotting
# plot_cost(true_costs, grid, norms, grid_slices, true_roas)
# plot_cost(est_costs, grid, norms, grid_slices, est_roas)

planes = [[0, 2], [1, 3]]
limits = np.asarray(norms).reshape((-1, 1)) * grid.limits
scaled_discrete_points = [norm * points for norm, points in zip(norms, grid.discrete_points)]

plt.rc('font', size=20)
pad = 20
fig = plt.figure(figsize=(10, 16), dpi=OPTIONS.dpi)
# fig.subplots_adjust(wspace=0.4, hspace=0.2)

for i, p in enumerate(planes):
    ax = fig.add_subplot(211 + i, projection='3d')
    if i == 0:
        ax.set_title(r'$\phi = \dot{\phi} = 0$' + '\n')
        ax.set_xlabel(r'$x$ [m]', labelpad=pad)
        ax.set_ylabel(r'$\dot{x}$ [m/s]', labelpad=pad)
        ax.xaxis.set_ticks(np.arange(-x_max, 1.01 * x_max, 0.25))
        ax.yaxis.set_ticks(np.arange(-x_dot_max, 1.01 * x_dot_max, 1))
    else:
        ax.set_title(r'$x = \dot{x} = 0$' + '\n')
        ax.set_xlabel(r'$\phi$ [deg]', labelpad=pad)
        ax.set_ylabel(r'$\dot{\phi}$ [deg/s]', labelpad=pad)
    ax.view_init(None, -45)

    xx, yy = np.meshgrid(*[scaled_discrete_points[p[0]], scaled_discrete_points[p[1]]])

    for j, (costs, roas, color) in enumerate(zip([true_costs, est_costs, par_costs], 
                                                 [true_roas, est_roas, None],
                                                 [(0, 0, 1, 0.6), (0, 1, 0, 0.8), (1, 0, 0, 0.65)])):
        z = costs[i].reshape(grid.num_points[p])
#         z /= z.max()
        surf = ax.plot_surface(xx, yy, z, color=color) #, label=r'$J_{\bf \theta}({\bf x})$')
        surf._facecolors2d = surf._facecolors3d
        surf._edgecolors2d = surf._edgecolors3d
#         if roas is not None:
#             z = roas[i].reshape(grid.num_points[p])
#             ax.contourf(xx, yy, z, cmap=binary_cmap(color, 0.5), zdir='z', offset=0)
    proxy = [plt.Rectangle((0,0), 1, 1, fc=c) for c in [(0, 0, 1, 0.6), (0, 1, 0, 0.8), (1, 0, 0, 0.65)]]    
    ax.legend(proxy, [r'$J_{\pi}({\bf x})$', r'$J_{\pi_{\bf \delta}}({\bf x})$', r'$J_{\bf \theta}({\bf x})$'], 
              loc=(0.85, 0.85))

fig.tight_layout()
plt.show()

if OPTIONS.save_figs:
    gamma_string = str(feed_dict[gamma])[2:]
    fig.savefig(OPTIONS.fig_path + 'cartpole_policyiter_costfunc_gamma' + gamma_string + '.pdf', bbox_inches='tight')

## Parametric Policy

In [None]:
plt.rc('font', size=20)
pad = 20
fig = plt.figure(figsize=(10, 16), dpi=OPTIONS.dpi)

for i, (p, mask) in enumerate(zip(planes, grid_slices)):
    ax = fig.add_subplot(211 + i, projection='3d')
    if i == 0:
        ax.set_title(r'$\phi = \dot{\phi} = 0$' + '\n')
        ax.set_xlabel(r'$x$ [m]', labelpad=pad)
        ax.set_ylabel(r'$\dot{x}$ [m/s]', labelpad=pad) 
        ax.xaxis.set_ticks(np.arange(-x_max, 1.01 * x_max, 0.25))
        ax.yaxis.set_ticks(np.arange(-x_dot_max, 1.01 * x_dot_max, 1))
    else:
        ax.set_title(r'$x = \dot{x} = 0$' + '\n')
        ax.set_xlabel(r'$\phi$ [deg]', labelpad=pad)
        ax.set_ylabel(r'$\dot{\phi}$ [deg/s]', labelpad=pad)
    ax.view_init(None, -45)
    
    xx, yy = np.meshgrid(*[scaled_discrete_points[p[0]], scaled_discrete_points[p[1]]])
    acts = u_max * actions.eval({states: grid.all_points[mask]})
    true_acts = u_max * actions_lqr.eval({states: grid.all_points[mask]})

    ax.plot_surface(xx, yy, true_acts.reshape(grid.num_points[p]), color='blue', alpha=0.55)
    ax.plot_surface(xx, yy, acts.reshape(grid.num_points[p]), color='red', alpha=0.75)

    z = est_roas[i].reshape(grid.num_points[p])
    ax.contourf(xx, yy, z, cmap=binary_cmap('green', 0.65), zdir='z', offset=-u_max)
    
    ax.tick_params(axis='z', which='major', pad=10)
    proxy = [plt.Rectangle((0,0), 1, 1, fc=c) for c in [(0, 0, 1, 0.6), (1, 0, 0, 0.65), (0., 1., 0., 0.65)]]
    ax.legend(proxy, [r'$\pi({\bf x})$ [N]', r'$\pi_{\bf \delta}({\bf x})$ [N]', r'ROA for $\pi_{\bf \delta}$'], 
              loc=(0.85, 0.85))

fig.tight_layout()
plt.show()

if OPTIONS.save_figs:
    gamma_string = str(feed_dict[gamma])[2:]
    fig.savefig(OPTIONS.fig_path + 'cartpole_policyiter_policy_gamma' + gamma_string + '.pdf', bbox_inches='tight')

## Full ROA

In [None]:
# Grid for ROA computation
# N = 11
# norms = np.asarray([x_max, np.rad2deg(theta_max), x_dot_max, np.rad2deg(theta_dot_max)])
# maxes = np.copy(norms)
# disc = gridify(norms, maxes, N)

# #
# closed_loop_dynamics = lambda x: future_states.eval({states: x})
# roa_horizon          = 3000
# roa_tol              = 0.1

# #
# roa, _ = compute_roa(disc.all_points, closed_loop_dynamics, roa_horizon, roa_tol)
# roa_size = roa.sum()

In [None]:
# roa_size = roa.sum()

# print(disc.nindex)
# print(roa_size)
# print(roa_size / disc.nindex)

## Zero-Input and Step Responses

In [None]:
ic = np.array([0.1, 0.2, 0.1, 0.1])
const = 1 / u_max
steps = 1000

print(session.run(actions, {states: np.zeros([1, state_dim])}))
print(session.run(values, {states: np.zeros([1, state_dim])}))

# Zero-input response, zero initial condition
plot_closedloop_response(dynamics, policy, policy_lqr, steps, dt, 'zero', ic=None, labels=['Learned','True'], denormalize=True)

# Zero-input response, non-zero initial condition
plot_closedloop_response(dynamics, policy, policy_lqr, steps, dt, 'zero', ic=ic, labels=['Learned','True'], denormalize=False)

# Step response
# plot_closedloop_response(dynamics, policy, policy_lqr, steps, dt, 'step', const=const, labels=['Learned','True'], denormalize=False)

# Impulse response
# plot_closedloop_response(dynamics, policy, lqr_policy, steps, dt, 'impulse', labels=['Learned','True'])

## TensorBoard Graph

In [None]:
# plotting.show_graph(tf.get_default_graph())