In [None]:
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from torch_geometric.utils import add_self_loops
from sklearn.model_selection import train_test_split                                                                                                   
from sklearn.preprocessing import StandardScaler
import numpy as np
import os
from torch.amp import autocast
from torch.optim.lr_scheduler import CosineAnnealingLR, CosineAnnealingWarmRestarts
from sklearn.metrics import r2_score as sklearn_r2_score
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
import sys
sys.path.append('/home/jovyan/work/GNO/GNN/GNNShap')
from gnnshap.explainer import GNNShapExplainer
import uuid


from blitz.modules import BayesianLinear                 
from blitz.utils import variational_estimator


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


class HydroDataset(torch.utils.data.Dataset):
    def __init__(self, graphs):
        self.graphs = graphs
    
    def __len__(self):
        return len(self.graphs)
    
    def __getitem__(self, idx):
        return self.graphs[idx]

class BlitzSTConv(MessagePassing):
    """
    使用Blitz实现的贝叶斯时空卷积层，增强鲁棒性
    """
    def __init__(self, spatial_dim, prior_sigma_1=0.1, prior_sigma_2=0.002, posterior_mu_init=0.0, posterior_rho_init=-3.0, dropout=0.1):
        super().__init__(aggr='mean')

        self.pre_msg = nn.Linear(2 * spatial_dim + 3, spatial_dim)
        

        self.bayes_msg = nn.Sequential(
            BayesianLinear(spatial_dim, spatial_dim,
                        prior_sigma_1=prior_sigma_1, 
                        prior_sigma_2=prior_sigma_2,
                        posterior_mu_init=posterior_mu_init,
                        posterior_rho_init=posterior_rho_init),
            nn.ReLU(inplace=True),
            nn.LayerNorm(spatial_dim),
            nn.Dropout(dropout)
        )
        

        self.pre_gate = nn.Linear(3 * spatial_dim, spatial_dim)
        self.bayes_gate = nn.Sequential(
            BayesianLinear(spatial_dim, spatial_dim,
                        prior_sigma_1=prior_sigma_1, 
                        prior_sigma_2=prior_sigma_2,
                        posterior_mu_init=posterior_mu_init,
                        posterior_rho_init=posterior_rho_init),
            nn.Sigmoid()
        )
        

        self.pre_res = nn.Linear(spatial_dim, spatial_dim)
        self.bayes_res = BayesianLinear(spatial_dim, spatial_dim,
                                     prior_sigma_1=prior_sigma_1/2, 
                                     prior_sigma_2=prior_sigma_2/2,
                                     posterior_mu_init=posterior_mu_init,
                                     posterior_rho_init=posterior_rho_init)

    def forward(self, x, edge_index, edge_attr):
        try:

            edge_attr = edge_attr.float()
            

            out = self.propagate(edge_index, x=x, edge_attr=edge_attr)
            

            combined = torch.cat([x, out, x - out], dim=-1)
            gate_pre = self.pre_gate(combined)
            gate = self.bayes_gate(gate_pre)
            

            res_pre = self.pre_res(x)
            res = self.bayes_res(res_pre)
            
            return x + gate * out + 0.1 * res
        except Exception as e:
            print(f"Error in BlitzSTConv.forward: {e}")

            print(f"x shape: {x.shape}, dtype: {x.dtype}")
            print(f"edge_index shape: {edge_index.shape}, dtype: {edge_index.dtype}")
            print(f"edge_attr shape: {edge_attr.shape}, dtype: {edge_attr.dtype}")
            raise

    def message(self, x_i, x_j, edge_attr):
        try:

            edge_attr = edge_attr.to(x_i.dtype).to(x_i.device)
            

            combined = torch.cat([x_i, x_j, edge_attr], dim=-1)
            pre_msg = self.pre_msg(combined)
            return self.bayes_msg(pre_msg)
        except Exception as e:
            print(f"Error in BlitzSTConv.message: {e}")

            print(f"x_i shape: {x_i.shape}, dtype: {x_i.dtype}")
            print(f"x_j shape: {x_j.shape}, dtype: {x_j.dtype}")
            print(f"edge_attr shape: {edge_attr.shape}, dtype: {edge_attr.dtype}")
            raise

class BlitzBoundaryProcessor(nn.Module):
    """
    使用Blitz实现的贝叶斯边界处理器
    """
    def __init__(self, dim, prior_sigma_1=0.1, prior_sigma_2=0.002, posterior_mu_init=0.0, posterior_rho_init=-3.0):
        super().__init__()
        self.boundary_net = nn.Sequential(
            BayesianLinear(dim + 1, dim,
                        prior_sigma_1=prior_sigma_1, 
                        prior_sigma_2=prior_sigma_2,
                        posterior_mu_init=posterior_mu_init,
                        posterior_rho_init=posterior_rho_init),
            nn.ReLU()
        )
        
        self.river_net = nn.Sequential(
            BayesianLinear(dim + 2, dim,
                        prior_sigma_1=prior_sigma_1, 
                        prior_sigma_2=prior_sigma_2,
                        posterior_mu_init=posterior_mu_init,
                        posterior_rho_init=posterior_rho_init),
            nn.ReLU()
        )
        
        self.well_net = BayesianLinear(dim + 1, dim,
                                    prior_sigma_1=prior_sigma_1, 
                                    prior_sigma_2=prior_sigma_2,
                                    posterior_mu_init=posterior_mu_init,
                                    posterior_rho_init=posterior_rho_init)
        
        self.gate = nn.Sequential(
            BayesianLinear(2 * dim, dim,
                        prior_sigma_1=prior_sigma_1/2, 
                        prior_sigma_2=prior_sigma_2/2,
                        posterior_mu_init=posterior_mu_init,
                        posterior_rho_init=posterior_rho_init),
            nn.Sigmoid()
        )
        
        self.chd_enforcer = BayesianLinear(dim, dim,
                                        prior_sigma_1=prior_sigma_1/2, 
                                        prior_sigma_2=prior_sigma_2/2,
                                        posterior_mu_init=posterior_mu_init,
                                        posterior_rho_init=posterior_rho_init)
    
    def forward(self, x, bc_mask):
        boundary_feat = self.boundary_net(
            torch.cat([x, bc_mask[:, 0:1]], dim=-1)
        ) * bc_mask[:, 0:1]
        
        river_feat = self.river_net(
            torch.cat([x, bc_mask[:, 1:3]], dim=-1)
        ) * bc_mask[:, 1:2]
        
        well_feat = self.well_net(
            torch.cat([x, bc_mask[:, 3:4]], dim=-1)
        ) * bc_mask[:, 4:5]
        
        combined = boundary_feat + river_feat + well_feat
        gate = self.gate(torch.cat([x, combined], dim=-1))
        out = x * (1 - gate) + combined * gate
        
        chd_mask = bc_mask[:, 0] > 0
        if chd_mask.sum() > 0:
            chd_out = self.chd_enforcer(out[chd_mask]).to(out.dtype)
            out[chd_mask] = chd_out
            
        return out

@variational_estimator
class BlitzHeadGNN(nn.Module):

    def __init__(self, node_features=16, max_time_steps=40, spatial_dim=64, 
                temporal_dim=64, output_dim=1, prior_sigma_1=0.05, prior_sigma_2=0.001,
                posterior_mu_init=0.0, posterior_rho_init=-3.0, dropout=0.2):
        super().__init__()
        self.spatial_dim = spatial_dim
        self.time_embed = nn.Embedding(max_time_steps + 1, temporal_dim)
        

        self.node_enc = nn.Sequential(
            nn.Linear(node_features + temporal_dim, spatial_dim),
            nn.BatchNorm1d(spatial_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(spatial_dim, spatial_dim),
            nn.BatchNorm1d(spatial_dim),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        

        self.conv_layers = nn.ModuleList([
            BlitzSTConv(spatial_dim, prior_sigma_1, prior_sigma_2, 
                     posterior_mu_init, posterior_rho_init, dropout) 
            for _ in range(4)
        ])
        

        self.bc_processor = BlitzBoundaryProcessor(
            spatial_dim, prior_sigma_1, prior_sigma_2, 
            posterior_mu_init, posterior_rho_init
        )
        

        self.deterministic_path = nn.Sequential(
            nn.Linear(spatial_dim, spatial_dim // 2),
            nn.BatchNorm1d(spatial_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(spatial_dim // 2, spatial_dim // 4),
            nn.ReLU(),
            nn.Linear(spatial_dim // 4, output_dim),
            nn.Softplus()
        )
        

        self.bayesian_path = nn.Sequential(
            BayesianLinear(spatial_dim, spatial_dim // 2,
                          prior_sigma_1=prior_sigma_1, 
                          prior_sigma_2=prior_sigma_2,
                          posterior_mu_init=posterior_mu_init,
                          posterior_rho_init=posterior_rho_init),
            nn.ReLU(),
            nn.Dropout(dropout),
            BayesianLinear(spatial_dim // 2, output_dim,
                          prior_sigma_1=prior_sigma_1/2, 
                          prior_sigma_2=prior_sigma_2/2,
                          posterior_mu_init=posterior_mu_init,
                          posterior_rho_init=posterior_rho_init),
            nn.Softplus()
        )
        

        self.attention = nn.Sequential(
            nn.Linear(spatial_dim, spatial_dim // 4),
            nn.ReLU(),
            nn.Linear(spatial_dim // 4, spatial_dim),
            nn.Sigmoid()
        )

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr.to(torch.float32)
        

        if hasattr(self, 'feature_engineering'):
            x = self.feature_engineering(x)
        

        time_emb = self.time_embed(data.time_step)
        node_feat = torch.cat([x, time_emb], dim=-1)

        h = self.node_enc(node_feat)
        

        for conv in self.conv_layers:
            h_new = conv(h, edge_index, edge_attr)
            h = h + 0.1 * h_new
        

        attention_weights = self.attention(h)
        h = h * attention_weights
        
  
        h = self.bc_processor(h, data.bc_mask)
        

        det_pred = self.deterministic_path(h.detach())
        bayes_pred = self.bayesian_path(h)
        

        return det_pred * 0.7 + bayes_pred * 0.3

class ImprovedPhysicsInformedLoss(nn.Module):

    def __init__(self, alpha=0.5, kl_weight=1e-4):
        super().__init__()
        self.alpha = alpha
        self.kl_weight = kl_weight  
        
    def forward(self, pred, data, model=None):  

        mse_loss = F.mse_loss(pred, data.head_y.unsqueeze(1))

        time_steps = data.time_step.unique(sorted=True)
        flux_loss = 0
        for t in time_steps[:-1]:
            mask_t = (data.time_step == t)
            mask_next = (data.time_step == t + 1)
            if mask_t.sum() > 0 and mask_next.sum() > 0:
                flux_diff = torch.mean((pred[mask_next] - pred[mask_t]) ** 2)
                flux_loss += flux_diff
        flux_loss /= len(time_steps) - 1 if len(time_steps) > 1 else 1
        

        bc_mask = data.bc_mask[:, 0] > 0
        bc_loss = F.l1_loss(pred[bc_mask], data.head_y[bc_mask].unsqueeze(1)) if bc_mask.sum() > 0 else torch.tensor(0.0, device=pred.device)
        

        well_mask = data.bc_mask[:, 4] > 0
        well_loss = F.l1_loss(pred[well_mask], data.head_y[well_mask].unsqueeze(1)) if well_mask.sum() > 0 else torch.tensor(0.0, device=pred.device)
        

        total_loss = (1 - self.alpha) * mse_loss + self.alpha * (flux_loss + bc_loss + well_loss)
        
        return total_loss, (mse_loss.item(), flux_loss.item(), bc_loss.item(), well_loss.item(), 0.0) 


@variational_estimator
class BlitzConcGNN(nn.Module):

    def __init__(self, node_features=19, max_time_steps=40, spatial_dim=128,
                temporal_dim=64, output_dim=1, prior_sigma_1=0.1, prior_sigma_2=0.01,
                posterior_mu_init=0.0, posterior_rho_init=-3.0, dropout=0.1):
        super().__init__()
        self.spatial_dim = spatial_dim
        self.time_embed = nn.Embedding(max_time_steps + 1, temporal_dim)
        

        self.node_enc_scale1 = nn.Sequential(
            nn.Linear(node_features + temporal_dim, spatial_dim),  # 18 + 64 = 82 -> 128
            nn.BatchNorm1d(spatial_dim),
            nn.ReLU(),
            nn.Dropout(dropout * 0.5)
        )
        
        self.node_enc = nn.Sequential(
            BayesianLinear(spatial_dim, spatial_dim,
                        prior_sigma_1=prior_sigma_1, 
                        prior_sigma_2=prior_sigma_2,
                        posterior_mu_init=posterior_mu_init,
                        posterior_rho_init=posterior_rho_init),
            nn.ReLU(inplace=True),
            nn.LayerNorm(spatial_dim),
            nn.Dropout(dropout)
        )
        

        self.conv_layers = nn.ModuleList([
            BlitzSTConv(spatial_dim, prior_sigma_1, prior_sigma_2, 
                     posterior_mu_init, posterior_rho_init, dropout) 
            for _ in range(4)
        ])
        

        self.bc_processor = BlitzBoundaryProcessor(
            spatial_dim, prior_sigma_1, prior_sigma_2, 
            posterior_mu_init, posterior_rho_init
        )
        

        self.attention = nn.Sequential(
            nn.Linear(spatial_dim, spatial_dim // 4),
            nn.ReLU(),
            nn.Linear(spatial_dim // 4, spatial_dim),
            nn.Sigmoid()
        )
        

        self.decoder = nn.Sequential(
            BayesianLinear(spatial_dim, 128,
                        prior_sigma_1=prior_sigma_1/2, 
                        prior_sigma_2=prior_sigma_2/2,
                        posterior_mu_init=posterior_mu_init,
                        posterior_rho_init=posterior_rho_init),
            nn.ReLU(),
            nn.Dropout(dropout),
            BayesianLinear(128, 64,
                        prior_sigma_1=prior_sigma_1/2, 
                        prior_sigma_2=prior_sigma_2/2,
                        posterior_mu_init=posterior_mu_init,
                        posterior_rho_init=posterior_rho_init),
            nn.ReLU(),
            BayesianLinear(64, output_dim,
                        prior_sigma_1=prior_sigma_1/4, 
                        prior_sigma_2=prior_sigma_2/4,
                        posterior_mu_init=posterior_mu_init,
                        posterior_rho_init=posterior_rho_init),
            nn.Softplus()
        )
        

        self.decoder_det = nn.Sequential(
            nn.Linear(spatial_dim, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, output_dim),
            nn.Softplus()
        )
        

        self.branch_weight = nn.Parameter(torch.tensor(0.3))

    def forward(self, data, pred_head=None):

        x = data.conc_x
        edge_index, edge_attr = data.edge_index, data.edge_attr
        

        if edge_attr is not None:
            edge_attr = edge_attr.to(torch.float32).to(x.device)
        

        time_step = data.time_step.to(x.device)
        time_emb = self.time_embed(time_step)

        node_feat = torch.cat([x, time_emb], dim=-1)  
        

        h_scale1 = self.node_enc_scale1(node_feat)
        h = self.node_enc(h_scale1)
        

        for conv in self.conv_layers:
            h_new = conv(h, edge_index, edge_attr)
            h = h + 0.1 * h_new  
        

        attention_weights = self.attention(h)
        h = h * attention_weights
        

        bc_mask = data.bc_mask.to(x.device) if hasattr(data, 'bc_mask') else None
        h = self.bc_processor(h, bc_mask)
        

        bayes_output = self.decoder(h)
        det_output = self.decoder_det(h.detach())
        

        combined_output = (
            torch.sigmoid(self.branch_weight) * bayes_output + 
            (1 - torch.sigmoid(self.branch_weight)) * det_output
        )
        
        return combined_output

class ImprovedConcLoss(nn.Module):

    def __init__(self, kl_weight=5e-5, l1_weight=1e-8): 
        super().__init__()
        self.kl_weight = kl_weight
        self.l1_weight = l1_weight 
        
    def forward(self, pred, data, model=None):

        mse_loss = F.mse_loss(pred, data.y.unsqueeze(1))
        

        l1_reg = torch.tensor(0., device=pred.device)
        if model is not None and self.l1_weight > 0:
            for name, param in model.named_parameters():

                if 'weight' in name and ('bayesian' in name.lower() or 'bayes' in name.lower()):
                    l1_reg += torch.norm(param, 1)
        

        total_loss = mse_loss + self.l1_weight * l1_reg
        
        return total_loss, (mse_loss.item(), 0.0, l1_reg.item())

def generate_spatial_edges(model_df):

    spatial_edges = []
    time_steps = model_df['time_step'].unique()
    for t in time_steps:
        time_df = model_df[model_df['time_step'] == t]
        coord_to_idx = {(r, c): idx for idx, r, c in zip(time_df['local_index'], time_df['row'], time_df['col'])}
        for idx, row, col in zip(time_df['local_index'], time_df['row'], time_df['col']):
            right_coord = (row, col + 1)
            if right_coord in coord_to_idx:
                spatial_edges.append([idx, coord_to_idx[right_coord]])
            upper_coord = (row + 1, col)
            if upper_coord in coord_to_idx:
                spatial_edges.append([idx, coord_to_idx[upper_coord]])
    return np.array(spatial_edges), np.full((len(spatial_edges), 3), [1.0, 0, 0], dtype=np.float32)

def generate_temporal_edges(model_df):

    temporal_edges = []
    groups = model_df.groupby(['row', 'col'], sort=False)
    for (row, col), group in groups:
        time_series = group.sort_values('time_step')
        for i in range(len(time_series) - 1):
            global_src = time_series['local_index'].iloc[i]
            global_dst = time_series['local_index'].iloc[i + 1]
            temporal_edges.append([global_src, global_dst])
    return np.array(temporal_edges), np.full((len(temporal_edges), 3), [0.0, 1.0, 0], dtype=np.float32)

def build_bc_mask(model_df):

    bc_mask = np.zeros((len(model_df), 5), dtype=np.float32)
    bc_mask[:, 0] = model_df['chd_mask'].values.astype(np.float32)
    bc_mask[:, 1] = (model_df['river_cond'] > 0).astype(np.float32)
    bc_mask[:, 2] = model_df['river_stage'].values.astype(np.float32)
    bc_mask[:, 3] = model_df['well_rate'].values.astype(np.float32)
    bc_mask[:, 4] = model_df['well_mask'].values.astype(np.float32)
    return bc_mask
def build_spatiotemporal_graph(df):

    print(f"\n▶ Started building spatiotemporal graphs")
    print(f"▷ Total models to process: {len(df['model_name'].unique())}")
    graphs = []
    

    df = df.astype({
        'x': np.float32, 'y': np.float32, 'top': np.float32, 
        'bottom': np.float32, 'K': np.float32, 'recharge': np.float32,
        'ET': np.float32, 'river_stage': np.float32, 'river_cond': np.float32,
        'river_rbot': np.float32, 'well_rate': np.float32, 'well_mask': np.uint8,
        'chd_mask': np.uint8, 'lytyp': np.uint8, 'head': np.float32, 
        'concentration': np.float32,'conc_mask': np.uint8
    })
    
    time_min = df['time_step'].min()
    df['time_step'] = df['time_step'] - time_min
    
    model_groups = list(df.groupby('model_name', sort=False))
    total_models = len(model_groups)
    
    for model_idx, (model_name, model_df) in enumerate(model_groups, 1):
        model_df = model_df.reset_index(drop=True).copy()
        model_df['local_index'] = model_df.index
        print(f"\n▣ Processing model {model_idx}/{total_models}: {model_name}")
        
        model_df = model_df.sort_values(['row', 'col', 'time_step'])
        

        feature_cols = [
            'x', 'y', 'top', 'bottom', 'K', 'recharge', 'ET',
            'river_stage', 'river_cond', 'river_rbot', 'well_rate', 'well_mask',
            'chd_mask', 'lytyp'
        ]
        node_feats = model_df[feature_cols].values.astype(np.float32)
        col_types = df[feature_cols].dtypes.to_dict()
        float_indices = [i for i, col in enumerate(feature_cols) if col_types[col] != np.uint8]
        float_feats = node_feats[:, float_indices]
        scaler = StandardScaler()
        float_feats_scaled = scaler.fit_transform(float_feats)
        node_feats[:, float_indices] = float_feats_scaled
        conc_feature_cols = [
            'x', 'y', 'top', 'bottom', 'K', 'recharge', 'ET',
            'river_stage', 'river_cond', 'river_rbot', 'well_rate', 'well_mask',
            'chd_mask', 'lytyp','conc_mask'
        ]
        conc_node_feats = model_df[conc_feature_cols].values.astype(np.float32)
        col_types = df[conc_feature_cols].dtypes.to_dict()
        float_indices = [i for i, col in enumerate(conc_feature_cols) if col_types[col] != np.uint8]
        conc_float_feats = conc_node_feats[:, float_indices]
        scaler = StandardScaler()
        conc_float_feats_scaled = scaler.fit_transform(conc_float_feats)
        conc_node_feats[:, float_indices] = conc_float_feats_scaled

        prev_head = np.zeros(len(model_df), dtype=np.float32)
        prev2_head = np.zeros(len(model_df), dtype=np.float32)
        prev_conc = np.zeros(len(model_df), dtype=np.float32)
        prev2_conc = np.zeros(len(model_df), dtype=np.float32)
        
        groups = model_df.groupby(['row', 'col'], sort=False)
        for (row, col), group in groups:
            time_series = group.sort_values('time_step')
            prev_head[time_series.index] = np.roll(time_series['head'].values, 1)
            prev2_head[time_series.index] = np.roll(time_series['head'].values, 2)
            prev_conc[time_series.index] = np.roll(time_series['concentration'].values, 1)
            prev2_conc[time_series.index] = np.roll(time_series['concentration'].values, 2)
            
            first_idx = time_series.index[0]
            if len(time_series) > 1:
                second_idx = time_series.index[1]

                prev_head[first_idx] = time_series['head'].values[0]
                prev2_head[first_idx] = time_series['head'].values[0]
                prev2_head[second_idx] = time_series['head'].values[0]
                

                prev_conc[first_idx] = time_series['concentration'].values[0]
                prev2_conc[first_idx] = time_series['concentration'].values[0]
                prev2_conc[second_idx] = time_series['concentration'].values[0]
            else:
                prev_head[first_idx] = 0.0
                prev2_head[first_idx] = 0.0
                prev_conc[first_idx] = 0.0
                prev2_conc[first_idx] = 0.0


        head_feats = np.concatenate([
            node_feats,          
            prev_head[:, None],  
            prev2_head[:, None]   
        ], axis=1)
        

        conc_feats = np.concatenate([
            conc_node_feats,         
            prev_head[:, None],   
            prev2_head[:, None],  
            prev_conc[:, None],  
            prev2_conc[:, None]   
        ], axis=1)
        
        if np.any(np.isnan(head_feats)) or np.any(np.isinf(head_feats)):
            print(f"Warning: head_feats contains NaN or Inf for model {model_name}")
            head_feats = np.nan_to_num(head_feats, nan=0.0, posinf=1e6, neginf=-1e6)
        
        if np.any(np.isnan(conc_feats)) or np.any(np.isinf(conc_feats)):
            print(f"Warning: conc_feats contains NaN or Inf for model {model_name}")
            conc_feats = np.nan_to_num(conc_feats, nan=0.0, posinf=1e6, neginf=-1e6)
        
        conc = model_df['concentration'].values.astype(np.float32)
        head = model_df['head'].values.astype(np.float32)
        spatial_edges, spatial_attrs = generate_spatial_edges(model_df)
        temporal_edges, temporal_attrs = generate_temporal_edges(model_df)
        edges = np.concatenate([spatial_edges, temporal_edges], axis=0)
        edge_attr = np.concatenate([spatial_attrs, temporal_attrs], axis=0)
        bc_mask = build_bc_mask(model_df)
        
        assert bc_mask.shape == (len(model_df), 5), \
            f"Invalid bc_mask shape: {bc_mask.shape} for model {model_name}"
        
        graph = Data(
            x=torch.from_numpy(head_feats),      
            conc_x=torch.from_numpy(conc_feats),
            edge_index=torch.tensor(edges.T, dtype=torch.long),
            edge_attr=torch.from_numpy(edge_attr),
            y=torch.from_numpy(conc),
            head_y=torch.from_numpy(head),
            bc_mask=torch.from_numpy(bc_mask),
            time_step=torch.from_numpy(model_df['time_step'].values).long(),
            time_steps=model_df['time_step'].nunique(),
            model_name=str(model_name),
            row=torch.from_numpy(model_df['row'].values).long(),
            col=torch.from_numpy(model_df['col'].values).long(),
        )
        graphs.append(graph)
    
    print(f"\n✅ All models processed! Total graphs created: {len(graphs):,}")
    return graphs


def prepare_data(data, batch_size=4):

    print('正在处理数据...')
    all_graphs = build_spatiotemporal_graph(data)
    print('数据处理完成！')
    train_graphs, val_graphs = train_test_split(
        all_graphs, test_size=0.3, random_state=42
    )
    train_dataset = HydroDataset(train_graphs)
    val_dataset = HydroDataset(val_graphs)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    return train_loader, val_loader


config = {
    'head_input_dim': 16,
    'conc_input_dim': 19,
    'hidden_dim': 96,  
    'num_epochs': 500,
    'lr': 1e-3,  
    'weight_decay': 1e-4,
    'patience': 30,
    'save_path': './saved_models/blitz_bayesian_gnn_dual_guass',
    'mc_samples': 10,
    'head_prior_sigma_1': 0.01,  
    'head_prior_sigma_2': 0.002,
    'conc_prior_sigma_1': 0.05,  
    'conc_prior_sigma_2': 0.002,
    'kl_weight': 1e-4  
}

def compute_metrics(y_true, y_pred):

    if isinstance(y_true, torch.Tensor):
        y_true = y_true.detach().cpu().numpy()
    if isinstance(y_pred, torch.Tensor):
        y_pred = y_pred.detach().cpu().numpy()

    y_true = y_true.flatten()
    y_pred = y_pred.flatten()

    mask = ~np.isnan(y_true) & ~np.isinf(y_true) & ~np.isnan(y_pred) & ~np.isinf(y_pred)
    y_true = y_true[mask]
    y_pred = y_pred[mask]

    mse = np.mean((y_true - y_pred) ** 2)
    rmse = np.sqrt(mse)
    mae = np.mean(np.abs(y_true - y_pred))
    r2 = sklearn_r2_score(y_true, y_pred)

    return {
        'mse': mse,
        'rmse': rmse,
        'mae': mae,
        'r2': r2
    }

def compute_uncertainty(model, data, mc_samples=10):
    """
    为Blitz模型计算预测和不确定性
    """
    model.train() 
    predictions = []
    
    with torch.no_grad():
        for _ in range(mc_samples):
            pred = model(data)
            predictions.append(pred)
    
    predictions = torch.stack(predictions, dim=0)
    mean_pred = predictions.mean(dim=0)
    std_pred = predictions.std(dim=0)
    
    return mean_pred, std_pred

def compute_feature_shap_values_improved(model, data, n_samples=20, num_samples=10):

    model.train()  
    data = data.to(next(model.parameters()).device)
    
    print(f"[FeatureSHAP] 开始特征重要性分析，抽样{n_samples}个节点...")
    

    if not hasattr(data, 'x') or not torch.is_tensor(data.x):
        print("[FeatureSHAP] 错误: 数据缺少节点特征 (data.x)")
        return None, 0.0
    
    try:

        current_time_step = data.time_step.unique()[0].item() if hasattr(data, 'time_step') else 0
        time_mask = data.time_step == current_time_step if hasattr(data, 'time_step') else torch.ones(data.num_nodes, dtype=torch.bool, device=data.x.device)
        candidate_nodes = torch.where(time_mask)[0]
        
        if len(candidate_nodes) == 0:
            print("[FeatureSHAP] 错误: 找不到满足条件的节点")
            return None, 0.0
        
    
        actual_n_samples = min(n_samples, len(candidate_nodes))
        if actual_n_samples < n_samples:
            print(f"[FeatureSHAP] 警告: 候选节点数({len(candidate_nodes)})少于请求的样本数({n_samples})，调整为{actual_n_samples}")
        
      
        sampled_indices = torch.randperm(len(candidate_nodes))[:actual_n_samples]
        sampled_nodes = candidate_nodes[sampled_indices]
        
      
        num_features = data.x.size(1)
        
       
        all_shap_values = torch.zeros(actual_n_samples, num_features, device=data.x.device)
        all_expected_values = torch.zeros(actual_n_samples, device=data.x.device)
        
       
        baseline_preds = []
        for _ in range(num_samples):
            with torch.no_grad():
                pred = model(data)
                baseline_preds.append(pred)
        baseline_pred = torch.stack(baseline_preds, dim=0).mean(dim=0)
        
     
        for i, node_idx in enumerate(sampled_nodes):
            node_idx = node_idx.item()
            original_value = baseline_pred[node_idx].item()
            all_expected_values[i] = original_value
            
 
            for feat_idx in range(num_features):

                original_feat = data.x[:, feat_idx].clone()
                

                feat_mean = original_feat.mean()
                

                data.x[:, feat_idx] = feat_mean
                

                masked_preds = []
                for _ in range(num_samples):
                    with torch.no_grad():
                        masked_pred = model(data)
                        masked_preds.append(masked_pred)
                
     
                masked_pred = torch.stack(masked_preds, dim=0).mean(dim=0)
                

                shap_value = abs(original_value - masked_pred[node_idx].item())
                all_shap_values[i, feat_idx] = shap_value
                

                data.x[:, feat_idx] = original_feat
            

            if (i + 1) % 5 == 0 or i == len(sampled_nodes) - 1:
                print(f"[FeatureSHAP] 已完成 {i+1}/{len(sampled_nodes)} 个节点的分析")
        

        avg_shap_values = all_shap_values.mean(dim=0)
        avg_expected_value = all_expected_values.mean().item()
        

        if avg_shap_values.sum() > 0:
            avg_shap_values = avg_shap_values / avg_shap_values.sum()
        
        print("[FeatureSHAP] 特征重要性分析完成")
        return avg_shap_values.cpu().numpy(), avg_expected_value
    
    except Exception as e:
        print(f"[FeatureSHAP] 特征重要性分析出错: {e}")
        import traceback
        traceback.print_exc()
        return None, 0.0

class BlitzHeadLossAdapter(nn.Module):
    def __init__(self, base_criterion, data, kl_weight=1e-4):
        super().__init__()
        self.base_criterion = base_criterion
        self.data = data  
        self.kl_weight = kl_weight
    
    def forward(self, pred, labels):

        loss, _ = self.base_criterion(pred, self.data)
        return loss

class BlitzConcLossAdapter(nn.Module):
    def __init__(self, base_criterion, data, kl_weight=5e-5):
        super().__init__()
        self.base_criterion = base_criterion
        self.data = data  
        self.kl_weight = kl_weight
    
    def forward(self, pred, labels):

        loss, _ = self.base_criterion(pred, self.data)
        return loss

In [None]:
def train_dual_model_improved(train_loader, val_loader, evaluation_criterion='r2'):

    try:
        torch.cuda.empty_cache()
        print("CUDA缓存已成功清除")
    except RuntimeError as e:
        print(f"无法清除CUDA缓存: {e}")
        return None, None, None
    

    print("=" * 80)
    print("第一阶段：开始训练水头模型")
    print("=" * 80)
    
   
    head_input_dim = 16 

    head_model = BlitzHeadGNN(
        node_features=head_input_dim,
        spatial_dim=config['hidden_dim'],
        temporal_dim=config['hidden_dim'],
        prior_sigma_1=config['head_prior_sigma_1'],
        prior_sigma_2=config['head_prior_sigma_2'],
        dropout=0.1
    ).to(device)
    
  
    criterion_head = ImprovedPhysicsInformedLoss(alpha=0.1, kl_weight=config['kl_weight'])
    
  
    head_params = list(head_model.parameters())
    head_optimizer = torch.optim.AdamW(head_params, lr=config['lr'], weight_decay=config['weight_decay'])
    

    head_scheduler = CosineAnnealingWarmRestarts(
        head_optimizer, T_0=20, T_mult=2, eta_min=1e-5
    )
    
  
    best_head_val_loss = float('inf')
    best_head_r2 = float('-inf')
    head_early_stop_counter = 0
    head_losses = {'train': [], 'val': []}
    
  
    os.makedirs(config['save_path'], exist_ok=True)
    
    print("开始训练水头模型")
    print(f"水头模型参数数量: {sum(p.numel() for p in head_model.parameters() if p.requires_grad)}")
    
   
    for epoch in range(config['num_epochs']):
        head_model.train()
        train_loss = 0.0
        train_batches = 0
        
        
        for batch_idx, batch in enumerate(train_loader):
            try:
                
                batch = batch.to(device)
                if hasattr(batch, 'edge_attr') and batch.edge_attr is not None:
                    batch.edge_attr = batch.edge_attr.float()
                
           
                head_optimizer.zero_grad()
                
               
                pred_head = head_model(batch)
                
               
                criterion_output = criterion_head(pred_head, batch)
                
              
                if isinstance(criterion_output, tuple):
                    head_criterion_loss = criterion_output[0]
                    if len(criterion_output) > 1:
                        physics_loss = criterion_output[1]
                        if isinstance(physics_loss, tuple):
                            physics_loss_value = sum([p.item() if hasattr(p, 'item') else p for p in physics_loss])
                        else:
                            physics_loss_value = physics_loss.item() if hasattr(physics_loss, 'item') else physics_loss
                    else:
                        physics_loss_value = 0.0
                else:
                    head_criterion_loss = criterion_output
                    physics_loss_value = 0.0
                
                kl_loss = head_model.nn_kl_divergence() * config['kl_weight']
                total_loss = head_criterion_loss + kl_loss
                
               
                total_loss.backward()
                
                
                torch.nn.utils.clip_grad_norm_(head_params, max_norm=1.0)
                
               
                head_optimizer.step()
                
               
                train_loss += total_loss.item()
                train_batches += 1
                
                
                if batch_idx % 50 == 0:
                    print(f"水头模型 Epoch {epoch+1}, Batch {batch_idx}: "
                          f"Total Loss: {total_loss.item():.4f}, "
                          f"Criterion Loss: {head_criterion_loss.item():.4f}, "
                          f"Physics Loss: {physics_loss_value:.4f}, "
                          f"KL Loss: {kl_loss.item():.4f}")
                
            except Exception as e:
                print(f"水头模型训练批次 {batch_idx} 出错: {e}")
                continue
        
       
        if train_batches == 0:
            print("警告: 水头模型本轮训练没有成功处理任何批次，跳过本轮")
            continue
            
      
        avg_train_loss = train_loss / train_batches
        
        
        head_model.eval()
        val_loss = 0.0
        val_metrics = {'mse': 0.0, 'rmse': 0.0, 'mae': 0.0, 'r2': 0.0}
        val_batches = 0
        
        with torch.no_grad():
            for batch_idx, batch in enumerate(val_loader):
                try:
                   
                    batch = batch.to(device)
                    if hasattr(batch, 'edge_attr') and batch.edge_attr is not None:
                        batch.edge_attr = batch.edge_attr.float()
                    
                  
                    head_model.train()  #
                    pred_head, head_std = compute_uncertainty(head_model, batch, mc_samples=config['mc_samples'])
                    
                 
                    criterion_output = criterion_head(pred_head, batch)
                    
                 
                    if isinstance(criterion_output, tuple):
                        head_criterion_loss = criterion_output[0]
                    else:
                        head_criterion_loss = criterion_output
                    
                   
                    metrics = compute_metrics(batch.head_y, pred_head)
                    
                    for k in metrics:
                        val_metrics[k] += metrics[k]
                    
                    val_loss += head_criterion_loss.item()
                    val_batches += 1
                    
                except Exception as e:
                    print(f"水头模型验证批次 {batch_idx} 出错: {e}")
                    continue
        
      
        if val_batches > 0:
            avg_val_loss = val_loss / val_batches
            for k in val_metrics:
                val_metrics[k] /= val_batches
        else:
            print("警告: 水头模型本轮验证没有成功处理任何批次")
            avg_val_loss = float('inf')
        
       
        head_losses['train'].append(avg_train_loss)
        head_losses['val'].append({
            'loss': avg_val_loss,
            'metrics': val_metrics
        })
        
     
        head_scheduler.step()
        current_lr = head_scheduler.get_last_lr()[0]
        
      
        print(f"水头模型 Epoch {epoch+1:03d}/{config['num_epochs']} | "
              f"训练损失: {avg_train_loss:.4f} | 验证损失: {avg_val_loss:.4f} | "
              f"LR: {current_lr:.6f}")
        print(f"水头验证指标 - MSE: {val_metrics['mse']:.4f}, "
              f"RMSE: {val_metrics['rmse']:.4f}, "
              f"MAE: {val_metrics['mae']:.4f}, "
              f"R2: {val_metrics['r2']:.4f}")
        
   
        if avg_val_loss < best_head_val_loss:
            best_head_val_loss = avg_val_loss
            try:
                torch.save({
                    'model_state_dict': head_model.state_dict(),
                    'epoch': epoch,
                    'train_loss': avg_train_loss,
                    'val_loss': avg_val_loss,
                    'val_metrics': val_metrics,
                    'config': config,
                    'criterion': 'loss'
                }, os.path.join(config['save_path'], 'best_head_model_loss.pth'))
                print(f"保存基于损失的最佳水头模型，验证损失: {best_head_val_loss:.4f}")
            except Exception as e:
                print(f"保存水头模型失败: {e}")
        
 
        if val_metrics['r2'] > best_head_r2:
            best_head_r2 = val_metrics['r2']
            head_early_stop_counter = 0  
            try:
                torch.save({
                    'model_state_dict': head_model.state_dict(),
                    'epoch': epoch,
                    'train_loss': avg_train_loss,
                    'val_loss': avg_val_loss,
                    'val_metrics': val_metrics,
                    'config': config,
                    'criterion': 'r2'
                }, os.path.join(config['save_path'], 'best_head_model_r2.pth'))
                print(f"保存基于R2的最佳水头模型，R2: {best_head_r2:.4f}")
            except Exception as e:
                print(f"保存水头模型失败: {e}")
        else:
            head_early_stop_counter += 1
        

        if head_early_stop_counter >= config['patience']:
            print(f"水头模型早停触发! 在第{epoch+1}个epoch停止训练")
            break
        

        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    print(f"\n水头模型训练完成!")
    print(f"基于损失的最佳验证损失: {best_head_val_loss:.4f}")
    print(f"基于R2的最佳R2分数: {best_head_r2:.4f}")
    

    print("\n" + "=" * 80)
    print("第二阶段：开始训练浓度模型")
    print("=" * 80)
    
  
    head_model_file = f'best_head_model_{evaluation_criterion}.pth'
    try:
        checkpoint = torch.load(os.path.join(config['save_path'], head_model_file),weights_only=False)
        head_model.load_state_dict(checkpoint['model_state_dict'])
        print(f"成功加载基于{evaluation_criterion}的最佳水头模型")
    except Exception as e:
        print(f"加载水头模型失败，使用当前模型: {e}")
    
   
    for param in head_model.parameters():
        param.requires_grad = False
    head_model.eval()
    
   
    conc_input_dim = 19  
    conc_model = BlitzConcGNN(
        node_features=conc_input_dim,  
        spatial_dim=config['hidden_dim'],
        temporal_dim=config['hidden_dim'],
        prior_sigma_1=0.1,
        prior_sigma_2=0.01,
        dropout=0.1,
        posterior_mu_init=0.0,
        posterior_rho_init=-3.0
    ).to(device)
    

    criterion_conc = ImprovedConcLoss(kl_weight=config['kl_weight'], l1_weight=1e-5)
    

    conc_params = list(conc_model.parameters())
    conc_optimizer = torch.optim.AdamW(conc_params, lr=config['lr'] * 0.8, weight_decay=config['weight_decay'])
    
 
    conc_scheduler = CosineAnnealingWarmRestarts(
        conc_optimizer, T_0=15, T_mult=2, eta_min=1e-6
    )
    
 
    best_conc_val_loss = float('inf')
    best_conc_r2 = float('-inf')
    conc_early_stop_counter = 0
    conc_losses = {'train': [], 'val': []}
    
    print(f"浓度模型参数数量: {sum(p.numel() for p in conc_model.parameters() if p.requires_grad)}")
    

    for epoch in range(config['num_epochs']):
        conc_model.train()
        train_loss = 0.0
        train_batches = 0
        
        # 训练阶段
        for batch_idx, batch in enumerate(train_loader):
            try:
              
                batch = batch.to(device)
                if hasattr(batch, 'edge_attr') and batch.edge_attr is not None:
                    batch.edge_attr = batch.edge_attr.float()
                
             
                with torch.no_grad():
                    pred_head = head_model(batch)
                
           
                conc_optimizer.zero_grad()
                
             
                pred_conc = conc_model(batch, pred_head)
                
            
                criterion_output = criterion_conc(pred_conc, batch, conc_model)
                
           
                if isinstance(criterion_output, tuple):
                    conc_criterion_loss = criterion_output[0]
                    if len(criterion_output) > 1:
                        loss_components = criterion_output[1]
                        if isinstance(loss_components, tuple) and len(loss_components) >= 3:
                            mse_loss, kl_loss_val, l1_reg = loss_components[:3]
                        else:
                            mse_loss, kl_loss_val, l1_reg = 0.0, 0.0, 0.0
                    else:
                        mse_loss, kl_loss_val, l1_reg = 0.0, 0.0, 0.0
                else:
                    conc_criterion_loss = criterion_output
                    mse_loss, kl_loss_val, l1_reg = 0.0, 0.0, 0.0
                
                kl_loss = conc_model.nn_kl_divergence() * config['kl_weight']
                total_loss = conc_criterion_loss + kl_loss
                
               
                total_loss.backward()
                
                
                torch.nn.utils.clip_grad_norm_(conc_params, max_norm=1.0)
                
               
                conc_optimizer.step()
                
               
                train_loss += total_loss.item()
                train_batches += 1
                
               
                if batch_idx % 50 == 0:
                    print(f"浓度模型 Epoch {epoch+1}, Batch {batch_idx}: "
                          f"Total Loss: {total_loss.item():.4f}, "
                          f"Criterion Loss: {conc_criterion_loss.item():.4f}, "
                          f"MSE: {mse_loss:.4f}, "
                          f"KL Loss: {kl_loss.item():.4f}, "
                          f"L1 Reg: {l1_reg:.4f}")
                
            except Exception as e:
                print(f"浓度模型训练批次 {batch_idx} 出错: {e}")
                continue
        
      
        if train_batches == 0:
            print("警告: 浓度模型本轮训练没有成功处理任何批次，跳过本轮")
            continue
            
      
        avg_train_loss = train_loss / train_batches
        
     
        conc_model.eval()
        val_loss = 0.0
        val_metrics = {'mse': 0.0, 'rmse': 0.0, 'mae': 0.0, 'r2': 0.0}
        val_batches = 0
        
       
        all_conc_predictions = []
        all_conc_targets = []
        all_conc_uncertainties = []
        all_head_predictions = []
        
        with torch.no_grad():
            for batch_idx, batch in enumerate(val_loader):
                try:
                   
                    batch = batch.to(device)
                    if hasattr(batch, 'edge_attr') and batch.edge_attr is not None:
                        batch.edge_attr = batch.edge_attr.float()
                    
                   
                    pred_head = head_model(batch)
                    
                   
                    conc_model.train()  
                    pred_conc, conc_std = compute_uncertainty(conc_model, batch, mc_samples=config['mc_samples'])
                    
                    
                    criterion_output = criterion_conc(pred_conc, batch, conc_model)
                    
                   
                    if isinstance(criterion_output, tuple):
                        conc_criterion_loss = criterion_output[0]
                    else:
                        conc_criterion_loss = criterion_output
                    
               
                    metrics = compute_metrics(batch.y, pred_conc)
                    
                    for k in metrics:
                        val_metrics[k] += metrics[k]
                    
                    val_loss += conc_criterion_loss.item()
                    val_batches += 1
                    
                    # 收集预测结果用于后续分析
                    all_conc_predictions.append(pred_conc.cpu().numpy())
                    all_conc_targets.append(batch.y.cpu().numpy())
                    all_conc_uncertainties.append(conc_std.cpu().numpy())
                    all_head_predictions.append(pred_head.cpu().numpy())
                    
                except Exception as e:
                    print(f"浓度模型验证批次 {batch_idx} 出错: {e}")
                    continue
        
       
        if val_batches > 0:
            avg_val_loss = val_loss / val_batches
            for k in val_metrics:
                val_metrics[k] /= val_batches
        else:
            print("警告: 浓度模型本轮验证没有成功处理任何批次")
            avg_val_loss = float('inf')
        
       
        conc_losses['train'].append(avg_train_loss)
        conc_losses['val'].append({
            'loss': avg_val_loss,
            'metrics': val_metrics
        })
        
        
        conc_scheduler.step()
        current_lr = conc_scheduler.get_last_lr()[0]
        
  
        print(f"浓度模型 Epoch {epoch+1:03d}/{config['num_epochs']} | "
              f"训练损失: {avg_train_loss:.4f} | 验证损失: {avg_val_loss:.4f} | "
              f"LR: {current_lr:.6f}")
        print(f"浓度验证指标 - MSE: {val_metrics['mse']:.4f}, "
              f"RMSE: {val_metrics['rmse']:.4f}, "
              f"MAE: {val_metrics['mae']:.4f}, "
              f"R2: {val_metrics['r2']:.4f}")
        
      
        if avg_val_loss < best_conc_val_loss:
            best_conc_val_loss = avg_val_loss
            try:
                torch.save({
                    'model_state_dict': conc_model.state_dict(),
                    'epoch': epoch,
                    'train_loss': avg_train_loss,
                    'val_loss': avg_val_loss,
                    'val_metrics': val_metrics,
                    'config': config,
                    'criterion': 'loss'
                }, os.path.join(config['save_path'], 'best_conc_model_loss.pth'))
                
            
                if all_conc_predictions:
                    np.save(os.path.join(config['save_path'], 'best_conc_predictions_loss.npy'), 
                           np.concatenate(all_conc_predictions, axis=0))
                    np.save(os.path.join(config['save_path'], 'best_conc_targets_loss.npy'), 
                           np.concatenate(all_conc_targets, axis=0))
                    np.save(os.path.join(config['save_path'], 'best_conc_uncertainties_loss.npy'), 
                           np.concatenate(all_conc_uncertainties, axis=0))
                
                print(f"保存基于损失的最佳浓度模型，验证损失: {best_conc_val_loss:.4f}")
            except Exception as e:
                print(f"保存浓度模型失败: {e}")
        
      
        if val_metrics['r2'] > best_conc_r2:
            best_conc_r2 = val_metrics['r2']
            conc_early_stop_counter = 0  
            try:
                torch.save({
                    'model_state_dict': conc_model.state_dict(),
                    'epoch': epoch,
                    'train_loss': avg_train_loss,
                    'val_loss': avg_val_loss,
                    'val_metrics': val_metrics,
                    'config': config,
                    'criterion': 'r2'
                }, os.path.join(config['save_path'], 'best_conc_model_r2.pth'))
                
           
                if all_conc_predictions:
                    np.save(os.path.join(config['save_path'], 'best_conc_predictions_r2.npy'), 
                           np.concatenate(all_conc_predictions, axis=0))
                    np.save(os.path.join(config['save_path'], 'best_conc_targets_r2.npy'), 
                           np.concatenate(all_conc_targets, axis=0))
                    np.save(os.path.join(config['save_path'], 'best_conc_uncertainties_r2.npy'), 
                           np.concatenate(all_conc_uncertainties, axis=0))
                
                print(f"保存基于R2的最佳浓度模型，R2: {best_conc_r2:.4f}")
            except Exception as e:
                print(f"保存浓度模型失败: {e}")
        else:
            conc_early_stop_counter += 1
        
  
        if conc_early_stop_counter >= config['patience']:
            print(f"浓度模型早停触发! 在第{epoch+1}个epoch停止训练")
            break
        
     
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    print(f"\n浓度模型训练完成!")
    print(f"基于损失的最佳验证损失: {best_conc_val_loss:.4f}")
    print(f"基于R2的最佳R2分数: {best_conc_r2:.4f}")
    

    try:
  
        head_history_data = []
        for i, (train_loss, val_data) in enumerate(zip(head_losses['train'], head_losses['val'])):
            head_history_data.append({
                'epoch': i + 1,
                'train_loss': train_loss,
                'val_loss': val_data['loss'],
                'val_mse': val_data['metrics']['mse'],
                'val_rmse': val_data['metrics']['rmse'],
                'val_mae': val_data['metrics']['mae'],
                'val_r2': val_data['metrics']['r2']
            })
        
        if head_history_data:
            head_history_df = pd.DataFrame(head_history_data)
            head_history_df.to_csv(os.path.join(config['save_path'], 'head_training_history.csv'), index=False)
        
      
        conc_history_data = []
        for i, (train_loss, val_data) in enumerate(zip(conc_losses['train'], conc_losses['val'])):
            conc_history_data.append({
                'epoch': i + 1,
                'train_loss': train_loss,
                'val_loss': val_data['loss'],
                'val_mse': val_data['metrics']['mse'],
                'val_rmse': val_data['metrics']['rmse'],
                'val_mae': val_data['metrics']['mae'],
                'val_r2': val_data['metrics']['r2']
            })
        
        if conc_history_data:
            conc_history_df = pd.DataFrame(conc_history_data)
            conc_history_df.to_csv(os.path.join(config['save_path'], 'conc_training_history.csv'), index=False)
        
       
        if head_history_data and conc_history_data:
            plt.figure(figsize=(20, 12))
            
           
            plt.subplot(2, 4, 1)
            plt.plot(head_history_df['epoch'], head_history_df['train_loss'], 'b-', label='Head Train Loss')
            plt.plot(head_history_df['epoch'], head_history_df['val_loss'], 'r-', label='Head Val Loss')
          
            best_loss_epoch = head_history_df.loc[head_history_df['val_loss'].idxmin(), 'epoch']
            best_loss_value = head_history_df['val_loss'].min()
            plt.scatter(best_loss_epoch, best_loss_value, color='red', s=100, marker='*', 
                       label=f'Best Loss (E{best_loss_epoch})')
            plt.title('Head Model: Training and Validation Loss')
            plt.xlabel('Epoch')
            plt.ylabel('Loss')
            plt.legend()
            plt.grid(True)
            
            plt.subplot(2, 4, 2)
            plt.plot(head_history_df['epoch'], head_history_df['val_r2'], 'g-', label='Head R2')
      
            best_r2_epoch = head_history_df.loc[head_history_df['val_r2'].idxmax(), 'epoch']
            best_r2_value = head_history_df['val_r2'].max()
            plt.scatter(best_r2_epoch, best_r2_value, color='green', s=100, marker='*', 
                       label=f'Best R2 (E{best_r2_epoch})')
            plt.title('Head Model: Validation R2 Score')
            plt.xlabel('Epoch')
            plt.ylabel('R2')
            plt.legend()
            plt.grid(True)
            
            plt.subplot(2, 4, 3)
            plt.plot(head_history_df['epoch'], head_history_df['val_mse'], 'orange', label='Head MSE')
            plt.title('Head Model: Validation MSE')
            plt.xlabel('Epoch')
            plt.ylabel('MSE')
            plt.legend()
            plt.grid(True)
            
            plt.subplot(2, 4, 4)
            plt.plot(head_history_df['epoch'], head_history_df['val_rmse'], 'purple', label='Head RMSE')
            plt.title('Head Model: Validation RMSE')
            plt.xlabel('Epoch')
            plt.ylabel('RMSE')
            plt.legend()
            plt.grid(True)
     
            plt.subplot(2, 4, 5)
            plt.plot(conc_history_df['epoch'], conc_history_df['train_loss'], 'b--', label='Conc Train Loss')
            plt.plot(conc_history_df['epoch'], conc_history_df['val_loss'], 'r--', label='Conc Val Loss')
        
            best_loss_epoch = conc_history_df.loc[conc_history_df['val_loss'].idxmin(), 'epoch']
            best_loss_value = conc_history_df['val_loss'].min()
            plt.scatter(best_loss_epoch, best_loss_value, color='red', s=100, marker='*', 
                       label=f'Best Loss (E{best_loss_epoch})')
            plt.title('Conc Model: Training and Validation Loss')
            plt.xlabel('Epoch')
            plt.ylabel('Loss')
            plt.legend()
            plt.grid(True)
            
            plt.subplot(2, 4, 6)
            plt.plot(conc_history_df['epoch'], conc_history_df['val_r2'], 'g--', label='Conc R2')
      
            best_r2_epoch = conc_history_df.loc[conc_history_df['val_r2'].idxmax(), 'epoch']
            best_r2_value = conc_history_df['val_r2'].max()
            plt.scatter(best_r2_epoch, best_r2_value, color='green', s=100, marker='*', 
                       label=f'Best R2 (E{best_r2_epoch})')
            plt.title('Conc Model: Validation R2 Score')
            plt.xlabel('Epoch')
            plt.ylabel('R2')
            plt.legend()
            plt.grid(True)
            
            plt.subplot(2, 4, 7)
            plt.plot(conc_history_df['epoch'], conc_history_df['val_mse'], 'orange', linestyle='--', label='Conc MSE')
            plt.title('Conc Model: Validation MSE')
            plt.xlabel('Epoch')
            plt.ylabel('MSE')
            plt.legend()
            plt.grid(True)
            
            plt.subplot(2, 4, 8)
            plt.plot(conc_history_df['epoch'], conc_history_df['val_rmse'], 'purple', linestyle='--', label='Conc RMSE')
            plt.title('Conc Model: Validation RMSE')
            plt.xlabel('Epoch')
            plt.ylabel('RMSE')
            plt.legend()
            plt.grid(True)
            
            plt.tight_layout()
            plt.savefig(os.path.join(config['save_path'], 'dual_model_training_curves_improved.png'), 
                       dpi=300, bbox_inches='tight')
            plt.close()
            
            print(f"改进的双模型训练曲线已保存")
        
    except Exception as e:
        print(f"保存训练历史或绘图失败: {e}")
    
    # 重新启用水头模型的梯度计算（如果需要）
    for param in head_model.parameters():
        param.requires_grad = True
    
    print("\n" + "=" * 80)
    print("双模型训练完成总结:")
    print(f"水头模型 - 基于损失的最佳验证损失: {best_head_val_loss:.4f}")
    print(f"水头模型 - 基于R2的最佳R2分数: {best_head_r2:.4f}")
    print(f"浓度模型 - 基于损失的最佳验证损失: {best_conc_val_loss:.4f}")
    print(f"浓度模型 - 基于R2的最佳R2分数: {best_conc_r2:.4f}")
    print(f"评估将使用基于{evaluation_criterion}的模型")
    print("=" * 80)
    
    return head_model, conc_model, {'head': head_losses, 'conc': conc_losses}

def compute_uncertainty_with_func(forward_func, data, mc_samples=30):

    all_preds = []
    for _ in range(mc_samples):
        with torch.no_grad():

            pred = forward_func(data)
            all_preds.append(pred)
    

    all_preds = torch.stack(all_preds, dim=0)
    pred_mean = all_preds.mean(dim=0)
    pred_std = all_preds.std(dim=0)
    
    return pred_mean, pred_std

def evaluate_dual_model_improved(head_model, conc_model, val_loader, evaluation_criterion='loss'):

    print(f"开始评估双模型性能（基于{evaluation_criterion}标准）...")
    

    try:
        head_model_file = f'best_head_model_{evaluation_criterion}.pth'
        head_checkpoint = torch.load(os.path.join(config['save_path'], head_model_file), weights_only=False)
        head_model.load_state_dict(head_checkpoint['model_state_dict'])
        
        conc_model_file = f'best_conc_model_{evaluation_criterion}.pth'
        conc_checkpoint = torch.load(os.path.join(config['save_path'], conc_model_file), weights_only=False)
        conc_model.load_state_dict(conc_checkpoint['model_state_dict'])
        
        print(f"成功加载基于{evaluation_criterion}的最佳模型权重")
        print(f"水头模型来自epoch {head_checkpoint['epoch']}, 验证损失: {head_checkpoint['val_loss']:.4f}, R2: {head_checkpoint['val_metrics']['r2']:.4f}")
        print(f"浓度模型来自epoch {conc_checkpoint['epoch']}, 验证损失: {conc_checkpoint['val_loss']:.4f}, R2: {conc_checkpoint['val_metrics']['r2']:.4f}")
        
    except Exception as e:
        print(f"加载模型权重失败，使用当前权重: {e}")
    
    # 设置模型为训练模式以进行MC dropout
    head_model.train()  
    conc_model.train()
    

    results_dir = os.path.join(config['save_path'], 'evaluation_results')
    os.makedirs(results_dir, exist_ok=True)
    

    all_head_preds = []
    all_head_targets = []
    all_head_uncertainties = []
    all_conc_preds = []
    all_conc_targets = []
    all_conc_uncertainties = []
    

    predictions = []
    uncertainties = []
    

    def head_forward_func(batch):

        return head_model(batch)
    
    def conc_forward_func(batch_with_head):

        return conc_model(batch_with_head)
    
    print("开始处理验证数据...")
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(val_loader):
            try:
                batch = batch.to(device)
                if hasattr(batch, 'edge_attr') and batch.edge_attr is not None:
                    batch.edge_attr = batch.edge_attr.float()
                
      
                head_pred, head_std = compute_uncertainty_with_func(
                    head_forward_func, batch, mc_samples=config.get('mc_samples', 30)
                )
                
            
                batch_conc = batch.clone()
              
                if hasattr(batch_conc, 'x'):

                    batch_conc.x = torch.cat([batch.x, head_pred.detach()], dim=1)
                else:

                    batch_conc.x = head_pred.detach()

                conc_pred, conc_std = compute_uncertainty_with_func(
                    conc_forward_func, batch_conc, mc_samples=config.get('mc_samples', 30)
                )
                
                # 收集结果
                all_head_preds.append(head_pred.cpu().numpy())
                all_head_targets.append(batch.head_y.cpu().numpy())
                all_head_uncertainties.append(head_std.cpu().numpy())
                all_conc_preds.append(conc_pred.cpu().numpy())
                all_conc_targets.append(batch.y.cpu().numpy())
                all_conc_uncertainties.append(conc_std.cpu().numpy())
                

                batch_predictions = {
                    'batch_idx': batch_idx,
                    'pred_head': head_pred.cpu().numpy().flatten(),
                    'true_head': batch.head_y.cpu().numpy().flatten(),
                    'pred_conc': conc_pred.cpu().numpy().flatten(),
                    'true_conc': batch.y.cpu().numpy().flatten()
                }
                

                if hasattr(batch, 'row') and hasattr(batch, 'col'):
                    batch_predictions['row'] = batch.row.cpu().numpy()
                    batch_predictions['col'] = batch.col.cpu().numpy()
                if hasattr(batch, 'time_step'):
                    batch_predictions['time_step'] = batch.time_step.cpu().numpy()
                
                predictions.append(batch_predictions)
                

                batch_uncertainties = {
                    'batch_idx': batch_idx,
                    'head_std': head_std.cpu().numpy().flatten(),
                    'conc_std': conc_std.cpu().numpy().flatten()
                }
                

                if hasattr(batch, 'row') and hasattr(batch, 'col'):
                    batch_uncertainties['row'] = batch.row.cpu().numpy()
                    batch_uncertainties['col'] = batch.col.cpu().numpy()
                if hasattr(batch, 'time_step'):
                    batch_uncertainties['time_step'] = batch.time_step.cpu().numpy()
                
                uncertainties.append(batch_uncertainties)
                
                # 每10个批次输出进度
                if batch_idx % 10 == 0:
                    print(f"处理批次 {batch_idx}/{len(val_loader)}")
                
            except Exception as e:
                print(f"评估批次 {batch_idx} 出错: {e}")
                continue
    

    all_head_preds = np.concatenate(all_head_preds, axis=0)
    all_head_targets = np.concatenate(all_head_targets, axis=0)
    all_head_uncertainties = np.concatenate(all_head_uncertainties, axis=0)
    all_conc_preds = np.concatenate(all_conc_preds, axis=0)
    all_conc_targets = np.concatenate(all_conc_targets, axis=0)
    all_conc_uncertainties = np.concatenate(all_conc_uncertainties, axis=0)

    head_metrics = compute_metrics(all_head_targets, all_head_preds)
    conc_metrics = compute_metrics(all_conc_targets, all_conc_preds)
    
    print(f"\n水头模型评估结果（基于{evaluation_criterion}）:")
    for metric, value in head_metrics.items():
        print(f"  {metric.upper()}: {value:.4f}")
    
    print(f"\n浓度模型评估结果（基于{evaluation_criterion}）:")
    for metric, value in conc_metrics.items():
        print(f"  {metric.upper()}: {value:.4f}")
    

    plt.figure(figsize=(10, 10))
    plt.errorbar(all_head_targets.flatten(), all_head_preds.flatten(), 
                yerr=all_head_uncertainties.flatten(), fmt='o', alpha=0.3, 
                ecolor='lightgray', elinewidth=0.5, capsize=0, markersize=2)
    
    min_val = min(np.min(all_head_targets), np.min(all_head_preds))
    max_val = max(np.max(all_head_targets), np.max(all_head_preds))
    plt.plot([min_val, max_val], [min_val, max_val], 'r--', linewidth=2)
    
    plt.xlabel('True Head Values', fontsize=14)
    plt.ylabel('Predicted Head Values', fontsize=14)
    plt.title(f'Head Predictions (R² = {head_metrics["r2"]:.4f}, RMSE = {head_metrics["rmse"]:.4f})', fontsize=16)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(results_dir, f'head_predictions_{evaluation_criterion}.png'), dpi=300, bbox_inches='tight')
    plt.close()
    

    plt.figure(figsize=(10, 10))
    plt.errorbar(all_conc_targets.flatten(), all_conc_preds.flatten(), 
                yerr=all_conc_uncertainties.flatten(), fmt='o', alpha=0.3,
                ecolor='lightgray', elinewidth=0.5, capsize=0, markersize=2)
    
    min_val = min(np.min(all_conc_targets), np.min(all_conc_preds))
    max_val = max(np.max(all_conc_targets), np.max(all_conc_preds))
    plt.plot([min_val, max_val], [min_val, max_val], 'r--', linewidth=2)
    
    plt.xlabel('True Concentration Values', fontsize=14)
    plt.ylabel('Predicted Concentration Values', fontsize=14)
    plt.title(f'Concentration Predictions (R² = {conc_metrics["r2"]:.4f}, RMSE = {conc_metrics["rmse"]:.4f})', fontsize=16)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(results_dir, f'conc_predictions_{evaluation_criterion}.png'), dpi=300, bbox_inches='tight')
    plt.close()
    

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    ax1.hist(all_head_uncertainties.flatten(), bins=50, alpha=0.7, color='blue', edgecolor='black')
    ax1.set_xlabel('Head Uncertainty (Std)', fontsize=12)
    ax1.set_ylabel('Frequency', fontsize=12)
    ax1.set_title('Head Uncertainty Distribution', fontsize=14)
    ax1.grid(True, alpha=0.3)
    
    ax2.hist(all_conc_uncertainties.flatten(), bins=50, alpha=0.7, color='green', edgecolor='black')
    ax2.set_xlabel('Concentration Uncertainty (Std)', fontsize=12)
    ax2.set_ylabel('Frequency', fontsize=12)
    ax2.set_title('Concentration Uncertainty Distribution', fontsize=14)
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(results_dir, f'uncertainty_distributions_{evaluation_criterion}.png'), dpi=300, bbox_inches='tight')
    plt.close()
    

    head_errors = np.abs(all_head_preds.flatten() - all_head_targets.flatten())
    conc_errors = np.abs(all_conc_preds.flatten() - all_conc_targets.flatten())
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    ax1.scatter(all_head_uncertainties.flatten(), head_errors, alpha=0.5, s=10)
    head_corr = np.corrcoef(all_head_uncertainties.flatten(), head_errors)[0, 1]
    ax1.set_xlabel('Head Uncertainty (Std)', fontsize=12)
    ax1.set_ylabel('Head Absolute Error', fontsize=12)
    ax1.set_title(f'Head: Uncertainty vs Error (r={head_corr:.3f})', fontsize=14)
    ax1.grid(True, alpha=0.3)
    
    ax2.scatter(all_conc_uncertainties.flatten(), conc_errors, alpha=0.5, s=10)
    conc_corr = np.corrcoef(all_conc_uncertainties.flatten(), conc_errors)[0, 1]
    ax2.set_xlabel('Concentration Uncertainty (Std)', fontsize=12)
    ax2.set_ylabel('Concentration Absolute Error', fontsize=12)
    ax2.set_title(f'Concentration: Uncertainty vs Error (r={conc_corr:.3f})', fontsize=14)
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(results_dir, f'uncertainty_vs_error_{evaluation_criterion}.png'), dpi=300, bbox_inches='tight')
    plt.close()
    

    evaluation_results = {
        'criterion': evaluation_criterion,
        'head_metrics': head_metrics,
        'conc_metrics': conc_metrics,
        'head_predictions': all_head_preds,
        'head_targets': all_head_targets,
        'head_uncertainties': all_head_uncertainties,
        'conc_predictions': all_conc_preds,
        'conc_targets': all_conc_targets,
        'conc_uncertainties': all_conc_uncertainties,
        'detailed_predictions': predictions,
        'detailed_uncertainties': uncertainties
    }
    

    filename = f'dual_model_evaluation_{evaluation_criterion}.npy'
    np.save(os.path.join(config['save_path'], filename), evaluation_results)
    print(f"\n评估结果已保存到: {config['save_path']}/{filename}")
    
    # 保存CSV格式的预测结果
    if predictions:
        pred_dfs = []
        for pred_dict in predictions:
            pred_df = pd.DataFrame({k: v for k, v in pred_dict.items() if not k.startswith('batch')})
            pred_dfs.append(pred_df)
        
        if pred_dfs:
            predictions_df = pd.concat(pred_dfs, ignore_index=True)
            predictions_df.to_csv(os.path.join(results_dir, f'predictions_{evaluation_criterion}.csv'), index=False)
            print(f"预测结果CSV已保存到: {results_dir}/predictions_{evaluation_criterion}.csv")
    
    # 保存CSV格式的不确定性结果
    if uncertainties:
        unc_dfs = []
        for unc_dict in uncertainties:
            unc_df = pd.DataFrame({k: v for k, v in unc_dict.items() if not k.startswith('batch')})
            unc_dfs.append(unc_df)
        
        if unc_dfs:
            uncertainties_df = pd.concat(unc_dfs, ignore_index=True)
            uncertainties_df.to_csv(os.path.join(results_dir, f'uncertainties_{evaluation_criterion}.csv'), index=False)
            print(f"不确定性结果CSV已保存到: {results_dir}/uncertainties_{evaluation_criterion}.csv")
    
    # 保存整体评估指标
    metrics_summary = {
        'evaluation_criterion': evaluation_criterion,
        'head_mse': head_metrics['mse'],
        'head_rmse': head_metrics['rmse'],
        'head_mae': head_metrics['mae'],
        'head_r2': head_metrics['r2'],
        'conc_mse': conc_metrics['mse'],
        'conc_rmse': conc_metrics['rmse'],
        'conc_mae': conc_metrics['mae'],
        'conc_r2': conc_metrics['r2'],
        'head_uncertainty_mean': np.mean(all_head_uncertainties),
        'head_uncertainty_std': np.std(all_head_uncertainties),
        'conc_uncertainty_mean': np.mean(all_conc_uncertainties),
        'conc_uncertainty_std': np.std(all_conc_uncertainties),
        'head_error_uncertainty_correlation': head_corr,
        'conc_error_uncertainty_correlation': conc_corr
    }
    
    metrics_df = pd.DataFrame([metrics_summary])
    metrics_df.to_csv(os.path.join(results_dir, f'evaluation_summary_{evaluation_criterion}.csv'), index=False)
    
    print(f"\n📊 评估完成!")
    print(f"📈 水头模型 - R2: {head_metrics['r2']:.4f}, RMSE: {head_metrics['rmse']:.4f}")
    print(f"📈 浓度模型 - R2: {conc_metrics['r2']:.4f}, RMSE: {conc_metrics['rmse']:.4f}")
    print(f"📁 所有结果已保存到: {results_dir}")
    
    return evaluation_results

def compare_model_criteria(head_model, conc_model, val_loader):


    loss_results = evaluate_dual_model_improved(head_model, conc_model, val_loader, 'loss')
    r2_results = evaluate_dual_model_improved(head_model, conc_model, val_loader, 'r2')
    

    comparison_data = []
    

    comparison_data.append({
        'Model': 'Head',
        'Criterion': 'Loss',
        'MSE': loss_results['head_metrics']['mse'],
        'RMSE': loss_results['head_metrics']['rmse'],
        'MAE': loss_results['head_metrics']['mae'],
        'R2': loss_results['head_metrics']['r2'],
        'Uncertainty_Mean': np.mean(loss_results['head_uncertainties']),
        'Error_Uncertainty_Corr': np.corrcoef(
            loss_results['head_uncertainties'].flatten(),
            np.abs(loss_results['head_predictions'].flatten() - loss_results['head_targets'].flatten())
        )[0, 1]
    })
    
    comparison_data.append({
        'Model': 'Head',
        'Criterion': 'R2',
        'MSE': r2_results['head_metrics']['mse'],
        'RMSE': r2_results['head_metrics']['rmse'],
        'MAE': r2_results['head_metrics']['mae'],
        'R2': r2_results['head_metrics']['r2'],
        'Uncertainty_Mean': np.mean(r2_results['head_uncertainties']),
        'Error_Uncertainty_Corr': np.corrcoef(
            r2_results['head_uncertainties'].flatten(),
            np.abs(r2_results['head_predictions'].flatten() - r2_results['head_targets'].flatten())
        )[0, 1]
    })
    

    comparison_data.append({
        'Model': 'Concentration',
        'Criterion': 'Loss',
        'MSE': loss_results['conc_metrics']['mse'],
        'RMSE': loss_results['conc_metrics']['rmse'],
        'MAE': loss_results['conc_metrics']['mae'],
        'R2': loss_results['conc_metrics']['r2'],
        'Uncertainty_Mean': np.mean(loss_results['conc_uncertainties']),
        'Error_Uncertainty_Corr': np.corrcoef(
            loss_results['conc_uncertainties'].flatten(),
            np.abs(loss_results['conc_predictions'].flatten() - loss_results['conc_targets'].flatten())
        )[0, 1]
    })
    
    comparison_data.append({
        'Model': 'Concentration',
        'Criterion': 'R2',
        'MSE': r2_results['conc_metrics']['mse'],
        'RMSE': r2_results['conc_metrics']['rmse'],
        'MAE': r2_results['conc_metrics']['mae'],
        'R2': r2_results['conc_metrics']['r2'],
        'Uncertainty_Mean': np.mean(r2_results['conc_uncertainties']),
        'Error_Uncertainty_Corr': np.corrcoef(
            r2_results['conc_uncertainties'].flatten(),
            np.abs(r2_results['conc_predictions'].flatten() - r2_results['conc_targets'].flatten())
        )[0, 1]
    })
    

    comparison_df = pd.DataFrame(comparison_data)
    

    results_dir = os.path.join(config['save_path'], 'evaluation_results')
    os.makedirs(results_dir, exist_ok=True)
    comparison_df.to_csv(os.path.join(results_dir, 'model_criteria_comparison.csv'), index=False)
    

    print("\n模型选择标准比较结果:")
    print(comparison_df.to_string(index=False, float_format='%.4f'))
    
    # 绘制扩展的比较图
    plt.figure(figsize=(20, 12))
    
    # R2比较
    plt.subplot(2, 4, 1)
    head_r2 = [loss_results['head_metrics']['r2'], r2_results['head_metrics']['r2']]
    conc_r2 = [loss_results['conc_metrics']['r2'], r2_results['conc_metrics']['r2']]
    x = ['Loss-based', 'R2-based']
    plt.bar([0, 1], head_r2, alpha=0.7, label='Head Model', width=0.35)
    plt.bar([0.35, 1.35], conc_r2, alpha=0.7, label='Concentration Model', width=0.35)
    plt.xlabel('Model Selection Criterion')
    plt.ylabel('R2 Score')
    plt.title('R2 Score Comparison')
    plt.xticks([0.175, 1.175], x)
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # MSE比较
    plt.subplot(2, 4, 2)
    head_mse = [loss_results['head_metrics']['mse'], r2_results['head_metrics']['mse']]
    conc_mse = [loss_results['conc_metrics']['mse'], r2_results['conc_metrics']['mse']]
    plt.bar([0, 1], head_mse, alpha=0.7, label='Head Model', width=0.35)
    plt.bar([0.35, 1.35], conc_mse, alpha=0.7, label='Concentration Model', width=0.35)
    plt.xlabel('Model Selection Criterion')
    plt.ylabel('MSE')
    plt.title('MSE Comparison')
    plt.xticks([0.175, 1.175], x)
    plt.legend()
    plt.grid(True, alpha=0.3)
    

    plt.subplot(2, 4, 3)
    head_unc = [np.mean(loss_results['head_uncertainties']), np.mean(r2_results['head_uncertainties'])]
    conc_unc = [np.mean(loss_results['conc_uncertainties']), np.mean(r2_results['conc_uncertainties'])]
    plt.bar([0, 1], head_unc, alpha=0.7, label='Head Model', width=0.35)
    plt.bar([0.35, 1.35], conc_unc, alpha=0.7, label='Concentration Model', width=0.35)
    plt.xlabel('Model Selection Criterion')
    plt.ylabel('Mean Uncertainty')
    plt.title('Uncertainty Comparison')
    plt.xticks([0.175, 1.175], x)
    plt.legend()
    plt.grid(True, alpha=0.3)
    

    plt.subplot(2, 4, 4)
    head_corr = [comparison_data[0]['Error_Uncertainty_Corr'], comparison_data[1]['Error_Uncertainty_Corr']]
    conc_corr = [comparison_data[2]['Error_Uncertainty_Corr'], comparison_data[3]['Error_Uncertainty_Corr']]
    plt.bar([0, 1], head_corr, alpha=0.7, label='Head Model', width=0.35)
    plt.bar([0.35, 1.35], conc_corr, alpha=0.7, label='Concentration Model', width=0.35)
    plt.xlabel('Model Selection Criterion')
    plt.ylabel('Error-Uncertainty Correlation')
    plt.title('Calibration Quality Comparison')
    plt.xticks([0.175, 1.175], x)
    plt.legend()
    plt.grid(True, alpha=0.3)
    

    plt.subplot(2, 4, 5)
    plt.scatter(loss_results['head_targets'].flatten(), loss_results['head_predictions'].flatten(), 
               alpha=0.5, label='Loss-based', s=5)
    plt.scatter(r2_results['head_targets'].flatten(), r2_results['head_predictions'].flatten(), 
               alpha=0.5, label='R2-based', s=5)
    plt.plot([loss_results['head_targets'].min(), loss_results['head_targets'].max()], 
             [loss_results['head_targets'].min(), loss_results['head_targets'].max()], 'r--', alpha=0.8)
    plt.xlabel('True Head Values')
    plt.ylabel('Predicted Head Values')
    plt.title('Head Model: True vs Predicted')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.subplot(2, 4, 6)
    plt.scatter(loss_results['conc_targets'].flatten(), loss_results['conc_predictions'].flatten(), 
               alpha=0.5, label='Loss-based', s=5)
    plt.scatter(r2_results['conc_targets'].flatten(), r2_results['conc_predictions'].flatten(), 
               alpha=0.5, label='R2-based', s=5)
    plt.plot([loss_results['conc_targets'].min(), loss_results['conc_targets'].max()], 
             [loss_results['conc_targets'].min(), loss_results['conc_targets'].max()], 'r--', alpha=0.8)
    plt.xlabel('True Concentration Values')
    plt.ylabel('Predicted Concentration Values')
    plt.title('Concentration Model: True vs Predicted')
    plt.legend()
    plt.grid(True, alpha=0.3)
    

    plt.subplot(2, 4, 7)
    plt.hist(loss_results['head_uncertainties'].flatten(), bins=30, alpha=0.5, label='Loss-based', density=True)
    plt.hist(r2_results['head_uncertainties'].flatten(), bins=30, alpha=0.5, label='R2-based', density=True)
    plt.xlabel('Head Uncertainty')
    plt.ylabel('Density')
    plt.title('Head Uncertainty Distribution')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.subplot(2, 4, 8)
    plt.hist(loss_results['conc_uncertainties'].flatten(), bins=30, alpha=0.5, label='Loss-based', density=True)
    plt.hist(r2_results['conc_uncertainties'].flatten(), bins=30, alpha=0.5, label='R2-based', density=True)
    plt.xlabel('Concentration Uncertainty')
    plt.ylabel('Density')
    plt.title('Concentration Uncertainty Distribution')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(results_dir, 'model_criteria_comparison_comprehensive.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"\n比较图表已保存到: {results_dir}/model_criteria_comparison_comprehensive.png")
    print(f"比较数据已保存到: {results_dir}/model_criteria_comparison.csv")
    
    return comparison_df, loss_results, r2_results

In [None]:
data = pd.read_csv('conc_dual_guass.csv')  
train_loader, val_loader = prepare_data(data, batch_size=4)
head_model, conc_model, training_losses = train_dual_model_improved(train_loader, val_loader)
evaluation_results = evaluate_dual_model_improved(head_model, conc_model, val_loader)
comparison_df, loss_results, r2_results=compare_model_criteria(head_model, conc_model, val_loader)

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import gc
from sklearn.metrics import r2_score
import warnings
warnings.filterwarnings('ignore')

def create_complete_data_copy(original_data, device):
    """创建包含所有必要属性的完整数据副本"""
    try:
        data_copy = original_data.clone().to(device)
        
        required_attrs = ['x', 'conc_x', 'edge_index', 'edge_attr', 'y', 'head_y', 
                         'time_step', 'bc_mask', 'row', 'col', 'model_name']
        
        for attr in required_attrs:
            if hasattr(original_data, attr):
                original_value = getattr(original_data, attr)
                if original_value is not None:
                    if torch.is_tensor(original_value):
                        setattr(data_copy, attr, original_value.clone().to(device))
                    else:
                        setattr(data_copy, attr, original_value)
            else:
                if attr == 'bc_mask':
                    if hasattr(data_copy, 'x') and data_copy.x is not None:
                        num_nodes = data_copy.x.shape[0]
                        setattr(data_copy, attr, torch.zeros(num_nodes, 5, dtype=torch.float32, device=device))
                elif attr == 'time_step':
                    if hasattr(data_copy, 'x') and data_copy.x is not None:
                        num_nodes = data_copy.x.shape[0]
                        setattr(data_copy, attr, torch.zeros(num_nodes, dtype=torch.long, device=device))
        
        return data_copy
    except Exception as e:
        print(f"❌ 创建数据副本失败: {e}")
        return None

def compute_feature_effects_and_importance(model, data, model_type='head', mc_samples=20, perturbation_scale=0.1):
    """
    计算特征对模型输出的作用（正向/负向）和重要性
    
    Args:
        model: 要分析的模型
        data: 输入数据
        model_type: 模型类型 ('head' 或 'conc')
        mc_samples: MC采样次数
        perturbation_scale: 扰动强度
    
    Returns:
        dict: 包含特征作用和重要性的字典
    """
    model.train()  
    device = next(model.parameters()).device
    
    print(f"  🔍 计算{model_type.upper()}模型的特征作用和重要性...")
    
    try:
       
        data_copy = create_complete_data_copy(data, device)
        if data_copy is None:
            return {}
        
        # 确定特征集
        if model_type == 'head':
            if not hasattr(data_copy, 'x') or data_copy.x is None:
                print(f"    ❌ 水头模型缺少节点特征 x")
                return {}
            feature_tensor = data_copy.x
            feature_name = 'x'
        else:  # conc
            if not hasattr(data_copy, 'conc_x') or data_copy.conc_x is None:
                print(f"    ❌ 浓度模型缺少节点特征 conc_x")
                return {}
            feature_tensor = data_copy.conc_x
            feature_name = 'conc_x'
        
        num_features = feature_tensor.shape[1]
        print(f"    📊 特征数量: {num_features}")
        
        # 获取基线预测
        baseline_preds = []
        for _ in range(mc_samples):
            with torch.no_grad():
                pred = model(data_copy)
                baseline_preds.append(pred.detach())
        
        baseline_pred = torch.stack(baseline_preds, dim=0).mean(dim=0)
        baseline_mean = baseline_pred.mean().item()
        
        print(f"    📈 基线预测均值: {baseline_mean:.6f}")
        
        # 存储每个特征的作用和重要性
        feature_effects = np.zeros(num_features)  # 特征作用（正向/负向）
        feature_importance = np.zeros(num_features)  # 特征重要性（绝对值）
        
        # 对每个特征进行扰动分析
        for feat_idx in range(num_features):
            print(f"      分析特征 {feat_idx + 1}/{num_features}")
            
            # 正向扰动
            data_pos = data_copy.clone()
            original_feature = getattr(data_pos, feature_name)[:, feat_idx].clone()
            std_val = original_feature.std().item()
            
            # 正向扰动
            getattr(data_pos, feature_name)[:, feat_idx] = original_feature + perturbation_scale * std_val
            
            pos_preds = []
            for _ in range(mc_samples):
                with torch.no_grad():
                    pred = model(data_pos)
                    pos_preds.append(pred.detach())
            pos_pred = torch.stack(pos_preds, dim=0).mean(dim=0)
            pos_effect = pos_pred.mean().item() - baseline_mean
            
   
            data_neg = data_copy.clone()
            getattr(data_neg, feature_name)[:, feat_idx] = original_feature - perturbation_scale * std_val
            
            neg_preds = []
            for _ in range(mc_samples):
                with torch.no_grad():
                    pred = model(data_neg)
                    neg_preds.append(pred.detach())
            neg_pred = torch.stack(neg_preds, dim=0).mean(dim=0)
            neg_effect = neg_pred.mean().item() - baseline_mean
            
           
            avg_effect = (pos_effect - neg_effect) / 2
            
          
            importance = abs(pos_effect) + abs(neg_effect)
            
            feature_effects[feat_idx] = avg_effect
            feature_importance[feat_idx] = importance
        
  
        if feature_importance.max() > 0:
            feature_importance = feature_importance / feature_importance.max()
        
        print(f"    ✅ 特征作用和重要性计算完成")
        
        return {
            'effects': feature_effects,
            'importance': feature_importance,
            'baseline': baseline_mean
        }
        
    except Exception as e:
        print(f"    ❌ 特征作用计算失败: {e}")
        import traceback
        traceback.print_exc()
        return {}

def prepare_single_graph_batch(batch):
    """从batch中提取单个图的数据"""
    try:
        if not hasattr(batch, 'batch') or batch.batch is None:
            return batch
        
        unique_graphs = torch.unique(batch.batch)
        if len(unique_graphs) <= 1:
            return batch
        
 
        first_graph_mask = batch.batch == unique_graphs[0]
        new_batch = type(batch)()
        

        node_attrs = ['x', 'conc_x', 'head_y', 'y', 'time_step', 'bc_mask', 'row', 'col']
        for attr in node_attrs:
            if hasattr(batch, attr) and getattr(batch, attr) is not None:
                attr_value = getattr(batch, attr)
                if torch.is_tensor(attr_value) and len(attr_value) == len(first_graph_mask):
                    setattr(new_batch, attr, attr_value[first_graph_mask])
        

        if hasattr(batch, 'edge_index') and batch.edge_index is not None:
            first_nodes = torch.where(first_graph_mask)[0]
            edge_mask = torch.isin(batch.edge_index[0], first_nodes) & \
                       torch.isin(batch.edge_index[1], first_nodes)
            
            if edge_mask.sum() > 0:
                node_mapping = {old.item(): new for new, old in enumerate(first_nodes)}
                old_edges = batch.edge_index[:, edge_mask]
                new_edges = torch.zeros_like(old_edges)
                
                for i in range(old_edges.shape[1]):
                    new_edges[0, i] = node_mapping[old_edges[0, i].item()]
                    new_edges[1, i] = node_mapping[old_edges[1, i].item()]
                
                new_batch.edge_index = new_edges
                
                if hasattr(batch, 'edge_attr') and batch.edge_attr is not None:
                    new_batch.edge_attr = batch.edge_attr[edge_mask]
            else:
                new_batch.edge_index = torch.empty((2, 0), dtype=torch.long, device=batch.edge_index.device)
        

        other_attrs = ['model_name', 'time_steps']
        for attr in other_attrs:
            if hasattr(batch, attr):
                setattr(new_batch, attr, getattr(batch, attr))
        
        return new_batch
    except Exception as e:
        print(f"准备单图batch失败: {e}")
        return batch

def analyze_model_effects(model, val_loader, config, model_type='head', 
                         evaluation_criterion='loss', n_samples=300):

    

    try:
        model_file = f'best_{model_type}_model_{evaluation_criterion}.pth'
        checkpoint = torch.load(os.path.join(config['save_path'], model_file), 
                              map_location='cpu', weights_only=False)
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"✅ 成功加载{model_type.upper()}模型 (Epoch: {checkpoint['epoch']}, Val Loss: {checkpoint['val_loss']:.4f})")
    except Exception as e:
        print(f"⚠️  加载{model_type.upper()}模型权重失败: {e}")
    
    device = next(model.parameters()).device
    model.eval()
    

    if model_type == 'head':
        feature_names = [
            'x', 'y', 'top', 'bottom', 'K', 'recharge', 'ET',
            'river_stage', 'river_cond', 'river_rbot', 'well_rate', 
            'well_mask', 'chd_mask', 'lytyp', 'prev_head', 'prev2_head'
        ]
    else:  
        feature_names = [
            'x', 'y', 'top', 'bottom', 'K', 'recharge', 'ET',
            'river_stage', 'river_cond', 'river_rbot', 'well_rate', 
            'well_mask', 'chd_mask', 'lytyp', 'conc_mask',
            'prev_head', 'prev2_head', 'prev_conc', 'prev2_conc'
        ]
    
   
    all_effect_data = []
    successful_samples = 0
    total_batches = len(val_loader)
    
    print(f"🎯 开始分析 {n_samples} 个样本...")
    print(f"📊 验证数据加载器总批次数: {total_batches}")
    
 
    cycles_needed = max(1, (n_samples + total_batches - 1) // total_batches)
    print(f"🔄 需要重复遍历数据加载器 {cycles_needed} 次")
    
    for cycle in range(cycles_needed):
        print(f"\n🔄 第 {cycle + 1}/{cycles_needed} 轮遍历...")
        
        for batch_idx, batch in enumerate(val_loader):
            if successful_samples >= n_samples:
                break
            
            try:
               
                single_batch = prepare_single_graph_batch(batch)
                
        
                result = compute_feature_effects_and_importance(
                    model, single_batch, model_type=model_type
                )
                
                if result and 'effects' in result and 'importance' in result:
                    effects = result['effects']
                    importance = result['importance']
                    baseline = result['baseline']
                    
           
                    for i, (effect, imp) in enumerate(zip(effects, importance)):
                        feature_name = feature_names[i] if i < len(feature_names) else f'Feature_{i}'
                        all_effect_data.append({
                            'Sample': successful_samples + 1,
                            'Feature': feature_name,
                            'Effect': effect,  
                            'Importance': imp,  
                            'Feature_Index': i,
                            'Baseline': baseline,
                            'Cycle': cycle + 1,
                            'Batch_in_Cycle': batch_idx + 1
                        })
                    
                    successful_samples += 1
                    print(f"  ✅ 样本 {successful_samples}/{n_samples} 分析完成 (轮次{cycle+1}, 批次{batch_idx+1})")
                else:
                    print(f"  ❌ 样本分析失败 (轮次{cycle+1}, 批次{batch_idx+1})")
                
     
                torch.cuda.empty_cache()
                gc.collect()
                
            except Exception as e:
                print(f"  ❌ 处理批次 {batch_idx} 失败: {e}")
                continue
        
        if successful_samples >= n_samples:
            break
    
    print(f"\n📊 {model_type.upper()}模型分析完成，成功分析 {successful_samples} 个样本")
    
    if not all_effect_data:
        print("❌ 没有收集到任何有效数据")
        return None, None
    

    effect_df = pd.DataFrame(all_effect_data)
    

    feature_stats = effect_df.groupby('Feature').agg({
        'Effect': ['mean', 'std', 'min', 'max'],
        'Importance': ['mean', 'std', 'min', 'max']
    }).reset_index()
    

    feature_stats.columns = ['Feature', 'Effect_Mean', 'Effect_Std', 'Effect_Min', 'Effect_Max',
                            'Importance_Mean', 'Importance_Std', 'Importance_Min', 'Importance_Max']
    
    # 按重要性排序
    feature_stats = feature_stats.sort_values('Importance_Mean', ascending=False)
    
    return effect_df, feature_stats

def create_effect_swarm_plots(head_df, head_stats, conc_df, conc_stats, config, 
                             evaluation_criterion):
    """
    创建基于特征作用的蜂群图（横轴是作用，颜色是重要性）
    特征重要性从上到下排序
    """
    print("🎨 生成特征作用蜂群图...")

    # 创建结果目录
    results_dir = os.path.join(config['save_path'], 'effect_swarm_analysis')
    os.makedirs(results_dir, exist_ok=True)


    plt.style.use('default')


    if head_df is not None:
        plt.figure(figsize=(16, 12))
        
   
        feature_order = head_stats.sort_values("Importance_Mean", ascending=False)['Feature'].tolist()
        

        y_mapping = {f: i for i, f in enumerate(feature_order)}


        scatter = plt.scatter(
            head_df['Effect'], 
            head_df['Feature'].map(y_mapping), 
            c=head_df['Importance'], 
            cmap='viridis', 
            s=80, 
            alpha=0.7,
            edgecolors='black',
            linewidth=0.5
        )
        

        cbar = plt.colorbar(scatter)
        cbar.set_label('Feature Importance', fontsize=12, fontweight='bold')


        plt.axvline(x=0, color='red', linestyle='--', alpha=0.8, linewidth=2, label='No Effect')


        for i, feature in enumerate(feature_order):
            mean_effect = head_stats[head_stats['Feature'] == feature]['Effect_Mean'].iloc[0]
            plt.scatter(mean_effect, i, marker='D', s=100, color='red', 
                       edgecolors='black', linewidth=1, zorder=5)

            plt.text(mean_effect, i + 0.15, f'{mean_effect:.4f}', 
                    ha='center', va='bottom', fontsize=9, fontweight='bold', 
                    bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.8))

        plt.yticks(range(len(feature_order)), feature_order)
        plt.xlabel('Feature Effect on Model Output', fontsize=14, fontweight='bold')
        plt.ylabel('Features (sorted by importance)', fontsize=14, fontweight='bold')
        plt.title(f'Head Model: Feature Effects Distribution\nCriterion: {evaluation_criterion}\n(Red diamonds = mean effect)', 
                 fontsize=16, fontweight='bold', pad=20)
        plt.grid(True, alpha=0.3, axis='x')
        plt.legend()


        plt.tight_layout()
        plt.savefig(os.path.join(results_dir, f'head_effect_swarm_{evaluation_criterion}_sorted.png'), 
                   dpi=300, bbox_inches='tight')
        plt.close()

        print(f"  📊 Head模型作用蜂群图已保存（已按重要性排序）")

    if conc_df is not None:
        plt.figure(figsize=(16, 14))
        
        # 按重要性均值从高到低排序
        feature_order = conc_stats.sort_values("Importance_Mean", ascending=False)['Feature'].tolist()
        y_mapping = {f: i for i, f in enumerate(feature_order)}

        scatter = plt.scatter(
            conc_df['Effect'], 
            conc_df['Feature'].map(y_mapping), 
            c=conc_df['Importance'], 
            cmap='viridis', 
            s=80, 
            alpha=0.7,
            edgecolors='black',
            linewidth=0.5
        )

        cbar = plt.colorbar(scatter)
        cbar.set_label('Feature Importance', fontsize=12, fontweight='bold')

        plt.axvline(x=0, color='red', linestyle='--', alpha=0.8, linewidth=2, label='No Effect')

        for i, feature in enumerate(feature_order):
            mean_effect = conc_stats[conc_stats['Feature'] == feature]['Effect_Mean'].iloc[0]
            plt.scatter(mean_effect, i, marker='D', s=100, color='red', 
                       edgecolors='black', linewidth=1, zorder=5)
            plt.text(mean_effect, i + 0.15, f'{mean_effect:.4f}', 
                    ha='center', va='bottom', fontsize=9, fontweight='bold',
                    bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.8))

        plt.yticks(range(len(feature_order)), feature_order)
        plt.xlabel('Feature Effect on Model Output', fontsize=14, fontweight='bold')
        plt.ylabel('Features (sorted by importance)', fontsize=14, fontweight='bold')
        plt.title(f'Concentration Model: Feature Effects Distribution\nCriterion: {evaluation_criterion}\n(Red diamonds = mean effect)', 
                 fontsize=16, fontweight='bold', pad=20)
        plt.grid(True, alpha=0.3, axis='x')
        plt.legend()

        plt.tight_layout()
        plt.savefig(os.path.join(results_dir, f'conc_effect_swarm_{evaluation_criterion}_sorted.png'), 
                   dpi=300, bbox_inches='tight')
        plt.close()

        print(f"  📊 Conc模型作用蜂群图已保存（已按重要性排序）")


    if head_df is not None and conc_df is not None:
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(28, 14))


        top_head_features = head_stats.sort_values("Importance_Mean", ascending=False).head(10)['Feature'].tolist()
        head_top_df = head_df[head_df['Feature'].isin(top_head_features)]
        y_mapping_head = {f: i for i, f in enumerate(top_head_features)}

        ax1.scatter(
            head_top_df['Effect'], 
            head_top_df['Feature'].map(y_mapping_head), 
            c=head_top_df['Importance'], 
            cmap='viridis', 
            s=100, 
            alpha=0.8,
            edgecolors='black',
            linewidth=0.5
        )
        ax1.axvline(x=0, color='red', linestyle='--', alpha=0.8, linewidth=2)
        for i, feature in enumerate(top_head_features):
            mean_effect = head_stats[head_stats['Feature'] == feature]['Effect_Mean'].iloc[0]
            ax1.scatter(mean_effect, i, marker='D', s=120, color='red', 
                       edgecolors='black', linewidth=1, zorder=5)
            ax1.text(mean_effect, i + 0.15, f'{mean_effect:.4f}', 
                    ha='center', va='bottom', fontsize=10, fontweight='bold',
                    bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.8))

        ax1.set_yticks(range(len(top_head_features)))
        ax1.set_yticklabels(top_head_features)
        ax1.set_xlabel('Feature Effect on Output', fontsize=12, fontweight='bold')
        ax1.set_ylabel('Features (sorted by importance)', fontsize=12, fontweight='bold')
        ax1.set_title('Head Model - Top 10 Features', fontsize=14, fontweight='bold')
        ax1.grid(True, alpha=0.3, axis='x')


        top_conc_features = conc_stats.sort_values("Importance_Mean", ascending=False).head(10)['Feature'].tolist()
        conc_top_df = conc_df[conc_df['Feature'].isin(top_conc_features)]
        y_mapping_conc = {f: i for i, f in enumerate(top_conc_features)}

        ax2.scatter(
            conc_top_df['Effect'], 
            conc_top_df['Feature'].map(y_mapping_conc), 
            c=conc_top_df['Importance'], 
            cmap='viridis', 
            s=100, 
            alpha=0.8,
            edgecolors='black',
            linewidth=0.5
        )
        ax2.axvline(x=0, color='red', linestyle='--', alpha=0.8, linewidth=2)
        for i, feature in enumerate(top_conc_features):
            mean_effect = conc_stats[conc_stats['Feature'] == feature]['Effect_Mean'].iloc[0]
            ax2.scatter(mean_effect, i, marker='D', s=120, color='red', 
                       edgecolors='black', linewidth=1, zorder=5)
            ax2.text(mean_effect, i + 0.15, f'{mean_effect:.4f}', 
                    ha='center', va='bottom', fontsize=10, fontweight='bold',
                    bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.8))

        ax2.set_yticks(range(len(top_conc_features)))
        ax2.set_yticklabels(top_conc_features)
        ax2.set_xlabel('Feature Effect on Output', fontsize=12, fontweight='bold')
        ax2.set_ylabel('Features (sorted by importance)', fontsize=12, fontweight='bold')
        ax2.set_title('Concentration Model - Top 10 Features', fontsize=14, fontweight='bold')
        ax2.grid(True, alpha=0.3, axis='x')

 
        fig.subplots_adjust(right=0.85)
        cbar_ax = fig.add_axes([0.87, 0.15, 0.02, 0.7])
        cbar = fig.colorbar(ax2.collections[0], cax=cbar_ax)
        cbar.set_label('Feature Importance', fontsize=12, fontweight='bold')

        plt.suptitle(f'Model Comparison: Feature Effects on Output\nCriterion: {evaluation_criterion}', 
                    fontsize=16, fontweight='bold')
        plt.savefig(os.path.join(results_dir, f'comparison_effect_swarm_{evaluation_criterion}_sorted.png'), 
                   dpi=300, bbox_inches='tight')
        plt.close()

        print(f"  📊 对比作用蜂群图已保存（已按重要性排序）")

    return results_dir

def save_effect_results(head_df, head_stats, conc_df, conc_stats, results_dir, 
                       evaluation_criterion):
    """
    保存特征作用分析结果
    """
    print("💾 保存特征作用分析结果...")
    

    if head_df is not None:
        head_df.to_csv(os.path.join(results_dir, f'head_effects_raw_{evaluation_criterion}.csv'), 
                      index=False)
        head_stats.to_csv(os.path.join(results_dir, f'head_effects_stats_{evaluation_criterion}.csv'), 
                         index=False)
    
    if conc_df is not None:
        conc_df.to_csv(os.path.join(results_dir, f'conc_effects_raw_{evaluation_criterion}.csv'), 
                      index=False)
        conc_stats.to_csv(os.path.join(results_dir, f'conc_effects_stats_{evaluation_criterion}.csv'), 
                         index=False)
    

    summary_data = []
    
    if head_stats is not None:
        for _, row in head_stats.iterrows():
            summary_data.append({
                'Model': 'Head',
                'Feature': row['Feature'],
                'Mean_Effect': row['Effect_Mean'],
                'Effect_Direction': 'Positive' if row['Effect_Mean'] > 0 else 'Negative' if row['Effect_Mean'] < 0 else 'Neutral',
                'Mean_Importance': row['Importance_Mean'],
                'Effect_Std': row['Effect_Std'],
                'Importance_Std': row['Importance_Std']
            })
    
    if conc_stats is not None:
        for _, row in conc_stats.iterrows():
            summary_data.append({
                'Model': 'Concentration',
                'Feature': row['Feature'],
                'Mean_Effect': row['Effect_Mean'],
                'Effect_Direction': 'Positive' if row['Effect_Mean'] > 0 else 'Negative' if row['Effect_Mean'] < 0 else 'Neutral',
                'Mean_Importance': row['Importance_Mean'],
                'Effect_Std': row['Effect_Std'],
                'Importance_Std': row['Importance_Std']
            })
    
    if summary_data:
        summary_df = pd.DataFrame(summary_data)
        summary_df.to_csv(os.path.join(results_dir, f'feature_effects_summary_{evaluation_criterion}.csv'), 
                         index=False)
        
        print(f"  📄 总结报告已保存")

def run_effect_swarm_analysis(head_model, conc_model, val_loader, config, 
                             evaluation_criterion='loss', n_samples=300):
    """
    运行基于特征作用的蜂群图分析
    """
    print("=" * 80)
    print("🎯 特征作用蜂群图分析")
    print("  横轴：特征对模型输出的作用（正向/负向）")
    print("  颜色：特征重要性")
    print("=" * 80)
    

    print("\n🌊 分析Head模型特征作用...")
    head_df, head_stats = analyze_model_effects(
        head_model, val_loader, config, model_type='head', 
        evaluation_criterion=evaluation_criterion, n_samples=n_samples
    )
    

    print("\n🧪 分析Conc模型特征作用...")
    conc_df, conc_stats = analyze_model_effects(
        conc_model, val_loader, config, model_type='conc', 
        evaluation_criterion=evaluation_criterion, n_samples=n_samples
    )
    

    if head_df is not None or conc_df is not None:
        results_dir = create_effect_swarm_plots(
            head_df, head_stats, conc_df, conc_stats, 
            config, evaluation_criterion
        )
        

        save_effect_results(
            head_df, head_stats, conc_df, conc_stats, 
            results_dir, evaluation_criterion
        )
        
        print(f"\n🎉 特征作用蜂群图分析完成!")
        print(f"📁 结果保存至: {results_dir}")
        
   
        if head_stats is not None:
            print(f"\n🏆 Head模型 Top 5 重要特征作用:")
            for i, row in head_stats.head(5).iterrows():
                direction = "正向" if row['Effect_Mean'] > 0 else "负向" if row['Effect_Mean'] < 0 else "中性"
                print(f"  {i+1}. {row['Feature']}: 作用 {row['Effect_Mean']:.6f} ({direction}), 重要性 {row['Importance_Mean']:.6f}")
        
        if conc_stats is not None:
            print(f"\n🏆 Conc模型 Top 5 重要特征作用:")
            for i, row in conc_stats.head(5).iterrows():
                direction = "正向" if row['Effect_Mean'] > 0 else "负向" if row['Effect_Mean'] < 0 else "中性"
                print(f"  {i+1}. {row['Feature']}: 作用 {row['Effect_Mean']:.6f} ({direction}), 重要性 {row['Importance_Mean']:.6f}")
        
        return {
            'head_df': head_df,
            'head_stats': head_stats,
            'conc_df': conc_df,
            'conc_stats': conc_stats,
            'results_dir': results_dir
        }
    else:
        print("❌ 没有成功分析任何模型")
        return None


def run_complete_effect_analysis(head_model, conc_model, val_loader, config):


    print("🚀 开始完整的特征作用蜂群图分析")
    
    results = {}
    

    for criterion in ['r2']:
        print(f"\n📊 分析评估标准: {criterion}")
        
        result = run_effect_swarm_analysis(
            head_model, conc_model, val_loader, config,
            evaluation_criterion=criterion, n_samples=300
        )
        
        results[criterion] = result
        
 
        torch.cuda.empty_cache()
        gc.collect()
    
    print(f"\n✅ 完整的特征作用蜂群图分析完成!")
    return results


if __name__ == "__main__":

    all_results = run_complete_effect_analysis(head_model, conc_model, val_loader, config)
    
