-
Notifications
You must be signed in to change notification settings - Fork 174
/
train.py
153 lines (124 loc) · 5.29 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# coding: utf-8
from functools import partial
import time
import os
import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()
import bit_common
import bit_hyperrule
import bit_tf2.models as models
import input_pipeline_tf2_or_jax as input_pipeline
def reshape_for_keras(features, batch_size, crop_size):
features["image"] = tf.reshape(features["image"], (batch_size, crop_size, crop_size, 3))
features["label"] = tf.reshape(features["label"], (batch_size, -1))
return (features["image"], features["label"])
class BiTLRSched(tf.keras.callbacks.Callback):
def __init__(self, base_lr, num_samples):
self.step = 0
self.base_lr = base_lr
self.num_samples = num_samples
def on_train_batch_begin(self, batch, logs=None):
lr = bit_hyperrule.get_lr(self.step, self.num_samples, self.base_lr)
tf.keras.backend.set_value(self.model.optimizer.lr, lr)
self.step += 1
def main(args):
tf.io.gfile.makedirs(args.logdir)
logger = bit_common.setup_logger(args)
logger.info(f'Available devices: {tf.config.list_physical_devices()}')
tf.io.gfile.makedirs(args.bit_pretrained_dir)
bit_model_file = os.path.join(args.bit_pretrained_dir, f'{args.model}.h5')
if not tf.io.gfile.exists(bit_model_file):
model_url = models.KNOWN_MODELS[args.model]
logger.info(f'Downloading the model from {model_url}...')
tf.io.gfile.copy(model_url, bit_model_file)
# Set up input pipeline
dataset_info = input_pipeline.get_dataset_info(
args.dataset, 'train', args.examples_per_class)
# Distribute training
strategy = tf.distribute.MirroredStrategy()
num_devices = strategy.num_replicas_in_sync
print('Number of devices: {}'.format(num_devices))
resize_size, crop_size = bit_hyperrule.get_resolution_from_dataset(args.dataset)
data_train = input_pipeline.get_data(
dataset=args.dataset, mode='train',
repeats=None, batch_size=args.batch,
resize_size=resize_size, crop_size=crop_size,
examples_per_class=args.examples_per_class,
examples_per_class_seed=args.examples_per_class_seed,
mixup_alpha=bit_hyperrule.get_mixup(dataset_info['num_examples']),
num_devices=num_devices,
tfds_manual_dir=args.tfds_manual_dir)
data_test = input_pipeline.get_data(
dataset=args.dataset, mode='test',
repeats=1, batch_size=args.batch,
resize_size=resize_size, crop_size=crop_size,
examples_per_class=1, examples_per_class_seed=0,
mixup_alpha=None,
num_devices=num_devices,
tfds_manual_dir=args.tfds_manual_dir)
data_train = data_train.map(lambda x: reshape_for_keras(
x, batch_size=args.batch, crop_size=crop_size))
data_test = data_test.map(lambda x: reshape_for_keras(
x, batch_size=args.batch, crop_size=crop_size))
with strategy.scope():
filters_factor = int(args.model[-1])*4
model = models.ResnetV2(
num_units=models.NUM_UNITS[args.model],
num_outputs=21843,
filters_factor=filters_factor,
name="resnet",
trainable=True,
dtype=tf.float32)
model.build((None, None, None, 3))
logger.info(f'Loading weights...')
model.load_weights(bit_model_file)
logger.info(f'Weights loaded into model!')
model._head = tf.keras.layers.Dense(
units=dataset_info['num_classes'],
use_bias=True,
kernel_initializer="zeros",
trainable=True,
name="head/dense")
lr_supports = bit_hyperrule.get_schedule(dataset_info['num_examples'])
schedule_length = lr_supports[-1]
# NOTE: Let's not do that unless verified necessary and we do the same
# across all three codebases.
# schedule_length = schedule_length * 512 / args.batch
optimizer = tf.keras.optimizers.SGD(momentum=0.9)
loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
model.compile(optimizer=optimizer, loss=loss_fn, metrics=['accuracy'])
logger.info(f'Fine-tuning the model...')
steps_per_epoch = args.eval_every or schedule_length
history = model.fit(
data_train,
steps_per_epoch=steps_per_epoch,
epochs=schedule_length // steps_per_epoch,
validation_data=data_test, # here we are only using
# this data to evaluate our performance
callbacks=[BiTLRSched(args.base_lr, dataset_info['num_examples'])],
)
for epoch, accu in enumerate(history.history['val_accuracy']):
logger.info(
f'Step: {epoch * args.eval_every}, '
f'Test accuracy: {accu:0.3f}')
if __name__ == "__main__":
parser = bit_common.argparser(models.KNOWN_MODELS.keys())
parser.add_argument("--tfds_manual_dir", default=None,
help="Path to maually downloaded dataset.")
parser.add_argument("--batch_eval", default=32, type=int,
help="Eval batch size.")
main(parser.parse_args())