In [13]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.tri as mtri

# ─── 1) Define PINN Architecture ─────────────────────────────────────────────
L, W = 1.0, 0.5

class PINN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(2, 128),
            nn.Softplus(beta=10),
            nn.Linear(128, 256),
            nn.Softplus(beta=10),
            nn.Linear(256, 256),
            nn.Softplus(beta=10),
            nn.Linear(256, 256),
            nn.Softplus(beta=10),
            nn.Linear(256, 6)
        )

    def forward(self, x, y):
        xi  = 2.0 * x / L
        eta = 2.0 * y / W
        return self.net(torch.cat([xi, eta], dim=1))

# ─── 2) Load Trained Weights ─────────────────────────────────────────────────
model = PINN()
state = torch.load("pinn_elasticity.pth", map_location="cpu", weights_only=True)
model.load_state_dict(state)
model.eval()

# ─── 3) Load Actual Displacements from COMSOL CSV ──────────────────────────
csv_path = '/Users/murat/Downloads/data.csv'  # keep filename unchanged
df = pd.read_csv(csv_path, comment='%')

# Columns: X, Y, u1 (cm), u2 (cm), ...
X = df['X'].values
Y = df['Y'].values
u_act = df['u1 (cm)'].values  # u to compare
v_act = df['u2 (cm)'].values  # v to compare

# ─── 4) Prepare Input Tensors ───────────────────────────────────────────────
x_t = torch.tensor(X, dtype=torch.float32).reshape(-1,1)
y_t = torch.tensor(Y, dtype=torch.float32).reshape(-1,1)

# ─── 5) Forward Pass: Predict u, v ─────────────────────────────────────────
out = model(x_t, y_t)
# network outputs: [u_pred, v_pred, u_x, u_y, v_x, v_y]
u_pred = out[:,0].detach().numpy().flatten()
v_pred = out[:,1].detach().numpy().flatten()

df['u_pred'] = u_pred
df['v_pred'] = v_pred

# ─── 6) Compute Absolute & Percent Errors ───────────────────────────────────
df['err_u'] = np.abs(u_act - df['u_pred'])
df['err_v'] = np.abs(v_act - df['v_pred'])

df['pct_err_u'] = np.where(
    np.isclose(u_act, 0),
    0,
    df['err_u'] / np.abs(u_act) * 100
)
df['pct_err_v'] = np.where(
    np.isclose(v_act, 0),
    0,
    df['err_v'] / np.abs(v_act) * 100
)

# ─── 7) Display Percent Errors (printf-style) ───────────────────────────────
print(f"Mean % error in u: {df['pct_err_u'].mean():.2f}%  (max {df['pct_err_u'].max():.2f}%)")
print(f"Mean % error in v: {df['pct_err_v'].mean():.2f}%  (max {df['pct_err_v'].max():.2f}%)")

for i, (xi, yi, pu, pv) in enumerate(zip(X, Y, df['pct_err_u'], df['pct_err_v'])):
    print(f"Point {i}: X={xi:.3f}, Y={yi:.3f} → %err u={pu:.2f}%, v={pv:.2f}%")

# ─── 8) (Optional) Save comparison to CSV ─────────────────────────────────
# df.to_csv('pinn_uv_comparison.csv', index=False)

# ─── 9) Plotting Helpers ───────────────────────────────────────────────────
triang = mtri.Triangulation(X, Y)

def plot_contour(field, title, units=''):
    fig, ax = plt.subplots(figsize=(6,5))
    cf = ax.tricontourf(triang, field, levels=100, cmap='viridis')
    fig.colorbar(cf, ax=ax, label=units)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_title(title)
    ax.set_aspect('equal', 'box')
    plt.tight_layout()
    plt.show()

# ─── 10) Visualize Absolute Errors ───────────────────────────────────────────
plot_contour(df['err_u'], 'Absolute Error in u', '|u_act - u_pred|')
plot_contour(df['err_v'], 'Absolute Error in v', '|v_act - v_pred|')

# ─── 11) Parity Plots ──────────────────────────────────────────────────────
fig, axes = plt.subplots(1, 2, figsize=(12,5))
axes[0].scatter(df['u1 (cm)'], df['u_pred'], s=10)
axes[0].plot([
    df['u1 (cm)'].min(), df['u1 (cm)'].max()
], [
    df['u1 (cm)'].min(), df['u1 (cm)'].max()
], 'k--')
axes[0].set_xlabel('u1 (cm)')
axes[0].set_ylabel('u_pred (cm)')
axes[0].set_title('Parity Plot: u')

axes[1].scatter(df['u2 (cm)'], df['v_pred'], s=10)
axes[1].plot([
    df['u2 (cm)'].min(), df['u2 (cm)'].max()
], [
    df['u2 (cm)'].min(), df['u2 (cm)'].max()
], 'k--')
axes[1].set_xlabel('u2 (cm)')
axes[1].set_ylabel('v_pred (cm)')
axes[1].set_title('Parity Plot: v')

plt.tight_layout()
plt.show()


KeyError: '% X'