In [None]:

import os
import time
import random
import itertools
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader
from torch.cuda.amp import GradScaler, autocast
from sklearn.metrics import (
    accuracy_score, 
    precision_score, 
    recall_score, 
    f1_score,
    confusion_matrix,
    classification_report
)
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from ptflops import get_model_complexity_info
from models.tiny_vit import tiny_vit_5m_224, tiny_vit_11m_224, tiny_vit_21m_224, tiny_vit_21m_384, tiny_vit_21m_512

# ----------------- 数据增强 -----------------

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.RandomErasing(p=0.5, scale=(0.02, 0.2)),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# ----------------- 可视化工具 -----------------
class Visualizer:
    @staticmethod
    def cnf_matrix_plotter(cm, classes, cmap=plt.cm.Blues, filename='confusion_matrix.pdf'):
        plt.figure(figsize=(8, 6))
        plt.imshow(cm, interpolation='nearest', cmap=cmap)
        plt.title('Confusion Matrix', fontsize=14, pad=20)
        plt.colorbar()

        tick_marks = np.arange(len(classes))
        plt.xticks(tick_marks, classes, rotation=45, fontsize=12)
        plt.yticks(tick_marks, classes, rotation=45, fontsize=12)

        thresh = cm.max() / 2.
        for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
            plt.text(j, i, f"{cm[i, j]}",
                     horizontalalignment="center",
                     verticalalignment="center",
                     color="white" if cm[i, j] > thresh else "black",
                     fontsize=12)

        plt.ylabel('True label', fontsize=14)
        plt.xlabel('Predicted label', fontsize=14)
        plt.tight_layout()
        plt.savefig(filename, dpi=600, bbox_inches='tight')
        plt.close()

    @staticmethod 
    def plot_tsne(features, labels, class_names, filename='tsne.pdf'):
        tsne = TSNE(n_components=2, random_state=42, perplexity=30, n_iter=300)
        features_tsne = tsne.fit_transform(features)

        plt.figure(figsize=(10, 8))
        colors = list(mcolors.TABLEAU_COLORS.values())

        for i, class_name in enumerate(class_names):
            mask = labels == i
            plt.scatter(features_tsne[mask, 0], features_tsne[mask, 1],
                        c=[colors[i]], label=class_name, s=50, alpha=0.6)

        plt.xticks(fontsize=12)
        plt.yticks(fontsize=12)
        plt.xlabel('t-SNE 1', fontsize=14)
        plt.ylabel('t-SNE 2', fontsize=14)
        plt.legend(loc='upper right', fontsize=12)
        plt.grid(True, alpha=0.3)
        plt.savefig(filename, dpi=600, bbox_inches='tight')
        plt.close()

    @staticmethod
    def plot_training_metrics(history, filename='training_metrics.pdf'):
        epochs = range(1, len(history['train_loss']) + 1)

        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(14, 10))

        ax1.plot(epochs, history['train_loss'], 'b-', label='Train')
        ax1.plot(epochs, history['val_loss'], 'r-', label='Validation')
        ax1.set_title('Loss Curve')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.legend()
        ax1.grid(True, alpha=0.3)

        ax2.plot(epochs, history['train_acc'], 'b-', label='Train')
        ax2.plot(epochs, history['val_acc'], 'r-', label='Validation')
        ax2.set_title('Accuracy Curve')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Accuracy')
        ax2.legend()
        ax2.grid(True, alpha=0.3)

        ax3.plot(epochs, history['train_f1'], 'b-', label='Train')
        ax3.plot(epochs, history['val_f1'], 'r-', label='Validation')
        ax3.set_title('F1 Score Curve')
        ax3.set_xlabel('Epoch')
        ax3.set_ylabel('F1 Score')
        ax3.legend()
        ax3.grid(True, alpha=0.3)

        ax4.plot(epochs, history['lr'], 'g-', label='Learning Rate')
        ax4.set_title('Learning Rate Schedule')
        ax4.set_xlabel('Epoch')
        ax4.set_ylabel('Learning Rate')
        ax4.legend()
        ax4.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.savefig(filename, dpi=600, bbox_inches='tight')
        plt.close()

