In [None]:
from abc import ABC
from functools import partial
from functools import reduce
from io import BytesIO
from ipywidgets import interact
from IPython.display import display
from IPython.display import Image
from IPython.display import Markdown
from jupyter_renderer_widget import PyplotRenderer
from matplotlib import animation
from matplotlib import pyplot as plt
from matplotlib import rc
from matplotlib.lines import Line2D
from textwrap import dedent
from tqdm.notebook import tqdm
import daglet
import ipywidgets
import matplotlib.patches as patches
import numpy as np
import operator
import tempfile
import PIL
import random
import seaborn as sns

In [None]:
rc('animation', html='html5')
sns.set()

In [None]:
def draw_line(ax, x1, y1, x2, y2, **kwargs):
    kwargs.setdefault('linestyle', '-')
    kwargs.setdefault('linewidth', 2.)
    kwargs.setdefault('color', 'k')
    ax.add_line(
        Line2D(
            (x1, x2),
            (y1, y2),
            **kwargs
        )
    )
    
def init_ax(ax, lim=10, grid=False, hide_axes=True, aspect=16/9):
    ax.set_aspect('equal')
    x_max = lim * aspect
    y_max = lim

    if hide_axes:
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        #fig.patch.set_visible(False)
        ax.axis('off')
    else:
        ax.spines['bottom'].set_position('zero')
        ax.spines['left'].set_position('zero')
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        if grid:
            major_spacing = lim / 5
            minor_spacing = major_spacing / 2
            major_x_max = (int(x_max / major_spacing) + 1) * major_spacing
            major_y_max = (int(y_max / major_spacing) + 1) * major_spacing
            ax.set_xticks(np.arange(-major_x_max, major_x_max, major_spacing))
            ax.set_xticks(np.arange(-major_x_max, major_x_max, minor_spacing), minor=True)
            ax.set_yticks(np.arange(-major_y_max, major_y_max, major_spacing))
            ax.set_yticks(np.arange(-major_y_max, major_y_max, minor_spacing), minor=True)
            ax.grid(which='minor', alpha=0.2)
            ax.grid(which='major', alpha=0.5)
            hide_ticks = lambda ticks: [x.tick1line.set_visible(False) for x in ticks]
            hide_ticks(ax.xaxis.get_major_ticks())
            hide_ticks(ax.xaxis.get_minor_ticks())
            hide_ticks(ax.yaxis.get_major_ticks())
            hide_ticks(ax.yaxis.get_minor_ticks())
            
    ax.set_xlim(-x_max, x_max)
    ax.set_ylim(-y_max, y_max)

In [None]:
@interact(
    q0=(-1, 1, 0.1),
    q1=(-1, 1, 0.1),
    q2=(-1, 1, 0.1),
    qd0=(-2, 2, 0.1),
    qq1=(-2, 2, 0.1),
    qd2=(-2, 2, 0.1),
)
def f(
    q0=-0.6,
    q1=0.1,
    q2=0.3,
    qd0=1.2,
    qd1=-0.5,
    qd2=1.,
):
    q0 *= np.pi
    q1 *= np.pi
    q2 *= np.pi
    
    fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(16, 9))
    init_ax(ax0)
    init_ax(ax1)
    ax0.set_title('velocity')
    ax1.set_title('position')
    
    l = 4.
    x0 = l * np.sin(q0)
    y0 = l * -np.cos(q0)
    x1 = x0 + l * np.sin(q1)
    y1 = y0 + l * -np.cos(q1)
    x2 = x1 + l * np.sin(q2)
    y2 = y1 + l * -np.cos(q2)
    xd0 = qd0 * l * np.cos(q0)
    yd0 = qd0 * l * np.sin(q0)
    xd1 = xd0 + qd1 * l * np.cos(q1)
    yd1 = yd0 + qd1 * l * np.sin(q1)
    xd2 = xd1 + qd2 * l * np.cos(q2)
    yd2 = yd1 + qd2 * l * np.sin(q2)
    
    draw_line(ax1, 0, 0, x0, y0)
    draw_line(ax1, x0, y0, x1, y1)
    draw_line(ax1, x1, y1, x2, y2)
    draw_line(ax0, 0, 0, xd0, yd0)
    draw_line(ax0, xd0, yd0, xd1, yd1)
    draw_line(ax0, xd1, yd1, xd2, yd2)

    X = np.linspace(0, 1, 8, endpoint=False)
    Y = np.linspace(0, 1, 6)
    xx, yy = np.meshgrid(X, Y)
    pos_xx = (1 - yy) * 4 * np.cos(xx * 2 * np.pi) + x2
    pos_yy = (1 - yy) * 4 * np.sin(xx * 2 * np.pi) + y2
    ax1.plot(pos_xx.flatten(), pos_yy.flatten(), 'ro', ms=3.)
    ax1.plot(pos_xx[0][0], pos_yy[0][0], 'ko', ms=6.)


In [None]:
@interact(
    theta=(-3.1, 3.1, 0.1),
    theta_d=(-4., 4., 0.1),
    x=(-3., 3., 0.1),
    y=(-3., 3., 0.1),
    xd=(-6., 6., 0.2),
    yd=(-6., 6., 0.2),
    t=(-5., 5., 0.1),
)
def f(theta=0.5, theta_d=1.2, x=1.5, y=0.5, xd=-0.7, yd=1.5, t=0., show_taylor_approx=False):
    fig, (ax0) = plt.subplots(1, 1, figsize=(16, 9))
    init_ax(ax0, lim=5.)

    local1 = np.array([
        [-1., -1., 1.],
        [1., -1., 1.],
        [1., 1., 1.],
        [-1., 1., 1.],
        [0., 0., 1.],
    ]).transpose()
    A = np.matrix([
        [np.cos(theta), -np.sin(theta), x],
        [np.sin(theta), np.cos(theta), y],
        [0., 0., 1.],
    ])
    Ad = np.matrix([
        [theta_d * -np.sin(theta), theta_d * -np.cos(theta), xd],
        [theta_d * np.cos(theta), theta_d * -np.sin(theta), yd],
        [0, 0, 1],
    ])
    global1 = A * local1
    A2 = np.matrix([
        [np.cos(theta + theta_d*t), -np.sin(theta + theta_d*t), x + xd*t],
        [np.sin(theta + theta_d*t), np.cos(theta + theta_d*t), y + yd*t],
        [0., 0., 1.],
    ])
    global2 = A2 * local1
    global3 = (A + Ad*t) * local1

    draw_line(ax0, 0, 0, global1[0, -1], global1[1, -1], linestyle='--', linewidth=1.)

    draw_line(ax0, global1[0, 0], global1[1, 0], global1[0, 1], global1[1, 1], linestyle='--')
    draw_line(ax0, global1[0, 1], global1[1, 1], global1[0, 2], global1[1, 2], linestyle='--')
    draw_line(ax0, global1[0, 2], global1[1, 2], global1[0, 3], global1[1, 3], linestyle='--')
    draw_line(ax0, global1[0, 3], global1[1, 3], global1[0, 0], global1[1, 0], linestyle='--')
    ax0.plot(global1[0,:], global1[1,:], 'ro')
    
    if show_taylor_approx:
        draw_line(ax0, 0, 0, global3[0, -1], global3[1, -1], linestyle='--', linewidth=0.5)
        draw_line(ax0, global3[0, 0], global3[1, 0], global3[0, 1], global3[1, 1], linestyle='--', linewidth=1.)
        draw_line(ax0, global3[0, 1], global3[1, 1], global3[0, 2], global3[1, 2], linestyle='--', linewidth=1.)
        draw_line(ax0, global3[0, 2], global3[1, 2], global3[0, 3], global3[1, 3], linestyle='--', linewidth=1.)
        draw_line(ax0, global3[0, 3], global3[1, 3], global3[0, 0], global3[1, 0], linestyle='--', linewidth=1.)
        ax0.plot(global3[0,:], global3[1,:], 'bo', ms=4.)

    draw_line(ax0, 0, 0, global2[0, -1], global2[1, -1], linestyle='--', linewidth=1.)
    draw_line(ax0, global2[0, 0], global2[1, 0], global2[0, 1], global2[1, 1])
    draw_line(ax0, global2[0, 1], global2[1, 1], global2[0, 2], global2[1, 2])
    draw_line(ax0, global2[0, 2], global2[1, 2], global2[0, 3], global2[1, 3])
    draw_line(ax0, global2[0, 3], global2[1, 3], global2[0, 0], global2[1, 0])
    ax0.plot(global2[0,:], global2[1,:], 'go')
    
    xs = np.linspace(-5., 5., 15)
    ys = np.linspace(-5., 5., 15)
    xg, yg = np.meshgrid(xs, ys)
    zg = np.ones((xs.shape[0], ys.shape[0]))
    global2 = np.matrix(np.row_stack((
        xg.ravel(),
        yg.ravel(),
        zg.ravel(),
    )))
    local2 = np.linalg.inv(A) * global2
    #xformed2 = np.matrix(local2).T * (A.T*A) * np.matrix(local2)

    ij0 = np.matrix([[1, 0, 0], [0, 1, 0], [0, 0, 0]])
    vels = ij0 * Ad * local2
    vel_xg = np.array(vels[0,:]).reshape(xg.shape)
    vel_yg = np.array(vels[1,:]).reshape(yg.shape)
    ax0.quiver(xg, yg, vel_xg, vel_yg)

    S = Ad.T * ij0.T * ij0 * Ad
    zg = np.diag(local2.T * S * local2).reshape(xs.shape[0], ys.shape[0])

    #ax0.plot(xformed2[0,:], xformed2[1,:], 'go', ms=1.)
    
    # Find the center through this ugly formula:
    v2 = 1/(theta_d * (np.cos(theta)**2 / np.sin(theta) + np.sin(theta))) * (np.cos(theta) / np.sin(theta) * xd + yd)
    #v2 = 1/(theta_d * np.cos(theta)) * yd
    v1 = 1./np.sin(theta) * (xd/theta_d - v2 * np.cos(theta))
    w = A * np.matrix([[v1, v2, 1]]).T
    ax0.plot(w[0,0], w[1,0], 'ko', ms=7.)
    
    #eigvals, eigvecs = np.linalg.eig(A.T * A)
    #eigvecs = np.real(eigvecs)
    #ax0.arrow(0., 0., eigvecs[0, 0], eigvecs[0, 1], color='r')
    #ax0.arrow(0., 0., eigvecs[1, 1], eigvecs[1, 1], color='g')
    #ax0.arrow(0., 0., eigvecs[2, 2], eigvecs[2, 1], color='b')

    #ax1.set_title('Kinetic energy of each point on rotating frame\n(inherits velocity from parent frame)')
    #ax1.contour(xg, yg, zg, vmin=0, vmax=50, levels=20)

In [None]:
def draw_line(ax, x1, y1, x2, y2, **kwargs):
    kwargs.setdefault('linestyle', '-')
    kwargs.setdefault('linewidth', 2.)
    kwargs.setdefault('color', 'k')
    line = Line2D((x1, x2), (y1, y2), **kwargs)
    ax.add_line(line)


def draw_circle(ax, x, y, radius, **kwargs):
    kwargs.setdefault('color', 'k')
    circle = plt.Circle((x, y), radius, **kwargs)
    ax.add_artist(circle)


#def draw_box(ax, x, y, width, height, **kwargs):
#    kwargs.setdefault('color', 'k')
#    rect = patches.Rectangle((x, y), width, height, **kwargs)
#    ax.add_patch(rect)


def do_render_animation(draw_func, max_time_index, sample_interval, fps=25, fig=None, tqdm=None):
    ax = None
    if fig is None:
        fig, ax = plt.subplots(figsize=(16, 9), dpi=80)
        ax.axis('off')
    plt.tight_layout()
    
    frame_count = int(max_time_index / sample_interval)
    progress = tqdm(total=frame_count * 2) if tqdm else None

    def animate(i):
        if ax is not None:
            ax.clear()
        time_index = i * sample_interval
        draw_func(ax, time_index)
        if progress:
            progress.update()
        return (fig,)

    anim = animation.FuncAnimation(
        fig,
        animate,
        frames=frame_count,
        interval=int(1000/fps),
        blit=True,
    )
    plt.close(anim._fig)
    return anim

In [None]:
DEFAULT_GRAVITY = 10.

ZERO_POS = np.array([[0., 0., 1.]]).T
ZERO_VEL = np.array([[0., 0., 0.]]).T

CENTERED_SQUARE = np.array([
    (-0.5, -0.5, 1.),
    (0.5, -0.5, 1.),
    (0.5, 0.5, 1.),
    (-0.5, 0.5, 1.),
]).T

