In [None]:
from abc import ABC
from functools import partial
from functools import reduce
from io import BytesIO
from ipywidgets import interact
from IPython.display import display
from IPython.display import Image
from matplotlib import animation
from matplotlib import pyplot as plt
from matplotlib import rc
from matplotlib.lines import Line2D
import daglet
import ipywidgets
import matplotlib.patches as patches
import numpy as np
import operator
import tempfile
import PIL

In [None]:
rc('animation', html='html5')

In [None]:
def draw_line(ax, x1, y1, x2, y2, **kwargs):
    kwargs.setdefault('linestyle', '-')
    kwargs.setdefault('linewidth', 2.)
    kwargs.setdefault('color', 'k')
    ax.add_line(
        Line2D(
            (x1, x2),
            (y1, y2),
            **kwargs
        )
    )
    
def init_ax(ax, lim=10, hide_axes=True):
    ax.set_aspect('equal')
    ax.set_xlim(-lim, lim)
    ax.set_ylim(-lim, lim)

    if hide_axes:
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
    else:
        ax.spines['bottom'].set_position('zero')
        ax.spines['left'].set_position('zero')
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])

In [None]:
@interact(
    q0=(-1, 1, 0.1),
    q1=(-1, 1, 0.1),
    q2=(-1, 1, 0.1),
    qd0=(-2, 2, 0.1),
    qq1=(-2, 2, 0.1),
    qd2=(-2, 2, 0.1),
)
def f(
    q0=-0.6,
    q1=0.1,
    q2=0.3,
    qd0=1.2,
    qd1=-0.5,
    qd2=1.,
):
    q0 *= np.pi
    q1 *= np.pi
    q2 *= np.pi
    
    fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(16, 9))
    init_ax(ax0)
    init_ax(ax1)
    ax0.set_title('velocity')
    ax1.set_title('position')
    
    l = 4.
    x0 = l * np.sin(q0)
    y0 = l * -np.cos(q0)
    x1 = x0 + l * np.sin(q1)
    y1 = y0 + l * -np.cos(q1)
    x2 = x1 + l * np.sin(q2)
    y2 = y1 + l * -np.cos(q2)
    xd0 = qd0 * l * np.cos(q0)
    yd0 = qd0 * l * np.sin(q0)
    xd1 = xd0 + qd1 * l * np.cos(q1)
    yd1 = yd0 + qd1 * l * np.sin(q1)
    xd2 = xd1 + qd2 * l * np.cos(q2)
    yd2 = yd1 + qd2 * l * np.sin(q2)
    
    draw_line(ax1, 0, 0, x0, y0)
    draw_line(ax1, x0, y0, x1, y1)
    draw_line(ax1, x1, y1, x2, y2)
    draw_line(ax0, 0, 0, xd0, yd0)
    draw_line(ax0, xd0, yd0, xd1, yd1)
    draw_line(ax0, xd1, yd1, xd2, yd2)

    X = np.linspace(0, 1, 8, endpoint=False)
    Y = np.linspace(0, 1, 6)
    xx, yy = np.meshgrid(X, Y)
    pos_xx = (1 - yy) * 4 * np.cos(xx * 2 * np.pi) + x2
    pos_yy = (1 - yy) * 4 * np.sin(xx * 2 * np.pi) + y2
    ax1.plot(pos_xx.flatten(), pos_yy.flatten(), 'ro', ms=3.)
    ax1.plot(pos_xx[0][0], pos_yy[0][0], 'ko', ms=6.)


