In [None]:
import numpy as np
import matplotlib.pyplot as plt
import cvxpy
%matplotlib inline

# Try to import safe_rl from system
# if it fails get it from the main folder directly instead.
try:
    import safe_rl
except ImportError:
    import utilities
    safe_rl = utilities.import_from_directory('safe_rl', '../')

#### Get matrix for values of test points

For each vertex x, the next state is given by some function $f(x)$. The goal is to find the value $V(f(x))$ as a function of the values at the other vertices. This value can then be used as a constraint in the optimization problem.
$$V(f(x)) = a^\mathrm{T} V(\mathrm{vertices})$$

We start by finding the simplex corresponding to $f(x)$. Given this simplex's vertices, we solve for $a$ above, using one arbitrary vertex as a reference point:
$$V(f(x)) = V(v_1) + a_1 V(v_2) + a_2 V(v_3)$$

The hyperplane spanned by the simplex is given by
\begin{align*}
\xi_1 ( x_2 - x_1) + \xi_2 (y_2 - y_1) = V(v_2) - V(v_1) \\
\xi_1 ( x_3 - x_1) + \xi_2 (y_3 - y_1) = V(v_3) - V(v_1) 
\end{align*}

and, as a result,
$$V(f(x_t)) = V(v_1) + (x_t - x_1, y_t - y_1)
\left( \begin{matrix}
x_2 - x_1 & y_2 - y_1 \\
x_3 - x_1 & y_3 - y_1
\end{matrix} \right)^{-1}
\left( \begin{matrix}
V(v_2) - V(v_1) \\
V(v_3) - V(v_1)
\end{matrix} \right),
$$
which can easily be solved for $a_1$ and $a_2$.

## TODO: Add fake terminal state with zero reward.

In [None]:
def dynamics(states, actions):
    """Return future states of the car"""
    states = np.atleast_2d(states)
    actions = np.atleast_2d(actions)
    future_states = states.copy()
    future_states[:, 0] += states[:, 1]
    future_states[:, 1] += 0.001 * actions[:, 0] - 0.0025 * np.cos(3 * states[:, 0])
    
    # Add constraints
    np.clip(future_states[:, 1], -0.07, 0.07, out=future_states[:, 1])
    future_states[states[:, 0] >= 0.6, 0] = 0.6
    future_states[future_states[:, 0] < -1.2, 0] = -1.2
    return future_states

def reward_function(states, actions, next_states):
    states = np.atleast_2d(states)
#     return np.logical_and(states[:, 0] < 0.6, next_states[:, 0] >= 0.6).astype('int')
    return -0.01 * (next_states[:, 0] < 0.6).astype('int')

domain = [[-1.3, 0.7], [-.08, .08]]
n_points = [5, 5]
delaunay = safe_rl.Delaunay(domain, n_points)

In [None]:
states = delaunay.index_to_state(np.arange(delaunay.nindex))

# random actions
action_space = np.array([-1, 1])
policy = np.random.rand(len(states)) > 0.5

def get_value_function(policy):
    next_states = dynamics(states, policy)
    rewards = reward_function(states, policy, next_states)

    h = delaunay.function_values_at(next_states).toarray()
    values = cvxpy.Variable(delaunay.nindex)
    eps = cvxpy.Variable(delaunay.nindex)
    future_values = h * values

    gamma = cvxpy.Parameter(sign='positive')

    objective = cvxpy.Maximize(cvxpy.sum_entries(values))
    constraints = [values <= rewards + gamma * future_values]

    prob = cvxpy.Problem(objective, constraints)
    gamma.value = 0.99
    prob.solve()
    
    return np.asarray(values.value).squeeze(), prob, np.asarray(eps.value).squeeze()



def optimize_policy(vertex_values):
    values = np.empty((len(vertex_values), len(action_space)), dtype=np.float)
    for i, action in enumerate(action_space):
        next_states = dynamics(states, action * np.ones((len(states), 1)))
        values[:, i] = delaunay.function_values_at(next_states, vertex_values=vertex_values)
        
    best_actions = np.argmax(values, axis=1)
    return action_space[best_actions]

In [None]:
old_values = -100 * np.ones(delaunay.nindex)
old_policy = np.zeros_like(policy)

for i in range(100):
    values, prob, error = get_value_function(policy)
    
    if not prob.status == cvxpy.OPTIMAL:
        print('{} - optimizeation status: {}'.format(i, prob.status))
        break
        
    policy = optimize_policy(values)
    
    if np.max(np.abs(old_values - values)) <= 1e-10 and np.all(old_policy == policy):
        print('converged after {} iterations'.format(i))
        break
    else:
        old_values = values.copy()
        old_policy = policy.copy()

In [None]:
vals = values.reshape(n_points[0] + 1, n_points[1] + 1).T[::-1]

ax = plt.imshow(vals, origin='upper',
                extent=domain[0] + domain[1],
                aspect='auto', cmap='viridis')
plt.colorbar(ax)

In [None]:
acts = policy.reshape(n_points[0] + 1, n_points[1] + 1).T[::-1]
ax = plt.imshow(acts, origin='upper', extent=domain[0] + domain[1], aspect='auto')
plt.colorbar(ax)

In [None]:
states = np.zeros((1000, 2), dtype=np.float)
states[0, 0] = -0.5

for i in range(len(states) - 1):
    action = delaunay.function_values_at(states[i, :], vertex_values=policy)
#     index = delaunay.state_to_index(states[i, :])
#     action = policy[index]
    states[i+1, :] = dynamics(states[i, :], action)
    if states[i+1, 0] >= 0.6:
        states[i+1:, :] = states[i+1]
        break


In [None]:
plt.plot(states[:,0], states[:, 1])
plt.xlabel('pos')
plt.ylabel('vel')
plt.xlim(-1.2, 0.6)
plt.ylim(-0.07, 0.07)