QUAD1_SQUARE = np.array([
    (0., 0., 1.),
    (1., 0., 1.),
    (1., 1., 1.),
    (0., 1., 1.),
]).T

ZIGZAG1 = np.array([
    (0., 0., 1.),
    (0.25, 0.5, 1.),
    (0.75, -0.5, 1.),
    (1., 0., 1.),
]).T

ZIGZAG5 = np.array([
    (0., 0., 1.),
    (1./20., 0.5, 1.),
    (3./20., -0.5, 1.),
    (5./20., 0.5, 1.),
    (7./20., -0.5, 1.),
    (9./20., 0.5, 1.),
    (11./20., -0.5, 1.),
    (13./20., 0.5, 1.),
    (15./20., -0.5, 1.),
    (17./20., 0.5, 1.),
    (19./20., -0.5, 1.),
    (1., 0., 1.),
]).T

ZERO_STATE = (0., 0.)


def coerce_position_vector(a):
    #if not isinstance(a, np.ndarray) or a.shape != (3, 1):
    if isinstance(a, (int, float)):
        a = np.array([[float(a), 0., 1.]]).T
    else:
        a = np.asarray(a).flatten()
        assert len(a) == 2 or len(a) == 3
        a = np.array([[a[0], a[1], 1.]]).T
    return a


def get_scale_matrix(x, y):
    return np.array([
        (x, 0., 0.),
        (0., y, 0.),
        (0., 0., 1.),
    ])


def get_rotation_matrix(angle):
    c = np.cos(angle)
    s = np.sin(angle)
    return np.array([
        (c, -s, 0.),
        (s, c, 0.),
        (0., 0., 1.),
    ])


def get_translation_matrix(offset):
    offset = coerce_position_vector(offset)
    return np.array([
        (1., 0., offset[0, 0]),
        (0., 1., offset[1, 0]),
        (0., 0., 1.),
    ])


def get_rotation_translation_matrix(angle, offset):
    offset = coerce_position_vector(offset)
    c = np.cos(angle)
    s = np.sin(angle)
    return np.array([
        (c, -s, offset[0, 0]),
        (s, c, offset[1, 0]),
        (0., 0., 1.),
    ])


def normalize_vector(vec):
    size = np.linalg.norm(vec)
    if np.isclose(size, 0.):
        vec = np.zeros(vec.shape)
    else:
        vec = vec / size
    return vec


def coerce_state_tuple(t):
    if t is None:
        t = ZERO_STATE
    elif t is ZERO_STATE:
        pass
    elif isinstance(t, (int, float)):
        t = (float(t), 0.)
    else:
        if isinstance(t, list):
            t = tuple(t)
        assert len(t) != 0
        assert len(t) <= 2
        t0 = float(t[0]) if len(t) >= 1 else 0.
        t1 = float(t[1]) if len(t) >= 2 else 0.
        t = (t0, t1)
    return t


class Connector(object):
    def __init__(self, node=None, parents=[]):
        if isinstance(node, (tuple, list)):
            node = Group(*node)
        self.node = node
        self.parents = parents
        
    def __repr__(self):
        return '{} [0x{:x}]'.format(type(self).__name__, id(self))
    
    def __or__(self, other):
        return Connector(other, [self])


class Group(object):
    def __init__(self, *nodes):
        if len(nodes) == 1 and isinstance(nodes[0], (list, tuple)):
            # TODO: support generators too.
            nodes = nodes[0]
        self.nodes = nodes
        
    def __repr__(self):
        return '{} [0x{:x}]'.format(type(self).__name__, id(self))
    
    def __or__(self, other):
        return Connector(other, [Connector(self)])


class StaticTransform(ABC):
    def __or__(self, other):
        return Connector(other, [Connector(self)])
    
    def get_xform_matrix(self):
        raise NotImplementedError()


class Rotation(StaticTransform):
    def __init__(self, angle):
        self.angle = float(angle)
        
    def __repr__(self):
        return '{} [0x{:x}]\nangle={:.2f}'.format(type(self).__name__, id(self), self.angle)
    
    def get_xform_matrix(self):
        return get_rotation_matrix(self.angle)


class Translation(StaticTransform):
    def __init__(self, offset):
        self.offset = coerce_position_vector(offset)
        
    def __repr__(self):
        return '{} [0x{:x}]\noffset=({:.2f}, {:.2f})'.format(type(self).__name__, id(self), self.offset[0, 0], self.offset[1, 0])

    def get_xform_matrix(self):
        return get_translation_matrix(self.offset)


class Frame(object):
    def __init__(self, decals=[], masses=[], frames=[], resistance=0., initial_state=ZERO_STATE):
        self.decals = decals
        self.masses = masses
        self.frames = frames
        self.resistance = float(resistance)
        self.initial_state = coerce_state_tuple(initial_state)
        
    def __repr__(self):
        return '{} [0x{:x}]'.format(type(self).__name__, id(self))

    def __or__(self, other):
        return Connector(other, [Connector(self)])
    
    def draw(self, ax, xform_matrix, scale):
        raise NotImplementedError()

    def get_pos_matrix(self, q):
        return np.identity(3)
    
    def get_vel_matrix(self, q):
        return np.zeros((3, 3))
    
    def get_accel_matrix(self, q):
        return np.zeros((3, 3))

    def draw(self, ax, state_map, xform_matrix, scale):
        state = state_map.get(self, ZERO_STATE)
        q = state[0]
        xform_matrix = xform_matrix.dot(self.get_pos_matrix(q))
        for decal in self.decals:
            decal.draw(ax, xform_matrix, scale)
        for frame in self.frames:
            frame.draw(ax, state_map, xform_matrix, scale)

    
class RotationalFrame(Frame):
    def __init__(self, position=ZERO_POS, decals=[], masses=[], frames=[], resistance=0., initial_state=ZERO_STATE):
        super(RotationalFrame, self).__init__(decals, masses, frames, resistance, initial_state)
        self.position = coerce_position_vector(position)
        
    def xform(self, xform_matrix, decals=None, masses=None, frames=None):
        if decals is None:
            decals = self.decals
        if masses is None:
            masses = self.masses
        if frames is None:
            frames = self.frames
        return RotationalFrame(
            position=xform_matrix.dot(self.position),
            decals=decals,
            masses=masses,
            frames=frames,
            resistance=self.resistance,
            initial_state=self.initial_state,
        )

    def get_pos_matrix(self, q):
        c = np.cos(q)
        s = np.sin(q)
        return np.array([
            (c, -s, self.position[0, 0]),
            (s, c, self.position[1, 0]),
            (0., 0., 1.),
        ])
    
    def get_vel_matrix(self, q):
        c = np.cos(q)
        s = np.sin(q)
        return np.array([
            (-s, -c, 0.),
            (c, -s, 0.),
            (0., 0., 0.),
        ])
    
    def get_accel_matrix(self, q):
        c = np.cos(q)
        s = np.sin(q)
        return np.array([
            (-c, s, 0.),
            (-s, -c, 0.),
            (0., 0., 0.),
        ])


class TrackFrame(Frame):
    def __init__(self, position=ZERO_POS, angle=0., decals=[], masses=[], frames=[], resistance=0., initial_state=ZERO_STATE):
        super(TrackFrame, self).__init__(decals, masses, frames, resistance, initial_state)
        self.position = coerce_position_vector(position)
        self.angle = float(angle)
        
    def xform(self, xform_matrix, decals=None, masses=None, frames=None):
        angle = 0.  # TODO: detect from xform_matrix.
        if decals is None:
            decals = self.decals
        if masses is None:
            masses = self.masses
        if frames is None:
            frames = self.frames
        return TrackFrame(
            position=xform_matrix.dot(self.position),
            angle=self.angle + angle,
            decals=decals,
            masses=masses,
            frames=frames,
            resistance=self.resistance,
            initial_state=self.initial_state,
        )

    def get_pos_matrix(self, q):
        c = np.cos(self.angle)
        s = np.sin(self.angle)
        return get_translation_matrix(
            self.position + np.array([[q * c, q * s, 0.]]).T
        )
    
    def get_vel_matrix(self, q):
        return np.array([
            (0, 0, np.cos(self.angle)),
            (0, 0, np.sin(self.angle)),
            (0., 0., 0.),
        ])


class QuadraticFrame(Frame):
    def __init__(self, position=ZERO_POS, decals=[], masses=[], frames=[], resistance=0., a=1., b=0., c=0., d=0., e=0., f=0., initial_state=ZERO_STATE):
        """Frame with position coordinates determined by quadratic form:
            ax**2 + bxy + cy**2 + dx + ey = f
        """
        super(QuadraticFrame, self).__init__(decals, masses, frames, resistance, initial_state)
        self.position = coerce_position_vector(position)
        self.a = float(a)
        self.b = float(b)
        self.c = float(c)
        self.d = float(d)
        self.e = float(e)
        self.f = float(f)
        
    def xform(self, xform_matrix, decals=None, masses=None, frames=None):
        if decals is None:
            decals = self.decals
        if masses is None:
            masses = self.masses
        if frames is None:
            frames = self.frames
        return QuadraticFrame(
            position=xform_matrix.dot(self.position),
            decals=decals,
            masses=masses,
            frames=frames,
            resistance=self.resistance,
            initial_state=self.initial_state,
            a=self.a,
            b=self.b,
            c=self.c,
            d=self.d,
            e=self.e,
            f=self.f,
        )

    def get_pos_matrix(self, q):
        return get_translation_matrix(
            self.position + np.array([[
                q,
                self.a * q * q,
                0.
            ]]).T
        )
    
    def get_vel_matrix(self, q):
        return np.array([
            (0, 0, 1),
            (0, 0, 2. * self.a * q),
            (0., 0., 0.),
        ])
    
    def get_accel_matrix(self, q):
        return np.array([
            (0, 0, 0),
            (0, 0, 2 * self.a),
            (0., 0., 0.),
        ])


class Mass(object):
    def __init__(self, mass=1., position=ZERO_POS, drag=0.):
        self.mass = float(mass)
        self.position = coerce_position_vector(position)
        self.drag = float(drag)
        
    def __repr__(self):
        return '{} [0x{:x}]\nmass={:.2f}\ndrag={:.2f}'.format(type(self).__name__, id(self), self.mass, self.drag)

    def __or__(self, other):
        return Connector(other, [Connector(self)])

    def xform(self, xform_matrix):
        return Mass(mass=self.mass, position=xform_matrix.dot(self.position), drag=self.drag)


class Decal(ABC):
    def __or__(self, other):
        return Connector(other, [Connector(self)])
    
    def __repr__(self):
        return '{} [0x{:x}]'.format(type(self).__name__, id(self))
    
    def xform(self, xform_matrix):
        raise NotImplementedError()

    def draw(self, ax, xform_matrix, scale):
        raise NotImplementedError()
        

class LineDecal(Decal):
    def __init__(self, end_pos, start_pos=ZERO_POS, linewidth=2., color='k'):
        self.start_pos = coerce_position_vector(start_pos)
        self.end_pos = coerce_position_vector(end_pos)
        self.linewidth = float(linewidth)
        self._positions = np.hstack((self.start_pos, self.end_pos))
        self.color = color
        
    def xform(self, xform_matrix):
        scale = 1.  # TODO: detect from xform_matrix.
        return LineDecal(
            end_pos=xform_matrix.dot(self.end_pos),
            start_pos=xform_matrix.dot(self.start_pos), 
            linewidth=self.linewidth * scale,
            color=self.color,
        )

    def draw(self, ax, xform_matrix, scale):
        xformed = xform_matrix.dot(self._positions)
        draw_line(
            ax,
            xformed[0, 0],
            xformed[1, 0],
            xformed[0, 1],
            xformed[1, 1],
            linewidth=self.linewidth*scale,
            color=self.color,
        )


class CircleDecal(Decal):
    def __init__(self, position=ZERO_POS, radius=1.):
        self.position = coerce_position_vector(position)
        self.radius = float(radius)

    def __repr__(self):
        return '{} [0x{:x}]\nradius={}'.format(type(self).__name__, id(self), self.radius)
    
    def xform(self, xform_matrix):
        scale = 1.  # TODO: detect from xform_matrix.
        return CircleDecal(
            position=xform_matrix.dot(self.position),
            radius=self.radius * scale,
        )

    def draw(self, ax, xform_matrix, scale):
        xformed = xform_matrix.dot(self.position)
        draw_circle(ax, xformed[0, 0], xformed[1, 0], self.radius * scale)


