In [None]:
%load_ext autoreload
%autoreload 2

from sympy import *
from safe_region_utils import *
init_printing()

# ACAS X Example

## Define a polygon

In [None]:
w = 0.75
rect_points: list = [
    geometry.Point(val)
    for val in [[2 * w, -w], [2 * w, w], [-2 * w, w], [-2 * w, -w]]
]
rectangle: geometry.Polygon = Polygon(*rect_points)
plot_polygon(rectangle)

## Define a trajectory

In [None]:
x, y = symbols("x y")

traj_piecewise = Piecewise(
                    (x**2 / 16, x < 4),
                    (x/2 - 1, x >= 4))
plot(traj_piecewise)

## Define domain and plotting bounds

In [None]:
domain = Interval(-6, 15)
xbounds = (domain.inf-3, domain.sup+3)
ybounds = (-2, 9)

In [None]:
example_name = "ACAS X Climb with Rectangle"

cond = compute_unsafe_cond(
    x, y,
    rectangle,
    traj_piecewise,
    domain = domain, 
)
cond

In [None]:
# plot_condition(x, y, cond, xbounds, ybounds, title=example_name)
mathematica_output = print_mathematica(x, y, cond, xbounds, ybounds, traj_piecewise, rectangle)
print(mathematica_output)

In [None]:
(xbounds[1]+xbounds[0]) / 2

# UAV top-down collision avoidance

## Define a polygon

In [None]:
rp = 2
hexagon = RegularPolygon(Point(0, 0), rp, 6)

plot_polygon(hexagon)

## Define a trajectory

In [None]:
R = 10
theta = pi/3
# y = sqrt(R**2 - x**2)
bound = R / sqrt(tan(theta)**2 + 1)

traj_piecewise = Piecewise(
                    (sqrt(R**2 - x**2), x > bound),
                    (-1/tan(theta)*(x-R*cos(theta)) + R*sin(theta), x <= bound))
plot(traj_piecewise)

## Define domain and plot bounds

In [None]:
domain = Interval(-12, 10)
xbounds = (domain.inf-3, domain.sup+3)
ybounds = (-3, 19)

## Run algorithm

In [None]:
example_name = "Top-Down UAV Trajectory"

cond = compute_unsafe_cond(
    x, y,
    hexagon,
    traj_piecewise,
    domain = domain, 
)
cond

In [None]:
plot_condition(x, y, cond, xbounds, ybounds, title=example_name, resolution=0.75)
mathematica_output = print_mathematica(x, y, cond, xbounds, ybounds, traj_piecewise, hexagon)
print(mathematica_output)

# Short Examples

In [None]:
w = 0.5
square_points: list = [
    geometry.Point(val) for val in [[w, -w], [w, w], [-w, w], [-w, -w]]
]
square: geometry.Polygon = Polygon(*square_points)

traj_piecewise = Piecewise(
                    (sin(x/2), x < 0),
                    (x/2, x >= 0))
plot(traj_piecewise)

In [None]:
traj_piecewise.free_symbols

In [None]:
domain = Interval(-12, 9)
xbounds = [-15, 12]
ybounds = [-3, 9]

cond = compute_unsafe_cond(
    x,
    y,
    square,
    traj_piecewise,
    domain,
)

cond

In [None]:
plot_condition(x, y, cond, xbounds, ybounds, title=example_name)
mathematica_output = print_mathematica(x, y, cond, xbounds, ybounds, traj_piecewise, square, False)
print(mathematica_output)

In [None]:
domain = Interval(-12, 9)
xbounds = [-15, 12]
ybounds = [-3, 9]

cond = compute_unsafe_cond(
    x,
    y,
    square,
    0.8 * x,
    domain,
)

cond

In [None]:
plot_condition(x, y, cond, xbounds, ybounds, title=example_name)
mathematica_output = print_mathematica(x, y, cond, xbounds, ybounds, traj_piecewise, square, False)
print(mathematica_output)

## Testing $x = f(y)$

In [None]:
square: geometry.Polygon = Polygon(*square_points)
plot_polygon(square)

domain = Interval(0, 10)
traj_piecewise = Piecewise((4*sqrt(y), y < 1), (2*y + 2, y >= 1))
# plot_implicit(traj_piecewise)

In [None]:
plot(traj_piecewise)

In [None]:
plot_implicit(Eq(x, 2*y + 2))

In [None]:
plot_implicit(Eq(x, 4*sqrt(y)))

In [None]:
domain = Interval(0, 4)
xbounds = [-2, 12]
ybounds = [-2, 6]

cond = compute_unsafe_cond(
    x,
    y,
    square,
    traj_piecewise,
    domain,
)
cond

