# Intelligent Intersection

In [None]:
## Install dependencies for the project ##
sudo apt install ffmpeg
pip install -U pip
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install -e ./hj_reachability
pip install -U numpy matplotlib plotly kaleido tqdm moviepy scikit-image ipywidgets nbformat ipympl

In [None]:
%matplotlib ipympl

from functools import partial
from time import time

from tqdm import tqdm

import jax
import jax.numpy as jnp
import numpy as np
import numpy.lib.stride_tricks as st

from IPython.display import HTML
from ipywidgets import interact, IntSlider
import matplotlib.animation as anim
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import skimage.io as sio

import hj_reachability as hj
import hj_reachability.shapes as shp

### System

In [None]:
reach_dynamics = hj.systems.SVEA5D(min_steer=-jnp.pi, 
                                   max_steer=+jnp.pi,
                                   min_accel=-0.5,
                                   max_accel=+0.5).with_mode('reach')
avoid_dynamics = hj.systems.SVEA5D(min_steer=-jnp.pi, 
                                   max_steer=+jnp.pi,
                                   min_accel=-0.5,
                                   max_accel=+0.5).with_mode('avoid')

min_bounds = np.array([-1.2, -1.2, -np.pi, -np.pi/5, +0])
max_bounds = np.array([+1.2, +1.2, +np.pi, +np.pi/5, +1])
grid = hj.Grid.from_lattice_parameters_and_boundary_conditions(hj.sets.Box(min_bounds, max_bounds),
                                                               (31, 31, 31, 7, 11),
                                                               periodic_dims=2)

solver_settings = hj.SolverSettings.with_accuracy("high")

AVOID_MARGIN = 0.1

In [None]:
X, Y, YAW, DELTA, VEL = range(5)

north = shp.intersection(shp.upper_half_space(grid, YAW, +np.pi/2 - np.pi/5),
                         shp.lower_half_space(grid, YAW, +np.pi/2 + np.pi/5))
south = shp.intersection(shp.upper_half_space(grid, YAW, -np.pi/2 - np.pi/5),
                         shp.lower_half_space(grid, YAW, -np.pi/2 + np.pi/5))
east = shp.intersection(shp.upper_half_space(grid, YAW, - np.pi/5),
                        shp.lower_half_space(grid, YAW, + np.pi/5))
west = shp.union(shp.upper_half_space(grid, YAW, +np.pi - np.pi/5),
                 shp.lower_half_space(grid, YAW, -np.pi + np.pi/5))

slow_lower_speed_limit = shp.upper_half_space(grid, VEL, +0.4)
fast_lower_speed_limit = shp.upper_half_space(grid, VEL, +0.7)

# Roads
r1 = shp.intersection(slow_lower_speed_limit,
                      shp.lower_half_space(grid, X, +0.0),
                      shp.upper_half_space(grid, Y, +0.2),
                      shp.lower_half_space(grid, Y, +1.1))
r2 = shp.intersection(slow_lower_speed_limit,
                      shp.lower_half_space(grid, Y, +0.5),
                      shp.upper_half_space(grid, X, -0.5),
                      shp.lower_half_space(grid, X, +0.5))
r3 = shp.intersection(slow_lower_speed_limit,
                      shp.upper_half_space(grid, X, +0.0),
                      shp.upper_half_space(grid, Y, +0.2),
                      shp.lower_half_space(grid, Y, +1.1))

# Ports (in & out of region)
p1w = shp.intersection(r1,  west,   shp.upper_half_space(grid, Y, +0.6))
p1e = shp.intersection(r1,  east,   shp.lower_half_space(grid, Y, +0.6))
p2n = shp.intersection(r2,  north,  shp.upper_half_space(grid, X, 0.0))
p2s = shp.intersection(r2,  south,  shp.lower_half_space(grid, X, 0.0))
p3w = shp.intersection(r3,  west,   shp.upper_half_space(grid, Y, +0.6))
p3e = shp.intersection(r3,  east,   shp.lower_half_space(grid, Y, +0.6))

p1 = (p1w, p1e)
p2 = (p2n, p2s)
p3 = (p3w, p3e)

# intersections 
i1 = shp.intersection(fast_lower_speed_limit,
                      shp.rectangle(grid,
                                    target_min=[-0.5, +0.2, *min_bounds[Y+1:]],
                                    target_max=[+0.5, +1.1, *max_bounds[Y+1:]]))

# Regions
rg1 = shp.union(i1, *p1, *p2, *p3)

# Specific routes
p2n_p1w = shp.union(p2n, i1, p1w)
p3w_p2s = shp.union(p3w, i1, p2s)
p1e_p3e = shp.union(p1e, i1, p3e)

# Entry and Exit targets
entry1  = dict(target_min=[-1.0, +0.2, *min_bounds[Y+1:]],
               target_max=[-0.7, +0.5, *max_bounds[Y+1:]])
entry2  = dict(target_min=[+0.0, -1.2, *min_bounds[Y+1:]],
               target_max=[+0.5, -0.7, *max_bounds[Y+1:]])
