In [1]:
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from torch_geometric.utils import add_self_loops
from sklearn.model_selection import train_test_split                                                                                                   
from sklearn.preprocessing import StandardScaler
import numpy as np
import os
from torch.amp import autocast
from torch.optim.lr_scheduler import CosineAnnealingLR, CosineAnnealingWarmRestarts
from sklearn.metrics import r2_score as sklearn_r2_score
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
import sys
sys.path.append('/home/jovyan/work/GNO/GNN/GNNShap')
from gnnshap.explainer import GNNShapExplainer
import uuid

# 导入Blitz库
from blitz.modules import BayesianLinear                 
from blitz.utils import variational_estimator

# 设备配置
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# 自定义数据集类
class HydroDataset(torch.utils.data.Dataset):
    def __init__(self, graphs):
        self.graphs = graphs
    
    def __len__(self):
        return len(self.graphs)
    
    def __getitem__(self, idx):
        return self.graphs[idx]

class BlitzSTConv(MessagePassing):
    """
    使用Blitz实现的贝叶斯时空卷积层，增强鲁棒性
    """
    def __init__(self, spatial_dim, prior_sigma_1=0.1, prior_sigma_2=0.002, posterior_mu_init=0.0, posterior_rho_init=-3.0, dropout=0.1):
        super().__init__(aggr='mean')
        # 使用标准的PyTorch层进行初始化处理
        self.pre_msg = nn.Linear(2 * spatial_dim + 3, spatial_dim)
        
        # 然后使用Blitz层进行贝叶斯推断
        self.bayes_msg = nn.Sequential(
            BayesianLinear(spatial_dim, spatial_dim,
                        prior_sigma_1=prior_sigma_1, 
                        prior_sigma_2=prior_sigma_2,
                        posterior_mu_init=posterior_mu_init,
                        posterior_rho_init=posterior_rho_init),
            nn.ReLU(inplace=True),
            nn.LayerNorm(spatial_dim),
            nn.Dropout(dropout)
        )
        
        # 同样的模式用于gate网络
        self.pre_gate = nn.Linear(3 * spatial_dim, spatial_dim)
        self.bayes_gate = nn.Sequential(
            BayesianLinear(spatial_dim, spatial_dim,
                        prior_sigma_1=prior_sigma_1, 
                        prior_sigma_2=prior_sigma_2,
                        posterior_mu_init=posterior_mu_init,
                        posterior_rho_init=posterior_rho_init),
            nn.Sigmoid()
        )
        
        # 残差处理
        self.pre_res = nn.Linear(spatial_dim, spatial_dim)
        self.bayes_res = BayesianLinear(spatial_dim, spatial_dim,
                                     prior_sigma_1=prior_sigma_1/2, 
                                     prior_sigma_2=prior_sigma_2/2,
                                     posterior_mu_init=posterior_mu_init,
                                     posterior_rho_init=posterior_rho_init)

    def forward(self, x, edge_index, edge_attr):
        try:
            # 添加调试信息
            # 确保edge_attr是浮点类型
            edge_attr = edge_attr.float()
            
            # 尝试传播消息
            out = self.propagate(edge_index, x=x, edge_attr=edge_attr)
            
            # Gate机制
            combined = torch.cat([x, out, x - out], dim=-1)
            gate_pre = self.pre_gate(combined)
            gate = self.bayes_gate(gate_pre)
            
            # 残差连接
            res_pre = self.pre_res(x)
            res = self.bayes_res(res_pre)
            
            return x + gate * out + 0.1 * res
        except Exception as e:
            print(f"Error in BlitzSTConv.forward: {e}")
            # 提供更详细的错误信息
            print(f"x shape: {x.shape}, dtype: {x.dtype}")
            print(f"edge_index shape: {edge_index.shape}, dtype: {edge_index.dtype}")
            print(f"edge_attr shape: {edge_attr.shape}, dtype: {edge_attr.dtype}")
            raise

    def message(self, x_i, x_j, edge_attr):
        try:
            # 添加调试信息
            edge_attr = edge_attr.to(x_i.dtype).to(x_i.device)
            
            # 先使用标准层，然后使用贝叶斯层
            combined = torch.cat([x_i, x_j, edge_attr], dim=-1)
            pre_msg = self.pre_msg(combined)
            return self.bayes_msg(pre_msg)
        except Exception as e:
            print(f"Error in BlitzSTConv.message: {e}")
            # 提供更详细的错误信息
            print(f"x_i shape: {x_i.shape}, dtype: {x_i.dtype}")
            print(f"x_j shape: {x_j.shape}, dtype: {x_j.dtype}")
            print(f"edge_attr shape: {edge_attr.shape}, dtype: {edge_attr.dtype}")
            raise

class BlitzBoundaryProcessor(nn.Module):
    """
    使用Blitz实现的贝叶斯边界处理器
    """
    def __init__(self, dim, prior_sigma_1=0.1, prior_sigma_2=0.002, posterior_mu_init=0.0, posterior_rho_init=-3.0):
        super().__init__()
        self.boundary_net = nn.Sequential(
            BayesianLinear(dim + 1, dim,
                        prior_sigma_1=prior_sigma_1, 
                        prior_sigma_2=prior_sigma_2,
                        posterior_mu_init=posterior_mu_init,
                        posterior_rho_init=posterior_rho_init),
            nn.ReLU()
        )
        
        self.river_net = nn.Sequential(
            BayesianLinear(dim + 2, dim,
                        prior_sigma_1=prior_sigma_1, 
                        prior_sigma_2=prior_sigma_2,
                        posterior_mu_init=posterior_mu_init,
                        posterior_rho_init=posterior_rho_init),
            nn.ReLU()
        )
        
        self.well_net = BayesianLinear(dim + 1, dim,
                                    prior_sigma_1=prior_sigma_1, 
                                    prior_sigma_2=prior_sigma_2,
                                    posterior_mu_init=posterior_mu_init,
                                    posterior_rho_init=posterior_rho_init)
        
        self.gate = nn.Sequential(
            BayesianLinear(2 * dim, dim,
                        prior_sigma_1=prior_sigma_1/2, 
                        prior_sigma_2=prior_sigma_2/2,
                        posterior_mu_init=posterior_mu_init,
                        posterior_rho_init=posterior_rho_init),
            nn.Sigmoid()
        )
        
        self.chd_enforcer = BayesianLinear(dim, dim,
                                        prior_sigma_1=prior_sigma_1/2, 
                                        prior_sigma_2=prior_sigma_2/2,
                                        posterior_mu_init=posterior_mu_init,
                                        posterior_rho_init=posterior_rho_init)
    
    def forward(self, x, bc_mask):
        boundary_feat = self.boundary_net(
            torch.cat([x, bc_mask[:, 0:1]], dim=-1)
        ) * bc_mask[:, 0:1]
        
        river_feat = self.river_net(
            torch.cat([x, bc_mask[:, 1:3]], dim=-1)
        ) * bc_mask[:, 1:2]
        
        well_feat = self.well_net(
            torch.cat([x, bc_mask[:, 3:4]], dim=-1)
        ) * bc_mask[:, 4:5]
        
        combined = boundary_feat + river_feat + well_feat
        gate = self.gate(torch.cat([x, combined], dim=-1))
        out = x * (1 - gate) + combined * gate
        
        chd_mask = bc_mask[:, 0] > 0
        if chd_mask.sum() > 0:
            chd_out = self.chd_enforcer(out[chd_mask]).to(out.dtype)
            out[chd_mask] = chd_out
            
        return out

