In [None]:
"""
CMT-Mamba知识蒸馏半监督学习框架
数据集: CIFAR10
教师模型: CMT-Small with Mamba
学生模型: ResNet18
训练方式: 半监督学习
"""

import os
import numpy as np
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import context
from mindspore.common import set_seed
from mindspore.dataset import GeneratorDataset
import mindspore.dataset as ds
import mindspore.dataset.transforms as transforms
import mindspore.dataset.vision as vision
from mindspore.train.callback import ModelCheckpoint, LossMonitor, TimeMonitor
import sys
import datetime
import logging
from tqdm import tqdm

# 导入MindCV相关模块
from mindcv.models import create_model
from mindcv.data import create_dataset, create_transforms, create_loader
from mindcv.loss import create_loss
from mindcv.scheduler import create_scheduler

# 设置随机种子确保实验可重复
set_seed(42)

# 定义命令行参数
import argparse
parser = argparse.ArgumentParser(description='半监督知识蒸馏学习 CMT-Mamba → ResNet')

# 模型参数
parser.add_argument('--teacher_model', type=str, default='cmt_small', help='教师模型名称')
parser.add_argument('--student_model', type=str, default='resnet18', help='学生模型名称')
parser.add_argument('--use_mamba', action='store_true', default=True, help='教师模型是否使用Mamba模块')
parser.add_argument('--pretrained_teacher', type=str, default='', help='预训练教师模型路径')

# 数据集参数
parser.add_argument('--dataset', type=str, default='cifar10', help='数据集名称')
parser.add_argument('--data_dir', type=str, default='cifar-10-batches-bin', help='数据集路径')
parser.add_argument('--labeled_ratio', type=float, default=0.1, help='有标签数据比例')
parser.add_argument('--batch_size', type=int, default=64, help='批量大小')
parser.add_argument('--num_workers', type=int, default=8, help='数据加载器工作线程数')

# 训练参数
parser.add_argument('--epochs', type=int, default=200, help='训练轮数')
parser.add_argument('--lr', type=float, default=0.01, help='学习率')
parser.add_argument('--momentum', type=float, default=0.9, help='SGD动量')
parser.add_argument('--weight_decay', type=float, default=1e-4, help='权重衰减')
parser.add_argument('--T', type=float, default=2.0, help='蒸馏温度')
parser.add_argument('--alpha', type=float, default=0.5, help='蒸馏损失权重')
parser.add_argument('--beta', type=float, default=0.3, help='一致性损失权重')
parser.add_argument('--ema_decay', type=float, default=0.999, help='指数移动平均衰减率')

# 设备参数
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU', 'CPU'], 
                    help='运行设备')
parser.add_argument('--device_id', type=int, default=0, help='设备ID')
parser.add_argument('--amp_level', type=str, default='O2', help='混合精度级别')

# 保存路径参数
parser.add_argument('--save_dir', type=str, default='./checkpoints', help='模型保存路径')
parser.add_argument('--save_interval', type=int, default=10, help='模型保存间隔')

args = parser.parse_args()

# 配置运行环境
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)

# 创建数据增强
def create_weak_augmentation():
    """创建弱数据增强"""
    return [
        # 添加Resize操作将CIFAR10图像放大到模型期望的尺寸
        vision.Resize(224),  # CMT模型期望的输入尺寸
        vision.RandomCrop(224, padding=28),  # 相应调整RandomCrop尺寸
        vision.RandomHorizontalFlip(prob=0.5),
        vision.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
        vision.HWC2CHW()
    ]

def create_strong_augmentation():
    """创建强数据增强 (RandAugment)"""
    return [
        # 添加Resize操作将CIFAR10图像放大到模型期望的尺寸
        vision.Resize(224),  # CMT模型期望的输入尺寸
        vision.RandomCrop(224, padding=28),  # 相应调整RandomCrop尺寸
        vision.RandomHorizontalFlip(prob=0.5),
        # RandAugment替代方案：增强多种变换组合
        vision.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
        vision.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.8, 1.2), shear=5),
        vision.RandomErasing(prob=0.2),
        vision.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
        vision.HWC2CHW()
    ]

# 自定义EMA模型更新器
class EMA:
    def __init__(self, model, shadow_model, decay=0.999):
        self.model = model
        self.shadow_model = shadow_model
        self.decay = decay
        self.shadow_params = [p.clone() for p in shadow_model.get_parameters()]
        self.backup_params = []
        
    def update(self):
        """更新Shadow模型参数"""
        model_params = list(self.model.get_parameters())
        for i, param in enumerate(self.shadow_params):
            param.assign_value((self.decay * param + (1 - self.decay) * model_params[i]))
            
    def apply_shadow(self):
        """应用Shadow参数到模型"""
        model_params = list(self.model.get_parameters())
        self.backup_params = [p.clone() for p in model_params]
        for i, param in enumerate(model_params):
            param.assign_value(self.shadow_params[i])
            
    def restore(self):
        """恢复模型原始参数"""
        model_params = list(self.model.get_parameters())
        for i, param in enumerate(model_params):
            param.assign_value(self.backup_params[i])


