In [None]:
from IPython.display import HTML
from ipywidgets import interact
from matplotlib import pyplot as plt, cm, animation, rc
from matplotlib.lines import Line2D
import functools
import matplotlib
import matplotlib.patches as patches
import numpy as np
import ipywidgets

In [None]:
%matplotlib inline
rc('animation', html='html5')
#rc('animation', html='jshtml')

In [None]:
def memoize(func):
    info = {
        'args': None,
        'kwargs': None,
    }

    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        if args != info['args'] or kwargs != info['kwargs']:
            info.update({
                'args': args,
                'kwargs': kwargs,
                'value': func(*args, **kwargs),
            })

        return info['value']
    return wrapper

In [None]:
#dx1 = x2
#dx2 = -g/L*sin(x1)

@interact(
    x1p=(-2., 4., 0.05),
    x2=(-5., 5., 0.05),
    k=(0., 5.),
)
def f(x1p, x2, k):
    x1 = x1p * np.pi
    X, Y = np.meshgrid(np.arange(-2 * np.pi, 2 * np.pi, 1.), np.arange(-2 * np.pi, 2 * np.pi, 1.))
    DX1 = np.ones(X.shape) * x2
    DX2 = -k * np.sin(np.ones(Y.shape) * x1)
    plt.figure()
    #M = np.hypot(DX1, DX2)
    #Q = plt.quiver(X, Y, DX1, DX2, M, units='x', pivot='tip', width=0.022, scale=1 / 0.15)
    #qk = plt.quiverkey(Q, 0.9, 0.9, 1, '', labelpos='E', coordinates='figure')
    
    Q = plt.quiver(X, Y, DX1, DX2, units='width')
    qk = plt.quiverkey(Q, 0.9, 0.9, 2, r'$2 \frac{m}{s}$', labelpos='E', coordinates='figure')
    
    plt.scatter(X, Y, color='k', s=0.5)

    plt.show()

In [None]:
#dx1 = x2
#dx2 = -g/L*sin(x1)

BASE_WIDTH = 2.
BASE_HEIGHT = 0.3
BALL_RADIUS = 1.

X1I_MIN = -np.pi * 3
X1I_MAX = np.pi * 5
X1I_DEFAULT = np.pi / 3.
X2I_MIN = -5.
X2I_MAX = 5.
X2I_DEFAULT = 0.
T_MAX = 500
GRAVITY = 10.
PHASE_COUNT = 10


def tick(length, b, dt, x1, x2):
    dx1 = x2
    dx2 = -GRAVITY / length * np.sin(x1) - b * x2
    x1 += dx1 * dt
    x2 += dx2 * dt
    return x1, x2, dx1, dx2


@memoize
def simulate(initial_x1, initial_x2, length, b, dt=0.01):
    x1 = initial_x1
    x2 = initial_x2
    X1 = []
    X2 = []
    DX1 = []
    DX2 = []
    for i in range(T_MAX):
        x1, x2, dx1, dx2 = tick(length, b, dt, x1, x2)
        X1.append(x1)
        X2.append(x2)
        DX1.append(dx1)
        DX2.append(dx2)
    return X1, X2, DX1, DX2


@memoize
def simulate_many(length, b, dt):
    simulations = []
    x1i_values = np.linspace(X1I_MIN, X1I_MAX, num=PHASE_COUNT)
    x2i_values = np.linspace(X2I_MIN, X2I_MAX, num=PHASE_COUNT)
    simulations = [
        [
            simulate(x1i, x2i, length, b, dt)
            for x2i in x2i_values
        ]
        for x1i in x1i_values
    ]
    return simulations
    
    
def draw_world(ax, X1, X2, t, length, scale):
    ax.set_aspect('equal')

    theta = X1[t]
    ball_x = np.sin(theta) * length * scale
    ball_y = -np.cos(theta) * length * scale

    ax.add_artist(
        plt.Circle(
            (ball_x + 0.5, ball_y + 0.5),
            BALL_RADIUS * scale / 2.,
            color='k',
        )
    )
    ax.add_line(
        Line2D(
            (0.5, ball_x + 0.5),
            (0.5, ball_y + 0.5),
            linestyle='-',
            linewidth=0.5,
            color='k',
        )
    )

    ax.add_patch(patches.Rectangle(
        (
            0.5 - BASE_WIDTH * scale / 2.,
            0.5,
        ),
        BASE_WIDTH * scale,
        BASE_HEIGHT * scale,
        color='k'
    ))

    
