In [None]:
%matplotlib widget

import torch
import torch.nn as nn
import torch.optim as optim
import simple_pinn
import cv_mesh
import matplotlib.pyplot as plt

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

from mhd_utils import *

# for dev purposes, reload these modules each time this cell is run
import importlib
importlib.reload(simple_pinn)
importlib.reload(cv_mesh)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
import numpy as np

pts = 4
xi = np.linspace(-1, 1, pts)
wi = np.array([1.0] + [2.0 for _ in range(pts - 2)] + [1.0])
wi = 2.0 * wi / sum(wi)

xi = torch.tensor(xi)
wi = torch.tensor(wi)

quad_dict = {"t": (xi, wi), "x": (xi, wi)}
quad_dict

In [None]:
# example mesh structure
t_domain = [0.0, 0.2]
x_domain = [-1.0, 1.0]
Nt = 100
Nx = 100
mesh = cv_mesh.CVMesh(t_domain, x_domain, Nt, Nx, None, quad_pts=(8, 8))
# mesh = cv_mesh.CVMesh(t_domain, x_domain, Nt, Nx, quad_dict)
# mesh.plot()
mesh.to(DEVICE)

In [None]:
F_t_eval_points, F_x_eval_points, _, _ = mesh.get_training_eval_points_and_weights()
x_points = torch.cat(
    (
        F_t_eval_points[F_t_eval_points[..., 0] == t_domain[0]][:, 1],
        F_x_eval_points[F_x_eval_points[..., 0] == t_domain[0]][:, 1],
    )
)
t_points = torch.cat(
    (
        F_t_eval_points[
            torch.logical_or(
                F_t_eval_points[..., 1] == x_domain[0],
                F_t_eval_points[..., 1] == x_domain[1],
            )
        ][:, 0],
        F_x_eval_points[
            torch.logical_or(
                F_x_eval_points[..., 1] == x_domain[0],
                F_x_eval_points[..., 1] == x_domain[1],
            )
        ][:, 0],
    )
)
x_points.shape, t_points.shape
x_points.min(), x_points.max(), t_points.min(), t_points.max()

In [None]:
# the core neural network model maps from space and time (t, x, y, z) to the *ideal* MHD state variables, which in general are
# rho, v_x, v_y, v_z, B_x, B_y, B_z, and E, where E is the total energy density;
# in this case, we are solving the 1D (in space) problem, so the model maps from (t, x) to the aforementioned state variables
mhd_state_variables_nn = simple_pinn.BrioAndWuPINN(
    2,
    [64, 64, 64, 64, 64],
    8,
    activation=nn.Tanh(),
    use_bias_in_output_layer=False,
).to(DEVICE)

optimizer = optim.Adam(mhd_state_variables_nn.parameters(), lr=1.0e-5)

total_loss_history = []
pde_loss_history = []
mnpl_loss_history = []
ic_loss_history = []
bc_loss_history = []
n_epochs = 100_000
for epoch in range(n_epochs):
    pde_loss, loss_structure = cv_based_PDE_residual_loss(
        mhd_state_variables_nn, ideal_equation_of_state, mesh
    )
    # mnpl_loss = monopole_loss(inputs, outputs)
    mnpl_loss = torch.tensor(0.0)
    (
        ic_loss,
        bc_loss,
        mhd_state_variables_ic_with_p,
    ) = collocation_based_brio_and_wu_IC_BC_residual_loss(
        mhd_state_variables_nn,
        ideal_equation_of_state,
        Nx=Nx,
        Nt=Nt,
        t_p=t_points,
        x_p=x_points,
        device=DEVICE,
    )
    loss = pde_loss + ic_loss + bc_loss
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    # write out the loss values
    if epoch % 100 == 0:
        print(
            f"epoch {epoch} loss: {loss.item()} pde_loss: {pde_loss.item()} monopole_loss: {mnpl_loss.item()} ic_loss: {ic_loss.item()} bc_loss: {bc_loss.item()}"
        )
    total_loss_history.append(loss.item())
    pde_loss_history.append(pde_loss.item())
    mnpl_loss_history.append(mnpl_loss.item())
    ic_loss_history.append(ic_loss.item())
    bc_loss_history.append(bc_loss.item())
# put the model in evaluation mode
mhd_state_variables_nn.eval()
t, x, inputs = mesh.get_eval_points(centered=False)
t_c, x_c, _ = mesh.get_eval_points(centered=True)
outputs = mhd_state_variables_nn(inputs)

In [None]:
t.shape, x.shape, inputs.shape, outputs.shape

In [None]:
plt.figure()
plt.plot(pde_loss_history, label="PDE Residual Loss")
plt.plot(mnpl_loss_history, label="Monopole Loss")
plt.plot(ic_loss_history, label="IC Loss")
plt.plot(bc_loss_history, label="BC Loss")
plt.plot(total_loss_history, label="Total Loss", color="black")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Loss History")
plt.legend()
plt.semilogy()
plt.show()

In [None]:
# t_ndx to plot
t_ndx = 0

# Get the component names
component_names = ["rho", "v_x", "v_y", "v_z", "B_x", "B_y", "B_z", "p"]

# Set up the figure and subplots
fig, axs = plt.subplots(2, 4, figsize=(12, 6))
fig.suptitle("Component Plots")

