In [None]:
from functools import reduce
from fuzix.sims.pendulumcart import draw_world
from fuzix.sims.pendulumcart import simulate
from fuzix.utils import do_render_animation
from fuzix.utils import memoize
from fuzix.utils import RendererWidget
from fuzix.draw_utils import draw_circle
from fuzix.draw_utils import draw_rect
from fuzix.draw_utils import draw_line
from IPython.display import display
from IPython.display import HTML
from ipywidgets import interact
from matplotlib import pyplot as plt, cm, animation, rc
from matplotlib.lines import Line2D
from sympy.physics.mechanics import dynamicsymbols, init_vprinting
import daglet
import functools
import ipywidgets
import matplotlib
import matplotlib.patches as patches
import numpy as np
import operator
import sympy as sp

In [None]:
%matplotlib inline
rc('animation', html='html5')
#rc('animation', html='jshtml')

sp.init_printing(use_unicode=True)
init_vprinting()

In [None]:
SAMPLE_INTERVAL = 4
MAX_TIME_INDEX = 800


renderer_widget = RendererWidget()
display(renderer_widget)


def init_render():
    fig, ax = plt.subplots(figsize=(16,9))
    sz = (2, 4)
    ax0 = plt.subplot2grid(sz, (0, 0), 2, 2, fig)
    ax1 = plt.subplot2grid(sz, (0, 2), 1, 2, fig)
    ax2 = plt.subplot2grid(sz, (1, 2), 1, 1, fig)
    ax3 = plt.subplot2grid(sz, (1, 3), 1, 1, fig)
    fig.tight_layout(pad=0.4, w_pad=2.0, h_pad=5.0)

    #return plt.subplots(2, 2, figsize=(16,9))
    return fig, (ax0, ax1, ax2, ax3)

def get_tension(params, state):
    tension = None
    cart_d2x = state.get('cart_d2x', None)
    if cart_d2x is not None:
        cart_mass = params['cart_mass']
        theta = state['theta']
        denominator = np.sin(theta)
        if np.abs(denominator) > 0.05:
            tension = cart_mass * cart_d2x / denominator
    return tension
        
def render(ctx, params, states, scale, time_index):
    ax1 = None
    
    fig, (ax0, ax1, ax2, ax3) = ctx
    #ax3.set_visible(False)
    ti = time_index
    state = states[ti]

    if ax0 is not None:
        ax0.clear()
        ax0.set_title('World')
        draw_world(ax0, params, state, scale)

    if ax1 is not None:
        ax1.clear()
        ax1.set_title('Tension')
        ax1.set_xlabel('time')

        times = [x['time'] for x in states]
        tension = [get_tension(params, state) for state in states]
        ax1.plot(times, tension, label='tension', color='C0')
        if tension[time_index] is not None:
            ax1.plot(state['time'], tension[time_index], 'o', color='C0')
        #ax1.legend(loc='upper right')
    
    if ax2 is not None:
        ax2.clear()
        ax2.set_title('Phase space (ball)')
        ax2.set_xlabel('angle')
        ax2.set_ylabel('momentum')

        ball_x_vals = [np.sin(x['theta']) for x in states]
        ball_dx_vals = [x['dtheta']*np.cos(x['theta']) for x in states]
        ax2.plot(ball_x_vals, ball_dx_vals, color='C0', label='ball_x')
        ax2.plot(ball_x_vals[ti], ball_dx_vals[ti], 'o', color='C0')
        #ax3.legend(loc='upper right')

    if ax3 is not None:
        ax3.clear()
        ax3.set_title('Phase space (cart)')
        ax3.set_xlabel('position')
        ax3.set_ylabel('momentum')

        cart_x_vals = [x['cart_x'] for x in states]
        cart_dx_vals = [x['cart_dx'] for x in states]
        ax3.plot(cart_x_vals, cart_dx_vals, color='C2', label='cart_x')
        ax3.plot(state['cart_x'], state['cart_dx'], 'o', color='C2')
        #ax3.legend(loc='upper right')


def render_animation(params, states, scale, sample_interval=1):
    ctx = init_render()
    fig, _ = ctx
    render_func = lambda _, time_index: render(ctx, params, states, scale, time_index)
    return do_render_animation(render_func, params['max_time_index'], sample_interval, fig=fig)


def do_render(params, states, scale):
    anim = render_animation(params, states, scale, sample_interval=SAMPLE_INTERVAL)
    anim.save('anim.mp4')
    display(anim)


@interact(
    time_index=(0, MAX_TIME_INDEX - 1, 1),
    cart_mass=(0.1, 5.),
    ball_mass=(0.1, 5.),
    initial_theta=(-np.pi, np.pi),
    initial_dtheta=(-2., 2.),
    initial_cart_x=(-5., 5.),
    initial_cart_dx=(-2., 2.),
    length=(1., 5.),
    dt=(0.01, 0.1, 0.01),
    scale=(2., 15.),
)
def f(
        time_index=0,
        cart_mass=0.9,
        ball_mass=0.6,
        initial_theta=1.96,
        initial_dtheta=-0,
        initial_cart_x=-3.9,
        initial_cart_dx=0.5,
        length=3.,
        dt=0.01,
        scale=9.,
        simulate_all=False,
):
    
    params = {
        'dt': dt,
        'max_time_index': MAX_TIME_INDEX,
        'length': length,
        'cart_mass': cart_mass,
        'ball_mass': ball_mass,
    }
    initial_state = {
        'time': 0.,
        'theta': initial_theta,
        'dtheta': initial_dtheta,
        'cart_x': initial_cart_x,
        'cart_dx': initial_cart_dx,
    }

    states = simulate(params, initial_state)
    
    renderer_widget.render_func = lambda: do_render(params, states, scale/100.)

    ctx = init_render()
    render(ctx, params, states, scale/100., time_index)