In [None]:
@interact(
    theta=(-3.1, 3.1, 0.1),
    theta_d=(-4., 4., 0.1),
    x=(-3., 3., 0.1),
    y=(-3., 3., 0.1),
    xd=(-6., 6., 0.2),
    yd=(-6., 6., 0.2),
    t=(-5., 5., 0.1),
)
def f(theta=0.5, theta_d=1.2, x=1.5, y=0.5, xd=-0.7, yd=1.5, t=0., show_taylor_approx=False):
    fig, (ax0) = plt.subplots(1, 1, figsize=(16, 9))
    init_ax(ax0, lim=5.)

    local1 = np.array([
        [-1., -1., 1.],
        [1., -1., 1.],
        [1., 1., 1.],
        [-1., 1., 1.],
        [0., 0., 1.],
    ]).transpose()
    A = np.matrix([
        [np.cos(theta), -np.sin(theta), x],
        [np.sin(theta), np.cos(theta), y],
        [0., 0., 1.],
    ])
    Ad = np.matrix([
        [theta_d * -np.sin(theta), theta_d * -np.cos(theta), xd],
        [theta_d * np.cos(theta), theta_d * -np.sin(theta), yd],
        [0, 0, 1],
    ])
    global1 = A * local1
    A2 = np.matrix([
        [np.cos(theta + theta_d*t), -np.sin(theta + theta_d*t), x + xd*t],
        [np.sin(theta + theta_d*t), np.cos(theta + theta_d*t), y + yd*t],
        [0., 0., 1.],
    ])
    global2 = A2 * local1
    global3 = (A + Ad*t) * local1

    draw_line(ax0, 0, 0, global1[0, -1], global1[1, -1], linestyle='--', linewidth=1.)

    draw_line(ax0, global1[0, 0], global1[1, 0], global1[0, 1], global1[1, 1], linestyle='--')
    draw_line(ax0, global1[0, 1], global1[1, 1], global1[0, 2], global1[1, 2], linestyle='--')
    draw_line(ax0, global1[0, 2], global1[1, 2], global1[0, 3], global1[1, 3], linestyle='--')
    draw_line(ax0, global1[0, 3], global1[1, 3], global1[0, 0], global1[1, 0], linestyle='--')
    ax0.plot(global1[0,:], global1[1,:], 'ro')
    
    if show_taylor_approx:
        draw_line(ax0, 0, 0, global3[0, -1], global3[1, -1], linestyle='--', linewidth=0.5)
        draw_line(ax0, global3[0, 0], global3[1, 0], global3[0, 1], global3[1, 1], linestyle='--', linewidth=1.)
        draw_line(ax0, global3[0, 1], global3[1, 1], global3[0, 2], global3[1, 2], linestyle='--', linewidth=1.)
        draw_line(ax0, global3[0, 2], global3[1, 2], global3[0, 3], global3[1, 3], linestyle='--', linewidth=1.)
        draw_line(ax0, global3[0, 3], global3[1, 3], global3[0, 0], global3[1, 0], linestyle='--', linewidth=1.)
        ax0.plot(global3[0,:], global3[1,:], 'bo', ms=4.)

    draw_line(ax0, 0, 0, global2[0, -1], global2[1, -1], linestyle='--', linewidth=1.)
    draw_line(ax0, global2[0, 0], global2[1, 0], global2[0, 1], global2[1, 1])
    draw_line(ax0, global2[0, 1], global2[1, 1], global2[0, 2], global2[1, 2])
    draw_line(ax0, global2[0, 2], global2[1, 2], global2[0, 3], global2[1, 3])
    draw_line(ax0, global2[0, 3], global2[1, 3], global2[0, 0], global2[1, 0])
    ax0.plot(global2[0,:], global2[1,:], 'go')
    
    xs = np.linspace(-5., 5., 15)
    ys = np.linspace(-5., 5., 15)
    xg, yg = np.meshgrid(xs, ys)
    zg = np.ones((xs.shape[0], ys.shape[0]))
    global2 = np.matrix(np.row_stack((
        xg.ravel(),
        yg.ravel(),
        zg.ravel(),
    )))
    local2 = np.linalg.inv(A) * global2
    #xformed2 = np.matrix(local2).T * (A.T*A) * np.matrix(local2)

    ij0 = np.matrix([[1, 0, 0], [0, 1, 0], [0, 0, 0]])
    vels = ij0 * Ad * local2
    vel_xg = np.array(vels[0,:]).reshape(xg.shape)
    vel_yg = np.array(vels[1,:]).reshape(yg.shape)
    ax0.quiver(xg, yg, vel_xg, vel_yg)

    S = Ad.T * ij0.T * ij0 * Ad
    zg = np.diag(local2.T * S * local2).reshape(xs.shape[0], ys.shape[0])

    #ax0.plot(xformed2[0,:], xformed2[1,:], 'go', ms=1.)
    
    # Find the center through this ugly formula:
    v2 = 1/(theta_d * (np.cos(theta)**2 / np.sin(theta) + np.sin(theta))) * (np.cos(theta) / np.sin(theta) * xd + yd)
    #v2 = 1/(theta_d * np.cos(theta)) * yd
    v1 = 1./np.sin(theta) * (xd/theta_d - v2 * np.cos(theta))
    w = A * np.matrix([[v1, v2, 1]]).T
    ax0.plot(w[0,0], w[1,0], 'ko', ms=7.)
    
    #eigvals, eigvecs = np.linalg.eig(A.T * A)
    #eigvecs = np.real(eigvecs)
    #ax0.arrow(0., 0., eigvecs[0, 0], eigvecs[0, 1], color='r')
    #ax0.arrow(0., 0., eigvecs[1, 1], eigvecs[1, 1], color='g')
    #ax0.arrow(0., 0., eigvecs[2, 2], eigvecs[2, 1], color='b')

    #ax1.set_title('Kinetic energy of each point on rotating frame\n(inherits velocity from parent frame)')
    #ax1.contour(xg, yg, zg, vmin=0, vmax=50, levels=20)