def draw_phase(ax, DX1, DX2, t=None, **kwargs):
    ax.set_aspect('equal')
    ax.set_xlim((-10., 15.))
    ax.set_ylim((-5., 5.))
    ax.plot(DX1, DX2, '-', **kwargs)
    if t is not None:
        ax.plot(DX1[t], DX2[t], 'bo')


def draw_many_phases(ax, length, b, dt):
    uniform_values = np.linspace(0., 1., num=PHASE_COUNT)
    simulations = simulate_many(length, b, dt)
    for sim2, color1 in zip(simulations, uniform_values):
        for simulation, color2 in zip(sim2, uniform_values):
            X1, X2, DX1, DX2 = simulation
            draw_phase(ax, X1, X2, color=(color1 * 0.8 + 0.2, 0.2 + color1 * 0.3 + color2 * 0.3, color2 * 0.5 + 0.5), alpha=0.2, lw=1.5)
        
@interact(
    time=(0, T_MAX - 1, 10),
    initial_x1=(X1I_MIN, X1I_MAX),
    initial_x2=(X2I_MIN, X2I_MAX),
    length=(0.01, 8.),
    b=(0., 1., 0.01),
    scale=(2., 15.),
    dt=(0.005, 0.05, 0.001),
)
def f(
        time=0,
        initial_x1=X1I_DEFAULT,
        initial_x2=X2I_DEFAULT,
        length=3.,
        b=0.01,
        scale=8.,
        dt=0.01,
        simulate_all=True,
):
    X1, X2, DX1, DX2 = simulate(initial_x1, initial_x2, length, b, dt)
    fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(24,8))
    draw_world(ax0, X1, X2, time, length, scale/100.)

    if simulate_all:
        draw_many_phases(ax1, length, b, dt)
                
    draw_phase(ax1, X1, X2, time, color='C0', alpha=1., lw=3.)


In [None]:
@interact(
    initial_x1=(-np.pi, np.pi),
    initial_x2=(-5., 5.),
    length=(0.01, 8.),
    dt=(0.005, 0.05, 0.001),
)
def f(
        initial_x1=np.pi/2.,
        initial_x2=0.,
        length=3.,
        dt=0.01,
):
    fig, ax0, = plt.subplots(1, 1, figsize=(15,8))
    uniform_values = np.linspace(0.5, 1., num=PHASE_COUNT)

    simulations = simulate_many(length, dt)
    for sim2, color1 in zip(simulations, uniform_values):
        for simulation, color2 in zip(sim2, uniform_values):
            X1, X2, DX1, DX2 = simulation
            draw_phase(ax0, DX1, DX2, color=(0.5, color1, color2), alpha=0.2)
    
    X1, X2, DX1, DX2 = simulate(initial_x1, initial_x2, length, dt)
    draw_phase(ax0, DX1, DX2, color='k', alpha=0.7)

In [None]:
#dx1 = x2
#dx2 = -k*x1 - b*x2

BASE_WIDTH = 0.3
BASE_HEIGHT = 2.
EQUILIBRIUM_LENGTH = 4.
BALL_RADIUS = 0.5

X1I_MIN = -2.
X1I_MAX = 2.
X2I_MIN = -2.
X2I_MAX = 2.
T_MAX = 500
PHASE_COUNT = 10
ZIGZAG_PADDING = 0.2
ZIGZAG_HEIGHT = 0.6
ZIGZAG_COUNT = 5


def tick(b, k, dt, x1, x2):
    dx1 = x2
    dx2 = -k * x1 - b * x2
    x1 += dx1 * dt
    x2 += dx2 * dt
    return x1, x2, dx1, dx2


