In [None]:
## imports & configuration

# standard imports
import collections
import threading
import types

# custom imports
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import mpl_toolkits.mplot3d
import ipywidgets

# inline plots
%config InlineBackend.figure_formats = ['svg']
%matplotlib notebook

# jupyter theme
import jupyterthemes as jt
jt.jtplot.style()

In [None]:
## manage ipython magics

# imports
from IPython.core.magic import register_line_cell_magic

# register magic
@register_line_cell_magic
def IPif(line, cell=None):
    '''Only executes current cell if line evaluates to True.'''
    if not eval(line): return
    get_ipython().run_cell(cell)
del IPif

In [None]:
## custom poly-surface plotter

# imports
import matplotlib.colors
import matplotlib.tri
import mpl_toolkits.mplot3d.art3d

# shade a collection of polys
def shade_polyc(
    ax, polyc, verts,
    color=None, norm=None, vmin=None, vmax=None, lightsource=None, cmap=None, shade=None
):
    """
    Parameters
    ----------
    color
        Color of the surface patches.
    cmap
        A colormap for the surface patches.
    norm : Normalize
        An instance of Normalize to map values to colors.
    vmin, vmax : float, default: None
        Minimum and maximum value to map.
    shade : bool, default: True
        Whether to shade the facecolors.  Shading is always disabled when
        *cmap* is specified.
    lightsource : `~matplotlib.colors.LightSource`
        The lightsource to use when *shade* is True.
    """
    # validate arguments
    assert isinstance(ax, mpl_toolkits.mplot3d.axes3d.Axes3D)
    assert isinstance(polyc, mpl_toolkits.mplot3d.art3d.Poly3DCollection)

    # set default for shade
    if shade is None:
        shade = cmap is None

    # default color to next plot color
    if color is None:
        color = ax._get_lines.get_next_color()
    color = np.array(matplotlib.colors.to_rgba(color))

    if cmap:
        # average over the three points of each triangle
        avg_z = verts[:, :, 2].mean(axis=1)
        polyc.set_array(avg_z)
        if vmin is not None or vmax is not None:
            polyc.set_clim(vmin, vmax)
        if norm is not None:
            polyc.set_norm(norm)
    else:
        if shade:
            normals = ax._generate_normals(verts)
            colset = ax._shade_colors(color, normals, lightsource)
        else:
            colset = color
        polyc.set_facecolors(colset)

# surface plot stolen from:
# https://github.com/matplotlib/matplotlib/blob/v3.5.0/lib/mpl_toolkits/mplot3d/axes3d.py#L1908-L2007
def poly_surface(ax, verts, vmin=None, vmax=None, **kwargs):
    """
    Parameters
    ----------
    X, Y, Z : array-like
        Data values as 1D arrays.
    **kwargs
        All other arguments are passed on to `shade_polyc` and
        :class:`~mpl_toolkits.mplot3d.art3d.Poly3DCollection`
    """

    assert isinstance(ax, mpl_toolkits.mplot3d.axes3d.Axes3D), "Axis must be `Axes3D` instance!"

    lightsource = kwargs.pop('lightsource', None)
    color = kwargs.pop('color', None)
    norm = kwargs.pop('norm', None)
    vmin = kwargs.pop('vmin', None)
    vmax = kwargs.pop('vmax', None)
    cmap = kwargs.get('cmap', None)
    shade = kwargs.pop('shade', cmap is None)

    polyc = mpl_toolkits.mplot3d.art3d.Poly3DCollection(verts, **kwargs)

    shade_polyc(ax, polyc, verts, color=color, norm=norm, vmin=vmin, vmax=vmax, lightsource=lightsource, cmap=cmap, shade=shade)

    ax.add_collection(polyc)

    return polyc

In [None]:
# method to visualize surface
def construct_surface(points, polys, figsize=None, **kwargs):
    # construct figure
    fig = plt.figure(figsize=figsize)
    ax = plt.axes(projection='3d')
    had_data = ax.has_data()
    polyc = poly_surface(ax, points[polys], **kwargs)
    ax.auto_scale_xyz(*points.T, had_data)

    # return references to axis and polyc
    return ax, polyc

