-
Notifications
You must be signed in to change notification settings - Fork 128
/
vit_i1k_gsam_no_aug.py
134 lines (115 loc) · 4.46 KB
/
vit_i1k_gsam_no_aug.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# Copyright 2022 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=line-too-long
r"""Pre-training ViT on ILSVRC-2012 with GSAM in https://arxiv.org/abs/2203.08065
Run training of a B/32 model:
big_vision.trainers.proj.gsam.train \
--config big_vision/configs/proj/gsam/vit_i1k_gsam_no_aug.py \
--workdir gs://[your_bucket]/big_vision/`date '+%m-%d_%H%M'`
"""
import big_vision.configs.common as bvcc
from big_vision.configs.common_fewshot import get_fewshot_lsr
import ml_collections as mlc
def get_config(arg=None):
"""Config for training."""
arg = bvcc.parse_arg(arg, variant='B/32', runlocal=False)
config = mlc.ConfigDict()
config.dataset = 'imagenet2012'
config.train_split = 'train[:99%]'
config.cache_raw = not arg.runlocal # Needs up to 120GB of RAM!
config.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok.
config.num_classes = 1000
config.loss = 'sigmoid_xent'
config.batch_size = 4096
config.num_epochs = 300
pp_common = (
'|value_range(-1, 1)'
'|onehot(1000, key="{lbl}", key_result="labels")'
'|keep("image", "labels")'
)
config.pp_train = (
'decode_jpeg_and_inception_crop(224)|flip_lr|' +
pp_common.format(lbl='label')
)
pp = 'decode|resize_small(256)|central_crop(224)' + pp_common
# Aggressive pre-fetching because our models here are small, so we not only
# can afford it, but we also need it for the smallest models to not be
# bottle-necked by the input pipeline. Play around with it for -L models tho.
config.prefetch_to_host = 8
config.prefetch_to_device = 4
config.log_training_steps = 50
config.checkpoint_steps = 1000
# Model section
config.model_name = 'vit'
config.model = dict(
variant=arg.variant,
rep_size=False,
pool_type='gap',
)
config.init_head_bias = -10.0
# Optimizer section
config.grad_clip_norm = 1.0
config.optax_name = 'scale_by_adam'
config.optax = dict(mu_dtype='float32')
# The modified AdaFactor we introduced in https://arxiv.org/abs/2106.04560
# almost always behaves exactly like adam, but at a fraction of the memory
# cost (specifically, adam_bf16 = +1.5M, adafactor = +0.5M), hence it is a
# good idea to try it when you are memory-bound!
# config.optax_name = 'big_vision.scale_by_adafactor'
# A good flag to play with when hitting instabilities, is the following:
# config.optax = dict(beta2_cap=0.95)
config.lr = 0.003
config.wd = 0.001 # default is 0.0001; paper used 0.3, effective wd=0.3*lr
config.schedule = dict(
warmup_steps=10_000,
decay_type='linear',
linear_end=0.01,
)
# GSAM settings.
# Note: when rho_max=rho_min and alpha=0, GSAM reduces to SAM.
config.gsam = dict(
rho_max=0.6,
rho_min=0.1,
alpha=0.6,
lr_max=config.get_ref('lr'),
lr_min=config.schedule.get_ref('linear_end') * config.get_ref('lr'),
)
# Eval section
eval_common = dict(
type='classification',
dataset='imagenet2012',
pp_fn=pp.format(lbl='label'),
loss_name=config.loss,
log_steps=2500, # Very fast O(seconds) so it's fine to run it often.
)
config.evals = {}
config.evals.train = {**eval_common, 'split': 'train[:2%]'}
config.evals.minival = {**eval_common, 'split': 'train[99%:]'}
config.evals.val = {**eval_common, 'split': 'validation'}
config.evals.v2 = {**eval_common, 'dataset': 'imagenet_v2', 'split': 'test'}
config.evals.real = {**eval_common}
config.evals.real.dataset = 'imagenet2012_real'
config.evals.real.split = 'validation'
config.evals.real.pp_fn = pp.format(lbl='real_label')
config.fewshot = get_fewshot_lsr(runlocal=arg.runlocal)
config.fewshot.log_steps = 10_000
# Make a few things much smaller for quick local debugging testruns.
if arg.runlocal:
config.shuffle_buffer_size = 10
config.batch_size = 8
config.minival.split = 'train[:16]'
config.val.split = 'validation[:16]'
config.real.split = 'validation[:16]'
config.v2.split = 'test[:16]'
return config