In [None]:
import numpy as np
import matplotlib.pyplot as plt
# import cvxpy
from mpl_toolkits.mplot3d import Axes3D
%matplotlib inline

# Try to import safe_rl from system
# if it fails get it from the main folder directly instead.
try:
    import safe_rl
except ImportError:
    import utilities
    safe_rl = utilities.import_from_directory('safe_rl', '../')
    
try:
    from tqdm import tqdm
except ImportError:
    tqdm = lambda x: x

In [None]:
def plot_triangulation(values, axis=None, three_dimensional=False, **kwargs):
    """Plot a triangulation.
    
    Parameters
    ----------
    values: ndarray
    axis: optional
    three_dimensional: bool, optional
        Whether to plot 3D
        
    Returns
    -------
    axis:
        The axis on which we plotted.
    """
    if three_dimensional:
        if axis is None:
            axis = Axes3D(plt.figure())

        # Get the simplices and plot
        simplices = delaunay.simplices(np.arange(delaunay.nsimplex))
        c = axis.plot_trisurf(state_space[:, 0], state_space[:, 1], values,
                            triangles=simplices.copy(),
                            cmap='viridis', lw=0.1, **kwargs)
        plt.colorbar(c)
    else:
        if axis is None:
            axis = plt.figure().gca()
            
        # Some magic reshaping to go to physical coordinates
        vals = values.reshape(n_points[0] + 1, n_points[1] + 1).T[::-1]
        axis = plt.imshow(vals.copy(), origin='upper',
                        extent=domain[0] + domain[1],
                        aspect='auto', cmap='viridis', interpolation='bilinear', **kwargs)
        plt.colorbar(axis)
        axis = axis.axes
        
    return axis

## TODO: Handle different terminal states in a better way!
## TODO: Why are there troubles for finely discretized domains and many iterations?

In [None]:
domain = [[-1.2, 0.7], [-.07, .07]]
n_points = [10, 10]
delaunay = safe_rl.Delaunay(domain, n_points)

state_space = delaunay.index_to_state(np.arange(delaunay.nindex))
action_space = np.array([-1, 1])

gamma = 0.99
terminal_reward = 1


def dynamics(states, actions):
    """Return future states of the car"""
    states = np.atleast_2d(states)
    actions = np.atleast_2d(actions)
    
    next_states = states.copy()
    
    next_states[:, 0] += states[:, 1]
    next_states[:, 1] += 0.001 * actions[:, 0] - 0.0025 * np.cos(3 * states[:, 0])
    
    return next_states


def reward_function(states, actions, next_states):
    states = np.atleast_2d(states)
    return terminal_reward * (states[:, 0] >= 0.6).astype('int')


def is_terminal(states):
    """Return true if states are terminal.
    
    Parameters
    ----------
    states: ndarray
    
    Returns
    -------
    is_temrminal: boolean array
    """
    return states[:, 0] >= 0.6


def get_value_function(states, actions, vertex_values):
    """Perform one round of value updates.
    
    Parameters
    ----------
    states: ndarray
    actions: ndarray
    vertex_values: ndarray
    
    Returns
    -------
    values: ndarray
        The updated values
    """
    next_states = dynamics(states, actions)
    rewards = reward_function(states, actions, next_states)
    
    expected_values = delaunay.function_values_at(next_states,
                                                  vertex_values=vertex_values)
    
    # Perform value update
    values = rewards + gamma * expected_values
    
    # Adapt values of terminal states
    values[is_terminal(states)] = terminal_reward
    
    return values
    

def optimize_policy(states, vertex_values):
    """Optimize the policy for a given value function.
    
    Parameters
    ----------
    states: ndarray
    vertex_values: ndarray
    
    Returns
    -------
    policy: ndarray
        The optimal policy for the given value function.
    """
    # Initialize
    values = np.empty((len(vertex_values), len(action_space)), dtype=np.float)
    actions = np.empty((delaunay.nindex, 1), dtype=np.float)
    
    # Compute values for each action
    for i, action in enumerate(action_space):
        actions[:] = action
        values[:, i] = get_value_function(states, actions, vertex_values)
    
    # Select best one
    return action_space[np.argmax(values, axis=1)]
    

In [None]:
# Initial guess for values and policy
values = np.zeros(delaunay.nindex)
policy = np.random.choice(action_space, size=len(state_space))

old_values = values.copy()
old_policy = policy.copy()

converged = False

for i in tqdm(range(1000)):
    values = get_value_function(state_space, policy, values)
    policy = optimize_policy(state_space, values)
        
    # Compute errors
    value_change = np.max(np.abs(old_values - values))
    policy_converged = np.all(old_policy == policy)
    
    # Break if converged
    if value_change <= 1e-2:# and policy_converged:
        converged = True
        break
    else:
        old_values[:] = values
        old_policy[:] = policy

if converged:
    print('converged after {} iterations. \nerror: {}, \npolicy: {}'.format(i + 1, value_change, policy_converged))
else:
    print('didnt converge, error: {} and policy: {}'.format(value_change, policy_converged))

In [None]:
plot_triangulation(values)
plt.show()

plot_triangulation(values, three_dimensional=True)
plt.show()

In [None]:
plot_triangulation(policy)
plt.show()

In [None]:
states = np.zeros((1000, 2), dtype=np.float)
states[0, 0] = -0.5

for i in range(len(states) - 1):
    # interpolate action
    action = delaunay.function_values_at(states[[i], :], vertex_values=policy)
    action = action_space[np.argmin(np.abs(action - action_space))]

    states[i+1, :] = dynamics(states[i, :], action)
    if states[i+1, 0] >= 0.6:
        states[i+1:, :] = states[i+1]
        break


In [None]:
ax = plot_triangulation(values)
ax.plot(states[:,0], states[:, 1], lw=3, color='k')

ax.set_xlabel('pos')
ax.set_ylabel('vel')
ax.set_xlim(domain[0])
ax.set_ylim(domain[1])

plt.show()