In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision.transforms import v2
from torchvision import datasets
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

#混淆矩阵
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import itertools

#混淆矩阵计算模块
def get_confusion_matrix(model, data_loader, num_classes, device):
    """收集所有预测结果和真实标签计算混淆矩阵"""
    model.eval()
    preds, labels = [], []
    with torch.no_grad():
        for images, targets in data_loader:
            images = images.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            preds.extend(predicted.cpu().numpy())
            labels.extend(targets.numpy())
    
    return confusion_matrix(labels, preds, labels=np.arange(num_classes))
#混淆矩阵可视化模块
def plot_confusion_matrix(cm, class_names):
    """生成带标签的混淆矩阵Figure对象"""
    fig = plt.figure(figsize=(12, 10))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title("Confusion Matrix")
    plt.colorbar()
    
    tick_marks = np.arange(len(class_names))
    plt.xticks(tick_marks, class_names, rotation=45)
    plt.yticks(tick_marks, class_names)

    # 添加数值标注
    threshold = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        color = "white" if cm[i, j] > threshold else "black"
        plt.text(j, i, cm[i, j], ha="center", va="center", color=color)

    plt.tight_layout()
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    return fig
#TensorBoard集成混淆矩阵模块
def log_confusion_matrix(visualizer, cm, class_names, epoch):
    """将混淆矩阵写入TensorBoard"""
    fig = plot_confusion_matrix(cm, class_names)
    visualizer.writer.add_figure('Confusion_Matrix', fig, epoch)



"""
MLP---main()
   |--配置模块
   |--可视化模块
   |--数据模块
   |--模型模块
   |--训练测试模块
"""


# ==========配置模块==========
class Config:
    # 训练配置
    batch_size_ = 256
    num_epochs_ = 30
    learning_rate_ = 0.005
    device_ = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 模型配置
    hidden_sizes_ = [128, 64]

    # 可视化配置
    log_dir_ = "../logdir/mlp"
    log_images_every_ = 50
    sample_grid_size_ = (4, 8)
    class_names_= ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']


# ==========可视化模块==========
class Visualizer:
    def __init__(self, log_dir):
        self.writer = SummaryWriter(log_dir=log_dir)
        self.step = 0

    def log_scalars(
        self, train_loss=None, test_loss=None, train_acc=None, test_acc=None
    ):
        """添加标量数据到 TensorBoard"""
        if train_loss is not None:
            self.writer.add_scalar("Loss/train", train_loss, self.step)
        if test_loss is not None:
            self.writer.add_scalar("Loss/test", test_loss, self.step)
        if train_acc is not None:
            self.writer.add_scalar("Accuracy/train", train_acc, self.step)
        if test_acc is not None:
            self.writer.add_scalar("Accuracy/test", test_acc, self.step)

    def log_images(self, images, tag="Trainning_images"):
        """添加图像数据到 TensorBoard"""
        img_grid_ = torchvision.utils.make_grid(images)
        self.writer.add_image(tag, img_grid_, self.step)

    def log_model_graph(self, model, dummy_input):
        """添加模型结构图到 TensorBoard"""
        self.writer.add_graph(model, dummy_input)
    

    def close(self):
        self.writer.close()

    def increment_step(self):
        self.step += 1


# ==========数据模块==========
def get_dataloaders(batch_size):
    transform_ = v2.Compose(
        [
            v2.ToTensor(),
            v2.Normalize(
                mean=[0.5],
                std=[0.5],
            ),
        ]
    )
    train_dataset_ = datasets.FashionMNIST(
        root="../data", train=True, download=True, transform=transform_
    )
    test_dataset_ = datasets.FashionMNIST(
        root="../data", train=False, download=True, transform=transform_
    )
    train_loader_ = DataLoader(
        dataset=train_dataset_, shuffle=True, num_workers=4, batch_size=batch_size
    )
    test_loader_ = DataLoader(
        dataset=test_dataset_, shuffle=True, num_workers=4, batch_size=batch_size
    )
    return train_loader_, test_loader_


# ==========模型模块==========
class MLP(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size):
        super().__init__()
        self.layers_ = nn.ModuleList()
        # 添加隐藏层
        prev_size_ = input_size
        for size in hidden_sizes:
            self.layers_.append(nn.Linear(prev_size_, size))
            self.layers_.append(nn.ReLU())
            prev_size_ = size
        # 添加输出层
        self.layers_.append(nn.Linear(prev_size_, output_size))
        self._initialize_weights()

    def forward(self, x):
        x = x.view(x.size(0), -1)
        for layer in self.layers_:
            x = layer(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)


