### Simple PINN but now with $\Delta b$ as input

In [27]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML, display
import xarray as xr


In [29]:
# Load dataset
ds = xr.open_dataset('../dat/RBC_Output.nc')

# Select depth (z) and horizontal (x) indices
z_vals = ds['z_aac'].values
x_vals = ds['x_caa'].values

top_100_z_idx = np.linspace(0, int(0.4 * len(z_vals)), 100, dtype=int)
z_sel_idx = top_100_z_idx[::2]  # 50 evenly spaced depth points
x_sel_idx = np.arange(0, len(x_vals), 2)  # 128 x points (even indices)

def slice_field(field, z_dim, x_dim):
    return field.isel({z_dim: z_sel_idx, x_dim: x_sel_idx}).values.astype(np.float32)

b = slice_field(ds['b'], 'z_aac', 'x_caa')
u = slice_field(ds['u'], 'z_aac', 'x_faa')
w = slice_field(ds['w'], 'z_aaf', 'x_caa')
p_dyn = slice_field(ds['p_dyn'], 'z_aac', 'x_caa')

time_vals = ds['time'].values.astype(np.float32)
depth_vals = ds['z_aac'].isel(z_aac=z_sel_idx).values
x_vals = ds['x_caa'].isel(x_caa=x_sel_idx).values


  ds = xr.open_dataset('../dat/RBC_Output.nc')


In [30]:
# Compute Δb
delta_b = b[1:] - b[:-1]
delta_t = time_vals[1:] - time_vals[:-1]
avg_time = (time_vals[1:] + time_vals[:-1]) / 2

u_mid = u[1:]
w_mid = w[1:]
p_mid = p_dyn[1:]

inputs = []
targets = []

for i in range(len(delta_b)):
    input_features = np.stack([u_mid[i], w_mid[i], p_mid[i]], axis=0)  # shape (3, z, x)
    inputs.append(input_features.flatten())
    targets.append(delta_b[i].flatten())

inputs = np.array(inputs)
targets = np.array(targets)

# Normalize inputs if needed
inputs_mean = inputs.mean(axis=0)
inputs_std = inputs.std(axis=0) + 1e-8
inputs_norm = (inputs - inputs_mean) / inputs_std

# Split train/test
train_idx = np.arange(6, 36)
test_idx = np.arange(36, 50)

X_train = torch.tensor(inputs_norm[train_idx], dtype=torch.float32).cuda()
Y_train = torch.tensor(targets[train_idx], dtype=torch.float32).cuda()

X_test = torch.tensor(inputs_norm[test_idx], dtype=torch.float32).cuda()
Y_test = torch.tensor(targets[test_idx], dtype=torch.float32).cuda()

z_len = len(z_sel_idx)
x_len = len(x_sel_idx)
dz = np.abs(depth_vals[1] - depth_vals[0])
dx = np.abs(x_vals[1] - x_vals[0])
κ = 1e-5  # Diffusivity


In [31]:
class PINN(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=512, num_layers=4):
        super(PINN, self).__init__()
        layers = [nn.Linear(input_dim, hidden_dim), nn.Tanh()]
        for _ in range(num_layers - 1):
            layers += [nn.Linear(hidden_dim, hidden_dim), nn.Tanh()]
        layers.append(nn.Linear(hidden_dim, output_dim))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

model = PINN(input_dim=3*z_len*x_len, output_dim=z_len*x_len).cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
mse = nn.MSELoss()