entry3  = dict(target_min=[+0.7, +0.6, *min_bounds[Y+1:]],
               target_max=[+1.1, +1.0, *max_bounds[Y+1:]])
exit1   = dict(target_min=[-1.0, +0.7, *min_bounds[Y+1:]],
               target_max=[-0.7, +1.0, *max_bounds[Y+1:]])
exit2   = dict(target_min=[-0.4, -1.0, *min_bounds[Y+1:]],
               target_max=[-0.1, -0.7, *max_bounds[Y+1:]])
exit3   = dict(target_min=[+0.7, +0.2, *min_bounds[Y+1:]],
               target_max=[+1.0, +0.5, *max_bounds[Y+1:]])

start1, end1 = shp.rectangle(grid, **entry1), shp.rectangle(grid, **exit1)
start2, end2 = shp.rectangle(grid, **entry2), shp.rectangle(grid, **exit2)
start3, end3 = shp.rectangle(grid, **entry3), shp.rectangle(grid, **exit3)

def find_windows(mask, N=1, M=None):
    """Find the indices of windows where at least N but less than M consequtive elements are true."""
    mask = np.asarray(mask)
    assert N <= len(mask)
    window_view = st.sliding_window_view(mask, window_shape=N)
    ix, = np.where(N == np.sum(window_view, axis=1))
    if M is not None:
        assert M <= len(mask)
        assert N < M
        window_view = st.sliding_window_view(mask, window_shape=M)
        jx = ix[ix + M <= len(mask)]
        iix, = np.where(np.sum(window_view[jx], axis=1) < M)
        ix = ix[iix]
    return ix

def earliest_window(mask, N=1, M=None):
    """Find the first window where at least N but less than M consequtive elements are true."""
    mask = np.asarray(mask)
    windows = find_windows(mask, N, M)
    if len(windows) == 0:
        return np.array([], int)
    i = windows[0] # Earliest window
    mask = mask[i:] if M is None else mask[i:i+M]
    for j, n in enumerate(mask.cumsum()[N-1:]):
        if n != j + N:
            break
    return np.arange(i, i+n)

def new_timeline(target_time, start_time=0, time_step=0.2):
    assert time_step > 0
    is_forward = target_time >= start_time
    target_time += 1e-5 if is_forward else -1e-5
    time_step *= 1 if is_forward else -1
    return np.arange(start_time, target_time, time_step)

def brs(times, target, constraints=None, *, mode='reach'):
    dynamics = reach_dynamics if mode == 'reach' else avoid_dynamics
    times = -np.asarray(times)
    if not  shp.is_invariant(grid, times, target):
        target = np.flip(target, axis=0)
    if not shp.is_invariant(grid, times, constraints):
        constraints = np.flip(constraints, axis=0)
    values = hj.solve(solver_settings, dynamics, grid,
                      times, target, constraints)
    values = np.asarray(values)
    values = np.flip(values, axis=0)
    return values

def frs(times, target, constraints=None, *, mode='avoid'):
    times = np.asarray(times)
    dynamics = reach_dynamics if mode == 'reach' else avoid_dynamics
    values = hj.solve(solver_settings, dynamics, grid,
                      times, target, constraints)
    values = np.asarray(values)
    return values


### Save/Load data

In [None]:
## Save results to file ##

if False: 
    np.savez_compressed('v1.npz',
                        v1_pass1=v1_pass1,
                        v1_pass2=v1_pass2,
                        v1_pass3=v1_pass3,
                        v1_pass4=v1_pass4)

if False:
    np.savez_compressed('v2.npz',
                        v2_pass1=v2_pass1,
                        v2_pass2=v2_pass2,
                        v2_pass3=v2_pass3,
                        v2_pass4=v2_pass4)

if False:
    np.savez_compressed('v3.npz',
                        v3_pass1=v3_pass1,
                        v3_pass2=v3_pass2,
                        v3_pass3=v3_pass3,
                        v3_pass4=v3_pass4)

if False:
    import numpy as np 
    from scipy.io import savemat
    npfile = np.load('arrays.npz')
    savemat('arrays.mat', npfile)

In [None]:
if True: ## Load results from file ##
    order = '123'
    with np.load(f'{order}/v1.npz') as npfile:
        # v1_pass1 = npfile['v1_pass1']
        # v1_pass2 = npfile['v1_pass2']
        # v1_pass3 = npfile['v1_pass3']
        v1_pass4 = npfile['v1_pass4']

    with np.load(f'{order}/v2.npz') as npfile:
        # v2_pass1 = npfile['v2_pass1']
        # v2_pass2 = npfile['v2_pass2']
        # v2_pass3 = npfile['v2_pass3']
        v2_pass4 = npfile['v2_pass4']

    with np.load(f'{order}/v3.npz') as npfile:
        # v2_pass1 = npfile['v3_pass1']
        # v3_pass2 = npfile['v3_pass2']
        # v3_pass3 = npfile['v3_pass3']
        v3_pass4 = npfile['v3_pass4']

## Scenario

