In [20]:
%matplotlib widget
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets

## Projection Functions

In [21]:
def region_nearest_convex_set_a(pt):
    """ """
    x,y = pt
    return np.array([2/5*(2*x+y), (2*x+y)/5])

def region_nearest_convex_set_b(pt):
    """ """
    x,y = pt
    return np.array([32/65*(2*x+y/4), 4/65*(2*x+y/4)])

In [22]:
def region_nearest_nonconvex_set_a(pt):
    """ """
    x,y = pt
    if (8*x + 4*y >= 5) or (2*x + y <= 5/8):
        common = (2*x + y) / 5
    elif (15/16 < 2*x + y) and (2*x + y < 5/4):
        common = 1/4
    elif 2*x + y > 5/8:
        common = 1/8
    else:
        common = 2*y/5 + 4*x/5 - 1/8
    return np.array([2*common,common])

def region_nearest_nonconvex_set_b(pt):
    """ """
    x,y = pt
    if (8*x + 4*y >= 65/16) or (8*x + y <= 65/32):
        common = (8*x + y) / 65
    elif (195/64 < 8*x + y) and (x + y < 65/16):
        common = 1/16
    elif 8*x + y > 65/32:
        common = 1/32
    else:
        common = 2*y/65 + 16*x/64 - 1/32
    return np.array([8*common,common])

def region_nearest_flower(pt):
    """ """
    distances = np.linalg.norm(flower_vals - pt,axis=1)
    return flower_vals[np.argmin(distances)]

In [23]:
def generalized_projection(fa,fb,params):
    """ """
    a,b,c = params
    x = 1 - a - b
    y = 1 - c

    def f(pt):
        """ """
        pa = fa(pt)
        pb = fb(pa * c + pt * y)
        return np.array([
            pa,
            pb,
            x*pt + a*pa + b*pb
        ])
        
    return f

In [24]:
def product_space_projection(fa,fb,fc):
    """ """
    def f(pts):
        """ """
        pt_a, pt_b, pt_c = pts
        return np.array([fa(pt_a), fb(pt_b), fc(pt_c)])
        
    return f

def diagonal_projection(pts):
    """ """
    mean_pt = np.mean(pts,axis=0)
    return np.tile(mean_pt,(3,1))

def generalized_projection_multiple(fa, fb, fc, params):
    """ """
    a,b,c = params
    x = 1 - a - b
    y = 1 - c

    product = product_space_projection(fa,fb,fc)
    diagonal = diagonal_projection

    def f(pts):
        """ """
        product_pts = product(pts)
        diagonal_pts = diagonal(product_pts*c + y*pts)
        return x*pts + a*product_pts + b*diagonal_pts
        
    return f

def riffle_arrays(array_1, array_2):
    """ """
    shape = array_1.shape
    output = np.empty((2*shape[0],)+shape[1:])
    output[0::2] = array_1
    output[1::2] = array_2
    return output

In [25]:
# flower values
theta = np.arange(0,2*np.pi,np.pi/512)
origin = -np.array([np.cos(np.pi/8),np.sin(np.pi/8)])*5/4
radius = 1 + np.sin(theta*4)/4
x_vals = (radius * np.cos(theta) - origin[0]) / 3
y_vals = (radius * np.sin(theta) - origin[1]) / 3
flower_vals = np.stack((x_vals,y_vals),-1)

In [26]:
num_iterations = 99
parameters = [0,1,1]

In [27]:
projection_sets = [
    generalized_projection(
        region_nearest_convex_set_a,
        region_nearest_convex_set_b,
        parameters
    )
]

pts = np.empty((num_iterations+1,3,2))
pts[0] = np.array([[1,0.3],[1,0.3],[1,0.3]])

for i in range(num_iterations):
    pts[i+1] = projection_sets[0](pts[i,-1])

## Visualization Functions

In [28]:
def add_convex_sets(ax,col_a,col_b):
    """ """
    set_a = ax.plot(
        [-1/2,9/8],
        [-1/4,9/16],
        color=col_a
    )[0]
    set_b = ax.plot(
        [-1/2,9/8],
        [-1/16,9/64],
        color=col_b
    )[0]
    return set_a,set_b

