In [None]:
%matplotlib notebook
from hamiltonian import *
from equivariant_color.datasets import NBodyDynamics

import math

torch.set_default_tensor_type(torch.DoubleTensor)

## Kepler Dynamics

In [None]:
n, d = 6, 2
n_body_dynamics = NBodyDynamics(n_systems=4, regen=False)
systems = n_body_dynamics.sample_system(n_systems=2, space_dim=d)
z0 = systems[0]
masses = systems[1][0]

H = lambda t,z: KeplerH(z,masses)
Dynamics = HamiltonianDynamics(H,wgrad=False)
ts = torch.linspace(0, 2., 100).double()
with torch.no_grad():
    zt = odeint(Dynamics,z0,ts,rtol=1e-8, method='dopri5').cpu().data.numpy()

In [None]:
qt = zt[:,0,:n*d].reshape(zt.shape[0],n,d).transpose((1,2,0))
xlim = ylim = zlim = (-1, 1)
A = AnimationNd(d)(qt, None)
a = A.animate()
plt.show()

In [None]:
# n = 6
# G = 1. # 6.67408 * 10 ** (-11)
# d = 3
# star_mass = 32.
# star_pos = (0, 0, 0)
# star_vel = (0, 0, 0)

# planet_mass_min, planet_mass_max = 2e-2, 2e-1
# planet_mass_range = planet_mass_max - planet_mass_min

# planet_dist_min, planet_dist_max = 0.25, 1.
# planet_dist_range = planet_dist_max - planet_dist_min

In [None]:
# star = torch.tensor([star_mass, *star_pos, *star_vel])

# # sample planet masses, radius vectors
# planet_masses = planet_mass_range * torch.rand(n-1, 1) + planet_mass_min

# theta = 2 * math.pi * torch.rand(n-1)
# phi = torch.acos(2 * torch.rand(n-1) - 1) # incorrect to uniformly sample \phi \in [0, \pi]
# rho = torch.linspace(planet_dist_min, planet_dist_max, n-1) + (0.3 * (torch.rand(n-1)-.5)*(planet_dist_range/(n-1)))

# planet_pos = torch.stack([
#     rho * torch.sin(phi) * torch.cos(theta),
#     rho * torch.sin(phi) * torch.sin(theta),
#     rho * torch.cos(phi)
# ], dim=1)

# # get orthonormal tangent plane basis vectors
# e_1 = torch.stack([torch.zeros(n-1), -planet_pos[:, 2], planet_pos[:, 1]], dim=1)
# e_2 = torch.cross(planet_pos, e_1, dim=-1)

# # print((e_1 * e_2).sum(-1))
# e_1 = e_1 / e_1.norm(dim=-1, keepdim=True)
# e_2 = e_2 / e_2.norm(dim=-1, keepdim=True)

# # sample initial stable orbit velocities
# planet_vel_magnitude = (G * star_mass / rho).sqrt().unsqueeze(-1)
# omega = 2 * math.pi * torch.rand(n - 1, 1)
# planet_vel = torch.cos(omega) * e_1 + torch.sin(omega) * e_2
# # print(planet_vel.norm(dim=-1))
# planet_vel = planet_vel_magnitude * planet_vel
# # planet_vel *= 1 + 0.2 * (torch.rand(n-1, 1)-0.5)
# planet_momentum = planet_masses * planet_vel

# planets = torch.cat([planet_masses, planet_pos, planet_momentum], dim=1)
# sim_params = torch.cat([star.unsqueeze(0), planets])

In [None]:
# masses = .1+5*torch.rand(n)[None].double()#torch.tensor([10,40,70,500]).double()[None,:n]#10*(torch.rand(n)+.1)[None]
# q0 = .8*torch.randn(n,d).double()#torch.tensor([3,2,-3,0,0,0,0,5,0,1,1,6]).double()[:n*d]#
# p0 = .4*torch.randn(n,d).double()#torch.tensor([1,20,-6,0,0,14,50,0,0,0,0,0]).double()[:n*d]#30*(torch.randn(n,d)).reshape(n*d)
# p0 -= p0.mean(0,keepdim=True)
# #q0 -= q0.mean(0,keepdim=True)
# z0 = torch.cat([q0.reshape(n*d),p0.reshape(n*d)])[None,:]
# masses.shape

In [None]:
# n_body_dynamics = NBodyDynamics(N=4, regen=True)
# systems = n_body_dynamics.sample_system(4)
# time_points, obs = n_body_dynamics.sim_trajectories(systems[0], systems[1])
# print(time_points.shape)

In [None]:
# masses = sim_params[:, 0].unsqueeze(0)
# pos = sim_params[:, 1:4].reshape(-1)
# momentum = sim_params[:, 4:7].reshape(-1)
# z0 = torch.cat([pos, momentum]).unsqueeze(0)