# 定义CMT-Mamba知识蒸馏半监督训练流程
def train_distill():
    # 设置日志记录
    current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    log_dir = os.path.join("logs", f"cmt_mamba_distill_{current_time}")
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    
    log_file = os.path.join(log_dir, "training.log")
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s [%(levelname)s] %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler(sys.stdout)
        ]
    )
    
    # 记录训练配置
    logging.info(f"启动CMT-Mamba知识蒸馏半监督训练")
    logging.info(f"配置参数: {vars(args)}")
    
    # 加载CIFAR10数据集
    cifar10_dataset = create_dataset(
        name=args.dataset,
        root=args.data_dir,
        split='train',
        download=False
    )
    
    # 手动进行有标签/无标签分割
    dataset_np = []
    print("加载并转换CIFAR10数据集...")
    for img, label in cifar10_dataset:
        dataset_np.append((img.asnumpy(), label.asnumpy().item()))
    
    print(f"数据集大小: {len(dataset_np)}")
    rng = np.random.RandomState(42)
    rng.shuffle(dataset_np)
    
    # 按类别分割
    class_indices = [[] for _ in range(10)]  # CIFAR10有10个类别
    for i, (img, label) in enumerate(dataset_np):
        class_indices[label].append(i)
    
    # 对每个类别进行分层抽样
    labeled_indices = []
    unlabeled_indices = []
    for indices in class_indices:
        n_labeled = max(1, int(len(indices) * args.labeled_ratio))
        labeled_indices.extend(indices[:n_labeled])
        unlabeled_indices.extend(indices[n_labeled:])
    
    # 创建有标签和无标签数据集
    labeled_data = [(dataset_np[i][0], dataset_np[i][1]) for i in labeled_indices]
    unlabeled_data = [(dataset_np[i][0], dataset_np[i][1]) for i in unlabeled_indices]
    
    print(f"有标签数据量: {len(labeled_data)}, 无标签数据量: {len(unlabeled_data)}")
    
    # 创建数据转换
    weak_transform = create_weak_augmentation()
    strong_transform = create_strong_augmentation()
    
    # 创建有标签数据加载器
    def labeled_generator():
        indices = list(range(len(labeled_data)))
        while True:
            rng.shuffle(indices)
            for idx in indices:
                yield labeled_data[idx]
    
    labeled_ds = ds.GeneratorDataset(
        source=labeled_generator(),
        column_names=["image", "label"],
        shuffle=True
    )
    labeled_ds = labeled_ds.map(operations=weak_transform, input_columns=["image"])
    labeled_ds = labeled_ds.batch(args.batch_size)
    
    # 创建无标签数据加载器
    def unlabeled_generator():
        indices = list(range(len(unlabeled_data)))
        while True:
            rng.shuffle(indices)
            for idx in indices:
                img, _ = unlabeled_data[idx]
                yield img, img  # 弱增强、强增强的原始图像
    
    unlabeled_ds = ds.GeneratorDataset(
        source=unlabeled_generator(),
        column_names=["weak_image", "strong_image"],
        shuffle=True
    )
    unlabeled_ds = unlabeled_ds.map(operations=weak_transform, input_columns=["weak_image"])
    unlabeled_ds = unlabeled_ds.map(operations=strong_transform, input_columns=["strong_image"])
    unlabeled_ds = unlabeled_ds.batch(args.batch_size)
    
    # 测试集预处理
    test_transform = [
        vision.Resize(224),  # 确保测试图像也调整到正确尺寸
        vision.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
        vision.HWC2CHW()
    ]

    # 创建测试数据集
    test_ds = create_dataset(
        name=args.dataset,
        root=args.data_dir,
        split='test',
        download=False
    )
    test_ds = test_ds.map(operations=test_transform, input_columns=["image"])
    test_ds = test_ds.batch(args.batch_size)
    
    # 创建模型
    print("创建教师模型(CMT-Small)和学生模型(ResNet18)...")

    # 创建教师模型 - 使用CMT-Small with Mamba
    teacher_model = create_model(
        args.teacher_model,
        num_classes=10,
        in_channels=3,
        use_mamba=args.use_mamba,
    )

    # 创建学生模型 - 使用ResNet18
    student_model = create_model(
        args.student_model,
        num_classes=10,
        in_channels=3,
    )

    # 检查是否存在预训练的教师模型
    if not args.pretrained_teacher:
        logging.info("未找到预训练教师模型，开始使用有标签数据训练教师模型...")
        
        # 创建教师模型优化器
        teacher_optimizer = nn.Momentum(
            params=teacher_model.trainable_params(),
            learning_rate=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay
        )
        
        # 仅使用有标签数据训练教师模型
        train_teacher(teacher_model, teacher_optimizer, labeled_ds, test_ds, epochs=100, labeled_size=len(labeled_data))
        
        # 保存训练好的教师模型
        ms.save_checkpoint(teacher_model, "teacher_pretrained.ckpt")
        logging.info("教师模型预训练完成，已保存检查点")
    
    # 创建优化器
    lr = nn.cosine_decay_lr(
        min_lr=args.lr * 0.01,
        max_lr=args.lr,
        total_step=args.epochs * (len(labeled_data) // args.batch_size),
        step_per_epoch=len(labeled_data) // args.batch_size,
        decay_epoch=args.epochs
    )
    
    optimizer = nn.Momentum(
        params=student_model.trainable_params(),
        learning_rate=lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )
    
    # 创建EMA更新器
    teacher_ema = EMA(student_model, teacher_model, decay=args.ema_decay)
    
    # 修改DistillationTrainStep类
    class DistillationTrainStep(nn.Cell):
        def __init__(self, student_model, teacher_model, optimizer, args):
            super(DistillationTrainStep, self).__init__()
            self.student_model = student_model
            self.teacher_model = teacher_model
            self.optimizer = optimizer
            self.args = args
            # 显式指定网络参数，确保梯度计算与优化器参数匹配
            self.weights = self.student_model.trainable_params()
            self.grad_fn = ops.value_and_grad(self.forward, None, self.weights, has_aux=False)
        
        def forward(self, labeled_imgs, labels, unlabeled_weak, unlabeled_strong):
            # 教师模型推理
            teacher_labeled_logits = self.teacher_model(labeled_imgs)
            teacher_unlabeled_logits = self.teacher_model(unlabeled_weak)
            
            # 学生模型推理
            student_labeled_logits = self.student_model(labeled_imgs)
            student_unlabeled_strong_logits = self.student_model(unlabeled_strong)
            
            # 计算蒸馏损失
            distill_loss = self._distill_loss_fn(student_labeled_logits, teacher_labeled_logits, labels)
            
            # 计算半监督一致性损失
            # 生成伪标签
            pseudo_labels = ops.softmax(teacher_unlabeled_logits, axis=1)
            max_probs, targets = ops.max(pseudo_labels, axis=1)
            mask = (max_probs > 0.95).astype(ms.float32)  # 高置信度伪标签的掩码
            
            # 计算一致性损失
            consistency_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='none')(
                student_unlabeled_strong_logits, 
                targets
            )
            consistency_loss = (consistency_loss * mask).mean()
            
            # 计算总损失
            total_loss = distill_loss + self.args.beta * consistency_loss
            
            return total_loss
        
        def _distill_loss_fn(self, student_logits, teacher_logits, labels):
            # 监督损失
            ce_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
            hard_loss = ce_loss(student_logits, labels)
            
            # 蒸馏软标签损失
            T = self.args.T
            soft_student = ops.softmax(student_logits / T, axis=1)
            soft_teacher = ops.softmax(teacher_logits / T, axis=1)
            
            kl_div = ops.kl_div(
                ops.log(soft_student + 1e-10),
                soft_teacher,
                reduction='batchmean'
            ) * (T * T)
            
            # 组合损失
            return (1 - self.args.alpha) * hard_loss + self.args.alpha * kl_div
        
        def construct(self, labeled_imgs, labels, unlabeled_weak, unlabeled_strong):
            loss, grads = self.grad_fn(labeled_imgs, labels, unlabeled_weak, unlabeled_strong)
            self.optimizer(grads)
            return loss
    
    # 创建训练网络
    net = DistillationTrainStep(student_model, teacher_model, optimizer, args)
    
    # 创建评估函数
    def eval_model():
        top1_correct = 0
        top5_correct = 0
        total = 0
        student_model.set_train(False)
        
        logging.info("开始评估模型...")
        for data in tqdm(test_ds.create_dict_iterator(), desc="Evaluating"):
            images = data["image"]
            labels = data["label"]
            outputs = student_model(images)
            
            # 计算TOP1准确率
            _, top1_pred = ops.max(outputs, axis=1)
            top1_pred = top1_pred.astype(ms.int32)
            labels = labels.astype(ms.int32)
            
            # 计算TOP5准确率
            top5_preds = ops.topk(outputs, 5)[1]
            top5_preds = top5_preds.astype(ms.int32)
            
            # 使用广播比较每个样本的TOP5预测是否包含真实标签
            top5_correct_per_sample = ops.equal(
                top5_preds, 
                ops.reshape(labels, (-1, 1)).repeat(5, axis=1)
            )
            top5_correct_mask = ops.cast(ops.any(top5_correct_per_sample, 1), ms.float32)
            
            # 累计正确预测
            top1_correct_mask = ops.cast(top1_pred == labels, ms.float32)
            
            total += labels.shape[0]
            top1_correct += top1_correct_mask.sum().asnumpy().item()
            top5_correct += top5_correct_mask.sum().asnumpy().item()
        
        top1_acc = top1_correct / total
        top5_acc = top5_correct / total
        
        logging.info(f"评估结果 - TOP1准确率: {top1_acc:.4f}, TOP5准确率: {top5_acc:.4f}")
        return top1_acc, top5_acc
    
    # 创建模型保存目录
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    
    # 开始训练
    print("开始知识蒸馏半监督训练...")
    best_acc = 0.0
    
    labeled_iter = labeled_ds.create_dict_iterator()
    unlabeled_iter = unlabeled_ds.create_dict_iterator()
    
    for epoch in range(args.epochs):
        # 设置训练模式
        student_model.set_train(True)
        teacher_model.set_train(False)  # 教师模型始终处于评估模式
        
        # 训练一个epoch
        total_loss = 0.0
        steps = min(len(labeled_data) // args.batch_size, len(unlabeled_data) // args.batch_size)
        
        logging.info(f"Epoch {epoch+1}/{args.epochs} 开始训练...")
        epoch_progress = tqdm(range(steps), desc=f"Epoch {epoch+1}/{args.epochs}")
        
        for step in epoch_progress:
            try:
                labeled_batch = next(labeled_iter)
            except StopIteration:
                labeled_iter = labeled_ds.create_dict_iterator()
                labeled_batch = next(labeled_iter)
            
            try:
                unlabeled_batch = next(unlabeled_iter)
            except StopIteration:
                unlabeled_iter = unlabeled_ds.create_dict_iterator()
                unlabeled_batch = next(unlabeled_iter)
            
            labeled_imgs = labeled_batch["image"]
            labels = labeled_batch["label"]
            weak_imgs = unlabeled_batch["weak_image"]
            strong_imgs = unlabeled_batch["strong_image"]
            
            loss = net(labeled_imgs, labels, weak_imgs, strong_imgs)
            total_loss += loss.asnumpy().item()
            
            # 更新进度条
            epoch_progress.set_postfix(loss=f"{loss.asnumpy().item():.4f}")
            
            if step % 50 == 0:
                logging.info(f"Epoch: {epoch+1}/{args.epochs}, Step: {step+1}/{steps}, Loss: {loss.asnumpy().item():.4f}")
        
        # 评估模型
        avg_loss = total_loss / steps
        logging.info(f"Epoch: {epoch+1}/{args.epochs}, Avg Loss: {avg_loss:.4f}")
        
        # 使用学生模型评估
        top1_acc, top5_acc = eval_model()
        logging.info(f"Epoch: {epoch+1}/{args.epochs}, TOP1 Accuracy: {top1_acc:.4f}, TOP5 Accuracy: {top5_acc:.4f}")
        
        # 保存最佳模型 (基于TOP1准确率)
        if top1_acc > best_acc:
            best_acc = top1_acc
            ms.save_checkpoint(student_model, os.path.join(args.save_dir, f"{args.student_model}_best.ckpt"))
            logging.info(f"已保存最佳模型，TOP1准确率: {best_acc:.4f}, TOP5准确率: {top5_acc:.4f}")
        
        # 每隔一定轮数保存检查点
        if (epoch + 1) % args.save_interval == 0:
            ms.save_checkpoint(student_model, os.path.join(args.save_dir, f"{args.student_model}_epoch{epoch+1}.ckpt"))
            logging.info(f"已保存第{epoch+1}轮检查点")
    
    print(f"训练完成！最佳TOP1准确率: {best_acc:.4f}")


def train_teacher(model, optimizer, train_ds, test_ds, epochs, labeled_size):
    """训练教师模型的函数"""
    # 创建动态学习率
    total_steps = epochs * (labeled_size // args.batch_size)
    lr = nn.cosine_decay_lr(
        min_lr=0.0,
        max_lr=args.lr,
        total_step=total_steps,
        step_per_epoch=labeled_size // args.batch_size,
        decay_epoch=epochs
    )
    dynamic_lr = ms.Parameter(ms.Tensor(lr, ms.float32))
    
    # 创建新的优化器，使用动态学习率
    optimizer = nn.Momentum(
        params=model.trainable_params(),
        learning_rate=dynamic_lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )
    
    # 定义训练步骤
    class TeacherTrainStep(nn.Cell):
        def __init__(self, network, optimizer):
            super(TeacherTrainStep, self).__init__()
            self.network = network
            self.optimizer = optimizer
            self.grad_fn = ops.value_and_grad(self.forward, None, self.network.trainable_params())
            self.loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
        
        def forward(self, images, labels):
            logits = self.network(images)
            loss = self.loss_fn(logits, labels)
            return loss
        
        def construct(self, images, labels):
            loss, grads = self.grad_fn(images, labels)
            self.optimizer(grads)
            return loss
    
    # 创建训练网络
    train_step = TeacherTrainStep(model, optimizer)
    
    # 定义评估函数
    def evaluate():
        correct = 0
        total = 0
        model.set_train(False)
        for data in test_ds.create_dict_iterator():
            images = data["image"]
            labels = data["label"]
            outputs = model(images)
            _, predicted = ops.max(outputs, axis=1)
            predicted = predicted.astype(ms.int32)
            labels = labels.astype(ms.int32)
            total += labels.shape[0]
            correct += (predicted == labels).astype(ms.float32).sum().asnumpy().item()
        return correct / total
    
    logging.info("开始训练教师模型...")
    best_acc = 0.0
    
    # 训练循环
    for epoch in range(epochs):
        model.set_train(True)
        total_loss = 0.0
        steps = 0
        
        # 计算每个epoch的总步数
        total_steps = labeled_size // args.batch_size
        
        # 使用tqdm显示进度
        train_iter = tqdm(range(total_steps), desc=f"Teacher Epoch {epoch+1}/{epochs}")
        
        data_iter = train_ds.create_dict_iterator()
        for _ in train_iter:
            try:
                data = next(data_iter)
            except StopIteration:
                data_iter = train_ds.create_dict_iterator()
                data = next(data_iter)
            
            images = data["image"]
            labels = data["label"]
            loss = train_step(images, labels)
            total_loss += loss.asnumpy().item()
            steps += 1
            
            # 更新进度条
            train_iter.set_postfix(loss=f"{loss.asnumpy().item():.4f}")
            
            # 每个epoch训练固定步数
            if steps >= total_steps:
                break
        
        # 计算平均损失
        avg_loss = total_loss / steps
        
        # 评估模型
        acc = evaluate()
        
        # 记录训练信息
        logging.info(f"Teacher Epoch: {epoch+1}/{epochs}, Avg Loss: {avg_loss:.4f}, Accuracy: {acc:.4f}")
        
        # 保存最佳模型
        if acc > best_acc:
            best_acc = acc
            ms.save_checkpoint(model, "teacher_best.ckpt")
            logging.info(f"保存最佳教师模型，准确率: {acc:.4f}")
    
    # 训练结束后加载最佳模型
    param_dict = ms.load_checkpoint("teacher_best.ckpt")
    ms.load_param_into_net(model, param_dict)
    logging.info(f"教师模型训练完成，最佳准确率: {best_acc:.4f}")


if __name__ == "__main__":
    train_distill() 

In [None]:
"""
CMT-Mamba知识蒸馏半监督学习框架
数据集: CIFAR10
教师模型: CMT-Small with Mamba
学生模型: ResNet18
训练方式: 半监督学习
"""

import os
import numpy as np
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import context
from mindspore.common import set_seed
from mindspore.dataset import GeneratorDataset
import mindspore.dataset as ds
import mindspore.dataset.transforms as transforms
import mindspore.dataset.vision as vision
from mindspore.train.callback import ModelCheckpoint, LossMonitor, TimeMonitor
import sys
import datetime
import logging
from tqdm import tqdm
from PIL import Image
from typing import Optional, Callable, Tuple, Any

# 导入MindCV相关模块
from mindcv.models import create_model
from mindcv.data import create_dataset, create_transforms, create_loader
from mindcv.loss import create_loss
from mindcv.scheduler import create_scheduler

# 设置随机种子确保实验可重复
set_seed(42)

# 定义命令行参数
import argparse
parser = argparse.ArgumentParser(description='半监督知识蒸馏学习 CMT-Mamba → ResNet')

# 模型参数
parser.add_argument('--teacher_model', type=str, default='cmt_small', help='教师模型名称')
parser.add_argument('--student_model', type=str, default='resnet18', help='学生模型名称')
parser.add_argument('--use_mamba', action='store_true', default=True, help='教师模型是否使用Mamba模块')
parser.add_argument('--pretrained_teacher', type=str, default='', help='预训练教师模型路径')

# 数据集参数
parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100', 'cub200'], 
                    help='数据集名称')
parser.add_argument('--num_classes', type=int, default=10, help='类别数量(cifar10为10，cifar100为100，cub200为200)')
parser.add_argument('--data_dir', type=str, default='', help='数据集路径')
parser.add_argument('--labeled_ratio', type=float, default=0.1, help='有标签数据比例')
parser.add_argument('--batch_size', type=int, default=64, help='批量大小')
parser.add_argument('--num_workers', type=int, default=2, help='数据加载器工作线程数')

# 训练参数
parser.add_argument('--epochs', type=int, default=200, help='训练轮数')
parser.add_argument('--lr', type=float, default=0.01, help='学习率')
parser.add_argument('--momentum', type=float, default=0.9, help='SGD动量')
parser.add_argument('--weight_decay', type=float, default=1e-4, help='权重衰减')
parser.add_argument('--T', type=float, default=2.0, help='蒸馏温度')
parser.add_argument('--alpha', type=float, default=0.5, help='蒸馏损失权重')
parser.add_argument('--beta', type=float, default=0.3, help='一致性损失权重')
parser.add_argument('--ema_decay', type=float, default=0.999, help='指数移动平均衰减率')

# 设备参数
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU', 'CPU'], 
                    help='运行设备')
