## 背景介绍
> 
+ 在本项目中，基于MindSpore框架使用 MedMNIST 提供的 PathMNIST 数据集，利用预训练的 Vision Transformer（ViT）模型对结直肠组织图像进行多分类任务；
+ PathMNIST 为 MedMNIST v2 中的一个子集，专注于结直肠癌组织切片的多分类任务，共包含约 100,000 张 28×28 彩色图像，涵盖 9 类组织类型，包括正常粘膜、癌相关间质、淋巴细胞等。  PathMNIST 数据集经过统一尺寸与预处理，便于快速原型开发与模型评估；  
+ 结直肠癌是全球发病率和死亡率较高的恶性肿瘤之一，早期病理筛查与分类对提高诊断准确性和治疗效果至关重要。利用自动化图像识别技术，可减轻病理专家工作量并提升诊断一致性。      

## ViT模型简介
Vision Transformer（ViT）负责提取图像特征，将输入图像编码为视觉表征，其核心思想——**用注意力重新定义视觉信息的表达**，仅处理图像输入，不参与语言理解或生成。Vision Transformer（ViT）是Transformer架构从自然语言处理（NLP）领域向计算机视觉（CV）领域拓展的里程碑式创新。它摒弃了传统卷积神经网络（CNN）依赖局部卷积核的设计，**将图像视为“序列化的文本”**，通过全局注意力机制直接捕捉图像中远距离像素间的依赖关系。核心技术与实现方式：
>- 输入图像被划分为多个 Patch（如 3x3的块），通过线性投影得到 Patch Embedding。
>- 加入位置编码（Positional Encoding）后，输入到多层 Transformer Encoder 中，提取全局上下文特征。
>- 输出为图像的特征序列（如 [CLS] token + Patch Embeddings）。

