In [1]:
import torch
import torch.nn as nn
from torch.autograd import grad

In [4]:
# Multi-layer Perceptron
class PINN(nn.Module):
    """
    Physics-Informed Neural Network (PINN) that models one scalar physical quantity
    (e.g., electron density n_e, electron temperature T_e, or electric potential φ)
    as a function of spatiotemporal coordinates (x, y, t).

    This class corresponds to a single fully-connected neural network used in the
    multi-network PINN architecture described in:

        "Uncovering turbulent plasma dynamics via deep learning from partial observations"
        by Mathews et al.

    Each PINN instance:
        - Takes input tensor (x, y, t)
        - Returns a scalar output corresponding to the predicted field value
        - Uses 5 hidden layers with 50 neurons and tanh activations

    In the original paper:
        - Networks for n_e and T_e are supervised by observational data and physics loss
        - The φ network is learned implicitly through its role in the PDEs (not supervised directly)
    """
    def __init__(self, layers=[3, 50, 50, 50, 50, 1]):
        super().__init__()
        modules = []
        for i in range(len(layers) - 2):
            modules.append(nn.Linear(layers[i], layers[i + 1]))
            modules.append(nn.Tanh())
        modules.append(nn.Linear(layers[-2], layers[-1]))
        self.net = nn.Sequential(*modules)

    def forward(self, xyt):
        return self.net(xyt)    

In [5]:
# Losses
def compute_grad(output, coords):
    return grad(output, coords, torch.ones_like(output), create_graph=True)[0]

def physics_residual_ne(ne_net, Te_net, phi_net, coords):
    coords.requires_grad_(True)
    ne = ne_net(coords)
    phi = phi_net(coords)

    dt_ne = compute_grad(ne, coords)[:, 2]  # ∂ne/∂t
    # Placeholder residual, replace with full PDE terms
    residual = dt_ne + phi
    return torch.mean(residual ** 2)

def data_loss(pred, target):
    return torch.mean((pred - target) ** 2)

In [6]:
# Create networks
ne_net = PINN()
Te_net = PINN()
phi_net = PINN()

# Dummy data
xyt_obs = torch.rand(500, 3)
ne_obs = torch.sin(xyt_obs[:, 0])
xyt_phys = torch.rand(500, 3)

# Optimizer
optimizer = torch.optim.Adam(
    list(ne_net.parameters()) + 
    list(Te_net.parameters()) + 
    list(phi_net.parameters()), lr=1e-3
)

for epoch in range(1000):
    optimizer.zero_grad()
    ne_pred = ne_net(xyt_obs).squeeze()
    loss_d = data_loss(ne_pred, ne_obs)
    loss_p = physics_residual_ne(ne_net, Te_net, phi_net, xyt_phys)
    total_loss = loss_d + loss_p
    total_loss.backward()
    optimizer.step()

    if epoch % 100 == 0:
        print(f"Epoch {epoch}, Loss: {total_loss.item():.4e}")

Epoch 0, Loss: 3.6717e-01
Epoch 100, Loss: 4.6088e-04
Epoch 200, Loss: 1.4676e-04
Epoch 300, Loss: 8.9036e-05
Epoch 400, Loss: 6.3831e-05
Epoch 500, Loss: 4.7203e-05
Epoch 600, Loss: 3.6377e-05
Epoch 700, Loss: 2.9317e-05
Epoch 800, Loss: 2.4488e-05
Epoch 900, Loss: 2.0926e-05
