In [1]:
from IPython.display import display, HTML, Image
import matplotlib
import matplotlib.pyplot as plt
import math
import torch
import torch.nn.functional as F
import sys  
import os
from functools import partial

sys.path.insert(0, '../../..')
sys.path.insert(1, '../../../logs')

from dynnn.simulation.mve_ensemble import MveEnsembleMechanics
from dynnn.simulation.mve_ensemble.viz import visualize_trajectory, plot_energy
from dynnn.simulation.mve_ensemble.mve_ensemble import energy_conservation_loss, calc_kinetic_energy, get_initial_conditions, calc_total_energy_per_cell
from dynnn.layers import TaskModel
from dynnn.types import Dataset, PinnTrainingArgs, SimulatorTrainingArgs, SimulatorArgs
from dynnn.utils import load_model

In [2]:
DEVICE='cpu'
if torch.cuda.is_available():
    DEVICE='cuda'

verbose=False
torch.set_default_device(DEVICE)

In [3]:
%matplotlib inline

mechanics = MveEnsembleMechanics()

def plot_energy_from_coords(r, v, time, masses):
    pe = mechanics.no_bc_potential_fn(r).detach().cpu()
    ke = calc_kinetic_energy(v, masses).detach().cpu()
    te = pe + ke

    plot_energy(pe, ke, te, time.cpu())

def plot_loss(stats: dict, zoom_length: int = 500, key_sets: list[list[str]] | None = None):
    if key_sets is None:
        key_sets = [stats.keys()]

    if not 'fig' in vars():
        fig, axes = plt.subplots(len(key_sets), 2, figsize=(10, 4))

    for i, keys in enumerate(key_sets):
        axes[i][0].clear()
        axes[i][0].set_title("Loss")
        axes[i][0].set_xlabel("time")
        for key in keys:
            axes[i][0].plot(getattr(stats, key), label=key)
        axes[i][0].legend(fontsize=8)
        axes[i][0].set_yscale("log")

        axes[i][1].clear()
        axes[i][1].set_title(f"Last {zoom_length} steps loss")
        axes[i][1].set_xlabel("time")
        for key in keys:
            axes[i][1].plot(getattr(stats, key)[-zoom_length:], label=key)
        axes[i][1].legend(fontsize=8)
        axes[i][1].set_yscale("log")
        fig.tight_layout()
    display(fig, clear=True)

In [4]:
if verbose:
    r, v, dr, dv, time, masses = mechanics.get_trajectory({}).dict().values()
    plot_energy_from_coords(r, v, time, masses)
    ani = visualize_trajectory(r.detach().cpu(), len(time), (mechanics.domain_min, mechanics.domain_max))
    display(HTML(ani.to_jshtml()))

    ani.save(sys.path[0] + '/../images/mve_ensemble.gif', writer='pillow')
    display(Image(filename=sys.path[0] + '/../images/mve_ensemble.gif'))

In [5]:
data, _ = mechanics.get_dataset({"n_samples": 2}, { "n_bodies": 3, "time_scale": 3, "t_span_max": 10 })

if verbose:
    data_r, data_v = [v.squeeze(-2) for v in torch.split(data["x"][0], 1, dim=-2)]
    plot_energy_from_coords(data_r, data_v, data["time"], masses)
    ani = visualize_trajectory(data_r, len(data["time"]), mechanics.domain)
    display(HTML(ani.to_jshtml()))

In [6]:
def loss_fn(dxdt, dxdt_hat, s, masses):
    """
    Calculate the loss
    """
    loss = F.mse_loss(dxdt_hat, dxdt)

    energy_loss = energy_conservation_loss(s, s + dxdt_hat * 0.01, masses).sum()
    print("Energy Loss:", energy_loss.item())
    loss += energy_loss

    return loss, energy_loss

def transform_y(q, p, masses):
    energy_by_cell = calc_total_energy_per_cell(
        q, p, masses,
        grid_resolution=task_grid_resolution,
        boundaries=(mechanics.domain_min, mechanics.domain_max)
    )
    return energy_by_cell # energy_by_cell.mean(dim=0)

task_grid_resolution = (4, 4, 4)
task_output_dim = math.prod(task_grid_resolution) # number of cells in grid
task_input_dim = math.prod(data.x.shape[2:])

