#  下载数据 

In [None]:
!gdown --id '1awF7pZ9Dz7X1jn1_QAiKN-_v56veCEKy' --output food-11.zip

# Unzip the dataset.
!unzip -q food-11.zip


# 导入包

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import transforms
import numpy as np
import sys
import os
import random
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import torch.nn.functional as F
import pandas as pd
from datetime import datetime
import torch.optim as optim
from torchvision.datasets import DatasetFolder


# 构建ResNet-18-512模型

## 基础残差块

In [None]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


## 适配512×512的ResNet-18模型

In [None]:
class ResNet18_512(nn.Module):
    def __init__(self, block, layers, num_classes=11, num_channels=3, dropout_rate=0.5):
        super(ResNet18_512, self).__init__()
        self.in_channels = 64
        self.dropout_rate = dropout_rate

        # 初始卷积层 - 针对512×512调整
        self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=7, stride=2,
                               padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # 残差层
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        # 分类器
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(dropout_rate)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        # 权重初始化
        self._initialize_weights()

    def _make_layer(self, block, out_channels, blocks, stride=1):
        downsample = None
        if stride != 1 or self.in_channels != out_channels * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, out_channels * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * block.expansion),
            )

        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.in_channels, out_channels))

        return nn.Sequential(*layers)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bin_channels, 0)

    def forward(self, x):
        # 输入: 512×512
        x = self.conv1(x)    # 512→256
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)  # 256→128

        x = self.layer1(x)   # 128→128
        x = self.layer2(x)   # 128→64
        x = self.layer3(x)   # 64→32
        x = self.layer4(x)   # 32→16

        x = self.avgpool(x)  # 16→1
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        x = self.fc(x)

        return x

## 创建模型的函数

In [None]:
def resnet18_512(num_classes=11):
    """创建适配512×512的ResNet-18模型"""
    return ResNet18_512(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)

# 添加工具函数

## 绘制训练和损失曲线

In [None]:
def plot_loss_curves(train_losses, val_losses, save_path=None):
    """绘制训练和验证损失曲线"""
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='训练损失', linewidth=2)
    plt.plot(val_losses, label='验证损失', linewidth=2)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('训练和验证损失曲线')
    plt.legend()
    plt.grid(True, alpha=0.3)

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"图片已保存到: {save_path}")

## 绘制训练准确率曲线