In [None]:
def run_analysis(*passes, **kwargs):
    ALL_PASSES = ['pass1', 'pass2', 'pass3', 'pass4']
    passes = passes or ALL_PASSES
    if 'all' in passes:
        passes = ALL_PASSES
    passes = [name for name in ALL_PASSES if name in passes]

    returns = list(passes)

    for name in passes:
        assert name in ALL_PASSES, f'Invalid pass: {name}'
        i = ALL_PASSES.index(name)
        
        if i > 0 and ALL_PASSES[i-1] not in passes + list(kwargs):
            passes += [ALL_PASSES[i-1]]
    passes = [name for name in ALL_PASSES if name in passes]

    msg = 'Running analysis with the following passes: '
    msg += ', '.join(passes) + '\n'
    msg += '-' * (len(msg)-1)
    print(msg)

    def to_shared(values, **kwargs):
        return shp.project_onto(values, 0, 1, 2, **kwargs)
    def when_overlapping(a, b):
        return np.where(shp.project_onto(shp.intersection(a, b), 0) <= 0)[0]
    
    if 'pass1' in passes:
        print('\n', 'Pass 1: Initial BRS')
        times = kwargs['times']
        rules = kwargs['rules']
        end = kwargs['end']

        target = shp.make_tube(times, end)
        constraints = rules

        start_time = time()
        pass1 = brs(times, target, constraints)
        stop_time = time()

        print(f'Time To Compute: {stop_time - start_time:.02f}')
        
        kwargs['pass1'] = pass1

    if 'pass2' in passes:
        print('\n', 'Pass 2: Avoidance')
        pass2 = kwargs['pass1'].copy()
        times = kwargs['times']
        avoid = kwargs['avoid']
        end = kwargs['end']

        if avoid:
            avoid_target = shp.union(
                *map(lambda target: to_shared(target, keepdims=True), avoid),
            ) - AVOID_MARGIN # increase avoid set with heuristic margin

            interact_window = when_overlapping(pass2, avoid_target)
            if 0 < interact_window.size:
                i, j = interact_window[0], interact_window[-1] + 1

                # Recompute solution until after last interaction
                constraints = shp.setminus(pass2, avoid_target)[:j+1]
                target = pass2[:j+1]    # last step is target to reach pass1
                target[:-1] = end       # all other steps are recomputed to end
                
                start_time = time()
                pass2[:j+1] = brs(times[:j+1], target, constraints)
                stop_time = time()
            
                print(f'Time To Compute: {stop_time - start_time:.02f}')
                print(f'First Interaction: {times[i]:.01f}')
                print(f'Last Interaction: {times[j-1]:.01f}')

        kwargs['pass2'] = pass2

    if 'pass3' in passes:
        print('\n', 'Pass 3: Planning, Departure')
        pass3 = kwargs['pass2'].copy()
        times = kwargs['times']
        start = kwargs['start']

        depart_target = shp.intersection(pass3, start)
        depart_window = earliest_window(shp.project_onto(depart_target, 0) <= 0, 5)
        assert depart_window.size > 0, 'No window to depart'
        i = depart_window[0] # Earliest departure
        j = depart_window[min(11, len(depart_window)-1)] # Depart within 2 seconds of earliest departure
        depart_target[j:] = 1

        start_time = time()
        pass3[i:] = frs(times[i:], depart_target[i:], pass3[i:])
        pass3[:i] = 1 # Remove all values before departure
        stop_time = time()

        print(f'Time To Compute: {stop_time - start_time:.02f}')
        print(f'Earliest Departure: {times[i]:.01f}')
        print(f'Latest Departure: {times[j-1]:.01f}')

        kwargs['pass3'] = pass3

    if 'pass4' in passes:
        print('\n', 'Pass 4: Planning, Arrival')
        pass4 = kwargs['pass3'].copy()
        times = kwargs['times']
        end = kwargs['end']

        arrival_target = shp.intersection(pass4, end)
        arrival_window = earliest_window(shp.project_onto(arrival_target, 0) <= 0, N=5)
        assert arrival_window.size > 0, 'No window to arrive'
        m = arrival_window[0] # Earliest arrival
        n = arrival_window[min(11, len(depart_window)-1)] # Arrive within 2 seconds of earliest arrival
        arrival_target[n:] = 1

        start_time = time()
        pass4[i:n] = brs(times[i:n], arrival_target[i:n], pass4[i:n])
        pass4[n:] = 1 # Remove all values after arrival
        stop_time = time()

        print(f'Time To Compute: {stop_time - start_time:.02f}')
        print(f'Earliest Arrival: {times[m]:.01f}')
        print(f'Latest Arrival: {times[n-1]:.01f}')

        kwargs['pass4'] = pass4
    
    return [kwargs[name] for name in returns]

def run_many(*objectives, **kwargs):
    results = []
    for o in objectives:
        results += run_analysis('pass4', **o, **kwargs, avoid=results)
        print('\n')
    return results

### All Vehicles

In [None]:
(
    v1_pass4,
    v2_pass4,
    v3_pass4,
) = run_many(
    dict(rules=p3w_p2s, start=start3, end=end2),
    dict(rules=p2n_p1w, start=start2, end=end1),
    dict(rules=p1e_p3e, start=start1, end=end3),
    times=new_timeline(10),
)

