In [None]:
from mymath import *
from potentials import *
import torch
import plotly.graph_objects as go
import numpy as np
from ipywidgets import widgets

In [None]:
# potential field in the space

dhat = 1
r = 1
alpha = 2
e0 = torch.tensor([-1, 1e-4])
e1 = torch.tensor([0, 1e-4])

resolution = 5e-3
xs = torch.arange(-1.1, 0.1, step=resolution)
ys = torch.arange(-0.1, 0.1, step=resolution)

# distances = torch.zeros((ys.shape[0], xs.shape[0]))
points = torch.zeros((ys.shape[0] * xs.shape[0], 2))
for i, x in enumerate(xs):
    for j, y in enumerate(ys):
        points[j * xs.shape[0] + i,:] = torch.tensor([xs[i], ys[j]])

distances = point_edge_potential_discrete(points, e0, e1, dhat, r, alpha, 2).view(ys.shape[0], xs.shape[0])

major_max = percentile(distances, 0.95)
print("major range", [0, major_max])
fig = go.Figure(data=[
    go.Scatter(x=[e0[0], e1[0]], y=[e0[1], e1[1]]),
    go.Contour(z=distances, x=xs, y=ys, 
    contours=dict(
        start=0,
        end=major_max,
        size=major_max/20,
    ))
], layout=go.Layout(width=800, height=400))
fig.show()

In [None]:
domain_range = torch.tensor([[-0.10001,1.1],[-0.20001,0.2]])

# interactive potential field in the space

dhat = 1e-1
r = 1
alpha = 2
edges = torch.tensor([[-1., 0],[0, 0],[1, 0]])

sample_slider = widgets.IntSlider(
    value=200,
    min=100,
    max=1e4,
    step=1,
    description='Resolution:',
    continuous_update=False,
    readout_format='.0f'
)
r_slider = widgets.FloatSlider(
    value=r,
    min=1,
    max=5,
    step=1e-1,
    description='barrier p:',
    continuous_update=False,
    readout_format='.2e'
)
alpha_slider = widgets.FloatSlider(
    value=alpha,
    min=0,
    max=10,
    step=1e-2,
    description='alpha:',
    continuous_update=False,
    readout_format='.2e'
)
eps_slider = widgets.FloatSlider(
    value=dhat,
    min=0,
    max=2,
    step=1e-3,
    description='eps:',
    continuous_update=False,
    readout_format='.3e'
)
X_slider = widgets.FloatSlider(
    value=edges[1,0],
    min=-1,
    max=1,
    step=1e-3,
    description='X:',
    continuous_update=False,
    readout_format='.3e'
)
Y_slider = widgets.FloatSlider(
    value=edges[1,1],
    min=-0.1,
    max=0.1,
    step=1e-4,
    description='Y:',
    continuous_update=False,
    readout_format='.4e'
)

sliders = [r_slider,alpha_slider,eps_slider,X_slider,Y_slider,sample_slider]
container = widgets.VBox([
    widgets.HBox([r_slider,alpha_slider,eps_slider]),
    widgets.HBox([X_slider,Y_slider])])

resolution = 5e-3
xs = torch.arange(domain_range[0,0], domain_range[0,1], step=resolution)
ys = torch.arange(domain_range[1,0], domain_range[1,1], step=resolution)

points = torch.zeros((ys.shape[0] * xs.shape[0], 2))
for i, x in enumerate(xs):
    for j, y in enumerate(ys):
        points[j * xs.shape[0] + i,:] = torch.tensor([xs[i], ys[j]])

potential = torch.zeros((ys.shape[0], xs.shape[0]))
for i in range(edges.shape[0]-1):
    potential += point_edge_potential_discrete(points, edges[i,:], edges[i+1,:], dhat, r, alpha, 2).view(ys.shape[0], xs.shape[0])

trace1 = go.Scatter(x=edges[:,0], y=edges[:,1])
major_max = percentile(potential, 0.95)
trace2 = go.Contour(z=potential,x=xs,y=ys, 
    contours=dict(
        start=0,
        end=major_max,
        size=major_max/20,
    ))
g = go.FigureWidget(data=[trace1, trace2],
                    layout=go.Layout(width=800, height=600))

def response(change):
    potential = torch.zeros((ys.shape[0], xs.shape[0]))
    edges[1,0] = X_slider.value
    edges[1,1] = Y_slider.value
    for i in range(edges.shape[0]-1):
        potential += point_edge_potential_discrete(points, edges[i,:], edges[i+1,:], eps_slider.value ** 2, r_slider.value, alpha_slider.value, 2).view(ys.shape[0], xs.shape[0])
    major_max = percentile(potential, 0.95)

    with g.batch_update():
        g.data[1].x = xs
        g.data[1].y = ys
        g.data[1].z = potential
        g.data[0].x = edges[:,0]
        g.data[0].y = edges[:,1]
        g.data[1].contours = dict(
            start=0,
            end=major_max,
            size=major_max/20,
        )

for s in sliders:
    s.observe(response, names="value")

widgets.VBox([container,
              g])