In [None]:
import math
import os
import time
import random
import json
import itertools
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader
from torch.cuda.amp import GradScaler, autocast

# Visualization imports
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from sklearn.metrics import (accuracy_score, f1_score, confusion_matrix, 
                           classification_report, precision_score, recall_score)
from sklearn.manifold import TSNE
from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_curve, auc
from itertools import cycle
from plotly.subplots import make_subplots
import plotly.graph_objects as go

# Model imports
import wandb
from ptflops import get_model_complexity_info
from models.tiny_vit import (tiny_vit_5m_224, tiny_vit_11m_224, 
                           tiny_vit_21m_224, tiny_vit_21m_384, tiny_vit_21m_512)

# ==================== Rest of your code remains the same ====================
class TimeTracker:
    """训练过程时间追踪器"""
    def __init__(self):
        self.epoch_times = []
        self.batch_times = []
        self.epoch_start_time = None
        self.batch_start_time = None
    
    def start_epoch(self):
        self.epoch_start_time = time.time()
    
    def end_epoch(self):
        epoch_time = time.time() - self.epoch_start_time
        self.epoch_times.append(epoch_time)
        return epoch_time
    
    def start_batch(self):
        self.batch_start_time = time.time()
    
    def end_batch(self):
        batch_time = time.time() - self.batch_start_time
        self.batch_times.append(batch_time)
        return batch_time
    
    def get_stats(self):
        """获取时间统计信息"""
        if not self.epoch_times:
            return {}
        
        return {
            'total_epoch_time': sum(self.epoch_times),
            'avg_epoch_time': np.mean(self.epoch_times),
            'median_epoch_time': np.median(self.epoch_times),
            'min_epoch_time': np.min(self.epoch_times),
            'max_epoch_time': np.max(self.epoch_times),
            'avg_batch_time': np.mean(self.batch_times) if self.batch_times else 0,
            'total_batch_time': sum(self.batch_times) if self.batch_times else 0
        }

def get_model_stats(model, input_size=(3, 224, 224)):
    """获取模型FLOPs和参数数量"""
    # 确保输入是4D格式 (batch_size, channels, height, width)
    if len(input_size) == 3:
        input_size = (1,) + input_size  # 添加batch维度
    
    try:
        # 修正输入维度问题
        if len(input_size) == 4 and input_size[0] == 1:
            # 如果已经是(1,3,224,224)格式，则直接使用
            pass
        elif len(input_size) == 3:
            input_size = (1,) + input_size
        
        # 处理TinyViT的特殊输入要求
        if hasattr(model, 'backbone'):
            # 创建一个包装器函数来适配TinyViT
            def flops_model(x):
                return model.backbone(x)
            
            macs, params = get_model_complexity_info(
                flops_model, 
                input_size, 
                as_strings=False, 
                print_per_layer_stat=False,
                ignore_modules=[nn.Dropout, nn.BatchNorm2d, nn.LayerNorm]
            )
        else:
            macs, params = get_model_complexity_info(
                model, 
                input_size, 
                as_strings=False, 
                print_per_layer_stat=False,
                ignore_modules=[nn.Dropout, nn.BatchNorm2d, nn.LayerNorm]
            )
        
        # 转换为百万单位
        flops = macs * 2  # MACs转换为FLOPs (1 MAC = 2 FLOPs)
        flops_m = flops / 1e6
        params_m = params / 1e6
        
        return {
            'flops': flops,
            'flops_m': flops_m,
            'params': params,
            'params_m': params_m,
            'flops_str': f"{flops_m:.2f}M",
            'params_str': f"{params_m:.2f}M"
        }
    except Exception as e:
        print(f"Error calculating FLOPs: {e}")
        # 只计算参数数量作为后备方案
        params = sum(p.numel() for p in model.parameters())
        params_m = params / 1e6
        
        return {
            'flops': None,
            'flops_m': None,
            'params': params,
            'params_m': params_m,
            'flops_str': "N/A",
            'params_str': f"{params_m:.2f}M"
        }