# ----------------- 知识蒸馏损失 -----------------
class DistillationLoss(nn.Module):
    def __init__(self, temperature=3.0, alpha=0.7, mode='logits'):
        super().__init__()
        self.T = temperature
        self.alpha = alpha
        self.mode = mode
        self.ce_loss = nn.CrossEntropyLoss(label_smoothing=0.1)
        
    def forward(self, student_outputs, teacher_outputs, targets):
        # 硬目标损失
        hard_loss = self.ce_loss(student_outputs['logits'], targets)
        
        # 软目标损失
        if self.mode == 'logits':
            soft_loss = F.kl_div(
                F.log_softmax(student_outputs['logits']/self.T, dim=1),
                F.softmax(teacher_outputs['logits']/self.T, dim=1),
                reduction='batchmean'
            ) * (self.T ** 2)
        else:  # 特征蒸馏
            soft_loss = 0
            for layer in student_outputs['features']:
                s_feat = student_outputs['features'][layer]
                t_feat = teacher_outputs['features'][layer]
                soft_loss += F.mse_loss(s_feat, t_feat)
        
        return self.alpha * soft_loss + (1 - self.alpha) * hard_loss

# ----------------- TinyViT模型包装器（支持特征提取）-----------------
class TinyViTWrapper(nn.Module):
    def __init__(self, model_name='tiny_vit_5m_224', num_classes=6, pretrained=True, 
                 pretrained_path=None, distill_layers=None):
        super().__init__()
        self.model = tiny_vit_5m_224(pretrained=pretrained, num_classes=num_classes)
        
        if pretrained_path:
            state_dict = torch.load(pretrained_path, weights_only=True)
            self.model.load_state_dict(state_dict, strict=False)
            
        self.distill_layers = distill_layers or [
            f'layers.{len(self.model.layers)-3}.blocks.{i}' 
            for i in range(self.model.depths[-3])
        ]
        
        self.original_head = self.model.head
        self.model.head = nn.Linear(self.model.head.in_features, num_classes)
        
        # Get actual resolutions from model
        self.layer_resolutions = self._get_actual_resolutions()

    def _get_actual_resolutions(self):
        """Get actual resolutions from model layers"""
        resolutions = []
        # Get patch embed resolution
        H = self.model.patch_embed.patches_resolution[0]
        W = self.model.patch_embed.patches_resolution[1]
        resolutions.append((H, W))
        
        # Get resolutions from layers
        for layer in self.model.layers:
            if hasattr(layer, 'blocks'):
                if len(layer.blocks) > 0 and hasattr(layer.blocks[0], 'input_resolution'):
                    H, W = layer.blocks[0].input_resolution
                    resolutions.append((H, W))
                else:
                    resolutions.append(resolutions[-1])
            else:
                resolutions.append(resolutions[-1])
        
        return resolutions[:len(self.model.layers)+1]

    def forward(self, x, return_features=False):
        if not return_features:
            return {
                'logits': self.model(x),
                'features': None
            }
            
        features = {}
        B = x.shape[0]
        
        # 1. Patch embedding
        x = self.model.patch_embed(x)
        C, H, W = x.shape[1], x.shape[2], x.shape[3]
        
        # Check initial resolution
        if (H, W) != self.layer_resolutions[0]:
            raise ValueError(
                f"Initial resolution mismatch: expected {self.layer_resolutions[0]}, got {(H, W)}"
            )
        
        # Convert to sequence format (B, L, C)
        x = x.flatten(2).permute(0, 2, 1)  # (B, C, H*W) -> (B, H*W, C)
        
        # 2. Process through layers
        for layer_idx, (layer, (H_layer, W_layer)) in enumerate(zip(self.model.layers, self.layer_resolutions[1:])):
            # Handle downsampling layers
            if hasattr(layer, 'downsample') and layer.downsample is not None:
                x = layer.downsample(x)
                H, W = H_layer, W_layer
                continue
            
            if hasattr(layer, 'blocks'):
                for block_idx, block in enumerate(layer.blocks):
                    # Ensure input shape is correct
                    if hasattr(block, 'input_resolution'):
                        expected_H, expected_W = block.input_resolution
                        current_L = x.shape[1]
                        if current_L != expected_H * expected_W:
                            # Reshape to spatial format for interpolation
                            x = x.permute(0, 2, 1).reshape(B, -1, H, W)
                            x = F.interpolate(x, size=(expected_H, expected_W), mode='nearest')
                            x = x.flatten(2).permute(0, 2, 1)
                            H, W = expected_H, expected_W
                    
                    # Save current shape for recovery
                    prev_shape = x.shape
                    
                    # Execute block forward pass
                    x = block(x)
                    
                    # Feature collection
                    name = f'layers.{layer_idx}.blocks.{block_idx}'
                    if name in self.distill_layers:
                        features[name] = x.mean(dim=1)  # (B, C)
                    
                    # Update channel dimension
                    C = x.shape[-1]
            else:
                x = layer(x)
                if f'layers.{layer_idx}' in self.distill_layers:
                    features[f'layers.{layer_idx}'] = x.mean(dim=1)
        
        # 3. Final classification
        x = x.mean(1)  # (B, C)
        x = self.model.norm_head(x)
        logits = self.model.head(x)
        
        return {
            'logits': logits,
            'features': features
        }


