In [1]:
from IPython.display import display, HTML, Image
import matplotlib
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import sys  
import os
from functools import partial

sys.path.insert(0, '../../..')

from dynnn.simulation.mve_ensemble import MveEnsembleMechanics
from dynnn.simulation.mve_ensemble.viz import visualize_trajectory, plot_energy
from dynnn.simulation.mve_ensemble.mve_ensemble import energy_conservation_loss, calc_kinetic_energy, get_initial_conditions
from dynnn.train.simulator.train_simulator import train_simulator as train
from dynnn.utils import load_model




ModuleNotFoundError: No module named 'dynnn.train.train_simulator'

In [None]:
import argparse

def get_args():
    parser = argparse.ArgumentParser(allow_abbrev=False)
    parser.add_argument('--device', default="cpu", type=str, help="device to run on")
    parser.add_argument('--rl_learn_rate', default=1e-4, type=float, help='learning rate')
    parser.add_argument('--rl_weight_decay', default=1e-5, type=float, help='weight decay')
    parser.add_argument('--num_experiments', default=100, type=int, help='number of RL param switch experiments')
    parser.add_argument('--max_simulator_steps', default=1000, type=int, help='max steps within an experiment')
    parser.add_argument('--verbose', default=False, type=bool, help='is notebook verbose? shows extra stuff that takes time to compute.')
    parser.set_defaults(feature=True)
    return parser.parse_known_args()[0]

args = get_args()

torch.set_default_device(args.device)

In [None]:
mechanics = MveEnsembleMechanics()

def plot_energy_from_coords(r, v, time, masses):
    pe = mechanics.no_bc_potential_fn(r).detach().cpu()
    ke = calc_kinetic_energy(v, masses).detach().cpu()
    te = pe + ke

    plot_energy(pe, ke, te, time.cpu())

In [None]:
y0, masses = get_initial_conditions(5)

if args.verbose:
    r, v, dr, dv, time = mechanics.get_trajectory({ "y0": y0, "masses": masses }).dict().values()
    plot_energy_from_coords(r, v, time, masses)
    ani = visualize_trajectory(r.detach().cpu(), len(time), (mechanics.domain_min, mechanics.domain_max))
    display(HTML(ani.to_jshtml()))

In [None]:
if args.verbose:
    ani.save(sys.path[0] + '/../images/mve_ensemble.gif', writer='pillow')
    display(Image(filename=sys.path[0] + '/../images/mve_ensemble.gif'))

In [None]:
# data, _ = mechanics.get_dataset({}, { "y0": y0, "masses": masses })

In [None]:
if args.verbose:
    data_r, data_v = [v.squeeze(-2) for v in torch.split(data["x"][0], 1, dim=-2)]
    plot_energy_from_coords(data_r, data_v, data["time"], masses)
    ani = visualize_trajectory(data_r, len(data["time"]), mechanics.domain)
    display(HTML(ani.to_jshtml()))

In [None]:
# load or train model
%matplotlib inline


def plot_loss(stats: dict, zoom_length: int = 500, key_sets: list[list[str]] | None = None):
    if key_sets is None:
        key_sets = [stats.keys()]

    if not 'fig' in vars():
        fig, axes = plt.subplots(len(key_sets), 2, figsize=(10, 4))

    for i, keys in enumerate(key_sets):
        axes[i][0].clear()
        axes[i][0].set_title("Loss")
        axes[i][0].set_xlabel("time")
        for key in keys:
            axes[i][0].plot(getattr(stats, key), label=key)
        axes[i][0].legend(fontsize=8)
        axes[i][0].set_yscale("log")

        axes[i][1].clear()
        axes[i][1].set_title(f"Last {zoom_length} steps loss")
        axes[i][1].set_xlabel("time")
        for key in keys:
            axes[i][1].plot(getattr(stats, key)[-zoom_length:], label=key)
        axes[i][1].legend(fontsize=8)
        axes[i][1].set_yscale("log")
        fig.tight_layout()
    display(fig, clear=True)

# model = load_model(args.model)
def loss_fn(dxdt, dxdt_hat, s, masses):
    """
    Calculate the loss
    """
    loss = F.mse_loss(dxdt, dxdt_hat)

    energy_loss = energy_conservation_loss(s, s + dxdt_hat * 0.01, masses).sum()
    loss += energy_loss

    return loss, energy_loss

model, stats = train(
    args,
    pinn_loss_fn=loss_fn,
    plot_loss_callback=partial(plot_loss, key_sets=[["train_loss", "test_loss"], ["train_additional_loss", "test_additional_loss"]])
)


Loading data from ../../../../data/mve_data-n_samples-2-n_bodies-2_time_scale-1_t_span_min-0_t_span_max-5.pkl
Data file mve_data-n_samples-2-n_bodies-2_time_scale-1_t_span_min-0_t_span_max-5.pkl not found.
Creating new data...
step 0, train_loss 4.3160e-01, additional_loss 3.5048e-02, test_loss 9.3974e-01, test_additional_loss 6.4161e-01
step 1, train_loss 1.5437e+00, additional_loss 1.1670e+00, test_loss 4.8948e-01, test_additional_loss 2.2893e-01
step 2, train_loss 6.8031e-01, additional_loss 3.2191e-01, test_loss 4.2185e-01, test_additional_loss 1.5199e-01
step 3, train_loss 6.5449e-01, additional_loss 2.9111e-01, test_loss 4.3366e-01, test_additional_loss 1.7331e-01
step 4, train_loss 3.8956e-01, additional_loss 3.3426e-02, test_loss 4.3167e-01, test_additional_loss 1.7467e-01
step 5, train_loss 4.4098e-01, additional_loss 8.9884e-02, test_loss 4.4701e-01, test_additional_loss 1.9743e-01
step 6, train_loss 3.6953e-01, additional_loss 2.5082e-02, test_loss 5.1056e-01, test_additiona

In [None]:
# plot model output

test_y0, test_masses = get_initial_conditions(5)
initial_state = test_y0.clone().detach().requires_grad_()
r, v, dr, dv, time = mechanics.get_trajectory({"y0": initial_state, "masses": test_masses, "model": model}).dict().values()
plot_energy_from_coords(r, v, time, test_masses)

ani_model = visualize_trajectory(r, len(time), (mechanics.domain_min, mechanics.domain_max))
HTML(ani_model.to_jshtml())