## Yuhao Test  with same PDE 

In [None]:
import pandas as pd
import numpy as np
import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

# ------------------------ 数据加载与预处理 ------------------------

def prepare_data(file_path, batch_size=32, test_size=0.2, group_key='mydata'):
    df = pd.read_hdf(file_path, group_key)
    n_samples, T = len(df), 180

    X = np.zeros((n_samples, T, 2), dtype=np.float32)
    for i, row in df.iterrows():
        X[i, :, 0] = np.array(row['freq'], dtype=np.float32)
        X[i, :, 1] = np.array(row['m_c'], dtype=np.float32)
    Y = df[['w_t', 'l_t', 'Q', 'V']].values.astype(np.float32)

    Xtr, Xte, Ytr, Yte = train_test_split(X, Y, test_size=test_size, random_state=42)

    sf = StandardScaler(); sm = StandardScaler(); st = StandardScaler()
    flat_tr = Xtr.reshape(-1, 2); flat_te = Xte.reshape(-1, 2)
    flat_tr[:, 0] = sf.fit_transform(flat_tr[:, 0:1]).ravel()
    flat_tr[:, 1] = sm.fit_transform(flat_tr[:, 1:2]).ravel()
    flat_te[:, 0] = sf.transform(flat_te[:, 0:1]).ravel()
    flat_te[:, 1] = sm.transform(flat_te[:, 1:2]).ravel()
    Xtr = flat_tr.reshape(-1, T, 2)
    Xte = flat_te.reshape(-1, T, 2)
    Ytr = st.fit_transform(Ytr)
    Yte = st.transform(Yte)

    tr_loader = DataLoader(
        TensorDataset(torch.from_numpy(Xtr), torch.from_numpy(Ytr)),
        batch_size=batch_size, shuffle=True)
    te_loader = DataLoader(
        TensorDataset(torch.from_numpy(Xte), torch.from_numpy(Yte)),
        batch_size=batch_size)

    return tr_loader, te_loader, sf, sm, st

# ------------------------ 物理常数（部分可调整） ------------------------

def get_constants(device):
    return {
        'E': torch.tensor(169e9, dtype=torch.float32, device=device),
        'rho': torch.tensor(2330, dtype=torch.float32, device=device),
        't': torch.tensor(25e-6, dtype=torch.float32, device=device),
        'beta': torch.tensor(4.730041, dtype=torch.float32, device=device),
        'm_coef_b': 0.39648,
        'k_coef_b': 198.4629,
        'k_coef_b3': 12.5643,
        'electrode_length': torch.tensor(1e-3, dtype=torch.float32, device=device),
        'electrode_width': torch.tensor(1e-3, dtype=torch.float32, device=device),
        'w_c': torch.tensor(1e-3, dtype=torch.float32, device=device),
        'l_c': torch.tensor(1e-3, dtype=torch.float32, device=device),
        'd': torch.tensor(1e-6, dtype=torch.float32, device=device),
        'Vac_ground': torch.tensor(1e-3, dtype=torch.float32, device=device),
        'phi': torch.linspace(
            math.pi/10,
            math.pi/10 + math.pi/160,
            180, device=device)
    }



In [None]:
# ------------------------ 神经网络模型 ------------------------

