In [None]:
import sys, os

if os.path.abspath('..') not in sys.path:
    sys.path.insert(0, os.path.abspath('..'))

import torch
import torch.nn as nn
# from torchdiffeq import odeint_adjoint as odeint
from torchdiffeq import odeint
import altair as alt
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

alt.data_transformers.disable_max_rows()

device = 'cuda:0' if torch.cuda.is_available() else None

## Example N-Pendulum Simulation

In [None]:
from src.systems import ChainPendulum
from src.systems.rigid_body import project_onto_constraints

body = ChainPendulum(3)

# z0 = body.sample_initial_conditions(10)
z0 = body.sample_initial_conditions(1).expand(10, -1, -1, -1)
z0 = project_onto_constraints(body.body_graph, z0 + 0.1 * torch.rand_like(z0), tol=1e-5)

In [None]:
ts = torch.arange(0., 10.0, body.dt, device=z0.device, dtype=z0.dtype)
zt = body.integrate(z0, ts, method='rk4')

In [None]:
body_idx, dof_idx = 2, 0
alt.Chart(pd.DataFrame({
    't': ts.unsqueeze(0).expand(zt.size(0), -1).cpu().numpy().flatten(),
    'y': zt[..., 0, body_idx, dof_idx].flatten().cpu().numpy().flatten(),
    'init': (torch.arange(zt.size(0)) + 1).unsqueeze(-1).expand(-1, zt.size(1)).cpu().numpy().flatten()
})).mark_line(opacity=0.5).encode(x='t:Q', y=alt.Y('y:Q'), color='init:N').properties(width=800, title=f'Mass={body_idx}, DoF={dof_idx}')