### Vehicle 1

In [None]:
## Objective ##
times = new_timeline(10)
start = start3
end = end2
rule = p3w_p2s
avoid = ()

## Run analysis ##
(
    v1_pass1,
    v1_pass2,
    v1_pass3,
    v1_pass4,
) = run_analysis('all',
                 times=new_timeline(10), 
                 rules=p3w_p2s, avoid=[],
                 start=start3, end=end2)

#### Manual

In [None]:
print('Pass 1: Initial BRS')

v1_pass1 = brs(times, shp.make_tube(times, end), rule)

print()

In [None]:
print('Pass 2: Avoidance')

if not avoid:
    v1_pass2 = v1_pass1
else:
    v1_avoid = brs(times, 
                target=shp.union(
                    *map(lambda target: shp.project_onto(target, 0, 1, 2, keepdims=True),
                         avoid),
                ),
                constraints=v1_pass1, 
                mode='avoid')

    v1_pass2 = brs(times, shp.make_tube(times, end), shp.setminus(v1_pass1, v1_avoid - AVOID_MARGIN))

print()

In [None]:
print('Pass 3: Planning, Departure')

depart_target = hj.shapes.intersection(v1_pass2, start)
depart_window = earliest_window(shp.project_onto(depart_target, 0) <= 0, 5)
assert depart_window.size > 0, 'No window to depart'
i = depart_window[0] # Earliest departure
j = depart_window[min(11, len(depart_window)-1)] # Depart within 2 seconds of earliest departure
depart_target[j:] = 1

v1_pass3 = np.ones_like(v1_pass2)
v1_pass3[i:] = frs(times[i:], depart_target[i:], v1_pass2[i:])

print(f'Earliest Departure: {times[i]:.01f}')
print(f'Latest Departure: {times[j-1]:.01f}')
print()

In [None]:
print('Pass 4: Planning, Arrival')

arrival_target = hj.shapes.intersection(v1_pass3, end)
arrival_window = earliest_window(shp.project_onto(arrival_target, 0) <= 0, N=5)
assert arrival_window.size > 0, 'No window to arrive'
m = arrival_window[0] # Earliest arrival
n = arrival_window[min(11, len(depart_window)-1)] # Arrive within 2 seconds of earliest arrival
arrival_target[n:] = 1

v1_pass4 = np.ones_like(v1_pass3)
v1_pass4[i:n] = brs(times[i:n], arrival_target[i:n], v1_pass3[i:n])

print(f'Earliest Arrival: {times[m-1]:.01f}')
print(f'Latest Arrival: {times[n-1]:.01f}')
print()

### Vehicle 2

In [None]:
## Run analysis ##
(
    v2_pass1, 
    v2_pass2,
    v2_pass3,
    v2_pass4,
) = run_analysis('all',
                 times=new_timeline(10), 
                 rules=p2n_p1w, avoid=[v1_pass4, v3_pass4],
                 start=start2, end=end1)

#### Manual

In [None]:
print('Pass 1: Initial BRS')

v2_pass1 = brs(times, shp.make_tube(times, end), rule)

print()

In [None]:
print('Pass 2: Avoidance')

if not avoid:
    v2_pass2 = v2_pass1
else:
    v2_avoid = brs(times, 
                target=shp.union(
                    *map(lambda target: shp.project_onto(target, 0, 1, 2, keepdims=True),
                         avoid),
                ),
                constraints=v2_pass1, 
                mode='avoid')

    v2_pass2 = brs(times, shp.make_tube(times, end), shp.setminus(v2_pass1, v2_avoid - AVOID_MARGIN))

print()

In [None]:
print('Pass 3: Planning, Departure')

depart_target = hj.shapes.intersection(v2_pass2, start)
depart_window = earliest_window(shp.project_onto(depart_target, 0) <= 0, 5)
assert depart_window.size > 0, 'No window to depart'
i = depart_window[0] # Earliest departure
j = depart_window[min(11, len(depart_window)-1)] # Depart within 2 seconds of earliest departure
depart_target[j:] = 1

v2_pass3 = np.ones_like(v2_pass2)
v2_pass3[i:] = frs(times[i:], depart_target[i:], v2_pass2[i:])

print(f'Earliest Departure: {times[i]:.01f}')
print(f'Latest Departure: {times[j-1]:.01f}')
print()

In [None]:
print('Pass 4: Planning, Arrival')

arrival_target = hj.shapes.intersection(v2_pass3, end)
arrival_window = earliest_window(shp.project_onto(arrival_target, 0) <= 0, N=5)
assert arrival_window.size > 0, 'No window to arrive'
m = arrival_window[0] # Earliest arrival
n = arrival_window[min(11, len(depart_window)-1)] # Arrive within 2 seconds of earliest arrival
arrival_target[n:] = 1

v2_pass4 = np.ones_like(v2_pass3)
v2_pass4[i:n] = brs(times[i:n], arrival_target[i:n], v2_pass3[i:n])

print(f'Earliest Arrival: {times[m-1]:.01f}')
print(f'Latest Arrival: {times[n-1]:.01f}')
print()

### Vehicle 3

