In [12]:
import h5py
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams
from utils.constants import *
from pinn_setup import PINN
from config.settings import LAYERS  # Your MLP layer config

rcParams['figure.dpi'] = 150
EPOCHS = 4000

# === Grid and frame constants ===
DATA_FILE_NAME = 'data/plasma_data.h5'
NT = 4
NX, NY = 128, 89
N_FRAME = NX * NY
SPACE_FACTOR = 100.0 * MINOR_RADIUS  # cm

# === Load and reshape frame 0 ===
vec_to_mat = lambda vec: np.fliplr(vec.reshape(NX, NY, order='F')).T
with h5py.File(DATA_FILE_NAME, 'r') as f:
    frame_idxs = range(0, N_FRAME)
    den_frame0 = vec_to_mat(f['y_den'][frame_idxs]) * PLASMA_DENSITY
    Te_frame0 = vec_to_mat(f['y_Te'][frame_idxs]) * ELECTRON_TEMP
    x_grid = vec_to_mat(f['x_x'][frame_idxs]) * SPACE_FACTOR
    y_grid = vec_to_mat(f['x_y'][frame_idxs]) * SPACE_FACTOR

# === Prepare training data ===
t_fixed = 0.0
x_vec = x_grid.T.flatten().reshape(-1, 1)
y_vec = y_grid.T.flatten().reshape(-1, 1)
t_vec = np.full_like(x_vec, t_fixed)
v1_vec = den_frame0.T.flatten().reshape(-1, 1)
v5_vec = Te_frame0.T.flatten().reshape(-1, 1)

# Convert to tensors
x_train = torch.tensor(x_vec, dtype=torch.float32)
y_train = torch.tensor(y_vec, dtype=torch.float32)
t_train = torch.tensor(t_vec, dtype=torch.float32)
v1_train = torch.tensor(v1_vec, dtype=torch.float32)
v5_train = torch.tensor(v5_vec, dtype=torch.float32)

# Normalize
v1_mean, v1_std = v1_train.mean(), v1_train.std()
v5_mean, v5_std = v5_train.mean(), v5_train.std()
v1_train_norm = (v1_train - v1_mean) / v1_std
v5_train_norm = (v5_train - v5_mean) / v5_std

# === Train PINN with PDE ===
model = PINN(x_train, y_train, t_train, v1_train_norm, v5_train_norm, layers=LAYERS, use_pde=True)
model.setup_optimizers()
model.apply(model.xavier_init)

loss_fn = torch.nn.MSELoss()
for epoch in range(EPOCHS):
    loss = model.train_step(loss_fn, model.optimizer_ne, model.optimizer_Te, model.optimizer_f)
    if epoch % 2000 == 0:
        ne_loss = loss.get('ne', 0.0)
        Te_loss = loss.get('Te', 0.0)
        f1_loss = loss.get('f1', 0.0)
        f5_loss = loss.get('f5', 0.0)
        total_loss = loss.get('total', 0.0)
        print(f"Epoch {epoch}: Total = {total_loss:.4e}, Data = {ne_loss + Te_loss:.4e}, PDE = {f1_loss + f5_loss:.4e}")

# === Predict ===
with torch.no_grad():
    preds = model.predict(x_vec, y_vec, t_vec)
    v1_pred = preds['ne'] * v1_std.item() + v1_mean.item()
    v5_pred = preds['Te'] * v5_std.item() + v5_mean.item()
    phi_pred = preds['phi']
    v1_pred_grid = np.fliplr(v1_pred.reshape(NX, NY)).T
    v5_pred_grid = np.fliplr(v5_pred.reshape(NX, NY)).T
    phi_pred_grid = np.fliplr(phi_pred.reshape(NX, NY)).T

# === Plot ===
x_min, x_max = np.min(x_grid)
y_min, y_max = np.min(y_grid)
fig, axs = plt.subplots(2, 3, figsize=(18, 8))

# True ne
im0 = axs[0, 0].imshow(den_frame0, cmap='YlOrRd_r', vmin=0, vmax=3.5e19,
                       extent=[x_min, x_max, y_min, y_max], aspect='equal')
axs[0, 0].set_title("True $n_e$ at $t=0$")
axs[0, 0].set_xlabel("x (cm)")
axs[0, 0].set_ylabel("y (cm)")
plt.colorbar(im0, ax=axs[0, 0])

# Predicted ne
im1 = axs[0, 1].imshow(v1_pred_grid, origin='lower', cmap='YlOrRd_r', vmin=0, vmax=3.5e19,
                       extent=[x_min, x_max, y_min, y_max], aspect='equal')