parser.add_argument('--device_id', type=int, default=0, help='设备ID')
parser.add_argument('--amp_level', type=str, default='O2', help='混合精度级别')

# 保存路径参数
parser.add_argument('--save_dir', type=str, default='./checkpoints', help='模型保存路径')
parser.add_argument('--save_interval', type=int, default=10, help='模型保存间隔')

args = parser.parse_args()

# 配置运行环境
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)

# 创建数据增强
def create_weak_augmentation():
    """创建弱数据增强"""
    # 根据数据集选择合适的归一化参数
    if args.dataset == 'cifar10':
        mean = [0.4914, 0.4822, 0.4465]
        std = [0.2023, 0.1994, 0.2010]
    else:  # cifar100
        mean = [0.5071, 0.4867, 0.4408]
        std = [0.2675, 0.2565, 0.2761]
        
    return [
        # 添加Resize操作将图像放大到模型期望的尺寸
        vision.Resize(224),  # CMT模型期望的输入尺寸
        vision.RandomCrop(224, padding=28),  # 相应调整RandomCrop尺寸
        vision.RandomHorizontalFlip(prob=0.5),
        vision.Normalize(mean=mean, std=std),
        vision.HWC2CHW()
    ]

def create_strong_augmentation():
    """创建强数据增强 (RandAugment)"""
    # 根据数据集选择合适的归一化参数
    if args.dataset == 'cifar10':
        mean = [0.4914, 0.4822, 0.4465]
        std = [0.2023, 0.1994, 0.2010]
    else:  # cifar100
        mean = [0.5071, 0.4867, 0.4408]
        std = [0.2675, 0.2565, 0.2761]
        
    return [
        # 添加Resize操作将图像放大到模型期望的尺寸
        vision.Resize(224),  # CMT模型期望的输入尺寸
        vision.RandomCrop(224, padding=28),  # 相应调整RandomCrop尺寸
        vision.RandomHorizontalFlip(prob=0.5),
        # RandAugment替代方案：增强多种变换组合
        vision.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
        vision.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.8, 1.2), shear=5),
        vision.RandomErasing(prob=0.2),
        vision.Normalize(mean=mean, std=std),
        vision.HWC2CHW()
    ]