@memoize
def simulate(initial_x1, initial_x2, b, k, dt):
    x1 = initial_x1
    x2 = initial_x2
    X1 = []
    X2 = []
    DX1 = []
    DX2 = []
    for i in range(T_MAX):
        x1, x2, dx1, dx2 = tick(b, k, dt, x1, x2)
        X1.append(x1)
        X2.append(x2)
        DX1.append(dx1)
        DX2.append(dx2)
    return X1, X2, DX1, DX2


@memoize
def simulate_many(b, k, dt):
    simulations = []
    x1i_values = np.linspace(X1I_MIN, X1I_MAX, num=PHASE_COUNT)
    x2i_values = np.linspace(X2I_MIN, X2I_MAX, num=PHASE_COUNT)
    simulations = [
        [
            simulate(x1i, x2i, b, k, dt)
            for x2i in x2i_values
        ]
        for x1i in x1i_values
    ]
    return simulations
    
    
def draw_world(ax, X1, X2, t, scale):
    ax.set_aspect('equal')

    x = X1[t]
    ball_x = x

    xform_coord = lambda x: x * scale + 0.5
    add_circle = lambda x, y, radius: ax.add_artist(
        plt.Circle(
            (xform_coord(x), xform_coord(y)),
            radius * scale,
            color='k',
        )
    )
    add_line = lambda x1, y1, x2, y2: ax.add_line(
        Line2D(
            (x1 * scale + 0.5, x2 * scale + 0.5),
            (y1 * scale + 0.5, y2 * scale + 0.5),
            linestyle='-',
            linewidth=0.5,
            color='k',
        )
    )
    add_rect = lambda x, y, width, height: ax.add_patch(patches.Rectangle(
        (x * scale + 0.5, y * scale + 0.5),
        width * scale,
        height * scale,
        color='k'
    ))
    
    ax.axvline(0.5, alpha=0.2, linestyle='--')
    
    zz1_x1 = -EQUILIBRIUM_LENGTH
    zz1_x2 = zz1_x1 + ZIGZAG_PADDING
    zz2_x2 = ball_x - BALL_RADIUS
    zz2_x1 = zz2_x2 - ZIGZAG_PADDING
    zigpoints = np.linspace(zz1_x2, zz2_x1, num=ZIGZAG_COUNT+1)
    zzh = ZIGZAG_HEIGHT/2.
    for zigx1, zagx2 in zip(zigpoints[:-1], zigpoints[1:]):
        zzw = (zagx2 - zigx1)
        zigx2 = zigx1 + zzw / 4.
        zagx1 = zigx1 + zzw * 3. / 4.
        add_line(zigx1, 0., zigx2, zzh)
        add_line(zigx2, zzh, zagx1, -zzh)
        add_line(zagx1, -zzh, zagx2, 0.)
    
    add_line(zz1_x1, 0., zz1_x2, 0.)
    add_line(zz2_x1, 0., zz2_x2, 0.)
    add_rect(
        -(BASE_WIDTH + EQUILIBRIUM_LENGTH),
        -BASE_HEIGHT / 2.,
        BASE_WIDTH,
        BASE_HEIGHT,
    )
    add_circle(ball_x, 0., BALL_RADIUS)
    

    
def draw_phase(ax, DX1, DX2, t=None, **kwargs):
    ax.set_aspect('equal')
    ax.set_xlim((-5., 5.))
    ax.set_ylim((-4., 4.))
    ax.plot(DX1, DX2, '-', **kwargs)
    if t is not None:
        ax.plot(DX1[t], DX2[t], 'bo')


def draw_many_phases(ax, b, k, dt):
    uniform_values = np.linspace(0.5, 1., num=PHASE_COUNT)
    simulations = simulate_many(b, k, dt)
    for sim2, color1 in zip(simulations, uniform_values):
        for simulation, color2 in zip(sim2, uniform_values):
            X1, X2, DX1, DX2 = simulation
            draw_phase(ax, X1, X2, color=(0.5, color1, color2), alpha=0.2)