plot_loss_cb = partial(plot_loss, key_sets=[["train_loss", "test_loss"], ["train_additional_loss", "test_additional_loss"]])
initial_args = SimulatorArgs(training_args=PinnTrainingArgs(loss_fn=loss_fn, plot_loss_callback=plot_loss_cb))

model = TaskModel(input_dim=task_input_dim, output_dim=task_output_dim, initial_sim_args=initial_args)
model.do_train(SimulatorTrainingArgs(), data, transform_y)

2024-06-21 10:06:51,424 dynnn.utils - INFO:Data file mve_data-n_samples-2-n_bodies-3_generator_type-1_time_scale-1_t_span_max-21_odeint_rtol-2e-06_odeint_atol-9e-07_odeint_solver-1_odeint_order-1.pkl not found.
2024-06-21 10:06:51,424 dynnn.utils - INFO:Creating new data...
2024-06-21 10:06:53,730 dynnn.mechanics.base_mechanics - INFO:Data traj: 1 of 2
2024-06-21 10:06:53,731 dynnn.mechanics.base_mechanics - INFO:Data traj: 2 of 2
  return func(*args, **kwargs)
2024-06-21 10:06:53,759 dynnn.train.task - INFO:Training task model
  return func(*args, **kwargs)
2024-06-21 10:06:53,782 dynnn.train.task - INFO:OUTER TASK Step 0, train_loss 4.4662e-03, test_loss 4.9225e-03
2024-06-21 10:06:53,791 dynnn.train.task - INFO:OUTER TASK Step 1, train_loss 1.3968e-02, test_loss 1.0836e-02
2024-06-21 10:06:53,802 dynnn.train.task - INFO:OUTER TASK Step 2, train_loss 1.9774e-02, test_loss 2.6535e-03
2024-06-21 10:06:53,811 dynnn.train.task - INFO:OUTER TASK Step 3, train_loss 2.8591e-03, test_loss 1.

H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])


2024-06-21 10:06:53,969 dynnn.train.task - INFO:OUTER TASK Step 19, train_loss 8.7473e-04, test_loss 8.8032e-04
2024-06-21 10:06:53,982 dynnn.train.task - INFO:OUTER TASK Step 20, train_loss 8.6272e-04, test_loss 8.5699e-04
2024-06-21 10:06:53,993 dynnn.train.task - INFO:OUTER TASK Step 21, train_loss 8.4612e-04, test_loss 8.3258e-04
2024-06-21 10:06:54,003 dynnn.train.task - INFO:OUTER TASK Step 22, train_loss 8.2627e-04, test_loss 8.0795e-04
2024-06-21 10:06:54,016 dynnn.train.task - INFO:OUTER TASK Step 23, train_loss 8.0452e-04, test_loss 7.8381e-04
2024-06-21 10:06:54,027 dynnn.train.task - INFO:OUTER TASK Step 24, train_loss 7.8207e-04, test_loss 7.6069e-04
2024-06-21 10:06:54,037 dynnn.train.task - INFO:OUTER TASK Step 25, train_loss 7.5987e-04, test_loss 7.3894e-04
2024-06-21 10:06:54,045 dynnn.train.task - INFO:OUTER TASK Step 26, train_loss 7.3858e-04, test_loss 7.1873e-04
2024-06-21 10:06:54,056 dynnn.train.task - INFO:OUTER TASK Step 27, train_loss 7.1858e-04, test_loss 7.0

H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])


2024-06-21 10:06:54,176 dynnn.train.task - INFO:OUTER TASK Step 40, train_loss 5.7020e-04, test_loss 5.6485e-04
2024-06-21 10:06:54,185 dynnn.train.task - INFO:OUTER TASK Step 41, train_loss 5.6485e-04, test_loss 5.6005e-04
2024-06-21 10:06:54,195 dynnn.train.task - INFO:OUTER TASK Step 42, train_loss 5.6005e-04, test_loss 5.5576e-04
2024-06-21 10:06:54,205 dynnn.train.task - INFO:OUTER TASK Step 43, train_loss 5.5576e-04, test_loss 5.5193e-04
2024-06-21 10:06:54,214 dynnn.train.task - INFO:OUTER TASK Step 44, train_loss 5.5193e-04, test_loss 5.4851e-04
2024-06-21 10:06:54,224 dynnn.train.task - INFO:OUTER TASK Step 45, train_loss 5.4851e-04, test_loss 5.4546e-04
2024-06-21 10:06:54,232 dynnn.train.task - INFO:OUTER TASK Step 46, train_loss 5.4546e-04, test_loss 5.4274e-04
2024-06-21 10:06:54,240 dynnn.train.task - INFO:OUTER TASK Step 47, train_loss 5.4274e-04, test_loss 5.4031e-04
2024-06-21 10:06:54,250 dynnn.train.task - INFO:OUTER TASK Step 48, train_loss 5.4031e-04, test_loss 5.3