In [None]:
def plot_accuracy_curves(train_accuracies, val_accuracies, save_path=None):
    """绘制准确率曲线"""
    plt.figure(figsize=(10, 6))
    epochs = range(1, len(train_accuracies) + 1)

    plt.plot(epochs, train_accuracies, 'b-', label='训练准确率', linewidth=2)
    plt.plot(epochs, val_accuracies, 'r-', label='验证准确率', linewidth=2)

    plt.title('训练和验证准确率', fontsize=14, fontweight='bold')
    plt.xlabel('Epochs', fontsize=12)
    plt.ylabel('Accuracy (%)', fontsize=12)
    plt.legend(fontsize=12)
    plt.grid(True, alpha=0.3)
    plt.ylim(bottom=0)

    # 添加最佳准确率标注
    best_val_acc = max(val_accuracies)
    best_epoch = val_accuracies.index(best_val_acc) + 1
    plt.axvline(x=best_epoch, color='gray', linestyle='--', alpha=0.7)
    plt.text(best_epoch, best_val_acc / 2, f'最佳: {best_val_acc:.2f}%\nEpoch: {best_epoch}',
             ha='center', va='center', fontsize=10, 
             bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
    else:
        plt.show()

## 绘制综合曲线

In [None]:
def plot_training_curves(train_losses, val_losses, train_accuracies, val_accuracies, save_path=None):
    """绘制综合训练曲线"""
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))
    epochs = range(1, len(train_losses) + 1)

    # 绘制损失曲线
    ax1.plot(epochs, train_losses, 'b-', label='训练损失', linewidth=2)
    ax1.plot(epochs, val_losses, 'r-', label='验证损失', linewidth=2)
    ax1.set_title('训练和验证损失', fontsize=14, fontweight='bold')
    ax1.set_ylabel('Loss', fontsize=12)
    ax1.legend(fontsize=12)
    ax1.grid(True, alpha=0.3)

    # 绘制准确率曲线
    ax2.plot(epochs, train_accuracies, 'b-', label='训练准确率', linewidth=2)
    ax2.plot(epochs, val_accuracies, 'r-', label='验证准确率', linewidth=2)
    ax2.set_title('训练和验证准确率', fontsize=14, fontweight='bold')
    ax2.set_xlabel('Epochs', fontsize=12)
    ax2.set_ylabel('Accuracy (%)', fontsize=12)
    ax2.legend(fontsize=12)
    ax2.grid(True, alpha=0.3)
    ax2.set_ylim(bottom=0)

    # 添加最佳准确率标注
    best_val_acc = max(val_accuracies)
    best_epoch = val_accuracies.index(best_val_acc) + 1
    ax2.axvline(x=best_epoch, color='gray', linestyle='--', alpha=0.7)
    ax2.text(best_epoch, best_val_acc / 2, f'最佳: {best_val_acc:.2f}%',
             ha='center', va='center', fontsize=10,
             bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
    else:
        plt.show()


# 半监督学习训练器

## 图片加载器

In [None]:
# 定义可序列化的图片加载函数
def pil_loader(path):
    """使用PIL加载图片，支持多种格式"""
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')

## 训练器主体类

In [None]:
class SemiSupervisedTrainer:
    """优化版半监督训练器"""
    
    def __init__(self, model, train_loader, unlabeled_loader, valid_loader,
                 optimizer, criterion, device, pseudo_threshold=0.9, consistency_weight=0.3):
        self.model = model
        self.train_loader = train_loader
        self.unlabeled_loader = unlabeled_loader
        self.valid_loader = valid_loader
        self.optimizer = optimizer
        self.criterion = criterion
        self.device = device
        self.pseudo_threshold = pseudo_threshold
        self.consistency_weight = consistency_weight

        # 数据增强
        self.weak_augment = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(degrees=10),
            transforms.ColorJitter(brightness=0.2, contrast=0.2),
        ])

        self.strong_augment = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(degrees=15),
            transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
            transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
        ])
        
        # 训练历史记录
        self.pseudo_label_stats = []

    def generate_pseudo_labels(self, epoch=None):
        """优化版伪标签生成：动态阈值调整"""
        # 动态调整阈值：前期严格，后期放宽
        if epoch < 20:
            confidence_threshold = self.pseudo_threshold + 0.03  # 更严格
        elif epoch < 50:
            confidence_threshold = self.pseudo_threshold + 0.01
        else:
            confidence_threshold = self.pseudo_threshold - 0.02  # 稍宽松

        self.model.eval()
        pseudo_data = []
        pseudo_labels = []
        pseudo_confidences = []
        total_samples = 0

        with torch.no_grad():
            for data, _ in self.unlabeled_loader:
                data = data.to(self.device)
                outputs = self.model(data)
                probabilities = F.softmax(outputs, dim=1)
                max_probs, predictions = torch.max(probabilities, 1)
                
                total_samples += data.size(0)
                
                # 筛选高置信度样本
                mask = max_probs > confidence_threshold
                high_conf_data = data[mask].cpu()
                high_conf_preds = predictions[mask].cpu()
                high_conf_probs = max_probs[mask].cpu()
                
                if len(high_conf_data) > 0:
                    for i in range(len(high_conf_data)):
                        pseudo_data.append(high_conf_data[i])
                        pseudo_labels.append(high_conf_preds[i])
                        pseudo_confidences.append(high_conf_probs[i].item())

        # 统计信息
        selection_rate = len(pseudo_data) / total_samples if total_samples > 0 else 0
        avg_confidence = np.mean(pseudo_confidences) if pseudo_confidences else 0
        
        print(f"Epoch {epoch}: 生成了 {len(pseudo_data)}/{total_samples} 个伪标签 "
              f"(选择率: {selection_rate:.3f}, 阈值: {confidence_threshold:.3f}, "
              f"平均置信度: {avg_confidence:.4f})")
        
        # 记录统计信息
        self.pseudo_label_stats.append({
            'epoch': epoch,
            'count': len(pseudo_data),
            'selection_rate': selection_rate,
            'avg_confidence': avg_confidence,
            'threshold': confidence_threshold
        })
        
        return pseudo_data, pseudo_labels

    def get_dynamic_consistency_weight(self, epoch):
        """动态调整一致性权重"""
        if epoch < 10:
            return self.consistency_weight * 0.3  # 前期注重有监督学习
        elif epoch < 30:
            return self.consistency_weight * 0.7
        elif epoch < 60:
            return self.consistency_weight
        else:
            return self.consistency_weight * 1.2  # 后期增加无监督学习

    def consistency_loss(self, unlabeled_batch):
        """计算一致性损失"""
        batch_size = unlabeled_batch.size(0)

        # 弱增强
        weak_aug = self.weak_augment(unlabeled_batch)
        # 强增强
        strong_aug = self.strong_augment(unlabeled_batch)

        # 获取预测
        with torch.no_grad():
            weak_output = F.softmax(self.model(weak_aug), dim=1)

        strong_output = F.log_softmax(self.model(strong_aug), dim=1)

        # 计算KL散度损失
        consistency_loss = F.kl_div(strong_output, weak_output, reduction='batchmean')
        return consistency_loss

    def calculate_topk_accuracy(self, outputs, targets, k=5):
        """计算top-k准确率"""
        _, topk_pred = outputs.topk(k, 1, True, True)
        topk_correct = topk_pred.eq(targets.view(-1, 1).expand_as(topk_pred))
        return topk_correct.any(1).sum().item()

    def train_epoch(self, epoch, use_consistency=True, use_pseudo_labels=False):
        """优化版训练epoch"""
        self.model.train()
        train_loss = 0.0
        train_correct = 0
        train_top5_correct = 0
        train_total = 0
        
        # 动态权重
        current_consistency_weight = self.get_dynamic_consistency_weight(epoch) if use_consistency else 0.0
        
        # 改进的伪标签策略
        pseudo_dataset = None
        if use_pseudo_labels and epoch % 5 == 0 and epoch >= 10:
            pseudo_data, pseudo_labels = self.generate_pseudo_labels(epoch=epoch)
            if pseudo_data and len(pseudo_data) > 50:  # 只有足够多的伪标签时才使用
                pseudo_dataset = list(zip(pseudo_data, pseudo_labels))
                print(f"使用 {len(pseudo_data)} 个伪标签样本扩展训练集")
            else:
                print("伪标签样本不足，跳过使用")

        # 创建无标签数据迭代器
        unlabeled_iter = iter(self.unlabeled_loader)

        for batch_idx, (data, target) in enumerate(self.train_loader):
            data, target = data.to(self.device), target.to(self.device)

            self.optimizer.zero_grad()

            # 有监督损失
            output = self.model(data)
            supervised_loss = self.criterion(output, target)

            total_loss = supervised_loss

            # 无监督损失（一致性正则化）
            consistency_loss = torch.tensor(0.0)
            if use_consistency and current_consistency_weight > 0:
                try:
                    unlabeled_data, _ = next(unlabeled_iter)
                    unlabeled_data = unlabeled_data.to(self.device)
                    consistency_loss = self.consistency_loss(unlabeled_data)
                    total_loss = supervised_loss + current_consistency_weight * consistency_loss
                except StopIteration:
                    unlabeled_iter = iter(self.unlabeled_loader)

            total_loss.backward()
            self.optimizer.step()

            train_loss += total_loss.item()

            # 计算准确率
            _, predicted = torch.max(output.data, 1)
            train_total += target.size(0)
            train_correct += (predicted == target).sum().item()
            train_top5_correct += self.calculate_topk_accuracy(output, target, k=5)

            # 定期打印训练信息
            if batch_idx % 50 == 0:
                cons_loss_val = consistency_loss.item() if use_consistency else 0.0
                print(f'Epoch: {epoch} | Batch: {batch_idx}/{len(self.train_loader)} | '
                      f'总损失: {total_loss.item():.4f} | 有监督: {supervised_loss.item():.4f} | '
                      f'一致性: {cons_loss_val:.4f} | 权重: {current_consistency_weight:.4f}')

        # 计算平均训练损失和准确率
        avg_train_loss = train_loss / len(self.train_loader)
        train_accuracy = 100.0 * train_correct / train_total
        train_top5_accuracy = 100.0 * train_top5_correct / train_total

        return avg_train_loss, train_accuracy, train_top5_accuracy

    def validate(self):
        """验证模型性能"""
        self.model.eval()
        val_loss = 0.0
        val_correct = 0
        val_top5_correct = 0
        val_total = 0

        with torch.no_grad():
            for data, target in self.valid_loader:
                data, target = data.to(self.device), target.to(self.device)

                output = self.model(data)
                loss = self.criterion(output, target)

                val_loss += loss.item()
                _, predicted = torch.max(output.data, 1)
                val_total += target.size(0)
                val_correct += (predicted == target).sum().item()
                val_top5_correct += self.calculate_topk_accuracy(output, target, k=5)

        avg_val_loss = val_loss / len(self.valid_loader)
        val_accuracy = 100.0 * val_correct / val_total
        val_top5_accuracy = 100.0 * val_top5_correct / val_total

        return avg_val_loss, val_accuracy, val_top5_accuracy