In [None]:
def draw_line(ax, x1, y1, x2, y2, **kwargs):
    kwargs.setdefault('linestyle', '-')
    kwargs.setdefault('linewidth', 2.)
    kwargs.setdefault('color', 'k')
    line = Line2D((x1, x2), (y1, y2), **kwargs)
    ax.add_line(line)


def draw_circle(ax, x, y, radius, **kwargs):
    kwargs.setdefault('color', 'k')
    circle = plt.Circle((x, y), radius, **kwargs)
    ax.add_artist(circle)


#def draw_box(ax, x, y, width, height, **kwargs):
#    kwargs.setdefault('color', 'k')
#    rect = patches.Rectangle((x, y), width, height, **kwargs)
#    ax.add_patch(rect)


def do_render_animation(draw_func, max_time_index, sample_interval, fps=25, fig=None):
    ax = None
    if fig is None:
        fig, ax = plt.subplots()
    plt.tight_layout()

    def animate(i):
        if ax is not None:
            ax.clear()
        time_index = i * sample_interval
        draw_func(ax, time_index)
        return (fig,)

    anim = animation.FuncAnimation(
        fig,
        animate,
        frames=int(max_time_index / sample_interval),
        interval=int(1000/fps),
        blit=True,
    )
    plt.close(anim._fig)
    return anim


class RendererWidget(ipywidgets.VBox):
    def __init__(self):
        self.render_func = lambda: None
        self.__out = ipywidgets.Output()
        button = ipywidgets.Button(description='Render')
        button.on_click(self.__on_click)
        super(RendererWidget, self).__init__([button, self.__out])

    def __on_click(self, _):
        self.__out.clear_output()
        with self.__out:
            self.render_func()

In [None]:
DEFAULT_GRAVITY = 10.

ZERO_POS = np.array([[0., 0., 1.]]).T
ZERO_VEL = np.array([[0., 0., 0.]]).T

CENTERED_SQUARE = np.array([
    (-0.5, -0.5, 1.),
    (0.5, -0.5, 1.),
    (0.5, 0.5, 1.),
    (-0.5, 0.5, 1.),
]).T

QUAD1_SQUARE = np.array([
    (0., 0., 1.),
    (1., 0., 1.),
    (1., 1., 1.),
    (0., 1., 1.),
]).T

ZIGZAG1 = np.array([
    (0., 0., 1.),
    (0.25, 0.5, 1.),
    (0.75, -0.5, 1.),
    (1., 0., 1.),
]).T

ZIGZAG5 = np.array([
    (0., 0., 1.),
    (1./20., 0.5, 1.),
    (3./20., -0.5, 1.),
    (5./20., 0.5, 1.),
    (7./20., -0.5, 1.),
    (9./20., 0.5, 1.),
    (11./20., -0.5, 1.),
    (13./20., 0.5, 1.),
    (15./20., -0.5, 1.),
    (17./20., 0.5, 1.),
    (19./20., -0.5, 1.),
    (1., 0., 1.),
]).T


def coerce_position_vector(a):
    #if not isinstance(a, np.ndarray) or a.shape != (3, 1):
    a = np.asarray(a).flatten()
    assert len(a) == 2 or len(a) == 3
    a = np.array([[a[0], a[1], 1.]]).T
    return a


def get_scale_matrix(x, y):
    return np.array([
        (x, 0., 0.),
        (0., y, 0.),
        (0., 0., 1.),
    ])


def get_rotation_matrix(angle):
    c = np.cos(angle)
    s = np.sin(angle)
    return np.array([
        (c, -s, 0.),
        (s, c, 0.),
        (0., 0., 1.),
    ])


def get_translation_matrix(offset):
    offset = coerce_position_vector(offset)
    return np.array([
        (1., 0., offset[0, 0]),
        (0., 1., offset[1, 0]),
        (0., 0., 1.),
    ])


def get_rotation_translation_matrix(angle, offset):
    offset = coerce_position_vector(offset)
    c = np.cos(angle)
    s = np.sin(angle)
    return np.array([
        (c, -s, offset[0, 0]),
        (s, c, offset[1, 0]),
        (0., 0., 1.),
    ])


class Connector(object):
    def __init__(self, node=None, parents=[]):
        if isinstance(node, (tuple, list)):
            node = Group(*node)
        self.node = node
        self.parents = parents
        
    def __repr__(self):
        return '{} [0x{:x}]'.format(type(self).__name__, id(self))
    
    def __or__(self, other):
        return Connector(other, [self])


class Group(object):
    def __init__(self, *nodes):
        self.nodes = nodes
        
    def __repr__(self):
        return '{} [0x{:x}]'.format(type(self).__name__, id(self))
    
    def __or__(self, other):
        return Connector(other, [Connector(self)])

    
