Skip to content
/ SMKD Public

Code implementation for the paper: Supervised Masked Knowledge Distillation for Few-Shot Transformers

License

Notifications You must be signed in to change notification settings

HL-hanlin/SMKD

Repository files navigation

SMKD

This is the PyTorch implementation of "Supervised Masked Knowledge Distillation for Few-Shot Transformers".

Han Lin*, Guangxing Han*, Jiawei Ma, Shiyuan Huang, Xudong Lin, Shih-Fu Chang

Columbia University, Department of Computer Science

The IEEE / CVF Computer Vision and Pattern Recognition Conference (CVPR), 2023

diagram

Installation

Python 3.8, Pytorch 1.11, CUDA 11.3. The code is tested on Ubuntu 20.04.

We have prepared a conda YAML file which contains all the python dependencies.

conda env create -f environment.yml

To activate this conda environment,

conda activate smkd

We use wandb to log the training stats (optional).

Datasets

We prepare 𝒎𝒊𝒏𝒊ImageNet and 𝒕𝒊𝒆𝒓𝒆𝒅ImageNet and resize the images following the guidelines from HCTransformers.

  • 𝒎𝒊𝒏𝒊ImageNet

The 𝑚𝑖𝑛𝑖ImageNet dataset was proposed by Vinyals et al. for few-shot learning evaluation. Its complexity is high due to the use of ImageNet images but requires fewer resources and infrastructure than running on the full ImageNet dataset. In total, there are 100 classes with 600 samples of color images per class. These 100 classes are divided into 64, 16, and 20 classes respectively for sampling tasks for meta-training, meta-validation, and meta-test. To generate this dataset from ImageNet, you may use the repository 𝑚𝑖𝑛𝑖ImageNet tools.

Note that in our implemenation images are resized to 480 × 480 because the data augmentation we used require the image resolution to be greater than 224 to avoid distortions. Therefore, when generating 𝒎𝒊𝒏𝒊ImageNet, you should set --image_resize 0 to keep the original size or --image_resize 480 as what we did.

  • 𝒕𝒊𝒆𝒓𝒆𝒅ImageNet

The 𝑡𝑖𝑒𝑟𝑒𝑑ImageNet dataset is a larger subset of ILSVRC-12 with 608 classes (779,165 images) grouped into 34 higher-level nodes in the ImageNet human-curated hierarchy. To generate this dataset from ImageNet, you may use the repository 𝑡𝑖𝑒𝑟𝑒𝑑ImageNet dataset: 𝑡𝑖𝑒𝑟𝑒𝑑ImageNet tools.

Similar to 𝒎𝒊𝒏𝒊ImageNet, you should set --image_resize 0 to keep the original size or --image_resize 480 as what we did when generating 𝒕𝒊𝒆𝒓𝒆𝒅ImageNet.

  • CIFAR-FS and FC100

CIFAR-FS and FC100 can be download using the scripts from DeepEMD.

Training

Our model are trained on 8 RTX3090 GPUs by default (24GB memory). You can specify the argument --nproc_per_node in the following command file as the number of GPUs available in your server, and increase/decrease the argument --batch_size_per_gpu if your GPU has more/less memory.

  • Phase1 (self-supervised)

In this phase, we pretrain our model using the self-supervised learning method iBOT. All models are trained for a maximum of 1600 epochs. We evaluate our model on the validation set after training for every 50 epochs, and report the best. 1-shot and 5-shot evaluation results with Prototype method is given in the following table. We also provide full checkpoints and test-set features for pretrained models, and command to replicate the results.

--data_path: need to be set as the location of the training set of dataset XXX (e.g. miniImageNet). --output_dir: location where the phase1 checkpoints and evaluation files to be stored.

Dataset 1-shot 5-shot Download
𝒎𝒊𝒏𝒊ImageNet 60.93% 80.38% checkpoint features command
𝒕𝒊𝒆𝒓𝒆𝒅ImageNet 71.36% 83.28% checkpoint features command
CIFAR-FS 65.70% 83.45% checkpoint features command
FC100 44.20% 61.64% checkpoint features command
  • Phase2 (supervised)

In this second phase, we start from the checkpoint in phase 1 and further train the model using the supervised knowledge distillation method proposed in our paper. All models are trained for a maximum of 150 epochs. We evaluate our model on the validation set after training for every 5 epochs, and report the best. Similarly, 1-shot and 5-shot evaluation results with Prototype method is given in the following table. We also provide checkpoints and features for pretrained models.

--pretrained_dino_path: should be set as the same location as --output_dir in phase1. --pretrained_dino_file: which checkpoint file to resume from (e.g. checkpoint1250.pth). --output_dir: location where the phase2 checkpoints and evaluation files to be stored.

Dataset 1-shot 5-shot Download
𝒎𝒊𝒏𝒊ImageNet 74.28% 88.82% checkpoint features command
𝒕𝒊𝒆𝒓𝒆𝒅ImageNet 78.83% 91.02% checkpoint features command
CIFAR-FS 80.08% 90.63% checkpoint features command
FC100 50.38% 68.37% checkpoint features command

Evaluation

We use eval_smkd.py to evaluate a trained model (either from phase1 or phase2). Before running the evaluation code, we need to specify the image data path in server_dict of this python file.

For example, we can use the following code to do 5-way 5-shot evaluation on the model trained in phase2 on mini-ImageNet:

  • prototype:
python eval_smkd.py --server mini --num_shots 5 --ckp_path /root/autodl-nas/FSVIT_results/MINI480_phase2 --ckpt_filename checkpoint0040.pth --output_dir /root/autodl-nas/FSVIT_results/MINI480_prototype --evaluation_method cosine --iter_num 10000
  • classifier:
python eval_smkd.py --server mini --num_shots 5 --ckp_path /root/autodl-nas/FSVIT_results/MINI480_phase2 --ckpt_filename checkpoint0040.pth --output_dir /root/autodl-nas/FSVIT_results/MINI480_classifier --evaluation_method classifier --iter_num 1000

Citation

@misc{lin2023supervised,
      title={Supervised Masked Knowledge Distillation for Few-Shot Transformers}, 
      author={Han Lin and Guangxing Han and Jiawei Ma and Shiyuan Huang and Xudong Lin and Shih-Fu Chang},
      year={2023},
      eprint={2303.15466},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Acknowledgement

This repo is developed based on HCTransformers, iBOT and DINO. Thanks for their wonderful codebases.

About

Code implementation for the paper: Supervised Masked Knowledge Distillation for Few-Shot Transformers

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Languages