# ========训练测试模块=========
class Trainer:
    def __init__(self, model, optimizer, visualizer):
        self.model_ = model.to(Config.device_)
        self.optimizer_ = optimizer
        self.visualizer_ = visualizer
        self.criterion_ = nn.CrossEntropyLoss()

    def train_epoch(self, train_loader):
        self.model_.train()
        total_loss_ = 0.0
        total_correct_ = 0.0
        total_samples_ = 0
        for batch_idx, (images, labels) in enumerate(train_loader):
            images, labels = images.to(Config.device_), labels.to(Config.device_)
            self.optimizer_.zero_grad()
            outputs_ = self.model_(images)
            loss_ = self.criterion_(outputs_, labels)
            loss_.backward()
            self.optimizer_.step()

            total_loss_ += loss_.item()
            _, predicted_ = torch.max(outputs_, 1)
            total_correct_ += (predicted_ == labels).sum().item()
            total_samples_ += labels.size(0)

            if batch_idx % Config.log_images_every_ == 0:
                self.visualizer_.log_images(images, tag="Training_images")

        avg_loss_ = total_loss_ / len(train_loader)
        avg_acc_ = 100 * total_correct_ / total_samples_
        self.visualizer_.log_scalars(train_loss=avg_loss_, train_acc=avg_acc_)
        self.visualizer_.increment_step()
        return avg_loss_

    def validate(self, test_loader):
        self.model_.eval()
        total_loss_ = 0.0
        correct_ = 0.0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(Config.device_), labels.to(Config.device_)
                outputs_ = self.model_(images)
                loss_ = self.criterion_(outputs_, labels)
                total_loss_ += loss_.item()
                _, predicted_ = torch.max(outputs_, 1)
                correct_ += (predicted_ == labels).sum().item()
        avg_loss_ = total_loss_ / len(test_loader)
        accuracy_ = 100 * correct_ / len(test_loader.dataset)
        return avg_loss_, accuracy_


# =========主程序==========
def main():

    config = Config()
    train_loader_, test_loader_ = get_dataloaders(config.batch_size_)
    visualizer_ = Visualizer(config.log_dir_)

    model_ = MLP(input_size=28 * 28, hidden_sizes=config.hidden_sizes_, output_size=10)
    optimizer_ = optim.Adam(model_.parameters(), lr=config.learning_rate_)
    trainer_ = Trainer(model_, optimizer_, visualizer_)

    dummy_input_ = torch.zeros(1, 28 * 28).to(config.device_)
    visualizer_.log_model_graph(model=model_, dummy_input=dummy_input_)

    
    for epoch in range(config.num_epochs_):
        train_loss_ = trainer_.train_epoch(train_loader_)

        val_loss_, val_acc_ = trainer_.validate(test_loader_)
        visualizer_.log_scalars(test_loss=val_loss_, test_acc=val_acc_)
        cm = get_confusion_matrix(model_, test_loader_, num_classes=10, device=config.device_)
        log_confusion_matrix(visualizer_, cm, config.class_names_, epoch)

        print(
            f"Device: {config.device_} | "
            f"Epoch {epoch+1}/{config.num_epochs_} | "
            f"Train Loss: {train_loss_:.4f} | "
            f"Val Loss: {val_loss_:.4f} | "
            f"Accuracy: {val_acc_:.2f}%"
        )
    visualizer_.close()


if __name__ == "__main__":
    main()



Device: cpu | Epoch 1/30 | Train Loss: 0.7931 | Val Loss: 0.4920 | Accuracy: 82.59%
Device: cpu | Epoch 2/30 | Train Loss: 0.4306 | Val Loss: 0.4512 | Accuracy: 84.37%
Device: cpu | Epoch 3/30 | Train Loss: 0.3890 | Val Loss: 0.4421 | Accuracy: 84.51%
Device: cpu | Epoch 4/30 | Train Loss: 0.3635 | Val Loss: 0.4438 | Accuracy: 84.06%
Device: cpu | Epoch 5/30 | Train Loss: 0.3399 | Val Loss: 0.4075 | Accuracy: 85.56%
Device: cpu | Epoch 6/30 | Train Loss: 0.3226 | Val Loss: 0.3843 | Accuracy: 86.04%
Device: cpu | Epoch 7/30 | Train Loss: 0.3138 | Val Loss: 0.3827 | Accuracy: 86.00%
Device: cpu | Epoch 8/30 | Train Loss: 0.3071 | Val Loss: 0.3873 | Accuracy: 85.95%
Device: cpu | Epoch 9/30 | Train Loss: 0.2962 | Val Loss: 0.3902 | Accuracy: 86.70%
Device: cpu | Epoch 10/30 | Train Loss: 0.2855 | Val Loss: 0.3803 | Accuracy: 86.85%
Device: cpu | Epoch 11/30 | Train Loss: 0.2790 | Val Loss: 0.3769 | Accuracy: 86.67%
Device: cpu | Epoch 12/30 | Train Loss: 0.2754 | Val Loss: 0.3924 | Accura