# ViT CIFAR-10（Colab A100 完整流程）
本 Notebook 依照预设的 8 个章节，在 Colab A100 上完成 Vision Transformer 的安装、训练、评估、推理与 Git 同步。请按顺序执行每个单元。

## 1. 环境准备
- 选择 `Runtime → Change runtime type → GPU → A100`.
- 运行下方单元安装依赖与同步代码。

In [None]:
!nvidia-smi
!pip install --quiet --upgrade pip
# 强制安装 numpy<2 以兼容 torch 2.2.x
!pip install --quiet "numpy<2"
!pip install --quiet torch==2.2.2+cu121 torchvision==0.17.2+cu121 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu121
!pip install --quiet timm==0.9.16 torchmetrics==1.3.2 tensorboard==2.17.0 scikit-learn==1.4.2 einops==0.7.0

Wed Nov 26 14:29:52 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   32C    P0             48W /  400W |       0MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

^C


## 2. Git 设置与仓库同步
- 将个人访问令牌保存到 `GITHUB_TOKEN` 变量后执行下方单元。
- 每次进入 Colab 需重新克隆仓库。
- 提醒：如果代码有更改，记得先push

In [None]:
!git pull origin master

fatal: not a git repository (or any of the parent directories): .git


In [None]:
import getpass, os, subprocess
REPO_URL = "https://github.com/lachlanye/MLLM-from-scratch.git"  
if "GITHUB_TOKEN" not in os.environ:
    # 建议使用 Colab 的 Secrets 功能存储 Token，或者在运行时手动输入
    print("请输入 GitHub Token:")
    os.environ["GITHUB_TOKEN"] = getpass.getpass()

repo_name = REPO_URL.split("/")[-1].replace(".git", "")
if not os.path.exists(repo_name):
    subprocess.run(["git", "clone", f"https://{os.environ['GITHUB_TOKEN']}@" + REPO_URL.split("https://")[-1]], check=True)
%cd $repo_name
!git config user.name "lachlanye"
!git config user.email "colab@example.com"

请输入 GitHub Token:


## 3. 数据集准备
此项目使用 `datasets/cifar10.py` 中的封装，默认会在 `data/` 目录下载 CIFAR-10。

In [11]:
import importlib.util
import sys
from pathlib import Path

data_dir = Path("data")
data_dir.mkdir(exist_ok=True)

# 直接从文件路径加载模块，绕过包缓存
file_path = "datasets/cifar10.py"
spec = importlib.util.spec_from_file_location("datasets.cifar10", file_path)
cifar_module = importlib.util.module_from_spec(spec)
sys.modules["datasets.cifar10"] = cifar_module
spec.loader.exec_module(cifar_module)

train_dataset, test_dataset = cifar_module.build_cifar10_datasets(data_dir=str(data_dir))
len(train_dataset), len(test_dataset)

100%|██████████| 170498071/170498071 [05:32<00:00, 513267.46it/s]



Extracting data/cifar-10-python.tar.gz to data


TypeError: _reconstruct: First argument must be a sub-type of ndarray

## 4. 训练参数与配置
可直接修改 `configs/vit_config.yaml`，或在下方通过 `omegaconf` 动态更新。

In [12]:
import yaml
from copy import deepcopy
config_path = Path("configs/vit_config.yaml")
with open(config_path) as f:
    vit_cfg = yaml.safe_load(f)
display(vit_cfg)
# 示例：在 Notebook 中快速修改批大小
# vit_cfg_overrides = deepcopy(vit_cfg)
# vit_cfg_overrides["training"]["batch_size"] = 256
# print("Override batch size -> 256")
# with open("/tmp/vit_config_colab.yaml", "w") as f:
#     yaml.safe_dump(vit_cfg_overrides, f)
# print("已写入 /tmp/vit_config_colab.yaml，可用于训练脚本")

