Skip to content

Commit

Permalink
Add UViM project (+misc. changes)
Browse files Browse the repository at this point in the history
Co-authored-by: André Susano Pint <andresp@google.com>
  • Loading branch information
akolesnikoff and andresusanopinto committed Jul 26, 2022
1 parent 6ff6d08 commit 21bd6eb
Show file tree
Hide file tree
Showing 51 changed files with 15,447 additions and 91 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,14 @@ codebase:
Resources: [config](big_vision/configs/vit_s16_i1k.py)
- [UViM: A Unified Modeling Approach for Vision with Learned Guiding Codes](https://arxiv.org/abs/2205.10337), by
Alexander Kolesnikov^*, André Susano Pinto^*, Lucas Beyer*, Xiaohua Zhai*, Jeremiah Harmsen*, Neil Houlsby*
Resources: [readme](big_vision/configs/proj/uvim/README.md) [configs](big_vision/configs/proj/uvim), [colabs](big_vision/configs/proj/uvim).

### Multimodal research

- [LiT: Zero-Shot Transfer with Locked-image Text Tuning](https://arxiv.org/abs/2111.07991), by
Xiaohua Zhai*, Xiao Wang*, Basil Mustafa*, Andreas Steiner*, Daniel Keysers,
Alexander Kolesnikov, and Lucas Beyer*\
Resources: [trainer](trainers/proj/image_text/contrastive.py), [config](configs/proj/image_text/lit_coco.py), [colab](https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/image_text/lit.ipynb).
Resources: [trainer](big_vision/trainers/proj/image_text/contrastive.py), [config](big_vision/configs/proj/image_text/lit_coco.py), [colab](https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/image_text/lit.ipynb).

### Knowledge distillation

Expand Down Expand Up @@ -114,12 +115,12 @@ We have since added the following key features and projects:
- Patient and consistent distillation.
- Scaling ViT.
- MLP-Mixer.
- UViM.

Features and projects we plan to release in the near future, in no particular
order:
- ImageNet-21k in TFDS.
- Loading misc public models used in our publications (NFNet, MoCov3, DINO).
- UViM.
- Memory-efficient Polyak-averaging implementation.
- Advanced JAX compute and memory profiling. We are using internal tools for
this, but may eventually add support for the publicly available ones.
Expand Down
2 changes: 1 addition & 1 deletion big_vision/configs/bit_i1k.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def get_config(runlocal=False):

config.seed = 0
config.batch_size = 4096 if not runlocal else 32
config.num_epochs = 90
config.total_epochs = 90

pp_common = '|onehot(1000, key="{lbl}", key_result="labels")'
pp_common += '|value_range(-1, 1)|keep("image", "labels")'
Expand Down
14 changes: 9 additions & 5 deletions big_vision/configs/bit_i21k.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,14 @@ def get_config():

config.trial = 0
config.batch_size = 4096
config.num_epochs = 90
config.total_epochs = 90

pp_common = f'|value_range(-1, 1)|onehot({config.num_classes})'
config.pp_train = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common
pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common
pp_common = '|value_range(-1, 1)|onehot({onehot_args})|keep("image", "labels")'
pp_common_i21k = pp_common.format(onehot_args=f'{config.num_classes}')
pp_common_i1k = pp_common.format(onehot_args='1000, key="label", key_result="labels"')
config.pp_train = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common_i21k
pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common_i21k
pp_eval_i1k = 'decode|resize_small(256)|central_crop(224)' + pp_common_i1k
config.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok.

config.log_training_steps = 50
Expand All @@ -63,7 +66,6 @@ def get_config():
eval_common = dict(
type='classification',
dataset=config.dataset,
data_dir=config.dataset_dir,
pp_fn=pp_eval,
loss_name=config.loss,
log_steps=1000, # Very fast O(seconds) so it's fine to run it often.
Expand All @@ -72,6 +74,8 @@ def get_config():
config.evals.test = {**eval_common, 'split': 'full[:25_600]'}
config.evals.val = {**eval_common, 'split': 'full[25_600:51_200]'}
config.evals.train = {**eval_common, 'split': 'full[51_200:76_800]'}

# Few-shot evaluators
config.evals.fewshot = get_fewshot_lsr()
config.evals.fewshot.log_steps = 25_000

Expand Down
5 changes: 5 additions & 0 deletions big_vision/configs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,8 @@ def autotype(x):
return float(x) # Returns as float.
except ValueError:
return x # Returns as str.


def pack_arg(**kw):
"""Packs key-word args as a string to be parsed by `parse_arg()`."""
return ','.join([f'{k}={v}' for k, v in kw.items()])
6 changes: 3 additions & 3 deletions big_vision/configs/mlp_mixer_i1k.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def get_config(mode=None):
)

config.batch_size = 4096
config.num_epochs = 300
config.total_epochs = 300

config.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok.

Expand Down Expand Up @@ -107,10 +107,10 @@ def get_config(mode=None):
config.fewshot = get_fewshot_lsr()

if mode == 'gpu8':
config.num_epochs = 60
config.total_epochs = 60
config.batch_size = 512
config.cache_raw = False
if mode == 'regression_test':
config.num_epochs = 60
config.total_epochs = 60

return config
84 changes: 84 additions & 0 deletions big_vision/configs/proj/uvim/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# UViM: A Unified Modeling Approach for Vision with Learned Guiding Codes

*by Alexander Kolesnikov, André Susano Pinto, Lucas Beyer, Xiaohua Zhai, Jeremiah Harmsen, Neil Houlsby*

We provide pretrained UViM models from the [original paper](https://arxiv.org/abs/2205.10337),
as well as the instructions on how to reproduce core paper experiments.

## Pretrained models

The table below contains UViM models (stage I and II) trained for three
different tasks: panoptic segmentation, colorization and depth prediction.

| task | model | dataset | accuracy | download link |
| --------------------- | ------------------- | ------------------------------------------------------------------------ | ------------ | ----------------------------------------------------------------------------------------- |
| Panoptic segmentation | UViM Stage I model | [COCO(2017)](https://cocodataset.org/#home) | 75.8 PQ | [link](https://storage.googleapis.com/big_vision/uvim/panoptic_stageI_params.npz) |
| Panoptic segmentation | UViM Stage II model | [COCO(2017)](https://cocodataset.org/#home) | 43.1 PQ | [link](https://storage.googleapis.com/big_vision/uvim/panoptic_stageII_params.npz) |
| Colorization | UViM Stage I model | [ILSVRC-2012](https://www.image-net.org/) | 15.59 PQ | [link](https://storage.googleapis.com/big_vision/uvim/color_stageI_params.npz) |
| Colorization | UViM Stage II model | [ILSVRC-2012](https://www.image-net.org/) | 16.99 FID | [link](https://storage.googleapis.com/big_vision/uvim/color_stageII_params.npz) |
| Depth | UViM Stage I model | [NYU Depth V2](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html) | 0.155 RMSE | [link](https://storage.googleapis.com/big_vision/uvim/depth_stageI_params.npz) |
| Depth | UViM Stage II model | [NYU Depth V2](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html) | 0.463 RMSE | [link](https://storage.googleapis.com/big_vision/uvim/depth_stageII_params.npz) |

All of this models can be interactively explored in our [colabs](configs/proj/uvim).

## Running on a single-host TPU machine

Below we provide instructions on how to run UViM training (stage I and
stage II) using a single TPU host with 8 TPU accelerators. These instructions
can be easily adapted to a GPU host and multi-host TPU setup, see the main
`big_vision` [README file](README.md).

We assume that the user has already created and `ssh`-ed to the TPU host
machine. The next step is to clone `big_vision` repository:
`git clone https://github.com/google-research/big_vision.git`.

The next steps are to create a python virtual environment and install python
dependencies:
```
virtualenv bv
source bv/bin/activate
cd big_vision/
pip3 install --upgrade pip
pip3 install -r big_vision/requirements.txt
pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
```

After this invoke the helper tool to download and prepare data:
`python3 -m big_vision.tools.download_tfds_datasets coco/2017_panoptic nyu_depth_v2`.
For preparing the ImageNet dataset consult the main codebase README.

> :warning: TPU machines have 100 GB of the disk space. It may not be enough to
> store all training data (though only panoptic or only depth data may fit).
> Consider preparing the data on a seperate machine and then copying it to
> to TPU machine's extra persistent disk or to a Google Cloud Bucket. See
> instructions for [creating an extra persistent disk](https://cloud.google.com/tpu/docs/users-guide-tpu-vm).
> Remember to set the correct data home directory, e.g.`export DISK=/mnt/disk/persist; export TFDS_DATA_DIR=$DISK/tensorflow_datasets`.
Our panoptic evaluator uses raw variant of the COCO data, so we move it into a
separate folder. Note, `tfds` has already pre-downloaded the panoptic data,
except for one small json file that we fetch manually:
```
mkdir $DISK/coco_data
cd $DISK/coco_data
mv $TFDS_DATA_DIR/downloads/extracted/ZIP.image.cocod.org_annot_panop_annot_train<REPLACE_ME_WITH_THE_HASH_CODE>.zip/annotations/* .
wget https://raw.githubusercontent.com/cocodataset/panopticapi/master/panoptic_coco_categories.json
export COCO_DATA_DIR=$DISK/coco_data
```

For FID evaluator, which is used for the colorization model, set the path to the
directory with image id files, e.g.
`export FID_DATA_DIR=<ROOT>/big_vision/evaluators/proj/uvim/coltran_fid_data`.

As an example, stage I panoptic training can be invoked as (note the `:singlehost` config parameter which will use lightweight configuration suitable for a single host):
```
python3 -m big_vision.trainers.proj.uvim.vqvae --config big_vision/configs/proj/uvim/vqvae_coco_panoptic.py:singlehost --workdir workdirs/`date '+%m-%d_%H%M'`
```
or stage II training
```
python3 -m big_vision.trainers.proj.uvim.train --config big_vision/configs/proj/uvim/train_coco_panoptic_pretrained.py:singlehost --workdir workdirs/`date '+%m-%d_%H%M'`
```

## Acknowledgments
The sampling code in `models/proj/uvim/decode.py` module is based on contributions
from Anselm Levskaya, Ilya Tolstikhin and Maxim Neumann.

166 changes: 166 additions & 0 deletions big_vision/configs/proj/uvim/train_coco_panoptic_pretrained.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Copyright 2022 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: disable=line-too-long
r"""A config for training a UViM stage II model for the panoptic task.
This config is expected to reproduce the paper's result and achieve
approximately 43.7 PQ points on the COCO holdout data.
We also provide a low-resource variant of this config, which can be enabled
by adding `:singlehost` postfix to the config name. This one is expected to
achieve 39.4 PQ points on the COCO holdout data.
"""

import big_vision.configs.common as bvcc
from ml_collections import ConfigDict

VTT_MODELS = {
'base': dict(num_layers=12, num_heads=12, mlp_dim=3072, emb_dim=768),
'large': dict(num_layers=24, num_heads=16, mlp_dim=4096, emb_dim=1024),
}

VQVAE_MODELS = {
'base': dict(enc_depth=6, dec_depth=12, num_heads=12, mlp_dim=3072, width=768),
}

RES = 512
PATCH_SIZE = 16
LABEL_RES = 512
LABEL_PATCH_SIZE = 16


def get_config(arg=''):
"""Config for training."""
arg = bvcc.parse_arg(arg, runlocal=False, singlehost=False)
config = ConfigDict()

config.pp_train = (
f'decode|coco_panoptic|concat(["semantics","instances"], "labels")|'
f'randu("fliplr")|det_fliplr(key="image")|det_fliplr(key="labels")|'
f'inception_box|crop_box(key="image")|crop_box(key="labels")|'
f'resize({LABEL_RES}, inkey="image", outkey="image_ctx")|'
f'resize({RES})|resize({LABEL_RES},key="labels",method="nearest")|'
f'value_range(-1, 1, key="image_ctx")|'
f'value_range(-1, 1)|make_canonical|keep("image","image_ctx","labels")'
)
pp_eval = (
f'decode|coco_panoptic|concat(["semantics","instances"], "labels")|'
f'resize({LABEL_RES}, inkey="image", outkey="image_ctx")|'
f'resize({RES})|resize({LABEL_RES},key="labels",method="nearest")|'
f'value_range(-1, 1, key="image_ctx")|'
f'value_range(-1, 1)|make_canonical|keep("image","image_ctx","labels")'
)
pp_predict = (
f'resize({LABEL_RES}, inkey="image", outkey="image_ctx")|resize({RES})|'
f'value_range(-1, 1, key="image_ctx")|value_range(-1, 1)|'
f'keep("image","image_ctx","image/id")' # image/id used for rng seeds.
)

config.dataset = 'coco/2017_panoptic'
config.train_split = 'train[4096:]'

config.batch_size = 512
config.total_epochs = 200

config.log_training_steps = 50
config.shuffle_buffer_size = 50_000
config.ckpt_steps = 1000
config.keep_ckpt_steps = 5000
config.ckpt_timeout = 1
config.prefetch_to_device = 2
config.trial = 0

# Optimizer section
config.optax_name = 'big_vision.scale_by_adafactor'
config.optax = dict(beta2_cap=0.95)

config.lr = 0.001
config.wd = 0.000001
config.lr_mults = [
('pos_embedding_encoder.*', 0.1),
('EmbedPatches.*', 0.1),
('encoder.*', 0.1),
('decoder.*', 1.0)
]
config.schedule = dict(decay_type='cosine', warmup_steps=4_000)

# Oracle section
config.oracle = ConfigDict()
config.oracle.task = 'proj.uvim.panoptic_task'
config.oracle.model_init = 'gs://big_vision/uvim/panoptic_stageI_params.npz'
config.oracle.model_name = 'proj.uvim.vit'
config.oracle.model = ConfigDict(VQVAE_MODELS['base'])
config.oracle.model.input_size = (LABEL_RES, LABEL_RES)
config.oracle.model.patch_size = (LABEL_PATCH_SIZE, LABEL_PATCH_SIZE)
config.oracle.model.code_len = 256
config.oracle.model.dict_size = 4096
config.oracle.model.codeword_dim = 768
config.oracle.model.with_encoder_ctx = True
config.oracle.model.with_decoder_ctx = True
config.oracle.model.code_dropout = 'random'
config.oracle.model.bottleneck_resize = True
config.oracle.model.inputs = {
'semantics': (133 + 1, LABEL_PATCH_SIZE**2), # +1 for void label
'instances': (100, LABEL_PATCH_SIZE**2), # COCO: actually 98 train/78 validation.
}
config.oracle.model.outputs = config.oracle.model.inputs

# Model section
config.model_name = 'proj.uvim.vtt'
# config.model_init = {'encoder': 'howto-i21k-B/8'}
config.model_init = {'encoder': 'howto-i21k-L/16'}
config.model = ConfigDict(VTT_MODELS['large'])
config.model.patches = ConfigDict({'size': (PATCH_SIZE, PATCH_SIZE)})
config.model.vocab_size = config.oracle.model.get_ref('dict_size') + 1
config.model.posemb_type = 'learn'
config.model.input_size = (RES, RES)
config.model.seq_len = config.oracle.model.get_ref('code_len')

# Evaluation section
config.evals = {}
config.evals.val = ConfigDict()
config.evals.val.type = 'proj.uvim.compute_mean'
config.evals.val.pred = 'validation'
config.evals.val.dataset = config.dataset
config.evals.val.split = 'train[:4096]'
config.evals.val.pp_fn = pp_eval
config.evals.val.log_steps = 1000

base = {
'type': 'proj.uvim.coco_panoptic',
'pp_fn': pp_predict,
'log_steps': 10_000,
# Filters objects that occupy less than 0.03^2 fraction of all pixels.
# 'predict_kwargs': {'min_fraction': 0.03 ** 2},
}
config.evals.coco_panoptic_train = dict(**base, split='train[4096:8192]')
config.evals.coco_panoptic_holdout = dict(**base, split='train[:4096]')
config.evals.coco_panoptic = dict(**base, split='validation')

# config.evals.save_pred = dict(type='proj.uvim.save_predictions')
# config.evals.save_pred.pp = pp_eval.replace('decode|', '')
# config.evals.save_pred.log_steps = 100_000
# config.evals.save_pred.dataset = config.dataset
# config.evals.save_pred.split = 'validation[:1024]'
# config.evals.save_pred.outfile = 'inference.npz'

if arg.singlehost:
config.batch_size = 32
config.num_epochs = 50
elif arg.runlocal:
config.batch_size = 4
config.shuffle_buffer_size = 10
config.evals.val.split = 'train[:16]'
return config
Loading

0 comments on commit 21bd6eb

Please sign in to comment.