In [None]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, f1_score
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, roc_curve, precision_recall_curve, average_precision_score, cohen_kappa_score
import seaborn as sns
from sklearn.metrics import precision_score, recall_score

# 设置随机种子保证可重复性
torch.manual_seed(42)
np.random.seed(42)

# 使用CPU设备
device = torch.device("cpu")
print(f"Using device: {device}")

# 数据路径配置
current_dir = os.getcwd()
data_dir = os.path.join(current_dir, "Crescent")
image_dir = data_dir
annotation_file = os.path.join(current_dir, "Crescentlabel.xlsx")

# 读取并预处理标注数据
def load_and_preprocess_annotations(annotation_file):
    annotations = pd.read_excel(annotation_file)
    
    # 自动检测标签列
    label_col = None
    possible_label_names = ['LABLE', 'LABEL', 'Label', 'label', '分级', '评分', 'fibrosis']
    for name in possible_label_names:
        if name in annotations.columns:
            label_col = name
            break
    if label_col is None:
        raise ValueError("无法找到标签列")
    
    annotations.rename(columns={label_col: 'LABEL'}, inplace=True)
    
    # 标准化ID格式
    annotations['ID'] = annotations['ID'].apply(lambda x: str(x).zfill(2))
    
    # 如果标签不是0/1，调整为0/1
    if annotations['LABEL'].max() > 1:
        annotations['LABEL'] = annotations['LABEL'] - 1
    
    # 检查类别平衡
    class_counts = annotations['LABEL'].value_counts()
    print(f"\n类别分布:\n{class_counts}")
    print(f"正负类比例: {class_counts[1]/class_counts[0]:.2f}:1")
    
    return annotations

annotations = load_and_preprocess_annotations(annotation_file)

# 自定义数据集类（增强错误处理）
class CrescentDataset(Dataset):
    def __init__(self, dataframe, image_dir, transform=None):
        self.dataframe = dataframe.copy()
        self.image_dir = image_dir
        self.transform = transform
        self.valid_samples = []
        
        # 预先检查所有图像可用性
        for idx, row in self.dataframe.iterrows():
            img_path = os.path.join(self.image_dir, f"{row['ID']}.jpg")
            if os.path.exists(img_path):
                self.valid_samples.append((img_path, row['LABEL']))
            else:
                print(f"警告：跳过不存在的图像 {img_path}")
    
    def __len__(self):
        return len(self.valid_samples)
    
    def __getitem__(self, idx):
        img_path, label = self.valid_samples[idx]
        
        try:
            image = Image.open(img_path).convert('RGB')
            
            if self.transform:
                image = self.transform(image)
                
            return image, torch.tensor(label, dtype=torch.long)
        except Exception as e:
            print(f"加载图像 {img_path} 出错: {str(e)}")
            # 返回空白图像和-1标签（训练时会过滤掉）
            return torch.zeros(3, 224, 224), torch.tensor(-1, dtype=torch.long)

