In [3]:
import re
import numpy as np
from collections import defaultdict

In [4]:
def parse_kfold_metrics(text, target_epoch=30):
    """
    解析K-fold输出，提取指定epoch的指标
    """
    # 按fold分割文本
    folds = re.split(r'Kfold: =+ \d+ =+', text)
    folds = [fold.strip() for fold in folds if fold.strip()]
    
    # 存储每个fold的指标
    metrics_dict = defaultdict(list)
    
    for fold_idx, fold_text in enumerate(folds):
        # 查找目标epoch行
        epoch_pattern = rf'Epoch: {target_epoch} -- (.*)'
        match = re.search(epoch_pattern, fold_text)
        
        if match:
            epoch_line = match.group(1)
            
            # 解析指标
            metrics = {
                'Accuracy': extract_metric(epoch_line, 'Accuracy'),
                'Precision': extract_metric(epoch_line, 'Precision'),
                'Recall': extract_metric(epoch_line, 'Recall'),
                'F1': extract_metric(epoch_line, 'F1'),
                'ROC AUC': extract_metric(epoch_line, 'ROC AUC'),
                'PR AUC': extract_metric(epoch_line, 'PR AUC')
            }
            
            # 只添加有效指标（排除警告情况）
            if all(v is not None for v in metrics.values()):
                for metric_name, value in metrics.items():
                    metrics_dict[metric_name].append(value)
                print(f"Fold {fold_idx+1}: {metrics}")
            else:
                print(f"Fold {fold_idx+1}: 包含无效指标，已跳过")
        else:
            print(f"Fold {fold_idx+1}: 未找到Epoch {target_epoch}")
    
    return metrics_dict

def extract_metric(text, metric_name):
    """
    从文本中提取特定指标的值
    """
    pattern = rf'{metric_name}: ([\d.]+)'
    match = re.search(pattern, text)
    if match:
        return float(match.group(1))
    return None

def calculate_statistics(metrics_dict):
    """
    计算每个指标的均值和标准差
    """
    results = {}
    
    for metric_name, values in metrics_dict.items():
        if values:  # 确保列表不为空
            values_array = np.array(values)
            results[metric_name] = {
                'mean': np.mean(values_array),
                'std': np.std(values_array, ddof=1),  # 样本标准差
                'values': values_array.tolist()
            }
    
    return results

def print_results(results, target_epoch):
    """
    打印统计结果
    """
    print(f"\n{'='*60}")
    print(f"Epoch {target_epoch} - 5-Fold 交叉验证结果")
    print(f"{'='*60}")
    
    for metric_name, stats in results.items():
        print(f"{metric_name:>10}: {stats['mean']:.4f} ± {stats['std']:.4f}")
        print(f"            各Fold值: {[f'{x:.4f}' for x in stats['values']]}")
    
    print(f"{'='*60}")

   

In [None]:
# 设置目标epoch
target_epoch = 14

# result_file = "result/yeast1/PIPR.txt"
result_file = "result/yeast1/PIPR.txt"


# 解析指标
with open(result_file, "r", encoding="utf-8") as fin:
    result_text = fin.read()

metrics_dict = parse_kfold_metrics(result_text, target_epoch)

# 计算统计量
results = calculate_statistics(metrics_dict)

# 打印结果
print_results(results, target_epoch)

# 额外：计算所有有效fold的数量
valid_folds = len(next(iter(metrics_dict.values()))) if metrics_dict else 0
print(f"有效Fold数量: {valid_folds}/5")

Fold 1: {'Accuracy': 0.9383, 'Precision': 0.9681, 'Recall': 0.9051, 'F1': 0.9355, 'ROC AUC': 0.938, 'PR AUC': 0.9231}
Fold 2: {'Accuracy': 0.9214, 'Precision': 0.951, 'Recall': 0.8878, 'F1': 0.9183, 'ROC AUC': 0.9212, 'PR AUC': 0.9001}
Fold 3: {'Accuracy': 0.9254, 'Precision': 0.9467, 'Recall': 0.902, 'F1': 0.9238, 'ROC AUC': 0.9255, 'PR AUC': 0.9031}
Fold 4: {'Accuracy': 0.9464, 'Precision': 0.9746, 'Recall': 0.9176, 'F1': 0.9453, 'ROC AUC': 0.9466, 'PR AUC': 0.9359}
Fold 5: {'Accuracy': 0.9437, 'Precision': 0.9835, 'Recall': 0.9029, 'F1': 0.9414, 'ROC AUC': 0.9438, 'PR AUC': 0.9367}

Epoch 14 - 5-Fold 交叉验证结果
  Accuracy: 0.9350 ± 0.0111
            各Fold值: ['0.9383', '0.9214', '0.9254', '0.9464', '0.9437']
 Precision: 0.9648 ± 0.0156
            各Fold值: ['0.9681', '0.9510', '0.9467', '0.9746', '0.9835']
    Recall: 0.9031 ± 0.0106
            各Fold值: ['0.9051', '0.8878', '0.9020', '0.9176', '0.9029']
        F1: 0.9329 ± 0.0115
            各Fold值: ['0.9355', '0.9183', '0.9238', '0.945