class Frame(object):
    def __init__(self, decals=[], masses=[], frames=[]):
        self.decals = decals
        self.masses = masses
        self.frames = frames
        
    def __repr__(self):
        return '{} [0x{:x}]'.format(type(self).__name__, id(self))

    def __or__(self, other):
        return Connector(other, [Connector(self)])

    def get_pos_matrix(self, q):
        return np.identity(3)
    
    def get_vel_matrix(self, q):
        return np.zeros((3, 3))
    
    def get_accel_matrix(self, q):
        return np.zeros((3, 3))
        
    def draw(self, ax, state_map, xform_matrix, scale):
        state = state_map.get(self, (0., 0.))
        q = state[0]
        xform_matrix = xform_matrix.dot(self.get_pos_matrix(q))
        for decal in self.decals:
            decal.draw(ax, xform_matrix, scale)
        for frame in self.frames:
            frame.draw(ax, state_map, xform_matrix, scale)
            

class StaticTransform(ABC):
    def __or__(self, other):
        return Connector(other, [Connector(self)])


class Rotation(StaticTransform):
    def __init__(self, angle):
        self.angle = angle
        
    def __repr__(self):
        return '{} [0x{:x}]\nangle={:.2f}'.format(type(self).__name__, id(self), self.angle)
    
    def get_xform_matrix(self):
        return get_rotation_matrix(self.angle)


class Translation(StaticTransform):
    def __init__(self, offset):
        self.offset = coerce_position_vector(offset)
        
    def __repr__(self):
        return '{} [0x{:x}]\noffset=({:.2f}, {:.2f})'.format(type(self).__name__, id(self), self.offset[0, 0], self.offset[1, 0])

    def __or__(self, other):
        return Connector(other, [Connector(self)])
    
    def get_xform_matrix(self):
        return get_translation_matrix(self.offset)

    
class RotationalFrame(Frame):
    def __init__(self, position=ZERO_POS, decals=[], masses=[], frames=[]):
        super(RotationalFrame, self).__init__(decals, masses, frames)
        self.position = coerce_position_vector(position)

    def get_pos_matrix(self, q):
        c = np.cos(q)
        s = np.sin(q)
        return np.array([
            (c, -s, self.position[0, 0]),
            (s, c, self.position[1, 0]),
            (0., 0., 1.),
        ])
    
    def get_vel_matrix(self, q):
        c = np.cos(q)
        s = np.sin(q)
        return np.array([
            (-s, -c, 0.),
            (c, -s, 0.),
            (0., 0., 0.),
        ])
    
    def get_accel_matrix(self, q):
        c = np.cos(q)
        s = np.sin(q)
        return np.array([
            (-c, s, 0.),
            (-s, -c, 0.),
            (0., 0., 0.),
        ])


class TrackFrame(Frame):
    def __init__(self, position=ZERO_POS, angle=0., decals=[], masses=[], frames=[]):
        super(TrackFrame, self).__init__(decals, masses, frames)
        self.position = coerce_position_vector(position)
        self.angle = angle

    def get_pos_matrix(self, q):
        c = np.cos(self.angle)
        s = np.sin(self.angle)
        return get_translation_matrix(
            self.position + np.array([[q * c, q * s, 0.]]).T
        )
    
    def get_vel_matrix(self, q):
        return np.array([
            (0, 0, np.cos(self.angle)),
            (0, 0, np.sin(self.angle)),
            (0., 0., 0.),
        ])


class Mass(object):
    def __init__(self, mass=1., position=ZERO_POS):
        self.mass = mass
        self.position = coerce_position_vector(position)
        
    def __repr__(self):
        return '{} [0x{:x}]\nmass={:.2f}'.format(type(self).__name__, id(self), self.mass)

    def __or__(self, other):
        return Connector(other, [Connector(self)])


class Decal(object):
    def __or__(self, other):
        return Connector(other, [Connector(self)])
    
    def __repr__(self):
        return '{} [0x{:x}]'.format(type(self).__name__, id(self))

    def draw(self, ax, xform_matrix, scale):
        raise NotImplementedError()
        

class LineDecal(Decal):
    def __init__(self, end_pos, start_pos=ZERO_POS, linewidth=2.):
        self.start_pos = coerce_position_vector(start_pos)
        self.end_pos = coerce_position_vector(end_pos)
        self.linewidth = linewidth
        self._positions = np.hstack((self.start_pos, self.end_pos))

    def draw(self, ax, xform_matrix, scale):
        xformed = xform_matrix.dot(self._positions)
        draw_line(
            ax,
            xformed[0, 0],
            xformed[1, 0],
            xformed[0, 1],
            xformed[1, 1],
            linewidth=self.linewidth*scale,
        )


class CircleDecal(Decal):
    def __init__(self, position=ZERO_POS, radius=1.):
        self.position = coerce_position_vector(position)
        self.radius = radius
        
    def __repr__(self):
        return '{} [0x{:x}]\nradius={}'.format(type(self).__name__, id(self), self.radius)

    def draw(self, ax, xform_matrix, scale):
        xformed = xform_matrix.dot(self.position)
        draw_circle(ax, xformed[0, 0], xformed[1, 0], self.radius * scale)


