In [None]:
import numpy as np
import matplotlib.pyplot as plt
import cvxpy
%matplotlib inline

# Try to import safe_rl from system
# if it fails get it from the main folder directly instead.
try:
    import safe_rl
except ImportError:
    import utilities
    safe_rl = utilities.import_from_directory('safe_rl', '../')
    
try:
    from tqdm import tqdm
except ImportError:
    tqdm = lambda x: x

#### Get matrix for values of test points

For each vertex x, the next state is given by some function $f(x)$. The goal is to find the value $V(f(x))$ as a function of the values at the other vertices. This value can then be used as a constraint in the optimization problem.
$$V(f(x)) = a^\mathrm{T} V(\mathrm{vertices})$$

We start by finding the simplex corresponding to $f(x)$. Given this simplex's vertices, we solve for $a$ above, using one arbitrary vertex as a reference point:
$$V(f(x)) = V(v_1) + a_1 V(v_2) + a_2 V(v_3)$$

The hyperplane spanned by the simplex is given by
\begin{align*}
\xi_1 ( x_2 - x_1) + \xi_2 (y_2 - y_1) = V(v_2) - V(v_1) \\
\xi_1 ( x_3 - x_1) + \xi_2 (y_3 - y_1) = V(v_3) - V(v_1) 
\end{align*}

and, as a result,
$$V(f(x_t)) = V(v_1) + (x_t - x_1, y_t - y_1)
\left( \begin{matrix}
x_2 - x_1 & y_2 - y_1 \\
x_3 - x_1 & y_3 - y_1
\end{matrix} \right)^{-1}
\left( \begin{matrix}
V(v_2) - V(v_1) \\
V(v_3) - V(v_1)
\end{matrix} \right),
$$
which can easily be solved for $a_1$ and $a_2$.

## TODO: Handle different terminal states in a better way!

In [None]:
def dynamics(states, actions):
    """Return future states of the car"""
    states = np.atleast_2d(states)
    actions = np.atleast_2d(actions)
    future_states = states.copy()
    future_states[:, 0] += states[:, 1]
    future_states[:, 1] += 0.001 * actions[:, 0] - 0.0025 * np.cos(3 * states[:, 0])
    
    return future_states

def reward_function(states, actions, next_states):
    states = np.atleast_2d(states)
    return (states[:, 0] >= 0.6).astype('int')

domain = [[-1.3, 0.7], [-.08, .08]]
n_points = [10, 10]
delaunay = safe_rl.Delaunay(domain, n_points)

In [None]:
state_space = delaunay.index_to_state(np.arange(delaunay.nindex))
action_space = np.array([-1, 1])

# random initial policy
policy = np.random.rand(len(state_space)) > 0.5
gamma = 1 - 1e-10


def is_terminal(states):
    return states[:, 0] >= 0.6

def is_in_domain(states, domain, eps=1e-2):
    constraint = np.ones(len(states), dtype=np.bool)
    for dimension, bound in zip(states.T, domain):
        constraint &= np.logical_and(bound[0] + eps < dimension, dimension < bound[1] - eps)
    return constraint

def get_value_function(states, actions, vertex_values):
    next_states = dynamics(states, actions)
    rewards = reward_function(states, actions, next_states)
    
    legal_index = is_in_domain(next_states, domain)
    expected_values = delaunay.function_values_at(next_states[legal_index],
                                                  vertex_values=vertex_values)
    values = np.zeros_like(vertex_values)
    values[legal_index] = rewards[legal_index] + gamma * expected_values
    values[is_terminal(states)] = 1
    
    return values
    
def optimize_policy(states, vertex_values):
    values = np.zeros((len(vertex_values), len(action_space)), dtype=np.float)
    actions = np.ones((delaunay.nindex, 1), dtype=np.float)
    for i, action in enumerate(action_space):
        actions[:] = action
        values[:, i] = get_value_function(states, actions, vertex_values)
        
    best_actions = np.argmax(values, axis=1)
    best_actions[is_terminal(states)] = 1
    return action_space[best_actions]

def value_iteration(policy, values):
    
    values = get_value_function(state_space, policy, values)
    policy = optimize_policy(state_space, values)
    return policy, values
    
# def optimize_value_function(states, actions, vertex_values):
#     next_states = dynamics(states, actions)
#     rewards = reward_function(states, actions, next_states)
#     gamma = 0.9
    
#     a, b = next_states[:, 0], next_states[:, 1]
#     legal_id = (a > domain[0][0]) & (a < domain[0][1]) & (b > domain[1][0]) & (b < domain[1][1])
#     terminal_id = is_terminal(states)
#     legal_id = np.logical_and(legal_id, ~terminal_id)
#     outside_id = np.logical_and(~legal_id, ~terminal_id)
    
    
#     values = cvxpy.Variable(delaunay.nindex)
#     gamma = cvxpy.Parameter(sign='positive')
    
#     h = delaunay.function_values_at(next_states[legal_id]).toarray()
    
#     constant = np.sum(h[:, terminal_id], axis=1)
#     expected_values = constant + h[:, legal_id] * values[legal_id]

#     objective = cvxpy.Maximize(cvxpy.sum_entries(values))
#     constraints = [values[legal_id] <= rewards[legal_id] + gamma * expected_values,
#                    values[terminal_id] == 1,
#                    values[outside_id] == 0]

    
#     prob = cvxpy.Problem(objective, constraints)
#     gamma.value = 0.98
#     prob.solve()
    
#     if not prob.status == cvxpy.OPTIMAL:
#         raise ValueError('Optimization problem is {}'.format(prob.status))
    
#     return np.asarray(values.value).squeeze()
    

In [None]:
values = np.zeros(delaunay.nindex)
policy = np.zeros_like(policy)

old_values = values.copy()
old_policy = policy.copy()

converged = False

for i in tqdm(range(1000)):
    policy, values = value_iteration(policy, values)
    
    value_change = np.max(np.abs(old_values - values))
    policy_converged = np.all(old_policy == policy)
    if value_change <= 1e-4: # and policy_converged:
        converged = True
        break
    else:
        old_values[:] = values
        old_policy[:] = policy

if converged:
    print('converged after {} iterations'.format(i + 1))
else:
    print('didnt converge, error: {} and policy: {}'.format(value_change, policy_converged))

In [None]:
vals = values.reshape(n_points[0] + 1, n_points[1] + 1).T[::-1]

ax = plt.imshow(vals, origin='upper',
                extent=domain[0] + domain[1],
                aspect='auto', cmap='viridis')
plt.colorbar(ax)

In [None]:
acts = policy.reshape(n_points[0] + 1, n_points[1] + 1).T[::-1]
ax = plt.imshow(acts, origin='upper', extent=domain[0] + domain[1], aspect='auto')
plt.colorbar(ax)

In [None]:
states = np.zeros((1000, 2), dtype=np.float)
states[0, 0] = -0.5

for i in range(len(states) - 1):
    # interpolate action
    action = delaunay.function_values_at(states[[i], :], vertex_values=policy)
    action = action_space[np.argmin(np.abs(action - action_space))]

    states[i+1, :] = dynamics(states[i, :], action)
    if states[i+1, 0] >= 0.6:
        states[i+1:, :] = states[i+1]
        break


In [None]:
plt.plot(states[:,0], states[:, 1])
plt.xlabel('pos')
plt.ylabel('vel')
plt.xlim(-1.2, 0.6)
plt.ylim(-0.07, 0.07)