-------------------------------

In [None]:
## enums

# method to generate enums
enum = lambda *args: types.SimpleNamespace(**{a: i for i, a in enumerate(args)})

# define some enums
SMODE = enum("SIMPLE", "STABLE") # spring modes
SIMDIM = enum("ONE_D", "TWO_D") # simulation dimensions
PRESET = enum("DEFAULT", "SIMPLE_CHAIN", "STABLE_CHAIN", "STABLE_SURFACE")

In [None]:
## global constants

# select preset
preset = PRESET.DEFAULT

# defaults
G = types.SimpleNamespace(
    integrator = "verlet", # set integrator type
    dt = 1/10, # sim time interval
    n_steps = 1500, # sim-length in frames
    sim_dim = SIMDIM.TWO_D, # toggle 1D/2D/3D
    n = 20, # number of points (per dimension)
    spring_mode = SMODE.SIMPLE, # spring configuration
    k0 = 0.1, # base spring constant
    damping = 0., # velocity decay rate
)

# 
if preset == PRESET.SIMPLE_CHAIN:
    G.n_steps = 1000
    G.sim_dim = SIMDIM.ONE_D
    G.n = 20
    G.spring_mode = SMODE.SIMPLE

# 
if preset == PRESET.STABLE_CHAIN:
    G.n_steps = 4000
    G.sim_dim = SIMDIM.ONE_D
    G.n = 100
    G.spring_mode = SMODE.STABLE
    G.damping = 0.03

# 
if preset == PRESET.STABLE_SURFACE:
    G.n_steps = 2000
    G.sim_dim = SIMDIM.TWO_D
    G.n = 10
    G.spring_mode = SMODE.STABLE
    G.damping = 0.03

In [None]:
%%IPif G.sim_dim == SIMDIM.ONE_D ## define initial conditions (1D)

# alias # of points
n = G.n

# initialize points in grid
pos = np.zeros((n, 3))
pos[:, 0] = np.linspace(0, 1, n+1)[:n]+0.5/n
pos[:, 2] = 0.5*np.cos(10*pos[:, 0])*np.cos(10*pos[:, 1])

# initialize velocities
vel = np.zeros((n, 3))

# other particle properties
masses = np.ones((n,))

# general observables
obs = collections.defaultdict(collections.OrderedDict)

# indexing method
i = lambda x: -1 if x < 0 or x >= n else x

# define grid spring network
if G.spring_mode == SMODE.SIMPLE:
    neighbors = np.zeros((n, 2), dtype=int)
    distances = np.zeros((n, 2))
    spring_ks = np.zeros((n, 2))
    perturbations = np.array([-1, 1]).astype(int)
    for x in range(n):
        node = i(x)
        neighs = [i(x+s) for s in perturbations]
        neighbors[node] = [node if index == -1 else index for index in neighs]
        distances[node] = [0 if index == -1 else 1/n for index in neighs]
        spring_ks[node, :] = G.k0
