Skip to content

Latest commit



94 lines (64 loc) · 5.07 KB


File metadata and controls

94 lines (64 loc) · 5.07 KB

Transfer Learning and Fine-tuning Vision Transformers for Image Classification


As the Transformers architecture scaled well in Natural Language Processing, the same architecture was applied to images by creating small patches of the image and treating them as tokens. The result was a Vision Transformer (Vision Transformers). Before we get started with transfer learning / fine-tuning concepts, let's compare Convolutional Neural Networks (CNNs) with Vision Transformers.

CNN vs Vision Transformers: Inductive Bias

Inductive bias is a term used in machine learning to describe the set of assumptions that a learning algorithm uses to make predictions. In simpler terms, inductive bias is like a shortcut that helps a machine learning model make educated guesses based on the information it has seen so far.

Here's a couple of inductive biases we observe in CNNs:

  • Translational Equivariance: an object can appear anywhere in the image, and CNNs can detect its features.
  • Locality: pixels in an image interact mainly with its surrounding pixels to form features.

These are lacking in Vision Transformers. Then how do they perform so well? It's because they're highly scalable and they're trained on massive amounts of images. Hence, they overcome the need for these inductive biases.

Using pre-trained Vision Transformers

It's not feasible for everyone to train a Vision Transformer on millions of images to get good performance. Instead, one can use openly available models from places such as the Hugging Face Hub.

What do you do with the pre-trained model? You can apply transfer learning and fine-tune it!

Transfer Learning & Fine-Tuning for Image Classification

The idea of transfer learning is that we can leverage the features learned by the Vision Transformers trained on a very large dataset and apply these features to our dataset. This can lead to significant improvements in model performance, especially when our dataset has limited data available for training.

Since we are taking advantage of the learned features, we do not need to update the entire model either. By freezing most of the weights, we can train only certain layers to get excellent performance with less training time and low GPU consumption.

Multi-class Image Classification

You can go through the transfer learning tutorial using Vision Transformers for image classification in this notebook:

<a target="_blank" href=""

Open In Colab

This is what we'll be building: an image classifier to tell apart dog and cat breeds:

<iframe src="" frameborder="0" width="850" height="450" ></iframe>

It might be that the domain of your dataset is not very similar to the pre-trained model's dataset. Yet, instead of training a Vision Transformer from scratch, we can choose to update the weights of the entire pre-trained model albeit with a lower learning rate, which will "fine-tune" the model to perform well with our data.

However, in most scenarios, applying transfer learning is ample in the case of Vision Transformers.

Multi-label Image Classification

The tutorial above teaches multi-class image classification, where each image only has 1 class assigned to it. What about scenarios where each image has multiple labels in a multi-class dataset?

This notebook will walk you through a fine-tuning tutorial using Vision Transformer for multi-label image classification:

<a target="_blank" href=""

Open In Colab

We'll also be learning how to use Hugging Face Accelerate to write our custom training loops. This is what you can expect to see as the outcome of the multi-label classification tutorial:

<iframe src="" frameborder="0" width="850" height="450" ></iframe>

Additional Resources

  • Original Vision Transformers Paper: An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale Paper
  • Swin Transformers Paper: Swin Transformer: Hierarchical Vision Transformer using Shifted Windows Paper
  • A systematic empirical study in order to better understand the interplay between the amount of training data, regularization, augmentation, model size and compute budget for Vision Transformers: How to train your Vision Transformers? Data, Augmentation, and Regularization in Vision Transformers Paper