In [None]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, f1_score, precision_score, recall_score
from sklearn.metrics import roc_auc_score, roc_curve, precision_recall_curve, average_precision_score, cohen_kappa_score
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import json

# 设置随机种子保证可重复性
torch.manual_seed(42)
np.random.seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# 设备配置（自动检测GPU）
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ==================== 需要修改的部分 ====================
# 1. 数据路径配置 - 修改为你的新数据路径
new_data_dir = os.getcwd()  # 修改：新的原图数据目录
new_image_dir = os.path.join(new_data_dir, "Ftest")  # 修改：新图像目录
new_annotation_file = os.path.join(new_data_dir, "Ftest.xlsx")  # 修改：新标注文件

# 2. 模型路径配置 - 修改为你训练好的模型路径
trained_model_path = os.path.join(new_data_dir, "F.pth")  # 修改：选择训练好的模型文件

# 3. 结果保存路径
results_save_dir = os.path.join(new_data_dir, "Ftestresults")  # 修改：验证结果保存目录
os.makedirs(results_save_dir, exist_ok=True)
# =====================================================

# 加载模型定义（需要与训练时相同的模型类）
class SoftMaskConstrainedHeatmapLayer(nn.Module):
    """宽松约束的热图生成层"""
    def __init__(self, in_channels, num_classes, heatmap_size=(56, 56)):
        super(SoftMaskConstrainedHeatmapLayer, self).__init__()
        self.heatmap_size = heatmap_size
        self.conv = nn.Conv2d(in_channels, num_classes, kernel_size=1)
        nn.init.xavier_uniform_(self.conv.weight)
        nn.init.zeros_(self.conv.bias)
    
    def forward(self, features, binary_mask=None):
        raw_heatmap = self.conv(features)
        
        if binary_mask is not None:
            if binary_mask.size()[-2:] != raw_heatmap.size()[-2:]:
                binary_mask = F.interpolate(binary_mask, size=raw_heatmap.size()[-2:], 
                                          mode='bilinear', align_corners=False)
            soft_mask = binary_mask * 0.8 + 0.2
            constrained_heatmap = raw_heatmap * soft_mask
        else:
            constrained_heatmap = raw_heatmap
        
        return constrained_heatmap

class RenalFibrosisModel(nn.Module):
    def __init__(self, num_classes=2, model_name='resnet18', pretrained=True, mask_constraint='soft'):
        super(RenalFibrosisModel, self).__init__()
        self.mask_constraint = mask_constraint
        
        if model_name == 'resnet18':
            self.backbone = models.resnet18(pretrained=pretrained)
            self.features = nn.Sequential(*list(self.backbone.children())[:-2])
            num_features = self.backbone.fc.in_features
        else:
            raise ValueError(f"Unsupported model: {model_name}")
        
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Linear(num_features, num_classes)
        self.heatmap_layer = SoftMaskConstrainedHeatmapLayer(num_features, num_classes)
        
        nn.init.xavier_uniform_(self.classifier.weight)
        if self.classifier.bias is not None:
            nn.init.zeros_(self.classifier.bias)
    
    def forward(self, x, mask=None):
        features = self.features(x)
        
        mask_resized = None
        if mask is not None:
            mask_resized = F.interpolate(mask, size=features.size()[-2:], 
                                       mode='bilinear', align_corners=False)
        
        heatmap = self.heatmap_layer(features, mask_resized)
        pooled = self.global_avg_pool(features)
        pooled = pooled.view(pooled.size(0), -1)
        outputs = self.classifier(pooled)
        
        return outputs, heatmap, features