In [34]:
def compute_physics_loss(model, X_batch, z_len=z_len, x_len=x_len, κ=κ, dz=dz, dx=dx):
    # Predict Δb
    b_pred = model(X_batch).reshape(-1, z_len, x_len)

    # Input fields
    u = X_batch[:, :z_len * x_len].reshape(-1, z_len, x_len)
    w = X_batch[:, z_len * x_len : 2 * z_len * x_len].reshape(-1, z_len, x_len)

    # To compute central differences safely, crop all equally:
    b_center = b_pred[:, 1:-1, 1:-1]    # shape: (batch, z-2, x-2)
    u_center = u[:, 1:-1, 1:-1]         # shape: (batch, z-2, x-2)
    w_center = w[:, 1:-1, 1:-1]         # shape: (batch, z-2, x-2)

    # Gradients: match shapes to (batch, z-2, x-2)
    b_x = (b_pred[:, 1:-1, 2:] - b_pred[:, 1:-1, :-2]) / (2 * dx)
    b_z = (b_pred[:, 2:, 1:-1] - b_pred[:, :-2, 1:-1]) / (2 * dz)

    advection = u_center * b_x + w_center * b_z

    # Second derivatives (Laplacian)
    b_xx = (b_pred[:, 1:-1, 2:] - 2 * b_pred[:, 1:-1, 1:-1] + b_pred[:, 1:-1, :-2]) / (dx**2)
    b_zz = (b_pred[:, 2:, 1:-1] - 2 * b_pred[:, 1:-1, 1:-1] + b_pred[:, :-2, 1:-1]) / (dz**2)

    diffusion = κ * (b_xx + b_zz)

    residual = advection - diffusion
    phy_loss = torch.mean(residual**2)
    
    return phy_loss


In [38]:
num_epochs = 1000
obs_weight = 0.01

for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()

    pred_train = model(X_train)
    obs_loss = mse(pred_train, Y_train)

    phy_loss = compute_physics_loss(model, X_train)
    ic_loss = torch.tensor(0.0, device=X_train.device)

    total_loss = phy_loss * 2 + ic_loss + obs_weight * obs_loss
    total_loss.backward()
    optimizer.step()

    if epoch % 100 == 0:
        print(f"Epoch {epoch} | Obs Loss: {obs_loss.item():.6f} | Phy Loss: {phy_loss.item():.6f} | Total: {total_loss.item():.6f}")


Epoch 0 | Obs Loss: 0.000005 | Phy Loss: 0.000080 | Total: 0.000160
Epoch 100 | Obs Loss: 0.000005 | Phy Loss: 0.000118 | Total: 0.000235
Epoch 200 | Obs Loss: 0.000004 | Phy Loss: 0.000058 | Total: 0.000117
Epoch 300 | Obs Loss: 0.000004 | Phy Loss: 0.000042 | Total: 0.000083
Epoch 400 | Obs Loss: 0.000003 | Phy Loss: 0.000032 | Total: 0.000064
Epoch 500 | Obs Loss: 0.000003 | Phy Loss: 0.000026 | Total: 0.000051
Epoch 600 | Obs Loss: 0.000003 | Phy Loss: 0.000021 | Total: 0.000042
Epoch 700 | Obs Loss: 0.000003 | Phy Loss: 0.000387 | Total: 0.000774
Epoch 800 | Obs Loss: 0.000003 | Phy Loss: 0.000199 | Total: 0.000398
Epoch 900 | Obs Loss: 0.000002 | Phy Loss: 0.000013 | Total: 0.000026


In [39]:
# Save the model
torch.save(model.state_dict(), 'pinn_2Dprofile_model.pth')

In [41]:
model.eval()
with torch.no_grad():
    pred_test = model(X_test).cpu().numpy().reshape(len(X_test), z_len, x_len)
    true_test = Y_test.cpu().numpy().reshape(len(Y_test), z_len, x_len)

z_norm = (depth_vals - depth_vals.min()) / (depth_vals.max() - depth_vals.min())
x_norm = (x_vals - x_vals.min()) / (x_vals.max() - x_vals.min())

fig, axes = plt.subplots(1, 2, figsize=(12, 5))
ims = []

for i in range(len(pred_test)):
    im1 = axes[0].imshow(true_test[i], cmap='RdBu_r', origin='lower', extent=[0, 1, 0, 1], animated=True)
    axes[0].set_title('True b')
    im2 = axes[1].imshow(pred_test[i], cmap='RdBu_r', origin='lower', extent=[0, 1, 0, 1], animated=True)
    axes[1].set_title('Predicted b')
    ims.append([im1, im2])

for ax in axes:
    ax.set_xlabel('x (normalized)')
    ax.set_ylabel('z (normalized)')

ani = animation.ArtistAnimation(fig, ims, interval=300, blit=True)
plt.close(fig)
display(HTML(ani.to_jshtml()))
