In [None]:
#%matplotlib widget

In [None]:
from ipywidgets import interact
from matplotlib import cm
from matplotlib import pyplot as plt
from matplotlib.ticker import FormatStrFormatter
from matplotlib.ticker import LinearLocator
from mpl_toolkits.mplot3d import Axes3D
from scipy.stats import kde
import numpy as np
import seaborn as sns

In [None]:
def simulate_unidirectionally(get_ddq_func, q, dq, dt=0.01, t_max=100):
    Q = []
    DQ = []
    for i in range(t_max):
        ddq = get_ddq_func(q, dq)
        dq += ddq * dt
        q += dq * dt
        Q.append(q)
        DQ.append(dq)
    return Q, DQ


def simulate(get_ddq_func, q, dq, dt=0.01, t_max=100):
    Q1, DQ1 = simulate_unidirectionally(get_ddq_func, q, dq, -dt, t_max)
    Q2, DQ2 = simulate_unidirectionally(get_ddq_func, q, dq, dt, t_max)
    Q = list(reversed(Q1)) + [q] + Q2
    DQ = list(reversed(DQ1)) + [dq] + DQ2
    return Q, DQ

In [None]:
g = 10.
m = 1.


def get_ddx(x, dx):
    return -g


def get_ddq(q, dq):
    return -(g + 2 * dq**2) / (2*q)


@interact(
    view_angle=(0, 360., 10.),
    sim_count=(1, 10),
)
def f(
    view_angle=120.,
    sim_count=1,
):
    x_min, x_max = -10., 10.
    dx_min, dx_max = -10., 10.
    dt = 0.01
    t_max = 100

    fig, ((ax0, ax1), (ax2, ax3)) = plt.subplots(2, 2, figsize=(15, 10))
    ax2.axis('off')
    ax2 = fig.add_subplot(2, 2, 3, projection='3d')
    ax3.axis('off')
    ax3 = fig.add_subplot(2, 2, 4, projection='3d')
    
    ax0.set_xlabel('time')
    ax0.set_ylabel('position')
    #ax0.set_xlim(-t_max, t_max)
    ax0.set_ylim(x_min, x_max)

    ax1.set_xlabel('x')
    ax1.set_ylabel('dx')
    ax1.set_xlim(x_min, x_max)
    ax1.set_ylim(dx_min, dx_max)

    ax2.set_xlabel('x')
    ax2.set_ylabel('dx')
    ax2.view_init(30, -view_angle)
    ax2.set_xlim(x_min, x_max)
    ax2.set_ylim(dx_min, dx_max)
    
    ax3.set_xlabel('t')
    ax3.set_ylabel('dx')
    ax3.set_zlabel('x')
    ax3.view_init(30, -view_angle)
    ax3.set_ylim(dx_min, dx_max)
    ax3.set_zlim(x_min, x_max)

    with np.errstate(divide='ignore', invalid='ignore'):
        np.random.seed(1212899)

        for i in range(sim_count):
            x = np.random.uniform(low=x_min, high=x_max)
            dx = np.random.uniform(low=dx_min, high=dx_max)
            X2, DX2 = simulate(get_ddx, x, dx, dt, t_max)

            X3 = []
            DX3 = []
            T = []
            Z = []
            for t, (x, dx) in enumerate(zip(X2[::10], DX2[::10])):
                if x >= x_min and x <= x_max and dx >= dx_min and dx <= dx_max:
                    X3.append(x)
                    DX3.append(dx)
                    T.append(t)
                    Z.append(1/2 * m * dx ** 2 + m * g * x)

            ax0.plot(T, X3, color='b')
            ax1.plot(X3, DX3, color='b', alpha=0.5)
            ax2.plot(X3, DX3, Z, color='b')
            ax3.plot(T, DX3, X3, color='b')
        
        X, DX = np.meshgrid(
            np.linspace(x_min, x_max, 15),
            np.linspace(dx_min, dx_max, 15),
        )
        DDX = get_ddx(X, DX)

        T = 1/2 * m * DX**2
        V = -m*g*X
        L = T - V
        
        ax1.imshow(
            L,
            interpolation='bilinear',
            origin='lower',
            cmap=plt.cm.coolwarm,
            extent=(x_min, x_max, dx_min, dx_max),
        )

        ax2.plot_surface(
            X,
            DX,
            L,
            cmap=cm.coolwarm,
            linewidth=0,
            antialiased=False,
            alpha=0.5,
        )

        ax1.quiver(X, DX, DX, DDX, scale=200., units='width', color='C0', alpha=0.2)
        ax2.quiver(X, DX, L, DX*0.1, DDX*0.1, np.zeros_like(L), color='C0', alpha=0.2)