In [None]:
%matplotlib widget

import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim
import simple_pinn
import matplotlib.pyplot as plt
import scipy.io

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

# for dev purposes, reload the simple_pinn module each time this cell is run
import importlib
importlib.reload(simple_pinn)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
def ideal_equation_of_state(mhd_state_variables, gamma=2.0):
    # the ideal gas equation of state is p = (gamma - 1) * (E - rho * v^2 / 2 - B^2 / 2)
    # where p is the pressure, gamma is the ratio of specific heats (adiabatic index),
    # E is the total energy density, rho is the mass density, and v and B are the velocity and magnetic field vectors, respectively
    rho = mhd_state_variables[..., 0]
    v_x = mhd_state_variables[..., 1]
    v_y = mhd_state_variables[..., 2]
    v_z = mhd_state_variables[..., 3]
    B_x = mhd_state_variables[..., 4]
    B_y = mhd_state_variables[..., 5]
    B_z = mhd_state_variables[..., 6]
    E = mhd_state_variables[..., 7]
    return (gamma - 1.0) * (
        E
        - rho * (v_x**2 + v_y**2 + v_z**2) / 2.0
        - (B_x**2 + B_y**2 + B_z**2) / 2.0
    )

In [None]:
# construct F_t, F_x, F_y, and F_z construct fluxes out of the MHD state variables and a provided equation of state function;
# right now only constructs F_t and F_x for the 1D (in space) problem;
# these fluxes are taken from Gardiner and Stone (2005), https://arxiv.org/pdf/astro-ph/0501557.pdf, where F_t, F_x, and F_y are equations 7 and 8;
# you can also just expand out the ideal MHD equations in conservative form (equations 1--4)
def construct_fluxes(mhd_state_variables, equation_of_state):
    # extract the state variables from the flattened input
    rho = mhd_state_variables[..., 0]
    v_x = mhd_state_variables[..., 1]
    v_y = mhd_state_variables[..., 2]
    v_z = mhd_state_variables[..., 3]
    B_x = mhd_state_variables[..., 4]
    B_y = mhd_state_variables[..., 5]
    B_z = mhd_state_variables[..., 6]
    E = mhd_state_variables[..., 7]
    # arrange these into F_t;
    # the selection operations above reduced the rank of the tensors by 1, we want to stack along this lost dimension,
    # so that the flux is a tensor of shape (..., 8)
    F_t = torch.stack(
        [rho, rho * v_x, rho * v_y, rho * v_z, B_x, B_y, B_z, E], dim=len(rho.shape)
    )
    # for the spatial components of the flux, we need to compute the total pressure P_star,
    # which is the sum of the gas pressure P and the magnetic pressure B^2 / 2;
    # the gas pressure is a function of the MHD state variables and the provided equation of state
    P = equation_of_state(mhd_state_variables)
    P_star = P + (B_x**2 + B_y**2 + B_z**2) / 2.0
    # now construct F_x (restoring the lost dimension as in F_t)
    F_x = torch.stack(
        [
            rho * v_x,
            rho * v_x**2 + P_star - B_x * B_x,
            rho * v_x * v_y - B_x * B_y,
            rho * v_x * v_z - B_x * B_z,
            torch.zeros_like(rho),
            v_x * B_y - B_x * v_y,
            v_x * B_z - B_x * v_z,
            (E + P_star) * v_x - B_x * (B_x * v_x + B_y * v_y + B_z * v_y),
        ],
        dim=len(rho.shape),
    )
    # at some point we could
    # construct F_y and F_z here as well
    return F_t, F_x