# 创建一个函数来处理测试数据
def process_test_batch(batch_data):
    """统一处理测试批次数据"""
    if args.dataset == 'cifar10':
        images, labels = batch_data
    else:  # cifar100
        images, fine_labels, _ = batch_data  # 解包三个返回值
        labels = fine_labels
    
    # 应用数据转换
    # 根据数据集选择合适的归一化参数
    if args.dataset == 'cifar10':
        mean = [0.4914, 0.4822, 0.4465]
        std = [0.2023, 0.1994, 0.2010]
    else:  # cifar100
        mean = [0.5071, 0.4867, 0.4408]
        std = [0.2675, 0.2565, 0.2761]
        
    test_transform = [
        vision.Resize(224),  # 确保测试图像也调整到正确尺寸
        vision.Normalize(mean=mean, std=std),
        vision.HWC2CHW()
    ]
    
    # 如果输入是字典类型，则提取图像和标签
    if isinstance(batch_data, dict):
        images = batch_data.get("image")
        labels = batch_data.get("label")
    else:
        # 已经处理过的元组
        pass
    
    # 应用转换
    for op in test_transform:
        images = op(images)
    
    return images, labels

# 自定义EMA模型更新器
class EMA:
    def __init__(self, model, shadow_model, decay=0.999):
        self.model = model
        self.shadow_model = shadow_model
        self.decay = decay
        self.shadow_params = [p.clone() for p in shadow_model.get_parameters()]
        self.backup_params = []
        
    def update(self):
        """更新Shadow模型参数"""
        model_params = list(self.model.get_parameters())
        for i, param in enumerate(self.shadow_params):
            param.assign_value((self.decay * param + (1 - self.decay) * model_params[i]))
            
    def apply_shadow(self):
        """应用Shadow参数到模型"""
        model_params = list(self.model.get_parameters())
        self.backup_params = [p.clone() for p in model_params]
        for i, param in enumerate(model_params):
            param.assign_value(self.shadow_params[i])
            
    def restore(self):
        """恢复模型原始参数"""
        model_params = list(self.model.get_parameters())
        for i, param in enumerate(model_params):
            param.assign_value(self.backup_params[i])