In [None]:
## Run analysis ##
(
    v3_pass1,
    v3_pass2,
    v3_pass3,
    v3_pass4,
) = run_analysis('all',
                 times=new_timeline(10), 
                 rules=p1e_p3e, avoid=[v1_pass4, v2_pass4],
                 start=start1, end=end3)

#### Manual

In [None]:
print('Pass 1: Initial BRS')

v3_pass1 = brs(times, shp.make_tube(times, end), rule)

print()

In [None]:
print('Pass 2: Avoidance')

if not avoid:
    v3_pass2 = v3_pass1
else:
    v3_avoid = brs(times, 
                target=shp.union(
                    *map(lambda target: shp.project_onto(target, 0, 1, 2, keepdims=True),
                         avoid),
                ),
                constraints=v3_pass1, 
                mode='avoid')

    v3_pass2 = brs(times, shp.make_tube(times, end), shp.setminus(v3_pass1, v3_avoid - AVOID_MARGIN))

print()

In [None]:
print('Pass 3: Planning, Departure')

depart_target = hj.shapes.intersection(v3_pass2, start)
depart_window = earliest_window(shp.project_onto(depart_target, 0) <= 0, 5)
assert depart_window.size > 0, 'No window to depart'
i = depart_window[0] # Earliest departure
j = depart_window[min(11, len(depart_window)-1)] # Depart within 2 seconds of earliest departure
depart_target[j:] = 1

v3_pass3 = np.ones_like(v3_pass2)
v3_pass3[i:] = frs(times[i:], depart_target[i:], v3_pass2[i:])

print(f'Earliest Departure: {times[i]:.01f}')
print(f'Latest Departure: {times[j-1]:.01f}')
print()

In [None]:
print('Pass 4: Planning, Arrival')

arrival_target = hj.shapes.intersection(v3_pass3, end)
arrival_window = earliest_window(shp.project_onto(arrival_target, 0) <= 0, N=5)
assert arrival_window.size > 0, 'No window to arrive'
m = arrival_window[0] # Earliest arrival
n = arrival_window[min(11, len(depart_window)-1)] # Arrive within 2 seconds of earliest arrival
arrival_target[n:] = 1

v3_pass4 = np.ones_like(v3_pass3)
v3_pass4[i:n] = brs(times[i:n], arrival_target[i:n], v3_pass3[i:n])

print(f'Earliest Arrival: {times[m-1]:.01f}')
print(f'Latest Arrival: {times[n-1]:.01f}')
print()

## Scratchpad

```
Theta_1 := ...
Theta_2 := BRT^M(T, K)
Theta_2 := Theta_2 \ BRT^m(Theta_1, Theta_2) # We can be selective for which time steps we do this
Theta_2 := FRT(Theta_2[0], Theta_2) # We can stop when Theta_2[i] < FRT(.)[i]
```

In [None]:
r = brs(new_timeline(10),
        target=v1_pass4,
        constraints=v2_pass1,
        mode='avoid')

I, F = interact_tubes(
    new_timeline(10),
    ('Reds',   v1_pass4),
    ('Mint',   r),
    eye=EYE_W,
    opacity=0.8,
)
F(50).show()

## Plotting

In [None]:
## Code ##

def sphere_to_cartesian(r, theta, phi):
    theta *= np.pi/180
    phi *= np.pi/180
    return dict(x=r*np.sin(theta)*np.cos(phi),
                y=r*np.sin(theta)*np.sin(phi),
                z=r*np.cos(theta))

def to_rect(kw):
    return dict(xlim=(kw['target_min'][0], kw['target_max'][0]),
                ylim=(kw['target_min'][1], kw['target_max'][1])) 

def plot_line(ax, vf, time_idx, **kwargs):
    N = 5
    ix = np.array([np.unravel_index(v.argmin(), v.shape) if v.min() <= 0 else (-1, -1) 
                    for v in shp.project_onto(vf, 0, 1, 2)])
    xs = np.array([np.nan if i == -1 else grid.coordinate_vectors[0][i] for i in ix[:, 0]])
    xs = np.convolve(xs, np.ones(N)/N, mode='valid')
    ys = np.array([np.nan if i == -1 else grid.coordinate_vectors[1][i] for i in ix[:, 1]])
    ys = np.convolve(ys, np.ones(N)/N, mode='valid')
    return [ax.plot(xs[:time_idx], ys[:time_idx], **kwargs)]

def plot_arrow(ax, vf, time_idx, **kwargs):
    artists = []
    vf = vf[time_idx]
    vf = shp.project_onto(vf, 0, 1, 2) # project non-shared states to the shared states
    ix, iy, ia = np.unravel_index(vf.argmin(), vf.shape)
    x = grid.coordinate_vectors[0][ix]
    y = grid.coordinate_vectors[1][iy]
    a = grid.coordinate_vectors[2][ia]
    if vf[ix, iy, ia] <= 0:
        artists += [ax.arrow(x, y, 0.1*np.cos(a), 0.1*np.sin(a), **kwargs)]
    return artists