def measure_latency(model, input_size=(3, 224, 224), repetitions=100, warmup=10):
    """测量模型推理延迟"""
    device = next(model.parameters()).device
    # 确保输入是4D格式 (batch_size, channels, height, width)
    if len(input_size) == 3:
        input_size = (1,) + input_size
    
    dummy_input = torch.randn(input_size).to(device)
    
    # Warmup
    for _ in range(warmup):
        _ = model(dummy_input)
    
    # 同步GPU操作
    if device.type == 'cuda':
        torch.cuda.synchronize()
    
    # 测量时间
    timings = []
    with torch.no_grad():
        for _ in range(repetitions):
            start = time.time()
            _ = model(dummy_input)
            if device.type == 'cuda':
                torch.cuda.synchronize()
            end = time.time()
            timings.append(end - start)
    
    # 转换为毫秒
    timings = np.array(timings) * 1000
    return {
        'latency_mean': np.mean(timings),
        'latency_median': np.median(timings),
        'latency_min': np.min(timings),
        'latency_max': np.max(timings),
        'latency_std': np.std(timings),
        'fps': 1000 / np.mean(timings)
    }


# ==================== Enhanced Components ====================
class DropKeyAttention(nn.Module):
    """带DropKey的注意力机制改进版"""
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., drop_key_prob=0.1):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        
        # DropKey specific
        self.drop_key = nn.Dropout(drop_key_prob)
        self.drop_key_prob = drop_key_prob

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        
        # 关键修改：在softmax前应用DropKey
        if self.training and self.drop_key_prob > 0:
            attn = self.drop_key(attn)
            attn = attn / (attn.sum(dim=-1, keepdim=True) + 1e-6)  # 重新归一化

        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class AdvancedMixupCutmix:
    """增强版混合策略，动态选择Mixup或CutMix"""
    def __init__(self, mixup_alpha=0.8, cutmix_alpha=1.0, switch_prob=0.5):
        self.mixup_beta = torch.distributions.Beta(mixup_alpha, mixup_alpha)
        self.cutmix_beta = torch.distributions.Beta(cutmix_alpha, cutmix_alpha)
        self.switch_prob = switch_prob

    def __call__(self, x, y):
        if random.random() < self.switch_prob:  # CutMix模式
            lam = self.cutmix_beta.sample().item()
            bbx1, bby1, bbx2, bby2 = self.rand_bbox(x.size(), lam)
            x[:, :, bbx1:bbx2, bby1:bby2] = x.flip(0)[:, :, bbx1:bbx2, bby1:bby2]
            # 调整lambda以匹配实际混合比例
            lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2]))
            return x, (y, y.flip(0), lam)
        else:  # Mixup模式
            lam = self.mixup_beta.sample().item()
            return lam*x + (1-lam)*x.flip(0), (y, y.flip(0), lam)

    def rand_bbox(self, size, lam):
        W, H = size[2], size[3]
        cut_rat = np.sqrt(1. - lam)
        cut_w = int(W * cut_rat)
        cut_h = int(H * cut_rat)
        cx = np.random.randint(W)
        cy = np.random.randint(H)
        bbx1 = np.clip(cx - cut_w // 2, 0, W)
        bby1 = np.clip(cy - cut_h // 2, 0, H)
        bbx2 = np.clip(cx + cut_w // 2, 0, W)
        bby2 = np.clip(cy + cut_h // 2, 0, H)
        return bbx1, bby1, bbx2, bby2

class SmartRandomErasing:
    """改进版随机擦除，自动适应张量/PIL输入"""
    def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, mode='pixel'):
        self.p = p
        self.scale = scale
        self.ratio = ratio
        self.value = value
        self.mode = mode

    def __call__(self, img):
        if random.random() > self.p:
            return img

        if isinstance(img, torch.Tensor):
            return self._erase_tensor(img)
        else:  # PIL Image
            return self._erase_pil(img)

    def _erase_tensor(self, img):
        C, H, W = img.shape
        area = H * W

        for _ in range(10):
            erase_area = random.uniform(*self.scale) * area
            aspect_ratio = random.uniform(*self.ratio)

            h = int(round(math.sqrt(erase_area * aspect_ratio)))
            w = int(round(math.sqrt(erase_area / aspect_ratio)))

            if h < H and w < W:
                i = random.randint(0, H - h)
                j = random.randint(0, W - w)
                
                if self.mode == 'pixel':
                    v = torch.rand(C, h, w, dtype=img.dtype, device=img.device)
                else:
                    v = torch.rand(C, 1, 1, dtype=img.dtype, device=img.device).expand(-1, h, w)
                
                img[:, i:i+h, j:j+w] = v
                return img
        return img

    def _erase_pil(self, img):
        img_np = np.array(img)
        H, W, C = img_np.shape
        area = H * W

        for _ in range(10):
            erase_area = random.uniform(*self.scale) * area
            aspect_ratio = random.uniform(*self.ratio)

            h = int(round(math.sqrt(erase_area * aspect_ratio)))
            w = int(round(math.sqrt(erase_area / aspect_ratio)))

            if h < H and w < W:
                i = random.randint(0, H - h)
                j = random.randint(0, W - w)
                
                if self.mode == 'pixel':
                    v = np.random.randint(0, 256, (h, w, C))
                else:
                    v = np.random.randint(0, 256, (1, 1, C))
                    v = np.tile(v, (h, w, 1))
                
                img_np[i:i+h, j:j+w, :] = v
                return transforms.ToPILImage()(img_np.transpose(2, 0, 1))
        return img

# ==================== Enhanced TinyViT ====================

class EnhancedTinyViT(nn.Module):
    def __init__(self, base_model, num_classes, drop_key_prob=0.1):
        super().__init__()
        self.backbone = base_model
        self._replace_attention(drop_key_prob)
        
        # 动态获取特征维度
        with torch.no_grad():
            dummy_input = torch.randn(1, 3, 224, 224).to(next(base_model.parameters()).device)
            features = self.backbone(dummy_input)
            feature_dim = features.shape[-1]
        
        self.head = nn.Linear(feature_dim, num_classes)
        
    def _replace_attention(self, drop_key_prob, module=None):
        """递归替换所有注意力层，适配TinyViT的实际结构"""
        if module is None:
            module = self.backbone
            
        for name, child in module.named_children():
            if hasattr(child, 'attn') and isinstance(child.attn, nn.Module):
                # 获取原始注意力层参数
                dim = child.attn.qkv.in_features
                num_heads = child.attn.num_heads
                qkv_bias = child.attn.qkv.bias is not None
                
                # 安全获取dropout参数
                attn_drop = getattr(child.attn, 'attn_drop', nn.Dropout(0.)).p
                proj_drop = getattr(child.attn, 'proj_drop', nn.Dropout(0.)).p
                
                # 替换为带DropKey的注意力层
                child.attn = DropKeyAttention(
                    dim=dim,
                    num_heads=num_heads,
                    qkv_bias=qkv_bias,
                    attn_drop=attn_drop,
                    proj_drop=proj_drop,
                    drop_key_prob=drop_key_prob
                )
            
            # 递归处理子模块
            if isinstance(child, nn.ModuleList) or isinstance(child, nn.Sequential):
                for sub_module in child:
                    self._replace_attention(drop_key_prob, sub_module)
            elif isinstance(child, nn.Module):
                self._replace_attention(drop_key_prob, child)

    def forward(self, x):
        features = self.backbone(x)
        return self.head(features)
        
class Visualizer:
    @staticmethod
    def cnf_matrix_plotter(cm, classes, cmap=plt.cm.Blues, filename='confusion_matrix.pdf'):
        plt.figure(figsize=(8, 6))
        plt.imshow(cm, interpolation='nearest', cmap=cmap)
        plt.title('Confusion Matrix', fontsize=14, pad=20)
        plt.colorbar()

        tick_marks = np.arange(len(classes))
        plt.xticks(tick_marks, classes, rotation=45, fontsize=12)
        plt.yticks(tick_marks, classes, rotation=45, fontsize=12)

        thresh = cm.max() / 2.
        for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
            plt.text(j, i, f"{cm[i, j]}",
                     horizontalalignment="center",
                     verticalalignment="center",
                     color="white" if cm[i, j] > thresh else "black",
                     fontsize=12)

        plt.ylabel('True label', fontsize=14)
        plt.xlabel('Predicted label', fontsize=14)
        plt.tight_layout()
        plt.savefig(filename, dpi=600, bbox_inches='tight')
        plt.close()

    @staticmethod 
    def plot_tsne(features, labels, class_names, filename='tsne.pdf'):
        tsne = TSNE(n_components=2, random_state=42, perplexity=30, n_iter=300)
        features_tsne = tsne.fit_transform(features)

        plt.figure(figsize=(10, 8))
        colors = list(mcolors.TABLEAU_COLORS.values())

        for i, class_name in enumerate(class_names):
            mask = labels == i
            plt.scatter(features_tsne[mask, 0], features_tsne[mask, 1],
                        c=[colors[i]], label=class_name, s=50, alpha=0.6)

        plt.xticks(fontsize=12)
        plt.yticks(fontsize=12)
        plt.xlabel('t-SNE 1', fontsize=14)
        plt.ylabel('t-SNE 2', fontsize=14)
        plt.legend(loc='upper right', fontsize=12)
        plt.grid(True, alpha=0.3)
        plt.savefig(filename, dpi=600, bbox_inches='tight')
        plt.close()

    @staticmethod
    def plot_training_metrics(history, filename='training_metrics.pdf'):
        """绘制训练指标曲线"""
        epochs = range(1, len(history['train_loss']) + 1)

        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))

        # Loss曲线
        ax1.plot(epochs, history['train_loss'], 'b-', linewidth=2, label='Train')
        ax1.plot(epochs, history['val_loss'], 'r-', linewidth=2, label='Validation')
        ax1.set_title('Training and Validation Loss', fontsize=14)
        ax1.set_xlabel('Epoch', fontsize=12)
        ax1.set_ylabel('Loss', fontsize=12)
        ax1.legend(fontsize=12)
        ax1.grid(True, linestyle='--', alpha=0.5)

        # Accuracy曲线
        ax2.plot(epochs, history['train_acc'], 'b-', linewidth=2, label='Train')
        ax2.plot(epochs, history['val_acc'], 'r-', linewidth=2, label='Validation')
        ax2.set_title('Training and Validation Accuracy', fontsize=14)
        ax2.set_xlabel('Epoch', fontsize=12)
        ax2.set_ylabel('Accuracy', fontsize=12)
        ax2.legend(fontsize=12)
        ax2.grid(True, linestyle='--', alpha=0.5)

        # F1 Score曲线
        ax3.plot(epochs, history['train_f1'], 'b-', linewidth=2, label='Train')
        ax3.plot(epochs, history['val_f1'], 'r-', linewidth=2, label='Validation')
        ax3.set_title('Training and Validation F1 Score', fontsize=14)
        ax3.set_xlabel('Epoch', fontsize=12)
        ax3.set_ylabel('F1 Score', fontsize=12)
        ax3.legend(fontsize=12)
        ax3.grid(True, linestyle='--', alpha=0.5)

        # 学习率曲线
        ax4.plot(epochs, history['lr'], 'g-', linewidth=2, label='Learning Rate')
        ax4.set_title('Learning Rate Schedule', fontsize=14)
        ax4.set_xlabel('Epoch', fontsize=12)
        ax4.set_ylabel('Learning Rate', fontsize=12)
        ax4.legend(fontsize=12)
        ax4.grid(True, linestyle='--', alpha=0.5)

        plt.tight_layout()
        plt.savefig(filename, dpi=600, bbox_inches='tight')
        plt.close()

    @staticmethod
    def plot_roc_curve(y_true, y_score, class_names, filename='roc_curve.pdf'):
        """绘制多类ROC曲线"""
        y_true_bin = label_binarize(y_true, classes=np.arange(len(class_names)))
        n_classes = y_true_bin.shape[1]

        # 计算每个类的ROC曲线和AUC
        fpr = dict()
        tpr = dict()
        roc_auc = dict()
        for i in range(n_classes):
            fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_score[:, i])
            roc_auc[i] = auc(fpr[i], tpr[i])

        # 计算微平均ROC曲线和AUC
        fpr["micro"], tpr["micro"], _ = roc_curve(y_true_bin.ravel(), y_score.ravel())
        roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

        # 绘制所有ROC曲线
        plt.figure(figsize=(10, 8))
        colors = cycle(['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b'])
        
        for i, color in zip(range(n_classes), colors):
            plt.plot(fpr[i], tpr[i], color=color, lw=2,
                     label='{0} (AUC = {1:0.2f})'.format(class_names[i], roc_auc[i]))

        plt.plot([0, 1], [0, 1], 'k--', lw=2)
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate', fontsize=14)
        plt.ylabel('True Positive Rate', fontsize=14)
        plt.title('Receiver Operating Characteristic (ROC) Curve', fontsize=16, pad=20)
        plt.legend(loc="lower right", fontsize=12)
        plt.grid(True, alpha=0.3)
        plt.savefig(filename, dpi=600, bbox_inches='tight')
        plt.close()

    @staticmethod
    def plot_class_distribution(labels, class_names, filename='class_distribution.pdf'):
        """绘制类别分布直方图"""
        plt.figure(figsize=(10, 6))
        counts = np.bincount(labels)
        colors = plt.cm.viridis(np.linspace(0, 1, len(class_names)))
        
        bars = plt.bar(class_names, counts, color=colors)
        plt.title('Class Distribution', fontsize=16, pad=20)
        plt.xlabel('Class', fontsize=14)
        plt.ylabel('Count', fontsize=14)
        plt.xticks(rotation=45, fontsize=12)
        plt.yticks(fontsize=12)
        
        # 在柱子上添加数值标签
        for bar in bars:
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., height,
                    f'{int(height)}',
                    ha='center', va='bottom', fontsize=12)
        
        plt.grid(True, axis='y', linestyle='--', alpha=0.5)
        plt.tight_layout()
        plt.savefig(filename, dpi=600, bbox_inches='tight')
        plt.close()

    @staticmethod
    def plot_interactive_metrics(history, filename='interactive_metrics.html'):
        """创建交互式训练指标图"""
        epochs = list(range(1, len(history['train_loss']) + 1))
        
        fig = make_subplots(rows=2, cols=2,
                           subplot_titles=("Training and Validation Loss",
                                          "Training and Validation Accuracy",
                                          "Training and Validation F1 Score",
                                          "Learning Rate Schedule"))
        
        # Loss曲线
        fig.add_trace(
            go.Scatter(x=epochs, y=history['train_loss'], name="Train Loss", line=dict(color='blue')),
            row=1, col=1
        )
        fig.add_trace(
            go.Scatter(x=epochs, y=history['val_loss'], name="Validation Loss", line=dict(color='red')),
            row=1, col=1
        )
        
        # Accuracy曲线
        fig.add_trace(
            go.Scatter(x=epochs, y=history['train_acc'], name="Train Accuracy", line=dict(color='blue')),
            row=1, col=2
        )
        fig.add_trace(
            go.Scatter(x=epochs, y=history['val_acc'], name="Validation Accuracy", line=dict(color='red')),
            row=1, col=2
        )
        
        # F1 Score曲线
        fig.add_trace(
            go.Scatter(x=epochs, y=history['train_f1'], name="Train F1", line=dict(color='blue')),
            row=2, col=1
        )
        fig.add_trace(
            go.Scatter(x=epochs, y=history['val_f1'], name="Validation F1", line=dict(color='red')),
            row=2, col=1
        )
        
        # 学习率曲线
        fig.add_trace(
            go.Scatter(x=epochs, y=history['lr'], name="Learning Rate", line=dict(color='green')),
            row=2, col=2
        )
        
        # 更新布局
        fig.update_layout(height=800, width=1000, title_text="Training Metrics", showlegend=True)
        fig.update_xaxes(title_text="Epoch", row=1, col=1)
        fig.update_xaxes(title_text="Epoch", row=1, col=2)
        fig.update_xaxes(title_text="Epoch", row=2, col=1)
        fig.update_xaxes(title_text="Epoch", row=2, col=2)
        
        fig.update_yaxes(title_text="Loss", row=1, col=1)
        fig.update_yaxes(title_text="Accuracy", row=1, col=2)
        fig.update_yaxes(title_text="F1 Score", row=2, col=1)
        fig.update_yaxes(title_text="Learning Rate", row=2, col=2)
        
        fig.write_html(filename)