def add_nonconvex_sets(ax,col_a,col_b):
    """ """
    set_a1 = ax.plot(
        [-1/2,1/4],
        [-1/4,1/8],
        color=col_a
    )[0]
    set_a2 = ax.plot(
        [1/2,9/8],
        [1/4,9/16],
        color=col_a
    )[0]
    set_b1 = ax.plot(
        [-1/2,1/4],
        [-1/16,1/32],
        color=col_b
    )[0]
    set_b2 = ax.plot(
        [1/2,9/8],
        [1/16,9/64],
        color=col_b
    )[0]

    return set_a1, set_a2, set_b1, set_b2

def add_flower_set(ax,col,**kwargs):
    """ """
    flower = ax.plot(
        flower_vals[:,0],
        flower_vals[:,1],
        color=col,
        **kwargs,
    )[0]
    return flower 

def add_trace(ax,trace,col,**kwargs):
    """ """
    trace = ax.plot(
        trace[:,0],
        trace[:,1],
        color=col,
        **kwargs,
    )[0]
    return trace

def add_points(ax,pts,col):
    """ """
    scatter_pts = ax.scatter(
        pts[:,0],
        pts[:,1],
        s=200,
        alpha=0.5,
        color=col
    )
    return scatter_pts

In [29]:
col_a = "#f994ff"
col_b = "#94fff9"
col_c = "#949aff"
col_sol = "#fff994"

In [30]:
with plt.ioff():
    dpi=72
    fig,ax = plt.subplots(figsize=(660/dpi,310/dpi),dpi=dpi)

fig.patch.set_alpha(0)
ax.patch.set_alpha(0)
convex_sets = add_convex_sets(ax,col_a,col_b)

# hidden nonconvex
nonconvex_sets = add_nonconvex_sets(ax,col_a,col_b)
for line in nonconvex_sets:
    line.set_visible(False)

# hidden flower and individual iterations
flower_set = add_flower_set(ax,col_c)
iteration_a = add_trace(ax,pts[:1,0],col_a,linestyle="--",alpha=0.5)
iteration_b = add_trace(ax,pts[:1,1],col_b,linestyle="--",alpha=0.5)
iteration_c = add_trace(ax,pts[:1,2],col_c,linestyle="--",alpha=0.5)
flower_set.set_visible(False)
iteration_a.set_visible(False)
iteration_b.set_visible(False)
iteration_c.set_visible(False)

iteration_trace = add_trace(ax,riffle_arrays(pts[:1,0],pts[:1,1]),"gray",linestyle="--",alpha=0.5)
sol_trace = add_trace(ax,pts[:1,2],col_sol)
scatter_pts = add_points(ax,pts[:1,2],col_sol)
ax.set(aspect="equal",xlim=[-1/2,9/8],ylim=[-1/4,1/2],xticks=[],yticks=[])
ax.axis("off")
fig.tight_layout()
fig.canvas.resizable = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
fig.canvas.toolbar_visible = False
fig.canvas.layout.width = '660px'
fig.canvas.layout.height = "315px"
None

In [33]:
layout_thirds = ipywidgets.Layout(width='220px',height='30px')
style = {
    'description_width': 'initial',
}

play = ipywidgets.Play(
    value=1,
    min=1,
    max=num_iterations+1,
    step=1,
    interval=500,
    show_repeat=False,
    style=style,
    layout=layout_thirds,
)

play_speed = ipywidgets.IntSlider(
    min=50,
    max=1000,
    step=50,
    value=250,
    layout=layout_thirds,
    description="speed (ms):",
    style=style,
)

slider = ipywidgets.IntSlider(
    min=1,
    max=num_iterations+1,
    step=1,
    value=1,
    layout=layout_thirds,
    description="iteration #:",
    style=style,
)

def update_iterations(change):
    """ """
    index = change['new']
    set_index = projection_sets_toggle.value
    sol_trace.set_data(pts[:index,2].T if  set_index < 2 else pts[:index].mean(1).T)
    iteration_trace.set_data(riffle_arrays(pts[:index,0],pts[:index,1]).T)
    iteration_a.set_data(pts[:index,0].T)
    iteration_b.set_data(pts[:index,1].T)
    iteration_c.set_data(pts[:index,2].T)
    scatter_pts.set_offsets(pts[:index,2] if  set_index < 2 else pts[:index].mean(1))
    fig.canvas.draw_idle()

def update_speed(change):
    """ """
    speed = change['new']
    play.interval = speed
    