def plot_set(ax, vf, time_idx, **kwargs):
    vf = vf[time_idx]
    vf = shp.project_onto(vf, 0, 1) # project non-shared states to the shared states
    vf = np.where(vf <= 0, 0.5, np.nan)            
    vf = vf.T # imshow flips axis
    return [ax.imshow(vf, vmin=0, vmax=1, origin='lower', **kwargs)]

def plot_rect(ax, xlim, ylim, **kwargs):
    xy = xlim[0], ylim[0]
    w, h = xlim[1] - xlim[0], ylim[1] - ylim[0]
    rect = patches.Rectangle(xy, w, h, linewidth=2, facecolor='none', linestyle='--', **kwargs)
    ax.add_patch(rect)
    return rect

def make_frames(eye):
    def rotate_z(x, y, z, theta):
        w = x+1j*y
        return np.real(np.exp(1j*theta)*w), np.imag(np.exp(1j*theta)*w), z
    frames=[]
    for t in np.arange(0, 6.26, 0.1):
        xe, ye, ze = rotate_z(eye['x'], eye['y'], eye['z'], -t)
        frames.append(go.Frame(layout=dict(scene=dict(camera=dict(eye=dict(x=xe, y=ye, z=ze))))))
    return frames

def interact_vf(times, values):
    window = np.where(shp.project_onto(values, 0) <= 0)
    interaction = partial(interact, time_idx=IntSlider(int(np.median(window)), min=0, max=len(times)-1))
    def render_frame(time_idx):
        data = [
            go.Surface(
                x=grid.coordinate_vectors[0],
                y=np.flip(grid.coordinate_vectors[1]),
                z=shp.project_onto(values[time_idx], 0, 1),
                showscale=False,
                contours=dict(z=dict(show=True, start=0, end=0.1, size=0.1, color='black')),
            ),
        ]
        eye = dict(x=-0.5, y=-1.4, z=1.6)
        fw = go.FigureWidget(data=data)
        # fw.update_yaxes(autorange='reversed')
        fw.update_layout(width=600, height=600, 
                         scene=dict(xaxis_visible=False,
                                    yaxis_visible=False,
                                    zaxis_title='V'),
                         scene_camera=dict(eye=eye))
        fw._config = dict(toImageButtonOptions=dict(height=600, width=600, scale=6))
        return fw
    return interaction, render_frame

def record_levelset(values, target=None):
    fig, ax = plt.figure(layout='tight'), plt.gca()

    ax.set_xlabel("x [m]")
    ax.set_ylabel("y [m]")
    ax.axis('off')

    vmin, vmax = values.min(), values.max()
    levels = np.linspace(vmin, vmax, 20)

    if target is not None:
        is_target_invariant = values.shape[1:] == target.shape
        target = target if is_target_invariant else target[i, ...]

    def render_frame(i, colorbar=False):
        ax.contourf(grid.coordinate_vectors[0],
                    grid.coordinate_vectors[1],
                    shp.project_onto(values[i, ...], 0, 1).T,
                    vmin=vmin,
                    vmax=vmax,
                    levels=levels)
        if colorbar:
            fig.colorbar()
        if not values[i, ...].min() < 0:
            return
        ax.contour(grid.coordinate_vectors[0],
                   grid.coordinate_vectors[1],
                   shp.project_onto(values[i, ...], 0, 1).T,
                   levels=0,
                   colors="black",
                   linewidths=3)
        if target is None:
            return
        ax.contour(grid.coordinate_vectors[0],
                   grid.coordinate_vectors[1],
                   shp.project_onto(target, 0, 1).T,
                   levels=0,
                   colors="gray",
                   linewidths=3)

    return HTML(anim.FuncAnimation(fig, render_frame, values.shape[0], interval=200).to_html5_video())

def interact_scenario(times, *pairs):
    fig, ax = plt.figure(), plt.gca()
    extent=[min_bounds[0], max_bounds[0],
            min_bounds[1], max_bounds[1]]
    ax.set_xlabel("x [m]")
    ax.set_ylabel("y [m]")
    ax.axis('off')
    ax.invert_yaxis()
    background = plt.imread('background.png')
    interaction = partial(interact, time_idx=IntSlider(0, min=0, max=len(times)-1))
    def render_frame(time_idx):
        # ax.clear()
        ax.imshow(background, extent=extent)
        for cmap, vf in pairs:
            cmap = plt.get_cmap(cmap)
            plot_set(ax, vf, time_idx, alpha=0.9, cmap=cmap, extent=extent)
            # plot_arrow(ax, vf, time_idx, color=cmap(0.75))
        fig.tight_layout()
        return fig
    return interaction, render_frame

def record_scenario(times, *pairs):
    num_frames = times.shape[0]
    with tqdm(total=num_frames) as pbar:
        fig, ax = plt.figure(dpi=96), plt.gca()
        extent=[min_bounds[0], max_bounds[0],
                min_bounds[1], max_bounds[1]]
        ax.set_xlabel("x [m]")
        ax.set_ylabel("y [m]")
        ax.axis('off')
        ax.invert_yaxis()
        background = plt.imread('background.png')
        def render_frame(i):
            ax.clear()
            artists = [ax.imshow(background, extent=extent)]
            for cmap, vf in pairs:
                cmap = plt.get_cmap(cmap)
                artists += plot_set(ax, vf, i, alpha=0.9, cmap=cmap, extent=extent, animated=True)
            fig.tight_layout()
            pbar.update()
            return artists
        return HTML(anim.FuncAnimation(fig, render_frame, frames=num_frames, interval=200).to_html5_video())