In [None]:
# compute the collocation point-based PDE residual loss function, i.e., equation 6 of Gardiner and Stone (2005),
# that is, just the continuity equation that the ideal MHD equations form in their conservative form
# d/dt F_t + d/dx F_x + d/dy F_y + d/dz F_z = 0;
# note this is just for the 1D (in space) problem for now
def collocation_based_PDE_residual_loss(inputs, F_t, F_x):
    dF_t_dt = torch.zeros_like(F_t)
    dF_x_dx = torch.zeros_like(F_x)
    for k in range(F_t.shape[-1]):
        F_t_kth_component = F_t[..., k]
        F_x_kth_component = F_x[..., k]
        dF_t_dt[..., k] = torch.autograd.grad(
            F_t_kth_component,
            inputs,
            torch.ones_like(F_t_kth_component),
            retain_graph=True,
            create_graph=True,
        )[0][..., 0]
        dF_x_dx[..., k] = torch.autograd.grad(
            F_x_kth_component,
            inputs,
            torch.ones_like(F_x_kth_component),
            retain_graph=True,
            create_graph=True,
        )[0][..., 1]
    loss = torch.mean(torch.square(dF_t_dt + dF_x_dx))
    return loss

In [None]:
# compute the collocation point-based initial condition and boundary loss functions
def collocation_based_IC_BC_residual_loss(model, eos, Nx=101, Nt=101):
    # from here on is specific to the Brio and Wu shock tube problem
    x = torch.linspace(-1.0, 1.0, Nx).to(DEVICE)
    t = torch.zeros_like(x).to(DEVICE)
    inputs = torch.stack([t, x], dim=1)
    mhd_state_variables = model(inputs)
    # we construct a comparison tensor that describes the IC of the Brio and Wu shock tube problem
    rho_IC = torch.where(x < 0.0, 1.0, 0.125)
    v_x_IC = torch.zeros_like(x)
    v_y_IC = torch.zeros_like(x)
    v_z_IC = torch.zeros_like(x)
    B_x_IC = 0.75 * torch.ones_like(x)
    B_y_IC = torch.where(x < 0.0, 1.0, -1.0)
    B_z_IC = torch.zeros_like(x)
    # note that Brio and Wu specify an initial pressure, not an initial total energy...
    p_IC = torch.where(x < 0.0, 1.0, 0.1)
    IC_mhd_state_variables = torch.stack(
        [rho_IC, v_x_IC, v_y_IC, v_z_IC, B_x_IC, B_y_IC, B_z_IC, p_IC],
        dim=1,
    )
    # ... therefore, before constructing the loss, we need to translate the initial energy density of our model output to pressure
    p = eos(mhd_state_variables)
    mhd_state_variables_with_p = torch.cat(
        [mhd_state_variables[..., :-1], p.unsqueeze(1)], dim=1
    )
    # now construct the loss
    ic_loss = torch.mean(
        torch.square(mhd_state_variables_with_p - IC_mhd_state_variables)
    )

    # compute the collocation point-based boundary condition residual loss function
    IC_mhd_state_variables_left = IC_mhd_state_variables[0, :]
    IC_mhd_state_variables_right = IC_mhd_state_variables[-1, :]

    t = torch.linspace(0.0, 0.2, Nt).to(DEVICE)
    x_left = -torch.ones_like(t)
    x_right = torch.ones_like(t)
    inputs_left = torch.stack([t, x_left], dim=1)
    inputs_right = torch.stack([t, x_right], dim=1)
    mhd_state_variables_left = model(inputs_left)
    mhd_state_variables_right = model(inputs_right)
    p_left = eos(mhd_state_variables_left)
    p_right = eos(mhd_state_variables_right)
    mhd_state_variables_with_p_left = torch.cat(
        [mhd_state_variables_left[..., :-1], p_left.unsqueeze(1)], dim=1
    )
    mhd_state_variables_with_p_right = torch.cat(
        [mhd_state_variables_right[..., :-1], p_right.unsqueeze(1)], dim=1
    )
    loss_left = torch.mean(
        torch.square(mhd_state_variables_with_p_left - IC_mhd_state_variables_left)
    )
    loss_right = torch.mean(
        torch.square(mhd_state_variables_with_p_right - IC_mhd_state_variables_right)
    )
    bc_loss = 0.5 * (loss_left + loss_right)

    return ic_loss, bc_loss

In [None]:
def monopole_loss(inputs, mhd_state_variables):
    B_x = mhd_state_variables[..., 4]
    B_y = mhd_state_variables[..., 5]
    B_z = mhd_state_variables[..., 6]
    # get the divergence of the magnetic field everywhere
    dB_x_dx = torch.autograd.grad(
        B_x,
        inputs,
        torch.ones_like(B_x),
        retain_graph=True,
        create_graph=True,
    )[0][..., 1]
    # dB_y_dy and dB_z_dz are zero because this is a 1D spatial problem
    monopole_loss = torch.mean(torch.square(dB_x_dx))
    return monopole_loss

