In [1]:
from IPython.display import display, HTML, Image
import matplotlib
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import sys  
import os
from functools import partial

sys.path.insert(0, '../../..')

from dynnn.simulation.mve_ensemble import MveEnsembleMechanics
from dynnn.simulation.mve_ensemble.viz import visualize_trajectory, plot_energy
from dynnn.simulation.mve_ensemble.mve_ensemble import energy_conservation_loss, calc_kinetic_energy, get_initial_conditions
from dynnn.train.simulator.train_simulator import train_simulator as train
from dynnn.utils import load_model




In [2]:
import argparse

def get_args():
    parser = argparse.ArgumentParser(allow_abbrev=False)
    parser.add_argument('--device', default="cpu", type=str, help="device to run on")
    parser.add_argument('--rl_learn_rate', default=1e-4, type=float, help='learning rate')
    parser.add_argument('--rl_weight_decay', default=1e-5, type=float, help='weight decay')
    parser.add_argument('--num_experiments', default=100, type=int, help='number of RL param switch experiments')
    parser.add_argument('--max_simulator_steps', default=1000, type=int, help='max steps within an experiment')
    parser.add_argument('--verbose', default=False, type=bool, help='is notebook verbose? shows extra stuff that takes time to compute.')
    parser.set_defaults(feature=True)
    return parser.parse_known_args()[0]

args = get_args()

torch.set_default_device(args.device)

In [3]:
mechanics = MveEnsembleMechanics()

def plot_energy_from_coords(r, v, time, masses):
    pe = mechanics.no_bc_potential_fn(r).detach().cpu()
    ke = calc_kinetic_energy(v, masses).detach().cpu()
    te = pe + ke

    plot_energy(pe, ke, te, time.cpu())

In [4]:
y0, masses = get_initial_conditions(5)

if args.verbose:
    r, v, dr, dv, time, masses = mechanics.get_trajectory({ "y0": y0, "masses": masses }).dict().values()
    plot_energy_from_coords(r, v, time, masses)
    ani = visualize_trajectory(r.detach().cpu(), len(time), (mechanics.domain_min, mechanics.domain_max))
    display(HTML(ani.to_jshtml()))

In [5]:
if args.verbose:
    ani.save(sys.path[0] + '/../images/mve_ensemble.gif', writer='pillow')
    display(Image(filename=sys.path[0] + '/../images/mve_ensemble.gif'))

In [6]:
# data, _ = mechanics.get_dataset({}, { "y0": y0, "masses": masses })

In [7]:
if args.verbose:
    data_r, data_v = [v.squeeze(-2) for v in torch.split(data["x"][0], 1, dim=-2)]
    plot_energy_from_coords(data_r, data_v, data["time"], masses)
    ani = visualize_trajectory(data_r, len(data["time"]), mechanics.domain)
    display(HTML(ani.to_jshtml()))

In [8]:
# load or train model
%matplotlib inline


def plot_loss(stats: dict, zoom_length: int = 500, key_sets: list[list[str]] | None = None):
    if key_sets is None:
        key_sets = [stats.keys()]

    if not 'fig' in vars():
        fig, axes = plt.subplots(len(key_sets), 2, figsize=(10, 4))

    for i, keys in enumerate(key_sets):
        axes[i][0].clear()
        axes[i][0].set_title("Loss")
        axes[i][0].set_xlabel("time")
        for key in keys:
            axes[i][0].plot(getattr(stats, key), label=key)
        axes[i][0].legend(fontsize=8)
        axes[i][0].set_yscale("log")

        axes[i][1].clear()
        axes[i][1].set_title(f"Last {zoom_length} steps loss")
        axes[i][1].set_xlabel("time")
        for key in keys:
            axes[i][1].plot(getattr(stats, key)[-zoom_length:], label=key)
        axes[i][1].legend(fontsize=8)
        axes[i][1].set_yscale("log")
        fig.tight_layout()
    display(fig, clear=True)

# model = load_model(args.model)
def loss_fn(dxdt, dxdt_hat, s, masses):
    """
    Calculate the loss
    """
    loss = F.mse_loss(dxdt, dxdt_hat)

    energy_loss = energy_conservation_loss(s, s + dxdt_hat * 0.01, masses).sum()
    loss += energy_loss

    return loss, energy_loss

model, stats = train(
    args,
    pinn_loss_fn=loss_fn,
    plot_loss_callback=partial(plot_loss, key_sets=[["train_loss", "test_loss"], ["train_additional_loss", "test_additional_loss"]])
)


Attempting to load data from ../../../../data/mve_data-n_samples-5-n_bodies-2_time_scale-8_t_span_max-74_odeint_rtol-2e-07_odeint_atol-4e-08.pkl
Data file mve_data-n_samples-5-n_bodies-2_time_scale-8_t_span_max-74_odeint_rtol-2e-07_odeint_atol-4e-08.pkl not found.
Creating new data...
Trajectory 7fd21a9aee924791b5a70d732b212497: 500 steps (last t: 10.142131805419922)
Trajectory 7fd21a9aee924791b5a70d732b212497: 1000 steps (last t: 20.32520866394043)
Trajectory 7fd21a9aee924791b5a70d732b212497: 1500 steps (last t: 20.308124542236328)
Trajectory 7fd21a9aee924791b5a70d732b212497: Not making progress (tensor([-0.0888], grad_fn=<SubBackward0>)); giving up
Trajectory 2881b4c2f0cb4a098c31948c8a8520ad: 500 steps (last t: 10.267343521118164)
Trajectory 2881b4c2f0cb4a098c31948c8a8520ad: 1000 steps (last t: 20.700843811035156)
Trajectory 2881b4c2f0cb4a098c31948c8a8520ad: 1500 steps (last t: 31.1751651763916)
Trajectory 2881b4c2f0cb4a098c31948c8a8520ad: 2000 steps (last t: 32.50510787963867)
Traje

In [None]:
# plot model output

test_y0, test_masses = get_initial_conditions(5)
initial_state = test_y0.clone().detach().requires_grad_()
r, v, dr, dv, time = mechanics.get_trajectory({"y0": initial_state, "masses": test_masses, "model": model}).dict().values()
plot_energy_from_coords(r, v, time, test_masses)

ani_model = visualize_trajectory(r, len(time), (mechanics.domain_min, mechanics.domain_max))
HTML(ani_model.to_jshtml())