class BoxDecal(Decal):
    def __init__(self, width=1., height=1., position=ZERO_POS, angle=0., centered=True, solid=True, linewidth=1., color='black'):
        self.width = float(width)
        self.height = float(height)
        self.position = coerce_position_vector(position)
        self.angle = float(angle)
        self.centered = bool(centered)
        self.solid = bool(solid)
        self.linewidth = float(linewidth)
        self.color = color
        positions = CENTERED_SQUARE if centered else QUAD1_SQUARE
        xform_matrix = get_scale_matrix(width, height)
        xform_matrix = get_rotation_translation_matrix(-angle, self.position).dot(xform_matrix)
        self._corner_positions = xform_matrix.dot(positions)
        
    def __repr__(self):
        return '{} [0x{:x}]\nwidth={}, height={}'.format(type(self).__name__, id(self), self.width, self.height)

    def xform(self, xform_matrix):
        angle = 0.  # TODO: detect from xform_matrix.
        scale = 1.  # TODO: detect from xform_matrix.
        return BoxDecal(
            width=self.width * scale,
            height=self.height * scale,
            position=xform_matrix.dot(self.position),
            angle=self.angle + angle,
            centered=self.centered,
            solid=self.solid,
            linewidth=self.linewidth * scale,
            color=self.color,
        )

    def draw(self, ax, xform_matrix, scale):
        xformed = xform_matrix.dot(self._corner_positions)
        npoints = self._corner_positions.shape[1]
        if self.solid:
            ax.fill(xformed[0, :], xformed[1, :], self.color)
        else:
            for i in range(npoints):
                j = (i + 1) % npoints
                draw_line(
                    ax,
                    xformed[0, i],
                    xformed[1, i],
                    xformed[0, j],
                    xformed[1, j],
                    linewidth=self.linewidth*scale,
                    color=self.color,
                )


class Spring(object):
    def __init__(
        self,
        frame1,
        frame2,
        k=1.,
        damping=0.,
        position1=ZERO_POS,
        position2=ZERO_POS,
        visible=True,
        linewidth=1.,
        zigzag_count=5,
        zigzag_padding=0.8,
        zigzag_width=0.8,
    ):
        assert not isinstance(frame1, list), frame1  # TODO: better type sanity
        assert not isinstance(frame2, list), frame2  # TODO: better type sanity
        self.frame1 = frame1
        self.frame2 = frame2
        self.k = float(k)
        self.damping = float(damping)
        self.position1 = coerce_position_vector(position1)
        self.position2 = coerce_position_vector(position2)
        self.visible = visible
        self.linewidth = float(linewidth)
        self.zigzag_count = int(zigzag_count)
        self.zigzag_padding = float(zigzag_padding)
        self.zigzag_width = float(zigzag_width)
        
    def __repr__(self):
        return '{} [0x{:x}]\nk={}'.format(type(self).__name__, id(self), self.k)
        
    def __or__(self, other):
        return Connector(other, [self])
    
    def xform(self, xform_matrix1, xform_matrix2, frame1=None, frame2=None):
        scale = 1.  # TODO: detect from xform_matrix.
        return Spring(
            frame1=frame1 or self.frame1,
            frame2=frame2 or self.frame2,
            k=self.k,
            damping=self.damping,
            position1=xform_matrix1.dot(self.position1),
            position2=xform_matrix2.dot(self.position2),
            visible=self.visible,
            linewidth=self.linewidth,
            zigzag_count=self.zigzag_count,
            zigzag_padding=self.zigzag_padding,
            zigzag_width=self.zigzag_width,
        )
        
    def draw(self, ax, root_xform_matrix, frame1_xform_matrix, frame2_xform_matrix, scale):
        if not self.visible:
            return
        pos1 = frame1_xform_matrix.dot(self.position1)
        pos2 = frame2_xform_matrix.dot(self.position2)
        dist = np.linalg.norm(pos2 - pos1)
        if dist < self.zigzag_padding * 2 * scale:
            draw_line(ax, pos1[0, 0,], pos1[1, 0], pos2[0, 0], pos2[1, 0])
        else:
            normal = (pos2 - pos1) / dist
            zigzag_start = pos1 + normal * self.zigzag_padding * scale
            zigzag_end = pos2 - normal * self.zigzag_padding * scale
            basis1 = zigzag_end - zigzag_start
            basis2 = scale * self.zigzag_width * np.array([(normal[1, 0], -normal[0, 0], 0.)]).T
            basis3 = np.array([(0., 0., 1.)]).T
            zigzag_xform = get_translation_matrix(zigzag_start).dot(np.hstack((basis1, basis2, basis3)))
            zigzag_points = zigzag_xform.dot(ZIGZAG5)
            points = np.hstack((
                pos1,
                zigzag_points,
                pos2,
            ))
            for i in range(points.shape[1] - 1):
                draw_line(
                    ax,
                    points[0, i],
                    points[1, i],
                    points[0, i + 1],
                    points[1, i + 1],
                    linewidth=self.linewidth * scale
                )

                
class Constraint(object):
    def __init__(
        self,
        frame1,
        frame2,
        position1=ZERO_POS,
        position2=ZERO_POS,
        linewidth=1.,
    ):
        assert not isinstance(frame1, list), frame1  # TODO: better type sanity
        assert not isinstance(frame2, list), frame2  # TODO: better type sanity
        self.frame1 = frame1
        self.frame2 = frame2
        self.position1 = coerce_position_vector(position1)
        self.position2 = coerce_position_vector(position2)
        self.linewidth = float(linewidth)
        
    def __repr__(self):
        return '{} [0x{:x}]'.format(type(self).__name__, id(self))
        
    def __or__(self, other):
        return Connector(other, [self])
    
    def xform(self, xform_matrix1, xform_matrix2, frame1=None, frame2=None):
        scale = 1.  # TODO: detect from xform_matrix.
        return Constraint(
            frame1=frame1 or self.frame1,
            frame2=frame2 or self.frame2,
            position1=xform_matrix1.dot(self.position1),
            position2=xform_matrix2.dot(self.position2),
            linewidth=self.linewidth,
        )
        
    def draw(self, ax, root_xform_matrix, frame1_xform_matrix, frame2_xform_matrix, scale):
        pos1 = frame1_xform_matrix.dot(self.position1)
        pos2 = frame2_xform_matrix.dot(self.position2)
        draw_line(ax, pos1[0, 0], pos1[1, 0], pos2[0, 0], pos2[1, 0], linewidth=self.linewidth*scale)


class DrawOptions:
    def __init__(self, grid=False, hide_axes=True, aspect=16/9):
        self.grid = grid
        self.hide_axes = hide_axes
        self.aspect = aspect


class Scene:
    def __init__(
        self,
        decals=[],
        frames=[],
        springs=[],
        constraints=[],
        gravity=DEFAULT_GRAVITY,
        draw_options=None,
    ):
        get_children = lambda x: x.frames
        visit_path = lambda frame, paths: reduce(operator.add, paths, []) + [frame]
        self.decals = decals
        self.frames = frames
        self.springs = springs
        self.constraints = constraints
        self.gravity = gravity
        self.sorted_frames = list(reversed(daglet.toposort(self.frames, get_children)))
        self.frame_parents_map = daglet.get_child_map(self.sorted_frames, get_children)
        assert all(len(x) <= 1 for x in self.frame_parents_map.values())  # frames should only have one parent
        self.frame_parent_map = {k: list(v)[0] if len(v) == 1 else None for k, v in self.frame_parents_map.items()}
        self.frame_path_map = daglet.transform_vertices(self.sorted_frames, self.frame_parents_map.get, visit_path)
        self.draw_options = draw_options or DrawOptions()
        
    def __repr__(self):
        return '{} [0x{:x}]'.format(type(self).__name__, id(self))

    def get_xform_matrix(self, state_map, frame, *deriv_frames):
        """Get matrix that transforms points from one frame into the global
        coordinate space, and optionally take partial derivatives to return
        a velocity transformation matrix or acceleration transformation matrix
        rather than a position transformation matrix.
        """
        path = self.frame_path_map[frame]
        deriv_frames = [x for x in deriv_frames if x is not None]
        if any(x not in path for x in deriv_frames):
            matrix = np.zeros((3, 3))
        else:
            matrix = np.identity(3)
            for frame2 in reversed(path):
                q, _ = state_map[frame2]
                deriv_order = sum(frame2 is x for x in deriv_frames)
                if deriv_order == 0:
                    matrix = frame2.get_pos_matrix(q).dot(matrix)
                elif deriv_order == 1:
                    matrix = frame2.get_vel_matrix(q).dot(matrix)
                elif deriv_order == 2:
                    matrix = frame2.get_accel_matrix(q).dot(matrix)
                else:
                    assert 0, 'Invalid xform matrix derivative order: {}'.format(deriv_order)
        return matrix
    
    def get_velocity_matrix(self, state_map, frame):
        mat = np.zeros((3, 3))
        for frame2 in self.frame_path_map[frame]:
            _, qd = state_map[frame2]
            mat += qd * self.get_xform_matrix(state_map, frame, frame2)
        return mat
        
    def draw(self, ax, state_map, xform_matrix=np.identity(3), draw_options=None):
        draw_options = draw_options or self.draw_options
        init_ax(ax, grid=draw_options.grid, hide_axes=draw_options.hide_axes, aspect=draw_options.aspect)
        scale = np.sqrt(np.linalg.det(xform_matrix[:2, :2]))
        for decal in self.decals:
            decal.draw(ax, xform_matrix, scale)
        for frame in self.frames:
            frame.draw(ax, state_map, xform_matrix, scale)
        for spring in self.springs:
            frame1_xform_matrix = xform_matrix.dot(self.get_xform_matrix(state_map, spring.frame1))
            frame2_xform_matrix = xform_matrix.dot(self.get_xform_matrix(state_map, spring.frame2))
            spring.draw(ax, xform_matrix, frame1_xform_matrix, frame2_xform_matrix, scale)
        for constraint in self.constraints:
            frame1_xform_matrix = xform_matrix.dot(self.get_xform_matrix(state_map, constraint.frame1))
            frame2_xform_matrix = xform_matrix.dot(self.get_xform_matrix(state_map, constraint.frame2))
            constraint.draw(ax, xform_matrix, frame1_xform_matrix, frame2_xform_matrix, scale)

    def get_initial_state_map(self, randomize=False, seed=0):
        frames = daglet.toposort(self.frames, lambda x: x.frames)
        if randomize:
            random.seed(seed)
            randpi = lambda: random.random() * np.pi * 2 - np.pi
            state_map = {frame: (randpi(), randpi()) for frame in frames}
        else:
            state_map = {frame: frame.initial_state for frame in frames}
        return state_map

    