# ----------------- ResNet教师模型包装器 -----------------
class ResNetTeacher(nn.Module):
    def __init__(self, model_name='resnet50', num_classes=6):
        super().__init__()
        self.model = getattr(models, model_name)(pretrained=True)
        
        # 替换最后的全连接层
        in_features = self.model.fc.in_features
        self.model.fc = nn.Linear(in_features, num_classes)
        
        # 特征提取层
        self.feature_layers = ['layer1', 'layer2']
        
    def forward(self, x, return_features=False):
        if return_features:
            features = {}
            
            x = self.model.conv1(x)
            x = self.model.bn1(x)
            x = self.model.relu(x)
            x = self.model.maxpool(x)

            x = self.model.layer1(x)
            features['layer1'] = F.adaptive_avg_pool2d(x, (1, 1)).squeeze()
            
            x = self.model.layer2(x)
            features['layer2'] = F.adaptive_avg_pool2d(x, (1, 1)).squeeze()
            
            x = self.model.layer3(x)
            x = self.model.layer4(x)
            
            x = self.model.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.model.fc(x)
            
            return {
                'logits': x,
                'features': features
            }
        return {
            'logits': self.model(x),
            'features': None
        }

# ----------------- 训练时间记录器 -----------------
class TimeTracker:
    def __init__(self):
        self.epoch_times = []
        self.batch_times = []
        self.start_time = None

    def epoch_start(self):
        self.start_time = time.time()

    def epoch_end(self):
        epoch_time = time.time() - self.start_time
        self.epoch_times.append(epoch_time)
        return epoch_time

    def batch_end(self):
        self.batch_times.append(time.time() - self.start_time)

    def get_stats(self):
        return {
            'total_time': sum(self.epoch_times),
            'avg_epoch_time': np.mean(self.epoch_times),
            'avg_batch_time': np.mean(self.batch_times) if self.batch_times else 0,
            'epoch_times': self.epoch_times
        }