In [None]:
# the core neural network model maps from space and time (t, x, y, z) to the *ideal* MHD state variables, which in general are
# rho, v_x, v_y, v_z, B_x, B_y, B_z, and E, where E is the total energy density;
# in this case, we are solving the 1D (in space) problem, so the model maps from (t, x) to the aforementioned state variables
mhd_state_variables_nn = simple_pinn.SimplePINN(
    2,
    [32, 64, 128, 64, 32],
    8,
    activation=nn.Softplus(),
).to(DEVICE)

optimizer = optim.Adam(mhd_state_variables_nn.parameters(), lr=0.0025)

t = torch.linspace(0.0, 0.2, 51, requires_grad=True).to(DEVICE)
x = torch.linspace(-1.0, 1.0, 3001, requires_grad=True).to(DEVICE)
T, X = torch.meshgrid(t, x, indexing="ij")
inputs = torch.stack([T, X], dim=len(T.shape))

total_loss_history = []
pde_loss_history = []
mnpl_loss_history = []
ic_loss_history = []
bc_loss_history = []
n_epochs = 10_000
for epoch in range(n_epochs):
    outputs = mhd_state_variables_nn(inputs)
    F_t, F_x = construct_fluxes(outputs, ideal_equation_of_state)
    pde_loss = collocation_based_PDE_residual_loss(inputs, F_t, F_x)
    mnpl_loss = monopole_loss(inputs, outputs)
    ic_loss, bc_loss = collocation_based_IC_BC_residual_loss(
        mhd_state_variables_nn, ideal_equation_of_state
    )
    loss = pde_loss + mnpl_loss + ic_loss + bc_loss
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    # write out the loss values
    if epoch % 100 == 0:
        print(
            f"epoch {epoch} loss: {loss.item()} pde_loss: {pde_loss.item()} monopole_loss: {mnpl_loss.item()} ic_loss: {ic_loss.item()} bc_loss: {bc_loss.item()}"
        )
    total_loss_history.append(loss.item())
    pde_loss_history.append(pde_loss.item())
    mnpl_loss_history.append(mnpl_loss.item())
    ic_loss_history.append(ic_loss.item())
    bc_loss_history.append(bc_loss.item())

In [None]:
plt.figure()
plt.plot(pde_loss_history, label="PDE Residual Loss")
plt.plot(mnpl_loss_history, label="Monopole Loss")
plt.plot(ic_loss_history, label="IC Loss")
plt.plot(bc_loss_history, label="BC Loss")
plt.plot(total_loss_history, label="Total Loss", color="black")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Loss History")
plt.legend()
plt.semilogy()
plt.show()

In [None]:
x_d = x.to("cpu").detach().numpy()
t_d = t.to("cpu").detach().numpy()
var_to_plot = outputs[..., 2].to("cpu").detach().numpy()
var_name = "v_y"

# Set up the figure
fig, ax = plt.subplots()
(line,) = ax.plot(x_d, var_to_plot[0, :], color="blue")

# Add a text element for the time
time_text = ax.text(
    0.95,
    0.95,
    "",
    transform=ax.transAxes,
    verticalalignment="top",
    horizontalalignment="right",
)

# Axis labels
ax.set_xlabel("x")
ax.set_ylabel(var_name)
ax.set_title(f"{var_name} as a function of x over time")


# Initialization function
def init():
    line.set_ydata(np.ma.array(x_d, mask=True))
    time_text.set_text("")
    return line, time_text


# Animation update function
def update(frame):
    line.set_ydata(var_to_plot[frame, :])
    time_text.set_text(f"Time: {t_d[frame]:.3f}")
    return line, time_text


# Create the animation
ani = FuncAnimation(fig, update, frames=len(t_d), init_func=init, blit=True)

ani.save(f"{var_name}_animation_with_time.gif", fps=5)

plt.show()