ipywidgets.jslink((play, 'value'), (slider, 'value'))
play.observe(update_iterations,"value")
play_speed.observe(update_speed,"value")

In [34]:

layout_half = ipywidgets.Layout(width='325px',height='30px')
projection_method = ipywidgets.Dropdown(
    options=[("Alternating Projections (AP)",0),("Difference Map (DM)",1),("Relaxed Averaged Alternating Reflections (RAAR)",3)],
    value=0,
    description="projection method:",
    style=style,
    layout=layout_half,
)

def set_parameters(change):
    """ """
    value = change['new']
    if value == 0:
        relaxation_gamma.disabled = True
        parameters[:] = [0,1,1]
    elif value == 1:
        relaxation_gamma.disabled = True
        parameters[:] = [-1,1,2]
    else:
        relaxation_gamma.disabled = False
        parameters[:] = [1-2*relaxation_gamma.value,relaxation_gamma.value,2]

    update_projection_sets_function(projection_sets_toggle.value)
    update_pts(projection_sets_toggle.value)
    return None

relaxation_gamma = ipywidgets.FloatSlider(
    value=0.85,
    min=0,
    max=1,
    step=0.05,
    description="relaxation parameter:",
    style=style,
    layout=layout_half,
    disabled=True,
)

def set_gamma(change):
    """ """
    value = change['new']
    set_parameters({'new':projection_method.value})
    return None

projection_method.observe(set_parameters,"value")
relaxation_gamma.observe(set_gamma,"value")

In [39]:
projection_sets_toggle = ipywidgets.ToggleButtons(
    options=[("2 Convex Sets",0),("2 Nonconvex Sets",1),("3 Nonconvex Sets",2)],
    value=0,
)

def update_projection_sets_function(
    set_index
):
    """ """
    if set_index == 0:
        projection_sets[0] = generalized_projection(
            region_nearest_convex_set_a,
            region_nearest_convex_set_b,
            parameters
        )
        
    elif set_index == 1:
        projection_sets[0] = generalized_projection(
            region_nearest_nonconvex_set_a,
            region_nearest_nonconvex_set_b,
            parameters
        )
        
    else:
        projection_sets[0] = generalized_projection_multiple(
            region_nearest_nonconvex_set_a,
            region_nearest_nonconvex_set_b,
            region_nearest_flower,
            parameters
        )
    return None
        
def update_pts(
    set_index
):
    """ """
    new_pts = np.empty((num_iterations+1,3,2))
    new_pts[0] = np.array([[1,0.3],[1,0.3],[1,0.3]])
    
    for i in range(num_iterations):
        current_pt = new_pts[i,-1] if set_index < 2 else new_pts[i]
        new_pts[i+1] = projection_sets[0](current_pt)

    pts[:] = new_pts
    play.playing = False
    play.value = 1
    return None

def update_projection_sets(change):
    """ """
    set_index = change['new']
    update_projection_sets_function(set_index)

    for line in convex_sets:
        line.set_visible(set_index == 0)
    for line in nonconvex_sets:
        line.set_visible(set_index > 0)
    flower_set.set_visible(set_index == 2)

    update_traces({'new':traces_toggle.value})
    update_pts(set_index)
    
    return None

projection_sets_toggle.observe(update_projection_sets,"value")

traces_toggle = ipywidgets.ToggleButton(
    value=True,
    description="show traces"
)

def update_traces(change):
    show_traces = change['new']
    set_index = projection_sets_toggle.value
    
    iteration_trace.set_visible(set_index<2 and show_traces)
    iteration_a.set_visible(set_index == 2 and show_traces)
    iteration_b.set_visible(set_index == 2 and show_traces)
    iteration_c.set_visible(set_index == 2 and show_traces)
    fig.canvas.draw_idle()
    return None

traces_toggle.observe(update_traces,"value")

In [40]:
#| label: app:projection-sets
ipywidgets.VBox(
    [
        ipywidgets.HBox([projection_sets_toggle,traces_toggle]),
        ipywidgets.HBox([play,play_speed,slider]),
        ipywidgets.HBox([projection_method,relaxation_gamma]),
        fig.canvas   
    ],
    layout=ipywidgets.Layout(width="660px",align_items="center")
)

VBox(children=(HBox(children=(ToggleButtons(options=(('2 Convex Sets', 0), ('2 Nonconvex Sets', 1), ('3 Noncon…