class BoxDecal(Decal):
    def __init__(self, width=1., height=1., position=ZERO_POS, angle=0., centered=True, solid=True, linewidth=1.):
        self.width = width
        self.height = height
        self.position = coerce_position_vector(position)
        self.angle = angle
        self.centered = centered
        self.solid = solid
        self.linewidth = linewidth
        positions = CENTERED_SQUARE if centered else QUAD1_SQUARE
        xform_matrix = get_scale_matrix(width, height)
        xform_matrix = get_rotation_translation_matrix(-angle, self.position).dot(xform_matrix)
        self._corner_positions = xform_matrix.dot(positions)
        
    def __repr__(self):
        return '{} [0x{:x}]\nwidth={}, height={}'.format(type(self).__name__, id(self), self.width, self.height)

    def draw(self, ax, xform_matrix, scale):
        xformed = xform_matrix.dot(self._corner_positions)
        npoints = self._corner_positions.shape[1]
        if self.solid:
            ax.fill(xformed[0, :], xformed[1, :], 'k')
        else:
            for i in range(npoints):
                j = (i + 1) % npoints
                draw_line(
                    ax,
                    xformed[0, i],
                    xformed[1, i],
                    xformed[0, j],
                    xformed[1, j],
                    linewidth=self.linewidth*scale
                )


class Spring(object):
    def __init__(
        self,
        frame1,
        frame2,
        k=1.,
        position1=ZERO_POS,
        position2=ZERO_POS,
        linewidth=1.,
        zigzag_count=5,
        zigzag_padding=0.8,
        zigzag_width=0.8,
    ):
        self.frame1 = frame1
        self.frame2 = frame2
        self.k = k
        self.position1 = coerce_position_vector(position1)
        self.position2 = coerce_position_vector(position2)
        self.zigzag_count = zigzag_count
        self.zigzag_padding = zigzag_padding
        self.zigzag_width = zigzag_width
        self.linewidth = linewidth
        
    def __repr__(self):
        return '{} [0x{:x}]\nk={}'.format(type(self).__name__, id(self), self.k)
        
    def __or__(self, other):
        return Connector(other, [self])
        
    def draw(self, ax, state_map, root_xform_matrix, frame1_xform_matrix, frame2_xform_matrix, scale):
        pos1 = frame1_xform_matrix.dot(self.position1)
        pos2 = frame2_xform_matrix.dot(self.position2)
        dist = np.linalg.norm(pos2 - pos1)
        if dist < self.zigzag_padding * 2 * scale:
            draw_line(ax, pos1[0, 0,], pos1[1, 0], pos2[0, 0], pos2[1, 0])
        else:
            normal = (pos2 - pos1) / dist
            zigzag_start = pos1 + normal * self.zigzag_padding * scale
            zigzag_end = pos2 - normal * self.zigzag_padding * scale
            basis1 = zigzag_end - zigzag_start
            basis2 = scale * self.zigzag_width * np.array([(normal[1, 0], -normal[0, 0], 0.)]).T
            basis3 = np.array([(0., 0., 1.)]).T
            zigzag_xform = get_translation_matrix(zigzag_start).dot(np.hstack((basis1, basis2, basis3)))
            zigzag_points = zigzag_xform.dot(ZIGZAG5)
            points = np.hstack((
                pos1,
                zigzag_points,
                pos2,
            ))
            for i in range(points.shape[1] - 1):
                draw_line(
                    ax,
                    points[0, i],
                    points[1, i],
                    points[0, i + 1],
                    points[1, i + 1],
                    linewidth=self.linewidth * scale
                )


