In [None]:
import os
import torch
from hamiltonian import *
from torch.utils.data import DataLoader
from torch.optim import SGD, Adam
from oil.utils.utils import LoaderTo, cosLr, recursively_update,islice, Eval
from equivariant_color.datasets import NBodyDynamics
from equivariant_color.dynamicsTrainer import IntegratedDynamicsTrainer,FCHamNet, Partial,RawDynamicsNet, LieConvNetT2, LieResNetT2
from graphnets import OGN,HOGN, VOGN

class ZeroBaseline(torch.nn.Module):
    def __init__(self, device):
        super().__init__()
        self.nfe = 0
        self.device = device
    def parameters(self):
        return [torch.tensor([1.], requires_grad=True)]
    def forward(self, *args, **kwargs):
        self.nfe += 1
        return torch.tensor((0.,), requires_grad=True, device=self.device)

def makeTrainer(lr=1e-2,N=200, regen=True):
    # Create Training set and model
    trainset = NBodyDynamics(n_systems=N,regen=regen, chunk_len=20)
    testset = NBodyDynamics(n_systems=N//10,train=False,regen=regen, chunk_len=20)
    
    device = torch.device('cuda')
    dtype = torch.float64
#     model = ZeroBaseline(device).to(device)
    model = FCHamNet(k=256, d=trainset.space_dim, sys_dim=trainset.sys_dim).to(device=device,dtype=dtype)
#     model = LieResNetT2(k=384, num_layers=4, d=trainset.space_dim, sys_dim=trainset.sys_dim).to(device=device,dtype=dtype)
#     model = HOGN(k=512, num_layers=1, d=trainset.space_dim, sys_dim=trainset.sys_dim).to(device=device,dtype=dtype)
#     model = OGN(k=256, d=trainset.space_dim, sys_dim=trainset.sys_dim).to(device=device,dtype=dtype)
#     model = VOGN(k=512, d=trainset.space_dim, sys_dim=trainset.sys_dim).to(device=device,dtype=dtype)
#     model = RawDynamicsNet(k=512, d=trainset.space_dim, sys_dim=trainset.sys_dim).to(device=device,dtype=dtype)
    model.double()
    
    # Create train and Dev(Test) dataloaders and move elems to gpu
    dataloaders = {}
    dataloaders['train'] = DataLoader(trainset,batch_size=200,num_workers=0,shuffle=True,pin_memory=True)
    dataloaders['Train'] = dataloaders['train']
    dataloaders['Test'] = DataLoader(testset,batch_size=200,num_workers=0,shuffle=False,pin_memory=True)
    dataloaders = {k:LoaderTo(v,device=device,dtype=dtype) for k,v in dataloaders.items()}
    
    # Initialize optimizer and learning rate schedule
    opt_constr = lambda params: Adam(params, lr=lr)
    lr_sched = lambda e: 1#cosLr(cfg['num_epochs'])
    return IntegratedDynamicsTrainer(model,dataloaders,opt_constr,lr_sched,tol=1e-4, log_args={'timeFrac':1/2,'minPeriod':0.0})

In [None]:
%debug
trainer = makeTrainer(N=5000, lr=5e-3, regen=True)
trainer.train(100)

In [None]:
import matplotlib.pyplot as plt
trainer.logger.scalar_frame.drop(columns=['lr0','nfe']).plot()
plt.yscale('log')

In [None]:
# n_systems = 20
# n_bodies = 6  # NBodyDynamics assumes n=6
# d = 3
# trainset = NBodyDynamics(N=1,regen=False)
# timesteps, z, sys_params = [], [], []
# for i in range(n_systems):
#     z0, params = trainset.sample_system(bs=1, n=n_bodies)
#     traj = trainset.sim_trajectories(z0, params)
#     z.append(traj[1])
#     sys_params.append(params[0])
    
# t = traj[0]
# z = torch.cat(z, dim=1)
# sys_params = torch.cat(sys_params).unsqueeze(-1)
# t, z, p = trainset.chunk(t, z, sys_params, C=5)

In [None]:
# root_dir = os.path.expanduser("~/datasets/ODEDynamics/NBodyDynamics")
# torch.save((t, z, p), os.path.join(root_dir, f"dyn_{n_systems}_test.pz"))

In [None]:
# # print(pos_mom.shape)
# qt = pos_mom[:, 0, :n_bodies*d].reshape(-1, n_bodies, d).permute(1, 2, 0)
# print(qt.shape)

# fig = plt.figure()
# ax = fig.add_subplot(111, projection='3d')
# for i in range(n_bodies):
#     ax.scatter(qt[i, 0], qt[i, 1], qt[i, 2])
#     ax.set_xlim((-150, 150))
#     ax.set_ylim((-150, 150))
#     ax.set_zlim((-150, 150))

In [None]:
# %matplotlib notebook
# xlim = ylim = zlim = (-150, 150)
# A = AnimationNd(d)(qt.detach().cpu().numpy(), None, xlim, ylim, zlim)
# a = A.animate()
# plt.show()