In [None]:
mathematica_command = print_mathematica(x, y, cond, xbounds, ybounds, traj_piecewise, square)
print(mathematica_command)

In [None]:
plot_condition(x, y, cond, xbounds, ybounds, resolution=0.25, title="test")

In [None]:
# non-piecewise case
cond = compute_unsafe_cond(
    x,
    y,
    hexagon,
    4*sqrt(y),
    domain,
)
cond

In [None]:
mathematica_command = print_mathematica(x, y, cond, xbounds, ybounds, 4*sqrt(y), hexagon)
print(mathematica_command)

In [None]:
plot_condition(x, y, cond, xbounds, ybounds, resolution=0.75, title="test")

In [None]:
diff(x**2 / 16, x).subs(x, 8*sqrt(3))

In [None]:
plot_polygon(hexagon)  # side is 60 deg indeed

In [None]:
diff(x**2 / 16, x)

In [None]:
1 / diff(4*sqrt(y), y)

In [None]:
x_val = 8
y_val = (x**2 / 16).subs(x, x_val)
y_val

In [None]:
(4 * sqrt(y)).subs(y, y_val) == x_val

In [None]:
diff(x**2 / 16, x).subs(x, x_val) == (1 / diff(4*sqrt(y), y)).subs(y, y_val)

In [None]:
# seems this works
for x_val in np.arange(0, 100):
    y_val = (x**2 / 16).subs(x, x_val)
    if (4 * sqrt(y)).subs(y, y_val) != x_val or diff(x**2 / 16, x).subs(x, x_val) != (1 / diff(4*sqrt(y), y)).subs(y, y_val):
            print(x_val, y_val)

In [None]:
if None:
    print("test")

## testing slope sym

In [None]:
plot_polygon(hexagon)
angles, vertex_pairs = compute_polygon_angles(hexagon)
angles

In [None]:
type(sin_traj)

In [None]:
sin_traj = sin(2*x)
plot(sin_traj)

In [None]:
slope_sym(sin_traj, x, y)

In [None]:
slope_sym(-y + sin_traj, x, y)

In [None]:
diff(sin_traj, x)

In [None]:
dtrans, settrans = find_transitions(-y + sin_traj, angles, x, y, domain = Interval(-10, 10))

In [None]:
dtrans[pi/3]

In [None]:
dtrans[4*pi/3]

In [None]:
dtrans[2*pi/3]

In [None]:
dtrans[5*pi/3]

In [None]:
diff(sin_traj, x).subs(x, pi/12)

In [None]:
diff(sin_traj, x).subs(x, 11*pi/12)

In [None]:
diff(sin_traj, x).subs(x, 13*pi/12)

In [None]:
dtrans, settrans = find_transitions(-sin_traj, angles, x, y, domain = Interval(-10, 10))

In [None]:
plot(-sin_traj)

In [None]:
slope_sym

In [None]:
piecewise_traj

## testing left inclusive piecewise filtering

In [None]:
w = 0.5
square_points: list = [
    geometry.Point(val) for val in [[w, -w], [w, w], [-w, w], [-w, -w]]
]
square: geometry.Polygon = Polygon(*square_points)

traj_piecewise = Piecewise(
                    (sin(x/2), x < 0),
                    (x/2, x >= 0))
plot(traj_piecewise)

In [None]:
traj_piecewise = Piecewise(
                    (sin(x/2), x <= 0),
                    (x/2, x >= 0))
plot(traj_piecewise)

piece = Interval(0, 5)
traj_piecewise.as_expr_set_pairs(piece)

In [None]:
len(traj_piecewise.as_expr_set_pairs(piece))

In [None]:
traj_piecewise = Piecewise(
                    (sin(x/2), x <= 0),
                    (x/2, x > 0))
plot(traj_piecewise)

piece = Interval(0, 5)
traj_piecewise.as_expr_set_pairs(piece)

In [None]:
traj_piecewise = Piecewise(
                    (sin(x/2), x < 0),
                    (x/2, x > 0))
plot(traj_piecewise)

piece = Interval(0, 5)
traj_piecewise.as_expr_set_pairs(piece)

In [None]:
# left inclusive will get it first, which is bad - use open intervals to query?

In [None]:
traj_piecewise = Piecewise(
                    (sin(x/2), x <= 0),
                    (x/2, x >= 0))
plot(traj_piecewise)

piece = Interval.open(0, 5)
traj_piecewise.as_expr_set_pairs(piece)

In [None]:
traj_piecewise = Piecewise(
                    (sin(x/2), x <= 0),
                    (x/2, x > 0))
plot(traj_piecewise)

piece = Interval.open(0, 5)
traj_piecewise.as_expr_set_pairs(piece)