def load_new_annotations(annotation_file, image_dir):
    """加载新的标注数据（无掩码）"""
    try:
        annotations = pd.read_excel(annotation_file)
        
        # 自动检测标签列
        label_col = None
        possible_label_names = ['LABLE', 'LABEL', 'Label', 'label', '分级', '评分', 'fibrosis', 'stage']
        for name in possible_label_names:
            if name in annotations.columns:
                label_col = name
                break
        if label_col is None:
            raise ValueError("无法找到标签列")
        
        annotations.rename(columns={label_col: 'LABEL'}, inplace=True)
        annotations['ID'] = annotations['ID'].astype(str).str.strip()
        
        # 处理标签 - 转换为0/1二分类
        unique_labels = sorted(annotations['LABEL'].unique())
        if len(unique_labels) != 2:
            if len(unique_labels) > 2:
                print(f"警告: 检测到多类标签{unique_labels}, 将转换为二分类问题")
                annotations['LABEL'] = (annotations['LABEL'] >= 3).astype(int)
            else:
                raise ValueError("标签类别不足2类")
        
        # 添加图像路径列
        def find_image_path(x):
            base_path = os.path.join(image_dir, x)
            for ext in ['.jpg', '.tif', '.tiff', '.png', '.jpeg']:
                img_path = base_path + ext
                if os.path.exists(img_path):
                    return img_path
                img_path = base_path + ext.upper()
                if os.path.exists(img_path):
                    return img_path
            return None
        
        annotations['image_path'] = annotations['ID'].apply(find_image_path)
        
        # 检查是否有图像缺失
        missing_images = annotations['image_path'].isnull().sum()
        if missing_images > 0:
            print(f"警告: 有{missing_images}个ID找不到对应的图像文件")
        
        # 只保留有图像的有效样本
        annotations = annotations.dropna(subset=['image_path'])
        
        if len(annotations) == 0:
            raise ValueError("没有找到有效的样本")
        
        print(f"找到{len(annotations)}个有效样本")
        
        return annotations
    
    except Exception as e:
        print(f"加载标注文件出错: {str(e)}")
        raise

class NewDataDataset(Dataset):
    """新的数据集类（无掩码）"""
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe.copy()
        self.transform = transform
        self.valid_samples = []
        
        # 预先检查所有图像可用性
        for idx, row in self.dataframe.iterrows():
            img_path = row['image_path']
            
            if isinstance(img_path, str) and os.path.exists(img_path):
                try:
                    with Image.open(img_path) as img:
                        img.verify()
                    self.valid_samples.append((img_path, row['LABEL']))
                except Exception as e:
                    print(f"警告: 文件损坏 {img_path}: {str(e)}")
            else:
                print(f"警告: 文件不存在 {img_path}")
    
    def __len__(self):
        return len(self.valid_samples)
    
    def __getitem__(self, idx):
        img_path, label = self.valid_samples[idx]
        
        try:
            with Image.open(img_path) as img:
                image = img.convert('RGB')
            
            if self.transform:
                image = self.transform(image)
            
            return image, torch.tensor(label, dtype=torch.long), img_path
        except Exception as e:
            print(f"加载文件 {img_path} 出错: {str(e)}")
            return torch.zeros(3, 224, 224), torch.tensor(-1, dtype=torch.long), img_path

def get_validation_transforms():
    """获取验证用的数据变换"""
    val_transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    return val_transforms

def calculate_metrics(y_true, y_pred, y_probs):
    """计算评估指标"""
    metrics = {
        'accuracy': accuracy_score(y_true, y_pred),
        'precision': precision_score(y_true, y_pred),
        'recall': recall_score(y_true, y_pred),
        'f1': f1_score(y_true, y_pred),
        'roc_auc': roc_auc_score(y_true, y_probs),
        'pr_auc': average_precision_score(y_true, y_probs),
        'kappa': cohen_kappa_score(y_true, y_pred)
    }
    return metrics