# 训练配置

In [None]:
# 基本参数
BASE_DIR = os.getcwd() 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

# 创建结果目录
now_time = datetime.now()
time_str = datetime.strftime(now_time, '%m-%d_%H-%M')
log_dir = os.path.join(BASE_DIR, "results", time_str)
if not os.path.exists(log_dir):
    os.makedirs(log_dir)
print(f"结果保存目录: {log_dir}")

# 超参数（优化版）
MAX_EPOCH = 200
BATCH_SIZE = 32
LR = 0.001
PATIENCE = 25  # 增加耐心值

# 半监督参数优化
PSEUDO_THRESHOLD = 0.90  # 稍微降低阈值
CONSISTENCY_WEIGHT = 0.2  # 增加一致性权重

print("训练配置完成")

# 数据加载

## 数据增强

In [None]:
# 数据增强
train_tfm = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
    transforms.RandomCrop(512, padding=16),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_tfm = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

## 加载数据

In [None]:
# 加载数据集
train_set = DatasetFolder(
    os.path.join(BASE_DIR, "food-11", "training", "labeled"),
    loader=pil_loader,
    extensions="jpg",
    transform=train_tfm
)

valid_set = DatasetFolder(
    os.path.join(BASE_DIR, "food-11", "validation"),
    loader=pil_loader,
    extensions="jpg",
    transform=test_tfm
)