axs[0, 1].set_title("PINN Predicted $n_e$ at $t=0$")
axs[0, 1].set_xlabel("x (cm)")
axs[0, 1].set_ylabel("y (cm)")
plt.colorbar(im1, ax=axs[0, 1])

# Predicted phi
im2 = axs[0, 2].imshow(phi_pred_grid, origin='lower', cmap='coolwarm',
                       extent=[x_min, x_max, y_min, y_max], aspect='equal')
axs[0, 2].set_title("PINN Predicted $\phi$ at $t=0$")
axs[0, 2].set_xlabel("x (cm)")
axs[0, 2].set_ylabel("y (cm)")
plt.colorbar(im2, ax=axs[0, 2])

# True Te
im3 = axs[1, 0].imshow(Te_frame0, cmap='YlOrRd_r', vmin=0, vmax=80,
                       extent=[x_min, x_max, y_min, y_max], aspect='equal')
axs[1, 0].set_title("True $T_e$ at $t=0$")
axs[1, 0].set_xlabel("x (cm)")
axs[1, 0].set_ylabel("y (cm)")
plt.colorbar(im3, ax=axs[1, 0])

# Predicted Te
im4 = axs[1, 1].imshow(v5_pred_grid, origin='lower', cmap='YlOrRd_r', vmin=0, vmax=80,
                       extent=[x_min, x_max, y_min, y_max], aspect='equal')
axs[1, 1].set_title("PINN Predicted $T_e$ at $t=0$")
axs[1, 1].set_xlabel("x (cm)")
axs[1, 1].set_ylabel("y (cm)")
plt.colorbar(im4, ax=axs[1, 1])

# Empty placeholder or additional field (e.g., residual)
axs[1, 2].axis('off')

plt.tight_layout()
plt.show()

Epoch 0: Total = 1.4887e+02, Data = 3.0643e+00, PDE = 1.4581e+02
Epoch 2000: Total = 8.7398e+04, Data = 3.0461e+01, PDE = 8.7367e+04


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In [None]:
# === Plot ===
x_min, x_max = np.min(x_grid)
y_min, y_max = np.min(y_grid)
fig, axs = plt.subplots(2, 3, figsize=(18, 8))

# True ne
im0 = axs[0, 0].imshow(den_frame0, cmap='YlOrRd_r', vmin=0, vmax=3.5e19,
                       extent=[x_min, x_max, y_min, y_max], aspect='equal')
axs[0, 0].set_title("True $n_e$ at $t=0$")
axs[0, 0].set_xlabel("x (cm)")
axs[0, 0].set_ylabel("y (cm)")
plt.colorbar(im0, ax=axs[0, 0])

# Predicted ne
im1 = axs[0, 1].imshow(v1_pred_grid, origin='lower', cmap='YlOrRd_r', vmin=0, vmax=3.5e19,
                       extent=[x_min, x_max, y_min, y_max], aspect='equal')
axs[0, 1].set_title("PINN Predicted $n_e$ at $t=0$")
axs[0, 1].set_xlabel("x (cm)")
axs[0, 1].set_ylabel("y (cm)")
plt.colorbar(im1, ax=axs[0, 1])

# Predicted phi
im2 = axs[0, 2].imshow(phi_pred_grid, origin='lower', cmap='coolwarm',
                       extent=[x_min, x_max, y_min, y_max], aspect='equal')
axs[0, 2].set_title("PINN Predicted $\phi$ at $t=0$")
axs[0, 2].set_xlabel("x (cm)")
axs[0, 2].set_ylabel("y (cm)")
plt.colorbar(im2, ax=axs[0, 2])

# True Te
im3 = axs[1, 0].imshow(Te_frame0, cmap='YlOrRd_r', vmin=0, vmax=80,
                       extent=[x_min, x_max, y_min, y_max], aspect='equal')
axs[1, 0].set_title("True $T_e$ at $t=0$")
axs[1, 0].set_xlabel("x (cm)")
axs[1, 0].set_ylabel("y (cm)")
plt.colorbar(im3, ax=axs[1, 0])

# Predicted Te
im4 = axs[1, 1].imshow(v5_pred_grid, origin='lower', cmap='YlOrRd_r', vmin=0, vmax=80,
                       extent=[x_min, x_max, y_min, y_max], aspect='equal')
axs[1, 1].set_title("PINN Predicted $T_e$ at $t=0$")
axs[1, 1].set_xlabel("x (cm)")
axs[1, 1].set_ylabel("y (cm)")
plt.colorbar(im4, ax=axs[1, 1])

# Empty placeholder or additional field (e.g., residual)
axs[1, 2].axis('off')

plt.tight_layout()
plt.show()