# 添加CUB200Dataset类
class CUB200Dataset:
    """CUB_200_2011数据集加载器"""
    
    def __init__(self, root, split='train', transform=None):
        """
        初始化CUB_200_2011数据集
        
        Args:
            root: 数据集根目录
            split: 'train'或'test'
            transform: 图像变换
        """
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.split = split
        
        # 读取图像列表和对应的类别
        self.image_paths = []
        self.targets = []
        
        # 读取图像ID和图像路径的映射
        image_path_file = os.path.join(self.root, 'images.txt')
        image_paths = {}
        with open(image_path_file, 'r') as f:
            for line in f:
                image_id, image_path = line.strip().split()
                image_paths[image_id] = image_path
        
        # 读取图像ID和类别的映射
        image_class_file = os.path.join(self.root, 'image_class_labels.txt')
        image_classes = {}
        with open(image_class_file, 'r') as f:
            for line in f:
                image_id, class_id = line.strip().split()
                # 类别ID从1开始，转换为从0开始
                image_classes[image_id] = int(class_id) - 1
        
        # 读取训练/测试分割
        split_file = os.path.join(self.root, 'train_test_split.txt')
        train_test_split = {}
        with open(split_file, 'r') as f:
            for line in f:
                image_id, is_train = line.strip().split()
                train_test_split[image_id] = int(is_train)
        
        # 根据分割情况构建数据集
        for image_id in image_paths:
            # 1表示训练集，0表示测试集
            is_train = train_test_split[image_id] == 1
            if (self.split == 'train' and is_train) or (self.split == 'test' and not is_train):
                self.image_paths.append(os.path.join(self.root, 'images', image_paths[image_id]))
                self.targets.append(image_classes[image_id])
        
        # 获取类别名称
        self.classes = []
        with open(os.path.join(self.root, 'classes.txt'), 'r') as f:
            for line in f:
                class_id, class_name = line.strip().split()
                self.classes.append(class_name)
        
        print(f"CUB200数据集加载完成，{split}集包含 {len(self.image_paths)} 张图像")
    
    def __getitem__(self, index):
        path = self.image_paths[index]
        target = self.targets[index]
        
        # 加载图像
        img = Image.open(path).convert('RGB')
        
        # 应用变换
        if self.transform is not None:
            img = self.transform(img)
        
        return img, target
    
    def __len__(self):
        return len(self.image_paths)