class Scene(object):
    def __init__(self, decals=[], frames=[], springs=[], gravity=DEFAULT_GRAVITY):
        get_children = lambda x: x.frames
        visit_path = lambda frame, paths: reduce(operator.add, paths, []) + [frame]
        self.decals = decals
        self.frames = frames
        self.springs = springs
        self.gravity = gravity
        self.sorted_frames = list(reversed(daglet.toposort(self.frames, get_children)))
        self.frame_parent_map = daglet.get_child_map(self.sorted_frames, get_children)
        self.frame_path_map = daglet.transform_vertices(self.sorted_frames, self.frame_parent_map.get, visit_path)
        
    def __repr__(self):
        return '{} [0x{:x}]'.format(type(self).__name__, id(self))

    def get_xform_matrix(self, state_map, frame, deriv_frame1=None, deriv_frame2=None):
        """Get matrix that transforms points from one frame into the global
        coordinate space, and optionally take partial derivatives to return
        a velocity transformation matrix or acceleration transformation matrix
        rather than a position transformation matrix.
        """
        path = self.frame_path_map[frame]
        if deriv_frame1 is not None and deriv_frame1 not in path:
            matrix = np.zeros((3, 3))
        elif deriv_frame2 is not None and deriv_frame2 not in path:
            matrix = np.zeros((3, 3))
        else:
            matrix = np.identity(3)
            for frame2 in reversed(path):
                q, _ = state_map[frame2]
                if frame2 is deriv_frame1 and frame2 is deriv_frame2:
                    matrix = frame2.get_accel_matrix(q).dot(matrix)
                elif frame2 is deriv_frame1 or frame2 is deriv_frame2:
                    matrix = frame2.get_vel_matrix(q).dot(matrix)
                else:
                    matrix = frame2.get_pos_matrix(q).dot(matrix)
        return matrix
        
    def draw(self, ax, state_map, xform_matrix=np.identity(3)):
        init_ax(ax)
        scale = np.sqrt(np.linalg.det(xform_matrix[:2, :2]))
        for decal in self.decals:
            decal.draw(ax, xform_matrix, scale)
        for frame in self.frames:
            frame.draw(ax, state_map, xform_matrix, scale)
        for spring in self.springs:
            frame1_xform_matrix = xform_matrix.dot(self.get_xform_matrix(state_map, spring.frame1))
            frame2_xform_matrix = xform_matrix.dot(self.get_xform_matrix(state_map, spring.frame2))
            spring.draw(ax, state_map, xform_matrix, frame1_xform_matrix, frame2_xform_matrix, scale)


class NaiveSolver(object):
    def __init__(self, scene):
        self.scene = scene
        
    def _make_state_map(self, qs, qds):
        return {frame: (q, qd) for frame, q, qd in zip(self.scene.sorted_frames, qs, qds)}

    def _solve(self, qs, qds):
        state_map = self._make_state_map(qs, qds)
        nframes = len(self.scene.sorted_frames)
        a_mat = np.zeros((nframes, nframes))
        b_vec = np.zeros((nframes))
        for k, frame_k in enumerate(self.scene.sorted_frames):
            for frame_i in self.scene.sorted_frames:
                frame_i_path = self.scene.frame_path_map[frame_i]
                if frame_k not in frame_i_path:
                    continue
                submat2 = np.zeros((3, 3))
                submat3 = np.zeros((3, 3))
                submat4 = np.zeros((3, 3))
                dgi_dqk = self.scene.get_xform_matrix(state_map, frame_i, frame_k)
                for frame_j in frame_i_path:
                    j = self.scene.sorted_frames.index(frame_j)
                    _, qdj = state_map[frame_j]
                    dgi_dqj = self.scene.get_xform_matrix(state_map, frame_i, frame_j)
                    d2gi_dqjdqk = self.scene.get_xform_matrix(state_map, frame_i, frame_j, frame_k)
                    qddj_coeff_mat = dgi_dqj.T.dot(dgi_dqk)
                    for mass in frame_i.masses:
                        a_mat[k, j] += mass.mass * mass.position.T.dot(qddj_coeff_mat).dot(mass.position)
                    for frame_l in frame_i_path:
                        _, qdl = state_map[frame_l]
                        d2gi_dqjdql = self.scene.get_xform_matrix(state_map, frame_i, frame_j, frame_l)
                        d2gi_dqkdql = self.scene.get_xform_matrix(state_map, frame_i, frame_k, frame_l)
                        submat2 += qdj * qdl * (d2gi_dqjdql.T.dot(dgi_dqk) + dgi_dqj.T.dot(d2gi_dqkdql))
                    submat3 += qdj * dgi_dqj.T
                    submat4 += qdj * d2gi_dqjdqk
                submat5 = submat2 - submat3.dot(submat4)
                for mass in frame_i.masses:
                    b_vec[k] -= mass.mass * mass.position.T.dot(submat5).dot(mass.position)
                    b_vec[k] -= mass.mass * self.scene.gravity * dgi_dqk[1, :].dot(mass.position)
            for spring in self.scene.springs:
                pos_mat1 = self.scene.get_xform_matrix(state_map, spring.frame1)
                pos_mat2 = self.scene.get_xform_matrix(state_map, spring.frame2)
                vel_mat1 = self.scene.get_xform_matrix(state_map, spring.frame1, frame_k)
                vel_mat2 = self.scene.get_xform_matrix(state_map, spring.frame2, frame_k)
                displacement = pos_mat1.dot(spring.position1) - pos_mat2.dot(spring.position2)
                velocity = vel_mat1.dot(spring.position1) - vel_mat2.dot(spring.position2)
                b_vec[k] -= spring.k * displacement.T.dot(velocity)

        qdds = np.linalg.solve(a_mat, b_vec)
        return qdds

    def tick(self, state_map, delta_time):
        def f1(qs, qds):
            return qds
        
        def f2(qs, qds):
            return self._solve(qs, qds)

        qs0 = np.array([state_map[frame][0] for frame in self.scene.sorted_frames])
        qds0 = np.array([state_map[frame][1] for frame in self.scene.sorted_frames])
        k1_qs = f1(qs0, qds0) * delta_time
        k1_qds = f2(qs0, qds0) * delta_time
        k2_qs = f1(qs0 + 0.5 * k1_qs, qds0 + 0.5 * k1_qds) * delta_time
        k2_qds = f2(qs0 + 0.5 * k1_qs, qds0 + 0.5 * k1_qds) * delta_time
        k3_qs = f1(qs0 + 0.5 * k2_qs, qds0 + 0.5 * k2_qds) * delta_time
        k3_qds = f2(qs0 + 0.5 * k2_qs, qds0 + 0.5 * k2_qds) * delta_time
        k4_qs = f1(qs0 + k3_qs, qds0 + k3_qds) * delta_time
        k4_qds = f2(qs0 + k3_qs, qds0 + k3_qds) * delta_time
        new_qs = qs0 + (k1_qs + 2 * k2_qs + 2 * k3_qs + k4_qs) / 6.
        new_qds = qds0 + (k1_qds + 2 * k2_qds + 2 * k3_qds + k4_qds) / 6.
        return self._make_state_map(new_qs, new_qds)

