In [1]:
import matplotlib.pyplot as plt
import scipy.integrate
import torch
import sys  
sys.path.insert(0, '../../..')

from hnn.simulation.mve_ensemble import MveEnsembleHamiltonianDynamics
from hnn.utils import integrate_model
from hnn.train import train
from hnn.types import TrajectoryArgs

In [2]:
import argparse

def get_args():
    parser = argparse.ArgumentParser(allow_abbrev=False)
    parser.add_argument('--input_dim', default=2, type=int, help='dimensionality of input tensor')
    parser.add_argument('--hidden_dim', default=200, type=int, help='hidden dimension of mlp')
    parser.add_argument('--learn_rate', default=1e-3, type=float, help='learning rate')
    parser.add_argument('--weight_decay', default=1e-4, type=float, help='weight decay')
    parser.add_argument('--total_steps', default=20, type=int, help='number of gradient steps')
    parser.add_argument('--field_type', default='conservative', type=str, help='type of vector field to learn') # solenoidal, conservative
    parser.set_defaults(feature=True)
    return parser.parse_known_args()[0]

In [3]:
# Test the MVE ensemble Hamiltonian function

from hnn.simulation.mve_ensemble import mve_ensemble_fn

# Example coordinates and masses for three particles
coords = torch.tensor([[1.0, 0.5, 1.0, -1.0, -0.5, -1.0], [-0.1, 0.1, 0.3, 0.1, 0.4, -0.1]]).T  # [q1, q2, q3, ...,], [p1, p2, p3, ...]
masses = torch.tensor([1.0, 1.5, 2.0])

# Hamiltonian for the ensemble
H = mve_ensemble_fn(coords, masses)
print("Hamiltonian (Total Energy):", H.item())

Hamiltonian (Total Energy): -0.5924206376075745


In [4]:

args = get_args()

hamiltonian = MveEnsembleHamiltonianDynamics()
data = hamiltonian.get_dataset({}, {})
model, stats = train(args, data)