H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])


2024-06-21 10:06:54,400 dynnn.train.task - INFO:OUTER TASK Step 53, train_loss 5.3146e-04, test_loss 5.3017e-04
2024-06-21 10:06:54,410 dynnn.train.task - INFO:OUTER TASK Step 54, train_loss 5.3017e-04, test_loss 5.2899e-04
2024-06-21 10:06:54,426 dynnn.train.task - INFO:OUTER TASK Step 55, train_loss 5.2899e-04, test_loss 5.2790e-04
2024-06-21 10:06:54,439 dynnn.train.task - INFO:OUTER TASK Step 56, train_loss 5.2790e-04, test_loss 5.2689e-04
2024-06-21 10:06:54,452 dynnn.train.task - INFO:OUTER TASK Step 57, train_loss 5.2689e-04, test_loss 5.2596e-04
2024-06-21 10:06:54,463 dynnn.train.task - INFO:OUTER TASK Step 58, train_loss 5.2596e-04, test_loss 5.2508e-04
2024-06-21 10:06:54,472 dynnn.train.task - INFO:OUTER TASK Step 59, train_loss 5.2508e-04, test_loss 5.2426e-04
2024-06-21 10:06:54,482 dynnn.train.task - INFO:OUTER TASK Step 60, train_loss 5.2426e-04, test_loss 5.2348e-04
2024-06-21 10:06:54,492 dynnn.train.task - INFO:OUTER TASK Step 61, train_loss 5.2348e-04, test_loss 5.2

H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])


2024-06-21 10:06:54,603 dynnn.train.task - INFO:OUTER TASK Step 73, train_loss 5.1691e-04, test_loss 5.1656e-04
2024-06-21 10:06:54,613 dynnn.train.task - INFO:OUTER TASK Step 74, train_loss 5.1656e-04, test_loss 5.1624e-04
2024-06-21 10:06:54,624 dynnn.train.task - INFO:OUTER TASK Step 75, train_loss 5.1624e-04, test_loss 5.1594e-04
2024-06-21 10:06:54,634 dynnn.train.task - INFO:OUTER TASK Step 76, train_loss 5.1594e-04, test_loss 5.1567e-04
2024-06-21 10:06:54,644 dynnn.train.task - INFO:OUTER TASK Step 77, train_loss 5.1567e-04, test_loss 5.1542e-04
2024-06-21 10:06:54,654 dynnn.train.task - INFO:OUTER TASK Step 78, train_loss 5.1542e-04, test_loss 5.1519e-04
2024-06-21 10:06:54,663 dynnn.train.task - INFO:OUTER TASK Step 79, train_loss 5.1519e-04, test_loss 5.1499e-04
2024-06-21 10:06:54,673 dynnn.train.task - INFO:OUTER TASK Step 80, train_loss 5.1499e-04, test_loss 5.1480e-04
2024-06-21 10:06:54,682 dynnn.train.task - INFO:OUTER TASK Step 81, train_loss 5.1480e-04, test_loss 5.1

H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])


2024-06-21 10:06:54,808 dynnn.train.task - INFO:OUTER TASK Step 94, train_loss 5.1366e-04, test_loss 5.1363e-04
2024-06-21 10:06:54,817 dynnn.train.task - INFO:OUTER TASK Step 95, train_loss 5.1363e-04, test_loss 5.1361e-04
2024-06-21 10:06:54,828 dynnn.train.task - INFO:OUTER TASK Step 96, train_loss 5.1361e-04, test_loss 5.1359e-04
2024-06-21 10:06:54,837 dynnn.train.task - INFO:OUTER TASK Step 97, train_loss 5.1359e-04, test_loss 5.1357e-04
2024-06-21 10:06:54,846 dynnn.train.task - INFO:OUTER TASK Step 98, train_loss 5.1357e-04, test_loss 5.1356e-04
2024-06-21 10:06:54,855 dynnn.train.task - INFO:OUTER TASK Step 99, train_loss 5.1356e-04, test_loss 5.1355e-04
2024-06-21 10:06:54,864 dynnn.train.task - INFO:OUTER TASK Step 100, train_loss 5.1355e-04, test_loss 5.1354e-04
2024-06-21 10:06:54,873 dynnn.train.task - INFO:OUTER TASK Step 101, train_loss 5.1354e-04, test_loss 5.1353e-04
2024-06-21 10:06:54,883 dynnn.train.task - INFO:OUTER TASK Step 102, train_loss 5.1353e-04, test_loss 

