Skip to content

Commit

Permalink
added docstrings to hamiltonian dynamics and removed some unnecessary…
Browse files Browse the repository at this point in the history
… features
  • Loading branch information
mfinzi committed Apr 23, 2021
1 parent 1542fcb commit d958c16
Showing 1 changed file with 60 additions and 27 deletions.
87 changes: 60 additions & 27 deletions experiments/trainer/hamiltonian_dynamics.py
Expand Up @@ -24,6 +24,8 @@
from functools import partial
from itertools import islice

## Code to rollout a Hamiltonian system

def unpack(z):
D = jnp.shape(z)[-1]
assert D % 2 == 0
Expand All @@ -35,29 +37,54 @@ def pack(q, p_or_v):
return jnp.concatenate([q, p_or_v], axis=-1)

def symplectic_form(z):
""" Equivalent to multiplying z by the matrix J=[[0,I],[-I,0]]"""
q, p = unpack(z)
return pack(p, -q)

def hamiltonian_dynamics(hamiltonian, z,t):
grad_h = grad(hamiltonian)
gh = grad_h(z)
return symplectic_form(gh)
""" Takes a Hamiltonian function, a state vector z, and an unused time t
to compute the hamiltonian dynamics J∇H"""
grad_h = grad(hamiltonian) # ∇H
gh = grad_h(z) # ∇H(z)
return symplectic_form(gh) # J∇H(z)

def HamiltonianFlow(H,z0,T):
""" Converts a Hamiltonian H and initial conditions z0
to rolled out trajectory at time points T.
z0 shape (state_dim,) and T shape (t,) yields (t,state_dim) rollout."""
dynamics = lambda z,t: hamiltonian_dynamics(H,z,t)
return odeint(dynamics, z0, T, rtol=1e-4, atol=1e-4)#.transpose((1,0,2))

def BHamiltonianFlow(H,z0,T,tol=1e-4):
""" Batched version of HamiltonianFlow, essentially equivalent to vmap(HamiltonianFlow),
z0 of shape (bs,state_dim) and T of shape (t,) yields (bs,t,state_dim) rollouts """
dynamics = jit(vmap(jit(partial(hamiltonian_dynamics,H)),(0,None)))
return odeint(dynamics, z0, T, rtol=tol).transpose((1,0,2))

def BOdeFlow(dynamics,z0,T,tol=1e-4):
""" Batched integration of ODE dynamics into rollout trajectories.
Given dynamics (state_dim->state_dim) and z0 of shape (bs,state_dim)
and T of shape (t,) outputs trajectories (bs,t,state_dim) """
dynamics = jit(vmap(jit(dynamics),(0,None)))
return odeint(dynamics, z0, T, rtol=tol).transpose((1,0,2))
#BHamiltonianFlow = jit(vmap(HamiltonianFlow,(None,0,None)),static_argnums=(0,))

class HamiltonianDataset(Dataset):

""" A dataset that generates trajectory chunks from integrating the Hamiltonian dynamics
from a given Hamiltonian system and initial condition distribution.
Each element ds[i] = ((ic,T),z_target) where ic (state_dim,) are the initial conditions,
T are the evaluation timepoints, and z_target (T,state_dim) is the ground truth trajectory chunk.
Here state_dim includes both the position q and canonical momentum p concatenated together.
Args:
n_systems (int): total number of trajectory chunks that makeup the dataset.
chunk_len (int): the number of timepoints at which each chunk is evaluated
dt (float): the spacing of the evaluation points (not the integrator step size which is set by tol=1e-4)
integration_time (float): The integration time for evaluation rollouts and also
the total integration time from which each trajectory chunk is randomly sampled
regen (bool): whether or not to regenerate and overwrite any datasets cached to disk
with the same arguments. If false, will use trajectories saved at {filename}
Returns:
Dataset: A (torch style) dataset. """
def __init__(self,n_systems=100,chunk_len=5,dt=0.2,integration_time=30,regen=False):
super().__init__()
root_dir = os.path.expanduser(f"~/datasets/ODEDynamics/{self.__class__}/")
Expand Down Expand Up @@ -106,10 +133,17 @@ def chunk_training_data(self, zs, chunk_len):
return chosen_zs

def H(self,z):
""" The Hamiltonian function, depending on z=pack(q,p)"""
raise NotImplementedError

def sample_initial_conditions(self,bs):
""" Initial condition distribution """
raise NotImplementedError

def animate(self, zt=None):
""" Visualize the dynamical system, or given input trajectories.
Usage: from IPython.display import HTML
HTML(dataset.animate())"""
if zt is None:
zt = np.asarray(self.integrate(self.sample_initial_conditions(10)[0],self.T_long))
# bs, T, 2nd
Expand All @@ -122,6 +156,7 @@ def animate(self, zt=None):
return anim.animate()

class SHO(HamiltonianDataset):
""" A basic simple harmonic oscillator"""
def H(self,z):
ke = (z[...,1]**2).sum()/2
pe = (z[...,0]**2).sum()/2
Expand All @@ -130,6 +165,7 @@ def sample_initial_conditions(self,bs):
return np.random.randn(bs,2)

class DoubleSpringPendulum(HamiltonianDataset):
""" The double spring pendulum dataset described in the paper."""
def __init__(self,*args,**kwargs):
super().__init__(*args,**kwargs)
self.rep_in = 4*T(1)#Vector
Expand Down Expand Up @@ -161,12 +197,11 @@ def animator(self):