In [None]:
def is_binding(node):
    return isinstance(node, Connector) and node.parents and isinstance(node.node, Connector)


def walk_nodes(node, bypass_bindings=False):
    if isinstance(node, Connector):
        nodes = node.parents
        if is_binding(node) and bypass_bindings:
            pass
        elif node.node is not None:
            nodes = nodes + [node.node]
    elif isinstance(node, Group):
        nodes = node.nodes
    elif isinstance(node, Frame):
        nodes = node.frames + node.decals + node.masses
    elif isinstance(node, Spring):
        nodes = [node.frame1, node.frame2]
    else:
        nodes = []
    return nodes


def label_node(node):
    return repr(node)  # TODO: improve


def get_node_color(node):
    if is_binding(node):
        color = 'honeydew'
    elif isinstance(node, Connector):
        color = 'ivory'
    elif isinstance(node, Group):
        color = 'orange'
    elif isinstance(node, Frame):
        color = 'greenyellow'
    elif isinstance(node, Spring):
        color = 'yellow'
    elif isinstance(node, Decal):
        color = 'khaki'
    elif isinstance(node, Mass):
        color = 'lightblue'
    elif isinstance(node, StaticTransform):
        color = 'gold'
    else:
        color = 'white'
    return color


def _show_png(png_data, scale=0.5):
    buffer = BytesIO(png_data)
    image = PIL.Image.open(buffer)
    display(Image(png_data, width=image.width * scale, unconfined=True))


def show_node_graph(nodes, scale=0.5, parent_func=walk_nodes):
    graph = daglet.make_graph(
        nodes,
        parent_func=parent_func,
        vertex_label_func=label_node,
        vertex_color_func=get_node_color
    )
    #graph.view()
    png_data = graph.pipe(format='png')
    _show_png(png_data, scale)
    
    
class Context(object):
    def __init__(self, node=None, other_context=None):
        self.position = ZERO_POS
        self.decals = []
        self.masses = []
        if isinstance(node, Frame):
            self.frame = transplant_frame(other_context)
        elif other_context is not None:
            pass


#c = RotationalFrame() | BoxDecal() | Mass()

def Pendulum(length=3., mass=1., radius=0.5):
    return (
        RotationalFrame()
        | LineDecal((length, 0.))
        | Translation((length, 0.))
        | CircleDecal(radius=radius)
        | Mass(mass)
    )

thing1 = (
    Pendulum()
    | Spring(
        Pendulum() | Pendulum(),
        Pendulum() | Pendulum(),
    )
)

leaf_node = (
    Group(thing1)
    | Group(*(
        Rotation(np.pi * 2/5. * i)
        | LineDecal((3., 0.))
        | Translation((3., 0.))
        | Pendulum()
        for i in range(5)
    ))
)



leaf_nodes = [leaf_node]


def merge_node_maps(*node_maps):  # TODO: replace with list of edges.
    out_node_map = {}
    for node_map in node_maps:
        for k, v in node_map.items():
            out_node_map.setdefault(k, set())
            out_node_map[k] |= v
    return out_node_map


def process_node(node, root_node_parents=set()):
    if isinstance(node, Connector):
        nodes, node_map = process_connector(node, root_node_parents)
    else:
        nodes = {node}
        node_map = {node: set()}
        if isinstance(node, Group):
            parents = node.nodes
        elif isinstance(node, Spring):
            parents = [node.frame1, node.frame2]
        else:
            parents = []
            node_map = {node: root_node_parents}

        for node2 in parents:
            more_nodes, more_node_map = process_node(node2, root_node_parents)
            node_map = merge_node_maps(node_map, more_node_map)
            node_map[node] |= more_nodes
    return nodes, node_map
    