@interact(
    time=(0, T_MAX - 1, 10),
    initial_x1=(X1I_MIN, X1I_MAX),
    initial_x2=(X2I_MIN, X2I_MAX),
    b=(0., 5., 0.15),
    k=(0.01, 8.),
    scale=(2., 15.),
    dt=(0.02, 0.1, 0.01),
)
def f(
        time=0,
        initial_x1=1.,
        initial_x2=0.,
        b=0.12,
        k=3.,
        scale=8.,
        dt=0.01,
        simulate_all=False,
):
    X1, X2, DX1, DX2 = simulate(initial_x1, initial_x2, b, k, dt)
    fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(15,8))
    draw_world(ax0, X1, X2, time, scale/100.)

    if simulate_all:
        draw_many_phases(ax1, b, k, dt)
                
    draw_phase(ax1, X1, X2, time, color='C0', alpha=1., lw=3.)


In [None]:
#dx1 = x2
#dx2 = -k*x1 - b*x2

#dy1 = y2
#dy2 = b/m1 * (z2-y2) + k/m1 * (z1-y1-d)
#dz1 = z2
#dz2 = -b/m2 * (z2-y2) - k/m2 * (z1-y1-d)

BLOCK_WIDTH = 1.
BLOCK_HEIGHT = 1.
ZIGZAG_PADDING = 0.2
ZIGZAG_HEIGHT = 0.6
ZIGZAG_COUNT = 5

Y1I_MIN = -3.
Y1I_MAX = -1.
Y2I_MIN = -2.
Y2I_MAX = 2.
Z1I_MIN = 1.
Z1I_MAX = 3.
Z2I_MIN = -2.
Z2I_MAX = 2.
T_MAX = 1000
PHASE_COUNT = 5


def tick(b, k, spring_length, m1, m2, dt, y1, y2, z1, z2):
    dy1 = y2
    dy2 = b/m1 * (z2 - y2) + k/m1 * (z1 - y1 - spring_length)
    dz1 = z2
    dz2 = -b/m2 * (z2 - y2) - k/m2 * (z1 - y1 - spring_length)
    y1 += dy1 * dt
    y2 += dy2 * dt
    z1 += dz1 * dt
    z2 += dz2 * dt
    return y1, y2, z1, z2


@memoize
def simulate(initial_y1, initial_y2, initial_z1, initial_z2, b, k, spring_length, m1, m2, dt):
    y1 = initial_y1
    y2 = initial_y2
    z1 = initial_z1
    z2 = initial_z2
    Y1 = []
    Y2 = []
    Z1 = []
    Z2 = []
    for i in range(T_MAX):
        y1, y2, z1, z2 = tick(b, k, spring_length, m1, m2, dt, y1, y2, z1, z2)
        Y1.append(y1)
        Y2.append(y2)
        Z1.append(z1)
        Z2.append(z2)
    return Y1, Y2, Z1, Z2


@memoize
def simulate_many(b, k, spring_length, m1, m2, dt):
    simulations = []
    y1i_values = np.linspace(Y1I_MIN, Y1I_MAX, num=PHASE_COUNT)
    y2i_values = np.linspace(Y2I_MIN, Y2I_MAX, num=PHASE_COUNT)
    #z1i_values = np.linspace(Z1I_MIN, Z1I_MAX, num=PHASE_COUNT)
    #z2i_values = np.linspace(Z2I_MIN, Z2I_MAX, num=PHASE_COUNT)
    simulations = [
        [
            simulate(x1i, x2i, b, k, spring_length, m1, m2, dt)
            for y2i in y2i_values
        ]
        for y1i in y1i_values
    ]
    return simulations


def draw_circle(ax, x, y, radius, scale):
    ax.add_artist(
        plt.Circle(
            (x * scale + 0.5, y * scale + 0.5),
            radius * scale,
            color='k',
        )
    )
    
    
def draw_line(ax, x1, y1, x2, y2, scale):
    ax.add_line(
        Line2D(
            (x1 * scale + 0.5, x2 * scale + 0.5),
            (y1 * scale + 0.5, y2 * scale + 0.5),
            linestyle='-',
            linewidth=0.5,
            color='k',
        )
    )
    
    
