In [1]:
import sys, os

if os.path.abspath('..') not in sys.path:
    sys.path.insert(0, os.path.abspath('..'))

In [2]:
import torch
import torch.nn as nn
# from torchdiffeq import odeint_adjoint as odeint
from torchdiffeq import odeint
import altair as alt
import pandas as pd
import numpy as np

device = 'cuda:0' if torch.cuda.is_available() else None

In [3]:
class Lambda(nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer('A', torch.tensor([[-0.1, 2.0], [-2.0, -0.1]]))
    def forward(self, t, y):
        dz = (y**3).mm(self.A)
        return dz #+ torch.randn_like(dz)

In [4]:
with torch.no_grad():
    true_y0 = torch.tensor([[2., 0.]]).to(device)
    t = torch.linspace(0., 25., 1000).to(device)
    true_y = odeint(Lambda().to(device), true_y0, t, method='dopri5')

In [5]:
alt.Chart(pd.DataFrame({
    'x': true_y[:, 0, 0].cpu().numpy(),
    'y': true_y[:, 0, 1].cpu().numpy(),
    't': t.cpu().numpy(),
})).mark_circle().encode(x='x:Q', y='y:Q', color=alt.Color('t:Q', scale=alt.Scale(scheme='plasma')))

## Example Chain Pendulum Simulation

In [8]:
import matplotlib.pyplot as plt

from src.systems import ChainPendulum

body = ChainPendulum(3)
z0s = body.sample_initial_conditions(10)
ts = torch.arange(0., body.integration_time, body.dt, device=z0s.device, dtype=z0s.dtype)
new_zs = body.integrate(z0s, ts)

In [None]:
print(np.shape(new_zs))

pos_idx = 0
for ic_idx in range(10):
    for body_idx in range(3):
        plt.scatter(z0s[ic_idx,pos_idx,body_idx,0], 
                    z0s[ic_idx,pos_idx,body_idx,1])
        plt.plot(new_zs[ic_idx,:,pos_idx,body_idx,0],
                 new_zs[ic_idx,:,pos_idx,body_idx,1])

    plt.show()
    plt.close()