# ----------------- 训练函数 -----------------
def train_one_epoch(student, teacher, loader, optimizer, scaler, criterion, epoch, time_tracker, config):
    student.train()
    total_loss = 0.0
    preds, labels = [], []

    time_tracker.epoch_start()
    pbar = tqdm(loader, desc=f'Epoch {epoch} Training')

    for images, targets in pbar:
        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        optimizer.zero_grad()

        with torch.cuda.amp.autocast(enabled=True):
            # 教师模型推理
            with torch.no_grad():
                teacher_outputs = teacher(images, return_features=True)
            
            # 学生模型推理
            student_outputs = student(images, return_features=True)
            
            # 计算蒸馏损失
            loss = criterion(student_outputs, teacher_outputs, targets)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        time_tracker.batch_end()

        total_loss += loss.item()
        batch_preds = torch.argmax(student_outputs['logits'], 1).cpu().numpy()
        batch_labels = targets.cpu().numpy()
        preds.extend(batch_preds)
        labels.extend(batch_labels)

        pbar.set_postfix({'Loss': loss.item()})

    epoch_time = time_tracker.epoch_end()
    return {
        'train_loss': total_loss / len(loader),
        'train_acc': accuracy_score(labels, preds),
        'train_f1': f1_score(labels, preds, average='macro'),
        'epoch_time': epoch_time
    }

# ----------------- 验证函数 -----------------
@torch.no_grad()
def evaluate(model, loader, criterion, epoch, is_best=False):
    model.eval()
    total_loss = 0.0
    preds, labels, features = [], [], []

    for images, targets in tqdm(loader, desc='Evaluating'):
        images = images.to(device)
        targets = targets.to(device)

        with torch.cuda.amp.autocast(enabled=True):
            outputs = model(images, return_features=True)
            loss = criterion(outputs, None, targets) 

        total_loss += loss.item()
        preds.extend(torch.argmax(outputs['logits'], 1).cpu().numpy())
        labels.extend(targets.cpu().numpy())
        if outputs['features'] is not None:
            features.extend([f.mean(dim=0).cpu().numpy() for f in outputs['features'].values()])

    metrics = {
        'val_loss': total_loss / len(loader),
        'val_acc': accuracy_score(labels, preds),
        'val_precision': precision_score(labels, preds, average='macro'),
        'val_recall': recall_score(labels, preds, average='macro'),
        'val_f1': f1_score(labels, preds, average='macro')
    }

    if is_best and features:
        idx_to_labels = np.load('idx_to_labels.npy', allow_pickle=True).item()
        class_names = [idx_to_labels[i] for i in range(len(idx_to_labels))]
        
        features = np.array(features)
        Visualizer.plot_tsne(features, np.array(labels), class_names, 
                 filename=f'checkpoint/best_tsne_epoch{epoch}.pdf')

        cm = confusion_matrix(labels, preds)
        Visualizer.cnf_matrix_plotter(cm, class_names, filename=f'checkpoint/best_cm_epoch{epoch}.pdf')
        print("\nClassification Report:")
        print(classification_report(labels, preds, target_names=class_names, digits=4))

    return metrics