def validate_model():
    """验证模型在主函数"""
    
    print("开始验证模型...")
    
    # 加载新的标注数据
    print("加载新的标注数据...")
    new_annotations = load_new_annotations(new_annotation_file, new_image_dir)
    
    # 获取数据变换
    val_transforms = get_validation_transforms()
    
    # 创建数据集和数据加载器
    print("创建数据集...")
    new_dataset = NewDataDataset(new_annotations, transform=val_transforms)
    
    new_loader = DataLoader(
        new_dataset,
        batch_size=8,
        shuffle=False,
        num_workers=0,
        pin_memory=True
    )
    
    # 初始化模型
    print("初始化模型...")
    model = RenalFibrosisModel(
        num_classes=2,
        model_name='resnet18',
        pretrained=False,
        mask_constraint='soft'
    ).to(device)
    
    # 加载训练好的模型权重
    print(f"加载训练好的模型: {trained_model_path}")
    if not os.path.exists(trained_model_path):
        raise FileNotFoundError(f"模型文件不存在: {trained_model_path}")
    
    model.load_state_dict(torch.load(trained_model_path, map_location=device))
    model.eval()
    print("模型加载完成!")
    
    # 进行预测
    print("开始预测...")
    all_labels = []
    all_preds = []
    all_probs = []
    all_image_paths = []
    
    with torch.no_grad():
        for batch_idx, (inputs, labels, img_paths) in enumerate(tqdm(new_loader, desc='Validation')):
            # 过滤无效样本
            valid_mask = labels != -1
            if not valid_mask.any():
                continue
                
            inputs = inputs[valid_mask].to(device)
            labels = labels[valid_mask].to(device)
            valid_img_paths = [img_paths[i] for i in range(len(valid_mask)) if valid_mask[i]]
            
            # 前向传播 - mask=None
            outputs, _, _ = model(inputs, mask=None)
            
            # 获取预测结果
            _, preds = torch.max(outputs, 1)
            probs = torch.softmax(outputs, dim=1)[:, 1]
            
            # 收集结果
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
            all_image_paths.extend(valid_img_paths)
    
    # 计算评估指标
    print("计算评估指标...")
    if len(all_labels) > 0:
        metrics = calculate_metrics(
            np.array(all_labels),
            np.array(all_preds),
            np.array(all_probs)
        )
        
        # 打印详细结果
        print("\n" + "="*50)
        print("验证结果汇总")
        print("="*50)
        print(f"总样本数: {len(all_labels)}")
        print(f"准确率: {metrics['accuracy']:.4f}")
        print(f"精确率: {metrics['precision']:.4f}")
        print(f"召回率: {metrics['recall']:.4f}")
        print(f"F1分数: {metrics['f1']:.4f}")
        print(f"ROC-AUC: {metrics['roc_auc']:.4f}")
        print(f"PR-AUC: {metrics['pr_auc']:.4f}")
        print(f"Kappa: {metrics['kappa']:.4f}")
        
        # 分类报告
        print("\n详细分类报告:")
        print(classification_report(all_labels, all_preds, target_names=['Class 0', 'Class 1']))
        
        # 混淆矩阵
        cm = confusion_matrix(all_labels, all_preds)
        plt.figure(figsize=(6, 5))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                   xticklabels=['Pred 0', 'Pred 1'], 
                   yticklabels=['True 0', 'True 1'])
        plt.title('Confusion Matrix')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        cm_path = os.path.join(results_save_dir, 'confusion_matrix.png')
        plt.savefig(cm_path, dpi=300, bbox_inches='tight')
        plt.close()
        print(f"混淆矩阵已保存: {cm_path}")
        
        # ROC曲线
        fpr, tpr, _ = roc_curve(all_labels, all_probs)
        roc_auc = metrics['roc_auc']
        
        plt.figure(figsize=(8, 6))
        plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.3f})')
        plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Receiver Operating Characteristic (ROC) Curve')
        plt.legend(loc="lower right")
        roc_path = os.path.join(results_save_dir, 'roc_curve.png')
        plt.savefig(roc_path, dpi=300, bbox_inches='tight')
        plt.close()
        print(f"ROC曲线已保存: {roc_path}")
        
        # 保存详细结果
        results = {
            'model_path': trained_model_path,
            'total_samples': len(all_labels),
            'metrics': metrics,
            'predictions': [
                {
                    'image_path': all_image_paths[i],
                    'true_label': int(all_labels[i]),
                    'predicted_label': int(all_preds[i]),
                    'probability': float(all_probs[i])
                }
                for i in range(len(all_labels))
            ]
        }
        
        # 保存为JSON
        results_path = os.path.join(results_save_dir, 'validation_results.json')
        with open(results_path, 'w', encoding='utf-8') as f:
            json.dump(results, f, indent=4, ensure_ascii=False)
        
        # 保存为CSV
        results_df = pd.DataFrame({
            'image_path': all_image_paths,
            'true_label': all_labels,
            'predicted_label': all_preds,
            'probability': all_probs
        })
        csv_path = os.path.join(results_save_dir, 'validation_results.csv')
        results_df.to_csv(csv_path, index=False, encoding='utf-8')
        
        print(f"\n所有结果已保存到: {results_save_dir}")
        print(f"- 验证结果(JSON): {results_path}")
        print(f"- 验证结果(CSV): {csv_path}")
        print(f"- 混淆矩阵: {cm_path}")
        print(f"- ROC曲线: {roc_path}")
        
    else:
        print("没有有效的预测结果!")

if __name__ == '__main__':
    validate_model()