In [None]:
len(traj_piecewise.as_expr_set_pairs(piece))

In [None]:
# open intervals seem to fix boundary problems

## speeding up dotplot with vectors?

In [None]:
cond

In [None]:
cond.subs([(x, 0), (y, 0)])

In [None]:
lambdify(cond)

In [None]:
help(lambdify)

In [None]:
f = lambdify((x, y), cond, "numpy")

In [None]:
f(np.array([0, 1, 2]), np.array([0, 1, 2]))

In [None]:
f(np.array([0, 1, 2]), 0)

In [None]:
len(np.arange(0, 10, 0.5))

In [None]:
import time 
# with subs
t0 = time.time()
vals = np.ones((40, 40)) * -1
for i, x_O in enumerate(np.arange(-5, 5, 0.25)):
    for j, y_O in enumerate(np.arange(-5, 5, 0.25)):
        val = bool(~cond.subs([(x, x_O), (y, x_O)]))
        vals[i, j] = val
print(time.time() - t0)

In [None]:
assert(np.all(vals != -1))

In [None]:
plt.figure()
for i, x_O in enumerate(np.arange(-5, 5, 0.25)):
    for j, y_O in enumerate(np.arange(-5, 5, 0.25)):
        if vals[i, j] == 1:
            plt.plot(x_O, y_O, "bo")
        else:
            plt.plot(x_O, y_O, "ro")
plt.show()

In [None]:
cond

In [None]:
import time 
# with subs
t0 = time.time()
vals_l = np.zeros((40, 40)) * -1
f = lambdify((x, y), cond, "numpy")
for i, x_O in enumerate(np.arange(-5, 5, 0.25)):
    for j, y_O in enumerate(np.arange(-5, 5, 0.25)):
        val = bool(f(x_O, y_O))
        vals_l[i, j] = val
print(time.time() - t0)

In [None]:
plot_condition(x, y, cond)

In [None]:
import time 
# with subs
t0 = time.time()
f = lambdify((x, y), cond, "numpy")
for i, x_O in enumerate(np.arange(0, 10, 0.5)):
    for j, y_O in enumerate(np.arange(0, 10, 0.5)):
        assert bool(f(x_O, x_O)) == bool(cond.subs([(x, x_O), (y, x_O)]))

In [None]:
bool(f(0, 0))

In [None]:
bool(cond.subs([(x, 0), (y, 0)]))

In [None]:
vals

In [None]:
vals_l

In [None]:
t0 = time.time()
fig = plt.figure()
ax = fig.gca()
resolution = 0.25
for x0 in np.arange(xbounds[0], xbounds[1], resolution):
    for y0 in np.arange(ybounds[0], ybounds[1], resolution):
        is_safe = (~cond).subs([(x, x0), (y, y0)])
        if resolution < 0.5:
            dotscale = 6
        elif resolution < 1:
            dotscale = 4
        else:
            dotscale = 3
        if is_safe:
            ax.plot(
                x0,
                y0,
                "o",
                # TODO(nishant): make matplotlib default blue
                color="#0000bb",
                markersize=1
            )
        else:
            ax.plot(
                x0,
                y0,
                "o",
                # TODO(nishant): make matplotlib default red/orange
                color="#bb0010",
                markersize=1.0,
            )
print(time.time() - t0)

In [None]:
t0 = time.time()
fig = plt.figure()
ax = fig.gca()
resolution = 0.25
f = lambdify([x, y], ~cond)
if resolution < 0.5:
    dotscale = 6
elif resolution < 1:
    dotscale = 4
else:
    dotscale = 3
for x0 in np.arange(xbounds[0], xbounds[1], resolution):
    for y0 in np.arange(ybounds[0], ybounds[1], resolution):
        is_safe = f(x0, y0)
        if is_safe:
            ax.plot(
                x0,
                y0,
                "o",
                # TODO(nishant): make matplotlib default blue
                color="#0000bb",
                markersize=1
            )
        else:
            ax.plot(
                x0,
                y0,
                "o",
                # TODO(nishant): make matplotlib default red/orange
                color="#bb0010",
                markersize=1.0,
            )
print(time.time() - t0)

In [None]:
help(plt.scatter)

In [None]:
# lambdify is slow 
t0 = time.time()
fig = plt.figure()
ax = fig.gca()
resolution = 0.25
f = lambdify([x, y], ~cond)
xpoints = []
ypoints = []
colors = []
for x0 in np.arange(xbounds[0], xbounds[1], resolution):
    for y0 in np.arange(ybounds[0], ybounds[1], resolution):
        is_safe = f(x0, y0)
        xpoints.append(x0)
        ypoints.append(y0)
        if is_safe:
            colors.append("b")
        else:
            colors.append("r")