unlabeled_set = DatasetFolder(
    os.path.join(BASE_DIR, "food-11", "training", "unlabeled"),
    loader=pil_loader,
    extensions="jpg",
    transform=train_tfm
)
print(f"训练集: {len(train_set)}")
print(f"无标签集: {len(unlabeled_set)}")
print(f"验证集: {len(valid_set)}")

## 数据加载器

In [None]:
# 数据加载器
num_workers = 2 if os.name == 'nt' else 4
pin_memory = (device.type == 'cuda')

train_loader = DataLoader(
    train_set,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=pin_memory
)

unlabeled_loader = DataLoader(
    unlabeled_set,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=pin_memory
)

valid_loader = DataLoader(
    valid_set,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=pin_memory
)

print("数据加载完成")

# 模型初始化

In [None]:
model = resnet18_512(num_classes=11)
model.to(device)

# 损失函数和优化器

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)

# 使用CosineAnnealingWarmRestarts学习率调度器
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, T_0=50, T_mult=2, eta_min=1e-6
)

print(f"模型初始化完成，参数量: {sum(p.numel() for p in model.parameters()):,}")

# 创建训练器

In [None]:
trainer = SemiSupervisedTrainer(
    model=model,
    train_loader=train_loader,
    unlabeled_loader=unlabeled_loader,
    valid_loader=valid_loader,
    optimizer=optimizer,
    criterion=criterion,
    device=device,
    pseudo_threshold=PSEUDO_THRESHOLD,
    consistency_weight=CONSISTENCY_WEIGHT
)

print("训练器创建完成")

# 开始训练

## 初始化记录变量

In [None]:
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []
train_top5_accuracies = []
val_top5_accuracies = []
learning_rates = []
best_val_accuracy = 0.0
early_stop_counter = 0
best_epoch = 0

## 正式开始训练

In [None]:
print("开始半监督训练...")

for epoch in range(MAX_EPOCH):
    # 训练阶段
    train_loss, train_accuracy, train_top5_accuracy = trainer.train_epoch(
        epoch,
        use_consistency=True,
        use_pseudo_labels=True
    )

    train_losses.append(train_loss)
    train_accuracies.append(train_accuracy)
    train_top5_accuracies.append(train_top5_accuracy)

    # 更新学习率并记录
    current_lr = scheduler.get_last_lr()[0]
    learning_rates.append(current_lr)
    scheduler.step()

    # 验证阶段
    val_loss, val_accuracy, val_top5_accuracy = trainer.validate()
    val_losses.append(val_loss)
    val_accuracies.append(val_accuracy)
    val_top5_accuracies.append(val_top5_accuracy)

    # 早停判断和模型保存
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        best_epoch = epoch
        early_stop_counter = 0

        # 保存最佳模型
        checkpoint = {
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict(),
            "epoch": epoch,
            "best_val_accuracy": best_val_accuracy,
            "val_loss": val_loss,
            "train_accuracy": train_accuracy,
            "train_loss": train_loss
        }
        path_checkpoint = os.path.join(log_dir, "checkpoint_best.pkl")
        torch.save(checkpoint, path_checkpoint)
        print(f"✅ 保存最佳模型，验证准确率: {best_val_accuracy:.2f}%")
    
    # 保存接近最佳的模型（用于集成）
    elif val_accuracy > best_val_accuracy - 2.0:
        checkpoint = {
            "model_state_dict": model.state_dict(),
            "val_accuracy": val_accuracy,
            "epoch": epoch
        }
        torch.save(checkpoint, os.path.join(log_dir, f"checkpoint_epoch_{epoch}_acc_{val_accuracy:.2f}.pth"))

    else:
        early_stop_counter += 1

    # 打印训练信息
    print(f'Epoch: {epoch:03d}/{MAX_EPOCH}, '
          f'训练损失: {train_loss:.4f}, 训练准确率: {train_accuracy:.2f}% (Top-5: {train_top5_accuracy:.2f}%), '
          f'验证损失: {val_loss:.4f}, 验证准确率: {val_accuracy:.2f}% (Top-5: {val_top5_accuracy:.2f}%), '
          f'学习率: {current_lr:.6f}, '
          f'最佳: {best_val_accuracy:.2f}% @ Epoch {best_epoch}, '
          f'早停计数: {early_stop_counter}/{PATIENCE}')
    print('-' * 80)

    # 早停检查
    if early_stop_counter >= PATIENCE:
        print(f"🚨 早停触发！在 epoch {epoch} 停止训练")
        print(f"🏆 最佳模型在 epoch {best_epoch}, 验证准确率: {best_val_accuracy:.2f}%")
        break