class PINN_MultiHead(nn.Module):
    def __init__(self, hidden_dim=256, T=180):
        super().__init__()
        self.trunk = nn.Sequential(
            nn.Flatten(),
            nn.Linear(2*T, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim//2), nn.ReLU(),
            nn.Linear(hidden_dim//2, hidden_dim//4), nn.ReLU()
        )
        self.heads = nn.ModuleList([nn.Linear(hidden_dim//4, 1) for _ in range(4)])

    def forward(self, x):
        z = self.trunk(x)
        return torch.cat([h(z) for h in self.heads], dim=1)

# ------------------------ 参数反归一化工具 ------------------------

def calculate_params_vectorized(w_n, l_n, Q_n, V_n, consts, norm_sfs):
    _, _, st = norm_sfs
    out_mean = torch.tensor(st.mean_, device=w_n.device, dtype=torch.float32)
    out_std  = torch.tensor(st.scale_, device=w_n.device, dtype=torch.float32)
    w = w_n * out_std[0] + out_mean[0]
    l = l_n * out_std[1] + out_mean[1]
    Q = Q_n * out_std[2] + out_mean[2]
    V = V_n * out_std[3] + out_mean[3]
    eps = 1e-9

    k_t  = (consts['k_coef_b'] / 12) * consts['E'] * consts['t'] * (w / (l + eps))**3
    k_t3 =  consts['k_coef_b3'] * consts['E'] * consts['t'] * (w / (l + eps)**3)
    Mass = consts['rho'] * (
        consts['t'] * w * l * consts['m_coef_b']
        + consts['electrode_length'] * consts['electrode_width'] * consts['t']
        + 2 * consts['w_c'] * consts['l_c'] * consts['t']
    )
    trans = 8.85e-12 * V * consts['electrode_length'] * consts['t'] / (consts['d']**2)
    k_e  = 2 * trans * V / consts['d']
    k_e3 = 4 * trans * V / (consts['d']**3)
    fac  = consts['Vac_ground'] * trans
    c    = torch.sqrt(Mass * k_t) / Q

    return {
        'k_t':   k_t, 'k_t3':  k_t3, 'Mass':  Mass, 'k_e':   k_e, 'k_e3':  k_e3,
        'fac':   fac, 'c':     c,    'trans': trans, 'phi':  consts['phi']
    }



In [None]:
# ------------------------ 物理PDE损失函数（核心！） ------------------------

def spring_mass_damper_pde_loss(pred, Xb, constants, norm_sfs, denom_clamp=1e-4, loss_mode='clamp', clamp_val=1e-3):
    """
    对应非线性弹簧-质量-阻尼动力学方程的物理loss，适用于PINN
    """
    sf, sm, st = norm_sfs
    freq  = Xb[:,:,0] * sf.scale_[0] + sf.mean_[0]
    mc    = Xb[:,:,1] * sm.scale_[0] + sm.mean_[0]
    omega = freq * (2 * math.pi)

    w_n, l_n, Q_n, V_n = pred[:,0], pred[:,1], pred[:,2], pred[:,3]
    p = calculate_params_vectorized(w_n, l_n, Q_n, V_n, constants, norm_sfs)

    denom = (omega * p['trans'].unsqueeze(1)).clamp(min=denom_clamp)
    A = (mc * 1e-9) / denom

    kt, ke, kt3, ke3, mass, c, fac = [p[k].unsqueeze(1) for k in ['k_t','k_e','k_t3','k_e3','Mass','c','fac']]
    phi = p['phi'].unsqueeze(0)

    # 图中公式的sin/cos项残差
    sin_res = -mass * A * omega**2 + (kt - ke) * A + (3/4) * (kt3 - ke3) * (A**3) - fac * torch.cos(phi)
    cos_res = c * A * omega - fac * torch.sin(phi)

    res_all = sin_res.pow(2) + cos_res.pow(2)
    if loss_mode == 'clamp':
        loss = res_all.clamp(max=clamp_val).mean()
    else:
        loss = res_all.mean()
    return loss

# ------------------------ 综合loss封装 ------------------------

def pinn_total_loss(pred, Yb, Xb, constants, norm_sfs, 
                    lambda_data=1.0, lambda_phys=1.0,
                    denom_clamp=1e-4, loss_mode='clamp', clamp_val=1e-3):
    """
    总损失 = 数据MSE + 物理PDE残差loss
    """
    L_data = nn.MSELoss()(pred, Yb)
    L_phys = spring_mass_damper_pde_loss(pred, Xb, constants, norm_sfs, denom_clamp, loss_mode, clamp_val)
    L_total = lambda_data * L_data + lambda_phys * L_phys
    return L_total, L_data, L_phys



In [None]:
# ------------------------ 训练与验证循环 ------------------------

def train_one_epoch(model, loader, optimizer, constants, norm_sfs, 
                   lambda_data, lambda_phys, denom_clamp, loss_mode, clamp_val, device):
    model.train()
    total_loss, total_data_loss, total_phys_loss = 0.0, 0.0, 0.0
    for Xb, Yb in loader:
        Xb, Yb = Xb.to(device), Yb.to(device)
        optimizer.zero_grad()
        pred = model(Xb)
        L_total, L_data, L_phys = pinn_total_loss(
            pred, Yb, Xb, constants, norm_sfs, lambda_data, lambda_phys, denom_clamp, loss_mode, clamp_val
        )
        L_total.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        total_loss += L_total.item()
        total_data_loss += L_data.item()
        total_phys_loss += L_phys.item()
    n = len(loader)
    return total_loss / n, total_data_loss / n, total_phys_loss / n

@torch.no_grad()
def validate_one_epoch(model, loader, constants, norm_sfs,
                      lambda_data, lambda_phys, denom_clamp, loss_mode, clamp_val, device):
    model.eval()
    total_loss, total_data_loss, total_phys_loss = 0.0, 0.0, 0.0
    for Xb, Yb in loader:
        Xb, Yb = Xb.to(device), Yb.to(device)
        pred = model(Xb)
        L_total, L_data, L_phys = pinn_total_loss(
            pred, Yb, Xb, constants, norm_sfs, lambda_data, lambda_phys, denom_clamp, loss_mode, clamp_val
        )
        total_loss += L_total.item()
        total_data_loss += L_data.item()
        total_phys_loss += L_phys.item()
    n = len(loader)
    return total_loss / n, total_data_loss / n, total_phys_loss / n

# ------------------------ 主训练流程 ------------------------

def train_pinn(file_path, device, 
               lr=1e-3, weight_decay=1e-5, hidden_dim=256,
               lambda_data=1.0, lambda_phys=1.0, 
               denom_clamp=1e-4, loss_mode='clamp', clamp_val=1e-3,
               batch_size=32, epochs=500, patience=20):
    tr_loader, te_loader, sf, sm, st = prepare_data(file_path, batch_size=batch_size)
    norm_sfs = (sf, sm, st)
    constants = get_constants(device)

    model = PINN_MultiHead(hidden_dim=hidden_dim, T=180).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)

    history = {'train_loss': [], 'val_loss': [], 'train_data': [], 'val_data': [], 'train_phys': [], 'val_phys': []}
    best_val = float('inf'); no_improv = 0

    for epoch in range(epochs):
        tr_loss, tr_data, tr_phys = train_one_epoch(
            model, tr_loader, optimizer, constants, norm_sfs, lambda_data, lambda_phys, denom_clamp, loss_mode, clamp_val, device)
        val_loss, val_data, val_phys = validate_one_epoch(
            model, te_loader, constants, norm_sfs, lambda_data, lambda_phys, denom_clamp, loss_mode, clamp_val, device)
        scheduler.step(val_loss)
        history['train_loss'].append(tr_loss)
        history['val_loss'].append(val_loss)
        history['train_data'].append(tr_data)
        history['val_data'].append(val_data)
        history['train_phys'].append(tr_phys)
        history['val_phys'].append(val_phys)
        print(f"Epoch {epoch+1}/{epochs} | Train: {tr_loss:.4g} (data {tr_data:.4g}, phys {tr_phys:.4g}) "
              f"| Val: {val_loss:.4g} (data {val_data:.4g}, phys {val_phys:.4g})")
        if val_loss < best_val - 1e-5:
            best_val = val_loss
            no_improv = 0
            torch.save(model.state_dict(), 'best_pinn.pth')
        else:
            no_improv += 1
            if no_improv > patience:
                print("Early stopping!")
                break
    return model, history



In [None]:
# ------------------------ 可视化工具 ------------------------

def plot_training_history(history):
    epochs = range(1, len(history['train_loss']) + 1)
    plt.figure(figsize=(12, 8))
    plt.subplot(2, 2, 1)
    plt.plot(epochs, history['train_loss'], label='Train')
    plt.plot(epochs, history['val_loss'], label='Val')
    plt.title('Total Loss'); plt.legend(); plt.grid()
    plt.subplot(2, 2, 2)
    plt.plot(epochs, history['train_data'], label='Train Data')
    plt.plot(epochs, history['val_data'], label='Val Data')
    plt.title('Data Loss'); plt.legend(); plt.grid()
    plt.subplot(2, 2, 3)
    plt.plot(epochs, history['train_phys'], label='Train Phys')
    plt.plot(epochs, history['val_phys'], label='Val Phys')
    plt.title('Physics Loss'); plt.legend(); plt.grid()
    plt.tight_layout(); plt.show()

# ------------------------ 主程序入口 ------------------------



In [None]:
if __name__ == "__main__":
    # 你的.h5文件路径
    file_path = 'pinns_1.h5'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model, history = train_pinn(
        file_path=file_path,
        device=device,
        lr=1e-3,
        weight_decay=1e-5,
        hidden_dim=256,
        lambda_data=1.0,
        lambda_phys=1.0,
        denom_clamp=1e-4,
        loss_mode='clamp',
        clamp_val=1e-3,
        batch_size=32,
        epochs=200,
        patience=20,
    )
    plot_training_history(history)


## 更改 v1

In [37]:
import pandas as pd
import numpy as np
import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

# ------------------------ 数据加载与预处理（nan保留） ------------------------

def prepare_data(file_path, batch_size=32, test_size=0.2, group_key='mydata'):
    df = pd.read_hdf(file_path, group_key)
    df = df[df[['w_t','l_t','Q','V']].notnull().all(axis=1)].reset_index(drop=True)
    n_samples, T = len(df), 180

    X = np.zeros((n_samples, T, 2), dtype=np.float32)
    for i, row in df.iterrows():
        X[i, :, 0] = np.array(row['freq'], dtype=np.float32)  # 这里允许nan
        X[i, :, 1] = np.array(row['m_c'], dtype=np.float32)
    Y = df[['w_t', 'l_t', 'Q', 'V']].values.astype(np.float32)

    Xtr, Xte, Ytr, Yte = train_test_split(X, Y, test_size=test_size, random_state=42)

    # 标准化freq和m_c时跳过nan
    def nan_scale_fit(x): return (np.nanmean(x), np.nanstd(x) + 1e-12)
    freq_mean, freq_std = nan_scale_fit(Xtr[..., 0])
    mc_mean, mc_std     = nan_scale_fit(Xtr[..., 1])
    for arr in [Xtr, Xte]:
        arr[..., 0] = (arr[..., 0] - freq_mean) / freq_std
        arr[..., 1] = (arr[..., 1] - mc_mean)   / mc_std

    # 目标归一化不涉及nan
    y_mean, y_std = np.mean(Ytr, axis=0), np.std(Ytr, axis=0) + 1e-12
    Ytr = (Ytr - y_mean) / y_std
    Yte = (Yte - y_mean) / y_std

    tr_loader = DataLoader(
        TensorDataset(torch.from_numpy(Xtr), torch.from_numpy(Ytr)),
        batch_size=batch_size, shuffle=True)
    te_loader = DataLoader(
        TensorDataset(torch.from_numpy(Xte), torch.from_numpy(Yte)),
        batch_size=batch_size)
    # 用标量均值/方差对象返回，方便后续反归一化
    norm_dict = {
        'freq': (freq_mean, freq_std),
        'mc':   (mc_mean, mc_std),
        'y':    (y_mean, y_std)
    }
    return tr_loader, te_loader, norm_dict

# ------------------------ 物理常数 ------------------------

def get_constants(device):
    return {
        'E': torch.tensor(169e9, dtype=torch.float32, device=device),
        'rho': torch.tensor(2330, dtype=torch.float32, device=device),
        't': torch.tensor(25e-6, dtype=torch.float32, device=device),
        'beta': torch.tensor(4.730041, dtype=torch.float32, device=device),
        'm_coef_b': 0.39648,
        'k_coef_b': 198.4629,
        'k_coef_b3': 12.5643,
        'electrode_length': torch.tensor(1e-3, dtype=torch.float32, device=device),
        'electrode_width': torch.tensor(1e-3, dtype=torch.float32, device=device),
        'w_c': torch.tensor(1e-3, dtype=torch.float32, device=device),
        'l_c': torch.tensor(1e-3, dtype=torch.float32, device=device),
        'd': torch.tensor(1e-6, dtype=torch.float32, device=device),
        'Vac_ground': torch.tensor(1e-3, dtype=torch.float32, device=device),
        'phi': torch.linspace(
            math.pi/10,
            math.pi/10 + math.pi/160,
            180, device=device)
    }

# ------------------------ 网络结构 ------------------------

class PINN_MultiHead(nn.Module):
    def __init__(self, hidden_dim=256, T=180):
        super().__init__()
        self.trunk = nn.Sequential(
            nn.Flatten(),
            nn.Linear(2*T, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim//2), nn.ReLU(),
            nn.Linear(hidden_dim//2, hidden_dim//4), nn.ReLU()
        )
        self.heads = nn.ModuleList([nn.Linear(hidden_dim//4, 1) for _ in range(4)])

    def forward(self, x):
        z = self.trunk(x)
        return torch.cat([h(z) for h in self.heads], dim=1)

# ------------------------ 参数反归一化 ------------------------

def denorm_pred(pred, norm_dict):
    y_mean = torch.tensor(norm_dict['y'][0], device=pred.device, dtype=torch.float32)
    y_std  = torch.tensor(norm_dict['y'][1], device=pred.device, dtype=torch.float32)
    return pred * y_std + y_mean

# ------------------------ nan mask数据loss ------------------------

def robust_mse(pred, target):
    mask = torch.isfinite(pred) & torch.isfinite(target)
    n_valid = mask.sum()
    if n_valid == 0:
        return torch.tensor(0., device=pred.device)
    diff = (pred - target)[mask]
    return (diff ** 2).mean()

# ------------------------ nan mask物理loss ------------------------

def physics_loss_with_nan_penalty(freq_res, disp_res, penalty=1.0):
    valid = torch.isfinite(freq_res) & torch.isfinite(disp_res)
    f = torch.where(valid, freq_res, torch.zeros_like(freq_res))
    d = torch.where(valid, disp_res, torch.zeros_like(disp_res))
    sq = f.pow(2) + d.pow(2)
    sum_sq = sq.sum(dim=1)
    n_valid = valid.sum(dim=1).float()
    T = freq_res.size(1)
    n_invalid = T - n_valid
    avg_valid = sum_sq / torch.clamp(n_valid, min=1.0)
    frac_invalid = n_invalid / T
    loss_per_sample = avg_valid + penalty * frac_invalid
    return loss_per_sample.mean()

# ------------------------ 物理PDE残差 ------------------------

def spring_mass_damper_pde_residuals(pred, Xb, norm_dict, consts, denom_clamp=1e-4):
    # 反标准化freq/mc
    freq_mean, freq_std = norm_dict['freq']
    mc_mean, mc_std = norm_dict['mc']
    freq = Xb[:,:,0] * freq_std + freq_mean
    mc   = Xb[:,:,1] * mc_std   + mc_mean
    omega = freq * (2 * math.pi)

    # 反归一化 pred
    w_n, l_n, Q_n, V_n = [pred[:,i] for i in range(4)]
    y_mean, y_std = norm_dict['y']
    w = w_n * y_std[0] + y_mean[0]
    l = l_n * y_std[1] + y_mean[1]
    Q = Q_n * y_std[2] + y_mean[2]
    V = V_n * y_std[3] + y_mean[3]

    eps = 1e-9
    k_t  = (consts['k_coef_b'] / 12) * consts['E'] * consts['t'] * (w / (l + eps))**3
    k_t3 =  consts['k_coef_b3'] * consts['E'] * consts['t'] * (w / (l + eps)**3)
    Mass = consts['rho'] * (
        consts['t'] * w * l * consts['m_coef_b']
        + consts['electrode_length'] * consts['electrode_width'] * consts['t']
        + 2 * consts['w_c'] * consts['l_c'] * consts['t']
    )
    trans = 8.85e-12 * V * consts['electrode_length'] * consts['t'] / (consts['d']**2)
    k_e  = 2 * trans * V / consts['d']
    k_e3 = 4 * trans * V / (consts['d']**3)
    fac  = consts['Vac_ground'] * trans
    c    = torch.sqrt(Mass * k_t) / Q

    denom = (omega * trans.unsqueeze(1)).clamp(min=denom_clamp)
    A = (mc * 1e-9) / denom

    # 扩展物理参数形状
    kt, ke, kt3, ke3, mass, c_, fac_ = [x.unsqueeze(1) for x in [k_t,k_e,k_t3,k_e3,Mass,c,fac]]
    phi = consts['phi'].unsqueeze(0)

    sin_res = -mass * A * omega**2 + (kt - ke) * A + (3/4) * (kt3 - ke3) * (A**3) - fac_ * torch.cos(phi)
    cos_res = c_ * A * omega - fac_ * torch.sin(phi)
    return cos_res, sin_res

# ------------------------ 综合loss ------------------------

def pinn_total_loss(pred, Yb, Xb, norm_dict, consts, penalty=1.0, denom_clamp=1e-4):
    L_data = robust_mse(pred, Yb)
    freq_res, disp_res = spring_mass_damper_pde_residuals(pred, Xb, norm_dict, consts, denom_clamp)
    L_phys = physics_loss_with_nan_penalty(freq_res, disp_res, penalty=penalty)
    L_total = L_data + L_phys
    return L_total, L_data, L_phys

# ------------------------ 训练与验证循环 ------------------------

def train_one_epoch(model, loader, optimizer, norm_dict, consts, penalty, denom_clamp, device):
    model.train()
    total_loss, total_data_loss, total_phys_loss = 0.0, 0.0, 0.0
    for Xb, Yb in loader:
        Xb, Yb = Xb.to(device), Yb.to(device)
        optimizer.zero_grad()
        pred = model(Xb)
        L_total, L_data, L_phys = pinn_total_loss(
            pred, Yb, Xb, norm_dict, consts, penalty, denom_clamp
        )
        L_total.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        total_loss += L_total.item()
        total_data_loss += L_data.item()
        total_phys_loss += L_phys.item()
    n = len(loader)
    return total_loss / n, total_data_loss / n, total_phys_loss / n

@torch.no_grad()
def validate_one_epoch(model, loader, norm_dict, consts, penalty, denom_clamp, device):
    model.eval()
    total_loss, total_data_loss, total_phys_loss = 0.0, 0.0, 0.0
    for Xb, Yb in loader:
        Xb, Yb = Xb.to(device), Yb.to(device)
        pred = model(Xb)
        L_total, L_data, L_phys = pinn_total_loss(
            pred, Yb, Xb, norm_dict, consts, penalty, denom_clamp
        )
        total_loss += L_total.item()
        total_data_loss += L_data.item()
        total_phys_loss += L_phys.item()
    n = len(loader)
    return total_loss / n, total_data_loss / n, total_phys_loss / n

# ------------------------ 主训练流程 ------------------------

def train_pinn(file_path, device, 
               lr=1e-3, weight_decay=1e-5, hidden_dim=256,
               penalty=1.0, denom_clamp=1e-4,
               batch_size=32, epochs=500, patience=20):
    tr_loader, te_loader, norm_dict = prepare_data(file_path, batch_size=batch_size)
    consts = get_constants(device)

    model = PINN_MultiHead(hidden_dim=hidden_dim, T=180).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)

    history = {'train_loss': [], 'val_loss': [], 'train_data': [], 'val_data': [], 'train_phys': [], 'val_phys': []}
    best_val = float('inf'); no_improv = 0

    for epoch in range(epochs):
        tr_loss, tr_data, tr_phys = train_one_epoch(
            model, tr_loader, optimizer, norm_dict, consts, penalty, denom_clamp, device)
        val_loss, val_data, val_phys = validate_one_epoch(
            model, te_loader, norm_dict, consts, penalty, denom_clamp, device)
        scheduler.step(val_loss)
        history['train_loss'].append(tr_loss)
        history['val_loss'].append(val_loss)
        history['train_data'].append(tr_data)
        history['val_data'].append(val_data)
        history['train_phys'].append(tr_phys)
        history['val_phys'].append(val_phys)
        print(f"Epoch {epoch+1}/{epochs} | Train: {tr_loss:.4g} (data {tr_data:.4g}, phys {tr_phys:.4g}) "
              f"| Val: {val_loss:.4g} (data {val_data:.4g}, phys {val_phys:.4g})")
        if val_loss < best_val - 1e-5:
            best_val = val_loss
            no_improv = 0
            torch.save(model.state_dict(), 'best_pinn.pth')
        else:
            no_improv += 1
            if no_improv > patience:
                print("Early stopping!")
                break
    return model, history

# ------------------------ 可视化 ------------------------

def plot_training_history(history):
    epochs = range(1, len(history['train_loss']) + 1)
    plt.figure(figsize=(12, 8))
    plt.subplot(2, 2, 1)
    plt.plot(epochs, history['train_loss'], label='Train')
    plt.plot(epochs, history['val_loss'], label='Val')
    plt.title('Total Loss'); plt.legend(); plt.grid()
    plt.subplot(2, 2, 2)
    plt.plot(epochs, history['train_data'], label='Train Data')
    plt.plot(epochs, history['val_data'], label='Val Data')
    plt.title('Data Loss'); plt.legend(); plt.grid()
    plt.subplot(2, 2, 3)
    plt.plot(epochs, history['train_phys'], label='Train Phys')
    plt.plot(epochs, history['val_phys'], label='Val Phys')
    plt.title('Physics Loss'); plt.legend(); plt.grid()
    plt.tight_layout(); plt.show()

# ------------------------ 主程序入口 ------------------------

if __name__ == "__main__":
    file_path = 'pinns_1.h5'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model, history = train_pinn(
        file_path=file_path,
        device=device,
        lr=1e-3,
        weight_decay=1e-5,
        hidden_dim=256,
        penalty=1.0,
        denom_clamp=1e-4,
        batch_size=32,
        epochs=200,
        patience=20,
    )
    plot_training_history(history)


Epoch 1/200 | Train: 0.9991 (data 0.006438, phys 0.9926) | Val: 1 (data 0, phys 1)
Epoch 2/200 | Train: 1 (data 0, phys 1) | Val: 1 (data 0, phys 1)
Epoch 3/200 | Train: 1 (data 0, phys 1) | Val: 1 (data 0, phys 1)


KeyboardInterrupt: 