In [None]:
import matplotlib.pyplot as plt
import scipy.integrate
import torch
import sys  
sys.path.insert(0, '../../..')

from hnn.simulation.mve_ensemble import MveEnsembleHamiltonianDynamics
from hnn.utils import integrate_model
from hnn.train import train
from hnn.types import TrajectoryArgs

In [None]:
import argparse

def get_args():
    parser = argparse.ArgumentParser(allow_abbrev=False)
    parser.add_argument('--input_dim', default=2, type=int, help='dimensionality of input tensor')
    parser.add_argument('--hidden_dim', default=300, type=int, help='hidden dimension of mlp')
    parser.add_argument('--learn_rate', default=1e-3, type=float, help='learning rate')
    parser.add_argument('--weight_decay', default=1e-4, type=float, help='weight decay')
    parser.add_argument('--total_steps', default=1000, type=int, help='number of gradient steps')
    parser.add_argument('--field_type', default='conservative', type=str, help='type of vector field to learn') # solenoidal, conservative
    parser.set_defaults(feature=True)
    return parser.parse_known_args()[0]

In [None]:
# Test the MVE ensemble Hamiltonian function

from hnn.simulation.mve_ensemble.mve_ensemble import mve_ensemble_fn, get_initial_conditions

coords, masses = get_initial_conditions(4)

# Hamiltonian for the ensemble
H = mve_ensemble_fn(coords, masses)
print("Hamiltonian (Total Energy):", H.item())

In [None]:
hamiltonian = MveEnsembleHamiltonianDynamics()
xy, v, dq, dp, t = hamiltonian.get_trajectory({})
positions = xy.transpose(0, 1)
velocities = v.transpose(0, 1)
time = t
timepoints = len(time)

In [None]:
from matplotlib.animation import FuncAnimation

# Setup the figure and axes
fig, ax = plt.subplots(figsize=[10, 4], dpi=100)
ax.set_title('Trajectories')
lines = []
for i in range(positions.shape[0]):
    line, = ax.plot([], [], label='body {} path'.format(i))
    lines.append(line)

ax.axis('equal')
ax.set_xlim(torch.min(positions[:,:,0]), torch.max(positions[:,:,0]))
ax.set_ylim(torch.min(positions[:,:,1]), torch.max(positions[:,:,1]))
ax.set_xlabel('$x$')
ax.set_ylabel('$y$')
ax.legend(fontsize=8)

def init():
    for line in lines:
        line.set_data([], [])
    return lines

def update(frame):
    for i, line in enumerate(lines):
        x, y = positions[i, :frame, 0], positions[i, :frame, 1]
        line.set_data(x, y)
    return lines

ani = FuncAnimation(fig, update, frames=torch.arange(1, timepoints+1), init_func=init, blit=True, repeat=True)

from IPython.display import HTML
HTML(ani.to_jshtml())

In [None]:
from hnn.simulation.mve_ensemble.mve_ensemble import calc_lennard_jones_potential, calc_kinetic_energy
ljp = calc_lennard_jones_potential(positions)
masses = torch.ones(5)
ke = calc_kinetic_energy(velocities, masses)
te = ljp + ke

print("Potential Energy:", ljp)

p2 = plt.subplot(1,2,2)
plt.title('Energy') ; plt.xlabel('time')
p2.set_yscale('symlog')
plt.plot(time, ljp, label='potential')
plt.plot(time, ke, label='kinetic')
plt.plot(time, te, label='total')
plt.legend(fontsize=8)


In [7]:
args = get_args()

data = hamiltonian.get_dataset({}, {})
model, stats = train(args, data)