In [None]:
import numpy as np
import matplotlib.pyplot as plt
# import cvxpy
from mpl_toolkits.mplot3d import Axes3D
%matplotlib inline

# Try to import safe_rl from system
# if it fails get it from the main folder directly instead.
import utilities
safe_learning = utilities.import_from_directory('safe_learning', '../')
    
try:
    from tqdm import tqdm
except ImportError:
    tqdm = lambda x: x

In [None]:
def plot_triangulation(values, axis=None, three_dimensional=False, zlabel=None, **kwargs):
    """Plot a triangulation.
    
    Parameters
    ----------
    values: ndarray
    axis: optional
    three_dimensional: bool, optional
        Whether to plot 3D
        
    Returns
    -------
    axis:
        The axis on which we plotted.
    """
    if three_dimensional:
        if axis is None:
            axis = Axes3D(plt.figure())

        # Get the simplices and plot
        simplices = delaunay.simplices(np.arange(delaunay.nsimplex))
        c = axis.plot_trisurf(state_space[:, 0], state_space[:, 1], values,
                            triangles=simplices.copy(),
                            cmap='viridis', lw=0.1, **kwargs)
        cbar = plt.colorbar(c)
    else:
        if axis is None:
            axis = plt.figure().gca()
            
        # Some magic reshaping to go to physical coordinates
        vals = values.reshape(n_points[0] + 1, n_points[1] + 1).T[::-1]
        axis = plt.imshow(vals.copy(), origin='upper',
                        extent=domain[0] + domain[1],
                        aspect='auto', cmap='viridis', interpolation='bilinear', **kwargs)
        cbar = plt.colorbar(axis)
        axis = axis.axes
        
    axis.set_xlabel('position')
    axis.set_ylabel('velocity')
    if zlabel is not None:
        cbar.set_label(zlabel)
        
    return axis

##### TODO: Handle different terminal states in a better way!

In [None]:
domain = [[-1.2, 0.7], [-.07, .07]]
n_points = [50, 50]
delaunay = safe_learning.Triangulation(domain, n_points, project=True)

state_space = delaunay.index_to_state(np.arange(delaunay.nindex))
action_space = np.array([-1, 1])

gamma = .98
terminal_reward = 1


def dynamics(states, actions):
    """Return future states of the car"""
    states = np.atleast_2d(states)
    actions = np.atleast_2d(actions)
    
    next_states = states.copy()
    
    next_states[:, 0] += states[:, 1]
    next_states[:, 1] += 0.001 * actions[:, 0] - 0.0025 * np.cos(3 * states[:, 0])
    
    return next_states


def reward_function(states, actions, next_states):
    states = np.atleast_2d(states)
    return terminal_reward * (states[:, 0] >= 0.6).astype('int')

In [None]:
rl = safe_learning.PolicyIteration(
    state_space,
    action_space,
    dynamics,
    reward_function,
    function_approximator=delaunay,
    gamma=gamma,
    terminal_states=state_space[:, 0] >= 0.6,
    terminal_reward=1)

old_values = np.zeros(delaunay.nindex, dtype=np.float)
old_policy = np.random.choice(action_space, size=len(state_space))

converged = False

for i in tqdm(range(1000)):
#     rl.update_value_function()
    rl.optimize_value_function()
    rl.update_policy()
        
    # Compute errors
    value_change = np.max(np.abs(old_values - rl.values))
    policy_converged = np.all(old_policy == rl.policy)
    
    # Break if converged
    if value_change <= 1e-5 and policy_converged:
        converged = True
        break
    else:
        old_values[:] = rl.values
        old_policy[:] = rl.policy

if converged:
    print('converged after {} iterations. \nerror: {}, \npolicy: {}'
          .format(i + 1, value_change, policy_converged))
else:
    print('didnt converge, error: {} and policy: {}'
          .format(value_change, policy_converged))

In [None]:
plot_triangulation(rl.values, zlabel='values')
plt.show()

plot_triangulation(rl.values, three_dimensional=True, zlabel='values')
plt.show()

In [None]:
plot_triangulation(rl.policy, zlabel='policy', three_dimensional=False)
plt.show()

In [None]:
states = np.zeros((1000, 2), dtype=np.float)
states[0, 0] = -0.5

policy = safe_learning.PiecewiseConstant(rl.value_function.limits, 
                                         rl.value_function.num_points,
                                         vertex_values=rl.policy)

for i in range(len(states) - 1):
    # interpolate action
    action = policy.evaluate(states[[i], :])

    states[i+1, :] = dynamics(states[i, :], action)
    if states[i+1, 0] >= 0.6:
        states[i+1:, :] = states[i+1]
        break


In [None]:
ax = plot_triangulation(rl.values)
ax.plot(states[:,0], states[:, 1], lw=3, color='k')
ax.plot(np.ones(2) * 0.6, ax.get_ylim(), lw=2, color='r')

ax.set_xlabel('pos')
ax.set_ylabel('vel')
ax.set_xlim(domain[0])
ax.set_ylim(domain[1])

plt.show()

In [None]:
%time rl.optimize_value_function()

In [None]:
%time rl.update_value_function()