In [1]:
from IPython.display import display, HTML, Image
import matplotlib
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import sys  
import os
from functools import partial

sys.path.insert(0, '../../..')
sys.path.insert(1, '../../../logs')

from dynnn.simulation.mve_ensemble import MveEnsembleMechanics
from dynnn.simulation.mve_ensemble.viz import visualize_trajectory, plot_energy
from dynnn.simulation.mve_ensemble.mve_ensemble import energy_conservation_loss, calc_kinetic_energy, get_initial_conditions, calc_total_energy_per_cell
from dynnn.layers import TaskModel
from dynnn.types import Dataset, PinnTrainingArgs, SimulatorTrainingArgs, SimulatorArgs
from dynnn.utils import load_model

In [2]:
DEVICE='cpu'
if torch.cuda.is_available():
    DEVICE='cuda'

verbose=False
torch.set_default_device(DEVICE)

In [3]:
%matplotlib inline

mechanics = MveEnsembleMechanics()

def plot_energy_from_coords(r, v, time, masses):
    pe = mechanics.no_bc_potential_fn(r).detach().cpu()
    ke = calc_kinetic_energy(v, masses).detach().cpu()
    te = pe + ke

    plot_energy(pe, ke, te, time.cpu())

def plot_loss(stats: dict, zoom_length: int = 500, key_sets: list[list[str]] | None = None):
    if key_sets is None:
        key_sets = [stats.keys()]

    if not 'fig' in vars():
        fig, axes = plt.subplots(len(key_sets), 2, figsize=(10, 4))

    for i, keys in enumerate(key_sets):
        axes[i][0].clear()
        axes[i][0].set_title("Loss")
        axes[i][0].set_xlabel("time")
        for key in keys:
            axes[i][0].plot(getattr(stats, key), label=key)
        axes[i][0].legend(fontsize=8)
        axes[i][0].set_yscale("log")

        axes[i][1].clear()
        axes[i][1].set_title(f"Last {zoom_length} steps loss")
        axes[i][1].set_xlabel("time")
        for key in keys:
            axes[i][1].plot(getattr(stats, key)[-zoom_length:], label=key)
        axes[i][1].legend(fontsize=8)
        axes[i][1].set_yscale("log")
        fig.tight_layout()
    display(fig, clear=True)

In [4]:
if verbose:
    r, v, dr, dv, time, masses = mechanics.get_trajectory({}).dict().values()
    plot_energy_from_coords(r, v, time, masses)
    ani = visualize_trajectory(r.detach().cpu(), len(time), (mechanics.domain_min, mechanics.domain_max))
    display(HTML(ani.to_jshtml()))

    ani.save(sys.path[0] + '/../images/mve_ensemble.gif', writer='pillow')
    display(Image(filename=sys.path[0] + '/../images/mve_ensemble.gif'))

In [5]:
def transform_y(q, p, masses):
    return calc_total_energy_per_cell(
        q, p, masses,
        grid_resolution=(4, 4, 4),
        boundaries=(mechanics.domain_min, mechanics.domain_max)
    ).mean(dim=0)

data, _ = mechanics.get_dataset({"n_samples": 5}, { "n_bodies": 10, "time_scale": 3, "t_span_max": 30 })

if verbose:
    data_r, data_v = [v.squeeze(-2) for v in torch.split(data["x"][0], 1, dim=-2)]
    plot_energy_from_coords(data_r, data_v, data["time"], masses)
    ani = visualize_trajectory(data_r, len(data["time"]), mechanics.domain)
    display(HTML(ani.to_jshtml()))

2024-06-19 18:33:56,182 dynnn.utils - INFO:Data file mve_data-n_samples-5-n_bodies-10_time_scale-3_t_span_max-30_odeint_rtol-1e-10_odeint_atol-1e-06_odeint_order-2.pkl not found.
2024-06-19 18:33:56,182 dynnn.utils - INFO:Creating new data...


NotImplementedError: The operator 'aten::_cdist_backward' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

In [None]:
def loss_fn(dxdt, dxdt_hat, s, masses):
    """
    Calculate the loss
    """
    loss = F.mse_loss(dxdt, dxdt_hat)

    energy_loss = energy_conservation_loss(s, s + dxdt_hat * 0.01, masses).sum()
    loss += energy_loss

    return loss, energy_loss


plot_loss_cb = partial(plot_loss, key_sets=[["train_loss", "test_loss"], ["train_additional_loss", "test_additional_loss"]])
initial_args = SimulatorArgs(training_args=PinnTrainingArgs(loss_fn=loss_fn, plot_loss_callback=plot_loss_cb))
model = TaskModel(initial_sim_args=initial_args)
model.train(SimulatorTrainingArgs(), data, transform_y)

In [None]:
# plot model output

test_y0, test_masses = get_initial_conditions(5)
initial_state = test_y0.clone().detach().requires_grad_()
r, v, dr, dv, time = mechanics.get_trajectory({"y0": initial_state, "masses": test_masses, "model": model}).dict().values()
plot_energy_from_coords(r, v, time, test_masses)

ani_model = visualize_trajectory(r, len(time), (mechanics.domain_min, mechanics.domain_max))
HTML(ani_model.to_jshtml())