def comic_scenario(times, ncols, *pairs):
    fig, axs = plt.subplots(len(pairs), ncols, sharex=True, sharey=True, figsize=(ncols*3, len(pairs)*3))
    extent=[min_bounds[0], max_bounds[0],
            min_bounds[1], max_bounds[1]]
    background = plt.imread('background.png')
    ix = np.linspace(0, len(times)-1, ncols + 2, dtype=int)[1:-1]
    for i, ax in enumerate(axs[0, :]):
        ax.set_title(f't = {times[ix][i]:0.01f} s')
    for ax in axs[:, 0]:
        ax.set_ylabel("y [m]")
    for ax in axs[-1, :]:
        ax.set_xlabel("x [m]")
    for ax in axs.flatten():
        # ax.axis('off')
        ax.invert_yaxis()
        ax.imshow(background, extent=extent)
    for i, j in np.ndindex(axs.shape):
        ax = axs[i, j]
        time_idx = ix[j]
        for cmap, vf, _ in pairs[:i]:
            cmap = plt.get_cmap('gray')
            plot_set(ax, vf, time_idx, alpha=0.9, cmap=cmap, extent=extent)
        cmap, vf, opts = pairs[i]
        cmap = plt.get_cmap(cmap)
        plot_set(ax, vf, time_idx, alpha=0.9, cmap=cmap, extent=extent)
        if 'entry' in opts:
            plot_rect(ax, **to_rect(opts['entry']), edgecolor=cmap(0.8))
        if 'exit' in opts:
            plot_rect(ax, **to_rect(opts['exit']), edgecolor=cmap(0.8))
    fig.tight_layout()
    handles, labels = [], []
    for i, (cmap, _, _) in enumerate(pairs):
        cmap = plt.get_cmap(cmap)
        handles.append(patches.Patch(color=cmap(0.5)))
        labels.append(r'$\mathcal{V\,}' + f'_{i+1}$')
    handles += [patches.Patch(color='gray')]
    labels += ['Obstacles']
    fig.subplots_adjust(bottom=fig.subplotpars.bottom + 0.04)
    fig.legend(handles, labels, ncol=len(pairs)+1, loc='center', bbox_to_anchor=(0.5, 0.025), borderaxespad=0)
    return fig

def interact_tubes(times, *triplets, eye):
    background = sio.imread('background.png', as_gray=True)
    background = np.flipud(background)
    interaction = partial(interact, time_idx=IntSlider(len(times)-1, min=0, max=len(times)-1))
    def render_frame(time_idx):
        data = []

        meshgrid = np.mgrid[times[0]:times[-1]:complex(0, len(times)),
                            min_bounds[0]:max_bounds[0]:complex(0, grid.shape[0]), 
                            min_bounds[1]:max_bounds[1]:complex(0, grid.shape[1])]
        
        data += [
            go.Surface(
                x=np.linspace(min_bounds[0], max_bounds[0], background.shape[0]),
                y=np.linspace(min_bounds[1], max_bounds[1], background.shape[1]),
                z=np.zeros_like(background)-0.1,
                surfacecolor=background,
                colorscale='gray', 
                showscale=False,
            ),
        ]

        for values, colorscale, kwargs in triplets:

            vf = shp.project_onto(values, 0, 1, 2)
            if time_idx < len(times)-1:
                vf[time_idx+1:, ...] = 1

            data += [
                go.Isosurface(
                    x=meshgrid[1].flatten(),
                    y=meshgrid[2].flatten(),
                    z=meshgrid[0].flatten(),
                    value=vf.flatten(),
                    colorscale=colorscale,
                    showscale=False,
                    isomin=0,
                    surface_count=1,
                    isomax=0,
                    caps=dict(x_show=True, y_show=True),
                    **kwargs,
                ),
            ]
        
        fw = go.FigureWidget(data=data)
        fw.layout.update(width=720, height=720, 
                         margin=dict(l=10, r=10, t=10, b=10),
                        #  legend=dict(yanchor='bottom', xanchor='left', x=0.05, y=0.05, font=dict(size=16)),
                         scene=dict(zaxis_title='t'),
                         scene_camera=dict(eye=eye))
        fw._config = dict(toImageButtonOptions=dict(height=720, width=720, scale=6))
        return fw
    return interaction, render_frame