class NaiveSolver(object):
    def __init__(self, scene):
        self.scene = scene
        
    def _make_state_map(self, qs, qds):
        return {frame: (q, qd) for frame, q, qd in zip(self.scene.sorted_frames, qs, qds)}
            
    def _solve(self, qs, qds, qfs):
        state_map = self._make_state_map(qs, qds)
        nframes = len(self.scene.sorted_frames)
        nconstraints = len(self.scene.constraints)
        ncoeffs = nframes + nconstraints
        a_mat = np.zeros((ncoeffs, ncoeffs))
        b_vec = np.zeros((ncoeffs))
        for k, frame_k in enumerate(self.scene.sorted_frames):
            for frame_i in self.scene.sorted_frames:
                frame_i_path = self.scene.frame_path_map[frame_i]
                if frame_k not in frame_i_path:
                    continue
                submat2 = np.zeros((3, 3))
                submat3 = np.zeros((3, 3))
                submat4 = np.zeros((3, 3))
                dgi_dqk = self.scene.get_xform_matrix(state_map, frame_i, frame_k)
                for frame_j in frame_i_path:
                    j = self.scene.sorted_frames.index(frame_j)
                    _, qdj = state_map[frame_j]
                    dgi_dqj = self.scene.get_xform_matrix(state_map, frame_i, frame_j)
                    d2gi_dqjdqk = self.scene.get_xform_matrix(state_map, frame_i, frame_j, frame_k)
                    qddj_coeff_mat = dgi_dqj.T.dot(dgi_dqk)
                    for mass in frame_i.masses:
                        a_mat[k, j] += mass.mass * mass.position.T.dot(qddj_coeff_mat).dot(mass.position)
                    for frame_l in frame_i_path:
                        _, qdl = state_map[frame_l]
                        d2gi_dqjdql = self.scene.get_xform_matrix(state_map, frame_i, frame_j, frame_l)
                        d2gi_dqkdql = self.scene.get_xform_matrix(state_map, frame_i, frame_k, frame_l)
                        submat2 += qdj * qdl * (d2gi_dqjdql.T.dot(dgi_dqk) + dgi_dqj.T.dot(d2gi_dqkdql))
                    submat3 += qdj * dgi_dqj.T
                    submat4 += qdj * d2gi_dqjdqk
                submat5 = submat2 - submat3.dot(submat4)
                for mass in frame_i.masses:
                    b_vec[k] -= mass.mass * mass.position.T.dot(submat5).dot(mass.position)
                    b_vec[k] -= mass.mass * self.scene.gravity * dgi_dqk[1, :].dot(mass.position)

            for spring in self.scene.springs:
                pos_mat1 = self.scene.get_xform_matrix(state_map, spring.frame1)
                pos_mat2 = self.scene.get_xform_matrix(state_map, spring.frame2)
                perturb_mat1 = self.scene.get_xform_matrix(state_map, spring.frame1, frame_k)
                perturb_mat2 = self.scene.get_xform_matrix(state_map, spring.frame2, frame_k)
                displacement = pos_mat1.dot(spring.position1) - pos_mat2.dot(spring.position2)
                perturb = perturb_mat1.dot(spring.position1) - perturb_mat2.dot(spring.position2)
                b_vec[k] -= spring.k * displacement.T.dot(perturb)
                vel_mat1 = self.scene.get_velocity_matrix(state_map, spring.frame1)
                vel_mat2 = self.scene.get_velocity_matrix(state_map, spring.frame2)
                velocity = vel_mat1.dot(spring.position1) - vel_mat2.dot(spring.position2)
                b_vec[k] -= spring.damping * velocity.T.dot(perturb)

            _, qdi = state_map[frame_k]
            resistance_force = - qdi * frame_k.resistance
            ext_force = qfs[k]
            b_vec[k] += resistance_force + ext_force
                
        for c, constraint in enumerate(self.scene.constraints):
            coeff_index = nframes + c
            subframes = [constraint.frame1, constraint.frame2]
            positions = [constraint.position1, constraint.position2]
            signs = [1., -1.]
            pos_mats = [self.scene.get_xform_matrix(state_map, subframe) for subframe in subframes]
            displacement = pos_mats[0].dot(positions[0]) - pos_mats[1].dot(positions[1])
            vel_mats = [self.scene.get_velocity_matrix(state_map, subframe) for subframe in subframes]
            velocity = vel_mats[0].dot(positions[0]) - vel_mats[1].dot(positions[1])
            b_vec[coeff_index] -= velocity.T.dot(velocity)
            for subframe, position, sign in zip(subframes, positions, signs):
                subframe_path = self.scene.frame_path_map[subframe]
                for k, frame_k in enumerate(subframe_path):
                    perturb_mat = self.scene.get_xform_matrix(state_map, subframe, frame_k)
                    disp_perturb = sign * displacement.T.dot(perturb_mat).dot(position)
                    a_mat[k, coeff_index] += 0.5 * disp_perturb
                    a_mat[coeff_index, k] += disp_perturb
                    for l, frame_l in enumerate(subframe_path):
                        _, qdk = state_map[frame_k]
                        _, qdl = state_map[frame_l]
                        d2gc_dqkdql = self.scene.get_xform_matrix(state_map, subframe, frame_k, frame_l)
                        b_vec[coeff_index] += sign * qdk * qdl * displacement.T.dot(d2gc_dqkdql).dot(position)

        qdds = np.linalg.solve(a_mat, b_vec)
        return qdds[:nframes]

    def tick(self, state_map, delta_time, force_map={}):
        def f1(qs, qds, qfs):
            return qds
        
        def f2(qs, qds, qfs):
            return self._solve(qs, qds, qfs)

        qs0 = np.array([state_map[frame][0] for frame in self.scene.sorted_frames])
        qds0 = np.array([state_map[frame][1] for frame in self.scene.sorted_frames])
        qfs = np.array([force_map.get(frame, 0.) for frame in self.scene.sorted_frames])
        k1_qs = f1(qs0, qds0, qfs) * delta_time
        k1_qds = f2(qs0, qds0, qfs) * delta_time
        k2_qs = f1(qs0 + 0.5 * k1_qs, qds0 + 0.5 * k1_qds, qfs) * delta_time
        k2_qds = f2(qs0 + 0.5 * k1_qs, qds0 + 0.5 * k1_qds, qfs) * delta_time
        k3_qs = f1(qs0 + 0.5 * k2_qs, qds0 + 0.5 * k2_qds, qfs) * delta_time
        k3_qds = f2(qs0 + 0.5 * k2_qs, qds0 + 0.5 * k2_qds, qfs) * delta_time
        k4_qs = f1(qs0 + k3_qs, qds0 + k3_qds, qfs) * delta_time
        k4_qds = f2(qs0 + k3_qs, qds0 + k3_qds, qfs) * delta_time
        new_qs = qs0 + (k1_qs + 2 * k2_qs + 2 * k3_qs + k4_qs) / 6.
        new_qds = qds0 + (k1_qds + 2 * k2_qds + 2 * k3_qds + k4_qds) / 6.
        return self._make_state_map(new_qs, new_qds)

In [None]:
def is_binding(node):
    return isinstance(node, Connector) and node.parents and isinstance(node.node, Connector)


def walk_nodes(node):
    if isinstance(node, Connector):
        nodes = node.parents
        if node.node is not None:
            nodes = nodes + [node.node]
    elif isinstance(node, Group):
        nodes = node.nodes
    elif isinstance(node, Frame):
        nodes = node.frames + node.decals + node.masses
    elif isinstance(node, (Spring, Constraint)):
        nodes = [node.frame1, node.frame2]
    else:
        nodes = []
    return nodes


def label_node(node):
    return repr(node)  # TODO: improve


def get_node_color(node):
    if is_binding(node):
        color = 'honeydew'
    elif isinstance(node, Connector):
        color = 'ivory'
    elif isinstance(node, Group):
        color = 'orange'
    elif isinstance(node, Frame):
        color = 'greenyellow'
    elif isinstance(node, Spring):
        color = 'yellow'
    elif isinstance(node, Constraint):
        color = 'tomato'
    elif isinstance(node, Decal):
        color = 'khaki'
    elif isinstance(node, Mass):
        color = 'lightblue'
    elif isinstance(node, StaticTransform):
        color = 'gold'
    else:
        color = 'white'
    return color


def _show_png(png_data, scale=0.5):
    buffer = BytesIO(png_data)
    image = PIL.Image.open(buffer)
    display(Image(png_data, width=image.width * scale, unconfined=True))


def show_node_graph(nodes, parent_func=walk_nodes, label_func=label_node, scale=0.6):
    graph = daglet.make_graph(
        nodes,
        parent_func=parent_func,
        vertex_label_func=label_func,
        vertex_color_func=get_node_color
    )
    #graph.view()
    png_data = graph.pipe(format='png')
    _show_png(png_data, scale)


def merge_node_parents_maps(*node_parents_maps):  # TODO: use List[Tuple[Node, Node]] edgelist instead of Dict[Node, Set]
    out_node_parents_map = {}
    for node_parents_map in node_parents_maps:
        for k, v in node_parents_map.items():
            out_node_parents_map.setdefault(k, set())
            out_node_parents_map[k] |= v
    return out_node_parents_map


def process_node(node, root_node_parents=set()):
    if isinstance(node, Connector):
        nodes, node_parents_map = process_connector(node, root_node_parents)
    else:
        node_parents_map = {node: set()}
        if isinstance(node, Group):
            parents = node.nodes
            nodes = root_node_parents
        elif isinstance(node, (Spring, Constraint)):
            parents = [node.frame1, node.frame2]
            nodes = root_node_parents
        else:
            parents = []
            nodes = {node}
            node_parents_map = {node: root_node_parents}

        for node2 in parents:
            more_nodes, more_node_parents_map = process_node(node2, root_node_parents)
            node_parents_map = merge_node_parents_maps(node_parents_map, more_node_parents_map)
            node_parents_map[node] |= more_nodes
    return nodes, node_parents_map
    

def process_connector(connector, root_node_parents=set()):
    assert isinstance(connector, Connector)
    nodes = set()
    node_parents_map = {}
    if connector.parents:
        for parent in connector.parents:
            parent_nodes, parent_node_parents_map = process_connector(parent, root_node_parents)
            nodes |= parent_nodes
            node_parents_map = merge_node_parents_maps(node_parents_map, parent_node_parents_map)
    else:
        nodes = root_node_parents
    if connector.node:
        nodes, rhs_node_parents_map = process_node(connector.node, nodes)
        node_parents_map = merge_node_parents_maps(node_parents_map, rhs_node_parents_map)
    return nodes, node_parents_map


def transform_connectors(leaf_nodes):
    """Walk node graph and strip out any connectors, stitching together the
    nodes in the order indicated by the connectors.
    
    Returns a mapping of each remaining node to a set of parents of the node,
    sans connectors.
    """
    node_parents_map = {}
    for node in leaf_nodes:
        _, more_node_parents_map = process_node(node)
        node_parents_map = merge_node_parents_maps(node_parents_map, more_node_parents_map)
    return node_parents_map


def get_parent_frame_xform(parent, frame, xform_matrix):
    if isinstance(parent, Frame):
        frame = parent
        xform_matrix = np.identity(3)
    return frame, xform_matrix

    
def get_node_context(node, parent_contexts):
    if isinstance(node, (Spring, Constraint)):
        assert len(parent_contexts) == 2
        parent_xforms = [get_parent_frame_xform(*x) for x in parent_contexts]
        frame1, xform_matrix1 = parent_xforms[0]
        frame2, xform_matrix2 = parent_xforms[1]
        frame = [frame1, frame2]
        xform_matrix = [xform_matrix1, xform_matrix2]
    elif not parent_contexts or isinstance(node, Group):
        frame = None
        xform_matrix = np.identity(3)
    else:
        assert len(parent_contexts) == 1
        frame, xform_matrix = get_parent_frame_xform(*parent_contexts[0])
    if isinstance(node, StaticTransform):
        xform_matrix = xform_matrix.dot(node.get_xform_matrix())
    return node, frame, xform_matrix


def infer_frames_and_xforms(node_parents_map):
    """Determine which frame each node belongs to and accumulate
    transformation matrices along the way.
    
    Returns a mapping of each node to the frame it belongs to and the
    corresponding transformation matrix for the node within the context
    of the frame's local coordinate space.
    """
    node_context_map = daglet.transform_vertices(node_parents_map.keys(), node_parents_map.get, get_node_context)
    node_frame_map = {node: frame for node, (_, frame, _) in node_context_map.items()}
    node_xform_map = {node: xform_matrix for node, (_, _, xform_matrix) in node_context_map.items()}
    return node_frame_map, node_xform_map


def group_frame_nodes(node_frame_map):
    frame_nodes_map = {}
    for node, frame in node_frame_map.items():
        if isinstance(node, (Spring, Constraint)):
            frame = None
        if not isinstance(node, (Group, StaticTransform)):
            frame_nodes_map.setdefault(frame, set())
            frame_nodes_map[frame] |= {node}
    return frame_nodes_map


def get_frame_graph(node_parents_map, node_frame_map, frame_nodes_map):
    """Determine a graph of frames, returning a mapping from each frame to a
    set of child frames.
    """
    frame_parents_map = {}
    for frame, nodes in frame_nodes_map.items():
        frame_parents_map.setdefault(frame, set())
        for node in nodes:
            if not isinstance(node, (Spring, Constraint)):
                for parent in node_parents_map[node]:
                    parent_frame = node_frame_map[parent]
                    if parent_frame is not frame:
                        frame_parents_map[frame] |= {parent_frame}
    frame_children_map = daglet.get_child_map(frame_parents_map.keys(), frame_parents_map.get)
    return frame_children_map


def xform_decals(frame_nodes_map, node_xform_map, parent_frame):
    nodes = frame_nodes_map[parent_frame]
    decals = [x for x in nodes if isinstance(x, Decal)]
    decals = [x.xform(node_xform_map[x]) for x in decals]  # TODO: preserve mapping
    return decals


def xform_masses(frame_nodes_map, node_xform_map, parent_frame):
    nodes = frame_nodes_map[parent_frame]
    masses = [x for x in nodes if isinstance(x, Mass)]
    masses = [x.xform(node_xform_map[x]) for x in masses]  # TODO: preserve mapping
    return masses


def construct_frame(frame_nodes_map, node_xform_map, parent_frame, frames):
    decals = xform_decals(frame_nodes_map, node_xform_map, parent_frame)
    masses = xform_masses(frame_nodes_map, node_xform_map, parent_frame)
    if parent_frame is not None:
        xform_matrix = node_xform_map[parent_frame]
        out = parent_frame.xform(xform_matrix, decals=decals, masses=masses, frames=frames)
    else:
        out = frames
    return out


def xform_springs(frame_nodes_map, node_frame_map, node_xform_map, frame_map):
    springs = [x for x in frame_nodes_map[None] if isinstance(x, Spring)]
    out_springs = []
    for spring in springs:
        xform_matrix1, xform_matrix2 = node_xform_map[spring]
        frame1, frame2 = node_frame_map[spring]
        frame1 = frame_map[frame1]
        frame2 = frame_map[frame2]
        out_springs.append(spring.xform(xform_matrix1, xform_matrix2, frame1, frame2))
    return out_springs