H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size(

2024-06-21 10:06:55,016 dynnn.train.task - INFO:OUTER TASK Step 117, train_loss 5.1350e-04, test_loss 5.1350e-04
2024-06-21 10:06:55,026 dynnn.train.task - INFO:OUTER TASK Step 118, train_loss 5.1350e-04, test_loss 5.1350e-04
2024-06-21 10:06:55,036 dynnn.train.task - INFO:OUTER TASK Step 119, train_loss 5.1350e-04, test_loss 5.1350e-04
2024-06-21 10:06:55,045 dynnn.train.task - INFO:OUTER TASK Step 120, train_loss 5.1350e-04, test_loss 5.1349e-04
2024-06-21 10:06:55,056 dynnn.train.task - INFO:OUTER TASK Step 121, train_loss 5.1349e-04, test_loss 5.1349e-04
2024-06-21 10:06:55,066 dynnn.train.task - INFO:OUTER TASK Step 122, train_loss 5.1349e-04, test_loss 5.1349e-04
2024-06-21 10:06:55,076 dynnn.train.task - INFO:OUTER TASK Step 123, train_loss 5.1349e-04, test_loss 5.1349e-04
2024-06-21 10:06:55,086 dynnn.train.task - INFO:OUTER TASK Step 124, train_loss 5.1349e-04, test_loss 5.1349e-04
2024-06-21 10:06:55,095 dynnn.train.task - INFO:OUTER TASK Step 125, train_loss 5.1349e-04, test

H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])


2024-06-21 10:06:55,227 dynnn.train.task - INFO:OUTER TASK Step 139, train_loss 5.1348e-04, test_loss 5.1348e-04
2024-06-21 10:06:55,238 dynnn.train.task - INFO:OUTER TASK Step 140, train_loss 5.1348e-04, test_loss 5.1348e-04
2024-06-21 10:06:55,306 dynnn.train.task - INFO:OUTER TASK Step 141, train_loss 5.1348e-04, test_loss 5.1348e-04
2024-06-21 10:06:55,365 dynnn.train.task - INFO:OUTER TASK Step 142, train_loss 5.1348e-04, test_loss 5.1348e-04
2024-06-21 10:06:55,376 dynnn.train.task - INFO:OUTER TASK Step 143, train_loss 5.1348e-04, test_loss 5.1348e-04
2024-06-21 10:06:55,391 dynnn.train.task - INFO:OUTER TASK Step 144, train_loss 5.1348e-04, test_loss 5.1348e-04
2024-06-21 10:06:55,402 dynnn.train.task - INFO:OUTER TASK Step 145, train_loss 5.1348e-04, test_loss 5.1348e-04
2024-06-21 10:06:55,412 dynnn.train.task - INFO:OUTER TASK Step 146, train_loss 5.1348e-04, test_loss 5.1348e-04


H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])


2024-06-21 10:06:55,425 dynnn.train.task - INFO:OUTER TASK Step 147, train_loss 5.1348e-04, test_loss 5.1348e-04
2024-06-21 10:06:55,434 dynnn.train.task - INFO:OUTER TASK Step 148, train_loss 5.1348e-04, test_loss 5.1348e-04
2024-06-21 10:06:55,446 dynnn.train.task - INFO:OUTER TASK Step 149, train_loss 5.1348e-04, test_loss 5.1348e-04
2024-06-21 10:06:55,458 dynnn.train.task - INFO:OUTER TASK Step 150, train_loss 5.1348e-04, test_loss 5.1348e-04
2024-06-21 10:06:55,468 dynnn.train.task - INFO:OUTER TASK Step 151, train_loss 5.1348e-04, test_loss 5.1348e-04
2024-06-21 10:06:55,477 dynnn.train.task - INFO:OUTER TASK Step 152, train_loss 5.1348e-04, test_loss 5.1348e-04
2024-06-21 10:06:55,486 dynnn.train.task - INFO:OUTER TASK Step 153, train_loss 5.1348e-04, test_loss 5.1348e-04
2024-06-21 10:06:55,494 dynnn.train.task - INFO:OUTER TASK Step 154, train_loss 5.1348e-04, test_loss 5.1348e-04
2024-06-21 10:06:55,503 dynnn.train.task - INFO:OUTER TASK Step 155, train_loss 5.1348e-04, test

