<a href="https://colab.research.google.com/github/nanpolend/machine-learning/blob/master/%E9%A0%90%E8%A8%93%E7%B7%B4%E5%92%8Ctensorflow%E8%A8%93%E7%B7%B4%E5%9C%96%E8%A1%A8%E5%92%8C%E8%B6%85%E5%8F%83%E6%95%B8%E8%AA%BF%E6%95%B4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ========== Colab环境设置 ==========
!pip install torch==2.3.0+cu121 torchvision==0.18.0+cu121 --extra-index-url https://download.pytorch.org/whl/cu121
!pip install tensorboard

# ========== 导入库 ==========
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.cuda.amp import GradScaler, autocast
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
import numpy as np
import os

# ========== 超参数配置 ==========
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE = 128
LR = 1e-3
EPOCHS = 20  # Colab运行建议减少epoch数
PATIENCE = 5
MODEL_SAVE_PATH = '/content/best_model.pth'  # Colab路径调整
DATA_PATH = '/content/data'  # Colab数据存储路径

# ========== 数据增强配置 ==========
class CIFAR10Enhanced(torchvision.datasets.CIFAR10):
    """可视化数据增强的扩展数据集类"""
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def show_augmentation(self, num_samples=4):
        """可视化数据增强效果"""
        indices = np.random.choice(len(self), num_samples)
        fig, axes = plt.subplots(1, num_samples, figsize=(15, 3))
        for i, idx in enumerate(indices):
            img, label = self[idx]
            img = inv_normalize(img).numpy().transpose((1, 2, 0))
            axes[i].imshow(img)
            axes[i].set_title(f'Label: {self.classes[label]}')
            axes[i].axis('off')
        plt.show()

# ========== 数据预处理 ==========
CIFAR_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR_STD = (0.2023, 0.1994, 0.2010)

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(CIFAR_MEAN, CIFAR_STD)
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(CIFAR_MEAN, CIFAR_STD)
])

# 反标准化转换
inv_normalize = transforms.Normalize(
    mean=[-m/s for m, s in zip(CIFAR_MEAN, CIFAR_STD)],
    std=[1/s for s in CIFAR_STD]
)

# ========== 数据加载与可视化 ==========
# 创建数据目录
os.makedirs(DATA_PATH, exist_ok=True)

# 加载数据集
train_dataset = CIFAR10Enhanced(
    root=DATA_PATH, train=True, download=True, transform=train_transform)
test_dataset = CIFAR10Enhanced(
    root=DATA_PATH, train=False, download=True, transform=test_transform)

# 可视化数据增强效果
print("🎨 数据增强可视化：")
train_dataset.show_augmentation()

# 创建数据加载器
train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=2, pin_memory=True, persistent_workers=True
)
test_loader = DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=2, pin_memory=True, persistent_workers=True
)

# ========== 模型定义 ==========
class EnhancedCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128*4*4, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, 10)
        )
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

model = EnhancedCNN().to(DEVICE)
print(f"🔄 模型已加载到 {DEVICE}")

# ========== 训练配置 ==========
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS)
scaler = GradScaler()
writer = SummaryWriter('/content/logs')  # Colab日志路径

# 记录模型结构
dummy_input = torch.randn(1, 3, 32, 32).to(DEVICE)
writer.add_graph(model, dummy_input)

# ========== 训练函数 ==========
def train_epoch(epoch):
    model.train()
    total_loss = 0.0
    total_batches = len(train_loader)

    for batch_idx, (images, labels) in enumerate(train_loader):
        images = images.to(DEVICE, non_blocking=True)
        labels = labels.to(DEVICE, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        with autocast():
            outputs = model(images)
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        # 记录梯度直方图
        if batch_idx == total_batches - 1:
            for name, param in model.named_parameters():
                writer.add_histogram(f'Gradients/{name}', param.grad, epoch)

        scaler.step(optimizer)
        scaler.update()

        # 记录训练图像
        if batch_idx == 0:
            with torch.no_grad():
                images_inv = inv_normalize(images)
                img_grid = torchvision.utils.make_grid(images_inv.cpu())
                writer.add_image('Training Images', img_grid, epoch)

        total_loss += loss.item() * images.size(0)

    # 记录权重分布
    for name, param in model.named_parameters():
        writer.add_histogram(f'Weights/{name}', param, epoch)

    return total_loss / len(train_dataset)

# ========== 验证函数 ==========
def validate(epoch):
    model.eval()
    total_loss = 0.0
    correct = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)

            outputs = model(images)
            loss = criterion(outputs, labels)

            total_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()

    return total_loss / len(test_dataset), correct / len(test_dataset)

# ========== 主训练循环 ==========
best_acc = 0.0
patience_counter = 0

print("\n🚀 训练启动:")
for epoch in range(1, EPOCHS + 1):
    train_loss = train_epoch(epoch)
    val_loss, val_acc = validate(epoch)
    scheduler.step()

    # 记录指标
    writer.add_scalars('Loss', {'train': train_loss, 'val': val_loss}, epoch)
    writer.add_scalar('Accuracy/val', val_acc, epoch)
    writer.add_scalar('Learning Rate', optimizer.param_groups[0]['lr'], epoch)

    # 进度输出
    print(f"Epoch {epoch:02d} | "
          f"Train Loss: {train_loss:.4f} | "
          f"Val Loss: {val_loss:.4f} | "
          f"Val Acc: {val_acc:.2%} | "
          f"LR: {optimizer.param_groups[0]['lr']:.2e}")

    # 早停机制
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), MODEL_SAVE_PATH)
        patience_counter = 0
        print(f"💾 模型已保存 (准确率 {val_acc:.2%})")
    else:
        patience_counter += 1
        if patience_counter >= PATIENCE:
            print("🛑 提前停止训练")
            break

writer.close()

# ========== 最终测试 ==========
print("\n🔍 最终测试:")
model.load_state_dict(torch.load(MODEL_SAVE_PATH))
final_loss, final_acc = validate(0)
print(f"🏆 测试准确率: {final_acc:.2%}")

# ========== TensorBoard启动指令 ==========
print("\n🔬 启动TensorBoard：")
!tensorboard --logdir=/content/logs --port=6006