def xform_constraints(frame_nodes_map, node_frame_map, node_xform_map, frame_map):
    constraints = [x for x in frame_nodes_map[None] if isinstance(x, Constraint)]
    out_constraints = []
    for constraint in constraints:
        xform_matrix1, xform_matrix2 = node_xform_map[constraint]
        frame1, frame2 = node_frame_map[constraint]
        frame1 = frame_map[frame1]
        frame2 = frame_map[frame2]
        out_constraints.append(constraint.xform(xform_matrix1, xform_matrix2, frame1, frame2))
    return out_constraints


def do_construct_scene(frame_nodes_map, node_frame_map, node_xform_map, frame_map, scene_kwargs={}):
    frames = frame_map[None]
    decals = xform_decals(frame_nodes_map, node_xform_map, parent_frame=None)
    springs = xform_springs(frame_nodes_map, node_frame_map, node_xform_map, frame_map)
    constraints = xform_constraints(frame_nodes_map, node_frame_map, node_xform_map, frame_map)
    return Scene(decals=decals, frames=frames, springs=springs, constraints=constraints, **scene_kwargs)


_DEFAULT_SCALE = 0.6


def construct_scene(leaf_nodes, show_info=False, info_scale=_DEFAULT_SCALE, **scene_kwargs):
    if isinstance(leaf_nodes, tuple):
        leaf_nodes = list(leaf_nodes)
    elif not isinstance(leaf_nodes, list):
        leaf_nodes = [leaf_nodes]

    info_func_stack = []
    try:
        node_parents_map = transform_connectors(leaf_nodes)
        info_func_stack.append(lambda: display_compilation_stage1(leaf_nodes, scale=info_scale))

        node_frame_map, node_xform_map = infer_frames_and_xforms(node_parents_map)
        info_func_stack.append(lambda: display_compilation_stage2(node_parents_map, scale=info_scale))

        frame_nodes_map = group_frame_nodes(node_frame_map)
        info_func_stack.append(lambda: display_compilation_stage3(node_parents_map, node_frame_map, node_xform_map, scale=info_scale))
        
        frame_children_map = get_frame_graph(node_parents_map, node_frame_map, frame_nodes_map)
        info_func_stack.append(lambda: display_compilation_stage4(frame_children_map, scale=info_scale))

        frame_map = daglet.transform_vertices(
            frame_children_map.keys(),
            frame_children_map.get,
            partial(
                construct_frame,
                frame_nodes_map,
                node_xform_map,
            ),
        )
        info_func_stack.append(lambda: display_compilation_stage5(frame_nodes_map, frame_children_map, scale=info_scale))

    except Exception as error:
        display_compilation_error(error)
        for info_func in info_func_stack:
            info_func()
        raise

    scene = do_construct_scene(frame_nodes_map, node_frame_map, node_xform_map, frame_map, scene_kwargs)

    if show_info:
        for info_func in info_func_stack:
            info_func()

    if True: # (debug)
        springs1 = {x for x in frame_nodes_map[None] if isinstance(x, Spring)}
        springs2 = {x for x in node_parents_map.keys() if isinstance(x, Spring)}
        assert springs1 == springs2

    return scene


def label_node_with_context(node_frame_map, node_xform_map, node):
    frame = node_frame_map[node]
    xform_matrix = node_xform_map[node]
    return '{}\nframe: {}\nxform: {}\n'.format(label_node(node), repr(frame), str(xform_matrix))


def display_markdown(md):
    display(Markdown(dedent(md)))


def display_header(text):
    display_markdown('''
        <hr />

        ### **{}:**
    '''.format(text))


def display_compilation_error(error):
    display_markdown('''
        <hr />
        
        ## <font color='red'>Compilation error</font>
    ''')
    
    display(str(error))
    display(Markdown('<br />\n\n_(Stack trace follows below...)_'))


def display_compilation_stage1(leaf_nodes, scale=_DEFAULT_SCALE):
    display_header('Stage 1 - Connection graph')
    show_node_graph(leaf_nodes, walk_nodes, scale=scale)


def display_compilation_stage2(node_parents_map, scale=_DEFAULT_SCALE):
    display_header('Stage 2 - Connection reordering')
    show_node_graph(node_parents_map.keys(), node_parents_map.get, scale=scale)


def display_compilation_stage3(node_parents_map, node_frame_map, node_xform_map, scale=_DEFAULT_SCALE):
    display_header('Stage 3 - Frame and local-xform inference')
    show_node_graph(
        node_parents_map.keys(),
        node_parents_map.get,
        partial(label_node_with_context, node_frame_map, node_xform_map),
    )


def display_compilation_stage4(frame_children_map, scale=_DEFAULT_SCALE):
    display_header('Stage 4 - Frame ordering')
    show_node_graph(frame_children_map.keys(), frame_children_map.get, scale=scale)


def display_compilation_stage5(frame_nodes_map, frame_children_map, scale=_DEFAULT_SCALE):
    display_header('Stage 5 - Scene construction')
    combined_frame_node_map = merge_node_parents_maps(frame_nodes_map, frame_children_map)
    show_node_graph(combined_frame_node_map.keys(), combined_frame_node_map.get, scale=scale)

In [None]:
class ImprovedSolver(object):
    def __init__(self, scene):
        self.scene = scene
        
    def _make_state_map(self, qs, qds):
        return {frame: (q, qd) for frame, q, qd in zip(self.scene.sorted_frames, qs, qds)}
    
    def _get_pos_mat_map(self, state_map):
        """Determine all the local->global position transformation matrices,
        indexed by frame.
        """
        pos_mat_map = {}
        for frame in self.scene.sorted_frames:
            q, _ = state_map[frame]
            mat = frame.get_pos_matrix(q)            
            parent = self.scene.frame_parent_map[frame]
            if parent is not None:
                mat = pos_mat_map[parent] @ mat
            pos_mat_map[frame] = mat
        return pos_mat_map
    
    def _get_inv_pos_mat_map(self, pos_mat_map):
        """Determine all the global->local ("inverse") position transformation
        matrices, indexed by frame.
        """
        return {k: np.linalg.inv(v) for k, v in pos_mat_map.items()}
    
    def _get_vel_mat_map(self, pos_mat_map, inv_pos_mat_map, state_map):
        """Determine all the global position -> global velocity transformation
        matrices, indexed by frame, where each matrix represents the velocity
        field of the corresponding frame in global coordinates, such that
        right-multiplying the matrix by a global position vector yields a
        global velocity vector.
        """
        vel_mat_map = {}
        for frame in self.scene.sorted_frames:
            q, _ = state_map[frame]
            parent = self.scene.frame_parent_map[frame]
            mat = pos_mat_map[parent] if parent else np.identity(3)
            vel_mat_map[frame] = mat @ frame.get_vel_matrix(q) @ inv_pos_mat_map[frame]
        return vel_mat_map
    
    def _get_accel_mat_map(self, pos_mat_map, inv_pos_mat_map, state_map):
        """Global position -> global acceleration, indexed by frame, where
        each matrix represents the acceleration field of the corresponding
        frame in global coordinates.
        """
        accel_mat_map = {}
        for frame in self.scene.sorted_frames:
            q, _ = state_map[frame]
            parent = self.scene.frame_parent_map[frame]
            mat = pos_mat_map[parent] if parent else np.identity(3)
            accel_mat_map[frame] = mat @ frame.get_accel_matrix(q) @ inv_pos_mat_map[frame]
        return accel_mat_map
    
    def _get_vel_sum_mat_map(
        self, pos_mat_map, inv_pos_mat_map, vel_mat_map, state_map,
    ):
        vel_sum_mat_map = {}
        for frame in self.scene.sorted_frames:
            _, qd = state_map[frame]
            parent = self.scene.frame_parent_map[frame]
            parent_mat = vel_sum_mat_map[parent] if parent else np.zeros(3)
            vel_sum_mat_map[frame] = parent_mat + qd * vel_mat_map[frame]
        return vel_sum_mat_map

    def _get_accel_sum_mat_map(
        self, pos_mat_map, vel_mat_map, accel_mat_map, vel_sum_mat_map, state_map,
    ):
        accel_sum_mat_map = {}
        for frame in self.scene.sorted_frames:
            _, qd = state_map[frame]
            parent = self.scene.frame_parent_map[frame]
            if parent is not None:
                parent_accel_sum_mat = accel_sum_mat_map[parent]
                parent_vel_sum_mat = vel_sum_mat_map[parent]
                vel_mat = vel_mat_map[frame]
                mat = parent_accel_sum_mat + 2 * qd * (parent_vel_sum_mat @ vel_mat)
            else:
                mat = np.zeros(3)
            accel_sum_mat_map[frame] = mat + qd * qd * accel_mat_map[frame]
        return accel_sum_mat_map
    
    def _get_mass_pos_map(self, pos_mat_map, state_map):
        """Transform all the mass positions of all the frames into global
        positions, indexed by frame and mass reference.
        """
        mass_pos_map = {}
        for frame in self.scene.sorted_frames:
            pos_mat = pos_mat_map[frame]
            frame_mass_pos_map = {}
            for mass in frame.masses:
                frame_mass_pos_map[mass] = pos_mat @ mass.position
            mass_pos_map[frame] = frame_mass_pos_map
        return mass_pos_map
    
    def _get_mass_vel_map(self, mass_pos_map, vel_sum_mat_map, state_map):
        """Transform all the mass positions of all the frames into global
        positions, indexed by frame and mass reference.
        """
        mass_vel_map = {}
        for frame in self.scene.sorted_frames:
            frame_vel_sum_mat = vel_sum_mat_map[frame]
            frame_mass_pos_map = mass_pos_map[frame]
            frame_mass_vel_map = {}
            for mass in frame.masses:
                frame_mass_vel_map[mass] = frame_vel_sum_mat @ frame_mass_pos_map[mass]
            mass_vel_map[frame] = frame_mass_pos_map
        return mass_vel_map
    
    def get_system_of_equations(self, qs, qds, qfs=None):
        state_map = self._make_state_map(qs, qds)
        nframes = len(self.scene.sorted_frames)
        ncoeffs = nframes
        a_mat = np.zeros((ncoeffs, ncoeffs))
        b_vec = np.zeros((ncoeffs))

        pos_mat_map = self._get_pos_mat_map(state_map)
        inv_pos_mat_map = self._get_inv_pos_mat_map(pos_mat_map)
        vel_mat_map = self._get_vel_mat_map(pos_mat_map, inv_pos_mat_map, state_map)
        accel_mat_map = self._get_accel_mat_map(pos_mat_map, inv_pos_mat_map, state_map)
        vel_sum_mat_map = self._get_vel_sum_mat_map(
            pos_mat_map, inv_pos_mat_map, vel_mat_map, state_map
        )
        accel_sum_mat_map = self._get_accel_sum_mat_map(
            pos_mat_map, vel_mat_map, accel_mat_map, vel_sum_mat_map, state_map
        )
        mass_pos_map = self._get_mass_pos_map(pos_mat_map, state_map)
        mass_vel_map = self._get_mass_vel_map(mass_pos_map, vel_sum_mat_map, state_map)

        for index_i, frame_i in enumerate(self.scene.sorted_frames):
            frame_i_path = self.scene.frame_path_map[frame_i]
            vel_mat_i = vel_mat_map[frame_i]
            for index_h, frame_h in enumerate(self.scene.sorted_frames):
                frame_h_path = self.scene.frame_path_map[frame_h]
                if frame_i not in frame_h_path and frame_h not in frame_i_path:
                    continue
                qdd_coeff = 0.
                vel_mat_h = vel_mat_map[frame_h]
                min_index_j = min(index_i, index_h)
                for index_j, frame_j in enumerate(self.scene.sorted_frames[min_index_j:], min_index_j):
                    frame_j_path = self.scene.frame_path_map[frame_j]
                    if frame_i not in frame_j_path or frame_h not in frame_j_path:
                        continue
                    for mass in frame_j.masses:
                        mass_pos = mass_pos_map[frame_j][mass]
                        qdd_coeff += mass.mass * ((vel_mat_h @ mass_pos).transpose() @ (vel_mat_i @ mass_pos))

                a_mat[index_i, index_h] = qdd_coeff

            for index_j, frame_j in enumerate(self.scene.sorted_frames):  # [index_i:], index_i):
                frame_j_path = self.scene.frame_path_map[frame_j]
                if frame_i not in frame_j_path:
                    continue
                vel_sum_mat_j = vel_sum_mat_map[frame_j]
                accel_sum_mat_j = accel_sum_mat_map[frame_j]
                for mass in frame_j.masses:
                    mass_pos = mass_pos_map[frame_j][mass]
                    mass_vel_i = vel_mat_i @ mass_pos
                    mass_vel_sum_j = vel_sum_mat_j @ mass_pos
                    mass_accel_sum_j = accel_sum_mat_j @ mass_pos
                    b_vec[index_i] -= mass.mass * (mass_vel_i.transpose() @ mass_accel_sum_j)
                    b_vec[index_i] -= mass.drag * (mass_vel_i.transpose() @ mass_vel_sum_j)
                    b_vec[index_i] -= mass.mass * self.scene.gravity * mass_vel_i[1]

            _, qd_i = state_map[frame_i]
            b_vec[index_i] -= frame_i.resistance * qd_i

            if qfs is not None:
                b_vec[index_i] += qfs[index_i]
        
            for spring in self.scene.springs:
                path1 = self.scene.frame_path_map[spring.frame1]
                path2 = self.scene.frame_path_map[spring.frame2]
                if frame_i not in path1 and frame_i not in path2:
                    continue
                pos1 = pos_mat_map[spring.frame1] @ spring.position1
                pos2 = pos_mat_map[spring.frame2] @ spring.position2
                vel1 = vel_sum_mat_map[spring.frame1] @ pos1
                vel2 = vel_sum_mat_map[spring.frame2] @ pos2
                pos_diff = pos1 - pos2
                vel_diff = vel1 - vel2
                perturb = np.zeros((3, 1))
                if frame_i in path1:
                    perturb += (vel_mat_i @ pos1)
                if frame_i in path2:
                    perturb -= (vel_mat_i @ pos2)
                b_vec[index_i] -= (pos_diff.transpose() @ perturb) * spring.k
                b_vec[index_i] -= (vel_diff.transpose() @ perturb) * spring.damping

        if False:
            def format_mat_map(mat_map):
                return '\n'.join([f'{k}: {repr(v)}' for k, v in mat_map.items()])
            print('=== vel_mat_map:\n', format_mat_map(vel_mat_map))
            print('=== accel_mat_map:\n', format_mat_map(accel_mat_map))
            print('=== vel_sum_mat_map:\n', format_mat_map(vel_sum_mat_map))
            print('=== accel_sum_mat_map:\n', format_mat_map(accel_sum_mat_map))
            #assert 0

        #assert np.all(np.linalg.eigvals(a_mat) > 0)
        #print(np.det(a_mat))
        #print(np.linalg.eigvals(a_mat))

        return a_mat, b_vec

    def _solve(self, qs, qds, qfs):
        a_mat, b_vec = self.get_system_of_equations(qs, qds, qfs)

        qdds = np.linalg.solve(a_mat, b_vec)
        #norm = np.linalg.norm(qdds)
        #OVERFLOW = 10000.  # tbd
        #if norm > OVERFLOW:
        #    qdds /= (norm / OVERFLOW)
        
        #qdds[qdds > OVERFLOW] = OVERFLOW
        #qdds[qdds < -OVERFLOW] = -OVERFLOW

        #print(qdds)
        nframes = len(self.scene.sorted_frames)
        return qdds[:nframes]
    
    def tick_simple(self, state_map, delta_time, force_map={}):
        qs = np.array([state_map[frame][0] for frame in self.scene.sorted_frames])
        qds = np.array([state_map[frame][1] for frame in self.scene.sorted_frames])
        qfs = np.array([force_map.get(frame, 0.) for frame in self.scene.sorted_frames])
        qdds = self._solve(qs, qds, qfs)
        new_qs = qs + qds * delta_time
        new_qds = qds + qdds * delta_time
        return self._make_state_map(new_qs, new_qds)

    def tick(self, state_map, delta_time, force_map={}):
        def f1(qs, qds, qfs):
            return qds
        
        def f2(qs, qds, qfs):
            return self._solve(qs, qds, qfs)

        qs0 = np.array([state_map[frame][0] for frame in self.scene.sorted_frames])
        qds0 = np.array([state_map[frame][1] for frame in self.scene.sorted_frames])
        qfs = np.array([force_map.get(frame, 0.) for frame in self.scene.sorted_frames])
        k1_qs = f1(qs0, qds0, qfs) * delta_time
        k1_qds = f2(qs0, qds0, qfs) * delta_time
        k2_qs = f1(qs0 + 0.5 * k1_qs, qds0 + 0.5 * k1_qds, qfs) * delta_time
        k2_qds = f2(qs0 + 0.5 * k1_qs, qds0 + 0.5 * k1_qds, qfs) * delta_time
        k3_qs = f1(qs0 + 0.5 * k2_qs, qds0 + 0.5 * k2_qds, qfs) * delta_time
        k3_qds = f2(qs0 + 0.5 * k2_qs, qds0 + 0.5 * k2_qds, qfs) * delta_time
        k4_qs = f1(qs0 + k3_qs, qds0 + k3_qds, qfs) * delta_time
        k4_qds = f2(qs0 + k3_qs, qds0 + k3_qds, qfs) * delta_time
        new_qs = qs0 + (k1_qs + 2 * k2_qs + 2 * k3_qs + k4_qs) / 6.
        new_qds = qds0 + (k1_qds + 2 * k2_qds + 2 * k3_qds + k4_qds) / 6.
        return self._make_state_map(new_qs, new_qds)

