Skip to content

martinagvilas/vit-cls_emb

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

19 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Analyzing Vision Tranformers in Class Embedding Space (NeurIPS '23)

by Martina G. Vilas, Timothy Schaumlöffel and Gemma Roig

Links: Paper | Video presentation (coming soon) | Poster (coming soon)

Abstract: Despite the growing use of transformer models in computer vision, a mechanistic understanding of these networks is still needed. This work introduces a method to reverse-engineer Vision Transformers trained to solve image classification tasks. Inspired by previous research in NLP, we demonstrate how the inner representations at any level of the hierarchy can be projected onto the learned class embedding space to uncover how these networks build categorical representations for their pre- dictions. We use our framework to show how image tokens develop class-specific representations that depend on attention mechanisms and contextual information, and give insights on how self-attention and MLP layers differentially contribute to this categorical composition. We additionally demonstrate that this method (1) can be used to determine the parts of an image that would be important for detecting the class of interest, and (2) exhibits significant advantages over traditional linear probing approaches. Taken together, our results position our proposed framework as a powerful tool for mechanistic interpretability and explainability research.

framework

Schematic of our framework

📎 Contents

Tutorial

You can access a tutorial of our method here: Open In Colab

Running the experiments

Step 1: Get a local working copy of this code

1.1. Clone this repository in your local machine.

1.2. Install the required software using conda, by running:

conda create --name vit-cls python=3.9
conda activate vit-cls
pip install -r requirements.txt
pip install .

Step 2: Download the dataset and model checkpoints

2.1. Download the ImageNet-S dataset from here.

2.2. Download the stimuli info file from here, and place it inside the ImageNet-S/ImageNetS919 folder downloaded in the previous step.

2.3. Download the model checkpoint folder from here, and place it inside the project folder.

Step 3: Run experiments for extracting code

3.1. Project hidden states to class embedding space and save key coefficients, by running:

python extractor.py -pp {PATH TO SOURCE CODE} -dp {PATH TO DATASET} -m {MODEL} -pretrained
  • The model can be one of vit_b_32, vit_b_16, vit_large_16, vit_cifar_16, vit_miil_16, deit_ensemble_16 (Refinement model) and vit_gap_16.
  • You can reproduce the results of the random model by removing the -pretrained flag.

3.2. Run attention perturbation studies, by:

python perturbation/attn_perturbation.py -pp {PATH TO SOURCE CODE} -dp {PATH TO DATASET} -m vit_b_32 -pt {PERTURBATION TYPE}
  • Perturbation type can be one of self_only or no_cls.

3.3. Run context perturbation studies, by:

python perturbation/tokens_perturbation.py -pp {PATH TO SOURCE CODE} -dp {PATH TO DATASET} -m vit_b_32 -mt {MASK TYPE}
  • Mask type can be one of context or class label.

3.4. Run memory extractor, by:

python memories.py -pp {PATH TO SOURCE CODE} -dp {PATH TO DATASET} -m {MODEL} -lt {LAYER TYPE}
  • Layer type can be one of attn or mlp.

3.5. Run comparison with a linear probing approach, by:

python linear_probing/prober.py -pp {PATH TO SOURCE CODE} -dp {PATH TO DATASET} -l {LAYER INDEX}

Step 4: Reproduce the results

After running the above code, head to the notebooks section to reproduce and visualize the reported results.

Citing our work

Please cite this work as:

@inproceedings{vilas2023analyzing_vit,
 title = {Analyzing Vision Transformers for Image Classification in Class Embedding Space},
 author = {Vilas, Martina G. and Schauml\"{o}ffel, Timothy and Roig, Gemma},
 booktitle = {Advances in Neural Information Processing Systems},
 pages = {40030--40041},
 volume = {36},
 year = {2023}
 url = {https://proceedings.neurips.cc/paper_files/paper/2023/file/7dd309df03d37643b96f5048b44da798-Paper-Conference.pdf},
}

Acknowledgements

  • The pre-trained models are extracted from the timm library.
  • Our readme is inspired by IPViT.

About

Accompanying code for "Analyzing Vision Tranformers in Class Embedding Space" (NeurIPS '23)

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published