# 数据增强配置（适合医学图像）
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.1, contrast=0.1),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# 轻量级模型定义（适合CPU）
class CrescentModel(nn.Module):
    def __init__(self, num_classes=2):
        super(CrescentModel, self).__init__()
        # 使用预训练的resnet18（更轻量）
        self.backbone = models.resnet18(pretrained=True)
        
        # 冻结前几层（减少计算量）
        for param in list(self.backbone.parameters())[:50]:
            param.requires_grad = False
            
        # 修改最后一层
        num_ftrs = self.backbone.fc.in_features
        self.backbone.fc = nn.Sequential(
            nn.Linear(num_ftrs, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, x):
        return self.backbone(x)

# 训练和验证函数（减少重复代码）
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss, running_correct = 0.0, 0
    all_probs, all_labels = [], []
    
    for inputs, labels in loader:
        # 过滤掉无效样本（标签为-1的）
        valid_mask = labels != -1
        if not valid_mask.any():
            continue
            
        inputs = inputs[valid_mask].to(device)
        labels = labels[valid_mask].to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        _, preds = torch.max(outputs, 1)
        running_loss += loss.item() * inputs.size(0)
        running_correct += torch.sum(preds == labels.data)
        
        # 计算概率（用于AUC）
        probs = torch.softmax(outputs, dim=1)[:, 1]
        all_probs.extend(probs.detach().cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    
    epoch_loss = running_loss / len(loader.dataset)
    epoch_acc = running_correct.double() / len(loader.dataset)
    
    return epoch_loss, epoch_acc, np.array(all_probs), np.array(all_labels)

def validate_epoch(model, loader, criterion, device):
    model.eval()
    running_loss, running_correct = 0.0, 0
    all_probs, all_labels = [], []
    
    with torch.no_grad():
        for inputs, labels in loader:
            # 过滤无效样本
            valid_mask = labels != -1
            if not valid_mask.any():
                continue
                
            inputs = inputs[valid_mask].to(device)
            labels = labels[valid_mask].to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            _, preds = torch.max(outputs, 1)
            running_loss += loss.item() * inputs.size(0)
            running_correct += torch.sum(preds == labels.data)
            
            # 计算概率（用于AUC）
            probs = torch.softmax(outputs, dim=1)[:, 1]
            all_probs.extend(probs.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    epoch_loss = running_loss / len(loader.dataset)
    epoch_acc = running_correct.double() / len(loader.dataset)
    
    return epoch_loss, epoch_acc, np.array(all_probs), np.array(all_labels)

# 早停类
class EarlyStopping:
    def __init__(self, patience=5, delta=0, verbose=False):
        """
        Args:
            patience (int): 验证集性能不再提升的等待epoch数
            delta (float): 被视为有提升的最小变化量
            verbose (bool): 如果为True，打印早停信息
        """
        self.patience = patience
        self.delta = delta
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_acc_max = -np.inf

    def __call__(self, val_acc, model):
        score = val_acc

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_acc, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_acc, model)
            self.counter = 0

    def save_checkpoint(self, val_acc, model):
        '''Saves model when validation accuracy increases.'''
        if self.verbose:
            print(f'Validation accuracy increased ({self.val_acc_max:.6f} --> {val_acc:.6f}). Saving model...')
        torch.save(model.state_dict(), 'checkpoint.pth')
        self.val_acc_max = val_acc

# 主训练流程
def main():
    # 超参数配置（适合CPU）
    batch_size = 16  # 较小的batch size适合CPU内存
    num_epochs = 15  # 减少epoch数加快训练
    num_workers = 0  # CPU环境下设为0
    
    # 5折交叉验证
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    X = annotations
    y = annotations['LABEL']
    
    # 保存所有fold的结果
    all_train_labels = []
    all_train_probs = []
    all_val_labels = []
    all_val_probs = []
    fold_metrics = []
    
    for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
        print(f"\n========== Fold {fold + 1}/5 ==========")
        
        # 准备数据
        train_df = X.iloc[train_idx].reset_index(drop=True)
        val_df = X.iloc[val_idx].reset_index(drop=True)
        
        train_dataset = CrescentDataset(train_df, image_dir, data_transforms['train'])
        val_dataset = CrescentDataset(val_df, image_dir, data_transforms['val'])
        
        train_loader = DataLoader(train_dataset, batch_size=batch_size, 
                                shuffle=True, num_workers=num_workers)
        val_loader = DataLoader(val_dataset, batch_size=batch_size,
                              shuffle=False, num_workers=num_workers)
        
        # 初始化模型（每折重新开始）
        model = CrescentModel(num_classes=2).to(device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.0005)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=3, factor=0.5)
        
        # 初始化早停
        early_stopping = EarlyStopping(patience=5, verbose=True)
        
        # 训练记录
        best_acc = 0.0
        history = {'train_loss': [], 'val_loss': [], 
                  'train_acc': [], 'val_acc': [],
                  'train_auc': [], 'val_auc': []}
        
        for epoch in range(num_epochs):
            print(f'Epoch {epoch + 1}/{num_epochs}')
            
            # 训练
            train_loss, train_acc, train_probs, train_labels = train_epoch(
                model, train_loader, criterion, optimizer, device)
            train_auc = roc_auc_score(train_labels, train_probs)
            
            # 验证
            val_loss, val_acc, val_probs, val_labels = validate_epoch(
                model, val_loader, criterion, device)
            val_auc = roc_auc_score(val_labels, val_probs)
            
            # 学习率调整
            scheduler.step(val_acc)
            
            # 记录历史
            history['train_loss'].append(train_loss)
            history['val_loss'].append(val_loss)
            history['train_acc'].append(train_acc.item())
            history['val_acc'].append(val_acc.item())
            history['train_auc'].append(train_auc)
            history['val_auc'].append(val_auc)
            
            # 打印训练和验证指标
            print(f'Train - Loss: {train_loss:.4f} | Acc: {train_acc:.4f} | AUC: {train_auc:.4f}')
            print(f'Val   - Loss: {val_loss:.4f} | Acc: {val_acc:.4f} | AUC: {val_auc:.4f}')
            
            # 早停检查
            early_stopping(val_acc, model)
            if early_stopping.early_stop:
                print("Early stopping triggered")
                break
            
            # 保存最佳模型
            if val_acc > best_acc:
                best_acc = val_acc
                torch.save(model.state_dict(), f'best_model_fold{fold}.pth')
                # 保存训练集和验证集的预测结果
                np.savez(f'fold{fold}_predictions.npz',
                        train_probs=train_probs, train_labels=train_labels,
                        val_probs=val_probs, val_labels=val_labels)
        
        # 加载最佳模型
        model.load_state_dict(torch.load(f'best_model_fold{fold}.pth'))
        
        # 获取最佳模型在训练集和验证集上的预测
        _, _, train_probs, train_labels = validate_epoch(
            model, train_loader, criterion, device)
        val_loss, val_acc, val_probs, val_labels = validate_epoch(
            model, val_loader, criterion, device)
        val_auc = roc_auc_score(val_labels, val_probs)
        print(f'Final Validation - Loss: {val_loss:.4f} | Acc: {val_acc:.4f} | AUC: {val_auc:.4f}')
        
        # 保存本折结果
        all_train_labels.extend(train_labels)
        all_train_probs.extend(train_probs)
        all_val_labels.extend(val_labels)
        all_val_probs.extend(val_probs)
        
        # 计算本折指标
        fold_metrics.append({
            'fold': fold + 1,
            'train_acc': accuracy_score(train_labels, train_probs > 0.5),
            'train_auc': roc_auc_score(train_labels, train_probs),
            'val_acc': accuracy_score(val_labels, val_probs > 0.5),
            'val_auc': val_auc,
            'val_f1': f1_score(val_labels, val_probs > 0.5)
        })
        
        # 绘制训练曲线
        plt.figure(figsize=(15, 5))
        plt.subplot(1, 3, 1)
        plt.plot(history['train_loss'], label='Train Loss')
        plt.plot(history['val_loss'], label='Val Loss')
        plt.legend(); plt.title(f'Fold {fold+1} Loss')
        
        plt.subplot(1, 3, 2)
        plt.plot(history['train_acc'], label='Train Acc')
        plt.plot(history['val_acc'], label='Val Acc')
        plt.legend(); plt.title(f'Fold {fold+1} Accuracy')
        
        plt.subplot(1, 3, 3)
        plt.plot(history['train_auc'], label='Train AUC')
        plt.plot(history['val_auc'], label='Val AUC')
        plt.legend(); plt.title(f'Fold {fold+1} ROC-AUC')
        
        plt.tight_layout()
        plt.savefig(f'training_curve_fold{fold}.png')
        plt.close()
    
    # 整体评估
    # 整体评估
    y_train_true = np.array(all_train_labels)
    y_train_probs = np.array(all_train_probs)
    y_val_true = np.array(all_val_labels)
    y_val_probs = np.array(all_val_probs)
    
    # 1. 计算各项指标
    y_train_pred = y_train_probs > 0.5
    y_val_pred = y_val_probs > 0.5

    # 训练集指标
    train_metrics = {
        'accuracy': accuracy_score(y_train_true, y_train_pred),
        'precision': precision_score(y_train_true, y_train_pred),
        'recall': recall_score(y_train_true, y_train_pred),
        'f1': f1_score(y_train_true, y_train_pred),
        'roc_auc': roc_auc_score(y_train_true, y_train_probs),
        'pr_auc': average_precision_score(y_train_true, y_train_probs),
    }

    # 验证集指标
    val_metrics = {
        'accuracy': accuracy_score(y_val_true, y_val_pred),
        'precision': precision_score(y_val_true, y_val_pred),
        'recall': recall_score(y_val_true, y_val_pred),
        'f1': f1_score(y_val_true, y_val_pred),
        'roc_auc': roc_auc_score(y_val_true, y_val_probs),
        'pr_auc': average_precision_score(y_val_true, y_val_probs),
        'kappa': cohen_kappa_score(y_val_true, y_val_pred)
    }

    # 2. 打印结果
    print("\n================ Final 5-Fold CV Results ================")
    print("\n=== Training Set ===")
    print(f"Accuracy:    {train_metrics['accuracy']:.4f}")
    print(f"Precision:   {train_metrics['precision']:.4f}")
    print(f"Recall:      {train_metrics['recall']:.4f}")
    print(f"F1-score:    {train_metrics['f1']:.4f}")
    print(f"ROC-AUC:     {train_metrics['roc_auc']:.4f}")
    print(f"PR-AUC:      {train_metrics['pr_auc']:.4f}")

    print("\n=== Validation Set ===")
    print(f"Accuracy:    {val_metrics['accuracy']:.4f}")
    print(f"Precision:   {val_metrics['precision']:.4f}")
    print(f"Recall:      {val_metrics['recall']:.4f}")
    print(f"F1-score:    {val_metrics['f1']:.4f}")
    print(f"ROC-AUC:     {val_metrics['roc_auc']:.4f}")
    print(f"PR-AUC:      {val_metrics['pr_auc']:.4f}")
    print(f"Cohen's κ:   {val_metrics['kappa']:.4f}")

    # 3. 绘制混淆矩阵
    plt.figure(figsize=(6, 6))
    sns.heatmap(confusion_matrix(y_val_true, y_val_pred), 
                annot=True, fmt='d', cmap='Blues', 
                xticklabels=['Negative', 'Positive'],
                yticklabels=['Negative', 'Positive'])
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.title('Confusion Matrix')
    plt.savefig('confusion_matrix.png')
    plt.close()

    
    # 4. 绘制ROC曲线（训练集和验证集在同一张图上）
    train_fpr, train_tpr, _ = roc_curve(y_train_true, y_train_probs)
    val_fpr, val_tpr, _ = roc_curve(y_val_true, y_val_probs)

    plt.figure(figsize=(8, 8))
    plt.plot(train_fpr, train_tpr, label=f'Train ROC (AUC = {train_metrics["roc_auc"]:.3f})', color='blue', linestyle='--')
    plt.plot(val_fpr, val_tpr, label=f'Validation ROC (AUC = {val_metrics["roc_auc"]:.3f})', color='red')  # 修改这里
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curves (Train vs Validation)')
    plt.legend(loc='lower right')
    plt.savefig('roc_curve_combined.png')
    plt.close()
    
    # 5. 保存所有fold的指标
    fold_metrics_df = pd.DataFrame(fold_metrics)
    fold_metrics_df.to_csv('fold_metrics.csv', index=False)
    print("\n各折详细指标已保存到 fold_metrics.csv")
    
    print("\n训练完成！所有结果和模型已保存。")
    
if __name__ == '__main__':
    main()