# ==================== Training Pipeline ====================
def train_one_epoch(model, train_loader, optimizer, criterion, scaler, config, epoch, time_tracker=None):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    all_preds = []
    all_targets = []
    
    pbar = tqdm(train_loader, desc=f'Epoch {epoch}')
    for batch_idx, (inputs, targets) in enumerate(pbar):
        if time_tracker is not None:
            time_tracker.start_batch()
            
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        
        with torch.amp.autocast(device_type='cuda'):  # 更新后的 API
            outputs = model(inputs)
            loss = criterion(outputs, targets)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
        # 保存预测结果和真实标签
        all_preds.extend(predicted.cpu().numpy())
        all_targets.extend(targets.cpu().numpy())
        
        if time_tracker is not None:
            batch_time = time_tracker.end_batch()
            pbar.set_postfix({
                'Loss': total_loss/(batch_idx+1), 
                'Acc': 100.*correct/total,
                'BatchTime': f'{batch_time:.3f}s'
            })
        else:
            pbar.set_postfix({
                'Loss': total_loss/(batch_idx+1), 
                'Acc': 100.*correct/total
            })
    
    # 计算训练集的 F1 分数（确保 all_targets 和 all_preds 非空）
    if len(all_targets) > 0 and len(all_preds) > 0:
        train_f1 = f1_score(all_targets, all_preds, average='macro')
    else:
        train_f1 = 0.0  # 默认值
    
    return {
        'train_loss': total_loss/len(train_loader),
        'train_acc': correct/total,
        'train_f1': train_f1  # 确保包含此键
    }