# ----------------- 主函数 -----------------
def main():
    # 初始化配置
    config = {
        'model_name': 'tiny_vit_5m_224',
        'teacher_model': 'resnet50',
        'num_classes': 6,
        'pretrained_path': 'tiny_vit_5m_22kto1k_distill.pth',
        'distill_mode': 'features',  # 'logits' or 'features'
        'temperature': 3.0,
        'alpha': 0.7,
        'epochs': 80,
        'batch_size': 64,
        'lr': 2e-4,
        'weight_decay': 0.05,
        'seed': 42
    }
    
    time_tracker = TimeTracker()

    # 设置随机种子
    random.seed(config['seed'])
    np.random.seed(config['seed'])
    torch.manual_seed(config['seed'])
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(config['seed'])

    # 数据加载
    train_dataset = datasets.ImageFolder('soil/train', train_transform)
    test_dataset = datasets.ImageFolder('soil/val', test_transform)

    idx_to_labels = {v: k for k, v in train_dataset.class_to_idx.items()}
    np.save('idx_to_labels.npy', idx_to_labels)

    train_loader = DataLoader(
        train_dataset, batch_size=config['batch_size'], shuffle=True,
        num_workers=4, pin_memory=True, persistent_workers=True
    )
    test_loader = DataLoader(
        test_dataset, batch_size=config['batch_size'],
        num_workers=4, pin_memory=True, persistent_workers=True
    )

    # 初始化模型
    student = TinyViTWrapper(
        model_name=config['model_name'],
        num_classes=len(idx_to_labels),
        pretrained_path=config['pretrained_path']
    ).to(device)
    
    teacher = ResNetTeacher(
        model_name=config['teacher_model'],
        num_classes=len(idx_to_labels)
    ).to(device).eval()
    
    # 冻结教师模型
    for param in teacher.parameters():
        param.requires_grad = False

    # 计算模型统计信息
    macs, params = get_model_complexity_info(student, (3, 224, 224), as_strings=False)
    model_stats = {
        'params(M)': params / 1e6,
        'FLOPs(G)': macs / 1e9,
        'MACs(G)': macs / 1e9 * 2
    }
    print("\nStudent Model Analysis:")
    print(f"Parameters: {model_stats['params(M)']:.2f}M")
    print(f"FLOPs: {model_stats['FLOPs(G)']:.2f}G")

    # 训练准备
    criterion = DistillationLoss(
        temperature=config['temperature'],
        alpha=config['alpha'],
        mode=config['distill_mode']
    )
    optimizer = torch.optim.AdamW(
        student.parameters(), 
        lr=config['lr'], 
        weight_decay=config['weight_decay']
    )
    scaler = GradScaler()
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, 
        T_0=10, 
        T_mult=2
    )

    # 训练循环
    best_acc = 0.0
    history = {
        'train_loss': [], 'val_loss': [],
        'train_acc': [], 'val_acc': [],
        'train_f1': [], 'val_f1': [],
        'lr': [], 'epoch_times': []
    }

    for epoch in range(1, config['epochs'] + 1):
        train_metrics = train_one_epoch(
            student, teacher, train_loader, 
            optimizer, scaler, criterion, 
            epoch, time_tracker, config
        )
        val_metrics = evaluate(student, test_loader, criterion, epoch)

        # 更新最佳模型
        if val_metrics['val_acc'] > best_acc:
            best_acc = val_metrics['val_acc']
            torch.save(student.state_dict(), f'checkpoint/best_model_{best_acc:.3f}.pth')
            _ = evaluate(student, test_loader, criterion, epoch, is_best=True)
            print(f'New best model saved with acc {best_acc:.3f}')

        # 更新历史记录
        for k in history:
            if k in train_metrics: history[k].append(train_metrics[k])
            elif k in val_metrics: history[k].append(val_metrics[k])
            elif k == 'lr': history[k].append(optimizer.param_groups[0]['lr'])

        scheduler.step()

        # 打印当前epoch信息
        print(f"Epoch {epoch}/{config['epochs']}: "
              f"Train Loss: {train_metrics['train_loss']:.4f}, "
              f"Train Acc: {train_metrics['train_acc']:.4f}, "
              f"Val Loss: {val_metrics['val_loss']:.4f}, "
              f"Val Acc: {val_metrics['val_acc']:.4f}, "
              f"LR: {optimizer.param_groups[0]['lr']:.6f}")

    # 保存最终模型
    torch.save(student.state_dict(), 'checkpoint/final_model.pth')
    
    # 训练结束统计
    training_stats = {
        'total_time': time_tracker.get_stats()['total_time'],
        'avg_epoch_time': time_tracker.get_stats()['avg_epoch_time'],
        'best_val_acc': best_acc,
        **model_stats
    }
    np.savez('checkpoint/final_stats.npz', **training_stats)
    
    # 绘制训练曲线
    Visualizer.plot_training_metrics(history, filename='checkpoint/training_metrics.pdf')

    print("\nTraining Completed with Stats:")
    print(f"Total Training Time: {training_stats['total_time']/3600:.2f} hours")
    print(f"Average Epoch Time: {training_stats['avg_epoch_time']:.2f} seconds")
    print(f"Best Validation Accuracy: {best_acc:.4f}")

if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.backends.cudnn.benchmark = True
    os.makedirs('checkpoint', exist_ok=True)
    main()