def process_connector(connector, root_node_parents=set()):
    assert isinstance(connector, Connector)
    nodes = set()
    node_map = {}
    if connector.parents:
        for parent in connector.parents:
            parent_nodes, parent_node_map = process_connector(parent, root_node_parents)
            nodes |= parent_nodes
            node_map = merge_node_maps(node_map, parent_node_map)
    else:
        nodes = root_node_parents
    if connector.node:
        nodes, rhs_node_map = process_node(connector.node, nodes)
        node_map = merge_node_maps(node_map, rhs_node_map)
    return nodes, node_map


nodes, node_map = process_node(leaf_node)
show_node_graph([leaf_node], 0.6, walk_nodes)
show_node_graph(node_map.keys(), 0.6, node_map.get)

In [None]:
frame4 = RotationalFrame(
    (0., 0.),
    decals=[
        LineDecal((5., 0.)),
        CircleDecal((5., 0.), 0.7),
    ],
    masses=[
        Mass(3., (5., 0.)),
    ],
    frames=[],
)

frame3 = RotationalFrame(
    (3., 0.),
    decals=[
        LineDecal((3., 0.)),
        CircleDecal((3., 0.), 0.5),
    ],
    masses=[
        Mass(1., (3., 0.)),
    ],
    frames=[],
)

frame2 = RotationalFrame(
    (0., 0.),
    decals=[
        LineDecal((3., 0.)),
        CircleDecal((3., 0.), 0.5),
    ],
    masses=[
        Mass(1., (3., 0.)),
    ],
    frames=[
        frame3,
    ],
)

frame1b = TrackFrame(
    (0., 0.),
    decals=[
        #LineDecal((3., 0.)),
        CircleDecal((0., 0.), 0.6),
        #BoxDecal(3, 1, (3., 0.)),
    ],
    masses=[
        Mass(5., (0., 0.)),
    ],
    frames=[
        frame2,        
        frame4,
    ],
)

frame1 = RotationalFrame(
    (0., 0.),
    frames=[frame1b],
)

track_angle = 0.05

frame0 = TrackFrame(
    (0., 0.),
    angle=track_angle,
    decals=[
        BoxDecal(width=2., angle=-track_angle, solid=True)
    ],
    masses=[
        Mass(1.),
    ],
    frames=[
        frame1,
    ],
)

block1 = TrackFrame(
    (0., 5.),
    decals=[
        BoxDecal(width=2.5, height=1.25),
    ],
    masses=[
        Mass(3.),
    ],
)

block2 = TrackFrame(
    (0., 3.5),
    decals=[
        BoxDecal(width=2.),
    ],
    masses=[
        Mass(3.),
    ],
)
    
scene = Scene(
    frames=[
        frame0,
        block1,
        #block2,
    ],
    decals=[
        LineDecal(
            (-100. * np.cos(track_angle), -100. * np.sin(track_angle)),
            (100. * np.cos(track_angle), 100. * np.sin(track_angle))
        ),
        LineDecal(
            (-100., 5.),
            (100., 5.),
        ),
    ],
    springs=[
        Spring(block1, frame0, k=5.),
        #Spring(frame3, frame4, 1., (3., 0.), (5., 0.)),
        Spring(frame0, frame1b, 28.),
    ]
)

solver = NaiveSolver(scene)

state_map = {
    frame0: (5., 2.),
    frame1: (-1.2, 0.),
    frame1b: (3., 0.1),
    frame2: (0.2, 0.),
    frame3: (0.3, 0.),
    frame4: (-2., 0.2),
    block1: (2., 0),
    block2: (-2., 0.),
}

state_maps = [state_map]
for i in range(20):
    state_map = solver.tick(state_map, 0.05)
    state_maps.append(state_map)
    

renderer_widget = RendererWidget()

@interact(
    time_index=(0, len(state_maps) - 1),
    scale=(0.1, 5.),
    view_x=(-30, 30),
    view_y=(-30, 30),
)
def f(time_index=0, scale=1., view_x=0, view_y=0):
    fig, ax0 = plt.subplots(1, 1, figsize=(16, 9))
    xform_matrix = get_translation_matrix((-view_x, -view_y)).dot(get_scale_matrix(scale, scale))
    scene.draw(ax0, state_maps[time_index], xform_matrix)
    
    def render():
        draw_func = lambda ax, time_index: scene.draw(ax, state_maps[time_index], xform_matrix)
        anim = do_render_animation(draw_func, len(state_maps), sample_interval=2)
        anim.save('anim.mp4')
        display(anim)

    renderer_widget.render_func = render

display(renderer_widget)