{'data_params': {'dataset': 'CIFAR10',
  'data_dir': './data/cifar10',
  'img_size': 32,
  'patch_size': 4,
  'in_channels': 3,
  'num_classes': 10,
  'class_names': ['plane',
   'car',
   'bird',
   'cat',
   'deer',
   'dog',
   'frog',
   'horse',
   'ship',
   'truck'],
  'mean': [0.4914, 0.4822, 0.4465],
  'std': [0.247, 0.2435, 0.2616]},
 'model_params': {'d_model': 512,
  'num_layers': 6,
  'n_heads': 8,
  'd_ff': 2048,
  'dropout': 0.1},
 'training_params': {'device': 'cuda',
  'num_epochs': 10,
  'batch_size': 128,
  'learning_rate': 0.001,
  'weight_decay': 0.0001,
  'eval_interval': 5,
  'model_save_path': './checkpoints/vit_cifar10.pth'},
 'prediction_params': {'image_source': ''}}

## 5. 开始训练
默认脚本位于 `vision_transformer/train_vit.py`，训练日志写入 `runs/vit_cifar10`.

In [1]:
!python -m vision_transformer.train_vit --config configs/vit_config.yaml --device cuda --log_dir runs/vit_cifar10
!ls -R runs/vit_cifar10

/usr/bin/python3: Error while finding module specification for 'vision_transformer.train_vit' (ModuleNotFoundError: No module named 'vision_transformer')
ls: cannot access 'runs/vit_cifar10': No such file or directory
ls: cannot access 'runs/vit_cifar10': No such file or directory


## 6. 评估与可视化
使用 `torchmetrics` 计算准确率，并读取 TensorBoard 日志。

In [14]:
import torch
from torch.utils.data import DataLoader
from torchmetrics.classification import MulticlassAccuracy
from torchvision import transforms
from vision_transformer.vit import VisionTransformer
from configs import config_parser
cfg = config_parser.load_config("configs/vit_config.yaml")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
state_dict = torch.load(cfg["inference"]["weights_path"], map_location=device)
model = VisionTransformer(cfg["model"]).to(device)
model.load_state_dict(state_dict)
model.eval()
transform = transforms.Compose([transforms.Resize(224), transforms.ToTensor(), transforms.Normalize(mean=cfg["data"]["mean"], std=cfg["data"]["std"])])
test_dataset.transform = transform
loader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=2)
metric = MulticlassAccuracy(num_classes=10).to(device)
with torch.no_grad():
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        preds = model(images)
        metric.update(preds, labels)
print("Test Top-1 Acc:", metric.compute().item())

ValueError: numpy.dtype size changed, may indicate binary incompatibility. Expected 96 from C header, got 88 from PyObject

In [None]:
%load_ext tensorboard
%tensorboard --logdir runs/vit_cifar10

## 7. 推理与可视化预测
调用 `vision_transformer/predict_vit.py` 或直接在 Notebook 中推理若干图像。

In [None]:
!python -m vision_transformer.predict_vit --config configs/vit_config.yaml --weights_path runs/vit_cifar10/best.ckpt --samples 8
import matplotlib.pyplot as plt
import random
classes = cfg["data"]["classes"]
indices = random.sample(range(len(test_dataset)), 6)
fig, axes = plt.subplots(2, 3, figsize=(10, 6))
model.eval()
for ax, idx in zip(axes.flatten(), indices):
    image, label = test_dataset[idx]
    with torch.no_grad():
        pred = model(image.unsqueeze(0).to(device))
        pred_idx = pred.argmax(dim=1).item()
    ax.imshow(image.permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5)
    ax.set_title(f"GT: {classes[label]}, Pred: {classes[pred_idx]}")
    ax.axis("off")
plt.tight_layout()

## 8. 保存结果并推送
将训练好的权重、日志与 Notebook 推送回远程仓库。

In [None]:
!git status
!git add runs/vit_cifar10 notebooks/vit_cifar10_colab.ipynb configs/vit_config.yaml
!git commit -m "Update ViT Colab run" || echo "Nothing to commit"
!git push origin main

---
**提示**
- 如需重新初始化环境，执行 `Runtime → Factory reset runtime` 后从第 1 节开始。
- 若想只拉取最新改动，可在第 2 节中改为 `git pull origin main`。
- 建议在推送前下载 `runs/` 与 `best.ckpt` 以备份模型。