# 修改命令行参数，添加cub200选项
def parse_args():
    parser = argparse.ArgumentParser(description='半监督知识蒸馏学习 CMT-Mamba → ResNet')
    
    # 模型参数
    parser.add_argument('--teacher_model', type=str, default='cmt_small', help='教师模型名称')
    parser.add_argument('--student_model', type=str, default='resnet18', help='学生模型名称')
    parser.add_argument('--use_mamba', action='store_true', default=True, help='教师模型是否使用Mamba模块')
    parser.add_argument('--pretrained_teacher', type=str, default='', help='预训练教师模型路径')
    
    # 数据集参数
    parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100', 'cub200'], 
                        help='数据集名称')
    parser.add_argument('--num_classes', type=int, default=10, help='类别数量(cifar10为10，cifar100为100，cub200为200)')
    parser.add_argument('--data_dir', type=str, default='', help='数据集路径')
    parser.add_argument('--labeled_ratio', type=float, default=0.1, help='有标签数据比例')
    parser.add_argument('--batch_size', type=int, default=64, help='批量大小')
    parser.add_argument('--num_workers', type=int, default=2, help='数据加载器工作线程数')
    
    # ... 其他现有参数 ...
    
    return parser.parse_args()

# 定义CMT-Mamba知识蒸馏半监督训练流程
def train_distill():
    # 设置日志记录
    current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    log_dir = os.path.join("logs", f"cmt_mamba_distill_{current_time}")
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    
    log_file = os.path.join(log_dir, "training.log")
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s [%(levelname)s] %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler(sys.stdout)
        ]
    )
    
    # 记录训练配置
    logging.info(f"启动CMT-Mamba知识蒸馏半监督训练")
    logging.info(f"配置参数: {vars(args)}")
    
    # 加载数据集
    if args.dataset == 'cub200':
        # 为CUB_200_2011数据集创建数据集加载器
        print(f"加载CUB_200_2011数据集...")
        
        # 创建CUB200数据集
        train_dataset = CUB200Dataset(
            root=args.data_dir,
            split='train',
            transform=None  # 暂不应用转换，稍后会应用
        )
        test_dataset = CUB200Dataset(
            root=args.data_dir,
            split='test',
            transform=None
        )
        
        # 将数据集转换为内存中的numpy数组，便于后续处理
        dataset_np = []
        for idx in range(len(train_dataset)):
            img, label = train_dataset[idx]
            # 转换为numpy数组
            img_np = np.array(img)
            dataset_np.append((img_np, label))
    else:
        # 原有的CIFAR数据集处理逻辑
        cifar_dataset = create_dataset(
            name=args.dataset,
            root=args.data_dir,
            split='train',
            download=False
        )
        
        # 手动进行有标签/无标签分割
        dataset_np = []
        print(f"加载并转换{args.dataset.upper()}数据集...")
        for data in cifar_dataset:
            if args.dataset == 'cifar10':
                img, label = data
            else:  # cifar100
                # 解包三个返回值：图像、细粒度标签和粗粒度标签
                img, fine_label, _ = data
                label = fine_label  # 我们只使用细粒度标签(100个类别)
            
            dataset_np.append((img.asnumpy(), label.asnumpy().item()))
    
    print(f"数据集大小: {len(dataset_np)}")
    rng = np.random.RandomState(42)
    rng.shuffle(dataset_np)
    
    # 按类别分割
    class_indices = [[] for _ in range(args.num_classes)]  # 根据数据集类别数创建索引列表
    for i, (img, label) in enumerate(dataset_np):
        class_indices[label].append(i)
    
    # 对每个类别进行分层抽样
    labeled_indices = []
    unlabeled_indices = []
    for indices in class_indices:
        n_labeled = max(1, int(len(indices) * args.labeled_ratio))
        labeled_indices.extend(indices[:n_labeled])
        unlabeled_indices.extend(indices[n_labeled:])
    
    # 创建有标签和无标签数据集
    labeled_data = [(dataset_np[i][0], dataset_np[i][1]) for i in labeled_indices]
    unlabeled_data = [(dataset_np[i][0], dataset_np[i][1]) for i in unlabeled_indices]
    
    print(f"有标签数据量: {len(labeled_data)}, 无标签数据量: {len(unlabeled_data)}")
    
    # 创建数据转换
    weak_transform = create_weak_augmentation()
    strong_transform = create_strong_augmentation()
    
    # 创建有标签数据加载器
    def labeled_generator():
        indices = list(range(len(labeled_data)))
        while True:
            rng.shuffle(indices)
            for idx in indices:
                yield labeled_data[idx]
    
    labeled_ds = ds.GeneratorDataset(
        source=labeled_generator(),
        column_names=["image", "label"],
        shuffle=True
    )
    labeled_ds = labeled_ds.map(operations=weak_transform, input_columns=["image"])
    labeled_ds = labeled_ds.batch(args.batch_size)
    
    # 创建无标签数据加载器
    def unlabeled_generator():
        indices = list(range(len(unlabeled_data)))
        while True:
            rng.shuffle(indices)
            for idx in indices:
                img, _ = unlabeled_data[idx]
                yield img, img  # 弱增强、强增强的原始图像
    
    unlabeled_ds = ds.GeneratorDataset(
        source=unlabeled_generator(),
        column_names=["weak_image", "strong_image"],
        shuffle=True
    )
    unlabeled_ds = unlabeled_ds.map(operations=weak_transform, input_columns=["weak_image"])
    unlabeled_ds = unlabeled_ds.map(operations=strong_transform, input_columns=["strong_image"])
    unlabeled_ds = unlabeled_ds.batch(args.batch_size)
    
    # 测试集预处理
    # 根据数据集选择合适的归一化参数
    if args.dataset == 'cifar10':
        mean = [0.4914, 0.4822, 0.4465]
        std = [0.2023, 0.1994, 0.2010]
    else:  # cifar100
        mean = [0.5071, 0.4867, 0.4408]
        std = [0.2675, 0.2565, 0.2761]
        
    test_transform = [
        vision.Resize(224),  # 确保测试图像也调整到正确尺寸
        vision.Normalize(mean=mean, std=std),
        vision.HWC2CHW()
    ]
    
    # 创建测试数据集
    if args.dataset == 'cub200':
        # 已经在前面创建了test_dataset，现在应用转换
        def test_generator():
            for idx in range(len(test_dataset)):
                img, label = test_dataset[idx]
                img_np = np.array(img)
                
                # 应用测试转换
                for op in test_transform:
                    img_np = op(img_np)
                
                yield img_np, label
        
        test_ds = ds.GeneratorDataset(
            source=test_generator(),
            column_names=["image", "label"],
            shuffle=False
        )
        test_ds = test_ds.batch(args.batch_size)
    else:
        # 原有的CIFAR测试集处理逻辑
        test_ds = create_dataset(
            name=args.dataset,
            root=args.data_dir,
            split='test',
            download=False
        )
        
        # 对测试集应用转换
        test_ds = test_ds.map(operations=test_transform, input_columns=["image"])
        test_ds = test_ds.batch(args.batch_size)
    
    # 创建模型
    print("创建教师模型(CMT-Small)和学生模型(ResNet18)...")

    # 创建教师模型 - 使用CMT-Small with Mamba
    teacher_model = create_model(
        args.teacher_model,
        num_classes=args.num_classes,
        in_channels=3,
        use_mamba=args.use_mamba,
    )

    # 创建学生模型 - 使用ResNet18
    student_model = create_model(
        args.student_model,
        num_classes=args.num_classes,
        in_channels=3,
        
    )

    # 检查是否存在预训练的教师模型
    if not args.pretrained_teacher:
        logging.info("未找到预训练教师模型，开始使用有标签数据训练教师模型...")
        
        # 创建教师模型优化器
        teacher_optimizer = nn.Momentum(
            params=teacher_model.trainable_params(),
            learning_rate=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay
        )
        
        # 仅使用有标签数据训练教师模型
        train_teacher(teacher_model, teacher_optimizer, labeled_ds, test_ds, epochs=100, labeled_size=len(labeled_data))
        
        # 保存训练好的教师模型
        ms.save_checkpoint(teacher_model, "teacher_pretrained.ckpt")
        logging.info("教师模型预训练完成，已保存检查点")
    
    # 创建优化器
    lr = nn.cosine_decay_lr(
        min_lr=args.lr * 0.01,
        max_lr=args.lr,
        total_step=args.epochs * (len(labeled_data) // args.batch_size),
        step_per_epoch=len(labeled_data) // args.batch_size,
        decay_epoch=args.epochs
    )
    
    optimizer = nn.Momentum(
        params=student_model.trainable_params(),
        learning_rate=lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )
    
    # 创建EMA更新器
    teacher_ema = EMA(student_model, teacher_model, decay=args.ema_decay)
    
    # 修改DistillationTrainStep类
    class DistillationTrainStep(nn.Cell):
        def __init__(self, student_model, teacher_model, optimizer, args):
            super(DistillationTrainStep, self).__init__()
            self.student_model = student_model
            self.teacher_model = teacher_model
            self.optimizer = optimizer
            self.args = args
            # 显式指定网络参数，确保梯度计算与优化器参数匹配
            self.weights = self.student_model.trainable_params()
            self.grad_fn = ops.value_and_grad(self.forward, None, self.weights, has_aux=False)
        
        def forward(self, labeled_imgs, labels, unlabeled_weak, unlabeled_strong):
            # 教师模型推理
            teacher_labeled_logits = self.teacher_model(labeled_imgs)
            teacher_unlabeled_logits = self.teacher_model(unlabeled_weak)
            
            # 学生模型推理
            student_labeled_logits = self.student_model(labeled_imgs)
            student_unlabeled_strong_logits = self.student_model(unlabeled_strong)
            
            # 计算蒸馏损失
            distill_loss = self._distill_loss_fn(student_labeled_logits, teacher_labeled_logits, labels)
            
            # 计算半监督一致性损失
            # 生成伪标签
            pseudo_labels = ops.softmax(teacher_unlabeled_logits, axis=1)
            max_probs, targets = ops.max(pseudo_labels, axis=1)
            mask = (max_probs > 0.95).astype(ms.float32)  # 高置信度伪标签的掩码
            
            # 计算一致性损失
            consistency_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='none')(
                student_unlabeled_strong_logits, 
                targets
            )
            consistency_loss = (consistency_loss * mask).mean()
            
            # 计算总损失
            total_loss = distill_loss + self.args.beta * consistency_loss
            
            return total_loss
        
        def _distill_loss_fn(self, student_logits, teacher_logits, labels):
            # 监督损失
            ce_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
            hard_loss = ce_loss(student_logits, labels)
            
            # 蒸馏软标签损失
            T = self.args.T
            soft_student = ops.softmax(student_logits / T, axis=1)
            soft_teacher = ops.softmax(teacher_logits / T, axis=1)
            
            kl_div = ops.kl_div(
                ops.log(soft_student + 1e-10),
                soft_teacher,
                reduction='batchmean'
            ) * (T * T)
            
            # 组合损失
            return (1 - self.args.alpha) * hard_loss + self.args.alpha * kl_div
        
        def construct(self, labeled_imgs, labels, unlabeled_weak, unlabeled_strong):
            loss, grads = self.grad_fn(labeled_imgs, labels, unlabeled_weak, unlabeled_strong)
            self.optimizer(grads)
            return loss
    
    # 创建训练网络
    net = DistillationTrainStep(student_model, teacher_model, optimizer, args)
    
    # 定义评估函数
    def evaluate():
        correct = 0
        total = 0
        model.set_train(False)
        for data in test_ds.create_tuple_iterator():
            if args.dataset == 'cifar10':
                images, labels = data
            else:  # cifar100
                images, fine_labels, _ = data  # 解包三个返回值
                labels = fine_labels
            
            outputs = model(images)
            _, predicted = ops.max(outputs, axis=1)
            predicted = predicted.astype(ms.int32)
            labels = labels.astype(ms.int32)
            total += labels.shape[0]
            correct += (predicted == labels).astype(ms.float32).sum().asnumpy().item()
        return correct / total
    
    # 创建模型保存目录
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    
    # 开始训练
    print("开始知识蒸馏半监督训练...")
    best_acc = 0.0
    
    labeled_iter = labeled_ds.create_dict_iterator()
    unlabeled_iter = unlabeled_ds.create_dict_iterator()
    
    for epoch in range(args.epochs):
        # 设置训练模式
        student_model.set_train(True)
        teacher_model.set_train(False)  # 教师模型始终处于评估模式
        
        # 训练一个epoch
        total_loss = 0.0
        steps = min(len(labeled_data) // args.batch_size, len(unlabeled_data) // args.batch_size)
        
        logging.info(f"Epoch {epoch+1}/{args.epochs} 开始训练...")
        epoch_progress = tqdm(range(steps), desc=f"Epoch {epoch+1}/{args.epochs}")
        
        for step in epoch_progress:
            try:
                labeled_batch = next(labeled_iter)
            except StopIteration:
                labeled_iter = labeled_ds.create_dict_iterator()
                labeled_batch = next(labeled_iter)
            
            try:
                unlabeled_batch = next(unlabeled_iter)
            except StopIteration:
                unlabeled_iter = unlabeled_ds.create_dict_iterator()
                unlabeled_batch = next(unlabeled_iter)
            
            labeled_imgs = labeled_batch["image"]
            labels = labeled_batch["label"]
            weak_imgs = unlabeled_batch["weak_image"]
            strong_imgs = unlabeled_batch["strong_image"]
            
            loss = net(labeled_imgs, labels, weak_imgs, strong_imgs)
            total_loss += loss.asnumpy().item()
            
            # 更新进度条
            epoch_progress.set_postfix(loss=f"{loss.asnumpy().item():.4f}")
            
            if step % 50 == 0:
                logging.info(f"Epoch: {epoch+1}/{args.epochs}, Step: {step+1}/{steps}, Loss: {loss.asnumpy().item():.4f}")
        
        # 评估模型
        avg_loss = total_loss / steps
        logging.info(f"Epoch: {epoch+1}/{args.epochs}, Avg Loss: {avg_loss:.4f}")
        
        # 使用学生模型评估
        top1_acc = evaluate()
        logging.info(f"Epoch: {epoch+1}/{args.epochs}, TOP1 Accuracy: {top1_acc:.4f}")
        
        # 保存最佳模型 (基于TOP1准确率)
        if top1_acc > best_acc:
            best_acc = top1_acc
            ms.save_checkpoint(student_model, os.path.join(args.save_dir, f"{args.student_model}_best.ckpt"))
            logging.info(f"已保存最佳模型，TOP1准确率: {best_acc:.4f}")
        
        # 每隔一定轮数保存检查点
        if (epoch + 1) % args.save_interval == 0:
            ms.save_checkpoint(student_model, os.path.join(args.save_dir, f"{args.student_model}_epoch{epoch+1}.ckpt"))
            logging.info(f"已保存第{epoch+1}轮检查点")
    
    print(f"训练完成！最佳TOP1准确率: {best_acc:.4f}")


def train_teacher(model, optimizer, train_ds, test_ds, epochs, labeled_size):
    """训练教师模型的函数"""
    # 创建动态学习率
    total_steps = epochs * (labeled_size // args.batch_size)
    lr = nn.cosine_decay_lr(
        min_lr=0.0,
        max_lr=args.lr,
        total_step=total_steps,
        step_per_epoch=labeled_size // args.batch_size,
        decay_epoch=epochs
    )
    dynamic_lr = ms.Parameter(ms.Tensor(lr, ms.float32))
    
    # 创建新的优化器，使用动态学习率
    optimizer = nn.Momentum(
        params=model.trainable_params(),
        learning_rate=dynamic_lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )
    
    # 定义训练步骤
    class TeacherTrainStep(nn.Cell):
        def __init__(self, network, optimizer):
            super(TeacherTrainStep, self).__init__()
            self.network = network
            self.optimizer = optimizer
            self.grad_fn = ops.value_and_grad(self.forward, None, self.network.trainable_params())
            self.loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
        
        def forward(self, images, labels):
            logits = self.network(images)
            loss = self.loss_fn(logits, labels)
            return loss
        
        def construct(self, images, labels):
            loss, grads = self.grad_fn(images, labels)
            self.optimizer(grads)
            return loss
    
    # 创建训练网络
    train_step = TeacherTrainStep(model, optimizer)
    
    # 定义评估函数
    def evaluate():
        """在训练教师模型时使用的独立评估函数"""
        correct = 0
        total = 0
        model.set_train(False)
        
        # 避免使用process_test_batch，直接处理测试数据
        for data in test_ds.create_tuple_iterator():
            # 根据数据集类型获取图像和标签
            if args.dataset == 'cifar10':
                images, labels = data
            else:  # cifar100
                images, fine_labels, _ = data  # 解包三个返回值
                labels = fine_labels
                
            # 直接使用张量数据进行预测
            outputs = model(images)
            _, predicted = ops.max(outputs, axis=1)
            predicted = predicted.astype(ms.int32)
            labels = labels.astype(ms.int32)
            total += labels.shape[0]
            correct += (predicted == labels).astype(ms.float32).sum().asnumpy().item()
            
        return correct / total
    
    logging.info("开始训练教师模型...")
    best_acc = 0.0
    
    # 训练循环
    for epoch in range(epochs):
        model.set_train(True)
        total_loss = 0.0
        steps = 0
        
        # 计算每个epoch的总步数
        total_steps = labeled_size // args.batch_size
        
        # 使用tqdm显示进度
        train_iter = tqdm(range(total_steps), desc=f"Teacher Epoch {epoch+1}/{epochs}")
        
        data_iter = train_ds.create_dict_iterator()
        for _ in train_iter:
            try:
                data = next(data_iter)
            except StopIteration:
                data_iter = train_ds.create_dict_iterator()
                data = next(data_iter)
            
            images, labels = data["image"], data["label"]
            loss = train_step(images, labels)
            total_loss += loss.asnumpy().item()
            steps += 1
            
            # 更新进度条
            train_iter.set_postfix(loss=f"{loss.asnumpy().item():.4f}")
            
            # 每个epoch训练固定步数
            if steps >= total_steps:
                break
        
        # 计算平均损失
        avg_loss = total_loss / steps
        
        # 评估模型
        acc = evaluate()
        
        # 记录训练信息
        logging.info(f"Teacher Epoch: {epoch+1}/{epochs}, Avg Loss: {avg_loss:.4f}, Accuracy: {acc:.4f}")
        
        # 保存最佳模型
        if acc > best_acc:
            best_acc = acc
            ms.save_checkpoint(model, "teacher_best.ckpt")
            logging.info(f"保存最佳教师模型，准确率: {acc:.4f}")
    
    # 训练结束后加载最佳模型
    param_dict = ms.load_checkpoint("teacher_best.ckpt")
    ms.load_param_into_net(model, param_dict)
    logging.info(f"教师模型训练完成，最佳准确率: {best_acc:.4f}")


if __name__ == "__main__":
    train_distill() 