In [None]:
def Pendulum(length=3., mass=1., radius=0.5, initial_state=-np.pi/2, resistance=0., drag=0., linewidth=1.5):
    return (
        RotationalFrame(resistance=resistance, initial_state=initial_state)
        | LineDecal((length, 0.), linewidth=linewidth)
        | Translation((length, 0.))
        | CircleDecal(radius=radius)
        | Mass(mass, drag=drag)
    )


def SpringPendulum(base, length=3., k=1., mass=1., radius=0.5, initial_state_x=0, initial_state_y=2):
    pendulum = (
        base
        | TrackFrame(initial_state=initial_state_x, angle=np.pi/2)
        | TrackFrame(initial_state=initial_state_y)
        | CircleDecal(radius=radius)
        | Mass(0.1)
    )
    return Spring(base, pendulum, k=k)


def TrackCart(mass=1., size=1., ratio=1.618, resistance=0., drag=0., initial_state=ZERO_STATE, left=True, right=True, solid=True):
    return (
        LineDecal(
            -100 if left else 0,
            100 if right else 0,
            linewidth=0.5
        )
        | TrackFrame(resistance=resistance, initial_state=initial_state)
        | BoxDecal(width=ratio * size, height=size, solid=solid)
        | Mass(mass, drag=drag)
    )


def ParabolicTrackBall(mass=1., initial_state=ZERO_STATE, radius=None, a=0.02):
    if radius is None:
        radius = np.sqrt(mass) / 2.
    frame = QuadraticFrame(a=a, initial_state=initial_state)
    positions = [frame.get_pos_matrix(i).dot(ZERO_POS) for i in range(-30, 30)]
    lines = [LineDecal(pos1, pos2, linewidth=0.5) for pos1, pos2 in zip(positions, positions[1:])]
    return (
        Group(lines)
        | frame
        | Mass(mass)
        | CircleDecal(radius=radius)
    )


def Spoke(angle, length=10., mass=5.):
    return (
        Rotation(angle)
        | LineDecal((length, 0.), linewidth=0.5)
        | Translation((length, 0.))
        | Mass(mass=mass)
        | CircleDecal(radius=0.2)
    )


scale = 0.7
view_x = 0
view_y = 0
randomize = False
seed = 0


if False:  # anim027
    poi_count = 7
    poi_length = 6.
    spoke_length = 10.
    wheel = (
        ParabolicTrackBall(initial_state=(22., 0.), mass=30, radius=0.6)
        | RotationalFrame(initial_state=(-3.5, 0.5))
    )
    spokes = [
        Spoke(length=spoke_length, angle=(i / poi_count * 2 * np.pi))
        for i in range(poi_count - 3)
    ]
    pois = [
        spoke
        #| SpringPendulum(spoke)
        | Pendulum(length=poi_length, mass=1., initial_state=(-np.pi/2, (i % 2) * 2))
        for i, spoke in enumerate(spokes)
    ]
    node = (
        wheel
        | CircleDecal(radius=0.1)
        | Group([pois[0], spokes[1], spokes[2], pois[3]])
    )

if False:  # anim031
    scale = 0.7
    view_x = 0
    view_y = -2

    anchor = (
        TrackCart()
        | LineDecal((6, 4))
        | Translation((6, 4))
        | CircleDecal(radius=0.2)
    )
    base = (
        RotationalFrame(resistance=1.)
        | CircleDecal(radius=0.1)
    )
    cart = base | TrackCart(initial_state=2, resistance=0.5, left=False)
    node = (
        Group(
            (
                cart
                | Pendulum(resistance=0.5, initial_state=1.)
                | Pendulum(resistance=0.5, initial_state=-0.3)
            ),
            Spring(cart, anchor, k=3., damping=1.)
        )
    )


if False:  # anim032
    scale = 0.4
    poi_count = 8
    poi_length = 4.
    spoke_length = 10.
    wheel = (
        ParabolicTrackBall(initial_state=(22., 0.), mass=30, radius=0.6)
        | RotationalFrame(initial_state=(-3.5, 1.5))
    )
    spokes = [
        Spoke(length=spoke_length, angle=(i / poi_count * 2 * np.pi))
        for i in range(poi_count)
    ]
    pois = [
        spoke
        #| SpringPendulum(spoke)
        | Pendulum(length=poi_length, mass=1., resistance=2., initial_state=(-np.pi/2, (i % 2) * 2))
        for i, spoke in enumerate(spokes)
    ]
    node = (
        wheel
        | CircleDecal(radius=0.1)
        | Group(pois)
    )

    
if False:  # anim033
    scale = 0.4
    poi_count = 8
    poi_length = 4.
    spoke_length = 10.
    node = (
        RotationalFrame(resistance=100.)
        | CircleDecal(radius=0.1)
        | Group([
            Spoke(length=spoke_length, angle=(i / poi_count * 2 * np.pi))
            | Pendulum(length=poi_length, mass=1., resistance=3.)
            for i in range(poi_count)
        ])
    )
    
if False:
    scale = 0.6
    view_y = -2
    cart = TrackCart(initial_state=(-10, 0), mass=20)
    base1 = (
        cart
        | LineDecal((-6, -2), linewidth=3)
        | Translation((-6, -2))
        | CircleDecal(radius=0.2)
    )
    base2 = (
        cart
        | LineDecal((6, -2), linewidth=3)
        | Translation((6, -2))
        | CircleDecal(radius=0.2)
    )
    def Pendulum2(
        resistance=5,
        linewidth=0.5,
        length=3.,
        initial_state=0,
        **kwargs
    ):
        return Pendulum(
            resistance=resistance,
            linewidth=linewidth,
            length=length,
            initial_state=initial_state,
            radius=0.3,
            mass=0.2,
            **kwargs
        )
    middle = (
        base1
        | Pendulum2(initial_state=-np.pi/2, length=4)
        | Pendulum2(initial_state=np.pi/2)
        | Pendulum2()
    )
    end = (
        middle
        | Pendulum2()
        | Pendulum2()
    )
    node = (
        Group(
            middle | Pendulum(length=9),
            #Constraint(
            #    pendulum,
            #    base2,
            #    linewidth=0.5
            #),
            Spring(end, base2, k=20, damping=3)
        )
    )
    #node = pendulum

if True:  # swingycart
    
    def Pendulum2(length=0.8, mass=0.1, radius=0., initial_state=0., resistance=0.3, drag=0.05, linewidth=2.5):
        return (
            RotationalFrame(resistance=resistance, initial_state=initial_state)
            | LineDecal((length, 0.), linewidth=linewidth)
            | Translation((length, 0.))
            | CircleDecal(radius=radius)
            | Mass(mass, drag=drag)
        )

    scale = 1.
    view_y = -4
    ball1 = (
        TrackCart(size=0.8, mass=8, initial_state=-8, ratio=3)
        | Translation((0, -0.4))
        | CircleDecal(radius=0.3)

        #RotationalFrame(initial_state=(-1, -2))
        #| LineDecal((3, 0.), linewidth=16)
        #| Translation((3, 0.))
        #| Mass(mass=200)
        #| CircleDecal(radius=0.6)

        | Pendulum2(initial_state=(np.pi / 2 * 0.45, 0.5))
        | Pendulum2()
        | Pendulum2()
        | Pendulum2()
        | Pendulum2()
        | Pendulum2()
        | Pendulum2()
        | Pendulum2()
        | Pendulum2(mass=4, radius=0.6, length=2, drag=1.3)        
    )
    ball2 = (
        Translation((0, 3))
        #| Rotation(np.pi/2)
        | LineDecal((-30, 0), (30, 0), linewidth=0.5)
        | TrackFrame(initial_state=(2, 0))
        | CircleDecal(radius=0.6)
        | Mass(4, drag=1.3)
    )
    node = Spring(ball1, ball2, k=5.5, damping=0.01)
    node = ball1