@variational_estimator
class BlitzHeadGNN(nn.Module):
    """
    使用Blitz实现的贝叶斯水头预测GNN，加入前一时间步的水头和浓度特征
    """
    def __init__(self, node_features=16, max_time_steps=40, spatial_dim=64, 
                temporal_dim=64, output_dim=1, prior_sigma_1=0.05, prior_sigma_2=0.001,
                posterior_mu_init=0.0, posterior_rho_init=-3.0, dropout=0.2):
        super().__init__()
        self.spatial_dim = spatial_dim
        self.time_embed = nn.Embedding(max_time_steps + 1, temporal_dim)
        
        # 节点编码器 - 注意这里的输入维度变为node_features + temporal_dim
        # 节点编码器 - 更保守的设计
        self.node_enc = nn.Sequential(
            nn.Linear(node_features + temporal_dim, spatial_dim),
            nn.BatchNorm1d(spatial_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(spatial_dim, spatial_dim),
            nn.BatchNorm1d(spatial_dim),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        
        # 卷积层
        self.conv_layers = nn.ModuleList([
            BlitzSTConv(spatial_dim, prior_sigma_1, prior_sigma_2, 
                     posterior_mu_init, posterior_rho_init, dropout) 
            for _ in range(4)
        ])
        
        # 边界处理器
        self.bc_processor = BlitzBoundaryProcessor(
            spatial_dim, prior_sigma_1, prior_sigma_2, 
            posterior_mu_init, posterior_rho_init
        )
        
        # 确定性路径 - 增强权重
        self.deterministic_path = nn.Sequential(
            nn.Linear(spatial_dim, spatial_dim // 2),
            nn.BatchNorm1d(spatial_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(spatial_dim // 2, spatial_dim // 4),
            nn.ReLU(),
            nn.Linear(spatial_dim // 4, output_dim),
            nn.Softplus()
        )
        
        # 贝叶斯路径 - 降低复杂度
        self.bayesian_path = nn.Sequential(
            BayesianLinear(spatial_dim, spatial_dim // 2,
                          prior_sigma_1=prior_sigma_1, 
                          prior_sigma_2=prior_sigma_2,
                          posterior_mu_init=posterior_mu_init,
                          posterior_rho_init=posterior_rho_init),
            nn.ReLU(),
            nn.Dropout(dropout),
            BayesianLinear(spatial_dim // 2, output_dim,
                          prior_sigma_1=prior_sigma_1/2, 
                          prior_sigma_2=prior_sigma_2/2,
                          posterior_mu_init=posterior_mu_init,
                          posterior_rho_init=posterior_rho_init),
            nn.Softplus()
        )
        
        # 注意力机制用于特征选择
        self.attention = nn.Sequential(
            nn.Linear(spatial_dim, spatial_dim // 4),
            nn.ReLU(),
            nn.Linear(spatial_dim // 4, spatial_dim),
            nn.Sigmoid()
        )

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr.to(torch.float32)
        
        # 特征工程
        if hasattr(self, 'feature_engineering'):
            x = self.feature_engineering(x)
        
        # 时间嵌入
        time_emb = self.time_embed(data.time_step)
        node_feat = torch.cat([x, time_emb], dim=-1)
        
        # 节点编码
        h = self.node_enc(node_feat)
        
        # 简化的图卷积
        for conv in self.conv_layers:
            h_new = conv(h, edge_index, edge_attr)
            h = h + 0.1 * h_new  # 残差连接
        
        # 注意力加权
        attention_weights = self.attention(h)
        h = h * attention_weights
        
        # 边界处理
        h = self.bc_processor(h, data.bc_mask)
        
        # 双路径预测，增加确定性路径权重
        det_pred = self.deterministic_path(h.detach())
        bayes_pred = self.bayesian_path(h)
        
        # 自适应权重组合
        return det_pred * 0.7 + bayes_pred * 0.3

class ImprovedPhysicsInformedLoss(nn.Module):
    """
    改进的物理信息损失函数，无需显式计算KL散度（Blitz内部处理）
    """
    def __init__(self, alpha=0.5, kl_weight=1e-4):
        super().__init__()
        self.alpha = alpha
        self.kl_weight = kl_weight  # 在Blitz中，KL权重在sample_elbo中传入
        
    def forward(self, pred, data, model=None):  # 添加model参数使接口一致，但不使用
        # 基础MSE损失
        mse_loss = F.mse_loss(pred, data.head_y.unsqueeze(1))
        
        # 物理约束损失
        time_steps = data.time_step.unique(sorted=True)
        flux_loss = 0
        for t in time_steps[:-1]:
            mask_t = (data.time_step == t)
            mask_next = (data.time_step == t + 1)
            if mask_t.sum() > 0 and mask_next.sum() > 0:
                flux_diff = torch.mean((pred[mask_next] - pred[mask_t]) ** 2)
                flux_loss += flux_diff
        flux_loss /= len(time_steps) - 1 if len(time_steps) > 1 else 1
        
        # 边界条件损失
        bc_mask = data.bc_mask[:, 0] > 0
        bc_loss = F.l1_loss(pred[bc_mask], data.head_y[bc_mask].unsqueeze(1)) if bc_mask.sum() > 0 else torch.tensor(0.0, device=pred.device)
        
        # 井条件损失
        well_mask = data.bc_mask[:, 4] > 0
        well_loss = F.l1_loss(pred[well_mask], data.head_y[well_mask].unsqueeze(1)) if well_mask.sum() > 0 else torch.tensor(0.0, device=pred.device)
        
        # 总损失 - 不包含KL散度，KL散度由Blitz内部处理
        # total_loss = (1 - self.alpha) * mse_loss + self.alpha * (flux_loss + bc_loss + well_loss)
        total_loss= mse_loss
        return total_loss, (mse_loss.item(), flux_loss.item(), bc_loss.item(), well_loss.item(), 0.0)  # 返回0.0表示KL损失，但实际上由Blitz处理


@variational_estimator
class BlitzConcGNN(nn.Module):
    """
    简化的贝叶斯浓度预测GNN，直接使用18维输入特征
    """
    def __init__(self, node_features=19, max_time_steps=40, spatial_dim=128,
                temporal_dim=64, output_dim=1, prior_sigma_1=0.1, prior_sigma_2=0.01,
                posterior_mu_init=0.0, posterior_rho_init=-3.0, dropout=0.1):
        super().__init__()
        self.spatial_dim = spatial_dim
        self.time_embed = nn.Embedding(max_time_steps + 1, temporal_dim)
        
        # 简化的节点编码器 - 直接处理18维特征
        self.node_enc_scale1 = nn.Sequential(
            nn.Linear(node_features + temporal_dim, spatial_dim),  # 18 + 64 = 82 -> 128
            nn.BatchNorm1d(spatial_dim),
            nn.ReLU(),
            nn.Dropout(dropout * 0.5)
        )
        
        self.node_enc = nn.Sequential(
            BayesianLinear(spatial_dim, spatial_dim,
                        prior_sigma_1=prior_sigma_1, 
                        prior_sigma_2=prior_sigma_2,
                        posterior_mu_init=posterior_mu_init,
                        posterior_rho_init=posterior_rho_init),
            nn.ReLU(inplace=True),
            nn.LayerNorm(spatial_dim),
            nn.Dropout(dropout)
        )
        
        # 卷积层
        self.conv_layers = nn.ModuleList([
            BlitzSTConv(spatial_dim, prior_sigma_1, prior_sigma_2, 
                     posterior_mu_init, posterior_rho_init, dropout) 
            for _ in range(4)
        ])
        
        # 边界处理器
        self.bc_processor = BlitzBoundaryProcessor(
            spatial_dim, prior_sigma_1, prior_sigma_2, 
            posterior_mu_init, posterior_rho_init
        )
        
        # 注意力机制
        self.attention = nn.Sequential(
            nn.Linear(spatial_dim, spatial_dim // 4),
            nn.ReLU(),
            nn.Linear(spatial_dim // 4, spatial_dim),
            nn.Sigmoid()
        )
        
        # 贝叶斯分支解码器
        self.decoder = nn.Sequential(
            BayesianLinear(spatial_dim, 128,
                        prior_sigma_1=prior_sigma_1/2, 
                        prior_sigma_2=prior_sigma_2/2,
                        posterior_mu_init=posterior_mu_init,
                        posterior_rho_init=posterior_rho_init),
            nn.ReLU(),
            nn.Dropout(dropout),
            BayesianLinear(128, 64,
                        prior_sigma_1=prior_sigma_1/2, 
                        prior_sigma_2=prior_sigma_2/2,
                        posterior_mu_init=posterior_mu_init,
                        posterior_rho_init=posterior_rho_init),
            nn.ReLU(),
            BayesianLinear(64, output_dim,
                        prior_sigma_1=prior_sigma_1/4, 
                        prior_sigma_2=prior_sigma_2/4,
                        posterior_mu_init=posterior_mu_init,
                        posterior_rho_init=posterior_rho_init),
            nn.Softplus()
        )
        
        # 确定性分支解码器
        self.decoder_det = nn.Sequential(
            nn.Linear(spatial_dim, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, output_dim),
            nn.Softplus()
        )
        
        # 自适应分支权重
        self.branch_weight = nn.Parameter(torch.tensor(0.3))

    def forward(self, data, pred_head=None):
        # 直接使用conc_x作为输入（18维特征）
        x = data.conc_x
        edge_index, edge_attr = data.edge_index, data.edge_attr
        
        # 确保所有张量在同一设备
        if edge_attr is not None:
            edge_attr = edge_attr.to(torch.float32).to(x.device)
        
        # 时间嵌入
        time_step = data.time_step.to(x.device)
        time_emb = self.time_embed(time_step)
        
        # 节点特征与时间嵌入结合
        node_feat = torch.cat([x, time_emb], dim=-1)  # (18 + 64) = 82维
        
        # 节点编码
        h_scale1 = self.node_enc_scale1(node_feat)
        h = self.node_enc(h_scale1)
        
        # 图卷积层
        for conv in self.conv_layers:
            h_new = conv(h, edge_index, edge_attr)
            h = h + 0.1 * h_new  # 残差连接
        
        # 注意力加权
        attention_weights = self.attention(h)
        h = h * attention_weights
        
        # 边界处理
        bc_mask = data.bc_mask.to(x.device) if hasattr(data, 'bc_mask') else None
        h = self.bc_processor(h, bc_mask)
        
        # 双分支输出
        bayes_output = self.decoder(h)
        det_output = self.decoder_det(h.detach())
        
        # 自适应权重组合
        combined_output = (
            torch.sigmoid(self.branch_weight) * bayes_output + 
            (1 - torch.sigmoid(self.branch_weight)) * det_output
        )
        
        return combined_output

class ImprovedConcLoss(nn.Module):
    """
    改进的浓度预测损失函数，优化L1正则化
    """
    def __init__(self, kl_weight=5e-5, l1_weight=1e-8):  # 大幅降低L1权重
        super().__init__()
        self.kl_weight = kl_weight
        self.l1_weight = l1_weight  # L1正则化权重降低1000倍
        
    def forward(self, pred, data, model=None):
        # 使用MSE损失
        mse_loss = F.mse_loss(pred, data.y.unsqueeze(1))
        
        # 只对贝叶斯层应用L1正则化
        l1_reg = torch.tensor(0., device=pred.device)
        if model is not None and self.l1_weight > 0:
            for name, param in model.named_parameters():
                # 只对贝叶斯层的权重应用L1正则化
                if 'weight' in name and ('bayesian' in name.lower() or 'bayes' in name.lower()):
                    l1_reg += torch.norm(param, 1)
        
        # 总损失
        total_loss = mse_loss + self.l1_weight * l1_reg
        
        return total_loss, (mse_loss.item(), 0.0, l1_reg.item())

def generate_spatial_edges(model_df):
    """生成空间边和边属性"""
    spatial_edges = []
    time_steps = model_df['time_step'].unique()
    for t in time_steps:
        time_df = model_df[model_df['time_step'] == t]
        coord_to_idx = {(r, c): idx for idx, r, c in zip(time_df['local_index'], time_df['row'], time_df['col'])}
        for idx, row, col in zip(time_df['local_index'], time_df['row'], time_df['col']):
            right_coord = (row, col + 1)
            if right_coord in coord_to_idx:
                spatial_edges.append([idx, coord_to_idx[right_coord]])
            upper_coord = (row + 1, col)
            if upper_coord in coord_to_idx:
                spatial_edges.append([idx, coord_to_idx[upper_coord]])
    return np.array(spatial_edges), np.full((len(spatial_edges), 3), [1.0, 0, 0], dtype=np.float32)

def generate_temporal_edges(model_df):
    """生成时间边和边属性"""
    temporal_edges = []
    groups = model_df.groupby(['row', 'col'], sort=False)
    for (row, col), group in groups:
        time_series = group.sort_values('time_step')
        for i in range(len(time_series) - 1):
            global_src = time_series['local_index'].iloc[i]
            global_dst = time_series['local_index'].iloc[i + 1]
            temporal_edges.append([global_src, global_dst])
    return np.array(temporal_edges), np.full((len(temporal_edges), 3), [0.0, 1.0, 0], dtype=np.float32)

def build_bc_mask(model_df):
    """构建边界条件掩码"""
    bc_mask = np.zeros((len(model_df), 5), dtype=np.float32)
    bc_mask[:, 0] = model_df['chd_mask'].values.astype(np.float32)
    bc_mask[:, 1] = (model_df['river_cond'] > 0).astype(np.float32)
    bc_mask[:, 2] = model_df['river_stage'].values.astype(np.float32)
    bc_mask[:, 3] = model_df['well_rate'].values.astype(np.float32)
    bc_mask[:, 4] = model_df['well_mask'].values.astype(np.float32)
    return bc_mask
def build_spatiotemporal_graph(df):
    """构建时空图，将所有特征整合到一个输入中"""
    print(f"\n▶ Started building spatiotemporal graphs")
    print(f"▷ Total models to process: {len(df['model_name'].unique())}")
    graphs = []
    
    # 修正：添加 lytyp 字段的类型定义
    df = df.astype({
        'x': np.float32, 'y': np.float32, 'top': np.float32, 
        'bottom': np.float32, 'K': np.float32, 'recharge': np.float32,
        'ET': np.float32, 'river_stage': np.float32, 'river_cond': np.float32,
        'river_rbot': np.float32, 'well_rate': np.float32, 'well_mask': np.uint8,
        'chd_mask': np.uint8, 'lytyp': np.uint8, 'head': np.float32, 
        'concentration': np.float32,'conc_mask': np.uint8
    })
    
    time_min = df['time_step'].min()
    df['time_step'] = df['time_step'] - time_min
    
    model_groups = list(df.groupby('model_name', sort=False))
    total_models = len(model_groups)
    
    for model_idx, (model_name, model_df) in enumerate(model_groups, 1):
        model_df = model_df.reset_index(drop=True).copy()
        model_df['local_index'] = model_df.index
        print(f"\n▣ Processing model {model_idx}/{total_models}: {model_name}")
        
        model_df = model_df.sort_values(['row', 'col', 'time_step'])
        
        # 基础特征列（14维）
        feature_cols = [
            'x', 'y', 'top', 'bottom', 'K', 'recharge', 'ET',
            'river_stage', 'river_cond', 'river_rbot', 'well_rate', 'well_mask',
            'chd_mask', 'lytyp'
        ]
        node_feats = model_df[feature_cols].values.astype(np.float32)
        col_types = df[feature_cols].dtypes.to_dict()
        float_indices = [i for i, col in enumerate(feature_cols) if col_types[col] != np.uint8]
        float_feats = node_feats[:, float_indices]
        scaler = StandardScaler()
        float_feats_scaled = scaler.fit_transform(float_feats)
        node_feats[:, float_indices] = float_feats_scaled
        conc_feature_cols = [
            'x', 'y', 'top', 'bottom', 'K', 'recharge', 'ET',
            'river_stage', 'river_cond', 'river_rbot', 'well_rate', 'well_mask',
            'chd_mask', 'lytyp','conc_mask'
        ]
        conc_node_feats = model_df[conc_feature_cols].values.astype(np.float32)
        col_types = df[conc_feature_cols].dtypes.to_dict()
        float_indices = [i for i, col in enumerate(conc_feature_cols) if col_types[col] != np.uint8]
        conc_float_feats = conc_node_feats[:, float_indices]
        scaler = StandardScaler()
        conc_float_feats_scaled = scaler.fit_transform(conc_float_feats)
        conc_node_feats[:, float_indices] = conc_float_feats_scaled
        # 计算前一时间步和前两个时间步的水头和浓度
        prev_head = np.zeros(len(model_df), dtype=np.float32)
        prev2_head = np.zeros(len(model_df), dtype=np.float32)
        prev_conc = np.zeros(len(model_df), dtype=np.float32)
        prev2_conc = np.zeros(len(model_df), dtype=np.float32)
        
        groups = model_df.groupby(['row', 'col'], sort=False)
        for (row, col), group in groups:
            time_series = group.sort_values('time_step')
            prev_head[time_series.index] = np.roll(time_series['head'].values, 1)
            prev2_head[time_series.index] = np.roll(time_series['head'].values, 2)
            prev_conc[time_series.index] = np.roll(time_series['concentration'].values, 1)
            prev2_conc[time_series.index] = np.roll(time_series['concentration'].values, 2)
            
            first_idx = time_series.index[0]
            if len(time_series) > 1:
                second_idx = time_series.index[1]
                # 水头特征处理
                prev_head[first_idx] = time_series['head'].values[0]
                prev2_head[first_idx] = time_series['head'].values[0]
                prev2_head[second_idx] = time_series['head'].values[0]
                
                # 浓度特征处理
                prev_conc[first_idx] = time_series['concentration'].values[0]
                prev2_conc[first_idx] = time_series['concentration'].values[0]
                prev2_conc[second_idx] = time_series['concentration'].values[0]
            else:
                prev_head[first_idx] = 0.0
                prev2_head[first_idx] = 0.0
                prev_conc[first_idx] = 0.0
                prev2_conc[first_idx] = 0.0

        # 为水头模型：基础特征 + 前一/前二时间步水头（16维）
        head_feats = np.concatenate([
            node_feats,           # 14维基础特征
            prev_head[:, None],   # 1维前一时间步水头
            prev2_head[:, None]   # 1维前二时间步水头
        ], axis=1)
        
        # 为浓度模型：基础特征 + 前一/前二时间步水头 + 前一/前二时间步浓度（18维）
        conc_feats = np.concatenate([
            conc_node_feats,           # 14维基础特征
            prev_head[:, None],   # 1维前一时间步水头
            prev2_head[:, None],  # 1维前二时间步水头
            prev_conc[:, None],   # 1维前一时间步浓度
            prev2_conc[:, None]   # 1维前二时间步浓度
        ], axis=1)
        
        if np.any(np.isnan(head_feats)) or np.any(np.isinf(head_feats)):
            print(f"Warning: head_feats contains NaN or Inf for model {model_name}")
            head_feats = np.nan_to_num(head_feats, nan=0.0, posinf=1e6, neginf=-1e6)
        
        if np.any(np.isnan(conc_feats)) or np.any(np.isinf(conc_feats)):
            print(f"Warning: conc_feats contains NaN or Inf for model {model_name}")
            conc_feats = np.nan_to_num(conc_feats, nan=0.0, posinf=1e6, neginf=-1e6)
        
        conc = model_df['concentration'].values.astype(np.float32)
        head = model_df['head'].values.astype(np.float32)
        spatial_edges, spatial_attrs = generate_spatial_edges(model_df)
        temporal_edges, temporal_attrs = generate_temporal_edges(model_df)
        edges = np.concatenate([spatial_edges, temporal_edges], axis=0)
        edge_attr = np.concatenate([spatial_attrs, temporal_attrs], axis=0)
        bc_mask = build_bc_mask(model_df)
        
        assert bc_mask.shape == (len(model_df), 5), \
            f"Invalid bc_mask shape: {bc_mask.shape} for model {model_name}"
        
        graph = Data(
            x=torch.from_numpy(head_feats),      # 水头模型特征（16维）
            conc_x=torch.from_numpy(conc_feats), # 浓度模型特征（18维）
            edge_index=torch.tensor(edges.T, dtype=torch.long),
            edge_attr=torch.from_numpy(edge_attr),
            y=torch.from_numpy(conc),
            head_y=torch.from_numpy(head),
            bc_mask=torch.from_numpy(bc_mask),
            time_step=torch.from_numpy(model_df['time_step'].values).long(),
            time_steps=model_df['time_step'].nunique(),
            model_name=str(model_name),
            row=torch.from_numpy(model_df['row'].values).long(),
            col=torch.from_numpy(model_df['col'].values).long(),
        )
        graphs.append(graph)
    
    print(f"\n✅ All models processed! Total graphs created: {len(graphs):,}")
    return graphs


def prepare_data(data, batch_size=4):
    """准备数据加载器"""
    print('正在处理数据...')
    all_graphs = build_spatiotemporal_graph(data)
    print('数据处理完成！')
    train_graphs, val_graphs = train_test_split(
        all_graphs, test_size=0.3, random_state=42
    )
    train_dataset = HydroDataset(train_graphs)
    val_dataset = HydroDataset(val_graphs)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    return train_loader, val_loader

# Training configuration
config = {
    'head_input_dim': 16,
    'conc_input_dim': 19,
    'hidden_dim': 96,  # 增大隐藏维度
    'num_epochs': 500,
    'lr': 1e-3,  # 降低学习率以提高稳定性
    'weight_decay': 1e-4,
    'patience': 30,
    'save_path': './saved_models/blitz_bayesian_gnn_dual_base',
    'mc_samples': 10,
    'head_prior_sigma_1': 0.01,  # Blitz先验参数
    'head_prior_sigma_2': 0.002,
    'conc_prior_sigma_1': 0.05,  # 浓度模型使用较小的先验
    'conc_prior_sigma_2': 0.002,
    'kl_weight': 1e-4  # Blitz中的KL散度权重
}

def compute_metrics(y_true, y_pred):
    """计算评估指标"""
    if isinstance(y_true, torch.Tensor):
        y_true = y_true.detach().cpu().numpy()
    if isinstance(y_pred, torch.Tensor):
        y_pred = y_pred.detach().cpu().numpy()

    y_true = y_true.flatten()
    y_pred = y_pred.flatten()

    mask = ~np.isnan(y_true) & ~np.isinf(y_true) & ~np.isnan(y_pred) & ~np.isinf(y_pred)
    y_true = y_true[mask]
    y_pred = y_pred[mask]

    mse = np.mean((y_true - y_pred) ** 2)
    rmse = np.sqrt(mse)
    mae = np.mean(np.abs(y_true - y_pred))
    r2 = sklearn_r2_score(y_true, y_pred)

    return {
        'mse': mse,
        'rmse': rmse,
        'mae': mae,
        'r2': r2
    }

def compute_uncertainty(model, data, mc_samples=10):
    """
    为Blitz模型计算预测和不确定性
    """
    model.train()  # Blitz在训练模式下采样权重
    predictions = []
    
    with torch.no_grad():
        for _ in range(mc_samples):
            pred = model(data)
            predictions.append(pred)
    
    predictions = torch.stack(predictions, dim=0)
    mean_pred = predictions.mean(dim=0)
    std_pred = predictions.std(dim=0)
    
    return mean_pred, std_pred

def compute_feature_shap_values_improved(model, data, n_samples=20, num_samples=10):
    """
    计算模型对输入特征的重要性，使用Blitz贝叶斯网络
    """
    model.train()  # 使用训练模式以启用贝叶斯采样
    data = data.to(next(model.parameters()).device)
    
    print(f"[FeatureSHAP] 开始特征重要性分析，抽样{n_samples}个节点...")
    
    # 确保数据有必要的属性
    if not hasattr(data, 'x') or not torch.is_tensor(data.x):
        print("[FeatureSHAP] 错误: 数据缺少节点特征 (data.x)")
        return None, 0.0
    
    try:
        # 获取当前时间步的节点
        current_time_step = data.time_step.unique()[0].item() if hasattr(data, 'time_step') else 0
        time_mask = data.time_step == current_time_step if hasattr(data, 'time_step') else torch.ones(data.num_nodes, dtype=torch.bool, device=data.x.device)
        candidate_nodes = torch.where(time_mask)[0]
        
        if len(candidate_nodes) == 0:
            print("[FeatureSHAP] 错误: 找不到满足条件的节点")
            return None, 0.0
        
        # 调整样本数
        actual_n_samples = min(n_samples, len(candidate_nodes))
        if actual_n_samples < n_samples:
            print(f"[FeatureSHAP] 警告: 候选节点数({len(candidate_nodes)})少于请求的样本数({n_samples})，调整为{actual_n_samples}")
        
        # 随机抽样节点
        sampled_indices = torch.randperm(len(candidate_nodes))[:actual_n_samples]
        sampled_nodes = candidate_nodes[sampled_indices]
        
        # 特征数量
        num_features = data.x.size(1)
        
        # 初始化SHAP值存储
        all_shap_values = torch.zeros(actual_n_samples, num_features, device=data.x.device)
        all_expected_values = torch.zeros(actual_n_samples, device=data.x.device)
        
        # 生成基准预测
        baseline_preds = []
        for _ in range(num_samples):
            with torch.no_grad():
                pred = model(data)
                baseline_preds.append(pred)
        baseline_pred = torch.stack(baseline_preds, dim=0).mean(dim=0)
        
        # 对每个抽样节点计算特征重要性
        for i, node_idx in enumerate(sampled_nodes):
            node_idx = node_idx.item()
            original_value = baseline_pred[node_idx].item()
            all_expected_values[i] = original_value
            
            # 对每个特征计算重要性
            for feat_idx in range(num_features):
                # 保存原始特征值
                original_feat = data.x[:, feat_idx].clone()
                
                # 计算特征的平均值
                feat_mean = original_feat.mean()
                
                # 掩码该特征（使用平均值替换）
                data.x[:, feat_idx] = feat_mean
                
                # 蒙特卡洛采样以获取更稳定的结果
                masked_preds = []
                for _ in range(num_samples):
                    with torch.no_grad():
                        masked_pred = model(data)
                        masked_preds.append(masked_pred)
                
                # 计算掩码后的平均预测
                masked_pred = torch.stack(masked_preds, dim=0).mean(dim=0)
                
                # 计算特征重要性（原始预测与掩码后预测的差异）
                shap_value = abs(original_value - masked_pred[node_idx].item())
                all_shap_values[i, feat_idx] = shap_value
                
                # 恢复原始特征值
                data.x[:, feat_idx] = original_feat
            
            # 每5个节点输出一次进度
            if (i + 1) % 5 == 0 or i == len(sampled_nodes) - 1:
                print(f"[FeatureSHAP] 已完成 {i+1}/{len(sampled_nodes)} 个节点的分析")
        
        # 计算平均SHAP值
        avg_shap_values = all_shap_values.mean(dim=0)
        avg_expected_value = all_expected_values.mean().item()
        
        # 归一化SHAP值
        if avg_shap_values.sum() > 0:
            avg_shap_values = avg_shap_values / avg_shap_values.sum()
        
        print("[FeatureSHAP] 特征重要性分析完成")
        return avg_shap_values.cpu().numpy(), avg_expected_value
    
    except Exception as e:
        print(f"[FeatureSHAP] 特征重要性分析出错: {e}")
        import traceback
        traceback.print_exc()
        return None, 0.0
# 为Blitz的sample_elbo函数创建适配器损失函数
class BlitzHeadLossAdapter(nn.Module):
    def __init__(self, base_criterion, data, kl_weight=1e-4):
        super().__init__()
        self.base_criterion = base_criterion
        self.data = data  # 保存数据对象
        self.kl_weight = kl_weight
    
    def forward(self, pred, labels):
        # 调用原始损失函数，但传入完整的data对象
        loss, _ = self.base_criterion(pred, self.data)
        return loss

class BlitzConcLossAdapter(nn.Module):
    def __init__(self, base_criterion, data, kl_weight=5e-5):
        super().__init__()
        self.base_criterion = base_criterion
        self.data = data  # 保存数据对象
        self.kl_weight = kl_weight
    
    def forward(self, pred, labels):
        # 调用原始损失函数，但传入完整的data对象
        loss, _ = self.base_criterion(pred, self.data)
        return loss

If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


Using device: cuda:0


In [2]:
def train_dual_model_improved(train_loader, val_loader, evaluation_criterion='r2'):
    """
    改进的双模型训练，同时保存基于损失和R2的最佳模型
    
    Args:
        train_loader: 训练数据加载器
        val_loader: 验证数据加载器
        evaluation_criterion: 最终评估使用的标准 ('loss' 或 'r2')
    """
    try:
        torch.cuda.empty_cache()
        print("CUDA缓存已成功清除")
    except RuntimeError as e:
        print(f"无法清除CUDA缓存: {e}")
        return None, None, None
    
    # ===============================
    # 第一阶段：训练水头模型
    # ===============================
    print("=" * 80)
    print("第一阶段：开始训练水头模型")
    print("=" * 80)
    
    # 根据实际数据更新特征维度
    head_input_dim = 16  # 14个基本特征 + 2个前一时间步的水头
    
    # 初始化水头模型
    head_model = BlitzHeadGNN(
        node_features=head_input_dim,
        spatial_dim=config['hidden_dim'],
        temporal_dim=config['hidden_dim'],
        prior_sigma_1=config['head_prior_sigma_1'],
        prior_sigma_2=config['head_prior_sigma_2'],
        dropout=0.1
    ).to(device)
    
    # 损失函数 - 只需要水头损失
    criterion_head = ImprovedPhysicsInformedLoss(alpha=0.1, kl_weight=config['kl_weight'])
    
    # 优化器 - 只优化水头模型参数
    head_params = list(head_model.parameters())
    head_optimizer = torch.optim.AdamW(head_params, lr=config['lr'], weight_decay=config['weight_decay'])
    
    # 学习率调度器
    head_scheduler = CosineAnnealingWarmRestarts(
        head_optimizer, T_0=20, T_mult=2, eta_min=1e-5
    )
    
    # 跟踪变量 - 分别跟踪损失和R2
    best_head_val_loss = float('inf')
    best_head_r2 = float('-inf')
    head_early_stop_counter = 0
    head_losses = {'train': [], 'val': []}
    
    # 创建保存目录
    os.makedirs(config['save_path'], exist_ok=True)
    
    print("开始训练水头模型")
    print(f"水头模型参数数量: {sum(p.numel() for p in head_model.parameters() if p.requires_grad)}")
    
    # 水头模型训练循环
    for epoch in range(config['num_epochs']):
        head_model.train()
        train_loss = 0.0
        train_batches = 0
        
        # 训练阶段
        for batch_idx, batch in enumerate(train_loader):
            try:
                # 准备数据
                batch = batch.to(device)
                if hasattr(batch, 'edge_attr') and batch.edge_attr is not None:
                    batch.edge_attr = batch.edge_attr.float()
                
                # 重置梯度
                head_optimizer.zero_grad()
                
                # 前向传播
                pred_head = head_model(batch)
                
                # 计算损失
                criterion_output = criterion_head(pred_head, batch)
                
                # 处理损失函数的返回值
                if isinstance(criterion_output, tuple):
                    head_criterion_loss = criterion_output[0]
                    if len(criterion_output) > 1:
                        physics_loss = criterion_output[1]
                        if isinstance(physics_loss, tuple):
                            physics_loss_value = sum([p.item() if hasattr(p, 'item') else p for p in physics_loss])
                        else:
                            physics_loss_value = physics_loss.item() if hasattr(physics_loss, 'item') else physics_loss
                    else:
                        physics_loss_value = 0.0
                else:
                    head_criterion_loss = criterion_output
                    physics_loss_value = 0.0
                
                kl_loss = head_model.nn_kl_divergence() * config['kl_weight']
                total_loss = head_criterion_loss + kl_loss
                
                # 反向传播
                total_loss.backward()
                
                # 梯度裁剪
                torch.nn.utils.clip_grad_norm_(head_params, max_norm=1.0)
                
                # 更新参数
                head_optimizer.step()
                
                # 记录损失
                train_loss += total_loss.item()
                train_batches += 1
                
                # 每50个批次输出一次详细信息
                if batch_idx % 50 == 0:
                    print(f"水头模型 Epoch {epoch+1}, Batch {batch_idx}: "
                          f"Total Loss: {total_loss.item():.4f}, "
                          f"Criterion Loss: {head_criterion_loss.item():.4f}, "
                          f"Physics Loss: {physics_loss_value:.4f}, "
                          f"KL Loss: {kl_loss.item():.4f}")
                
            except Exception as e:
                print(f"水头模型训练批次 {batch_idx} 出错: {e}")
                continue
        
        # 检查训练批次
        if train_batches == 0:
            print("警告: 水头模型本轮训练没有成功处理任何批次，跳过本轮")
            continue
            
        # 计算平均训练损失
        avg_train_loss = train_loss / train_batches
        
        # 验证阶段
        head_model.eval()
        val_loss = 0.0
        val_metrics = {'mse': 0.0, 'rmse': 0.0, 'mae': 0.0, 'r2': 0.0}
        val_batches = 0
        
        with torch.no_grad():
            for batch_idx, batch in enumerate(val_loader):
                try:
                    # 准备数据
                    batch = batch.to(device)
                    if hasattr(batch, 'edge_attr') and batch.edge_attr is not None:
                        batch.edge_attr = batch.edge_attr.float()
                    
                    # 使用不确定性估计进行预测
                    head_model.train()  # 开启dropout进行MC采样
                    pred_head, head_std = compute_uncertainty(head_model, batch, mc_samples=config['mc_samples'])
                    
                    # 计算验证损失
                    criterion_output = criterion_head(pred_head, batch)
                    
                    # 处理损失函数的返回值
                    if isinstance(criterion_output, tuple):
                        head_criterion_loss = criterion_output[0]
                    else:
                        head_criterion_loss = criterion_output
                    
                    # 计算指标
                    metrics = compute_metrics(batch.head_y, pred_head)
                    
                    for k in metrics:
                        val_metrics[k] += metrics[k]
                    
                    val_loss += head_criterion_loss.item()
                    val_batches += 1
                    
                except Exception as e:
                    print(f"水头模型验证批次 {batch_idx} 出错: {e}")
                    continue
        
        # 计算平均验证损失和指标
        if val_batches > 0:
            avg_val_loss = val_loss / val_batches
            for k in val_metrics:
                val_metrics[k] /= val_batches
        else:
            print("警告: 水头模型本轮验证没有成功处理任何批次")
            avg_val_loss = float('inf')
        
        # 记录损失
        head_losses['train'].append(avg_train_loss)
        head_losses['val'].append({
            'loss': avg_val_loss,
            'metrics': val_metrics
        })
        
        # 更新学习率
        head_scheduler.step()
        current_lr = head_scheduler.get_last_lr()[0]
        
        # 输出训练状态
        print(f"水头模型 Epoch {epoch+1:03d}/{config['num_epochs']} | "
              f"训练损失: {avg_train_loss:.4f} | 验证损失: {avg_val_loss:.4f} | "
              f"LR: {current_lr:.6f}")
        print(f"水头验证指标 - MSE: {val_metrics['mse']:.4f}, "
              f"RMSE: {val_metrics['rmse']:.4f}, "
              f"MAE: {val_metrics['mae']:.4f}, "
              f"R2: {val_metrics['r2']:.4f}")
        
        # 保存基于损失的最佳模型
        if avg_val_loss < best_head_val_loss:
            best_head_val_loss = avg_val_loss
            try:
                torch.save({
                    'model_state_dict': head_model.state_dict(),
                    'epoch': epoch,
                    'train_loss': avg_train_loss,
                    'val_loss': avg_val_loss,
                    'val_metrics': val_metrics,
                    'config': config,
                    'criterion': 'loss'
                }, os.path.join(config['save_path'], 'best_head_model_loss.pth'))
                print(f"保存基于损失的最佳水头模型，验证损失: {best_head_val_loss:.4f}")
            except Exception as e:
                print(f"保存水头模型失败: {e}")
        
        # 保存基于R2的最佳模型
        if val_metrics['r2'] > best_head_r2:
            best_head_r2 = val_metrics['r2']
            head_early_stop_counter = 0  # 基于R2重置早停计数器
            try:
                torch.save({
                    'model_state_dict': head_model.state_dict(),
                    'epoch': epoch,
                    'train_loss': avg_train_loss,
                    'val_loss': avg_val_loss,
                    'val_metrics': val_metrics,
                    'config': config,
                    'criterion': 'r2'
                }, os.path.join(config['save_path'], 'best_head_model_r2.pth'))
                print(f"保存基于R2的最佳水头模型，R2: {best_head_r2:.4f}")
            except Exception as e:
                print(f"保存水头模型失败: {e}")
        else:
            head_early_stop_counter += 1
        
        # 早停检查（基于R2）
        if head_early_stop_counter >= config['patience']:
            print(f"水头模型早停触发! 在第{epoch+1}个epoch停止训练")
            break
        
        # 清理GPU内存
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    print(f"\n水头模型训练完成!")
    print(f"基于损失的最佳验证损失: {best_head_val_loss:.4f}")
    print(f"基于R2的最佳R2分数: {best_head_r2:.4f}")
    
    # ===============================
    # 第二阶段：训练浓度模型
    # ===============================
    print("\n" + "=" * 80)
    print("第二阶段：开始训练浓度模型")
    print("=" * 80)
    
    # 根据评估标准选择水头模型
    head_model_file = f'best_head_model_{evaluation_criterion}.pth'
    try:
        checkpoint = torch.load(os.path.join(config['save_path'], head_model_file),weights_only=False)
        head_model.load_state_dict(checkpoint['model_state_dict'])
        print(f"成功加载基于{evaluation_criterion}的最佳水头模型")
    except Exception as e:
        print(f"加载水头模型失败，使用当前模型: {e}")
    
    # 固定水头模型参数
    for param in head_model.parameters():
        param.requires_grad = False
    head_model.eval()
    
    # 初始化浓度模型
    conc_input_dim = 19  # 基础特征14 + 前一/前二时间步水头2 + 前一/前二时间步浓度2
    conc_model = BlitzConcGNN(
        node_features=conc_input_dim,  # 18维
        spatial_dim=config['hidden_dim'],
        temporal_dim=config['hidden_dim'],
        prior_sigma_1=0.1,
        prior_sigma_2=0.01,
        dropout=0.1,
        posterior_mu_init=0.0,
        posterior_rho_init=-3.0
    ).to(device)
    
    # 浓度模型损失函数
    criterion_conc = ImprovedConcLoss(kl_weight=config['kl_weight'], l1_weight=1e-5)
    
    # 浓度模型优化器
    conc_params = list(conc_model.parameters())
    conc_optimizer = torch.optim.AdamW(conc_params, lr=config['lr'] * 0.8, weight_decay=config['weight_decay'])
    
    # 浓度模型学习率调度器
    conc_scheduler = CosineAnnealingWarmRestarts(
        conc_optimizer, T_0=15, T_mult=2, eta_min=1e-6
    )
    
    # 浓度模型跟踪变量 - 分别跟踪损失和R2
    best_conc_val_loss = float('inf')
    best_conc_r2 = float('-inf')
    conc_early_stop_counter = 0
    conc_losses = {'train': [], 'val': []}
    
    print(f"浓度模型参数数量: {sum(p.numel() for p in conc_model.parameters() if p.requires_grad)}")
    
    # 浓度模型训练循环
    for epoch in range(config['num_epochs']):
        conc_model.train()
        train_loss = 0.0
        train_batches = 0
        
        # 训练阶段
        for batch_idx, batch in enumerate(train_loader):
            try:
                # 准备数据
                batch = batch.to(device)
                if hasattr(batch, 'edge_attr') and batch.edge_attr is not None:
                    batch.edge_attr = batch.edge_attr.float()
                
                # 使用固定的水头模型预测水头
                with torch.no_grad():
                    pred_head = head_model(batch)
                
                # 重置梯度
                conc_optimizer.zero_grad()
                
                # 浓度模型前向传播，使用预测的水头
                pred_conc = conc_model(batch, pred_head)
                
                # 计算损失
                criterion_output = criterion_conc(pred_conc, batch, conc_model)
                
                # 处理损失函数的返回值
                if isinstance(criterion_output, tuple):
                    conc_criterion_loss = criterion_output[0]
                    if len(criterion_output) > 1:
                        loss_components = criterion_output[1]
                        if isinstance(loss_components, tuple) and len(loss_components) >= 3:
                            mse_loss, kl_loss_val, l1_reg = loss_components[:3]
                        else:
                            mse_loss, kl_loss_val, l1_reg = 0.0, 0.0, 0.0
                    else:
                        mse_loss, kl_loss_val, l1_reg = 0.0, 0.0, 0.0
                else:
                    conc_criterion_loss = criterion_output
                    mse_loss, kl_loss_val, l1_reg = 0.0, 0.0, 0.0
                
                kl_loss = conc_model.nn_kl_divergence() * config['kl_weight']
                total_loss = conc_criterion_loss + kl_loss
                
                # 反向传播
                total_loss.backward()
                
                # 梯度裁剪
                torch.nn.utils.clip_grad_norm_(conc_params, max_norm=1.0)
                
                # 更新参数
                conc_optimizer.step()
                
                # 记录损失
                train_loss += total_loss.item()
                train_batches += 1
                
                # 每50个批次输出一次详细信息
                if batch_idx % 50 == 0:
                    print(f"浓度模型 Epoch {epoch+1}, Batch {batch_idx}: "
                          f"Total Loss: {total_loss.item():.4f}, "
                          f"Criterion Loss: {conc_criterion_loss.item():.4f}, "
                          f"MSE: {mse_loss:.4f}, "
                          f"KL Loss: {kl_loss.item():.4f}, "
                          f"L1 Reg: {l1_reg:.4f}")
                
            except Exception as e:
                print(f"浓度模型训练批次 {batch_idx} 出错: {e}")
                continue
        
        # 检查训练批次
        if train_batches == 0:
            print("警告: 浓度模型本轮训练没有成功处理任何批次，跳过本轮")
            continue
            
        # 计算平均训练损失
        avg_train_loss = train_loss / train_batches
        
        # 验证阶段
        conc_model.eval()
        val_loss = 0.0
        val_metrics = {'mse': 0.0, 'rmse': 0.0, 'mae': 0.0, 'r2': 0.0}
        val_batches = 0
        
        # 用于存储预测和真实值
        all_conc_predictions = []
        all_conc_targets = []
        all_conc_uncertainties = []
        all_head_predictions = []
        
        with torch.no_grad():
            for batch_idx, batch in enumerate(val_loader):
                try:
                    # 准备数据
                    batch = batch.to(device)
                    if hasattr(batch, 'edge_attr') and batch.edge_attr is not None:
                        batch.edge_attr = batch.edge_attr.float()
                    
                    # 使用水头模型预测水头
                    pred_head = head_model(batch)
                    
                    # 使用不确定性估计进行浓度预测
                    conc_model.train()  # 开启dropout进行MC采样
                    pred_conc, conc_std = compute_uncertainty(conc_model, batch, mc_samples=config['mc_samples'])
                    
                    # 计算验证损失
                    criterion_output = criterion_conc(pred_conc, batch, conc_model)
                    
                    # 处理损失函数的返回值
                    if isinstance(criterion_output, tuple):
                        conc_criterion_loss = criterion_output[0]
                    else:
                        conc_criterion_loss = criterion_output
                    
                    # 计算指标
                    metrics = compute_metrics(batch.y, pred_conc)
                    
                    for k in metrics:
                        val_metrics[k] += metrics[k]
                    
                    val_loss += conc_criterion_loss.item()
                    val_batches += 1
                    
                    # 收集预测结果用于后续分析
                    all_conc_predictions.append(pred_conc.cpu().numpy())
                    all_conc_targets.append(batch.y.cpu().numpy())
                    all_conc_uncertainties.append(conc_std.cpu().numpy())
                    all_head_predictions.append(pred_head.cpu().numpy())
                    
                except Exception as e:
                    print(f"浓度模型验证批次 {batch_idx} 出错: {e}")
                    continue
        
        # 计算平均验证损失和指标
        if val_batches > 0:
            avg_val_loss = val_loss / val_batches
            for k in val_metrics:
                val_metrics[k] /= val_batches
        else:
            print("警告: 浓度模型本轮验证没有成功处理任何批次")
            avg_val_loss = float('inf')
        
        # 记录损失
        conc_losses['train'].append(avg_train_loss)
        conc_losses['val'].append({
            'loss': avg_val_loss,
            'metrics': val_metrics
        })
        
        # 更新学习率
        conc_scheduler.step()
        current_lr = conc_scheduler.get_last_lr()[0]
        
        # 输出训练状态
        print(f"浓度模型 Epoch {epoch+1:03d}/{config['num_epochs']} | "
              f"训练损失: {avg_train_loss:.4f} | 验证损失: {avg_val_loss:.4f} | "
              f"LR: {current_lr:.6f}")
        print(f"浓度验证指标 - MSE: {val_metrics['mse']:.4f}, "
              f"RMSE: {val_metrics['rmse']:.4f}, "
              f"MAE: {val_metrics['mae']:.4f}, "
              f"R2: {val_metrics['r2']:.4f}")
        
        # 保存基于损失的最佳浓度模型
        if avg_val_loss < best_conc_val_loss:
            best_conc_val_loss = avg_val_loss
            try:
                torch.save({
                    'model_state_dict': conc_model.state_dict(),
                    'epoch': epoch,
                    'train_loss': avg_train_loss,
                    'val_loss': avg_val_loss,
                    'val_metrics': val_metrics,
                    'config': config,
                    'criterion': 'loss'
                }, os.path.join(config['save_path'], 'best_conc_model_loss.pth'))
                
                # 保存预测结果
                if all_conc_predictions:
                    np.save(os.path.join(config['save_path'], 'best_conc_predictions_loss.npy'), 
                           np.concatenate(all_conc_predictions, axis=0))
                    np.save(os.path.join(config['save_path'], 'best_conc_targets_loss.npy'), 
                           np.concatenate(all_conc_targets, axis=0))
                    np.save(os.path.join(config['save_path'], 'best_conc_uncertainties_loss.npy'), 
                           np.concatenate(all_conc_uncertainties, axis=0))
                
                print(f"保存基于损失的最佳浓度模型，验证损失: {best_conc_val_loss:.4f}")
            except Exception as e:
                print(f"保存浓度模型失败: {e}")
        
        # 保存基于R2的最佳浓度模型
        if val_metrics['r2'] > best_conc_r2:
            best_conc_r2 = val_metrics['r2']
            conc_early_stop_counter = 0  # 基于R2重置早停计数器
            try:
                torch.save({
                    'model_state_dict': conc_model.state_dict(),
                    'epoch': epoch,
                    'train_loss': avg_train_loss,
                    'val_loss': avg_val_loss,
                    'val_metrics': val_metrics,
                    'config': config,
                    'criterion': 'r2'
                }, os.path.join(config['save_path'], 'best_conc_model_r2.pth'))
                
                # 保存预测结果
                if all_conc_predictions:
                    np.save(os.path.join(config['save_path'], 'best_conc_predictions_r2.npy'), 
                           np.concatenate(all_conc_predictions, axis=0))
                    np.save(os.path.join(config['save_path'], 'best_conc_targets_r2.npy'), 
                           np.concatenate(all_conc_targets, axis=0))
                    np.save(os.path.join(config['save_path'], 'best_conc_uncertainties_r2.npy'), 
                           np.concatenate(all_conc_uncertainties, axis=0))
                
                print(f"保存基于R2的最佳浓度模型，R2: {best_conc_r2:.4f}")
            except Exception as e:
                print(f"保存浓度模型失败: {e}")
        else:
            conc_early_stop_counter += 1
        
        # 早停检查（基于R2）
        if conc_early_stop_counter >= config['patience']:
            print(f"浓度模型早停触发! 在第{epoch+1}个epoch停止训练")
            break
        
        # 清理GPU内存
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    print(f"\n浓度模型训练完成!")
    print(f"基于损失的最佳验证损失: {best_conc_val_loss:.4f}")
    print(f"基于R2的最佳R2分数: {best_conc_r2:.4f}")
    
    # ===============================
    # 保存训练历史和可视化
    # ===============================
    try:
        # 保存水头模型训练历史
        head_history_data = []
        for i, (train_loss, val_data) in enumerate(zip(head_losses['train'], head_losses['val'])):
            head_history_data.append({
                'epoch': i + 1,
                'train_loss': train_loss,
                'val_loss': val_data['loss'],
                'val_mse': val_data['metrics']['mse'],
                'val_rmse': val_data['metrics']['rmse'],
                'val_mae': val_data['metrics']['mae'],
                'val_r2': val_data['metrics']['r2']
            })
        
        if head_history_data:
            head_history_df = pd.DataFrame(head_history_data)
            head_history_df.to_csv(os.path.join(config['save_path'], 'head_training_history.csv'), index=False)
        
        # 保存浓度模型训练历史
        conc_history_data = []
        for i, (train_loss, val_data) in enumerate(zip(conc_losses['train'], conc_losses['val'])):
            conc_history_data.append({
                'epoch': i + 1,
                'train_loss': train_loss,
                'val_loss': val_data['loss'],
                'val_mse': val_data['metrics']['mse'],
                'val_rmse': val_data['metrics']['rmse'],
                'val_mae': val_data['metrics']['mae'],
                'val_r2': val_data['metrics']['r2']
            })
        
        if conc_history_data:
            conc_history_df = pd.DataFrame(conc_history_data)
            conc_history_df.to_csv(os.path.join(config['save_path'], 'conc_training_history.csv'), index=False)
        
        # 绘制双模型训练曲线，包含最佳点标记
        if head_history_data and conc_history_data:
            plt.figure(figsize=(20, 12))
            
            # 水头模型曲线
            plt.subplot(2, 4, 1)
            plt.plot(head_history_df['epoch'], head_history_df['train_loss'], 'b-', label='Head Train Loss')
            plt.plot(head_history_df['epoch'], head_history_df['val_loss'], 'r-', label='Head Val Loss')
            # 标记最佳损失点
            best_loss_epoch = head_history_df.loc[head_history_df['val_loss'].idxmin(), 'epoch']
            best_loss_value = head_history_df['val_loss'].min()
            plt.scatter(best_loss_epoch, best_loss_value, color='red', s=100, marker='*', 
                       label=f'Best Loss (E{best_loss_epoch})')
            plt.title('Head Model: Training and Validation Loss')
            plt.xlabel('Epoch')
            plt.ylabel('Loss')
            plt.legend()
            plt.grid(True)
            
            plt.subplot(2, 4, 2)
            plt.plot(head_history_df['epoch'], head_history_df['val_r2'], 'g-', label='Head R2')
            # 标记最佳R2点
            best_r2_epoch = head_history_df.loc[head_history_df['val_r2'].idxmax(), 'epoch']
            best_r2_value = head_history_df['val_r2'].max()
            plt.scatter(best_r2_epoch, best_r2_value, color='green', s=100, marker='*', 
                       label=f'Best R2 (E{best_r2_epoch})')
            plt.title('Head Model: Validation R2 Score')
            plt.xlabel('Epoch')
            plt.ylabel('R2')
            plt.legend()
            plt.grid(True)
            
            plt.subplot(2, 4, 3)
            plt.plot(head_history_df['epoch'], head_history_df['val_mse'], 'orange', label='Head MSE')
            plt.title('Head Model: Validation MSE')
            plt.xlabel('Epoch')
            plt.ylabel('MSE')
            plt.legend()
            plt.grid(True)
            
            plt.subplot(2, 4, 4)
            plt.plot(head_history_df['epoch'], head_history_df['val_rmse'], 'purple', label='Head RMSE')
            plt.title('Head Model: Validation RMSE')
            plt.xlabel('Epoch')
            plt.ylabel('RMSE')
            plt.legend()
            plt.grid(True)
            
            # 浓度模型曲线
            plt.subplot(2, 4, 5)
            plt.plot(conc_history_df['epoch'], conc_history_df['train_loss'], 'b--', label='Conc Train Loss')
            plt.plot(conc_history_df['epoch'], conc_history_df['val_loss'], 'r--', label='Conc Val Loss')
            # 标记最佳损失点
            best_loss_epoch = conc_history_df.loc[conc_history_df['val_loss'].idxmin(), 'epoch']
            best_loss_value = conc_history_df['val_loss'].min()
            plt.scatter(best_loss_epoch, best_loss_value, color='red', s=100, marker='*', 
                       label=f'Best Loss (E{best_loss_epoch})')
            plt.title('Conc Model: Training and Validation Loss')
            plt.xlabel('Epoch')
            plt.ylabel('Loss')
            plt.legend()
            plt.grid(True)
            
            plt.subplot(2, 4, 6)
            plt.plot(conc_history_df['epoch'], conc_history_df['val_r2'], 'g--', label='Conc R2')
            # 标记最佳R2点
            best_r2_epoch = conc_history_df.loc[conc_history_df['val_r2'].idxmax(), 'epoch']
            best_r2_value = conc_history_df['val_r2'].max()
            plt.scatter(best_r2_epoch, best_r2_value, color='green', s=100, marker='*', 
                       label=f'Best R2 (E{best_r2_epoch})')
            plt.title('Conc Model: Validation R2 Score')
            plt.xlabel('Epoch')
            plt.ylabel('R2')
            plt.legend()
            plt.grid(True)
            
            plt.subplot(2, 4, 7)
            plt.plot(conc_history_df['epoch'], conc_history_df['val_mse'], 'orange', linestyle='--', label='Conc MSE')
            plt.title('Conc Model: Validation MSE')
            plt.xlabel('Epoch')
            plt.ylabel('MSE')
            plt.legend()
            plt.grid(True)
            
            plt.subplot(2, 4, 8)
            plt.plot(conc_history_df['epoch'], conc_history_df['val_rmse'], 'purple', linestyle='--', label='Conc RMSE')
            plt.title('Conc Model: Validation RMSE')
            plt.xlabel('Epoch')
            plt.ylabel('RMSE')
            plt.legend()
            plt.grid(True)
            
            plt.tight_layout()
            plt.savefig(os.path.join(config['save_path'], 'dual_model_training_curves_improved.png'), 
                       dpi=300, bbox_inches='tight')
            plt.close()
            
            print(f"改进的双模型训练曲线已保存")
        
    except Exception as e:
        print(f"保存训练历史或绘图失败: {e}")
    
    # 重新启用水头模型的梯度计算（如果需要）
    for param in head_model.parameters():
        param.requires_grad = True
    
    print("\n" + "=" * 80)
    print("双模型训练完成总结:")
    print(f"水头模型 - 基于损失的最佳验证损失: {best_head_val_loss:.4f}")
    print(f"水头模型 - 基于R2的最佳R2分数: {best_head_r2:.4f}")
    print(f"浓度模型 - 基于损失的最佳验证损失: {best_conc_val_loss:.4f}")
    print(f"浓度模型 - 基于R2的最佳R2分数: {best_conc_r2:.4f}")
    print(f"评估将使用基于{evaluation_criterion}的模型")
    print("=" * 80)
    
    return head_model, conc_model, {'head': head_losses, 'conc': conc_losses}

def evaluate_dual_model_improved(head_model, conc_model, val_loader, evaluation_criterion='loss'):
    """
    改进的双模型评估，可选择基于损失或R2的最佳模型
    
    Args:
        head_model: 水头预测模型
        conc_model: 浓度预测模型
        val_loader: 验证数据加载器
        evaluation_criterion: 评估标准 ('loss' 或 'r2')
    """
    print(f"开始评估双模型性能（基于{evaluation_criterion}标准）...")
    
    # 加载指定标准的最佳模型权重
    try:
        head_model_file = f'best_head_model_{evaluation_criterion}.pth'
        print()
        head_checkpoint = torch.load(os.path.join(config['save_path'], head_model_file),weights_only=False)
        head_model.load_state_dict(head_checkpoint['model_state_dict'])
        
        conc_model_file = f'best_conc_model_{evaluation_criterion}.pth'
        conc_checkpoint = torch.load(os.path.join(config['save_path'], conc_model_file),weights_only=False)
        conc_model.load_state_dict(conc_checkpoint['model_state_dict'])
        
        print(f"成功加载基于{evaluation_criterion}的最佳模型权重")
        print(f"水头模型来自epoch {head_checkpoint['epoch']}, 验证损失: {head_checkpoint['val_loss']:.4f}, R2: {head_checkpoint['val_metrics']['r2']:.4f}")
        print(f"浓度模型来自epoch {conc_checkpoint['epoch']}, 验证损失: {conc_checkpoint['val_loss']:.4f}, R2: {conc_checkpoint['val_metrics']['r2']:.4f}")
        
    except Exception as e:
        print(f"加载模型权重失败，使用当前权重: {e}")
    
    head_model.eval()
    conc_model.eval()
    
    # 存储所有预测结果
    all_head_preds = []
    all_head_targets = []
    all_head_uncertainties = []
    all_conc_preds = []
    all_conc_targets = []
    all_conc_uncertainties = []
    
    with torch.no_grad():
        for batch in val_loader:
            try:
                batch = batch.to(device)
                if hasattr(batch, 'edge_attr') and batch.edge_attr is not None:
                    batch.edge_attr = batch.edge_attr.float()
                
                # 水头预测
                head_pred, head_std = compute_uncertainty(head_model, batch, mc_samples=config['mc_samples'])
                
                # 浓度预测（使用预测的水头）
                conc_pred, conc_std = compute_uncertainty(conc_model, batch, mc_samples=config['mc_samples'])
                
                # 收集结果
                all_head_preds.append(head_pred.cpu().numpy())
                all_head_targets.append(batch.head_y.cpu().numpy())
                all_head_uncertainties.append(head_std.cpu().numpy())
                all_conc_preds.append(conc_pred.cpu().numpy())
                all_conc_targets.append(batch.y.cpu().numpy())
                all_conc_uncertainties.append(conc_std.cpu().numpy())
                
            except Exception as e:
                print(f"评估批次出错: {e}")
                continue
    
    # 合并所有预测结果
    all_head_preds = np.concatenate(all_head_preds, axis=0)
    all_head_targets = np.concatenate(all_head_targets, axis=0)
    all_head_uncertainties = np.concatenate(all_head_uncertainties, axis=0)
    all_conc_preds = np.concatenate(all_conc_preds, axis=0)
    all_conc_targets = np.concatenate(all_conc_targets, axis=0)
    all_conc_uncertainties = np.concatenate(all_conc_uncertainties, axis=0)
    
    # 计算指标
    head_metrics = compute_metrics(all_head_targets, all_head_preds)
    conc_metrics = compute_metrics(all_conc_targets, all_conc_preds)
    
    print(f"\n水头模型评估结果（基于{evaluation_criterion}）:")
    for metric, value in head_metrics.items():
        print(f"  {metric.upper()}: {value:.4f}")
    
    print(f"\n浓度模型评估结果（基于{evaluation_criterion}）:")
    for metric, value in conc_metrics.items():
        print(f"  {metric.upper()}: {value:.4f}")
    
    # 保存评估结果
    evaluation_results = {
        'criterion': evaluation_criterion,
        'head_metrics': head_metrics,
        'conc_metrics': conc_metrics,
        'head_predictions': all_head_preds,
        'head_targets': all_head_targets,
        'head_uncertainties': all_head_uncertainties,
        'conc_predictions': all_conc_preds,
        'conc_targets': all_conc_targets,
        'conc_uncertainties': all_conc_uncertainties
    }
    
    # 保存为文件
    filename = f'dual_model_evaluation_{evaluation_criterion}.npy'
    np.save(os.path.join(config['save_path'], filename), evaluation_results)
    print(f"\n评估结果已保存到: {config['save_path']}/{filename}")
    
    return evaluation_results

def compare_model_criteria(head_model, conc_model, val_loader):
    """
    比较基于损失和基于R2的模型性能
    """
    print("=" * 80)
    print("比较不同选择标准的模型性能")
    print("=" * 80)
    
    # 分别评估两种标准的模型
    loss_results = evaluate_dual_model_improved(head_model, conc_model, val_loader, 'loss')
    r2_results = evaluate_dual_model_improved(head_model, conc_model, val_loader, 'r2')
    
    # 创建比较表格
    comparison_data = []
    
    # 水头模型比较
    comparison_data.append({
        'Model': 'Head',
        'Criterion': 'Loss',
        'MSE': loss_results['head_metrics']['mse'],
        'RMSE': loss_results['head_metrics']['rmse'],
        'MAE': loss_results['head_metrics']['mae'],
        'R2': loss_results['head_metrics']['r2']
    })
    
    comparison_data.append({
        'Model': 'Head',
        'Criterion': 'R2',
        'MSE': r2_results['head_metrics']['mse'],
        'RMSE': r2_results['head_metrics']['rmse'],
        'MAE': r2_results['head_metrics']['mae'],
        'R2': r2_results['head_metrics']['r2']
    })
    
    # 浓度模型比较
    comparison_data.append({
        'Model': 'Concentration',
        'Criterion': 'Loss',
        'MSE': loss_results['conc_metrics']['mse'],
        'RMSE': loss_results['conc_metrics']['rmse'],
        'MAE': loss_results['conc_metrics']['mae'],
        'R2': loss_results['conc_metrics']['r2']
    })
    
    comparison_data.append({
        'Model': 'Concentration',
        'Criterion': 'R2',
        'MSE': r2_results['conc_metrics']['mse'],
        'RMSE': r2_results['conc_metrics']['rmse'],
        'MAE': r2_results['conc_metrics']['mae'],
        'R2': r2_results['conc_metrics']['r2']
    })
    
    # 创建比较DataFrame
    comparison_df = pd.DataFrame(comparison_data)
    
    # 保存比较结果
    comparison_df.to_csv(os.path.join(config['save_path'], 'model_criteria_comparison.csv'), index=False)
    
    # 打印比较结果
    print("\n模型选择标准比较结果:")
    print(comparison_df.to_string(index=False, float_format='%.4f'))
    
    # 绘制比较图
    plt.figure(figsize=(15, 10))
    
    # R2比较
    plt.subplot(2, 3, 1)
    head_r2 = [loss_results['head_metrics']['r2'], r2_results['head_metrics']['r2']]
    conc_r2 = [loss_results['conc_metrics']['r2'], r2_results['conc_metrics']['r2']]
    x = ['Loss-based', 'R2-based']
    plt.bar([0, 1], head_r2, alpha=0.7, label='Head Model', width=0.35)
    plt.bar([0.35, 1.35], conc_r2, alpha=0.7, label='Concentration Model', width=0.35)
    plt.xlabel('Model Selection Criterion')
    plt.ylabel('R2 Score')
    plt.title('R2 Score Comparison')
    plt.xticks([0.175, 1.175], x)
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # MSE比较
    plt.subplot(2, 3, 2)
    head_mse = [loss_results['head_metrics']['mse'], r2_results['head_metrics']['mse']]
    conc_mse = [loss_results['conc_metrics']['mse'], r2_results['conc_metrics']['mse']]
    plt.bar([0, 1], head_mse, alpha=0.7, label='Head Model', width=0.35)
    plt.bar([0.35, 1.35], conc_mse, alpha=0.7, label='Concentration Model', width=0.35)
    plt.xlabel('Model Selection Criterion')
    plt.ylabel('MSE')
    plt.title('MSE Comparison')
    plt.xticks([0.175, 1.175], x)
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # RMSE比较
    plt.subplot(2, 3, 3)
    head_rmse = [loss_results['head_metrics']['rmse'], r2_results['head_metrics']['rmse']]
    conc_rmse = [loss_results['conc_metrics']['rmse'], r2_results['conc_metrics']['rmse']]
    plt.bar([0, 1], head_rmse, alpha=0.7, label='Head Model', width=0.35)
    plt.bar([0.35, 1.35], conc_rmse, alpha=0.7, label='Concentration Model', width=0.35)
    plt.xlabel('Model Selection Criterion')
    plt.ylabel('RMSE')
    plt.title('RMSE Comparison')
    plt.xticks([0.175, 1.175], x)
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # MAE比较
    plt.subplot(2, 3, 4)
    head_mae = [loss_results['head_metrics']['mae'], r2_results['head_metrics']['mae']]
    conc_mae = [loss_results['conc_metrics']['mae'], r2_results['conc_metrics']['mae']]
    plt.bar([0, 1], head_mae, alpha=0.7, label='Head Model', width=0.35)
    plt.bar([0.35, 1.35], conc_mae, alpha=0.7, label='Concentration Model', width=0.35)
    plt.xlabel('Model Selection Criterion')
    plt.ylabel('MAE')
    plt.title('MAE Comparison')
    plt.xticks([0.175, 1.175], x)
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # 水头模型散点图比较
    plt.subplot(2, 3, 5)
    plt.scatter(loss_results['head_targets'], loss_results['head_predictions'], alpha=0.5, label='Loss-based')
    plt.scatter(r2_results['head_targets'], r2_results['head_predictions'], alpha=0.5, label='R2-based')
    plt.plot([min(loss_results['head_targets']), max(loss_results['head_targets'])], 
             [min(loss_results['head_targets']), max(loss_results['head_targets'])], 'r--', alpha=0.8)
    plt.xlabel('True Head Values')
    plt.ylabel('Predicted Head Values')
    plt.title('Head Model: True vs Predicted')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # 浓度模型散点图比较
    plt.subplot(2, 3, 6)
    plt.scatter(loss_results['conc_targets'], loss_results['conc_predictions'], alpha=0.5, label='Loss-based')
    plt.scatter(r2_results['conc_targets'], r2_results['conc_predictions'], alpha=0.5, label='R2-based')
    plt.plot([min(loss_results['conc_targets']), max(loss_results['conc_targets'])], 
             [min(loss_results['conc_targets']), max(loss_results['conc_targets'])], 'r--', alpha=0.8)
    plt.xlabel('True Concentration Values')
    plt.ylabel('Predicted Concentration Values')
    plt.title('Concentration Model: True vs Predicted')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(config['save_path'], 'model_criteria_comparison.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"\n比较图表已保存到: {config['save_path']}/model_criteria_comparison.png")
    print(f"比较数据已保存到: {config['save_path']}/model_criteria_comparison.csv")
    
    return comparison_df, loss_results, r2_results


In [3]:
data = pd.read_csv('conc_dual_guass.csv')  # 替换为您的数据文件
train_loader, val_loader = prepare_data(data, batch_size=4)
head_model, conc_model, training_losses = train_dual_model_improved(train_loader, val_loader)
evaluation_results = evaluate_dual_model_improved(head_model, conc_model, val_loader)

正在处理数据...

▶ Started building spatiotemporal graphs
▷ Total models to process: 100

▣ Processing model 1/100: dual_42

▣ Processing model 2/100: dual_93

▣ Processing model 3/100: dual_71

▣ Processing model 4/100: dual_31

▣ Processing model 5/100: dual_60

▣ Processing model 6/100: dual_15

▣ Processing model 7/100: dual_88

▣ Processing model 8/100: dual_40

▣ Processing model 9/100: dual_20

▣ Processing model 10/100: dual_94

▣ Processing model 11/100: dual_76

▣ Processing model 12/100: dual_84

▣ Processing model 13/100: dual_62

▣ Processing model 14/100: dual_10

▣ Processing model 15/100: dual_21

▣ Processing model 16/100: dual_0

▣ Processing model 17/100: dual_29

▣ Processing model 18/100: dual_49

▣ Processing model 19/100: dual_37

▣ Processing model 20/100: dual_23

▣ Processing model 21/100: dual_61

▣ Processing model 22/100: dual_68

▣ Processing model 23/100: dual_18

▣ Processing model 24/100: dual_66

▣ Processing model 25/100: dual_17

▣ Processing model 26/100: