In [None]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, f1_score, precision_score, recall_score
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, roc_curve, precision_recall_curve, average_precision_score, cohen_kappa_score
import seaborn as sns
from tqdm import tqdm
from collections import Counter
import json
import cv2

# 设置随机种子保证可重复性
torch.manual_seed(42)
np.random.seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# 设备配置（自动检测GPU）
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 数据路径配置
data_dir = os.getcwd()
image_dir = os.path.join(data_dir, "Fibrosis")
annotation_file = os.path.join(data_dir, "Fibrosislabel.xlsx")
mask_dir = os.path.join(data_dir, "Mask")  

# 新增：模型保存路径配置
model_save_dir = os.path.join(data_dir, "Model")  # 修改为您想要保存模型的路径
os.makedirs(model_save_dir, exist_ok=True)  # 确保目录存在
print(f"模型将保存到: {model_save_dir}")

class SoftMaskConstrainedHeatmapLayer(nn.Module):
    """宽松约束的热图生成层，允许背景区域有一定信号"""
    def __init__(self, in_channels, num_classes, heatmap_size=(56, 56)):
        super(SoftMaskConstrainedHeatmapLayer, self).__init__()
        self.heatmap_size = heatmap_size
        self.conv = nn.Conv2d(in_channels, num_classes, kernel_size=1)
        # 初始化权重
        nn.init.xavier_uniform_(self.conv.weight)
        nn.init.zeros_(self.conv.bias)
    
    def forward(self, features, binary_mask=None):
        """
        Args:
            features: 特征图 [B, C, H, W]
            binary_mask: 二值掩码 [B, 1, H, W]，与特征图相同尺寸
        Returns:
            constrained_heatmap: 宽松约束后的热图 [B, num_classes, H, W]
        """
        # 生成原始热图
        raw_heatmap = self.conv(features)  # [B, num_classes, H, W]
        
        # 如果有掩码，应用宽松约束
        if binary_mask is not None:
            # 使用双线性插值保持平滑过渡
            if binary_mask.size()[-2:] != raw_heatmap.size()[-2:]:
                binary_mask = F.interpolate(binary_mask, size=raw_heatmap.size()[-2:], 
                                          mode='bilinear', align_corners=False)
            
            # 宽松约束：掩码区域保留，背景区域减弱但不归零
            # 方法1：添加基础值，背景保留20%信号
            soft_mask = binary_mask * 0.8 + 0.2
            
            constrained_heatmap = raw_heatmap * soft_mask
        else:
            constrained_heatmap = raw_heatmap
        
        return constrained_heatmap

class RenalFibrosisModel(nn.Module):
    def __init__(self, num_classes=2, model_name='resnet18', pretrained=True, mask_constraint='soft'):
        super(RenalFibrosisModel, self).__init__()
        self.mask_constraint = mask_constraint  # 'soft' 或 'strict'
        
        # 加载预训练模型
        if model_name == 'resnet18':
            self.backbone = models.resnet18(pretrained=pretrained)
            # 移除最后的全连接层和平均池化层
            self.features = nn.Sequential(*list(self.backbone.children())[:-2])
            num_features = self.backbone.fc.in_features
        else:
            raise ValueError(f"Unsupported model: {model_name}")
        
        # 全局平均池化
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        
        # 分类器
        self.classifier = nn.Linear(num_features, num_classes)
        
        # 热图生成层（使用宽松约束版本）
        self.heatmap_layer = SoftMaskConstrainedHeatmapLayer(num_features, num_classes)
        
        # 初始化分类器
        nn.init.xavier_uniform_(self.classifier.weight)
        if self.classifier.bias is not None:
            nn.init.zeros_(self.classifier.bias)
    
    def forward(self, x, mask=None):
        """
        Args:
            x: 输入图像 [B, 3, H, W]
            mask: 二值掩码 [B, 1, H, W]
        Returns:
            outputs: 分类logits [B, num_classes]
            heatmap: 宽松约束热图 [B, num_classes, H, W]
            features: 特征图 [B, C, H, W]
        """
        # 提取特征
        features = self.features(x)  # [B, 512, H/32, W/32]
        
        # 宽松约束：移除特征层的掩码约束，直接使用原始特征
        # 只在热图生成阶段应用宽松约束
        masked_features = features
        
        # 如果需要掩码用于热图生成，调整尺寸
        mask_resized = None
        if mask is not None:
            mask_resized = F.interpolate(mask, size=features.size()[-2:], 
                                       mode='bilinear', align_corners=False)
        
        # 生成热图（应用宽松约束）
        heatmap = self.heatmap_layer(masked_features, mask_resized)
        
        # 全局平均池化用于分类
        pooled = self.global_avg_pool(masked_features)
        pooled = pooled.view(pooled.size(0), -1)
        
        # 分类输出
        outputs = self.classifier(pooled)
        
        return outputs, heatmap, features
    
    def generate_attention_map(self, x, mask=None, class_idx=None):
        """生成宽松约束的热图"""
        # 前向传播
        outputs, heatmap, features = self.forward(x, mask)
        
        # 如果指定了类别，使用该类别的热图通道
        if class_idx is not None:
            attention_map = heatmap[:, class_idx:class_idx+1]  # [B, 1, H, W]
        else:
            # 使用所有类别的平均热图
            attention_map = heatmap.mean(dim=1, keepdim=True)  # [B, 1, H, W]
        
        # 上采样到输入图像尺寸
        attention_map = F.interpolate(attention_map, size=x.size()[-2:], 
                                    mode='bilinear', align_corners=False)
        
        # 宽松约束：不强制应用掩码，或应用轻微约束
        if mask is not None and self.mask_constraint == 'soft':
            # 轻微约束：背景区域保留30%信号
            soft_final_mask = mask * 0.6 + 0.4
            attention_map = attention_map * soft_final_mask
        
        # 归一化到[0, 1]
        attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min() + 1e-8)
        
        return attention_map.squeeze(1)  # [B, H, W]

def load_and_preprocess_annotations(annotation_file):
    """加载并预处理标注数据，现在包含掩码路径"""
    try:
        annotations = pd.read_excel(annotation_file)
        
        # 自动检测标签列
        label_col = None
        possible_label_names = ['LABLE', 'LABEL', 'Label', 'label', '分级', '评分', 'fibrosis', 'stage']
        for name in possible_label_names:
            if name in annotations.columns:
                label_col = name
                break
        if label_col is None:
            raise ValueError("无法找到标签列")
        
        annotations.rename(columns={label_col: 'LABEL'}, inplace=True)
        annotations['ID'] = annotations['ID'].astype(str).str.strip()
        
        # 处理标签 - 转换为0/1二分类
        unique_labels = sorted(annotations['LABEL'].unique())
        if len(unique_labels) != 2:
            if len(unique_labels) > 2:
                print(f"警告: 检测到多类标签{unique_labels}, 将转换为二分类问题")
                annotations['LABEL'] = (annotations['LABEL'] >= 3).astype(int)
            else:
                raise ValueError("标签类别不足2类")
        
        # 添加图像路径列
        def find_image_path(x):
            base_path = os.path.join(image_dir, x)
            for ext in ['.jpg', '.tif', '.tiff', '.png', '.jpeg']:
                img_path = base_path + ext
                if os.path.exists(img_path):
                    return img_path
                img_path = base_path + ext.upper()
                if os.path.exists(img_path):
                    return img_path
            return None
        
        annotations['image_path'] = annotations['ID'].apply(find_image_path)
        
        # 添加掩码路径列
        def find_mask_path(x):
            base_path = os.path.join(mask_dir, x)
            for ext in ['.png', '.jpg', '.tif', '.tiff']:
                mask_path = base_path + '_mask' + ext
                if os.path.exists(mask_path):
                    return mask_path
                mask_path = base_path + '_mask' + ext.upper()
                if os.path.exists(mask_path):
                    return mask_path
                # 也尝试直接使用ID作为文件名
                mask_path = os.path.join(mask_dir, x + ext)
                if os.path.exists(mask_path):
                    return mask_path
            return None
        
        annotations['mask_path'] = annotations['ID'].apply(find_mask_path)
        
        # 检查是否有图像或掩码缺失
        missing_images = annotations['image_path'].isnull().sum()
        missing_masks = annotations['mask_path'].isnull().sum()
        
        if missing_images > 0:
            print(f"\n警告: 有{missing_images}个ID找不到对应的图像文件")
        if missing_masks > 0:
            print(f"\n警告: 有{missing_masks}个ID找不到对应的掩码文件")
        
        # 只保留既有图像又有掩码的样本
        annotations = annotations.dropna(subset=['image_path', 'mask_path'])
        
        if len(annotations) == 0:
            raise ValueError("没有找到同时包含图像和掩码的有效样本")
        
        print(f"\n找到{len(annotations)}个有效样本（同时包含图像和掩码）")
        
        return annotations
    
    except Exception as e:
        print(f"加载标注文件出错: {str(e)}")
        raise

class RenalFibrosisDataset(Dataset):
    def __init__(self, dataframe, image_dir, mask_dir, transform=None, mask_transform=None):
        self.dataframe = dataframe.copy()
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.mask_transform = mask_transform or transforms.Compose([
            transforms.Resize((256, 256)),  # 与图像预处理一致
            transforms.ToTensor()
        ])
        self.valid_samples = []
        
        # 预先检查所有图像和掩码可用性
        for idx, row in self.dataframe.iterrows():
            img_path = row['image_path']
            mask_path = row['mask_path']
            
            if (isinstance(img_path, str) and os.path.exists(img_path) and
                isinstance(mask_path, str) and os.path.exists(mask_path)):
                try:
                    # 验证图像
                    with Image.open(img_path) as img:
                        img.verify()
                    
                    # 验证掩码
                    with Image.open(mask_path) as mask:
                        mask.verify()
                    
                    self.valid_samples.append((img_path, mask_path, row['LABEL']))
                except Exception as e:
                    print(f"警告: 文件损坏 {img_path} 或 {mask_path}: {str(e)}")
            else:
                print(f"警告: 文件不存在 {img_path} 或 {mask_path}")
        
        # 计算类别权重
        labels = [label for _, _, label in self.valid_samples]
        class_counts = Counter(labels)
        total = sum(class_counts.values())
        self.class_weights = torch.tensor([total/class_counts[0], total/class_counts[1]] if len(class_counts) == 2 else [1.0, 1.0])
    
    def __len__(self):
        return len(self.valid_samples)
    
    def __getitem__(self, idx):
        img_path, mask_path, label = self.valid_samples[idx]
        
        try:
            # 加载图像
            with Image.open(img_path) as img:
                image = img.convert('RGB')
            
            # 加载掩码
            with Image.open(mask_path) as mask_img:
                mask = mask_img.convert('L')  # 转为灰度图
            
            if self.transform:
                image = self.transform(image)
            
            if self.mask_transform:
                mask = self.mask_transform(mask)
                # 确保掩码是二值的
                mask = (mask > 0.5).float()
            
            return image, mask, torch.tensor(label, dtype=torch.long)
        except Exception as e:
            print(f"加载文件 {img_path} 或 {mask_path} 出错: {str(e)}")
            # 返回空白数据
            return torch.zeros(3, 224, 224), torch.zeros(1, 224, 224), torch.tensor(-1, dtype=torch.long)

def get_transforms():
    """获取数据增强变换"""
    train_transforms = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ColorJitter(brightness=0.1, contrast=0.1),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    val_transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    return train_transforms, val_transforms

class EarlyStopping:
    def __init__(self, patience=5, delta=0, verbose=True, path='checkpoint.pth'):
        self.patience = patience
        self.delta = delta
        self.verbose = verbose
        self.path = path
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.inf
    
    def __call__(self, val_loss, model):
        score = -val_loss
        
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter}/{self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0
    
    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

def calculate_metrics(y_true, y_pred, y_probs):
    """计算所有评估指标"""
    metrics = {
        'accuracy': accuracy_score(y_true, y_pred),
        'precision': precision_score(y_true, y_pred),
        'recall': recall_score(y_true, y_pred),
        'f1': f1_score(y_true, y_pred),
        'roc_auc': roc_auc_score(y_true, y_probs),
        'pr_auc': average_precision_score(y_true, y_probs),
        'kappa': cohen_kappa_score(y_true, y_pred)
    }
    return metrics

def visualize_heatmap(image, mask, heatmap, save_path, mask_strength=0.6):
    """可视化热图，应用宽松的掩码约束"""
    # 转换为numpy
    image_np = image.permute(1, 2, 0).cpu().numpy()
    image_np = (image_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])) * 255
    image_np = np.clip(image_np, 0, 255).astype(np.uint8)
    
    mask_np = mask.squeeze().cpu().numpy() if mask.dim() == 3 else mask.cpu().numpy()
    heatmap_np = heatmap.squeeze().cpu().numpy()
    
    # 宽松约束：背景区域保留部分信号
    if mask_strength < 1.0:
        soft_mask = mask_np * mask_strength + (1 - mask_strength)
        heatmap_np = heatmap_np * soft_mask
    
    # 归一化热图
    if heatmap_np.max() > heatmap_np.min():
        heatmap_np = (heatmap_np - heatmap_np.min()) / (heatmap_np.max() - heatmap_np.min())
    
    # 创建热图彩色图
    heatmap_colored = cv2.applyColorMap(np.uint8(255 * heatmap_np), cv2.COLORMAP_JET)
    heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
    
    # 叠加到原图
    overlayed = cv2.addWeighted(image_np, 0.6, heatmap_colored, 0.3, 0)
    
    # 创建子图
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    
    axes[0].imshow(image_np)
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    
    axes[1].imshow(mask_np, cmap='gray')
    axes[1].set_title('Binary Mask')
    axes[1].axis('off')
    
    axes[2].imshow(heatmap_np, cmap='jet')
    axes[2].set_title('Soft Attention Heatmap')
    axes[2].axis('off')
    
    axes[3].imshow(overlayed)
    axes[3].set_title('Overlayed Result')
    axes[3].axis('off')
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

def main():
    
    def convert_tensors_to_serializable(obj):
        if torch.is_tensor(obj):
            return obj.item() if obj.numel() == 1 else obj.tolist()
        elif isinstance(obj, dict):
            return {k: convert_tensors_to_serializable(v) for k, v in obj.items()}
        elif isinstance(obj, (list, tuple)):
            return [convert_tensors_to_serializable(x) for x in obj]
        elif isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        return obj
    
    # 加载数据
    annotations = load_and_preprocess_annotations(annotation_file)
    
    # 超参数配置 - 使用宽松约束
    config = {
        'batch_size': 8,
        'num_epochs': 40,
        'learning_rate': 0.0001,
        'weight_decay': 1e-4,
        'model_name': 'resnet18',
        'k_folds': 5,
        'patience': 7,
        'mask_constraint': 'soft',  # 改为宽松约束模式
        'mask_strength': 0.6,      # 掩码约束强度（0-1）
        'heatmap_lambda': 0.1,
        'model_save_dir': model_save_dir  # 新增：模型保存路径
    }
    
    # 获取数据增强
    train_transforms, val_transforms = get_transforms()
    
    # 掩码变换（使用双线性插值保持平滑）
    mask_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])
    
    # 5折交叉验证
    skf = StratifiedKFold(n_splits=config['k_folds'], shuffle=True, random_state=42)
    X = annotations
    y = annotations['LABEL']
    
    # 用于存储所有折的预测结果和热图
    all_train_labels = []
    all_train_preds = []
    all_train_probs = []
    all_val_labels = []
    all_val_preds = []
    all_val_probs = []
    
    fold_metrics = []
    
    for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
        print(f"\n{'='*20} Fold {fold + 1}/{config['k_folds']} {'='*20}")
        print(f"使用宽松掩码约束模式，强度: {config['mask_strength']}")
        
        # 准备数据
        train_df = X.iloc[train_idx].reset_index(drop=True)
        val_df = X.iloc[val_idx].reset_index(drop=True)
        
        train_dataset = RenalFibrosisDataset(train_df, image_dir, mask_dir, train_transforms, mask_transforms)
        val_dataset = RenalFibrosisDataset(val_df, image_dir, mask_dir, val_transforms, mask_transforms)
        
        # 初始化数据加载器
        train_loader = DataLoader(
            train_dataset,
            batch_size=config['batch_size'],
            shuffle=True,
            num_workers=0,
            pin_memory=True
        )
        val_loader = DataLoader(
            val_dataset,
            batch_size=config['batch_size'],
            shuffle=False,
            num_workers=0,
            pin_memory=True
        )
        
        # 初始化模型（使用宽松约束）
        model = RenalFibrosisModel(
            num_classes=2,
            model_name=config['model_name'],
            pretrained=True,
            mask_constraint=config['mask_constraint']
        ).to(device)
        
        # 损失函数
        class_weights = train_dataset.class_weights.to(device)
        criterion = nn.CrossEntropyLoss(weight=class_weights)
        
        # 优化器
        optimizer = optim.AdamW(
            model.parameters(),
            lr=config['learning_rate'],
            weight_decay=config['weight_decay']
        )
        
        # 修改：使用新的模型保存路径
        early_stopping_path = os.path.join(config['model_save_dir'], f'best_model_fold{fold}_soft.pth')
        early_stopping = EarlyStopping(
            patience=config['patience'],
            verbose=True,
            path=early_stopping_path
        )
        
        # 初始化最佳记录
        best_acc = 0.0
        best_auc = 0.0
        # 修改：使用新的模型保存路径
        best_acc_model_path = os.path.join(config['model_save_dir'], f'best_acc_model_fold{fold}_soft.pth')
        best_auc_model_path = os.path.join(config['model_save_dir'], f'best_auc_model_fold{fold}_soft.pth')
        
        # 用于存储本折的最佳预测结果
        best_train_labels = None
        best_train_preds = None
        best_train_probs = None
        best_val_labels = None
        best_val_preds = None
        best_val_probs = None
        
        # 训练循环
        for epoch in range(config['num_epochs']):
            print(f'\nEpoch {epoch + 1}/{config["num_epochs"]}')
            
            # 训练
            model.train()
            train_loss, train_correct = 0.0, 0
            total_train = 0
            train_probs, train_labels = [], []
            
            for inputs, masks, labels in tqdm(train_loader, desc='Training'):
                # 过滤无效样本
                valid_mask = labels != -1
                if not valid_mask.any():
                    continue
                    
                inputs = inputs[valid_mask].to(device)
                masks = masks[valid_mask].to(device)
                labels = labels[valid_mask].to(device)
                
                optimizer.zero_grad()
                outputs, heatmaps, _ = model(inputs, masks)
                loss = criterion(outputs, labels)
                
                loss.backward()
                optimizer.step()
                
                _, preds = torch.max(outputs, 1)
                train_loss += loss.item() * inputs.size(0)
                train_correct += torch.sum(preds == labels.data)
                total_train += inputs.size(0)
                
                # 收集训练集的预测概率和真实标签
                probs = torch.softmax(outputs, dim=1)[:, 1]
                train_probs.extend(probs.detach().cpu().numpy())
                train_labels.extend(labels.detach().cpu().numpy())
            
            # 验证
            model.eval()
            val_loss, val_correct = 0.0, 0
            total_val = 0
            val_probs, val_labels = [], []
            
            with torch.no_grad():
                for inputs, masks, labels in tqdm(val_loader, desc='Validation'):
                    valid_mask = labels != -1
                    if not valid_mask.any():
                        continue
                        
                    inputs = inputs[valid_mask].to(device)
                    masks = masks[valid_mask].to(device)
                    labels = labels[valid_mask].to(device)
                    
                    outputs, heatmaps, _ = model(inputs, masks)
                    loss = criterion(outputs, labels)
                    
                    _, preds = torch.max(outputs, 1)
                    val_loss += loss.item() * inputs.size(0)
                    val_correct += torch.sum(preds == labels.data)
                    total_val += inputs.size(0)
                    
                    probs = torch.softmax(outputs, dim=1)[:, 1]
                    val_probs.extend(probs.cpu().numpy())
                    val_labels.extend(labels.cpu().numpy())
                    
                    # 在每个epoch的最后一批验证数据上可视化热图
                    if epoch == config['num_epochs'] - 1:
                        # 修改：热图保存路径也放在模型保存目录下
                        heatmap_dir = os.path.join(config['model_save_dir'], 'heatmaps_soft')
                        os.makedirs(heatmap_dir, exist_ok=True)
                        for i in range(min(3, inputs.size(0))):  # 可视化前3个样本
                            attention_map = model.generate_attention_map(inputs[i:i+1], masks[i:i+1])
                            visualize_heatmap(
                                inputs[i], masks[i], attention_map,
                                os.path.join(heatmap_dir, f'fold{fold}_epoch{epoch}_sample{i}.png'),
                                mask_strength=config['mask_strength']
                            )
            
            # 计算训练集和验证集指标
            train_loss = train_loss / total_train if total_train > 0 else 0
            train_acc = train_correct.double() / total_train if total_train > 0 else 0
            val_loss = val_loss / total_val if total_val > 0 else 0
            val_acc = val_correct.double() / total_val if total_val > 0 else 0
            
            # 计算各项指标
            train_metrics = calculate_metrics(
                np.array(train_labels), 
                np.array(train_probs) > 0.5, 
                np.array(train_probs)
            )
            val_metrics = calculate_metrics(
                np.array(val_labels), 
                np.array(val_probs) > 0.5, 
                np.array(val_probs)
            )
            
            print(f'\nTrain Metrics - Loss: {train_loss:.4f} | Acc: {train_acc:.4f}')
            print(f"Train - Accuracy: {train_metrics['accuracy']:.4f} | Precision: {train_metrics['precision']:.4f} | Recall: {train_metrics['recall']:.4f} | F1: {train_metrics['f1']:.4f} | AUC-ROC: {train_metrics['roc_auc']:.4f}")
            
            print(f'\nVal Metrics - Loss: {val_loss:.4f} | Acc: {val_acc:.4f}')
            print(f"Val - Accuracy: {val_metrics['accuracy']:.4f} | Precision: {val_metrics['precision']:.4f} | Recall: {val_metrics['recall']:.4f} | F1: {val_metrics['f1']:.4f} | AUC-ROC: {val_metrics['roc_auc']:.4f}")
            
            # 早停检查
            early_stopping(val_loss, model)
            if early_stopping.early_stop:
                print("Early stopping triggered")
                break
            
            # 保存最佳ACC模型和预测结果
            if val_acc > best_acc:
                best_acc = val_acc
                torch.save(model.state_dict(), best_acc_model_path)
                print(f"New best ACC model saved with accuracy: {best_acc:.4f}")
                
                # 保存最佳预测结果
                best_train_labels = train_labels
                best_train_preds = (np.array(train_probs) > 0.5).astype(int)
                best_train_probs = train_probs
                best_val_labels = val_labels
                best_val_preds = (np.array(val_probs) > 0.5).astype(int)
                best_val_probs = val_probs
            
            # 保存最佳AUC模型
            if val_metrics['roc_auc'] > best_auc:
                best_auc = val_metrics['roc_auc']
                torch.save(model.state_dict(), best_auc_model_path)
                print(f"New best AUC model saved with AUC: {best_auc:.4f}")
        
        # 保存本折的最佳预测结果
        if best_train_labels is not None:
            all_train_labels.extend(best_train_labels)
            all_train_preds.extend(best_train_preds)
            all_train_probs.extend(best_train_probs)
            all_val_labels.extend(best_val_labels)
            all_val_preds.extend(best_val_preds)
            all_val_probs.extend(best_val_probs)
            
            # 保存本折指标
            fold_metrics.append({
                'fold': fold + 1,
                'best_acc': best_acc,
                'best_auc': best_auc,
                'train_metrics': calculate_metrics(best_train_labels, best_train_preds, best_train_probs),
                'val_metrics': calculate_metrics(best_val_labels, best_val_preds, best_val_probs),
                'best_epoch': epoch - early_stopping.counter
            })
    
    # 保存和输出最终结果
    if len(all_val_labels) > 0:
        # 计算合并后的指标
        combined_train_metrics = calculate_metrics(
            np.array(all_train_labels),
            np.array(all_train_preds),
            np.array(all_train_probs)
        )
        
        combined_val_metrics = calculate_metrics(
            np.array(all_val_labels),
            np.array(all_val_preds),
            np.array(all_val_probs)
        )
        
        # 打印结果
        print("\n\n================ Final Combined Results (Soft Mask Constraint) ================")
        print(f"Mask Constraint Strength: {config['mask_strength']}")
        print("\nCombined Training Metrics:")
        print(f"Accuracy:    {combined_train_metrics['accuracy']:.4f}")
        print(f"Precision:   {combined_train_metrics['precision']:.4f}")
        print(f"Recall:      {combined_train_metrics['recall']:.4f}")
        print(f"F1-score:    {combined_train_metrics['f1']:.4f}")
        print(f"ROC-AUC:     {combined_train_metrics['roc_auc']:.4f}")
        print(f"PR-AUC:      {combined_train_metrics['pr_auc']:.4f}")
        print(f"Kappa:       {combined_train_metrics['kappa']:.4f}")
        
        print("\nCombined Validation Metrics:")
        print(f"Accuracy:    {combined_val_metrics['accuracy']:.4f}")
        print(f"Precision:   {combined_val_metrics['precision']:.4f}")
        print(f"Recall:      {combined_val_metrics['recall']:.4f}")
        print(f"F1-score:    {combined_val_metrics['f1']:.4f}")
        print(f"ROC-AUC:     {combined_val_metrics['roc_auc']:.4f}")
        print(f"PR-AUC:      {combined_val_metrics['pr_auc']:.4f}")
        print(f"Kappa:       {combined_val_metrics['kappa']:.4f}")
        
        # 修改：结果文件也保存到模型目录
        fold_metrics_path = os.path.join(config['model_save_dir'], 'fold_metrics_soft_mask.csv')
        fold_metrics_df = pd.DataFrame(fold_metrics)
        fold_metrics_df.to_csv(fold_metrics_path, index=False)
        
        # 保存合并的预测结果
        combined_results = {
            'train': {
                'labels': [int(x) for x in all_train_labels],
                'preds': [int(x) for x in all_train_preds],
                'probs': [float(x) for x in all_train_probs]
            },
            'val': {
                'labels': [int(x) for x in all_val_labels],
                'preds': [int(x) for x in all_val_preds],
                'probs': [float(x) for x in all_val_probs]
            },
            'config': config
        }
        
        combined_predictions_path = os.path.join(config['model_save_dir'], 'combined_predictions_soft_mask.json')
        with open(combined_predictions_path, 'w') as f:
            json.dump(combined_results, f, indent=4)
        
        results = {
            'config': config,
            'combined_train_metrics': convert_tensors_to_serializable(combined_train_metrics),
            'combined_val_metrics': convert_tensors_to_serializable(combined_val_metrics),
            'fold_metrics': convert_tensors_to_serializable(fold_metrics)
        }

        final_results_path = os.path.join(config['model_save_dir'], 'final_results_soft_mask.json')
        with open(final_results_path, 'w') as f:
            json.dump(results, f, indent=4)
        
        # 可视化合并ROC曲线
        plt.figure(figsize=(8, 8))
        
        train_fpr, train_tpr, _ = roc_curve(all_train_labels, all_train_probs)
        train_auc = combined_train_metrics['roc_auc']
        plt.plot(train_fpr, train_tpr, color='blue', linestyle='--', 
                label=f'Train ROC (AUC = {train_auc:.3f})')
        
        val_fpr, val_tpr, _ = roc_curve(all_val_labels, all_val_probs)
        val_auc = combined_val_metrics['roc_auc']
        plt.plot(val_fpr, val_tpr, color='red', 
                label=f'Val ROC (AUC = {val_auc:.3f})')
        
        plt.plot([0, 1], [0, 1], 'k--')
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title(f'Soft Mask Constraint ROC (Strength: {config["mask_strength"]})')
        plt.legend()
        # 修改：ROC曲线也保存到模型目录
        roc_curve_path = os.path.join(config['model_save_dir'], 'combined_roc_curve_soft_mask.png')
        plt.savefig(roc_curve_path, dpi=300)
        plt.close()
    
    print(f"\n训练完成！所有结果和模型已保存到: {config['model_save_dir']}")
    print("宽松约束的热图可视化已保存到 'heatmaps_soft' 目录")

if __name__ == '__main__':
    main()

In [None]:
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import cv2
import os
from pathlib import Path
from torchvision import transforms
import torch.nn.functional as F

def ensure_align_heatmap_with_original():
    """确保热图与原图完美对齐的完整版本"""
    
    # 设备配置
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # 加载模型
    model_path = "F.pth"
    if not os.path.exists(model_path):
        model_files = [f for f in os.listdir('.') if f.startswith('best_auc_model') and f.endswith('_soft.pth')]
        if not model_files:
            model_files = [f for f in os.listdir('.') if f.startswith('best_auc_model') and f.endswith('.pth')]
        if model_files:
            model_path = model_files[0]
        else:
            raise FileNotFoundError("未找到训练好的模型文件")
    
    model = RenalFibrosisModel(
        num_classes=2, 
        model_name='resnet18', 
        pretrained=False, 
        mask_constraint='soft'
    ).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    print(f"模型加载成功: {model_path}")
    print(f"使用宽松掩码约束模式")
    
    # 指定路径
    image_folder = os.path.join(data_dir, "Fibrosis")
    mask_folder = os.path.join(data_dir, "Mask")
    output_folder = os.path.join(data_dir, "Fibrosisheatmap")
    os.makedirs(output_folder, exist_ok=True)
    
    supported_formats = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif'}
    mask_strength = 0.6
    
    def correct_interpolate_heatmap(heatmap, target_size):
        """正确地上采样热图到目标尺寸"""
        # heatmap 的维度应该是 [H, W]
        if heatmap.dim() == 2:
            # 添加 batch 和 channel 维度: [1, 1, H, W]
            heatmap_4d = heatmap.unsqueeze(0).unsqueeze(0)
        else:
            # 如果已经是3维或4维，确保是 [1, 1, H, W] 格式
            heatmap_4d = heatmap.unsqueeze(0) if heatmap.dim() == 3 else heatmap
            if heatmap_4d.size(1) != 1:
                heatmap_4d = heatmap_4d.unsqueeze(1)
        
        # 上采样到目标尺寸 (height, width)
        heatmap_resized = F.interpolate(
            heatmap_4d, 
            size=target_size,
            mode='bilinear', 
            align_corners=False
        )
        
        # 移除添加的维度，返回 [H, W]
        return heatmap_resized.squeeze()
    
    def create_perfect_alignment(original_img, original_mask, heatmap, mask_strength=0.6):
        """创建完美对齐的热图叠加"""
        
        # 转换为numpy数组
        original_img_np = np.array(original_img)
        original_mask_np = np.array(original_mask)
        
        # 确保图像是3通道
        if original_img_np.ndim == 2:
            original_img_np = np.stack([original_img_np] * 3, axis=-1)
        elif original_img_np.shape[2] == 1:
            original_img_np = np.concatenate([original_img_np] * 3, axis=-1)
        
        # 确保掩码是2维
        if original_mask_np.ndim == 3:
            original_mask_np = original_mask_np[:, :, 0]
        original_mask_np = (original_mask_np > 128).astype(np.float32)
        
        # 获取原始尺寸
        original_height, original_width = original_img_np.shape[:2]
        
        # 将热图转换为numpy（如果还是tensor）
        if torch.is_tensor(heatmap):
            heatmap_np = heatmap.cpu().numpy()
        else:
            heatmap_np = heatmap
        
        # 确保热图是2维
        if heatmap_np.ndim > 2:
            heatmap_np = heatmap_np.squeeze()
        
        # 使用OpenCV将热图调整到原始尺寸（最可靠的方法）
        heatmap_resized = cv2.resize(heatmap_np, (original_width, original_height), interpolation=cv2.INTER_LINEAR)
        
        # 同样调整掩码尺寸（确保完全一致）
        mask_resized = cv2.resize(original_mask_np, (original_width, original_height), interpolation=cv2.INTER_NEAREST)
        
        # 应用宽松约束
        soft_mask = mask_resized * mask_strength + (1 - mask_strength)
        constrained_heatmap = heatmap_resized * soft_mask
        
        # 归一化热图
        if constrained_heatmap.max() > constrained_heatmap.min():
            heatmap_normalized = (constrained_heatmap - constrained_heatmap.min()) / (constrained_heatmap.max() - constrained_heatmap.min() + 1e-8)
        else:
            heatmap_normalized = constrained_heatmap
        
        # 创建彩色热图
        heatmap_colored = cv2.applyColorMap(np.uint8(255 * heatmap_normalized), cv2.COLORMAP_JET)
        heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
        
        # 完美对齐叠加
        overlay = cv2.addWeighted(original_img_np, 0.6, heatmap_colored, 0.4, 0)
        
        return {
            'original_img': original_img_np,
            'mask': mask_resized,
            'heatmap': heatmap_normalized,
            'overlay': overlay,
            'constrained_heatmap': constrained_heatmap,
            'original_size': (original_width, original_height)
        }
    
    processed_count = 0
    
    for image_file in os.listdir(image_folder):
        file_path = os.path.join(image_folder, image_file)
        file_ext = Path(image_file).suffix.lower()
        
        if file_ext in supported_formats and os.path.isfile(file_path):
            try:
                print(f"\n处理图像: {image_file}")
                file_name = Path(image_file).stem
                
                # 查找掩码文件
                mask_file = None
                mask_patterns = [
                    f"{file_name}_mask.png", f"{file_name}_mask.jpg",
                    f"{file_name}.png", f"{file_name}.jpg",
                    f"{file_name}_mask.tif", f"{file_name}.tif"
                ]
                
                for pattern in mask_patterns:
                    mask_path = os.path.join(mask_folder, pattern)
                    if os.path.exists(mask_path):
                        mask_file = pattern
                        break
                
                if mask_file is None:
                    print(f"警告: 未找到 {file_name} 对应的掩码文件，跳过处理")
                    continue
                
                # 读取原始图像和掩码
                original_img_pil = Image.open(file_path).convert('RGB')
                original_mask_pil = Image.open(os.path.join(mask_folder, mask_file)).convert('L')
                
                original_size = original_img_pil.size
                print(f"原始图像尺寸: {original_size}")
                
                # 数据预处理
                data_transform = transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                ])
                
                mask_transform = transforms.Compose([
                    transforms.Resize((224, 224)),
                    transforms.ToTensor()
                ])
                
                # 准备模型输入
                model_input_img = data_transform(original_img_pil).unsqueeze(0).to(device)
                model_input_mask = mask_transform(original_mask_pil).unsqueeze(0).to(device)
                model_input_mask = (model_input_mask > 0.5).float()
                
                # 生成热图
                with torch.no_grad():
                    outputs, heatmap, features = model(model_input_img, model_input_mask)
                    attention_map = model.generate_attention_map(model_input_img, model_input_mask)
                    probs = torch.softmax(outputs, dim=1)
                    pred_class = torch.argmax(probs, dim=1).item()
                    confidence = probs[0, pred_class].item()
                
                print(f"生成的热图尺寸: {attention_map.shape}")
                
                # 创建完美对齐的结果
                alignment_result = create_perfect_alignment(
                    original_img_pil, original_mask_pil, attention_map, mask_strength
                )
                
                # 创建详细的可视化
                fig, axes = plt.subplots(2, 4, figsize=(20, 10))
                
                # 第一行：原始数据和模型输入
                axes[0, 0].imshow(alignment_result['original_img'])
                axes[0, 0].set_title(f"原始图像\n{original_size[0]}x{original_size[1]}")
                axes[0, 0].axis('off')
                
                axes[0, 1].imshow(alignment_result['mask'], cmap='gray')
                axes[0, 1].set_title("对齐后的掩码")
                axes[0, 1].axis('off')
                
                # 显示模型输入尺寸的图像
                model_img_np = model_input_img[0].cpu().numpy()
                model_img_np = np.transpose(model_img_np, (1, 2, 0))
                model_img_np = model_img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
                model_img_np = np.clip(model_img_np, 0, 1)
                
                axes[0, 2].imshow(model_img_np)
                axes[0, 2].set_title("模型输入图像\n224x224")
                axes[0, 2].axis('off')
                
                # 显示原始热图（模型输出尺寸）
                original_heatmap = attention_map.squeeze().cpu().numpy()
                axes[0, 3].imshow(original_heatmap, cmap='jet')
                axes[0, 3].set_title("原始热图输出\n224x224")
                axes[0, 3].axis('off')
                
                # 第二行：对齐后的结果
                axes[1, 0].imshow(alignment_result['heatmap'], cmap='jet')
                heatmap_size = alignment_result['heatmap'].shape
                axes[1, 0].set_title(f"对齐后热图\n{heatmap_size[1]}x{heatmap_size[0]}")
                axes[1, 0].axis('off')
                
                axes[1, 1].imshow(alignment_result['overlay'])
                class_name = "纤维化" if pred_class == 1 else "正常"
                axes[1, 1].set_title(f"完美对齐叠加\n预测: {class_name}\n置信度: {confidence:.3f}")
                axes[1, 1].axis('off')
                
                # 热图强度分布
                masked_heatmap = alignment_result['constrained_heatmap'][alignment_result['mask'] > 0.5]
                background_heatmap = alignment_result['constrained_heatmap'][alignment_result['mask'] <= 0.5]
                
                if len(masked_heatmap) > 0:
                    axes[1, 2].hist(masked_heatmap, bins=50, alpha=0.6, color='red', label='掩码区域')
                    if len(background_heatmap) > 0:
                        axes[1, 2].hist(background_heatmap, bins=50, alpha=0.5, color='blue', label='背景区域')
                    axes[1, 2].set_title("热图强度分布")
                    axes[1, 2].set_xlabel("强度值")
                    axes[1, 2].set_ylabel("像素数量")
                    axes[1, 2].legend()
                
                # 详细统计信息
                mask_coverage = np.sum(alignment_result['mask']) / alignment_result['mask'].size * 100
                mask_mean_intensity = np.mean(masked_heatmap) if len(masked_heatmap) > 0 else 0
                bg_mean_intensity = np.mean(background_heatmap) if len(background_heatmap) > 0 else 0
                
                info_text = f"""对齐验证信息:
原始尺寸: {original_size[0]}x{original_size[1]}
热图尺寸: {heatmap_size[1]}x{heatmap_size[0]}
掩码覆盖率: {mask_coverage:.1f}%
掩码区域强度: {mask_mean_intensity:.3f}
背景区域强度: {bg_mean_intensity:.3f}
约束强度: {mask_strength}
预测结果: {class_name}
置信度: {confidence:.3f}"""
                
                axes[1, 3].text(0.05, 0.95, info_text, fontsize=9, verticalalignment='top',
                               bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
                axes[1, 3].set_title("详细统计信息")
                axes[1, 3].axis('off')
                
                plt.suptitle(f'肾脏纤维化热图分析 - {file_name} (完美对齐)', fontsize=16, y=0.95)
                plt.tight_layout()
                
                # 保存综合结果
                result_path = os.path.join(output_folder, f"{file_name}_detailed_analysis.jpg")
                plt.savefig(result_path, dpi=300, bbox_inches='tight')
                plt.close()
                
                # 单独保存高分辨率叠加图像
                overlay_highres = cv2.addWeighted(
                    alignment_result['original_img'], 0.6, 
                    cv2.cvtColor(
                        cv2.applyColorMap(
                            np.uint8(255 * alignment_result['heatmap']), 
                            cv2.COLORMAP_JET
                        ), 
                        cv2.COLOR_BGR2RGB
                    ), 0.3, 0
                )
                
                overlay_path = os.path.join(output_folder, f"{file_name}_highres_overlay.jpg")
                cv2.imwrite(overlay_path, cv2.cvtColor(overlay_highres, cv2.COLOR_RGB2BGR))
                
                # 保存对齐检查图像（在四个角标记参考点）
                check_img = alignment_result['original_img'].copy()
                h, w = check_img.shape[:2]
                # 标记四个角
                cv2.circle(check_img, (10, 10), 8, (255, 255, 0), -1)  # 左上-黄色
                cv2.circle(check_img, (w-10, 10), 8, (0, 255, 255), -1)  # 右上-青色
                cv2.circle(check_img, (10, h-10), 8, (255, 0, 255), -1)  # 左下-粉色
                cv2.circle(check_img, (w-10, h-10), 8, (0, 255, 0), -1)  # 右下-绿色
                
                check_overlay = cv2.addWeighted(check_img, 0.8, 
                                              cv2.cvtColor(
                                                  cv2.applyColorMap(
                                                      np.uint8(255 * alignment_result['heatmap']), 
                                                      cv2.COLORMAP_JET
                                                  ), 
                                                  cv2.COLOR_BGR2RGB
                                              ), 0.2, 0)
                
                cv2.imwrite(os.path.join(output_folder, f"{file_name}_alignment_check.jpg"), 
                           cv2.cvtColor(check_overlay, cv2.COLOR_RGB2BGR))
                
                # 新增：保存原始图像-热图-叠加效果三合一图像
                fig_triple, axes_triple = plt.subplots(1, 3, figsize=(18, 6))
                
                # 原始图像
                axes_triple[0].imshow(alignment_result['original_img'])
                axes_triple[0].set_title(f"Original\n{original_size[0]}x{original_size[1]}")
                axes_triple[0].axis('off')
                
                # 热图
                im_heatmap = axes_triple[1].imshow(alignment_result['heatmap'], cmap='jet')
                axes_triple[1].set_title(f"Grad-CAM\n{heatmap_size[1]}x{heatmap_size[0]}")
                axes_triple[1].axis('off')
                # 添加颜色条
                plt.colorbar(im_heatmap, ax=axes_triple[1], fraction=0.046, pad=0.04)
                
                # 叠加效果
                axes_triple[2].imshow(alignment_result['overlay'])
                axes_triple[2].set_title(f"Overlay\n预测: {class_name} (置信度: {confidence:.3f})")
                axes_triple[2].axis('off')
                
                plt.suptitle(f'肾脏纤维化分析 - {file_name}', fontsize=16)
                plt.tight_layout()
                
                # 保存三合一图像
                triple_path = os.path.join(output_folder, f"{file_name}_triple_comparison.jpg")
                plt.savefig(triple_path, dpi=300, bbox_inches='tight')
                plt.close()
                
                print(f"✓ 完美对齐完成: {file_name}")
                print(f"  原始尺寸: {original_size[0]}x{original_size[1]}")
                print(f"  热图尺寸: {alignment_result['heatmap'].shape[1]}x{alignment_result['heatmap'].shape[0]}")
                print(f"  掩码覆盖率: {mask_coverage:.1f}%")
                print(f"  预测: {class_name} (置信度: {confidence:.3f})")
                
                processed_count += 1
                
            except Exception as e:
                print(f"✗ 处理 {image_file} 时出错: {str(e)}")
                import traceback
                traceback.print_exc()
                continue
    
    print(f"\n{'='*60}")
    print(f"完美对齐处理完成! 成功处理 {processed_count} 个图像")
    print(f"详细分析结果保存在: {output_folder}")
    print(f"每个图像包含:")
    print(f"  - 详细分析图 (_detailed_analysis.jpg)")
    print(f"  - 高分辨率叠加图 (_highres_overlay.jpg)")
    print(f"  - 对齐检查图 (_alignment_check.jpg)")
    print(f"  - 三合一对比图 (_triple_comparison.jpg)")

if __name__ == "__main__":
    if 'RenalFibrosisModel' not in globals():
        print("错误: 需要先定义 RenalFibrosisModel 类")
        print("请确保已经运行了包含模型定义的代码")
    else:
        ensure_align_heatmap_with_original()