In [1]:
import matplotlib.pyplot as plt
import scipy.integrate
import torch
import sys  
sys.path.insert(0, '../../..')

from hnn.simulation.mve_ensemble import MveEnsembleHamiltonianDynamics
from hnn.utils import integrate_model
from hnn.train import train
from hnn.types import TrajectoryArgs

In [2]:
import argparse

def get_args():
    parser = argparse.ArgumentParser(allow_abbrev=False)
    parser.add_argument('--input_dim', default=2, type=int, help='dimensionality of input tensor')
    parser.add_argument('--hidden_dim', default=200, type=int, help='hidden dimension of mlp')
    parser.add_argument('--learn_rate', default=1e-3, type=float, help='learning rate')
    parser.add_argument('--weight_decay', default=1e-4, type=float, help='weight decay')
    parser.add_argument('--total_steps', default=2000, type=int, help='number of gradient steps')
    parser.add_argument('--field_type', default='conservative', type=str, help='type of vector field to learn') # solenoidal, conservative
    parser.set_defaults(feature=True)
    return parser.parse_known_args()[0]

In [3]:
# Test the MVE ensemble Hamiltonian function

from hnn.simulation.mve_ensemble.mve_ensemble import mve_ensemble_fn, get_initial_conditions

coords, masses = get_initial_conditions(10)

# Hamiltonian for the ensemble
H = mve_ensemble_fn(coords, masses)
print("Hamiltonian (Total Energy):", H.item())

Hamiltonian (Total Energy): 308.671630859375


In [4]:

args = get_args()

hamiltonian = MveEnsembleHamiltonianDynamics()
data = hamiltonian.get_dataset({}, {})

In [None]:
model, stats = train(args, data)

In [None]:
q, p, dq, dp, t = hamiltonian.get_trajectory({})

In [None]:
positions = q.transpose(0, 1)
time = t

# draw trajectories
fig = plt.figure(figsize=[10,4], dpi=100)
p1 = plt.subplot(1,2,1)
plt.title('Trajectories')
for i, pos in enumerate(positions):
    x, y = pos[:, 0], pos[:, 1]
    plt.plot(x, y, label='body {} path'.format(i))

plt.axis('equal')
plt.xlabel('$x$') ; plt.ylabel('$y$')
plt.legend(fontsize=8)

In [None]:

from hnn.simulation.mve_ensemble.mve_ensemble import lennard_jones_potential
ljp = lennard_jones_potential(positions)
print("LJP", ljp.shape)

p2 = plt.subplot(1,2,2)
plt.title('Energy') ; plt.xlabel('time')
# p2.set_yscale('log')
plt.plot(time, lennard_jones_potential(positions), label='potential')
plt.legend()

In [None]:
R = 2.5
LINE_SEGMENTS = 10
ARROW_SCALE = 40
ARROW_WIDTH = 6e-3
LINE_WIDTH = 2

field_args = {
    'xmin': -R,
    'xmax': R,
    'ymin': -R,
    'ymax': R,
    'gridsize': 10,
}
field = hamiltonian.get_field(field_args)
vector_field = hamiltonian.get_vector_field(model, field_args)

In [None]:

from hnn.simulation.mve_ensemble.mve_ensemble import DEFAULT_TRAJECTORY_ARGS, DEFAULT_ODE_ARGS

ivp = integrate_model(
        model,
        t_span=DEFAULT_TRAJECTORY_ARGS['t_span'],
        y0=DEFAULT_ODE_ARGS['y0'],
        timescale=30,
        rtol=1e-12
)

In [None]:
fig = plt.figure(figsize=(11.3, 3.2), facecolor='white', dpi=300)

# plot data
fig.add_subplot(1, 4, 1, frameon=True)
x, y, dx, dy, t = hamiltonian.get_trajectory({})
N = x.shape[0] * 10
point_colors = [(i/N, 0, 1-i/N) for i in range(N)]
# torch.Size([33, 10])
plt.scatter(x, y, s=14, label='data', c=point_colors)
plt.quiver(
        field.x[:,0],
        field.x[:,1],
        field.dx[:,0],
        field.dx[:,1],
        cmap='gray_r',
        scale=ARROW_SCALE,
        width=ARROW_WIDTH,
        color=(.2,.2,.2)
)  
plt.xlabel("$q$", fontsize=14)
plt.ylabel("$p$", rotation=0, fontsize=14)
plt.title("Data", pad=10)

print(field.x.shape, vector_field.shape)

# plot HNN-modeled data
fig.add_subplot(1, 4, 2, frameon=True)
plt.quiver(
        field.x[:,0],
        field.x[:,1],
        vector_field[:,0],
        vector_field[:,1],
        cmap='gray_r',
        scale=ARROW_SCALE,
        width=ARROW_WIDTH,
        color=(.5,.5,.5)
)

for i, l in enumerate(torch.tensor_split(ivp, LINE_SEGMENTS)):
        color = (float(i)/LINE_SEGMENTS, 0, 1-float(i)/LINE_SEGMENTS)
        plt.plot(l[:, 0], l[:, 1], color=color, linewidth=LINE_WIDTH)

plt.xlabel("$q$", fontsize=14)
plt.ylabel("$p$", rotation=0, fontsize=14)
plt.title("Hamiltonian NN", pad=10)
plt.tight_layout() ; plt.show()