In [None]:
# momentums = zt[:,0,n*d:].reshape(zt.shape[0],n,d)
# np.linalg.norm(momentums, axis=-1).std(0)

## Spring Dynamics

In [None]:
n = 6
d = 2
masses = (.9*torch.rand(n).double()[None]+.1)#torch.tensor([10,40,70,500]).double()[None,:n]#10*(torch.rand(n)+.1)[None]
k = .5*torch.rand(n).double()[None]+.5
#K = k[:,:,None]*k[:,None,:] #(bs,n,n)
q0 = 2*torch.rand(n,d).double()-1#torch.tensor([3,2,-3,0,0,0,0,5,0,1,1,6]).double()[:n*d]#
p0 = 3*(2*torch.rand(n,d).double()-1)#/masses[0][:,None]#torch.tensor([1,20,-6,0,0,14,50,0,0,0,0,0]).double()[:n*d]#30*(torch.randn(n,d)).reshape(n*d)
p0 -= p0.mean(0,keepdim=True)
z0 = torch.cat([q0.reshape(n*d),p0.reshape(n*d)])[None,:]

H = lambda t,z: SpringH(z,masses,k)
Dynamics = HamiltonianDynamics(H,wgrad=False)
ts = torch.linspace(0,5,500).double()
with torch.no_grad():
    zt = odeint(Dynamics,z0,ts,rtol=1e-5).cpu().data.numpy()
qt = zt[:,0,:n*d].reshape(zt.shape[0],n,d).transpose((1,2,0))
A = AnimationNd(d)(qt)
a = A.animate()
plt.show()

## Ball Dynamics

In [None]:
n = 6
d = 2

rs = .15*torch.rand(n).double()[None]+.1
masses =45*(12*rs)**d
q0 = .5*(2*torch.rand(n,d).double()-1)#torch.tensor([3,2,-3,0,0,0,0,5,0,1,1,6]).double()[:n*d]#
p0 = 1*torch.randn(n,d).double()*3**d#torch.tensor([1,20,-6,0,0,14,50,0,0,0,0,0]).double()[:n*d]#30*(torch.randn(n,d)).reshape(n*d)
z0 = torch.cat([q0.reshape(n*d),p0.reshape(n*d)])[None,:]

H = lambda t,z: BallH(z,masses,rs)
Dynamics = HamiltonianDynamics(H,wgrad=False)
ts = torch.linspace(0,1,500).double()
with torch.no_grad():
    zt = odeint(Dynamics,z0,ts,rtol=1e-3).cpu().data.numpy()
qt = zt[:,0,:n*d].reshape(zt.shape[0],n,d).transpose((1,2,0))
A = AnimationNd(d)(qt,ms=300*rs[0])
a = A.animate()
plt.show()

In [None]:
p = zt[:,0,n*d:].reshape(-1,n,d)
p0n = p0.cpu().data.numpy().reshape(1,n,d)
q = qt.transpose(2,0,1)
Et = torch.stack([H(t,torch.from_numpy(z)) for t,z in zip(ts,zt)],dim=0).cpu().data.numpy()
pcm = (p-p.mean(1,keepdims=True))
qcm = (q-q.mean(1,keepdims=True))
cross = (pcm[:,:,None,:]*qcm[:,:,:,None]-qcm[:,:,None,:]*pcm[:,:,:,None]).sum(1)
angmom = cross.reshape(-1,d**2)#[:,[5,2,1]] # pull out (2,3),(0,2),(0,1) = Lx, -Ly, Lz
#angmom[:,1] *=-1
fig, axs = plt.subplots(3, 1,sharex=True)
axs[0].plot(ts,(Et-Et[0])/Et[0])
axs[0].set_title("Energy drift")
axs[1].plot(ts,np.linalg.norm(p.sum(1)-p0n.sum(1),axis=-1)/(np.linalg.norm(p0n.sum(1),axis=-1)+1e-10))
axs[1].set_title("Momentum drift")
axs[2].plot(ts,np.linalg.norm(angmom-angmom[:1],axis=-1)/(np.linalg.norm(angmom[:1],axis=-1)+1e-10))
axs[2].set_title("Angular Momentum drift")
for ax in axs.flat:
    ax.set(ylabel='relative error')
plt.xlabel('time')
plt.show()

In [None]:
import ipywidgets as widgets
from ipywidgets import interact
play = widgets.Play(
#     interval=10,
    value=2,
    min=2,
    max=qt.shape[-1]-1,
    step=1,
    description="Press play",
    disabled=False
)
slider = widgets.IntSlider()
# widgets.jslink((play, 'value'), (slider, 'value'))
# widgets.HBox([play, slider])
interact(A.update,i=(1,99))