def draw_rect(ax, x, y, width, height, scale):
    ax.add_patch(patches.Rectangle(
        (x * scale + 0.5, y * scale + 0.5),
        width * scale,
        height * scale,
        color='k'
    ))


def draw_zigzag(ax, x1, x2, scale):
    zzw = x2 - x1
    zzh = ZIGZAG_HEIGHT/2.
    zig_x1 = x1
    zig_x2 = zig_x1 + zzw / 4.
    zag_x1 = zig_x1 + zzw * 3. / 4.
    zag_x2 = x2
    draw_line(ax, zig_x1, 0., zig_x2, zzh, scale)
    draw_line(ax, zig_x2, zzh, zag_x1, -zzh, scale)
    draw_line(ax, zag_x1, -zzh, zag_x2, 0., scale)

    
def draw_spring(ax, x1, x2, scale):
    draw_line(ax, x1, 0., x1 + ZIGZAG_PADDING, 0., scale)
    draw_line(ax, x2, 0., x2 - ZIGZAG_PADDING, 0., scale)
    zigpoints = np.linspace(
        x1 + ZIGZAG_PADDING,
        x2 - ZIGZAG_PADDING,
        num=ZIGZAG_COUNT+1,
    )
    for zig_x1, zag_x2 in zip(zigpoints[:-1], zigpoints[1:]):
        draw_zigzag(ax, zig_x1, zag_x2, scale)


def draw_world(ax, y1, z1, spring_length, m1, m2, scale):
    ax.set_aspect('equal')
    
    ax.axvline(0.5, alpha=0.2, linestyle='--')

    size1 = np.sqrt(m1)
    size2 = np.sqrt(m2)
    draw_spring(ax, y1, z1, scale)
    draw_rect(
        ax,
        y1 - BLOCK_WIDTH * size1,
        -BLOCK_HEIGHT * size1 / 2.,
        BLOCK_WIDTH * size1,
        BLOCK_HEIGHT * size1,
        scale,
    )
    draw_rect(
        ax,
        z1,
        -BLOCK_HEIGHT * size2 / 2.,
        BLOCK_WIDTH * size2,
        BLOCK_HEIGHT * size2,
        scale,
    )

    
    
def draw_phase(ax, DX1, DX2, t=None, **kwargs):
    ax.set_aspect('equal')
    ax.set_xlim((-5., 5.))
    ax.set_ylim((-4., 4.))
    ax.plot(DX1, DX2, '-', **kwargs)
    if t is not None:
        ax.plot(DX1[t], DX2[t], 'bo')


def draw_many_phases(ax, b, k, dt):
    uniform_values = np.linspace(0.5, 1., num=PHASE_COUNT)
    simulations = simulate_many(b, k, dt)
    for sim2, color1 in zip(simulations, uniform_values):
        for simulation, color2 in zip(sim2, uniform_values):
            X1, X2, DX1, DX2 = simulation
            draw_phase(ax, X1, X2, color=(0.5, color1, color2), alpha=0.2)


stupid = []
def on_render(x):
    if stupid:
        stupid[0]()  # jezus fucking christ

            
out = ipywidgets.Output()
render_button = ipywidgets.Button(description='Render')
render_button.on_click(on_render)

display(render_button)
display(out)
            