elif G.spring_mode == SMODE.STABLE:
    neighbors = np.zeros((n, n), dtype=int)
    distances = np.zeros((n, n))
    spring_ks = np.zeros((n, n))
    perturbations = np.array([-1, 1]).astype(int)
    stabilization = np.arange(-2*(n//2), 2*(n//2)+1, 2)
    for x in range(n):
        node = i(x)
        neighbors[node, :] = node
        for j in stabilization:
            if (idx := i(x+j)) == -1: continue
            neighbors[node, idx] = idx
            distances[node, idx] = np.linalg.norm(np.diff(pos[[node, idx], :1], axis=0)[0], axis=-1)
            spring_ks[node, idx] = G.k0/4 # stabilization is weak
        for j in perturbations:
            if (idx := i(x+j)) == -1: continue
            neighbors[node, idx] = idx
            distances[node, idx] = np.linalg.norm(np.diff(pos[[node, idx], :1], axis=0)[0], axis=-1)
            spring_ks[node, idx] = G.k0
else:
    raise NotImplementedError(G.spring_mode)

# identify border points
borders = np.array([0, n-1])

In [None]:
%%IPif G.sim_dim == SIMDIM.TWO_D ## define initial conditions (2D)

# alias # of points
n = G.n

# initialize points in grid
pos = np.zeros((n**2, 3))
pos[:, 0] = np.tile(np.linspace(0, 1, n+1)[:n], n)
pos[:, 1] = np.tile(np.linspace(0, 1, n+1)[:n], n).reshape(n, n).T.ravel()
pos[:, :2] += 0.5/n
pos[:, 2] = 0.5*np.cos(10*pos[:, 0])*np.cos(10*pos[:, 1])

# initialize velocities
vel = np.zeros((n**2, 3))

# other particle properties
masses = np.ones((n**2,))

# general observables
obs = collections.defaultdict(collections.OrderedDict)

# indexing methods
i = lambda x, y: -1 if x < 0 or y < 0 or x >= n or y >= n else x+(n*y)
i_arr = lambda x, y: np.where(
    np.logical_or(np.logical_or(x < 0, y < 0), np.logical_or(x >= n, y >= n)),
    -np.ones_like(x), x+(n*y)
)
assert sorted(i_arr(np.arange(n), np.arange(n)[:, None]).ravel()) == list(range(n**2))

# define grid spring network
if G.spring_mode == SMODE.SIMPLE:
    neighbors = np.zeros((n**2, 4), dtype=int)
    distances = np.zeros((n**2, 4))
    spring_ks = np.zeros((n**2, 4))
    perturbations = np.array([np.exp(1j*t*np.pi/2) for t in range(4)]).view('(2,)float').astype(int)
    for y in range(n):
        for x in range(n):
            node = i(x, y)
            neighs = [i(x+s, y+t) for s, t in perturbations]
            neighbors[node] = [node if index == -1 else index for index in neighs]
            distances[node] = [0 if index == -1 else 1/n for index in neighs]
            spring_ks[node, :] = G.k0
elif G.spring_mode == SMODE.STABLE:
    immediate = np.array([np.exp(1j*t*np.pi/2) for t in range(4)]).view('(2,)float').astype(int)
    diagonals = (np.array([np.exp(1j*np.pi*(0.25+t/2)) for t in range(4)])*2/np.sqrt(2)).round().view('(2,)float').astype(int)
    coords = np.mgrid[-n:n+1, -n:n+1].reshape(2, -1).T
    everyother = 2*coords[(coords[:, 0] + coords[:, 1]) % 2 == 0]
    # debug with: `plt.scatter(*everyother.T)`
    neighbors = np.zeros((n**2, n**2), dtype=int)
    distances = np.zeros((n**2, n**2))
    spring_ks = np.zeros((n**2, n**2))
    for y in range(n):
        for x in range(n):
            node = i(x, y)
            neighbors[node, :] = node
            for s, t in everyother:
                if (idx := i(x+s, y+t)) == -1: continue
                neighbors[node, idx] = idx
                distances[node, idx] = np.linalg.norm(np.diff(pos[[node, idx], :2], axis=0)[0], axis=-1)
                spring_ks[node, idx] = G.k0/4 # stabilization
            for s, t in diagonals:
                if (idx := i(x+s, y+t)) == -1: continue
                neighbors[node, idx] = idx
                distances[node, idx] = np.linalg.norm(np.diff(pos[[node, idx], :2], axis=0)[0], axis=-1)
                spring_ks[node, idx] = G.k0/2 # shear
            for s, t in immediate:
                if (idx := i(x+s, y+t)) == -1: continue
                neighbors[node, idx] = idx
                distances[node, idx] = np.linalg.norm(np.diff(pos[[node, idx], :2], axis=0)[0], axis=-1)
                spring_ks[node, idx] = G.k0 # structural
else:
    raise NotImplementedError(G.spring_mode)

# drop trivial neighbors

# identify border points
borders = np.concatenate([
    i_arr(np.arange(n), 0  ),
    i_arr(np.arange(n), n-1),
    i_arr(0  , np.arange(1, n-1)),
    i_arr(n-1, np.arange(1, n-1))
])

# manually triangulate grid points
triangles = np.zeros((2*(n-1)**2, 3), dtype=int)
i = lambda x, y: x+(n*y)
j = 0
for y in range(n-1):
    for x in range(n-1):
        triangles[j] = [i(x+0, y+0), i(x+1, y+0), i(x+0, y+1)]; j += 1
        triangles[j] = [i(x+0, y+1), i(x+1, y+0), i(x+1, y+1)]; j += 1

In [None]:
%%IPif G.sim_dim == SIMDIM.TWO_D and G.spring_mode == SMODE.STABLE and neighbors.shape == (n**2, n**2) ## condense spring lists

# compile non-zero spring values
tmp_neighbors = []
tmp_distances = []
tmp_spring_ks = []
for y in range(n):
    for x in range(n):
        node = i(x, y)
        tmp_neighbors.append([j for j in neighbors[node] if j != node])
        tmp_distances.append(distances[node, tmp_neighbors[node]])
        tmp_spring_ks.append(spring_ks[node, tmp_neighbors[node]])

# initialize minimum necessary new springs
max_nbrs = max(map(len, tmp_neighbors))
neighbors = np.zeros((n**2, max_nbrs), dtype=int)
distances = np.zeros((n**2, max_nbrs))
spring_ks = np.zeros((n**2, max_nbrs))

# repopulate springs from compiled non-zero values
for y in range(n):
    for x in range(n):
        node = i(x, y)
        neighbors[node, :] = node
        n_nbrs = len(tmp_neighbors[node])
        neighbors[node, :n_nbrs] = tmp_neighbors[node]
        distances[node, :n_nbrs] = tmp_distances[node]
        spring_ks[node, :n_nbrs] = tmp_spring_ks[node]

In [None]:
# compute accelerations
def compute_acc(pos, vel, t=None):
    # initialize forces
    forces = np.zeros_like(pos)

    # spring constant force calculation
    # NOTE: this approach double-counts force contributions (since a is b's neighbor and vice versa)
    disps = (pos[neighbors]-pos[:, None]) # displacements
    dists = np.linalg.norm(disps, axis=-1) # distances
    offsets = (dists-distances)/2 # magnitude from equilibrium distance (/2 splits the double-count)
    with np.errstate(divide="ignore", invalid="ignore"):
        units = np.nan_to_num(disps/dists[:, :, None]) # unit displacements
    forces += np.sum((spring_ks*offsets)[:, :, None]*units, axis=1) # compute spring force

    # store energy as observable
    if t:
        obs["kinetic"][t] = 0.5*np.sum((masses[:, None]*vel)*vel) # kinetic
        obs["spring" ][t] = 0.5*np.sum(spring_ks*(offsets**2)) # spring
        obs["total"  ][t] = obs["kinetic"][t]+obs["spring"][t]

    # return accelerations and observables
    return forces/masses[:, None]

In [None]:
# update particles state
def update_state(t, dt, integrator="verlet", damping=0.):
    # load globals
    global pos, vel

    # store initial positions
    init_pos = pos.copy()

    # euler
    if integrator == "euler":
        acc = compute_acc(pos, vel, t)
        vel += acc*dt
        pos += vel*dt

    # verlet
    elif integrator == "verlet":
        acc = compute_acc(pos, vel, t)
        pos += vel*(dt/2)
        acc = compute_acc(pos, vel)
        vel += acc*dt
        acc = compute_acc(pos, vel)
        pos += vel*(dt/2)

    # not implemented
    else:
        raise NotImplementedError(f"The integrator '{integrator}' is not recognized!")

    # damp velocities
    if damping > 0: vel *= (1.-damping)**dt

#     # anchor edge positions
#     pos[borders, :2] = init_pos[borders, :2]
#     pos[borders, 2] = 0.

In [None]:
## method to process a timestep

# global position lock and cache
lock = threading.Lock()
state_cache = collections.OrderedDict()

# update positions given new time frame
def set_frame(f=0):
    # globals
    global G, pos, vel, state_cache, lock

    # acquire lock
    lock.acquire()

    # if frame is missing
    if f not in state_cache:
        # convert frame to real time
        t = f*G.dt

        # compute new position by integrating from previous frame
        prev_frame = next((i for i in sorted(state_cache, reverse=True) if i < f), 0)
        if state_cache and prev_frame != max(state_cache): pos, vel = state_cache[prev_frame]
        update_state(t, (f-prev_frame)*G.dt, G.integrator, damping=G.damping)

        # populate cache
        state_cache[f] = pos.copy(), vel.copy()

    # retrieve positions from cache
    pos, vel = state_cache[f]

    # update plot with new positions
    update_plot(pos, sim_dim=G.sim_dim)

    # release lock
    lock.release()

-------------------------------

In [None]:
# method to update live figure
live_figure = None # set later by sim visualization
def update_plot(pos, sim_dim, rescale=True, _locals={"old_max": 0}):
    # globals
    global live_axis, live_figure, triangles

    # do nothing if figure is not yet live
    if live_figure is None: return

    # update surface plot vertices and shading
    if sim_dim == SIMDIM.TWO_D:
        verts = pos[triangles]
        live_figure.set_verts(verts)
        shade_polyc(live_axis, live_figure, verts, cmap='viridis')#, color="#9467bd", shade=True)
        if rescale:
            live_axis.auto_scale_xyz(*pos.T, True)

    # update plot data
    elif sim_dim == SIMDIM.ONE_D:
        live_figure.set_data(pos[:, 0], pos[:, 2])
        new_max = np.abs(live_figure.get_data()).max()
        if rescale and new_max > _locals["old_max"]:
            live_axis.relim() # recompute limits
            live_axis.autoscale_view() # rescale view
            _locals["old_max"] = new_max

    # update plot data
    else:
        raise NotImplementedError(sim_dim)

In [None]:
## prepare ipywidget

# play button
play = ipywidgets.Play(value=0, min=0, max=G.n_steps, step=1, interval=int(1000*G.dt), disabled=False)

# make interactive slider
slider = ipywidgets.IntSlider(min=play.min, max=play.max, step=play.step, value=play.value)
slider = ipywidgets.interactive(set_frame, f=slider)

# link play button to slider value
ipywidgets.jslink((play, 'value'), (slider.children[0], 'value'))

# construct player widget
player = ipywidgets.HBox([play, slider])

-------------------------------

In [None]:
%%timeit -r1 -n1 # pre-render simulation offline
assert len([set_frame(f) for f in range(play.min, play.max+1, play.step)]) == len(state_cache)

In [None]:
## plot observables

# initialize figure and gridspec
fig = plt.figure(figsize=(6, 4))
gs = mpl.gridspec.GridSpec(5, 1)

# plot total energy
ax = plt.subplot(gs[:1, 0])
ax.plot(*zip(*obs["total"].items()))
ax.get_xaxis().set_visible(False)
ax.set_ylabel("Total")

# plot various observables
ax = plt.subplot(gs[1:, 0])
for key, observables in obs.items():
    ax.plot(*zip(*observables.items()))
ax.set_xlabel("Time")
ax.set_ylabel("Energy")
ax.legend([key[0].upper()+key[1:] for key in obs])

# display tight figure
fig.tight_layout()

-----------------------

In [None]:
%%IPif G.sim_dim == SIMDIM.ONE_D # visualize system (1D)
fig, live_axis = plt.subplots(figsize=(5, 3))
live_figure = live_axis.plot(pos[:, 0], pos[:, 2], 'o-')[0]
fig.tight_layout()

# widget to scrub through simulation
player

In [None]:
%%IPif G.sim_dim == SIMDIM.TWO_D # visualize system (2D)
live_axis, live_figure = construct_surface(pos, triangles, figsize=(6, 5), cmap="viridis", edgecolor=None)

# widget to scrub through simulation
player

-----------------------