In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from mpl_toolkits.mplot3d import Axes3D
%matplotlib inline

import safe_learning
import plotting

try:
    from tqdm import tqdm
except ImportError:
    tqdm = lambda x: x

In [None]:
def plot_triangulation(values, axis=None, three_dimensional=False, zlabel=None, **kwargs):
    """Plot a triangulation.
    
    Parameters
    ----------
    values: ndarray
    axis: optional
    three_dimensional: bool, optional
        Whether to plot 3D
        
    Returns
    -------
    axis:
        The axis on which we plotted.
    """
    if three_dimensional:
        if axis is None:
            axis = Axes3D(plt.figure())

        # Get the simplices and plot
        delaunay = value_function.tri
        simplices = delaunay.simplices(np.arange(delaunay.nsimplex))
        c = axis.plot_trisurf(state_space[:, 0], state_space[:, 1], values[:, 0],
                            triangles=simplices.copy(),
                            cmap='viridis', lw=0.1, **kwargs)
        cbar = plt.colorbar(c)
    else:
        if axis is None:
            axis = plt.figure().gca()
            
        # Some magic reshaping to go to physical coordinates
        vals = values.reshape(n_points[0], n_points[1]).T[::-1]
        axis = plt.imshow(vals.copy(), origin='upper',
                        extent=domain[0] + domain[1],
                        aspect='auto', cmap='viridis', interpolation='bilinear', **kwargs)
        cbar = plt.colorbar(axis)
        axis = axis.axes
        
    axis.set_xlabel('position')
    axis.set_ylabel('velocity')
    if zlabel is not None:
        cbar.set_label(zlabel)
        
    return axis

##### TODO: Handle different terminal states in a better way!

In [None]:
domain = [[-1.2, 0.7], [-.07, .07]]
n_points = [20, 20]


discretization = safe_learning.GridWorld(domain, n_points)
state_space = discretization.all_points
value_function = safe_learning.Triangulation(discretization, np.zeros_like(state_space[:, 0]), project=True,
                                             name='tri_value_function')

policy = safe_learning.Triangulation(discretization, np.ones_like(state_space[:, 0]), project=True,
                                     name='tri_policy')
policy.__call__ = lambda x: tf.clip_by_value(policy.evaluate(x), -1., 1.)

gamma = .98
# Maximum long-term reward is 1.
terminal_reward = 1 - gamma

@safe_learning.utilities.with_scope('true_dynamics')
def dynamics(states, actions):
    """Return future states of the car"""    
    x0 = states[:, 0] + states[:, 1]
    x1 = states[:, 1] + 0.001 * actions[:, 0] - 0.0025 * tf.cos(3 * states[:, 0])
    
    return tf.stack((x0, x1), axis=1)


@safe_learning.utilities.with_scope('reward_function')
def reward_function(states, actions):
    zeros = tf.zeros((states.shape[0], 1), tf.float64)
    ones = tf.ones_like(zeros)
    return tf.where(tf.greater(states[:, 0], 0.6), terminal_reward * ones, zeros)

rl = safe_learning.PolicyIteration(
    policy,
    dynamics,
    reward_function,
    value_function,
    gamma=gamma)

In [None]:
try:
    session.close()
except NameError:
    pass
finally:
    session = tf.InteractiveSession()
    session.run(tf.global_variables_initializer())

In [None]:
with tf.variable_scope('optimization'):
    old_values = np.zeros_like(rl.value_function.parameters.eval())
    old_policy = np.zeros_like(rl.policy.parameters.eval())
    converged = False
    action_space = np.array([[-1.], [1.]])

    # value_opt = rl.value_iteration()
    value_opt = rl.optimize_value_function()
    policy_loss = -tf.reduce_sum(rl.future_values(rl.state_space))


    adapt_policy = tf.train.GradientDescentOptimizer(0.1).minimize(policy_loss,
                                                                   var_list=[rl.policy.parameters])

    # Constrain the optimization
    adapt_policy = safe_learning.utilities.add_constraint(adapt_policy,
                                                          var_list=[rl.policy.parameters],
                                                          bound_list=[(-1, 1)])


    for i in range(30):
        # Optimize value function
        value_opt.eval()

        # Optimize policy
        rl.discrete_policy_optimization(action_space)
    #     for _ in range(200):
    #         session.run(adapt_policy)

        # Get new parameters
        values, policy = session.run([rl.value_function.parameters,
                                      rl.policy.parameters])

        # Compute errors
        value_change = np.max(np.abs(old_values - values))
        policy_change = np.max(np.abs(old_policy - policy))

        # Break if converged
        if value_change <= 1e-1 and policy_change <= 1e-1:
            converged = True
            break
        else:
            old_values = values
            old_policy = policy


    if converged:
        print('converged after {} iterations. \nerror: {}, \npolicy: {}'
              .format(i + 1, value_change, policy_change))
    else:
        print('didnt converge, error: {} and policy: {}'
              .format(value_change, policy_change))

In [None]:
plot_triangulation(rl.value_function.parameters.eval(), zlabel='values')
plt.show()

plot_triangulation(rl.value_function.parameters.eval(), three_dimensional=True, zlabel='values')
plt.show()

In [None]:
plot_triangulation(rl.policy.parameters.eval(), zlabel='policy', three_dimensional=False)
plt.show()

In [None]:
with tf.variable_scope('evaluate_dynamics'):
    states = np.zeros((1000, 2), dtype=np.float)
    states[0, 0] = -0.5

    state = tf.placeholder(tf.float64, [1, 2])
    dynamics = rl.dynamics(state, rl.policy(state))

    for i in range(len(states) - 1):
        states[i+1, :] = dynamics.eval(feed_dict={state: states[[i], :]})

        # break if terminal
        if states[i+1, 0] >= 0.6:
            states[i+1:] = states[i+1]
            break


In [None]:
ax = plot_triangulation(rl.value_function.parameters.eval())
ax.plot(states[:,0], states[:, 1], lw=3, color='k')
ax.plot(np.ones(2) * 0.6, ax.get_ylim(), lw=2, color='r')

ax.set_xlabel('pos')
ax.set_ylabel('vel')
ax.set_xlim(domain[0])
ax.set_ylim(domain[1])

plt.show()

In [None]:
plotting.show_graph(tf.get_default_graph())