In [3]:
import torch
import torch.nn as nn
import numpy as np

# 模拟DEM数据（替换为真实数据）
x_grid = np.linspace(0, 100, 50)
y_grid = np.linspace(0, 100, 50)
z_grid = np.random.rand(50, 50)  # 替换为真实DEM
z_grid = torch.tensor(z_grid, dtype=torch.float32)

# 双线性插值
def get_z(x, y, x_grid, y_grid, z_grid):
    x = x.squeeze(-1)  # [N] or [N, 1] -> [N]
    y = y.squeeze(-1)  # [N] or [N, 1] -> [N]
    
    # 归一化到网格索引
    x_idx = (x / 100.0) * (len(x_grid) - 1)
    y_idx = (y / 100.0) * (len(y_grid) - 1)
    
    # 找到最近的网格点
    x0 = torch.floor(x_idx).long().clamp(0, len(x_grid) - 2)
    x1 = x0 + 1
    y0 = torch.floor(y_idx).long().clamp(0, len(y_grid) - 2)
    y1 = y0 + 1
    
    # 双线性插值
    x0_val = torch.tensor(x_grid[x0], dtype=torch.float32)
    x1_val = torch.tensor(x_grid[x1], dtype=torch.float32)
    y0_val = torch.tensor(y_grid[y0], dtype=torch.float32)
    y1_val = torch.tensor(y_grid[y1], dtype=torch.float32)
    
    # 插值权重
    wx = (x - x0_val) / (x1_val - x0_val)
    wy = (y - y0_val) / (y1_val - y0_val)
    
    # 确保z_grid索引正确
    z00 = z_grid[y0, x0]
    z01 = z_grid[y0, x1]
    z10 = z_grid[y1, x0]
    z11 = z_grid[y1, x1]
    
    # 双线性插值公式
    z = (1 - wx) * (1 - wy) * z00 + wx * (1 - wy) * z01 + (1 - wx) * wy * z10 + wx * wy * z11
    return z

# 神经网络
class PINN(nn.Module):
    def __init__(self):
        super(PINN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(3, 32),
            nn.Tanh(),
            nn.Linear(32, 32),
            nn.Tanh(),
            nn.Linear(32, 2)
        )
    
    def forward(self, t, x, y):
        t = t.squeeze(-1)
        x = x.squeeze(-1)
        y = y.squeeze(-1)
        inputs = torch.stack([t, x, y], dim=-1)
        return self.net(inputs)

# 物理损失
def physics_loss(model, t, x, y, x_grid, y_grid, z_grid, g=9.81):
    t = t.requires_grad_(True)
    x = x.requires_grad_(True)
    y = y.requires_grad_(True)
    
    xy_pred = model(t, x, y)
    # print("physics_loss xy_pred shape:", xy_pred.shape)
    x_pred, y_pred = xy_pred[:, 0], xy_pred[:, 1]
    
    ax = torch.autograd.grad(x_pred, t, grad_outputs=torch.ones_like(x_pred), create_graph=True)[0]
    ax = torch.autograd.grad(ax, t, grad_outputs=torch.ones_like(ax), create_graph=True)[0]
    ay = torch.autograd.grad(y_pred, t, grad_outputs=torch.ones_like(y_pred), create_graph=True)[0]
    ay = torch.autograd.grad(ay, t, grad_outputs=torch.ones_like(ay), create_graph=True)[0]
    
    z = get_z(x, y, x_grid, y_grid, z_grid)
    dz_dx = torch.autograd.grad(z, x, grad_outputs=torch.ones_like(z), create_graph=True)[0]
    dz_dy = torch.autograd.grad(z, y, grad_outputs=torch.ones_like(z), create_graph=True)[0]
    
    fx = -g * dz_dx
    fy = -g * dz_dy
    return torch.mean((ax - fx)**2 + (ay - fy)**2)

# 数据损失
def data_loss(model, t, x, y, x0, y0):
    xy_pred = model(t, x, y)
    # print("data_loss xy_pred shape:", xy_pred.shape)
    x_pred, y_pred = xy_pred[:, 0], xy_pred[:, 1]
    return torch.mean((x_pred - x0)**2 + (y_pred - y0)**2)

# 训练
def train_pinn(model, t_data, x_data, y_data, x_grid, y_grid, z_grid, t_ic, x_ic, y_ic, epochs=1000):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    for epoch in range(epochs):
        optimizer.zero_grad()
        phys_loss = physics_loss(model, t_data, x_data, y_data, x_grid, y_grid, z_grid)
        data_loss_val = data_loss(model, t_ic, x_ic, y_ic, x_ic, y_ic)
        total_loss = phys_loss + 10.0 * data_loss_val
        total_loss.backward()
        optimizer.step()
        if epoch % 100 == 0:
            print(f"Epoch {epoch}, Loss: {total_loss.item():.6f}")

# 数据准备
N = 1000
t_data = torch.rand(N, 1) * 1.0
x_data = torch.rand(N, 1) * 100.0
y_data = torch.rand(N, 1) * 100.0
t_ic = torch.zeros(1, 1)
x_ic = torch.tensor([[50.0]])
y_ic = torch.tensor([[50.0]])

# 调试：检查输入形状
print("t_data shape:", t_data.shape)
print("x_data shape:", x_data.shape)
print("y_data shape:", y_data.shape)
print("t_ic shape:", t_ic.shape)
print("x_ic shape:", x_ic.shape)
print("y_ic shape:", y_ic.shape)

# 训练
model = PINN()
train_pinn(model, t_data, x_data, y_data, x_grid, y_grid, z_grid, t_ic, x_ic, y_ic)

t_data shape: torch.Size([1000, 1])
x_data shape: torch.Size([1000, 1])
y_data shape: torch.Size([1000, 1])
t_ic shape: torch.Size([1, 1])
x_ic shape: torch.Size([1, 1])
y_ic shape: torch.Size([1, 1])
Epoch 0, Loss: 50237.449219
Epoch 100, Loss: 41013.957031
Epoch 200, Loss: 34762.566406
Epoch 300, Loss: 29555.302734
Epoch 400, Loss: 25049.261719
Epoch 500, Loss: 21122.011719
Epoch 600, Loss: 17700.240234
Epoch 700, Loss: 14728.769531
Epoch 800, Loss: 12161.251953
Epoch 900, Loss: 9956.504883
