In [1]:
from IPython.display import display, HTML, Image
import matplotlib
import matplotlib.pyplot as plt
import math
import torch
import torch.nn.functional as F
import sys  
import os
from functools import partial

sys.path.insert(0, '../../..')
sys.path.insert(1, '../../../logs')

from dynnn.simulation.mve_ensemble import MveEnsembleMechanics
from dynnn.simulation.mve_ensemble.viz import visualize_trajectory, plot_energy
from dynnn.simulation.mve_ensemble.mve_ensemble import energy_conservation_loss, calc_kinetic_energy, get_initial_conditions, calc_total_energy_per_cell
from dynnn.layers import TaskModel
from dynnn.types import Dataset, PinnTrainingArgs, SimulatorTrainingArgs, SimulatorArgs
from dynnn.utils import load_model

In [2]:
DEVICE='cpu'
if torch.cuda.is_available():
    DEVICE='cuda'

verbose=False
torch.set_default_device(DEVICE)

In [3]:
%matplotlib inline

mechanics = MveEnsembleMechanics()

def plot_energy_from_coords(r, v, time, masses):
    pe = mechanics.no_bc_potential_fn(r).detach().cpu()
    ke = calc_kinetic_energy(v, masses).detach().cpu()
    te = pe + ke

    plot_energy(pe, ke, te, time.cpu())

def plot_loss(stats: dict, zoom_length: int = 500, key_sets: list[list[str]] | None = None):
    if key_sets is None:
        key_sets = [stats.keys()]

    if not 'fig' in vars():
        fig, axes = plt.subplots(len(key_sets), 2, figsize=(10, 4))

    for i, keys in enumerate(key_sets):
        axes[i][0].clear()
        axes[i][0].set_title("Loss")
        axes[i][0].set_xlabel("time")
        for key in keys:
            axes[i][0].plot(stats.get_as_float(key), label=key)
        axes[i][0].legend(fontsize=8)
        axes[i][0].set_yscale("log")

        axes[i][1].clear()
        axes[i][1].set_title(f"Last {zoom_length} steps loss")
        axes[i][1].set_xlabel("time")
        for key in keys:
            axes[i][1].plot(stats.get_as_float(key)[-zoom_length:], label=key)
        axes[i][1].legend(fontsize=8)
        axes[i][1].set_yscale("log")
        fig.tight_layout()
    display(fig, clear=True)

In [4]:
if verbose:
    r, v, dr, dv, time, masses = mechanics.get_trajectory({}).dict().values()
    plot_energy_from_coords(r, v, time, masses)
    ani = visualize_trajectory(r.detach().cpu(), len(time), (mechanics.domain_min, mechanics.domain_max))
    display(HTML(ani.to_jshtml()))

    ani.save(sys.path[0] + '/../images/mve_ensemble.gif', writer='pillow')
    display(Image(filename=sys.path[0] + '/../images/mve_ensemble.gif'))

In [6]:
data, _ = mechanics.get_dataset({"n_samples": 2}, { "n_bodies": 10, "time_scale": 1, "t_span_max": 20, "odeint_atol": 1e-8, "odeint_rtol": 1e-9 })

if verbose:
    data_r, data_v = [v.squeeze(-2) for v in torch.split(data["x"][0], 1, dim=-2)]
    plot_energy_from_coords(data_r, data_v, data["time"], masses)
    ani = visualize_trajectory(data_r, len(data["time"]), mechanics.domain)
    display(HTML(ani.to_jshtml()))

2024-06-25 07:47:51,047 dynnn.utils - INFO:Data file mve_data-n_samples-2-n_bodies-10_time_scale-1_t_span_max-20_odeint_rtol-1e-09_odeint_atol-1e-08_odeint_solver-8.pkl not found.
2024-06-25 07:47:51,048 dynnn.utils - INFO:Creating new data...
2024-06-25 07:47:52,720 dynnn.mechanics.base_mechanics - INFO:Trajectory 1cf6ed5a17354f0b84ef98cfeff619b0: 500 steps (last t: 0.0009932521497830749)
2024-06-25 07:47:52,825 dynnn.mechanics.base_mechanics - INFO:Trajectory 93e98e3dca574a8ca235afd8c309b004: 500 steps (last t: 2.959224900678237e-07)
2024-06-25 07:47:53,347 dynnn.mechanics.base_mechanics - INFO:Trajectory 1cf6ed5a17354f0b84ef98cfeff619b0: 1000 steps (last t: 0.0018320853123441339)
2024-06-25 07:47:53,652 dynnn.mechanics.base_mechanics - INFO:Trajectory 93e98e3dca574a8ca235afd8c309b004: 1000 steps (last t: 4.188604236787796e-07)
2024-06-25 07:47:54,102 dynnn.mechanics.base_mechanics - INFO:Trajectory 800fe68073e1425e99f944a42e10cdcb: 500 steps (last t: 6.80118574791777e-08)
2024-06-25

DatasetGenerationFailure: Trajectory stalled