@interact(
    time=(0, T_MAX - 1, 10),
    initial_y1=(Y1I_MIN, Y1I_MAX),
    initial_y2=(Y2I_MIN, Y2I_MAX),
    initial_z1=(Z1I_MIN, Z1I_MAX),
    initial_z2=(Z2I_MIN, Z2I_MAX),
    b=(0., 3., 0.05),
    k=(0.01, 8.),
    m1=(0.1, 5.),
    m2=(0.1, 5.),
    spring_length=(2., 4.),
    dt=(0.02, 0.1, 0.01),
    scale=(2., 15.),
)
def f(
        time=0,
        initial_y1=-2.,
        initial_y2=0.,
        initial_z1=2.,
        initial_z2=0.,
        b=0.2,
        k=3.,
        m1=1.,
        m2=1.,
        spring_length=3.,
        dt=0.05,
        scale=8.,
        simulate_all=False,
):
    def render():
        out.clear_output()
        with out:
            fig, ax = plt.subplots()
            def animate(i):
                time = i * 2
                ax.clear()
                draw_world(ax, Y1[time], Z1[time], spring_length, m1, m2, scale/100.)
                return (fig,)

            anim = animation.FuncAnimation(fig, animate, frames=100, interval=40, blit=True)
            plt.close(anim._fig)
            anim.save('anim.mp4')
            display(HTML(anim.to_html5_video()))
            
    stupid.clear()
    stupid.append(render)
    
    Y1, Y2, Z1, Z2, = simulate(initial_y1, initial_y2, initial_z1, initial_z2, b, k, spring_length, m1, m2, dt)
    fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(15,8))
    y1 = Y1[time]
    z1 = Z1[time]
    draw_world(ax0, y1, z1, spring_length, m1, m2, scale/100.)

    #if simulate_all:
    #    draw_many_phases(ax1, b, k, spring_length, m1, m2, dt)
                
    #draw_phase(ax1, Y1, Y2, time, color='C0', alpha=1., lw=3.)

In [None]:
#dx1 = x2
#dx2 = -k*x1 - b*x2

#dy1 = y2
#dy2 = b/m1 * (z2-y2) + k/m1 * (z1-y1-d)
#dz1 = z2
#dz2 = -b/m2 * (z2-y2) - k/m2 * (z1-y1-d)

BLOCK_WIDTH = 1.
BLOCK_HEIGHT = 1.
ZIGZAG_PADDING = 0.2
ZIGZAG_HEIGHT = 0.6
ZIGZAG_COUNT = 5

Y1I_MIN = -3.
Y1I_MAX = -1.
Y2I_MIN = -2.
Y2I_MAX = 2.
Z1I_MIN = 1.
Z1I_MAX = 3.
Z2I_MIN = -2.
Z2I_MAX = 2.
T_MAX = 400
PHASE_COUNT = 5


def tick(params, state):
    b = params['b']
    k = params['k']
    spring_length = params['spring_length']
    m1 = params['m1']
    m2 = params['m2']
    dt = params['dt']
    y1 = state['y1']
    y2 = state['y2']
    z1 = state['z1']
    z2 = state['z2']
    
    dy1 = y2
    dy2 = b/m1 * (z2 - y2) + k/m1 * (z1 - y1 - spring_length)
    dz1 = z2
    dz2 = -b/m2 * (z2 - y2) - k/m2 * (z1 - y1 - spring_length)
    y1 += dy1 * dt
    y2 += dy2 * dt
    z1 += dz1 * dt
    z2 += dz2 * dt

    new_state = {
        'y1': y1,
        'y2': y2,
        'z1': z1,
        'z2': z2,
    }
    return new_state


@memoize
def simulate(params, initial_state):
    state = initial_state
    states = [initial_state]
    for i in range(T_MAX):
        state = tick(params, state)
        states.append(state)
    return states


def draw_circle(ax, x, y, radius, scale):
    ax.add_artist(
        plt.Circle(
            (x * scale + 0.5, y * scale + 0.5),
            radius * scale,
            color='k',
        )
    )
    
    
def draw_line(ax, x1, y1, x2, y2, scale):
    ax.add_line(
        Line2D(
            (x1 * scale + 0.5, x2 * scale + 0.5),
            (y1 * scale + 0.5, y2 * scale + 0.5),
            linestyle='-',
            linewidth=0.5,
            color='k',
        )
    )
    
    
def draw_rect(ax, x, y, width, height, scale):
    ax.add_patch(patches.Rectangle(
        (x * scale + 0.5, y * scale + 0.5),
        width * scale,
        height * scale,
        color='k'
    ))


