See code at https://github.com/google-research/vision_transformer/

See papers at

- Vision Transformer: https://arxiv.org/abs/2010.11929
- MLP-Mixer: https://arxiv.org/abs/2105.01601
- How to train your ViT: https://arxiv.org/abs/2106.10270
- When Vision Transformers Outperform ResNets without Pretraining or Strong Data Augmentations: https://arxiv.org/abs/2106.01548

This Colab allows you to run the [JAX](https://jax.readthedocs.org) implementation of the Vision Transformer.

If you just want to load a pre-trained checkpoint from a large repository and
directly use it for inference, you probably want to go [this Colab](https://colab.research.google.com/github/google-research/vision_transformer/blob/main/vit_jax_augreg.ipynb).

##### Copyright 2021 Google LLC.

In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

<a href="https://colab.research.google.com/github/google-research/vision_transformer/blob/main/vit_jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Setup

Needs to be executed once in every VM.

The cell below downloads the code from Github and install necessary dependencies.

In [2]:
#@markdown Select whether you would like to store data in your personal drive.
#@markdown
#@markdown If you select **yes**, you will need to authorize Colab to access
#@markdown your personal drive
#@markdown
#@markdown If you select **no**, then any changes you make will diappear when
#@markdown this Colab's VM restarts after some time of inactivity...
use_gdrive = 'no'  #@param ["yes", "no"]

if use_gdrive == 'yes':
  from google.colab import drive
  drive.mount('/gdrive')
  root = '/gdrive/My Drive/vision_transformer_colab'
  import os
  if not os.path.isdir(root):
    os.mkdir(root)
  os.chdir(root)
  print(f'\nChanged CWD to "{root}"')
else:
  from IPython import display
  display.display(display.HTML(
      '<h1 style="color:red">CHANGES NOT PERSISTED</h1>'))

In [3]:
# Clone repository and pull latest changes.
![ -d deit ] || git clone --depth=1 https://github.com/facebookresearch/deit
!cd deit && git pull

Already up to date.


In [4]:
# DeiT is built on top of timm version 0.3.2, so need to install it first
!pip install timm==0.3.2



In [5]:
!pip install torch



In [6]:
!pip install torchvision



In [7]:
import os
import shutil
from torchvision import datasets
from PIL import Image

def prepare_cifar10_dataset(download_dir, dataset_dir):
    """
    下载 CIFAR-10 数据集并按照 ImageFolder 的格式组织数据。

    参数:
    - download_dir: 下载数据集的临时目录。
    - dataset_dir: 组织后的数据集存储目录（应包含 train 和 val 文件夹）。
    """
    # 定义类别名称
    classes = [
        'airplane', 'automobile', 'bird', 'cat', 'deer',
        'dog', 'frog', 'horse', 'ship', 'truck'
    ]

    # 创建目录结构
    train_dir = os.path.join(dataset_dir, 'train')
    val_dir = os.path.join(dataset_dir, 'val')

    for split_dir in [train_dir, val_dir]:
        if not os.path.exists(split_dir):
            os.makedirs(split_dir)
            print(f"创建目录: {split_dir}")

        for cls in classes:
            cls_dir = os.path.join(split_dir, cls)
            if not os.path.exists(cls_dir):
                os.makedirs(cls_dir)
                print(f"创建类别目录: {cls_dir}")

    # 下载 CIFAR-10 训练集
    train_dataset = datasets.CIFAR10(
        root=download_dir,
        train=True,
        download=True
    )

    # 下载 CIFAR-10 测试集
    val_dataset = datasets.CIFAR10(
        root=download_dir,
        train=False,
        download=True
    )

    # 保存训练集图像
    print("开始保存训练集图像...")
    for idx, (img, label) in enumerate(train_dataset):
        cls_name = classes[label]
        cls_dir = os.path.join(train_dir, cls_name)
        img_filename = f"{cls_name}_{idx:05d}.png"
        img_path = os.path.join(cls_dir, img_filename)
        img.save(img_path)
        if (idx + 1) % 10000 == 0:
            print(f"已保存 {idx + 1} 张训练图像")

    print("训练集图像保存完成。")

    # 保存验证集图像
    print("开始保存验证集图像...")
    for idx, (img, label) in enumerate(val_dataset):
        cls_name = classes[label]
        cls_dir = os.path.join(val_dir, cls_name)
        img_filename = f"{cls_name}_{idx:05d}.png"
        img_path = os.path.join(cls_dir, img_filename)
        img.save(img_path)
        if (idx + 1) % 2000 == 0:
            print(f"已保存 {idx + 1} 张验证图像")

    print("验证集图像保存完成。")




# 定义下载和组织后的数据集路径
DOWNLOAD_DIR = "./deit/tmp"
DATASET_DIR = "./deit/data"

# 如果组织后的数据集目录已存在并且包含数据，可以选择跳过
if not os.path.exists(DATASET_DIR) or not os.listdir(DATASET_DIR):
    prepare_cifar10_dataset(DOWNLOAD_DIR, DATASET_DIR)
else:
    print(f"数据集目录 '{DATASET_DIR}' 已存在且不为空。跳过下载和组织步骤。")

    # 清理临时下载目录（可选）
if os.path.exists(DOWNLOAD_DIR):
    shutil.rmtree(DOWNLOAD_DIR)
    print(f"已删除临时下载目录: {DOWNLOAD_DIR}")


数据集目录 './deit/data' 已存在且不为空。跳过下载和组织步骤。


### Imports

In [8]:
import os
# 安装 wget 模块
!pip install wget

import wget  # 用于下载文件

# 定义预训练模型的URL
pretrained_models = {
    "deit_tiny_distilled_patch16_224": "https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth",
    "deit_small_distilled_patch16_224": "https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth",
    "deit_base_distilled_patch16_224": "https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth"
}

# 选择要下载的模型名称
selected_model = "deit_tiny_distilled_patch16_224"  # 可以选择 ["deit_tiny_distilled_patch16_224", "deit_small_distilled_patch16_224", "deit_base_distilled_patch16_224"]

# 定义预训练模型存储目录
pretrained_dir = "./deit/pretrained"

# 创建目录（如果不存在）
os.makedirs(pretrained_dir, exist_ok=True)
print(f"预训练模型将被下载到: {pretrained_dir}")

# 获取下载链接
model_url = pretrained_models[selected_model]
model_filename = os.path.basename(model_url)
model_path = os.path.join(pretrained_dir, model_filename)

# 下载模型文件（如果尚未下载）
if not os.path.exists(model_path):
    print(f"正在下载 {selected_model} 模型...")
    wget.download(model_url, out=model_path)
    print(f"\n模型已下载并保存到: {model_path}")
else:
    print(f"模型文件已存在: {model_path}")


预训练模型将被下载到: ./deit/pretrained
模型文件已存在: ./deit/pretrained/deit_tiny_distilled_patch16_224-b40b3cf7.pth


In [9]:
!pip show torchvision


Name: torchvision
Version: 0.20.1+cu121
Summary: image and video datasets and models for torch deep learning
Home-page: https://github.com/pytorch/vision
Author: PyTorch Core Team
Author-email: soumith@pytorch.org
License: BSD
Location: /usr/local/lib/python3.10/dist-packages
Requires: numpy, pillow, torch
Required-by: fastai, timm


In [10]:
# 5. 调整训练脚本并运行
import sys
import torch
from pathlib import Path
from argparse import Namespace





In [11]:
# 添加项目路径到 sys.path 以导入 main.py
project_path = '/content/deit'  # Colab 中的路径
if project_path not in sys.path:
    sys.path.append(project_path)



In [12]:
!ls /content/deit

augment.py	LICENSE       patchconvnet_models.py  README_deit.md	      resmlp_models.py
cait_models.py	losses.py     pretrained	      README.md		      run_with_submitit.py
data		main.py       __pycache__	      README_patchconvnet.md  samplers.py
datasets.py	models.py     README_3things.md       README_resmlp.md	      tox.ini
engine.py	models_v2.py  README_cait.md	      README_revenge.md       utils.py
hubconf.py	outputs       README_cosub.md	      requirements.txt


In [13]:

from main import main, get_args_parser

In [14]:
# 获取默认参数解析器
parser = get_args_parser()

# 创建一个 Namespace 对象并设置必要的参数
args = parser.parse_args([])  # 使用默认参数

# 修改参数以适应 CIFAR-10
args.model = 'deit_tiny_distilled_patch16_224'
args.dataset = 'cifar10'  # 对应 CIFAR-10
args.data_path = './deit/data'  # 数据集目录在当前工作目录下的 "deit/dataset" 文件夹
args.epochs = 20  # 训练 20 个 epoch
args.batch_size = 512
# args.lr = 0.0005  # 可以根据需要调整学习率
args.num_classes = 10  # CIFAR-10 有 10 个类别
args.output_dir = './deit/outputs/cifar10pre_tiny'  # 输出目录
args.eval = False  # 设置为 True 进行评估
args.finetune = './deit/pretrained/deit_tiny_distilled_patch16_224-b40b3cf7.pth'  # 设置为下载的预训练模型路径
args.resume = ''  # 如果不需要从检查点恢复训练，可以保持为空
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'

# 创建输出目录
if args.output_dir:
    Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    print(f"输出目录已创建: {args.output_dir}")

# 运行训练
main(args)

输出目录已创建: ./deit/outputs/cifar10pre_tiny
Not using distributed mode
Namespace(batch_size=64, epochs=20, bce_loss=False, unscale_lr=False, model='deit_tiny_distilled_patch16_224', input_size=224, drop=0.0, drop_path=0.1, model_ema=True, model_ema_decay=0.99996, model_ema_force_cpu=False, opt='adamw', opt_eps=1e-08, opt_betas=None, clip_grad=None, momentum=0.9, weight_decay=0.05, sched='cosine', lr=0.0005, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0, warmup_lr=1e-06, min_lr=1e-05, decay_epochs=30, warmup_epochs=5, cooldown_epochs=10, patience_epochs=10, decay_rate=0.1, color_jitter=0.3, aa='rand-m9-mstd0.5-inc1', smoothing=0.1, train_interpolation='bicubic', repeated_aug=True, train_mode=True, ThreeAugment=False, src=False, reprob=0.25, remode='pixel', recount=1, resplit=False, mixup=0.8, cutmix=1.0, cutmix_minmax=None, mixup_prob=1.0, mixup_switch_prob=0.5, mixup_mode='batch', teacher_model='regnety_160', teacher_path='', distillation_type='none', distillation_alpha=0.5, distillat



Creating model: deit_tiny_distilled_patch16_224


  checkpoint = torch.load(args.finetune, map_location='cpu')


number of params: 5910800
Start training for 20 epochs


  self._scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():


Epoch: [0]  [  0/781]  eta: 1:57:35  lr: 0.000001  loss: 8.3028 (8.3028)  time: 9.0338  data: 5.8394  max mem: 1700
Epoch: [0]  [ 10/781]  eta: 0:15:46  lr: 0.000001  loss: 8.4700 (8.4557)  time: 1.2270  data: 0.5392  max mem: 1742
Epoch: [0]  [ 20/781]  eta: 0:10:11  lr: 0.000001  loss: 8.3913 (8.4065)  time: 0.3918  data: 0.0071  max mem: 1742
Epoch: [0]  [ 30/781]  eta: 0:07:50  lr: 0.000001  loss: 8.2964 (8.3060)  time: 0.2965  data: 0.0054  max mem: 1742
Epoch: [0]  [ 40/781]  eta: 0:06:37  lr: 0.000001  loss: 7.9127 (8.1950)  time: 0.2553  data: 0.0051  max mem: 1742
Epoch: [0]  [ 50/781]  eta: 0:05:56  lr: 0.000001  loss: 7.8588 (8.1053)  time: 0.2716  data: 0.0055  max mem: 1742
Epoch: [0]  [ 60/781]  eta: 0:05:46  lr: 0.000001  loss: 7.5483 (7.9822)  time: 0.3658  data: 0.0057  max mem: 1742
Epoch: [0]  [ 70/781]  eta: 0:05:21  lr: 0.000001  loss: 7.2250 (7.8812)  time: 0.3606  data: 0.0041  max mem: 1742
Epoch: [0]  [ 80/781]  eta: 0:04:59  lr: 0.000001  loss: 7.0707 (7.7771)

  with torch.cuda.amp.autocast():


Test:  [  0/105]  eta: 0:12:11  loss: 2.6065 (2.6065)  acc1: 20.8333 (20.8333)  acc5: 63.5417 (63.5417)  time: 6.9624  data: 5.2633  max mem: 1746
Test:  [ 10/105]  eta: 0:01:11  loss: 2.7166 (2.6436)  acc1: 19.7917 (19.9811)  acc5: 64.5833 (65.7197)  time: 0.7537  data: 0.5065  max mem: 1746
Test:  [ 20/105]  eta: 0:00:42  loss: 1.7157 (2.0969)  acc1: 34.3750 (39.5833)  acc5: 79.1667 (77.7778)  time: 0.1799  data: 0.0696  max mem: 1746
Test:  [ 30/105]  eta: 0:00:30  loss: 1.7157 (2.3087)  acc1: 17.7083 (31.3844)  acc5: 68.7500 (70.9341)  time: 0.2114  data: 0.0950  max mem: 1746
Test:  [ 40/105]  eta: 0:00:23  loss: 2.8140 (2.4391)  acc1: 5.2083 (24.1870)  acc5: 55.2083 (66.8699)  time: 0.2129  data: 0.0988  max mem: 1746
Test:  [ 50/105]  eta: 0:00:19  loss: 2.4065 (2.3397)  acc1: 5.2083 (25.7149)  acc5: 60.4167 (71.3440)  time: 0.2603  data: 0.1419  max mem: 1746
Test:  [ 60/105]  eta: 0:00:15  loss: 1.9709 (2.3052)  acc1: 29.1667 (25.6660)  acc5: 84.3750 (72.8654)  time: 0.3070  d

KeyboardInterrupt: 