![架构图](https://pic4.zhimg.com/v2-b6861b011c966bdeeb0a6236a38b3ce5_r.jpg)

优势是：相比 CNN，ViT **能捕捉长距离依赖关系，适合处理复杂视觉场景**。可复用 ImageNet 等大规模数据集上的预训练权重，提升泛化能力。ViT的诞生标志着CV领域从“局部卷积”到“全局注意力”的范式转变，后续衍生出DeiT（数据高效）、Swin Transformer（窗口化注意力）、BEiT（掩码预训练）等变体，推动视觉任务（分类、检测、分割）迈向更高效、灵活的模型架构。

## 环境配置

In [None]:
%%capture captured_output
# 实验环境已经预装了mindspore==2.5.0，如需更换mindspore版本，可更改下面 MINDSPORE_VERSION 变量
!pip uninstall mindspore -y
%env MINDSPORE_VERSION=2.5.0
!pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/${MINDSPORE_VERSION}/MindSpore/unified/aarch64/mindspore-${MINDSPORE_VERSION}-cp39-cp39-linux_aarch64.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com -i https://pypi.tuna.tsinghua.edu.cn/simple

In [1]:
# 查看当前 mindspore 版本
!pip show mindspore

Name: mindspore
Version: 2.5.0
Summary: MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.
Home-page: https://www.mindspore.cn
Author: The MindSpore Authors
Author-email: contact@mindspore.cn
License: Apache 2.0
Location: /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages
Requires: asttokens, astunparse, dill, numpy, packaging, pillow, protobuf, psutil, safetensors, scipy
Required-by: 


通过环境变量配置日志级别来控制 MindSpore 的日志详细程度。

In [None]:
import os
# Set log level: 4(EXCEPTION), 3(ERROR), 2(WARNING), 1(INFO), 0(DEBUG)
os.environ['GLOG_v'] = '3'

In [3]:
%%capture captured_output
!pip install medmnist
!export HF_ENDPOINT=https://hf-mirror.com

## 数据集加载及预处理

In [1]:
import os
from medmnist import INFO, PathMNIST
import mindspore.dataset.vision as vision
from mindspore.dataset import GeneratorDataset


# 指定数据集
data_flag = 'pathmnist'
download_root = './data'
os.makedirs(download_root, exist_ok=True)
# 从 INFO 中读取所有可用字段
info = INFO[data_flag]

# 直接根据标签列表计算类别数
labels = info['label']
num_classes = len(labels)
n_channels = info['n_channels']

# 下载并加载数据
train_dataset = PathMNIST(root=download_root, split='train', download=True)
val_dataset = PathMNIST(root=download_root, split='val', download=True)

In [2]:
from mindspore import dataset as ds

# 数据预处理与 DataLoader 创建函数
def create_dataset(dataset, batch_size=64, shuffle=True):
    transform = [
        vision.Resize((224, 224)),
        vision.ToTensor(),
    ]
    ds_gen = ds.GeneratorDataset(dataset, column_names=['image','label'], shuffle=shuffle)
    ds_gen = ds_gen.map(input_columns=['image'], operations=transform)
    ds_gen = ds_gen.map(input_columns=['label'], operations=lambda x: x.squeeze().astype('int32'))
    ds_gen = ds_gen.batch(batch_size)
    return ds_gen

train_ds = create_dataset(train_dataset, batch_size=64, shuffle=True)
val_ds = create_dataset(val_dataset, batch_size=64, shuffle=False)

## 网络构建
ViT 通过将图像划分为固定大小的“patch”并在 Transformer 编码器上进行自注意力计算，展现出在各类视觉任务上媲美或超越传统 CNN 的性能；采用 MindVision 提供的 vit_b_16 模型，该模型在 ImageNet-21k 上预训练，并在 ImageNet-1k 上微调后提供良好泛化性能；将原有的 1000 类分类头替换为适应 PathMNIST 的 9 类输出层，保留主干 Transformer 权重，以加速下游任务训练。

In [None]:
import mindspore as ms
from mindspore import nn, context
from mindspore.train.callback import LossMonitor, ModelCheckpoint, CheckpointConfig
from mindvision.classification.models import vit_b_16
from mindspore.train import Model
from mindspore.train.metrics import Accuracy


# 加载预训练 ViT 并替换分类头
net = vit_b_16(pretrained=True, num_classes=1000, image_size=224)
net.head = nn.Dense(in_channels=768, out_channels=num_classes)  # 9 类

# 定义损失函数与优化器
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
optimizer = nn.Adam(net.trainable_params(), learning_rate=1e-4)

# 初始化 Accuracy 指标
accuracy_metric = Accuracy()

# 定义 Model，确保 metrics 为字典格式
model = Model(net, loss_fn=loss_fn, optimizer=optimizer, metrics={"Accuracy": accuracy_metric})

## 模型训练

In [5]:
import time
from datetime import datetime

# 开始评估
train_start = time.perf_counter()
print(f"[{datetime.now():%Y-%m-%d %H:%M:%S}] Evaluating on validation set...")

# 每 100 步保存一次，最多保留 3 个检查点
ckpt_config = CheckpointConfig(save_checkpoint_steps=100, keep_checkpoint_max=3)
ckpt_cb = ModelCheckpoint(prefix="vit_pathmnist",
                          directory="./checkpoints",
                          config=ckpt_config)

# 开始训练
epoch_num = 5
print(f"Start training for {epoch_num} epochs, dataset size: {train_ds.get_dataset_size()} batches")
model.train(epoch_num,
            train_dataset=train_ds,
            callbacks=[LossMonitor(per_print_times=50), ckpt_cb],
            dataset_sink_mode=False)

# 结束评估
train_end = time.perf_counter()
elapsed_ms = (train_end - train_start) * 1000

print(f"Train time: {elapsed_ms:.2f} ms")

[2025-05-14 02:42:10] Evaluating on validation set...
Start training for 5 epochs, dataset size: 1407 batches


.epoch: 1 step: 50, loss is 2.2472798824310303
epoch: 1 step: 100, loss is 2.205869197845459
epoch: 1 step: 150, loss is 2.380136013031006
epoch: 1 step: 200, loss is 1.8169846534729004
epoch: 1 step: 250, loss is 1.7921384572982788
epoch: 1 step: 300, loss is 1.3815851211547852
epoch: 1 step: 350, loss is 1.4965860843658447
epoch: 1 step: 400, loss is 1.318996548652649
epoch: 1 step: 450, loss is 1.1503033638000488
epoch: 1 step: 500, loss is 1.0637834072113037
epoch: 1 step: 550, loss is 1.2114602327346802
epoch: 1 step: 600, loss is 1.1093289852142334
epoch: 1 step: 650, loss is 1.1249001026153564
epoch: 1 step: 700, loss is 1.189340591430664
epoch: 1 step: 750, loss is 0.9763270616531372
epoch: 1 step: 800, loss is 0.9399876594543457
epoch: 1 step: 850, loss is 0.7920712232589722
epoch: 1 step: 900, loss is 0.7658846378326416
epoch: 1 step: 950, loss is 0.9221714735031128
epoch: 1 step: 1000, loss is 0.8943929672241211
epoch: 1 step: 1050, loss is 1.0630111694335938
epoch: 1 step: 

## 模型评估

In [6]:
import time
from datetime import datetime

# 开始评估
eval_start = time.perf_counter()
print(f"[{datetime.now():%Y-%m-%d %H:%M:%S}] Evaluating on validation set...")

# 评估
metrics = model.eval(val_ds, dataset_sink_mode=False)

# 结束评估
eval_end = time.perf_counter()
elapsed_ms = (eval_end - eval_start) * 1000

# 打印各项指标
for name, value in metrics.items():
    print(f"{name}: {value:.4f}")

print(f"Evaluation time: {elapsed_ms:.2f} ms")

[2025-05-14 03:27:14] Evaluating on validation set...
Accuracy: 0.8784
Evaluation time: 19712.50 ms
