In [None]:
from spring import frutcherman_reingold
import cooling_schedules
from jax import numpy as jnp
from jax import random
from matplotlib import pyplot as plt
from IPython.display import clear_output
import time
from functools import partial
from matplotlib.collections import LineCollection
import utils
import samples

In [None]:
problem_index = 0

In [None]:
problems = [
    [samples.triangles, samples.triangles_colors],
    [samples.tri_quad_line, samples.tri_quad_line_colors],
    [samples.k_5, samples.k_5_colors],
    [samples.k_3_3, samples.k_3_3_colors],
    [samples.rings, samples.rings_colors],
    [samples.fireworks, samples.fireworks_colors]
]


In [None]:
key = random.PRNGKey(0)
num_time_steps = 100

width = 1
height = width


subkey, key = random.split(key)

E, colors = problems[problem_index]
num_points = E.shape[0]

V = random.normal(subkey, [num_points, 2])
V = 0.5 *V / (jnp.linalg.norm(V, axis=-1, keepdims=True) + 1e-17)

edge_list = jnp.array(utils.adjacency_matrix_to_list(E))

In [None]:
def plot(V, edge_list, title="", colors=None):
    lc = LineCollection(V[edge_list])
    fig = plt.figure(figsize=(8,8), dpi=200)
    f = 1.1
    plt.xlim(-width*f, width*f)
    plt.ylim(-height*f, height*f)
    
    
    plt.title(title)
    plt.gca().add_collection(lc)
    plt.scatter(V[:,0], V[:,1], c=colors, zorder=1000)
    
    plt.show(fig)

In [None]:
plot(V, edge_list, colors=colors)

In [None]:
cooling_fn_lin = partial(cooling_schedules.linear, start_temperature=3*width/num_time_steps, end_temperature=0.0)

In [None]:
def callback(V, E, time_step):
    clear_output(wait=True)
    plot(V, edge_list,f"time={time_step}", colors)
    if time_step == 0:
        time.sleep(10)
    time.sleep(0.000_001)

In [None]:
new_V = frutcherman_reingold.apply_frutcherman_reingold(key, V, E, width, height, num_time_steps, cooling_fn_lin, callback)