ax.scatter(xpoints, ypoints, s = 1.0, c=colors)
print(time.time() - t0)

In [None]:
t0 = time.time()
fig = plt.figure()
ax = fig.gca()
resolution = 0.5
xpoints = []
ypoints = []
colors = []
for x0 in np.arange(xbounds[0], xbounds[1], resolution):
    for y0 in np.arange(ybounds[0], ybounds[1], resolution):
        is_safe = (~cond).subs([(x, x0), (y, y0)])
        xpoints.append(x0)
        ypoints.append(y0)
        if is_safe:
            colors.append("b")
        else:
            colors.append("r")

ax.scatter(xpoints, ypoints, s = 1.0, c=colors)
runtime = time.time() - t0
print(runtime)

In [None]:
# with lists - lists win 
ntries = 30
total = 0
for i in range(ntries):
    t0 = time.time()
    fig = plt.figure()
    ax = fig.gca()
    resolution = 0.25
    xpoints = []
    ypoints = []
    colors = []
    for x0 in np.arange(xbounds[0], xbounds[1], resolution):
        for y0 in np.arange(ybounds[0], ybounds[1], resolution):
            is_safe = (~cond).subs([(x, x0), (y, y0)])
            xpoints.append(x0)
            ypoints.append(y0)
            if is_safe:
                colors.append("b")
            else:
                colors.append("r")

    ax.scatter(xpoints, ypoints, s = 1.0, c=colors)
    plt.close(fig)
    runtime = time.time() - t0
    if i % 3 == 0:
        print(runtime)
    total += runtime
print(total / ntries)

In [None]:
ntries = 30
total = 0
for i in range(ntries):
    # with np arrays
    t0 = time.time()
    fig = plt.figure()
    ax = fig.gca()
    resolution = 0.25
    xpoints = np.array([])
    ypoints = np.array([])
    colors = []
    for x0 in np.arange(xbounds[0], xbounds[1], resolution):
        for y0 in np.arange(ybounds[0], ybounds[1], resolution):
            is_safe = (~cond).subs([(x, x0), (y, y0)])
            xpoints = np.append(xpoints, x0)
            ypoints = np.append(ypoints, y0)
            if is_safe:
                colors.append("b")
            else:
                colors.append("r")

    ax.scatter(xpoints, ypoints, s = 1.0, c=colors)
    plt.close(fig)
    runtime = time.time() - t0
    if i % 3 == 0:
        print(runtime)
    
    total += runtime
print(total / ntries)

In [None]:
ntries = 30
total = 0
for i in range(ntries):
    # with meshgrid this time 
    t0 = time.time()
    fig = plt.figure()
    ax = fig.gca()
    resolution = 0.25
    xrange = np.arange(xbounds[0], xbounds[1], resolution)
    yrange = np.arange(ybounds[0], ybounds[1], resolution)
    ypoints, xpoints = np.meshgrid(yrange, xrange)
    xpoints.reshape((1, -1))
    ypoints.reshape((1, -1))
    colors = []
    for x0 in np.arange(xbounds[0], xbounds[1], resolution):
        for y0 in np.arange(ybounds[0], ybounds[1], resolution):
            is_safe = (~cond).subs([(x, x0), (y, y0)])
            if is_safe:
                colors.append("b")
            else:
                colors.append("r")

    ax.scatter(xpoints, ypoints, s = 1.0, c=colors)
    plt.close(fig)    
    runtime = time.time() - t0
    if i % 3 == 0:
        print(runtime)

    total += runtime
print(total / ntries)

In [None]:
np.array(xpoints)

In [None]:
np.array(ypoints)

In [None]:
xpoints.reshape((1, -1))

In [None]:
ypoints.reshape((1, -1))

In [None]:
t0 = time.time()
fig = plt.figure()
ax = fig.gca()
resolution = 0.25
for x0 in np.arange(xbounds[0], xbounds[1], resolution):
    for y0 in np.arange(ybounds[0], ybounds[1], resolution):
        is_safe = (~cond).subs([(x, x0), (y, y0)])
        if resolution < 0.5:
            dotscale = 6
        elif resolution < 1:
            dotscale = 4
        else:
            dotscale = 3
        if is_safe:
            ax.plot(
                x0,
                y0,
                "o",
                # TODO(nishant): make matplotlib default blue
                color="#0000bb",
                markersize=1
            )
        else:
            ax.plot(
                x0,
                y0,
                "o",
                # TODO(nishant): make matplotlib default red/orange
                color="#bb0010",
                markersize=1.0,
            )
print(time.time() - t0)