x torch.Size([2, 33, 10, 2])
F1 torch.Size([2, 33, 10, 1]) F2 torch.Size([2, 33, 10, 1]) y torch.Size([2, 33, 10, 2])
eye_tensor torch.Size([2, 33, 10, 2, 2]) dF1 torch.Size([2, 33, 10, 2])
dxdt_hat torch.Size([2, 33, 10, 2]) dxdt torch.Size([2, 33, 10, 2]) x torch.Size([2, 33, 10, 2])
x torch.Size([1, 33, 10, 2])
F1 torch.Size([1, 33, 10, 1]) F2 torch.Size([1, 33, 10, 1]) y torch.Size([1, 33, 10, 2])
eye_tensor torch.Size([1, 33, 10, 2, 2]) dF1 torch.Size([1, 33, 10, 2])
step 0, train_loss 8.3159e+03, test_loss 9.4752e+00
x torch.Size([2, 33, 10, 2])
F1 torch.Size([2, 33, 10, 1]) F2 torch.Size([2, 33, 10, 1]) y torch.Size([2, 33, 10, 2])
eye_tensor torch.Size([2, 33, 10, 2, 2]) dF1 torch.Size([2, 33, 10, 2])
dxdt_hat torch.Size([2, 33, 10, 2]) dxdt torch.Size([2, 33, 10, 2]) x torch.Size([2, 33, 10, 2])
x torch.Size([1, 33, 10, 2])
F1 torch.Size([1, 33, 10, 1]) F2 torch.Size([1, 33, 10, 1]) y torch.Size([1, 33, 10, 2])
eye_tensor torch.Size([1, 33, 10, 2, 2]) dF1 torch.Size([1, 33, 10

In [5]:
R = 2.5
LINE_SEGMENTS = 10
ARROW_SCALE = 40
ARROW_WIDTH = 6e-3
LINE_WIDTH = 2

field_args = {
    'xmin': -R,
    'xmax': R,
    'ymin': -R,
    'ymax': R,
    'gridsize': 10,
}
field = hamiltonian.get_field(field_args)
vector_field = hamiltonian.get_vector_field(model, field_args)

ys torch.Size([10, 100, 2])
ys torch.Size([10, 100, 2])
x torch.Size([1, 10, 100, 2])
F1 torch.Size([1, 10, 100, 1]) F2 torch.Size([1, 10, 100, 1]) y torch.Size([1, 10, 100, 2])
eye_tensor torch.Size([1, 10, 100, 2, 2]) dF1 torch.Size([1, 10, 100, 2])


In [13]:

from hnn.simulation.mve_ensemble.mve_ensemble import DEFAULT_TRAJECTORY_ARGS, DEFAULT_ODE_ARGS

ivp = integrate_model(
        model,
        t_span=DEFAULT_TRAJECTORY_ARGS['t_span'],
        y0=DEFAULT_ODE_ARGS['y0'],
        timescale=30,
        rtol=1e-12
)

{'y0': tensor([[ 0.3513,  0.5774],
        [-0.7414, -0.1744],
        [ 0.4781, -0.2955],
        [ 0.0400,  0.0929],
        [-0.7541,  0.4436],
        [ 0.5900,  0.2341],
        [ 0.0738,  2.2308],
        [ 0.9430, -0.0165],
        [-0.7673, -1.6551],
        [-1.7867, -0.1605]])} {'t_span': (0, 11)}
x torch.Size([10, 2])


ValueError: not enough values to unpack (expected 4, got 2)

In [18]:
from torchdiffeq import odeint

def get_timepoints(t_span: tuple[int, int], timescale: int = 30) -> torch.Tensor:
    return torch.linspace(
        t_span[0], t_span[1], int(timescale * (t_span[1] - t_span[0]))
    )

def integrate_model(model, t_span: tuple[int, int], y0: torch.tensor, timescale: int = 30, **kwargs):
    # x torch.Size([10, 2])
    print("t_span", t_span, "y0", y0.shape)
    def fun(t, x):
        print("XXXXX", x.shape)
        if x.ndim == 1:
            x = x.unsqueeze(0).unsqueeze(0)
        _x = x.clone().detach().requires_grad_()
        dx = model.time_derivative(_x).data
        return dx

    t = get_timepoints(t_span, timescale)
    return odeint(fun, t=t, y0=y0, **kwargs)

ivp = integrate_model(
        model,
        t_span=DEFAULT_TRAJECTORY_ARGS['t_span'],
        y0=DEFAULT_ODE_ARGS['y0'],
        timescale=30,
        rtol=1e-12
)

t_span (0, 11) y0 torch.Size([10, 2])
XXXXX torch.Size([10, 2])
x torch.Size([10, 2])


ValueError: not enough values to unpack (expected 4, got 2)

In [None]:
fig = plt.figure(figsize=(11.3, 3.2), facecolor='white', dpi=300)

# plot data
fig.add_subplot(1, 4, 1, frameon=True)
x, y, dx, dy, t = hamiltonian.get_trajectory({})
N = x.shape[1]
point_colors = [(i/N, 0, 1-i/N) for i in range(N)]
plt.scatter(x, y, s=14, label='data', c=point_colors)
plt.quiver(
        field.x[:,:,0],
        field.x[:,:,1],
        field.dx[:,0],
        field.dx[:,1],
        cmap='gray_r',
        scale=ARROW_SCALE,
        width=ARROW_WIDTH,
        color=(.2,.2,.2)
)  
plt.xlabel("$q$", fontsize=14)
plt.ylabel("$p$", rotation=0, fontsize=14)
plt.title("Data", pad=10)

# plot HNN-modeled data
fig.add_subplot(1, 4, 2, frameon=True)
plt.quiver(
        field.x[:,:,0],
        field.x[:,:,1],
        vector_field[:,:,0],
        vector_field[:,:,1],
        cmap='gray_r',
        scale=ARROW_SCALE,
        width=ARROW_WIDTH,
        color=(.5,.5,.5)
)

for i, l in enumerate(torch.tensor_split(ivp, LINE_SEGMENTS)):
        color = (float(i)/LINE_SEGMENTS, 0, 1-float(i)/LINE_SEGMENTS)
        plt.plot(l[:, 0], l[:, 1], color=color, linewidth=LINE_WIDTH)

plt.xlabel("$q$", fontsize=14)
plt.ylabel("$p$", rotation=0, fontsize=14)
plt.title("Hamiltonian NN", pad=10)
plt.tight_layout() ; plt.show()