# ==================== 验证函数 ====================
@torch.no_grad()
def validate(model, val_loader, criterion):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    all_preds = []
    all_targets = []
    
    for inputs, targets in val_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        
        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
        all_preds.extend(predicted.cpu().numpy())
        all_targets.extend(targets.cpu().numpy())
    
    # 计算混淆矩阵及其衍生指标
    cm = confusion_matrix(all_targets, all_preds)
    accuracy = accuracy_score(all_targets, all_preds)
    precision = precision_score(all_targets, all_preds, average='macro')
    recall = recall_score(all_targets, all_preds, average='macro')
    f1 = f1_score(all_targets, all_preds, average='macro')
    
    # 计算每个类别的指标
    class_precision = precision_score(all_targets, all_preds, average=None)
    class_recall = recall_score(all_targets, all_preds, average=None)
    class_f1 = f1_score(all_targets, all_preds, average=None)
    
    # 打印详细的分类报告
    print("\nClassification Report:")
    print(classification_report(all_targets, all_preds, target_names=val_loader.dataset.classes))
    
    # 打印混淆矩阵衍生指标
    print("\nConfusion Matrix Derived Metrics:")
    print(f"{'Class':<15}{'Precision':>12}{'Recall':>12}{'F1-score':>12}")
    for i, class_name in enumerate(val_loader.dataset.classes):
        print(f"{class_name:<15}{class_precision[i]:>12.4f}{class_recall[i]:>12.4f}{class_f1[i]:>12.4f}")
    print(f"{'Macro Avg':<15}{precision:>12.4f}{recall:>12.4f}{f1:>12.4f}")
    
    return {
        'val_loss': total_loss/len(val_loader),
        'val_acc': accuracy,
        'val_f1': f1,
        'val_precision': precision,
        'val_recall': recall,
        'confusion_matrix': cm
    }