2024-06-25 07:58:56,108 dynnn.mechanics.base_mechanics - INFO:Trajectory 1cf6ed5a17354f0b84ef98cfeff619b0: 32000 steps (last t: 2.271740674972534)
2024-06-25 07:59:16,484 dynnn.mechanics.base_mechanics - INFO:Trajectory 1cf6ed5a17354f0b84ef98cfeff619b0: 32500 steps (last t: 2.315324544906616)
2024-06-25 07:59:37,635 dynnn.mechanics.base_mechanics - INFO:Trajectory 1cf6ed5a17354f0b84ef98cfeff619b0: 33000 steps (last t: 2.325186252593994)
2024-06-25 07:59:59,152 dynnn.mechanics.base_mechanics - INFO:Trajectory 1cf6ed5a17354f0b84ef98cfeff619b0: 33500 steps (last t: 2.3260936737060547)
2024-06-25 08:00:20,757 dynnn.mechanics.base_mechanics - INFO:Trajectory 1cf6ed5a17354f0b84ef98cfeff619b0: 34000 steps (last t: 2.3265600204467773)
2024-06-25 08:00:42,562 dynnn.mechanics.base_mechanics - INFO:Trajectory 1cf6ed5a17354f0b84ef98cfeff619b0: 34500 steps (last t: 2.3268940448760986)
2024-06-25 08:01:04,849 dynnn.mechanics.base_mechanics - INFO:Trajectory 1cf6ed5a17354f0b84ef98cfeff619b0: 35000 st

In [None]:
def loss_fn(dxdt_hat, dxdt, s, masses):
    """
    Calculate the loss
    """
    loss = F.mse_loss(dxdt_hat, dxdt)
    energy_loss = energy_conservation_loss(s + dxdt_hat * 0.01, s, masses).sum()
    return loss + energy_loss, energy_loss

def transform_y(q, p, masses):
    energy_by_cell = calc_total_energy_per_cell(
        q, p, masses,
        grid_resolution=task_grid_resolution,
        boundaries=(mechanics.domain_min, mechanics.domain_max)
    )
    return energy_by_cell

task_grid_resolution = (4, 4, 4)
task_output_dim = math.prod(task_grid_resolution) # number of cells in grid
task_input_dim = math.prod(data.x.shape[2:])

plot_loss_cb = partial(plot_loss, key_sets=[["train_loss", "test_loss"], ["train_additional_loss", "test_additional_loss"]])
initial_args = SimulatorArgs(training_args=PinnTrainingArgs(loss_fn=loss_fn, plot_loss_callback=plot_loss_cb))

model = TaskModel(input_dim=task_input_dim, output_dim=task_output_dim, initial_sim_args=initial_args)
model.do_train(SimulatorTrainingArgs(), data, transform_y)

In [None]:
# plot model output

test_y0, test_masses = get_initial_conditions(5)
initial_state = test_y0.clone().detach().requires_grad_()
r, v, dr, dv, time = mechanics.get_trajectory({"y0": initial_state, "masses": test_masses, "model": model}).dict().values()
plot_energy_from_coords(r, v, time, test_masses)

ani_model = visualize_trajectory(r, len(time), (mechanics.domain_min, mechanics.domain_max))
HTML(ani_model.to_jshtml())

In [7]:
import math

class ComplexNumber:
    def __init__(self, real, imaginary):
        self.real = real
        self.i = imaginary

    @property
    def imaginary(self):
        return self.i

    @classmethod
    def parse(cls, other):
        if isinstance(other, cls):
            return other

        return ComplexNumber(other, 0)
            
    def __eq__(self, other):
        o = ComplexNumber.parse(other)
        return o.real == self.real \
               and o.i == self.i

    def __add__(self, other): 
        o = ComplexNumber.parse(other)
        return ComplexNumber(self.real + o.real, self.i + o.i)

    def __mul__(self, other):
        # (a + i * b) * (c + i * d) = (a * c - b * d) + (b * c + a * d) * i
        o = ComplexNumber.parse(other)
        real_part = (self.real * o.real) - (self.i * o.i)
        i_part = (self.i * o.real) + (self.real * o.i)
        return ComplexNumber(real_part, i_part)

    def __sub__(self, other):
        o = ComplexNumber.parse(other)
        return ComplexNumber(self.real - o.real, self.i - o.i)

    def __div__(self, other):
        return self.__truediv(other)

    def __truediv__(self, other):
        # (a * c + b * d)/(c^2 + d^2)
        o = ComplexNumber.parse(other)
        real_part = ((self.real * o.real) + (self.i * o.i)) / (o.real ** 2 + o.i ** 2)

        # (b * c - a * d)/(c^2 + d^2) * i
        i_part = ((self.i * o.real) - (self.real * o.i)) / (o.real ** 2 + o.i ** 2)
        return ComplexNumber(real_part, i_part)

    def __round__(self, precision = 0):
        return ComplexNumber(round(self.real, precision), round(self.i, precision))
        
    def __abs__(self):
        # |z| = sqrt(a^2 + b^2)
        return self * self.conjugate()

    def conjugate(self):
        # 1 / (a + i * b) = a/(a^2 + b^2) - b/(a^2 + b^2) * i
        a2b2 = self.real ** 2 + self.i ** 2
        return (self.real / a2b2) - (self.i / a2b2)

    def exp(self):
        # e^(a + i * b) = e^a * e^(i * b)
        # e^(i * b) = cos(b) + i * sin(b)

        real_part = math.exp(self.real) + math.cos(self.i)
        i_part = math.sin(self.i)
        return ComplexNumber(real_part, i_part)

    def __str__(self):
        return f"{self.real} + {self.i}i"

    def __repr__(self):
        return str(self)

In [9]:
ComplexNumber(5, 0).conjugate()

0.2