class IntegratedDynamicsTrainer(Regressor):
""" A trainer for training the Hamiltonian Neural Networks. Feel free to use your own instead."""
def __init__(self,model,*args,**kwargs):
super().__init__(model,*args,**kwargs)
self.loss = objax.Jit(self.loss,model.vars())
#self.model = objax.Jit(self.model)
self.gradvals = objax.Jit(objax.GradValues(self.loss,model.vars()))#objax.Jit(objax.GradValues(fastloss,model.vars()),model.vars())
#self.model.predict = objax.Jit(objax.ForceArgs(model.__call__,training=False),model.vars())
self.gradvals = objax.Jit(objax.GradValues(self.loss,model.vars()))

def loss(self, minibatch):
""" Standard cross-entropy loss """
Expand All @@ -185,6 +220,7 @@ def logStuff(self, step, minibatch=None):
super().logStuff(step,minibatch)

class IntegratedODETrainer(Regressor):
""" A trainer for training the Neural ODEs. Feel free to use your own instead."""
def __init__(self,model,*args,**kwargs):
super().__init__(model,*args,**kwargs)
self.loss = objax.Jit(self.loss,model.vars())
Expand All @@ -209,9 +245,13 @@ def logStuff(self, step, minibatch=None):
super().logStuff(step,minibatch)

def rel_err(a,b):
""" Relative error |a-b|/|a+b|"""
return jnp.sqrt(((a-b)**2).mean())/(jnp.sqrt((a**2).mean())+jnp.sqrt((b**2).mean()))#

def log_rollout_error(ds,model,minibatch):
""" Computes the log of the geometric mean of the rollout
error computed between the dataset ds and HNN model
on the initial condition in the minibatch."""
(z0, _), _ = minibatch
pred_zs = BHamiltonianFlow(model,z0,ds.T_long)
gt_zs = BHamiltonianFlow(ds.H,z0,ds.T_long)
Expand All @@ -228,6 +268,9 @@ def pred_and_gt(ds,model,minibatch):


def log_rollout_error_ode(ds,model,minibatch):
""" Computes the log of the geometric mean of the rollout
error computed between the dataset ds and NeuralODE model
on the initial condition in the minibatch."""
(z0, _), _ = minibatch
pred_zs = BOdeFlow(model,z0,ds.T_long)
gt_zs = BHamiltonianFlow(ds.H,z0,ds.T_long)
Expand Down Expand Up @@ -267,7 +310,7 @@ def pred_and_gt_ode(ds,model,minibatch):




### Some extra code to make pretty visualizations for the given system

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
Expand Down Expand Up @@ -386,9 +429,8 @@ def update(self,i=0):

@export
class hnn_trial(object):
""" Assumes trainer is an object of type Trainer, trains for num_epochs which may be an
integer or an iterable containing intermediate points at which to save.
Pulls out special (resume, save, early_stop_metric, local_rank) args from the cfg """
""" A training trial for the HNNs, contains lots of boiler plate which is not necessary.
Feel free to use your own."""
def __init__(self,make_trainer,strict=True):
self.make_trainer = make_trainer
self.strict=strict
Expand All @@ -402,12 +444,8 @@ def __call__(self,cfg,i=None):
cfg['trainer_config']['log_suffix'] = os.path.join(orig_suffix,f'trial{i}/')
trainer = self.make_trainer(**cfg)
trainer.logger.add_scalars('config',flatten_dict(cfg))
epochs = cfg['num_epochs'] if isinstance(cfg['num_epochs'],Iterable) else [cfg['num_epochs']]
if resume: trainer.load_checkpoint(None if resume==True else resume)
epochs = [e for e in epochs if e>trainer.epoch]
for epoch in epochs:
trainer.train_to(epoch)
if save: cfg['saved_at']=trainer.save_checkpoint()
trainer.train(cfg['num_epochs'])
if save: cfg['saved_at']=trainer.save_checkpoint()
outcome = trainer.ckpt['outcome']
trajectories = []
for mb in trainer.dataloaders['test']:
Expand All @@ -422,9 +460,8 @@ def __call__(self,cfg,i=None):

@export
class ode_trial(object):
""" Assumes trainer is an object of type Trainer, trains for num_epochs which may be an
integer or an iterable containing intermediate points at which to save.
Pulls out special (resume, save, early_stop_metric, local_rank) args from the cfg """
""" A training trial for the Neural ODEs, contains lots of boiler plate which is not necessary.
Feel free to use your own."""
def __init__(self,make_trainer,strict=True):
self.make_trainer = make_trainer
self.strict=strict
Expand All @@ -438,12 +475,8 @@ def __call__(self,cfg,i=None):
cfg['trainer_config']['log_suffix'] = os.path.join(orig_suffix,f'trial{i}/')
trainer = self.make_trainer(**cfg)
trainer.logger.add_scalars('config',flatten_dict(cfg))
epochs = cfg['num_epochs'] if isinstance(cfg['num_epochs'],Iterable) else [cfg['num_epochs']]
if resume: trainer.load_checkpoint(None if resume==True else resume)
epochs = [e for e in epochs if e>trainer.epoch]
for epoch in epochs:
trainer.train_to(epoch)
if save: cfg['saved_at']=trainer.save_checkpoint()
trainer.train(cfg['num_epochs'])
if save: cfg['saved_at']=trainer.save_checkpoint()
outcome = trainer.ckpt['outcome']
trajectories = []
for mb in trainer.dataloaders['test']:
Expand Down

0 comments on commit d958c16

Please sign in to comment.