# ==================== 主函数 ====================
def main():
    # 配置参数
    config = {
        'model_name': 'tiny_vit_5m_224',
        'num_classes': 6,
        'drop_key_prob': 0.1,
        'epochs': 100,
        'batch_size': 64,
        'lr': 2e-4,
        'weight_decay': 0.05,
        'seed': 42
    }
    
    # 初始化
    torch.manual_seed(config['seed'])
    np.random.seed(config['seed'])
    random.seed(config['seed'])
    time_tracker = TimeTracker()
    
    # 数据增强和加载
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    train_dataset = datasets.ImageFolder('soil/train', train_transform)
    val_dataset = datasets.ImageFolder('soil/val', val_transform)
    
    train_loader = DataLoader(
        train_dataset, 
        batch_size=config['batch_size'], 
        shuffle=True, 
        num_workers=4,
        pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        num_workers=4,
        pin_memory=True
    )
    
    # 模型初始化 - 添加这部分代码
    base_model = tiny_vit_5m_224(pretrained=True).to(device)
    model = EnhancedTinyViT(base_model, config['num_classes'], config['drop_key_prob']).to(device)
    
    # 训练组件 - 添加这部分代码
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=config['lr'], weight_decay=config['weight_decay'])
    scaler = GradScaler()
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['epochs'])
    
    # 训练循环
    best_acc = 0
    history = {
        'train_loss': [],
        'val_loss': [],
        'train_acc': [],
        'train_f1': [],
        'val_acc': [],
        'val_f1': [],
        'val_precision': [],
        'val_recall': [],
        'lr': []
    }
    
    for epoch in range(1, config['epochs'] + 1):
        time_tracker.start_epoch()
        
        # 训练和验证
        train_metrics = train_one_epoch(model, train_loader, optimizer, criterion, scaler, config, epoch, time_tracker)
        val_metrics = validate(model, val_loader, criterion)
        
        # 更新学习率
        scheduler.step()
        
        # 记录历史数据
        history['train_loss'].append(train_metrics['train_loss'])
        history['val_loss'].append(val_metrics['val_loss'])
        history['train_acc'].append(train_metrics['train_acc'])
        history['val_acc'].append(val_metrics['val_acc'])
        history['train_f1'].append(train_metrics['train_f1'])
        history['val_f1'].append(val_metrics['val_f1'])
        history['val_precision'].append(val_metrics['val_precision'])  # 新增
        history['val_recall'].append(val_metrics['val_recall'])        # 新增
        history['lr'].append(optimizer.param_groups[0]['lr'])
        
        # 保存最佳模型
        if val_metrics['val_acc'] > best_acc:
            best_acc = val_metrics['val_acc']
            torch.save(model.state_dict(), f'best_model_{best_acc:.4f}.pth')
        
        # 打印日志（增强版）
        epoch_time = time_tracker.end_epoch()
        print(f"\nEpoch {epoch}/{config['epochs']} Summary:")
        print(f"Train Loss: {train_metrics['train_loss']:.4f} | Train Acc: {train_metrics['train_acc']:.4f} | Train F1: {train_metrics['train_f1']:.4f}")
        print(f"Val Loss: {val_metrics['val_loss']:.4f} | Val Acc: {val_metrics['val_acc']:.4f}")
        print(f"Val Precision: {val_metrics['val_precision']:.4f} | Val Recall: {val_metrics['val_recall']:.4f} | Val F1: {val_metrics['val_f1']:.4f}")
        print(f"Time: {epoch_time:.2f}s | LR: {optimizer.param_groups[0]['lr']:.2e}")
    
    # 训练结束后生成可视化结果...(保持不变)
    
    # 最终输出
    print(f"\nTraining completed. Best Val Acc: {best_acc:.4f}")
    time_stats = time_tracker.get_stats()
    print(f"\nTime Statistics:")
    print(f"Total training time: {time_stats['total_epoch_time']:.2f}s")
    print(f"Average epoch time: {time_stats['avg_epoch_time']:.2f}s")
    print(f"Median epoch time: {time_stats['median_epoch_time']:.2f}s")
    print(f"Fastest epoch: {time_stats['min_epoch_time']:.2f}s")
    print(f"Slowest epoch: {time_stats['max_epoch_time']:.2f}s")

if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    main()  # 移除了wandb.init()