scene = construct_scene(node)
force_map = {}


@interact(
    scale=(0.1, 5.),
    view_x=(-30, 30),
    view_y=(-30, 30),
    seed=(0, 100, 1),
)
def f(
    scale=scale,
    view_x=view_x,
    view_y=view_y,
    randomize=randomize,
    seed=seed
):
    fig, ax0 = plt.subplots(1, 1, figsize=(16, 9))
    xform_matrix = get_translation_matrix((-view_x, -view_y)).dot(get_scale_matrix(scale, scale))
    state_map = scene.get_initial_state_map(randomize=randomize, seed=seed)
    scene.draw(ax0, state_map, xform_matrix)

@interact()
def f(show_compilation_info=False):
    if show_compilation_info:
        construct_scene(node, show_info=True)
    else:
        display('')
        
#construct_scene(node, show_info=True)

In [None]:
fps = 25
sample_interval = 4
sps = sample_interval * fps
time_limit = 15
state_count = int(time_limit * sps)


def has_nans(state_map):
    for frame, (q, qd) in state_map.items():
        if np.isnan(q) or np.isnan(qd):
            return True
    return False


def simulate():
    print('Simulating...')
    #solver = NaiveSolver(scene)
    solver = ImprovedSolver(scene)
    state_map = scene.get_initial_state_map(randomize=False)
    state_maps = [state_map]
    for i in tqdm(range(state_count)):
        state_map = solver.tick(state_map, 1 / sps, force_map)
        state_maps.append(state_map)
        if has_nans(state_map):
            print('Aborting simulation due to unstable results')
            break
    return state_maps


def render(state_maps, view_x=view_x, view_y=view_y, scale=scale):
    print('Rendering...')
    xform_matrix = get_translation_matrix((-view_x, -view_y)).dot(get_scale_matrix(scale, scale))

    def draw(fig, time_index):
        ax = fig.subplots(1, 1)
        return scene.draw(ax, state_maps[time_index], xform_matrix)

    anim = do_render_animation(draw_func, len(state_maps), sample_interval=sample_interval, tqdm=tqdm)
    anim.save('anim.mp4')
    display(anim)


if True:
    state_maps = simulate()

In [None]:
@interact(
    scale=(0.1, 5.),
    view_x=(-30, 30),
    view_y=(-30, 30),
)
def f(
    scale=scale,
    view_x=view_x,
    view_y=view_y,
):
    #sample_interval = 20
    sampled_state_maps = state_maps[::sample_interval*2]
    xform_matrix = get_translation_matrix((-view_x, -view_y)).dot(get_scale_matrix(scale, scale))
    
    def draw(fig, time_index):
        ax = fig.subplots(1, 1)
        return scene.draw(ax, sampled_state_maps[time_index], xform_matrix)

    display(PyplotRenderer(draw, len(sampled_state_maps)))

In [None]:
def Pendulum(length=3., mass=1., radius=0.5, initial_state=0, resistance=0., drag=0., linewidth=1.5):
    return (
        RotationalFrame(resistance=resistance, initial_state=initial_state)
        | LineDecal((length, 0.), linewidth=linewidth)
        | Translation((length, 0.))
        | CircleDecal(radius=radius)
        | Mass(mass, drag=drag)
    )


q_grid = np.zeros((len(scene.sorted_frames), len(state_maps)))
qd_grid = np.zeros(q_grid.shape)

for i, frame in enumerate(scene.sorted_frames):
    for t, state_map in enumerate(state_maps):
        q_grid[i, t], qd_grid[i, t] = state_map[frame]
        
#window = np.hamming(31)
window = np.kaiser(51, 14)
window_tile = np.repeat(window[np.newaxis, :], len(scene.sorted_frames), axis=0)


def render_scene_debugger(fig, time_index, state_map, scale=0.5, view_x=view_x, view_y=view_y):
    solver = ImprovedSolver(scene)
    pos_mat_map = solver._get_pos_mat_map(state_map)
    inv_pos_mat_map = solver._get_inv_pos_mat_map(pos_mat_map)
    vel_mat_map = solver._get_vel_mat_map(pos_mat_map, inv_pos_mat_map, state_map)
    accel_mat_map = solver._get_accel_mat_map(pos_mat_map, inv_pos_mat_map, state_map)
    vel_sum_mat_map = solver._get_vel_sum_mat_map(
        pos_mat_map, inv_pos_mat_map, vel_mat_map, state_map
    )
    accel_sum_mat_map = solver._get_accel_sum_mat_map(
        pos_mat_map, vel_mat_map, accel_mat_map, vel_sum_mat_map, state_map
    )
    mass_pos_map = solver._get_mass_pos_map(pos_mat_map, state_map)
    
    qs = np.array([state_map[frame][0] for frame in scene.sorted_frames])
    qds = np.array([state_map[frame][1] for frame in scene.sorted_frames])
    a_mat, b_vec = solver.get_system_of_equations(qs, qds)

    ax0, ax1 = fig.subplots(1, 2)  #, linewidth=8, edgecolor='black')
    xform_matrix = get_translation_matrix((-view_x, -view_y)).dot(get_scale_matrix(scale, scale))

    scene.draw(ax0, state_map, xform_matrix, draw_options=DrawOptions(aspect=1))
    #ax0.set_title('Scene')

    for i, frame in enumerate(scene.sorted_frames):
        for mass in frame.masses:
            mass_pos = mass_pos_map[frame][mass]
            mass_vel = vel_sum_mat_map[frame] @ mass_pos
            mass_accel = accel_sum_mat_map[frame] @ mass_pos
            mass_accel_norm = np.linalg.norm(mass_accel)
            #if mass_accel_norm > 0.01:
            #    mass_accel /= np.sqrt(mass_accel_norm)
            view_mass_pos = xform_matrix @ mass_pos
            view_mass_vel = (xform_matrix @ mass_vel) * 0.5
            view_mass_accel = (xform_matrix @ mass_accel) * 0.1
            ax0.arrow(
                view_mass_pos[0, 0],
                view_mass_pos[1, 0],
                view_mass_vel[0, 0],
                view_mass_vel[1, 0],
                color='cornflowerblue',
                #linewidth=8 * scale,
                zorder=3,
                width=0.05,
                head_width=8*0.05,
                alpha=0.7,
            )
            ax0.arrow(
                view_mass_pos[0, 0],
                view_mass_pos[1, 0],
                view_mass_accel[0, 0],
                view_mass_accel[1, 0],
                color='lime',
                #linewidth=8 * scale,
                zorder=2,
                width=0.05,
                head_width=8*0.05,
                alpha=0.7,
            )
            
    if False:
        chunk = qd_grid[:, time_index:time_index + len(window)]
        chunk = chunk * window_tile
        A = np.fft.fft(chunk, 256)
        #A = np.fft.fftshift(A)
        response = np.abs(A)
        #response = np.abs(np.fft.fftshift(A))
        #freqs = np.linspace(-0.5, 0.5, len(A))
        response = 20 * np.log10(response)
        #response = np.clip(response, -100, 100)
        
        sns.heatmap(
            #qd_grid,
            response,
            #xticklabels=freqs[::len(freqs)//10],
            ax=ax1,
            cmap='cool',
            vmin=-75,
            vmax=75,
        )

    if True:
        #eig_values, eig_vectors = np.linalg.eig(a_mat)
        #print(eig_values)
        #print(np.linalg.det(a_mat))
        mat = np.hstack((a_mat, b_vec[:, np.newaxis]))

        #colormesh = ax1.pcolormesh(a_mat, vmin=0, vmax=500)
        #fig.colorbar(colormesh)
        sns.heatmap(
            mat,
            annot=True,
            annot_kws={'size':10},
            fmt='.0f',
            alpha=0.7,
            cmap='coolwarm',
            vmin=0,
            vmax=500,
        )
        ax1.set_title('q̣̈ coefficients | virtual forces')


@interact(
    scale=(0.1, 5.),
    view_x=(-30, 30),
    view_y=(-30, 30),
    time=(0, len(state_maps) - 1),
)
def f(
    scale=0.8,
    view_x=-5.,
    view_y=view_y,
):
    #fig = plt.figure(figsize=(20, 9))
    #render_scene_debugger(fig, time, scale=scale, view_x=view_x, view_y=view_y)

    sampled_state_maps = state_maps[::sample_interval]
    
    def render(fig, time_index):
        state_map = sampled_state_maps[time_index]
        return render_scene_debugger(fig, time_index, state_map, scale=scale, view_x=view_x, view_y=view_y)

    display(PyplotRenderer(render, len(sampled_state_maps), width=1440, height=810))

In [None]:
def f():
    time_index = 210
    chunk = qd_grid[:, time_index:time_index + len(window)]
    chunk = chunk * window_tile
    plt.plot(chunk[0, :])

#f()

In [None]:
fig = plt.figure(figsize=(3, 4))
fig.clear()

In [None]:
def Pendulum(length=3., mass=1., radius=0.5, initial_state=0, resistance=0., drag=0., linewidth=1.5):
    return (
        RotationalFrame(resistance=resistance, initial_state=initial_state)
        | LineDecal((length, 0.), linewidth=linewidth)
        | Translation((length, 0.))
        | CircleDecal(radius=radius)
        | Mass(mass, drag=drag)
    )


@interact(
    scale=(0.1, 5.),
    view_x=(-30, 30),
    view_y=(-30, 30),
    length=(0.1, 8.,0.1),
    q1=(-180, 180, 5),
    q2=(-180, 180, 5),
    q3=(-180, 180, 5),
    q1d=(-3., 3., 0.1),
    q2d=(-3., 3., 0.1),
    q3d=(-3., 3., 0.1),
)
def f(
    scale=scale,
    view_x=view_x,
    view_y=view_y,
    length=5.,
    q1=-60.,
    q2=40.,
    q3=0.,
    q1d=1.,
    q2d=1.,
    q3d=1.,
):
    q1 = q1 / 180 * np.pi
    q2 = q2 / 180 * np.pi
    q3 = q3 / 180 * np.pi    
    node = (
        Pendulum(length=length, initial_state=q1)
        | Pendulum(length=length, initial_state=q2)
        | Pendulum(length=length, initial_state=q3)
    )
    scene = construct_scene(node)
    pendulum1 = scene.frames[0]
    pendulum2 = pendulum1.frames[0]
    pendulum3 = pendulum2.frames[0]
    state_map = scene.get_initial_state_map()
    state_map[pendulum1] = (q1, q1d)
    state_map[pendulum2] = (q2, q2d)
    state_map[pendulum3] = (q3, q3d)

    fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(20, 9), linewidth=8, edgecolor='black')
    init_ax(ax0)
    xform_matrix = get_translation_matrix((-view_x, -view_y)).dot(get_scale_matrix(scale, scale))

    scene.draw(ax0, state_map, xform_matrix)
    ax0.set_title('Scene')

    solver = ImprovedSolver(scene)
    
    qs = np.array([state_map[frame][0] for frame in scene.sorted_frames])
    qds = np.array([state_map[frame][1] for frame in scene.sorted_frames])
    a_mat, b_vec = solver.get_system_of_equations(qs, qds)
    mat = np.hstack((a_mat, b_vec[:, np.newaxis]))
    #colormesh = ax1.pcolormesh(a_mat, vmin=0, vmax=500)
    #fig.colorbar(colormesh)
    sns.heatmap(mat, annot=True, fmt='.2f', alpha=0.7, cmap='coolwarm', vmin=0, vmax=500)
    ax1.set_title('q̣̈ coefficients | virtual forces')


In [None]:

scale = 1.

def Pendulum(length=3., mass=1., radius=0.5, initial_state=0, resistance=0., drag=0., linewidth=1.5):
    return (
        RotationalFrame(resistance=resistance, initial_state=initial_state)
        | LineDecal((length, 0.), linewidth=linewidth)
        | Translation((length, 0.))
        | CircleDecal(radius=radius)
        | Mass(mass, drag=drag)
    )