# 训练完成
print(f"训练完成！最终最佳验证准确率: {best_val_accuracy:.2f}%")

# 保存训练记录

In [None]:
training_history = {
    'train_losses': train_losses,
    'val_losses': val_losses,
    'train_accuracies': train_accuracies,
    'val_accuracies': val_accuracies,
    'train_top5_accuracies': train_top5_accuracies,
    'val_top5_accuracies': val_top5_accuracies,
    'learning_rates': learning_rates,
    'best_val_accuracy': best_val_accuracy,
    'best_epoch': best_epoch,
    'pseudo_label_stats': trainer.pseudo_label_stats
}

torch.save(training_history, os.path.join(log_dir, 'training_history.pth'))

# 绘制训练曲线
picture_path_loss = os.path.join(log_dir, 'loss_curves.png')
picture_path_acc = os.path.join(log_dir, 'accuracy_curves.png')
picture_path_combined = os.path.join(log_dir, 'training_curves.png')

plot_loss_curves(train_losses, val_losses, picture_path_loss)
plot_accuracy_curves(train_accuracies, val_accuracies, picture_path_acc)
plot_training_curves(train_losses, val_losses, train_accuracies, val_accuracies, picture_path_combined)

print(f"训练曲线已保存至: {log_dir}")

# 伪标签统计

In [None]:
if trainer.pseudo_label_stats:
    pseudo_df = pd.DataFrame(trainer.pseudo_label_stats)
    plt.figure(figsize=(12, 8))
    
    plt.subplot(2, 2, 1)
    plt.plot(pseudo_df['epoch'], pseudo_df['count'], 'b-o')
    plt.title('伪标签数量变化')
    plt.xlabel('Epoch')
    plt.ylabel('数量')
    plt.grid(True, alpha=0.3)
    
    plt.subplot(2, 2, 2)
    plt.plot(pseudo_df['epoch'], pseudo_df['selection_rate'], 'g-o')
    plt.title('伪标签选择率')
    plt.xlabel('Epoch')
    plt.ylabel('选择率')
    plt.grid(True, alpha=0.3)
    
    plt.subplot(2, 2, 3)
    plt.plot(pseudo_df['epoch'], pseudo_df['avg_confidence'], 'r-o')
    plt.title('伪标签平均置信度')
    plt.xlabel('Epoch')
    plt.ylabel('置信度')
    plt.grid(True, alpha=0.3)
    
    plt.subplot(2, 2, 4)
    plt.plot(pseudo_df['epoch'], pseudo_df['threshold'], 'purple-o')
    plt.title('动态阈值变化')
    plt.xlabel('Epoch')
    plt.ylabel('阈值')
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(log_dir, 'pseudo_label_stats.png'), dpi=300, bbox_inches='tight')
    plt.show()

# 最终结果统计

In [None]:
print("=" * 80)
print("训练总结:")
print(f"最佳验证准确率: {best_val_accuracy:.2f}% (Epoch {best_epoch})")
print(f"最终训练准确率: {train_accuracies[-1]:.2f}%")
print(f"最终验证准确率: {val_accuracies[-1]:.2f}%")
print(f"训练轮数: {len(train_accuracies)}")
print(f"结果保存目录: {log_dir}")
print("=" * 80)