Skip to content

Commit

Permalink
Release distillation and scaling ViT projects.
Browse files Browse the repository at this point in the history
And a bunch of small fixes and improvements we made over time.

Co-authored-by: Xiaohua Zhai <xzhai@google.com>
Co-authored-by: Alexander Kolesnikov <alexanderkolesnikoff@gmail.com>
  • Loading branch information
3 people committed Jun 22, 2022
1 parent e9fb55d commit 2f3f493
Show file tree
Hide file tree
Showing 18 changed files with 1,656 additions and 24 deletions.
33 changes: 27 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,20 @@ codebase:
Xiaohua Zhai*, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer,
Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, and Neil Houlsby*
- [Scaling Vision Transformers](https://arxiv.org/abs/2106.04560), by
Xiaohua Zhai*, Alexander Kolesnikov*, Neil Houlsby, and Lucas Beyer*
Xiaohua Zhai*, Alexander Kolesnikov*, Neil Houlsby, and Lucas Beyer*\
Resources: [config](configs/proj/scaling_laws/train_vit_g.py).
- [How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers](https://arxiv.org/abs/2106.10270), by
Andreas Steiner*, Alexander Kolesnikov*, Xiaohua Zhai*, Ross Wightman,
Jakob Uszkoreit, and Lucas Beyer*
- [MLP-Mixer: An all-MLP Architecture for Vision](https://arxiv.org/abs/2105.01601), by
Ilya Tolstikhin*, Neil Houlsby*, Alexander Kolesnikov*, Lucas Beyer*,
Xiaohua Zhai, Thomas Unterthiner, Jessica Yung, Andreas Steiner,
Daniel Keysers, Jakob Uszkoreit, Mario Lucic, Alexey Dosovitskiy
- [Better plain ViT baselines for ImageNet-1k](https://arxiv.org/abs/2205.01580), by
Lucas Beyer, Xiaohua Zhai, Alexander Kolesnikov\
Resources: [config](big_vision/configs/vit_s16_i1k.py)
- [UViM: A Unified Modeling Approach for Vision with Learned Guiding Codes](https://arxiv.org/abs/2205.10337), by
Alexander Kolesnikov*, André Susano Pinto*, Lucas Beyer*, Xiaohua Zhai*, Jeremiah Harmsen*, Neil Houlsby*

### Multimodal research
- [LiT: Zero-Shot Transfer with Locked-image Text Tuning](https://arxiv.org/abs/2111.07991), by
Expand All @@ -50,7 +56,8 @@ codebase:
### Knowledge distillation
- [Knowledge distillation: A good teacher is patient and consistent](https://arxiv.org/abs/2106.05237), by
Lucas Beyer*, Xiaohua Zhai*, Amélie Royer*, Larisa Markeeva*, Rohan Anil,
and Alexander Kolesnikov*
and Alexander Kolesnikov*\
Resources: [README](big_vision/configs/proj/distill/README.md), [trainer](big_vision/trainers/proj/distill/distill.py), [colab](https://colab.research.google.com/drive/1nMykzUzsfQ_uAxfj3k35DYsATnG_knPl?usp=sharing).

### Misc
- [Are we done with ImageNet?](https://arxiv.org/abs/2006.07159), by
Expand Down Expand Up @@ -90,18 +97,26 @@ gcloud alpha compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all

See instructions below for more details on how to use Google Cloud TPUs.

All runs write checkpoints and logfiles. The logfiles are a list of JSON
objects, and we provide an short and straightforward [example colab to read
and display the logs and checkpoints](https://colab.research.google.com/drive/1R_lvV542WUp8Q2y8sbyooZOGCplkn7KI?usp=sharing).

# Current and future contents

The first release contains the core part of pre-training, transferring, and
evaluating classification models at scale on Cloud TPU VMs.

We have since added the following key features and projects:
- Patient and consistent distillation.
- Scaling ViT.

Features and projects we plan to release in the near future, in no particular
order:
- ImageNet-21k in TFDS.
- MLP-Mixer.
- Loading misc public models used in our publications (NFNet, MoCov3, DINO).
- Contrastive Image-Text model training and evaluation as in LiT and CLIP.
- "Patient and consistent" distillation.
- UViM.
- Memory-efficient Polyak-averaging implementation.
- Advanced JAX compute and memory profiling. We are using internal tools for
this, but may eventually add support for the publicly available ones.
Expand Down Expand Up @@ -154,7 +169,7 @@ dependencies.

```
git clone --branch=master https://github.com/google-research/big_vision
gcloud alpha compute tpus tpu-vm scp --recurse big_vision/big_vision $NAME: --worker=all --zone=$ZONE
gcloud alpha compute tpus tpu-vm scp --recurse big_vision/big_vision $NAME: --zone=$ZONE --worker=all
gcloud alpha compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all --command "bash big_vision/run_tpu.sh"
```

Expand All @@ -165,8 +180,9 @@ also do it on your local machine and copy the result to the cloud bucket. For
convenience, we provide instructions on how to prepare data using Cloud TPUs.

Download and prepare TFDS datasets using a single worker. Seven TFDS datasets
used during evaluations will be generated under `~/tensorflow_datasets/` (should
take 10-15 minutes in total).
used during evaluations will be generated under `~/tensorflow_datasets/` (by
default, can be overwritten by TFDS_DATA_DIR env variable). This should take
10-15 minutes in total.

```
gcloud alpha compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=0 --command "bash big_vision/run_tpu.sh big_vision.tools.download_tfds_datasets cifar10 cifar100 oxford_iiit_pet oxford_flowers102 cars196 dtd uc_merced"
Expand Down Expand Up @@ -206,6 +222,11 @@ run the following command line.
gcloud alpha compute tpus tpu-vm ssh $NAME --zone=$ZONE --worker=all --command "TFDS_DATA_DIR=gs://$GS_BUCKET_NAME/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.train --config big_vision/configs/bit_i1k.py --workdir gs://$GS_BUCKET_NAME/big_vision/workdir/`date '+%m-%d_%H%M'`"
```

## Sometimes useful gcloud commands

- Destroy the TPU machines: `gcloud alpha compute tpus tpu-vm delete $NAME --zone $ZONE`
- Remove all big_vision-related folders on all hosts: `gcloud alpha compute tpus tpu-vm ssh $NAME --zone $ZONE --worker=all --command 'rm -rf ~/big_vision ~/bv_venv'`

# ViT baseline

We provide a well-tuned ViT-S/16 baseline in the config file named
Expand Down
22 changes: 22 additions & 0 deletions big_vision/configs/load_and_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,28 @@ def vit_i1k(config):
)


def mlp_mixer_i1k(config):
# We could omit init_{shapes,types} if we wanted, as they are the default.
config.init_shapes = [(1, 224, 224, 3)]
config.init_types = ['float32']
config.num_classes = 1000

config.model_name = 'mlp_mixer'
config.model_init = '' # Will be set in sweep.
config.model = dict(variant='L/16')

config.evals = {}
config.evals.fewshot = get_fewshot_lsr()
config.evals.val = dict(
type='classification',
dataset='imagenet2012',
split='validation',
pp_fn='decode|resize_small(256)|central_crop(224)|value_range(-1, 1)|onehot(1000, key="label", key_result="labels")|keep("image", "labels")',
loss_name='softmax_xent',
cache_final=False, # Only run once, on low-mem machine.
)


def vit_i21k(config):
# We could omit init_{shapes,types} if we wanted, as they are the default.
config.init_shapes = [(1, 224, 224, 3)]
Expand Down
43 changes: 43 additions & 0 deletions big_vision/configs/proj/distill/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Knowledge distillation: A good teacher is patient and consistent
*by Lucas Beyer, Xiaohua Zhai, Amélie Royer, Larisa Markeeva, Rohan Anil, Alexander Kolesnikov*

## Introduction
We publish all teacher models, and configurations for the main experiments of
the paper, as well as training logs and student models.

Please read the main [big_vision README](/README.md) to learn how to run
configs, and remember that each config file contains an example invocation in
the top-level comment.

## Results

We provide the following [colab to read and plot the logfiles](https://colab.research.google.com/drive/1nMykzUzsfQ_uAxfj3k35DYsATnG_knPl?usp=sharing)
of a few runs that we reproduced on Cloud.

### ImageNet-1k

The file [bit_i1k.py](bit_i1k.py) is the configuration which reproduces our
distillation runs on ImageNet-1k reported in Figures 1 and 5(left) and the first
row of Table1.

We release both student and teacher models:

| Model | Download link | Resolution | ImageNet top-1 acc. (paper) |
| :--- | :---: | :---: | :---: |
| BiT-R50x1 | [link](https://storage.googleapis.com/bit_models/distill/R50x1_160.npz) | 160 | 80.5 |
| BiT-R50x1 | [link](https://storage.googleapis.com/bit_models/distill/R50x1_224.npz) | 224 | 82.8 |
| BiT-R152x2 | [link](https://storage.googleapis.com/bit_models/distill/R152x2_T_224.npz) | 224 | 83.0 |
| BiT-R152x2 | [link](https://storage.googleapis.com/bit_models/distill/R152x2_T_384.npz) | 384 | 84.3 |

### Flowers/Pet/Food/Sun

The files [bigsweep_flowers_pet.py](bigsweep_flowers_pet.py) and
[bigsweep_food_sun.py](bigsweep_food_sun.py) can be used to reproduce the
distillation runs on these datasets and shown in Figures 3,4,9-12, and Table4.

While our open-source release does not currently support doing hyper-parameter
sweeps, we still provide an example of the sweeps at the end of the configs
for reference.

### Teacher models
Links to all teacher models we used can be found in [common.py](common.py).
159 changes: 159 additions & 0 deletions big_vision/configs/proj/distill/bigsweep_flowers_pet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright 2022 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: disable=line-too-long
r"""Distilling BiT-R152x2 into BiT-R50x1 on Flowers/Pet as in https://arxiv.org/abs/2106.05237
While many epochs are required, this is a small dataset, and thus overall it
is still fast and possible to run on the relatively small v3-8TPUs (or GPUs).
This configuration contains the recommended settings from Fig3/Tab4 of the
paper, which can be selected via the fast/medium/long config argument.
(best settings were selected on a 10% minival)
For Flowers:
- The `fast` variant takes ~1h10m on a v2-8 TPU.
Example logs at gs://big_vision/distill/bit_flowers_fast_06-18_2008/big_vision_metrics.txt
- The `long` variant takes ~25h on a v3-32 TPU.
Example logs at gs://big_vision/distill/bit_flowers_long_06-19_0524/big_vision_metrics.txt
For Pet:
- The `fast` variant takes ~28min on a v2-8 TPU.
Example logs at gs://big_vision/distill/bit_pet_fast_06-16_2338/big_vision_metrics.txt
- The `long` variant takes ~11h on a v2-8 and ~8h on a v3-32.
Example logs at gs://big_vision/distill/bit_pet_long_06-17_0050/big_vision_metrics.txt
big_vision.trainers.proj.distill.distill \
--config big_vision/configs/proj/distill/bigsweep_flowers_pet.py:data=flowers,variant=fast \
--workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'` \
"""

import big_vision.configs.common as bvcc
import big_vision.configs.proj.distill.common as cd
import ml_collections as mlc

NCLS = dict(flowers=102, pet=37)


def get_config(arg=None):
"""Config for massive hypothesis-test on pet."""
arg = bvcc.parse_arg(arg, runlocal=False, data='flowers', variant='medium', crop='inception_crop(128)')
config = mlc.ConfigDict()

config.dataset = dict(flowers='oxford_flowers102', pet='oxford_iiit_pet')[arg.data]
config.cache_raw = True
config.prefetch_to_device = 4
config.train_split = dict(flowers='train', pet='train[:90%]')[arg.data]
config.num_classes = NCLS[arg.data]

config.batch_size = 512
config.num_epochs = {
'flowers': {'fast': 10_000, 'medium': 100_000, 'long': 1_000_000},
'pet': {'fast': 1000, 'medium': 3000, 'long': 30_000},
}[arg.data][arg.variant]
config.shuffle_buffer_size = 50_000

config.log_training_steps = 100
config.checkpoint_steps = 2500

# Model section
config.student_name = 'bit_paper'
config.student = dict(depth=50, width=1)

config.teachers = ['prof_m']
config.prof_m_name = 'bit_paper'
config.prof_m_init = cd.inits[f'BiT-M R152x2 {arg.data} rc128']
config.prof_m = dict(depth=152, width=2)

# Preprocessing pipeline for student & tacher.
pp_common = (
'|value_range(-1, 1)'
f'|onehot({config.num_classes}, key="label", key_result="labels")'
'|keep("image", "labels")'
)
config.pp_train = f'decode|{arg.crop}|flip_lr' + pp_common
ppv = 'decode|resize_small(160)|central_crop(128)' + pp_common

config.mixup = dict(p=1.0, n=2)

# Distillation settings
config.distance = 'kl'
config.distance_kw = dict(t={
'flowers': {'fast': 10., 'medium': 1., 'long': 1.},
'pet': {'fast': 5., 'medium': 10., 'long': 2.},
}[arg.data][arg.variant])

# Optimizer section
config.grad_clip_norm = 1.0
config.optax_name = 'scale_by_adam'
config.optax = dict(mu_dtype='bfloat16')

config.lr = {
'flowers': {'fast': 0.003, 'medium': 0.001, 'long': 0.0003},
'pet': {'fast': 0.01, 'medium': 0.003, 'long': 0.003},
}[arg.data][arg.variant]
config.wd = {
'flowers': {'fast': 3e-4, 'medium': 1e-4, 'long': 1e-5},
'pet': {'fast': 1e-3, 'medium': 3e-4, 'long': 1e-5},
}[arg.data][arg.variant]
config.schedule = dict(warmup_steps=1500, decay_type='cosine')
config.optim_name = 'adam_hp'

# Eval section
minitrain_split = 'train[:512]' if not arg.runlocal else 'train[:16]'
if arg.data == 'flowers':
val_split = 'validation' if not arg.runlocal else 'validation[:16]'
test_split = 'test' if not arg.runlocal else 'test[:16]'
elif arg.data == 'pet':
val_split = 'train[90%:]' if not arg.runlocal else 'train[:16]'
test_split = 'test' if not arg.runlocal else 'test[:16]'

base = dict(
type='classification',
pred='student_fwd',
dataset=config.dataset,
pp_fn=ppv,
loss_name='softmax_xent',
log_steps=500,
)
config.evals = {}
config.evals.student_train = {**base, 'split': minitrain_split}
config.evals.student_val = {**base, 'split': val_split}
config.evals.student_test = {**base, 'split': test_split}

# Teacher is fixed, so rare evals.
teacher = dict(log_steps=100_000, pred='prof_m_fwd')
config.evals.teacher_train = {**config.evals.student_train, **teacher}
config.evals.teacher_val = {**config.evals.student_val, **teacher}
config.evals.teacher_test = {**config.evals.student_test, **teacher}

# Could in principle also look at agreement on other datasets!
dist = dict(
type='proj.distill.distance',
pred='student_prof_m_fwd',
dataset=config.dataset,
pp_fn=ppv + '|keep("image")',
log_steps=1000,
distances=({'kind': 'kl'}, {'kind': 'euclidean'},
{'kind': 'agree', 'k': 1}, {'kind': 'agree', 'k': 5}),
)
config.evals.dist_train = {**dist, 'split': minitrain_split}
config.evals.dist_val = {**dist, 'split': val_split}
config.evals.dist_test = {**dist, 'split': test_split}

# Make a few things much smaller for quick local debugging testruns.
if arg.runlocal:
config.shuffle_buffer_size = 10
config.batch_size = 8

return config
Loading

0 comments on commit 2f3f493

Please sign in to comment.