# Iterate over the components and create subplots
for i, component in enumerate(component_names):
    # Get the component values at t = t_ndx
    component_values = outputs[t_ndx, :, i].to("cpu").detach().numpy()
    component_ic_values = mhd_state_variables_ic_with_p[:, i].to("cpu").detach().numpy()

    if component == "p":
        component_values = ideal_equation_of_state(outputs[t_ndx, :].to("cpu").detach())

    # Determine the subplot position
    row = i // 4
    col = i % 4

    # Plot the component values
    axs[row, col].plot(x.to("cpu").detach().numpy(), component_values)
    # axs[row, col].plot(
    #     x.to("cpu").detach().numpy(), component_ic_values, linestyle="--"
    # )
    axs[row, col].set_xlabel("x")
    axs[row, col].set_ylabel(component)
    axs[row, col].set_title(f"{component} at t_ndx = {t_ndx}")
    y_min = component_ic_values.min() - 0.1
    y_max = component_ic_values.max() + 0.1
    axs[row, col].set_ylim([y_min, y_max])  # Set y-axis limits

# Adjust the spacing between subplots
plt.tight_layout()

# Show the plot
plt.show()

In [None]:
# t_ndx to plot
t_ndx = -1

# Get the component names
component_names = ["rho", "v_x", "v_y", "v_z", "B_x", "B_y", "B_z", "p"]

# Set up the figure and subplots
fig, axs = plt.subplots(2, 4, figsize=(16, 6))
fig.suptitle("Component Plots with PDE Residual Loss")

# Iterate over the components and create subplots
for i, component in enumerate(component_names):
    # Get the component values at t = t_ndx
    component_values = outputs[t_ndx, :, i].to("cpu").detach().numpy()
    component_ic_values = mhd_state_variables_ic_with_p[:, i].to("cpu").detach().numpy()

    if component == "p":
        component_values = ideal_equation_of_state(outputs[t_ndx, :].to("cpu").detach())

    # Determine the subplot position
    row = i // 4
    col = i % 4

    # Plot the component values
    ax = axs[row, col]
    ax.plot(x.to("cpu").detach().numpy(), component_values)
    # ax.plot(x.to("cpu").detach().numpy(), component_ic_values, linestyle="--")
    ax.set_xlabel("x")
    ax.set_ylabel(component)
    ax.set_title(f"{component} at t_ndx = {t_ndx}")
    y_min = component_ic_values.min() - 0.1
    y_max = component_ic_values.max() + 0.1
    ax.set_ylim([y_min, y_max])  # Set y-axis limits

    # Create a twin axis for loss_structure
    ax2 = ax.twinx()
    ax2.plot(
        x_c.to("cpu").detach().numpy(),
        loss_structure[t_ndx, :, i].to("cpu").detach().numpy(),
        color="red",
        linestyle=":",
    )
    ax2.set_ylabel("Loss", color="red")

# Adjust the spacing between subplots
plt.tight_layout()

# Show the plot
plt.show()

In [None]:
# Assuming 'outputs', 'mhd_state_variables_ic_with_p', and 'x' are defined, as well as 't'
t_ndx = 0

# Get the component names
component_names = ["rho", "v_x", "v_y", "v_z", "B_x", "B_y", "B_z", "p"]

# Set up the figure and subplots
fig, axs = plt.subplots(2, 4, figsize=(12, 6))
fig.suptitle("Component Plots")


def animate(t_ndx):
    # Clear previous plots
    for ax in axs.flat:
        ax.clear()

    # Iterate over the components and update subplots
    for i, component in enumerate(component_names):
        # Get the component values at the current time index
        component_values = outputs[t_ndx, :, i].to("cpu").detach().numpy()
        component_ic_values = (
            mhd_state_variables_ic_with_p[:, i].to("cpu").detach().numpy()
        )

        if component == "p":
            component_values = ideal_equation_of_state(
                outputs[t_ndx, :].to("cpu").detach()
            )

        # Determine the subplot position
        row = i // 4
        col = i % 4

        # Get the two axes for the subplot
        ax1 = axs[row, col]

        # Update the plot for the component
        ax1.plot(x.to("cpu").detach().numpy(), component_values, label="MHD Value")
        ax1.plot(
            x.to("cpu").detach().numpy(),
            component_ic_values,
            linestyle="--",
            label="Initial Condition",
        )

        ax1.set_xlabel("x")
        ax1.set_ylabel(component)
        ax1.set_title(f"{component} at t_ndx = {t_ndx}")
        y_min = component_ic_values.min() - 0.1
        y_max = component_ic_values.max() + 0.1
        ax1.set_ylim([y_min, y_max])  # Set y-axis limits for MHD values

    # Adjust the spacing between subplots
    plt.tight_layout()


# Create animation
# ani = FuncAnimation(fig, animate, frames=len(t), interval=100, blit=True)

# ani.save("brio_and_wu.gif", writer="pillow")

# Show the animation
# plt.show()

In [None]:
fig, axes = plt.subplots(2, 4, figsize=(20, 10))  # Adjust size as needed
fig.suptitle("Component Loss Structures")
for i, component in enumerate(component_names):
    ax = axes[i // 4, i % 4]  # Select subplot
    im = ax.imshow(
        np.log(loss_structure[:, :, i].cpu().detach()), cmap="RdYlGn_r", origin="lower"
    )
    ax.set_title(f"{component}")
    ax.set_xlabel("x")
    ax.set_ylabel("t")

    cbar = fig.colorbar(im, ax=ax, fraction=0.025, pad=0.04)
    cbar.ax.set_ylabel("Loss Value", rotation=270, labelpad=15)

plt.tight_layout()
plt.show()