H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size(

2024-06-21 10:06:55,636 dynnn.train.task - INFO:OUTER TASK Step 170, train_loss 5.1348e-04, test_loss 5.1348e-04
2024-06-21 10:06:55,645 dynnn.train.task - INFO:OUTER TASK Step 171, train_loss 5.1348e-04, test_loss 5.1348e-04
2024-06-21 10:06:55,656 dynnn.train.task - INFO:OUTER TASK Step 172, train_loss 5.1348e-04, test_loss 5.1348e-04
2024-06-21 10:06:55,667 dynnn.train.task - INFO:OUTER TASK Step 173, train_loss 5.1348e-04, test_loss 5.1348e-04
2024-06-21 10:06:55,676 dynnn.train.task - INFO:OUTER TASK Step 174, train_loss 5.1348e-04, test_loss 5.1348e-04
2024-06-21 10:06:55,686 dynnn.train.task - INFO:OUTER TASK Step 175, train_loss 5.1348e-04, test_loss 5.1348e-04
2024-06-21 10:06:55,697 dynnn.train.task - INFO:OUTER TASK Step 176, train_loss 5.1348e-04, test_loss 5.1348e-04
2024-06-21 10:06:55,706 dynnn.train.task - INFO:OUTER TASK Step 177, train_loss 5.1348e-04, test_loss 5.1348e-04
2024-06-21 10:06:55,716 dynnn.train.task - INFO:OUTER TASK Step 178, train_loss 5.1348e-04, test

H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size(

2024-06-21 10:06:55,838 dynnn.train.task - INFO:OUTER TASK Step 191, train_loss 5.1348e-04, test_loss 5.1348e-04
2024-06-21 10:06:55,848 dynnn.train.task - INFO:OUTER TASK Step 192, train_loss 5.1348e-04, test_loss 5.1348e-04
2024-06-21 10:06:55,859 dynnn.train.task - INFO:OUTER TASK Step 193, train_loss 5.1348e-04, test_loss 5.1348e-04
2024-06-21 10:06:55,869 dynnn.train.task - INFO:OUTER TASK Step 194, train_loss 5.1348e-04, test_loss 5.1348e-04
2024-06-21 10:06:55,878 dynnn.train.task - INFO:OUTER TASK Step 195, train_loss 5.1348e-04, test_loss 5.1348e-04
2024-06-21 10:06:55,887 dynnn.train.task - INFO:OUTER TASK Step 196, train_loss 5.1348e-04, test_loss 5.1348e-04
2024-06-21 10:06:55,897 dynnn.train.task - INFO:OUTER TASK Step 197, train_loss 5.1348e-04, test_loss 5.1348e-04
2024-06-21 10:06:55,907 dynnn.train.task - INFO:OUTER TASK Step 198, train_loss 5.1348e-04, test_loss 5.1348e-04
2024-06-21 10:06:55,916 dynnn.train.task - INFO:OUTER TASK Step 199, train_loss 5.1348e-04, test

H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size(

2024-06-21 10:06:56,076 dynnn.train.task - INFO:OUTER TASK Step 220, train_loss 5.1348e-04, test_loss 5.1348e-04


H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])
H torch.Size([30, 64]) torch.Size([30, 1, 64])


KeyboardInterrupt: 

In [None]:
# plot model output

test_y0, test_masses = get_initial_conditions(5)
initial_state = test_y0.clone().detach().requires_grad_()
r, v, dr, dv, time = mechanics.get_trajectory({"y0": initial_state, "masses": test_masses, "model": model}).dict().values()
plot_energy_from_coords(r, v, time, test_masses)

ani_model = visualize_trajectory(r, len(time), (mechanics.domain_min, mechanics.domain_max))
HTML(ani_model.to_jshtml())