Skip to content

This repository contains the official implementation of the research paper, "FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization"

License

mohamadmansourX/ml-fastvit

 
 

Repository files navigation

FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization

This is the official repository of

FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization. Pavan Kumar Anasosalu Vasu, James Gabriel, Jeff Zhu, Oncel Tuzel, Anurag Ranjan. ICCV 2023

arxiv webpage

FastViT Performance

All models are trained on ImageNet-1K and benchmarked on iPhone 12 Pro using ModelBench app.

Setup

conda create -n fastvit python=3.9
conda activate fastvit
conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch
pip install -r requirements.txt

Usage

To use our model, follow the code snippet below,

import torch
import models
from timm.models import create_model
from models.modules.mobileone import reparameterize_model

# To Train from scratch/fine-tuning
model = create_model("fastvit_t8")
# ... train ...

# Load unfused pre-trained checkpoint for fine-tuning
# or for downstream task training like detection/segmentation
checkpoint = torch.load('/path/to/unfused_checkpoint.pth.tar')
model.load_state_dict(checkpoint['state_dict'])
# ... train ...

# For inference
model.eval()      
model_inf = reparameterize_model(model)
# Use model_inf at test-time

FastViT Model Zoo

Image Classification

Models trained on ImageNet-1K

Model Top-1 Acc. Latency Pytorch Checkpoint (url) CoreML Model
FastViT-T8 76.2 0.8 T8(unfused) fastvit_t8.mlpackage.zip
FastViT-T12 79.3 1.2 T12(unfused) fastvit_t12.mlpackage.zip
FastViT-S12 79.9 1.4 S12(unfused) fastvit_s12.mlpackage.zip
FastViT-SA12 80.9 1.6 SA12(unfused) fastvit_sa12.mlpackage.zip
FastViT-SA24 82.7 2.6 SA24(unfused) fastvit_sa24.mlpackage.zip
FastViT-SA36 83.6 3.5 SA36(unfused) fastvit_sa36.mlpackage.zip
FastViT-MA36 83.9 4.6 MA36(unfused) fastvit_ma36.mlpackage.zip

Models trained on ImageNet-1K with knowledge distillation.

Model Top-1 Acc. Latency Pytorch Checkpoint (url) CoreML Model
FastViT-T8 77.2 0.8 T8(unfused) fastvit_t8.mlpackage.zip
FastViT-T12 80.3 1.2 T12(unfused) fastvit_t12.mlpackage.zip
FastViT-S12 81.1 1.4 S12(unfused) fastvit_s12.mlpackage.zip
FastViT-SA12 81.9 1.6 SA12(unfused) fastvit_sa12.mlpackage.zip
FastViT-SA24 83.4 2.6 SA24(unfused) fastvit_sa24.mlpackage.zip
FastViT-SA36 84.2 3.5 SA36(unfused) fastvit_sa36.mlpackage.zip
FastViT-MA36 84.6 4.6 MA36(unfused) fastvit_ma36.mlpackage.zip

Latency Benchmarking

Latency of all models measured on iPhone 12 Pro using ModelBench app. For further details please contact James Gabriel and Jeff Zhu. All reported numbers are rounded to the nearest decimal.

Training

Image Classification

Dataset Preparation

Download the ImageNet-1K dataset and structure the data as follows:

/path/to/imagenet-1k/
  train/
    class1/
      img1.jpeg
    class2/
      img2.jpeg
  validation/
    class1/
      img3.jpeg
    class2/
      img4.jpeg

To train a variant of FastViT model, follow the respective command below:

FastViT-T8
# Without Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_t8 -b 128 --lr 1e-3 \
--native-amp --mixup 0.2 --output /path/to/save/results \
--input-size 3 256 256

# With Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_t8 -b 128 --lr 1e-3 \ 
--native-amp --mixup 0.2 --output /path/to/save/results \
--input-size 3 256 256 
--distillation-type "hard"
FastViT-T12
# Without Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_t12 -b 128 --lr 1e-3 \
--native-amp --mixup 0.2 --output /path/to/save/results \
--input-size 3 256 256

# With Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_t12 -b 128 --lr 1e-3 \ 
--native-amp --mixup 0.2 --output /path/to/save/results \
--input-size 3 256 256 
--distillation-type "hard"
FastViT-S12
# Without Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_s12 -b 128 --lr 1e-3 \
--native-amp --mixup 0.2 --output /path/to/save/results \
--input-size 3 256 256

# With Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_s12 -b 128 --lr 1e-3 \ 
--native-amp --mixup 0.2 --output /path/to/save/results \
--input-size 3 256 256 
--distillation-type "hard"
FastViT-SA12
# Without Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_sa12 -b 128 --lr 1e-3 \
--native-amp --mixup 0.2 --output /path/to/save/results \
--input-size 3 256 256 --drop-path 0.1

# With Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_sa12 -b 128 --lr 1e-3 \ 
--native-amp --output /path/to/save/results \
--input-size 3 256 256 
--distillation-type "hard"
FastViT-SA24
# Without Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_sa24 -b 128 --lr 1e-3 \
--native-amp --mixup 0.2 --output /path/to/save/results \
--input-size 3 256 256 --drop-path 0.1

# With Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_sa24 -b 128 --lr 1e-3 \ 
--native-amp --output /path/to/save/results \
--input-size 3 256 256 --drop-path 0.05 \
--distillation-type "hard"
FastViT-SA36
# Without Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_sa36 -b 128 --lr 1e-3 \
--native-amp --mixup 0.2 --output /path/to/save/results \
--input-size 3 256 256 --drop-path 0.2

# With Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_sa36 -b 128 --lr 1e-3 \ 
--native-amp --output /path/to/save/results \
--input-size 3 256 256 --drop-path 0.1 \
--distillation-type "hard"
FastViT-MA36
# Without Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_t8 -b 128 --lr 1e-3 \
--native-amp --output /path/to/save/results \
--input-size 3 256 256 --drop-path 0.35

# With Distillation
python -m torch.distributed.launch --nproc_per_node=8 train.py \
/path/to/ImageNet/dataset --model fastvit_t8 -b 128 --lr 1e-3 \ 
--native-amp --output /path/to/save/results \
--input-size 3 256 256 --drop-path 0.2 \
--distillation-type "hard"

Evaluation

To run evaluation on ImageNet, follow the example command below:

FastViT-T8
# Evaluate unfused checkpoint
python validate.py /path/to/ImageNet/dataset --model fastvit_t8 \
--checkpoint /path/to/pretrained_checkpoints/fastvit_t8.pth.tar

# Evaluate fused checkpoint
python validate.py /path/to/ImageNet/dataset --model fastvit_t8 \
--checkpoint /path/to/pretrained_checkpoints/fastvit_t8_reparam.pth.tar \
--use-inference-mode

Model Export

To export a coreml package file from a pytorch checkpoint, follow the example command below:

FastViT-T8
python export_model.py --variant fastvit_t8 --output-dir /path/to/save/exported_model \
--checkpoint /path/to/pretrained_checkpoints/fastvit_t8_reparam.pth.tar

Citation

@inproceedings{vasufastvit2023,
  author = {Pavan Kumar Anasosalu Vasu and James Gabriel and Jeff Zhu and Oncel Tuzel and Anurag Ranjan},
  title = {FastViT:  A Fast Hybrid Vision Transformer using Structural Reparameterization},
  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
  year = {2023}
}

Acknowledgements

Our codebase is built using multiple opensource contributions, please see ACKNOWLEDGEMENTS for more details.

About

This repository contains the official implementation of the research paper, "FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%