def record_tubes(times, *pairs):
    background = sio.imread('background.png', as_gray=True)
    background = np.flipud(background)

    def render_frame(time_idx):
        data = []

        meshgrid = np.mgrid[times[0]:times[-1]:complex(0, len(times)),
                            min_bounds[0]:max_bounds[0]:complex(0, grid.shape[0]), 
                            min_bounds[1]:max_bounds[1]:complex(0, grid.shape[1])]
        
        data += [
            go.Surface(
                x=np.linspace(min_bounds[0], max_bounds[0], background.shape[0]),
                y=np.linspace(min_bounds[1], max_bounds[1], background.shape[1]),
                z=np.zeros_like(background)-0.1,
                surfacecolor=background,
                colorscale='gray', 
                showscale=False,
            ),
        ]

        for colorscale, values in pairs:

            vf = shp.project_onto(values, 0, 1, 2)
            if time_idx < len(times)-1:
                vf[time_idx+1:, ...] = 1

            data += [
                go.Isosurface(
                    x=meshgrid[1].flatten(),
                    y=meshgrid[2].flatten(),
                    z=meshgrid[0].flatten(),
                    value=vf.flatten(),
                    colorscale=colorscale,
                    opacity=1.0,
                    showscale=False,
                    isomin=0,
                    surface_count=1,
                    isomax=0,
                    caps=dict(x_show=True, y_show=True)
                ),
            ]
        
        eye = dict(x=-0.5, y=-1.4, z=1.6)
        frames = make_frames(eye)

        fig = go.Figure(data=data)
        fig.update_layout(width=600, height=600, 
                          scene=dict(zaxis_title='t'),
                          scene_camera=dict(eye=eye))
        
        from concurrent.futures import ThreadPoolExecutor
        with tqdm(total=len(frames)) as pbar:
            with ThreadPoolExecutor() as pool:
                def f(i, frame):
                    fig.update_layout(frame.layout)
                    fig.write_image(f'frames/frame{i}.png', format='png', width=600, height=600, scale=6)
                    pbar.update()
                for i, frame in enumerate(tqdm(frames)):
                    pool.submit(f, i, frame)
            
        import moviepy.editor as mpy
        mpy.ImageSequenceClip('frames', fps=24).write_videofile('animation.mp4')

        return fig
    return render_frame

EYE_W   = sphere_to_cartesian(2.2, 45, -90 - 90)
EYE_WSW = sphere_to_cartesian(2.2, 70, -90 - 70)
EYE_SW  = sphere_to_cartesian(2.5, 60, -90 - 45)
EYE_S   = sphere_to_cartesian(2.5, 45, -90 + 0)
EYE_SE  = sphere_to_cartesian(2.5, 60, -90 + 45)
EYE_ESE = sphere_to_cartesian(2.2, 70, -90 + 70)
EYE_E   = sphere_to_cartesian(2.2, 45, -90 + 90)

In [None]:
I, F = interact_vf(new_timeline(10), v2_pass1)
I(F)

In [None]:
record_levelset(np.flip(v2_pass1[20:]), end1)

In [None]:
comic_scenario(
    new_timeline(10), 6,
    ('Reds',    v1_pass4, dict(entry=entry3, exit=exit2)),
    ('Blues',   v2_pass4, dict(entry=entry2, exit=exit1)),
    ('Purples', v3_pass4, dict(entry=entry1, exit=exit3)),
)

In [None]:
I, F = interact_scenario(
    new_timeline(10),
    # ('Reds',    v1_pass4),
    ('Blues',   v2_pass1),
    # ('Blues',   shp.setminus(shp.project_onto(v2_pass1, 0, 1, 2, keepdims=True),
    #                          shp.project_onto(v1_pass1-AVOID_MARGIN, 0, 1, 2, keepdims=True))),
    # ('Purples', v3_pass4),
)
I(F)

In [None]:
record_scenario(
    new_timeline(10),
    ('Reds',    v1_pass4),
    ('Blues',   v2_pass4),
    ('Purples', v3_pass4),
)

In [None]:
I, F = interact_tubes(
    new_timeline(10),

    (v1_pass4, 'Reds', dict(name=r'$\mathcal{V}_1$')),
    # (v2_pass4, r'$\mathcal{V}_2$', 'Blues'),
    # (v3_pass4, r'$\mathcal{V}_3$', 'Purples'),
    
    (v2_pass2, 'Blues', dict(name=r'$\mathcal{V}_2$')),
    # (shp.setminus(v2_pass1, shp.project_onto(v1_pass4 - AVOID_MARGIN, 0, 1, 2, keepdims=True))[:11],
    #  'Blues',
    #  dict(name=r'$\mathcal{V}_2$')),
    # (shp.setminus(v2_pass1, shp.project_onto(v1_pass4 - AVOID_MARGIN, 0, 1, 2, keepdims=True)),
    #  'Blues',
    #  dict(name=r'$\mathcal{V}_2$', opacity=0.6)),

    # ('Blues',   v2_pass1),
    # ('Reds',    v1_pass4),
    # ('Mint',    v2_pass2),
    # ('Peach',   v2_pass3),
    # ('Greens',  v2_pass4),
    
    # eye=sphere_to_cartesian(2.6, 50, -90 + 15),
    eye=EYE_ESE,
)
if False:
    r = I(F)
else:
    r = F(50)
    # r = r.show(config=dict(toImageButtonOptions=dict(height=720, width=720, scale=6)))
    r = r.write_image('../../Final Presentation/1. Figures/v2_pass2_10s.png', height=720, width=720, scale=6)
r