## Train the PointConv HamNet on Spring Dynamics

In [None]:
import os
import torch
from torch.utils.data import DataLoader
from torch.optim import SGD, Adam
from oil.utils.utils import LoaderTo, cosLr, recursively_update,islice, Eval
from equivariant_color.datasets import SpringDynamics
from equivariant_color.dynamicsTrainer import IntegratedDynamicsTrainer,FCHamNet, Partial,RawDynamicsNet, LieConvNetT2, LieResNetT2
from graphnets import OGN,HOGN, VOGN
def makeTrainer(lr=1e-2,N=4000, regen=False):
    # Create Training set and model
    trainset = SpringDynamics('~/datasets/ODEDynamics/SpringDynamics/',N=N,regen=regen)
    testset = SpringDynamics('~/datasets/ODEDynamics/SpringDynamics/',N=N//10,train=False,regen=regen)
    device = torch.device('cuda')
    dtype = torch.float32
    model = FCHamNet(k=256).to(device=device,dtype=dtype)
    #model = LieConvNetT2(k=256,bn=False).to(device=device,dtype=dtype)
    #model = LieResNetT2(k=512,bn=False).to(device=device,dtype=dtype)
    #model = HOGN(k=256).to(device=device,dtype=dtype)
    #model = OGN(k=256).to(device=device,dtype=dtype)
    #model = VOGN(k=512).to(device=device,dtype=dtype)
    #model = RawDynamicsNet(k=256).to(device=device,dtype=dtype)
    #model.double()
    
    # Create train and Dev(Test) dataloaders and move elems to gpu
    dataloaders = {}
    dataloaders['train'] = DataLoader(trainset,batch_size=200,num_workers=0,shuffle=True,pin_memory=True)
    dataloaders['Train'] = dataloaders['train']
    dataloaders['Test'] = DataLoader(testset,batch_size=100,num_workers=0,shuffle=False,pin_memory=True)
    dataloaders = {k:LoaderTo(v,device=device,dtype=dtype) for k,v in dataloaders.items()}
    
    # Initialize optimizer and learning rate schedule
    opt_constr = lambda params: Adam(params, lr=lr)
    lr_sched = lambda e: 1#cosLr(cfg['num_epochs'])
    return IntegratedDynamicsTrainer(model,dataloaders,opt_constr,lr_sched,log_args={'timeFrac':1/2,'minPeriod':0.0})

In [None]:
trainer = makeTrainer(lr=1e-2, N=200, regen=True)
trainer.train(32)

All trained for 10 epochs @ N=4000 on SpringDynamics w/ original chunking

LieResNet k=1024 lr=3e-3
Test_MSE  Train_MSE    lr0   nfe
0.000269   0.000044  0.003  23.5

LieResNet k=512 lr=3e-3 ===================
Test_MSE  Train_MSE    lr0        nfe
0.000143   0.000047  0.003  25.411765
            
LieResNet k=256 lr=3e-3
Test_MSE  Train_MSE    lr0        nfe
0.000774   0.000055  0.003  26.641711

OGN k=256 lr=1e-2 =========================
Test_MSE	Train_MSE	lr0	nfe
0.002872	0.001338	0.01	46.545455

HOGN k=256(tuned) lr=1e-2(tuned) ==========
Test_MSE	Train_MSE	lr0	      nfe
0.003142	0.001004	0.01	29.333333

VOGN k=256 lr=1e-2
Test_MSE	Train_MSE	lr0	nfe
0.000295	0.000136	0.01	23.272727

VOGN k=512 lr=1e-2 =======================
Test_MSE  Train_MSE   lr0        nfe
0.000243   0.000153  0.01  23.179487

RawDynamics k=1024 lr=3e-3
Test_MSE	Train_MSE	lr0	nfe
0.053781	0.003038	0.003	33.777778

RawDynamics k=512 lr=3e-3
Test_MSE	Train_MSE	lr0	nfe
0.039395	0.004135	0.003	33.777778

RawDynamics k=256 lr=3e-3 ================
Test_MSE	Train_MSE	lr0	nfe
0.036957	0.008535	0.003	23.249545

RawDynamics k=128 lr=3e-3
Test_MSE	Train_MSE	lr0	nfe
0.084753	0.034763	0.003	22.956522

FCHamNet k=512 lr=3e-3
Test_MSE	Train_MSE	lr0	nfe
0.057443	0.002975	0.003	24.0

FCHamNet k=256 lr=1e-2 ===================
Test_MSE  Train_MSE   lr0   nfe
0.028576   0.003485  0.01  24.0

FCHamNet k=128 lr=1e-2
Test_MSE  Train_MSE	  lr0	nfe
0.03035	   0.008771	 0.01	24.0

In [None]:
import matplotlib.pyplot as plt
trainer.logger.scalar_frame.drop(columns=['lr0','nfe']).plot()
plt.yscale('log')

In [None]:
import matplotlib.pyplot as plt
trainer.logger.scalar_frame.drop(columns=['lr0','nfe']).plot()
plt.yscale('log')

## Do the trajectories look reasonable?

In [None]:
ds = trainer.dataloaders['train'].dataset
z0s,system_params = ds.sample_system(bs=1)
ts,zts = ds.sim_trajectories(z0s,system_params)

In [None]:
from torchdiffeq import odeint

with torch.no_grad(), Eval(trainer.model):
    tz_model = Partial(trainer.model.cuda(),sysP=torch.stack(system_params,dim=-1).cuda())
    zt = odeint(tz_model,z0s.cuda(),ts.cuda(),rtol=trainer.hypers['tol'],method='rk4').detach().cpu().data.numpy()

In [None]:
%matplotlib notebook
from hamiltonian import HamiltonianDynamics, SpringH, AnimationNd
import matplotlib.pyplot as plt
n =6
d =2
qt = zt[:,0,:n*d].reshape(zt.shape[0],n,d).transpose((1,2,0))
A = AnimationNd(d)(qt)
a = A.animate()
plt.show()

In [None]:
H = lambda t,z: SpringH(z,*system_params)
Dynamics = HamiltonianDynamics(H,wgrad=False)
with torch.no_grad():
    ztgt = odeint(Dynamics,z0s,ts,rtol=1e-4,method='rk4').cpu().data.numpy()
    qtgt = ztgt[:,0,:n*d].reshape(zt.shape[0],n,d).transpose((1,2,0))

In [None]:
A = AnimationNd(d)(qtgt)
a = A.animate()
plt.show()

## Are Energy, Momentum, Angular momentum conserved?

In [None]:
import numpy as np
gt = False # change to true to evaluate ground truth dynamics false for learned
(ztt,qtt) = (ztgt,qtgt) if gt else (zt,qt)
p = ztt[:,0,n*d:].reshape(-1,n,d)
p0 = z0s[0,n*d:].reshape(n,d)
p0n = p0.cpu().data.numpy().reshape(1,n,d)
q = qtt.transpose(2,0,1)
Et = torch.stack([H(t,torch.from_numpy(z)) for t,z in zip(ts,ztt)],dim=0).cpu().data.numpy()
pcm = (p-p.mean(1,keepdims=True))
qcm = (q-q.mean(1,keepdims=True))
cross = (pcm[:,:,None,:]*qcm[:,:,:,None]-qcm[:,:,None,:]*pcm[:,:,:,None]).sum(1)
angmom = cross.reshape(-1,d**2)#[:,[5,2,1]] # pull out (2,3),(0,2),(0,1) = Lx, -Ly, Lz
#angmom[:,1] *=-1
fig, axs = plt.subplots(3, 1,sharex=True)
axs[0].plot(ts,(Et-Et[0])/Et[0])
axs[0].set_title("Energy drift")
axs[1].plot(ts,np.linalg.norm(p.sum(1)-p0n.sum(1),axis=-1)/(np.linalg.norm(p0n.sum(1),axis=-1)+1e-10))
axs[1].set_title("Momentum drift")
axs[2].plot(ts,np.linalg.norm(angmom-angmom[:1],axis=-1)/(np.linalg.norm(angmom[:1],axis=-1)+1e-10))
axs[2].set_title("Angular Momentum drift")
for ax in axs.flat:
    ax.set(ylabel='relative error')
plt.xlabel('time')
plt.show()

In [None]:
#from oil.tuning.args import argupdated_config
#from oil.tuning.study import Study
#import __init__
#from __init__ import *
#TrainTrial = train_trial(makeTrainer)
# thestudy = Study(TrainTrial,argupdated_config(config_spec,namespace=__init__),
#                 study_name="springpoint",base_log_dir=log_dir)
# thestudy.run(ordered=False)
#print(thestudy.covariates())

In [None]:
bs,n = 3,5
cols = (torch.arange(n)[:,None]*torch.ones(n)[None,:])
cols = (cols[None,:,:]+n*torch.arange(bs)[:,None,None]).long() #(bs,n,n) -> (bs*n*n)
edge_index = cols.permute(0,2,1).reshape(-1), cols.reshape(-1)

In [None]:
edge_index

In [None]:
edge_index[-1]

In [None]:
batch = (torch.arange(bs)[:,None]+torch.zeros(n)[None,:]).reshape(-1)

In [None]:
batch[edge_index[-1]]

In [None]:
batch