def draw_zigzag(ax, x1, x2, scale):
    zzw = x2 - x1
    zzh = ZIGZAG_HEIGHT/2.
    zig_x1 = x1
    zig_x2 = zig_x1 + zzw / 4.
    zag_x1 = zig_x1 + zzw * 3. / 4.
    zag_x2 = x2
    draw_line(ax, zig_x1, 0., zig_x2, zzh, scale)
    draw_line(ax, zig_x2, zzh, zag_x1, -zzh, scale)
    draw_line(ax, zag_x1, -zzh, zag_x2, 0., scale)

    
def draw_spring(ax, x1, x2, scale):
    draw_line(ax, x1, 0., x1 + ZIGZAG_PADDING, 0., scale)
    draw_line(ax, x2, 0., x2 - ZIGZAG_PADDING, 0., scale)
    zigpoints = np.linspace(
        x1 + ZIGZAG_PADDING,
        x2 - ZIGZAG_PADDING,
        num=ZIGZAG_COUNT+1,
    )
    for zig_x1, zag_x2 in zip(zigpoints[:-1], zigpoints[1:]):
        draw_zigzag(ax, zig_x1, zag_x2, scale)


def draw_world(ax, params, state, scale):
    ax.set_aspect('equal')

    spring_length = params['spring_length']
    size1 = np.sqrt(params['m1'])
    size2 = np.sqrt(params['m2'])
    y1 = state['y1']
    z1 = state['z1']
    
    draw_spring(ax, y1, z1, scale)
    draw_rect(
        ax,
        y1 - BLOCK_WIDTH * size1,
        -BLOCK_HEIGHT * size1 / 2.,
        BLOCK_WIDTH * size1,
        BLOCK_HEIGHT * size1,
        scale,
    )
    draw_rect(
        ax,
        z1,
        -BLOCK_HEIGHT * size2 / 2.,
        BLOCK_WIDTH * size2,
        BLOCK_HEIGHT * size2,
        scale,
    )


stupid = []
on_render = lambda x: stupid[0]() if stupid else None  # jezus fucking christ
out = ipywidgets.Output()
render_button = ipywidgets.Button(description='Render')
render_button.on_click(on_render)
display(render_button)
display(out)

def render_animation(params, states, scale):
    out.clear_output()
    with out:
        fig, ax = plt.subplots()
        def animate(i):
            ax.clear()
            time = i * FRAMESKIP
            draw_world(ax, params, states[time], scale/100.)
            return (fig,)

        anim = animation.FuncAnimation(
            fig,
            animate,
            frames=int(T_MAX / FRAMESKIP),
            interval=20 * FRAMESKIP,
            blit=True,
        )
        plt.close(anim._fig)
        anim.save('anim.mp4')
        display(HTML(anim.to_html5_video()))


FRAMESKIP = 2

@interact(
    time=(0, T_MAX - 1, 1),
    initial_y1=(Y1I_MIN, Y1I_MAX),
    initial_y2=(Y2I_MIN, Y2I_MAX),
    initial_z1=(Z1I_MIN, Z1I_MAX),
    initial_z2=(Z2I_MIN, Z2I_MAX),
    b=(0., 3., 0.05),
    k=(0.01, 8.),
    m1=(0.1, 5.),
    m2=(0.1, 5.),
    spring_length=(2., 4.),
    dt=(0.02, 0.1, 0.01),
    scale=(2., 15.),
)
def f(
        time=0,
        initial_y1=-2.,
        initial_y2=0.,
        initial_z1=2.,
        initial_z2=0.,
        b=0.2,
        k=3.,
        m1=1.,
        m2=1.,
        spring_length=3.,
        dt=0.05,
        scale=8.,
        simulate_all=False,
):
    params = {
        'b': b,
        'k': k,
        'm1': m1,
        'm2': m2,
        'spring_length': spring_length,
        'dt': dt,
    }
            
    stupid.clear()
    stupid.append(lambda: render_animation(params, states, scale))
    
    initial_state = {
        'y1': initial_y1,
        'y2': initial_y2,
        'z1': initial_z1,
        'z2': initial_z2,
    }
    states = simulate(params, initial_state)
    fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(15,8))
    
    state = states[time]
    draw_world(ax0, params, state, scale/100.)