This repository contains the official implementation of Stratified Knowledge-Density Super-Network for Scalable Vision Transformers (AAAI 2026), a novel approach for building scalable Vision Transformers (ViTs) that enable efficient sub-network extraction for diverse deployment scenarios.
Traditional approaches require training and maintaining multiple ViT variants for different resource constraints, which is computationally expensive and inefficient. Our method transforms a pre-trained ViT into a Stratified Knowledge-Density Super-Network, where knowledge is hierarchically organized across weights, allowing flexible extraction of sub-networks that retain maximal knowledge for varying model sizes.
- π One-shot Transformation: Convert pre-trained ViTs into scalable super-networks
- π Knowledge Stratification: Hierarchical organization of knowledge across parameter dimensions
- β‘ Zero-cost Extraction: Extract sub-networks of arbitrary sizes at
$\mathcal{O}(1)$ cost - π― State-of-the-Art Performance: Outperforms existing model compression and expansion methods
- π§ Easy Deployment: Support for diverse resource constraints from edge devices to servers
WPAC concentrates knowledge into a compact set of critical weights through function-preserving transformations:
- Token-wise Weighted PCA: Applies PCA to intermediate features with Taylor-based importance weighting
- Function Preservation: Mathematical equivalence maintained through transformation matrix injection
- Information Concentration: Knowledge condensed into top principal components
PIAD enhances knowledge stratification through adaptive dropout:
- Progressive Evaluation: Dynamically assesses importance of weight groups
- Importance-Aware Sampling: Lower dropout probabilities for important parameters
- Hierarchical Training: Promotes knowledge stratification across different model sizes
Sub-networks are extracted from the trained SKD Super-Network and evaluated directly on ImageNet-1k without fine-tuning.
| Backbone | Method | MACs | Top-1 Acc |
|---|---|---|---|
| DeiT-Base | Original | 16.88 G | 81.8 |
| DeiT-Base | SKD (Ours) | 14.07 G | 81.5 |
| DeiT-Base | SKD (Ours) | 11.25 G | 80.9 |
| DeiT-Base | SKD (Ours) | 8.44 G | 80.4 |
| DeiT-Base | SKD (Ours) | 5.63 G | 77.0 |
| DeiT-Small | Original | 4.26 G | 79.8 |
| DeiT-Small | SKD (Ours) | 3.55 G | 79.0 |
| DeiT-Small | SKD (Ours) | 2.84 G | 78.2 |
| DeiT-Small | SKD (Ours) | 2.13 G | 76.2 |
| DeiT-Small | SKD (Ours) | 1.42 G | 70.6 |
| DeiT-Tiny | Original | 1.08 G | 72.1 |
| DeiT-Tiny | SKD (Ours) | 0.90 G | 70.0 |
| DeiT-Tiny | SKD (Ours) | 0.72 G | 68.6 |
| DeiT-Tiny | SKD (Ours) | 0.54 G | 65.8 |
| DeiT-Tiny | SKD (Ours) | 0.36 G | 61.4 |
Results after fine-tuning for 30 epochs on ImageNet-1k.
| Backbone | Method | MACs | Params | Top-1 Acc |
|---|---|---|---|---|
| DeiT-Base | Original | 16.88 G | 86.57 M | 81.80 |
| DeiT-Base | SKD (Ours) | 10.57 G | 54.55 M | 81.45 |
| DeiT-Base | SKD (Ours) | 8.47 G | 43.92 M | 81.24 |
| DeiT-Small | Original | 4.26 G | 22.05 M | 79.83 |
| DeiT-Small | SKD (Ours) | 3.07 G | 16.03 M | 79.42 |
| DeiT-Small | SKD (Ours) | 2.43 G | 12.78 M | 78.71 |
| DeiT-Tiny | Original | 1.08 G | 5.72 M | 72.14 |
| DeiT-Tiny | SKD (Ours) | 0.89 G | 4.77 M | 71.40 |
Execute run_reparam.sh to apply Weighted PCA for Attention Contraction (WPAC) for reparameterizing the pre-trained model. Subsequently, run run_train.sh to train the reparameterized model into an SKD Super-Network using Progressive Importance-Aware Dropout (PIAD).
Example Commands:
# Step 1: Reparameterize the model with WPAC
sh run_reparam.sh
# Step 2: Train the SKD Super-Network with PIAD
sh run_train.sh
Note: Replace /path/to/ImageNet2012/Data/CLS-LOC with the actual path to your ImageNet2012 dataset.
Execute run_test.sh to comprehensively evaluate the performance of sub-networks extracted from the pre-trained SKD Super-Network. This script facilitates rapid validation across multiple model sizes and configurations without requiring fine-tuning.
We provide pre-trained SKD Super-Network checkpoints in this link. After downloading, please place the files in the following directory structure under your project's root folder:
output/
βββ peelable/
βββ vit_base/
β βββ checkpoint_149.pth
β βββ mask_table.json
βββ vit_small/
β βββ checkpoint_299.pth
β βββ mask_table.json
βββ vit_tiny/
βββ checkpoint_299.pth
βββ mask_table.json
Example Commands:
# Evaluate sub-network performance directly from SKD Super-Network
sh run_test.sh
Note: Replace /path/to/ImageNet2012/Data/CLS-LOC with the actual path to your ImageNet2012 dataset.
Execute run_train_peeled.sh to perform fine-tuning on sub-networks extracted from the SKD Super-Network, optimizing their performance for specific tasks or datasets. This process enhances the capabilities of the pre-structured sub-networks while maintaining their efficient architecture.
Example Commands:
# Fine-Tune extracted sub-networks for enhanced performance
sh run_train_peeled.sh
Note: Replace /path/to/ImageNet2012/Data/CLS-LOC with the actual path to your ImageNet2012 dataset.
If you use this work in your research, please cite our paper:
@inproceedings{li2026stratified,
title={Stratified Knowledge-Density Super-Network for Scalable Vision Transformers},
author={Li, Longhua and Qi, Lei and Geng, Xin},
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
volume={40},
number={27},
pages={22985--22993},
year={2026}
}
This project is licensed under the Apache 2.0 License - see the LICENSE file for details.
Our work is based on the DeiT (Data-efficient Image Transformers) repository. We sincerely thank the authors for open-sourcing their excellent work.