Clone Repository for Big Transfer (BiT): General Visual Representation Learning
This repository provides the code and models from the paper "Big Transfer (BiT): General Visual Representation Learning" by Alexander Kolesnikov, Lucas Beyer, Xiaohua Zhai, Joan Puigcerver, Jessica Yung, Sylvain Gelly, and Neil Houlsby.
Update 13/02/2023: Code base refactored to modular structure, see Maté.
To install the source base, you can clone the repo, or use maté:
mate install oalee/big_transfer/experiments/bit -yo pip
mate train bit learn
Update 18/06/2021: We release new high performing BiT-R50x1 models, which were distilled from BiT-M-R152x2, see this section. More details in our paper "Knowledge distillation: A good teacher is patient and consistent".
Update 08/02/2021: We also release ALL BiT-M models fine-tuned on ALL 19 VTAB-1k datasets, see below.
In this repository we release multiple models from the Big Transfer (BiT): General Visual Representation Learning paper that were pre-trained on the ILSVRC-2012 and ImageNet-21k datasets. We provide the code to fine-tuning the released models in the major deep learning frameworks TensorFlow 2, PyTorch and Jax/Flax.
We hope that the computer vision community will benefit by employing more powerful ImageNet-21k pretrained models as opposed to conventional models pre-trained on the ILSVRC-2012 dataset.
We also provide colabs for a more exploratory interactive use: a TensorFlow 2 colab, a PyTorch colab, and a Jax colab.
Make sure you have Python>=3.6
installed on your machine.
To setup Tensorflow 2, PyTorch or Jax, follow the instructions provided in the corresponding repository linked here.
In addition, install python dependencies by running (please select tf2
, pytorch
or jax
in the command below):
pip install -r bit_{tf2|pytorch|jax}/requirements.txt
First, download the BiT model. We provide models pre-trained on ILSVRC-2012 (BiT-S) or ImageNet-21k (BiT-M) for 5 different architectures: ResNet-50x1, ResNet-101x1, ResNet-50x3, ResNet-101x3, and ResNet-152x4.
For example, if you would like to download the ResNet-50x1 pre-trained on ImageNet-21k, run the following command:
wget https://storage.googleapis.com/bit_models/BiT-M-R50x1.{npz|h5}
Other models can be downloaded accordingly by plugging the name of the model (BiT-S or BiT-M) and architecture in the above command.
Note that we provide models in two formats: npz
(for PyTorch and Jax) and h5
(for TF2). By default we expect that model weights are stored in the root folder of this repository.
Then, you can run fine-tuning of the downloaded model on your dataset of interest in any of the three frameworks. All frameworks share the command line interface
python3 -m bit_{pytorch|jax|tf2}.train --name cifar10_`date +%F_%H%M%S` --model BiT-M-R50x1 --logdir /tmp/bit_logs --dataset cifar10
Currently. all frameworks will automatically download CIFAR-10 and CIFAR-100 datasets. Other public or custom datasets can be easily integrated: in TF2 and JAX we rely on the extensible tensorflow datasets library. In PyTorch, we use torchvision’s data input pipeline.
Note that our code uses all available GPUs for fine-tuning.
We also support training in the low-data regime: the --examples_per_class <K>
option will randomly draw K samples per class for training.
To see a detailed list of all available flags, run python3 -m bit_{pytorch|jax|tf2}.train --help
.
For convenience, we provide BiT-M models that were already fine-tuned on the
ILSVRC-2012 dataset. The models can be downloaded by adding the -ILSVRC2012
postfix, e.g.
wget https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz
We release all architectures mentioned in the paper, such that you may choose between accuracy or speed: R50x1, R101x1, R50x3, R101x3, R152x4.
In the above path to the model file, simply replace R50x1
by your architecture of choice.
We further investigated more architectures after the paper's publication and found R152x2 to have a nice trade-off between speed and accuracy, hence we also include this in the release and provide a few numbers below.
We also release the fine-tuned models for each of the 19 tasks included in the VTAB-1k benchmark. We ran each model three times and release each of these runs. This means we release a total of 5x19x3=285 models, and hope these can be useful in further analysis of transfer learning.
The files can be downloaded via the following pattern:
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-{R50x1,R101x1,R50x3,R101x3,R152x4}-run{0,1,2}-{caltech101,diabetic_retinopathy,dtd,oxford_flowers102,oxford_iiit_pet,resisc45,sun397,cifar100,eurosat,patch_camelyon,smallnorb-elevation,svhn,dsprites-orientation,smallnorb-azimuth,clevr-distance,clevr-count,dmlab,kitti-distance,dsprites-xpos}.npz
We did not convert these models to TF2 (hence there is no corresponding .h5
file), however, we also uploaded TFHub models which can be used in TF1 and TF2. An example sequence of commands for downloading one such model is:
mkdir BiT-M-R50x1-run0-caltech101.tfhub && cd BiT-M-R50x1-run0-caltech101.tfhub
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-R50x1-run0-caltech101.tfhub/{saved_model.pb,tfhub_module.pb}
mkdir variables && cd variables
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-R50x1-run0-caltech101.tfhub/variables/variables.{data@1,index}
For reproducibility, our training script uses hyper-parameters (BiT-HyperRule) that were used in the original paper. Note, however, that BiT models were trained and finetuned using Cloud TPU hardware, so for a typical GPU setup our default hyper-parameters could require too much memory or result in a very slow progress. Moreover, BiT-HyperRule is designed to generalize across many datasets, so it is typically possible to devise more efficient application-specific hyper-parameters. Thus, we encourage the user to try more light-weight settings, as they require much less resources and often result in a similar accuracy.
For example, we tested our code using a 8xV100 GPU machine on the CIFAR-10 and CIFAR-100 datasets, while reducing batch size from 512 to 128 and learning rate from 0.003 to 0.001. This setup resulted in nearly identical performance (see Expected results below) in comparison to BiT-HyperRule, despite being less computationally demanding.
Below, we provide more suggestions on how to optimize our paper's setup.
The default BiT-HyperRule was developed on Cloud TPUs and is quite memory-hungry. This is mainly due to the large batch-size (512) and image resolution (up to 480x480). Here are some tips if you are running out of memory:
- In
bit_hyperrule.py
we specify the input resolution. By reducing it, one can save a lot of memory and compute, at the expense of accuracy. - The batch-size can be reduced in order to reduce memory consumption. However, one then also needs to play with learning-rate and schedule (steps) in order to maintain the desired accuracy.
- The PyTorch codebase supports a batch-splitting technique ("micro-batching") via
--batch_split
option. For example, running the fine-tuning with--batch_split 8
reduces memory requirement by a factor of 8.
We verified that when using the BiT-HyperRule, the code in this repository reproduces the paper's results.
For these common benchmarks, the aforementioned changes to the BiT-HyperRule (--batch 128 --base_lr 0.001
) lead to the following, very similar results.
The table shows the min←median→max result of at least five runs.
NOTE: This is not a comparison of frameworks, just evidence that all code-bases can be trusted to reproduce results.
Dataset | Ex/cls | TF2 | Jax | PyTorch |
---|---|---|---|---|
CIFAR10 | 1 | 52.5 ← 55.8 → 60.2 | 48.7 ← 53.9 → 65.0 | 56.4 ← 56.7 → 73.1 |
CIFAR10 | 5 | 85.3 ← 87.2 → 89.1 | 80.2 ← 85.8 → 88.6 | 84.8 ← 85.8 → 89.6 |
CIFAR10 | full | 98.5 | 98.4 | 98.5 ← 98.6 → 98.6 |
CIFAR100 | 1 | 34.8 ← 35.7 → 37.9 | 32.1 ← 35.0 → 37.1 | 31.6 ← 33.8 → 36.9 |
CIFAR100 | 5 | 68.8 ← 70.4 → 71.4 | 68.6 ← 70.8 → 71.6 | 70.6 ← 71.6 → 71.7 |
CIFAR100 | full | 90.8 | 91.2 | 91.1 ← 91.2 → 91.4 |
Dataset | Ex/cls | Jax | PyTorch |
---|---|---|---|
CIFAR10 | 1 | 44.0 ← 56.7 → 65.0 | 50.9 ← 55.5 → 59.5 |
CIFAR10 | 5 | 85.3 ← 87.0 → 88.2 | 85.3 ← 85.8 → 88.6 |
CIFAR10 | full | 98.5 | 98.5 ← 98.5 → 98.6 |
CIFAR100 | 1 | 36.4 ← 37.2 → 38.9 | 34.3 ← 36.8 → 39.0 |
CIFAR100 | 5 | 69.3 ← 70.5 → 72.0 | 70.3 ← 72.0 → 72.3 |
CIFAR100 | full | 91.2 | 91.2 ← 91.3 → 91.4 |
(TF2 models not yet available.)
Dataset | Ex/cls | TF2 | Jax | PyTorch |
---|---|---|---|---|
CIFAR10 | 1 | 49.9 ← 54.4 → 60.2 | 48.4 ← 54.1 → 66.1 | 45.8 ← 57.9 → 65.7 |
CIFAR10 | 5 | 80.8 ← 83.3 → 85.5 | 76.7 ← 82.4 → 85.4 | 80.3 ← 82.3 → 84.9 |
CIFAR10 | full | 97.2 | 97.3 | 97.4 |
CIFAR100 | 1 | 35.3 ← 37.1 → 38.2 | 32.0 ← 35.2 → 37.8 | 34.6 ← 35.2 → 38.6 |
CIFAR100 | 5 | 63.8 ← 65.0 → 66.5 | 63.4 ← 64.8 → 66.5 | 64.7 ← 65.5 → 66.0 |
CIFAR100 | full | 86.5 | 86.4 | 86.6 |
These results were obtained using BiT-HyperRule.
However, because this results in large batch-size and large resolution, memory can be an issue.
The PyTorch code supports batch-splitting, and hence we can still run things there without resorting to Cloud TPUs by adding the --batch_split N
command where N
is a power of two.
For instance, the following command produces a validation accuracy of 80.68
on a machine with 8 V100 GPUs:
python3 -m bit_pytorch.train --name ilsvrc_`date +%F_%H%M%S` --model BiT-M-R50x1 --logdir /tmp/bit_logs --dataset imagenet2012 --batch_split 4
Further increase to --batch_split 8
when running with 4 V100 GPUs, etc.
Full results achieved that way in some test runs were:
Ex/cls | R50x1 | R152x2 | R101x3 |
---|---|---|---|
1 | 18.36 | 24.5 | 25.55 |
5 | 50.64 | 64.5 | 64.18 |
full | 80.68 | 85.15 | WIP |
These are re-runs and not the exact paper models. The expected VTAB scores for two of the models are:
Model | Full | Natural | Structured | Specialized |
---|---|---|---|---|
BiT-M-R152x4 | 73.51 | 80.77 | 61.08 | 85.67 |
BiT-M-R101x3 | 72.65 | 80.29 | 59.40 | 85.75 |
In Appendix G of our paper, we investigate whether BiT improves out-of-context robustness. To do this, we created a dataset comprising foreground objects corresponding to 21 ILSVRC-2012 classes pasted onto 41 miscellaneous backgrounds.
To download the dataset, run
wget https://storage.googleapis.com/bit-out-of-context-dataset/bit_out_of_context_dataset.zip
Images from each of the 21 classes are kept in a directory with the name of the class.
We release top-performing compressed BiT models from our paper "Knowledge distillation: A good teacher is patient and consistent" on knoweldge distillation. In particular, we distill the BiT-M-R152x2 model (which was pre-trained on ImageNet-21k) to BiT-R50x1 models. As a result, we obtain compact models with very competitive performance.
Model | Download link | Resolution | ImageNet top-1 acc. (paper) |
---|---|---|---|
BiT-R50x1 | link | 224 | 82.8 |
BiT-R50x1 | link | 160 | 80.5 |
For reproducibility, we also release weights of two BiT-M-R152x2 teacher models: pretrained at resolution 224 and resolution 384. See the paper for details on how these teachers were used.
We have no concrete plans for publishing the distillation code, as the recipe is simple and we imagine most people would integrate it in their existing training code. However, Sayak Paul has independently re-implemented the distillation setup in TensorFlow and nearly reproduced our results in several settings.