@interact(
    scale=(0.1, 5.),
    view_x=(-30, 30),
    view_y=(-30, 30),
    seed=(0, 100, 1),
    length=(0.1, 5.,0.1),
    q1=(-180, 180, 5),
    q2=(-180, 180, 5),
    q1d=(-3., 3., 0.1),
    q2d=(-3., 3., 0.1),
)
def f(
    scale=scale,
    view_x=view_x,
    view_y=view_y,
    length=3.,
    q1=0.,
    q2=55.,
    q1d=1.,
    q2d=1.,
    show_Gd=True,
    show_dGdq1=False,
    show_dGdq2=False,
    show_dGdq1Tdgdq2=False,    
    show_dGdq2Tdgdq1=False,
    show_d2Gdq1dq1=False,
    show_d2Gdq2dq2=False,
    show_d2Gdq1dq2=False,
    show_scene=True,
    show_contour=True,
    show_vectors=True,
):
    q1 = q1 / 180 * np.pi
    q2 = q2 / 180 * np.pi    
    node = Pendulum(length=length, initial_state=q1 ) | Pendulum(length=length, initial_state=q2)
    scene = construct_scene(node)
    pendulum1 = scene.frames[0]
    pendulum2 = pendulum1.frames[0]
    state_map = scene.get_initial_state_map()
    state_map[pendulum1] = (q1, q1d)
    state_map[pendulum2] = (q2, q2d)

    fig, ax0 = plt.subplots(1, 1, figsize=(16, 9))
    init_ax(ax0)
    xform_matrix = get_translation_matrix((-view_x, -view_y)).dot(get_scale_matrix(scale, scale))

    if show_scene:
        scene.draw(ax0, state_map, xform_matrix)

    G = scene.get_xform_matrix(state_map, pendulum2)
    Gd = scene.get_velocity_matrix(state_map, pendulum2)
    dGdq1 = scene.get_xform_matrix(state_map, pendulum2, pendulum1)
    dGdq2 = scene.get_xform_matrix(state_map, pendulum2, pendulum2)

    xs = np.linspace(-10., 10., 15)
    ys = np.linspace(-10., 10., 15)
    xg, yg = np.meshgrid(xs, ys)
    zg = np.ones(xg.shape)
    local_pos = np.array(np.row_stack((xg.ravel(), yg.ravel(), zg.ravel())))
    global_pos = G.dot(local_pos)
    view_pos = xform_matrix.dot(global_pos)
    
    def do_contour(mat, color):
        if show_contour:
            ps = []
            for p in local_pos.T:
                ps.append(p.T.dot(mat).dot(p))                
            ax0.contour(view_pos[0].reshape(xg.shape), view_pos[1].reshape(yg.shape), np.array(ps).reshape(xg.shape), levels=6, colors=color)
    
    def do_quiver(mat, color):
        if show_vectors:
            global_vel = mat.dot(local_pos)
            view_vel = xform_matrix.dot(global_vel)
            ax0.quiver(view_pos[0], view_pos[1], view_vel[0], view_vel[1], width=0.002, color=color, alpha=0.3)

    if show_Gd:
        do_quiver(Gd, 'green')
        do_contour(Gd.T.dot(Gd), 'green')
            
    if show_dGdq1:
        do_quiver(dGdq1, 'red')
        do_contour(dGdq1.T.dot(dGdq1), 'red')

    if show_dGdq2:
        do_quiver(dGdq2, 'blue')
        do_contour(dGdq2.T.dot(dGdq2), 'blue')

    if show_dGdq1Tdgdq2:
        do_contour(dGdq1.T.dot(dGdq2), 'purple')

    if show_dGdq2Tdgdq1:
        do_contour(dGdq2.T.dot(dGdq1), 'pink')

    if show_d2Gdq1dq1:
        d2Gdq1dq1 = scene.get_xform_matrix(state_map, pendulum2, pendulum1, pendulum1)
        do_quiver(d2Gdq1dq1, 'purple')
        do_contour(d2Gdq1dq1.T.dot(d2Gdq1dq1), 'purple')
    
    if show_d2Gdq2dq2:
        d2Gdq2dq2 = scene.get_xform_matrix(state_map, pendulum2, pendulum2, pendulum2)
        do_quiver(d2Gdq2dq2, 'blue')
        do_contour(d2Gdq2dq2.T.dot(d2Gdq2dq2), 'blue')
        
    if show_d2Gdq1dq2:
        d2Gdq1dq2 = scene.get_xform_matrix(state_map, pendulum2, pendulum1, pendulum2)
        do_quiver(d2Gdq1dq2, 'green')
        do_contour(d2Gdq1dq2.T.dot(d2Gdq1dq2), 'green')

In [None]:

scale = 1.

def Pendulum(length=3., mass=1., radius=0.5, initial_state=0, resistance=0., drag=0., linewidth=1.5):
    return (
        RotationalFrame(resistance=resistance, initial_state=initial_state)
        | LineDecal((length, 0.), linewidth=linewidth)
        | Translation((length, 0.))
        | CircleDecal(radius=radius)
        | Mass(mass, drag=drag)
    )

@interact(
    scale=(0.1, 5.),
    view_x=(-30, 30),
    view_y=(-30, 30),
    seed=(0, 100, 1),
    length=(0.1, 5.,0.1),
    q1=(-180, 180, 5),
    q2=(-180, 180, 5),
    q1d=(-3., 3., 0.1),
    q2d=(-3., 3., 0.1),
    deriv_frame1=[None, 0, 1, 2],
    deriv_frame2=[None, 0, 1, 2],
)
def f(
    scale=scale,
    view_x=view_x,
    view_y=view_y,
    length=3.,
    q1=0.,
    q2=55.,
    q1d=1.,
    q2d=1.,
    show_Gd=True,
    deriv_frame1=None,
    deriv_frame2=None,
    show_dGdq1=False,
    show_dGdq2=False,
    show_dGdq1Tdgdq2=False,    
    show_dGdq2Tdgdq1=False,
    show_d2Gdq1dq1=False,
    show_d2Gdq2dq2=False,
    show_d2Gdq1dq2=False,
    show_scene=True,
    show_contour=True,
    show_vectors=True,
):
    q1 = q1 / 180 * np.pi
    q2 = q2 / 180 * np.pi    
    node = Pendulum(length=length, initial_state=q1 ) | Pendulum(length=length, initial_state=q2)
    scene = construct_scene(node)
    pendulum1 = scene.frames[0]
    pendulum2 = pendulum1.frames[0]
    state_map = scene.get_initial_state_map()
    state_map[pendulum1] = (q1, q1d)
    state_map[pendulum2] = (q2, q2d)

    fig, ax0 = plt.subplots(1, 1, figsize=(16, 9))
    init_ax(ax0)
    xform_matrix = get_translation_matrix((-view_x, -view_y)).dot(get_scale_matrix(scale, scale))

    if show_scene:
        scene.draw(ax0, state_map, xform_matrix)

    G = scene.get_xform_matrix(state_map, pendulum2)
    Gd = scene.get_velocity_matrix(state_map, pendulum2)
    dGdq1 = scene.get_xform_matrix(state_map, pendulum2, pendulum1)
    dGdq2 = scene.get_xform_matrix(state_map, pendulum2, pendulum2)

    xs = np.linspace(-10., 10., 15)
    ys = np.linspace(-10., 10., 15)
    xg, yg = np.meshgrid(xs, ys)
    zg = np.ones(xg.shape)
    local_pos = np.array(np.row_stack((xg.ravel(), yg.ravel(), zg.ravel())))
    global_pos = G.dot(local_pos)
    view_pos = xform_matrix.dot(global_pos)
    
    def do_contour(mat, color):
        if show_contour:
            ps = []
            for p in local_pos.T:
                ps.append(p.T.dot(mat).dot(p))                
            ax0.contour(view_pos[0].reshape(xg.shape), view_pos[1].reshape(yg.shape), np.array(ps).reshape(xg.shape), levels=6, colors=color)
    
    def do_quiver(mat, color):
        if show_vectors:
            global_vel = mat.dot(local_pos)
            view_vel = xform_matrix.dot(global_vel)
            ax0.quiver(view_pos[0], view_pos[1], view_vel[0], view_vel[1], width=0.002, color=color, alpha=0.3)

    if show_Gd:
        do_quiver(Gd, 'green')
        do_contour(Gd.T.dot(Gd), 'green')
            
    if show_dGdq1:
        do_quiver(dGdq1, 'red')
        do_contour(dGdq1.T.dot(dGdq1), 'red')

    if show_dGdq2:
        do_quiver(dGdq2, 'blue')
        do_contour(dGdq2.T.dot(dGdq2), 'blue')

    if show_dGdq1Tdgdq2:
        do_contour(dGdq1.T.dot(dGdq2), 'purple')

    if show_dGdq2Tdgdq1:
        do_contour(dGdq2.T.dot(dGdq1), 'pink')

    if show_d2Gdq1dq1:
        d2Gdq1dq1 = scene.get_xform_matrix(state_map, pendulum2, pendulum1, pendulum1)
        do_quiver(d2Gdq1dq1, 'purple')
        do_contour(d2Gdq1dq1.T.dot(d2Gdq1dq1), 'purple')
    
    if show_d2Gdq2dq2:
        d2Gdq2dq2 = scene.get_xform_matrix(state_map, pendulum2, pendulum2, pendulum2)
        do_quiver(d2Gdq2dq2, 'blue')
        do_contour(d2Gdq2dq2.T.dot(d2Gdq2dq2), 'blue')
        
    if show_d2Gdq1dq2:
        d2Gdq1dq2 = scene.get_xform_matrix(state_map, pendulum2, pendulum1, pendulum2)
        do_quiver(d2Gdq1dq2, 'green')
        do_contour(d2Gdq1dq2.T.dot(d2Gdq1dq2), 'green')

In [None]:
max_t = 20.

@interact(
    init_x1=(-10., 10.),
    init_x2=(-10., 10.),
    init_dx1=(-5., 5., 0.1),
    init_dx2=(-5., 5., 0.1),
    #d0=(0., 10., 0.1),
    k=(0., 10., 0.01),
    m1=(0.1, 10., 0.1),
    m2=(0.1, 10., 0.1),
    t=(0., max_t, 0.1),
    ts=(0., 5., 0.1),
    scale=(0.2, 2.),
)
def f2(
    init_x1=-5.,
    init_x2=5.,
    init_dx1=5.,
    init_dx2=0.,
    #d0=10.,
    k=1.,
    m1=1.,
    m2=1.,
    t=0.,
    ts=1.1,
    scale=1.,
    grid=False,
):
    d0 = init_x2 - init_x1
    dd0 = init_dx2 - init_dx1
    w = np.sqrt((1 / m1 + 1 / m2) * k)
    #a = -dd0 / w # - 2 * d0 / np.pi
    a = dd0 / (np.cos(w * ts) - 1)
    #collision_start_time = (d0 - (init_x2 - init_x1)) / (init_dx2 - init_dx1)
    #collision_end_time = collision_start_time + np.pi / (2 * w)
    #print(a)
    #print(collision_end_time)

    dx1 = init_dx1
    dx2 = init_dx2
    get_x1 = lambda t: init_x1 + init_dx1 * t - k * a / (m1 * w) * (t - 1/w * np.sin(w * t))
    get_x2 = lambda t: init_x2 + init_dx2 * t + k * a / (m2 * w) * (t - 1/w * np.sin(w * t))
    x1 = get_x1(t)
    x2 = get_x2(t)
    t_vals = np.linspace(0., max_t, 100)
    x1_vals = get_x1(t_vals)
    x2_vals = get_x2(t_vals)

    box_width = 2.
    cart1 = TrackFrame(decals=[BoxDecal(width=box_width, height=box_width / 1.618, position=(-box_width/2, 0))])
    cart2 = TrackFrame(decals=[BoxDecal(width=box_width, height=box_width / 1.618, position=(box_width/2, 0))])
    track_line = LineDecal(-100, 100, linewidth=0.5)
    colliding = True
    collision_line = LineDecal(x1, x2, linewidth=4., color='r' if colliding else 'g')
    scene = Scene(frames=[cart1, cart2], decals=[track_line, collision_line])

    fig, ((ax0, ax1), (ax2, ax3)) = plt.subplots(2, 2, figsize=(18, 10))
    xform_matrix = get_scale_matrix(scale, scale)
    state_map = {
        cart1: (x1, dx1),
        cart2: (x2, dx2),
    }
    scene.draw(ax0, state_map, xform_matrix, draw_options=DrawOptions(grid=grid))
    
    with np.errstate(all='ignore'):
        a_vals = (d0 + dd0 * t_vals) / (np.sin(w * t_vals) - t_vals) - dd0 / (np.cos(w * t_vals) - 1)
    ax1.plot(t_vals, a_vals)
    ax1.set_ylim(-50., 50.)
    ax1.axhline(color='k')
    ax1.axvline(color='k')
    ax1.axvline(t, color='r', linestyle='--')
    
    ax2.plot(x1_vals, t_vals)
    ax2.plot(x2_vals, t_vals)
    ax2.axhline(t)
