-
Notifications
You must be signed in to change notification settings - Fork 0
/
boss.py
230 lines (201 loc) · 10.7 KB
/
boss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import os
import math
import sys
import numpy as np
import tensorflow as tf
from absl import app
from absl import flags
from tqdm import trange
from cta_boss.cta_remixmatch import CTAReMixMatch
from libml import data, utils, augment, ctaugment
FLAGS = flags.FLAGS
class AugmentPoolCTACutOut(augment.AugmentPoolCTA):
@staticmethod
def numpy_apply_policies(arglist):
x, cta, probe = arglist
if x.ndim == 3:
assert probe
policy = cta.policy(probe=True)
return dict(policy=policy,
probe=ctaugment.apply(x, policy),
image=x)
assert not probe
cutout_policy = lambda: cta.policy(probe=False) + [ctaugment.OP('cutout', (1,))]
return dict(image=np.stack([x[0]] + [ctaugment.apply(y, cutout_policy()) for y in x[1:]]).astype('f'))
class Boss(CTAReMixMatch):
AUGMENT_POOL_CLASS = AugmentPoolCTACutOut
def train(self, train_nimg, report_nimg):
if FLAGS.eval_ckpt:
self.eval_checkpoint(FLAGS.eval_ckpt)
return
batch = FLAGS.batch
train_labeled = self.dataset.train_labeled.repeat().shuffle(FLAGS.shuffle).parse().augment()
train_labeled = train_labeled.batch(batch).prefetch(16).make_one_shot_iterator().get_next()
train_unlabeled = self.dataset.train_unlabeled.repeat().shuffle(FLAGS.shuffle).parse().augment()
train_unlabeled = train_unlabeled.batch(batch * self.params['uratio']).prefetch(16)
train_unlabeled = train_unlabeled.make_one_shot_iterator().get_next()
train_original = self.dataset.train_original.repeat().shuffle(False).parse().augment()
train_original = train_original.batch(50000).prefetch(16).make_one_shot_iterator().get_next()
scaffold = tf.train.Scaffold(saver=tf.train.Saver(max_to_keep=FLAGS.keep_ckpt,
pad_step_number=10))
with tf.Session(config=utils.get_config()) as sess:
self.session = sess
self.cache_eval()
with tf.train.MonitoredTrainingSession(
scaffold=scaffold,
checkpoint_dir=self.checkpoint_dir,
config=utils.get_config(),
save_checkpoint_steps=FLAGS.save_kimg << 10,
save_summaries_steps=report_nimg - batch) as train_session:
self.session = train_session._tf_sess()
gen_labeled = self.gen_labeled_fn(train_labeled)
gen_unlabeled = self.gen_unlabeled_fn(train_unlabeled)
self.tmp.step = self.session.run(self.step)
while self.tmp.step < train_nimg:
loop = trange(self.tmp.step % report_nimg, report_nimg, batch,
leave=False, unit='img', unit_scale=batch,
desc='Epoch %d/%d' % (1 + (self.tmp.step // report_nimg), train_nimg // report_nimg))
for _ in loop:
self.train_step(train_session, gen_labeled, gen_unlabeled)
while self.tmp.print_queue:
loop.write(self.tmp.print_queue.pop(0))
while self.tmp.print_queue:
print(self.tmp.print_queue.pop(0))
def model(self, batch, lr, wd, wu, mom, delT, confidence, balance, uratio, ema=0.999, **kwargs):
hwc = [self.dataset.height, self.dataset.width, self.dataset.colors]
xt_in = tf.placeholder(tf.float32, [batch] + hwc, 'xt') # Training labeled
x_in = tf.placeholder(tf.float32, [None] + hwc, 'x') # Eval images
y_in = tf.placeholder(tf.float32, [batch * uratio, 2] + hwc, 'y') # Training unlabeled (weak, strong)
l_in = tf.placeholder(tf.int32, [batch], 'labels') # Labels
lrate = tf.clip_by_value(tf.to_float(self.step) / (FLAGS.train_kimg << 10), 0, 1)
lr *= tf.cos(lrate * (7 * np.pi) / (2 * 8))
tf.summary.scalar('monitors/lr', lr)
# Compute logits for xt_in and y_in
classifier = lambda x, **kw: self.classifier(x, **kw, **kwargs).logits
skip_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
x = utils.interleave(tf.concat([xt_in, y_in[:, 0], y_in[:, 1]], 0), 2 * uratio + 1)
logits = utils.para_cat(lambda x: classifier(x, training=True), x)
logits = utils.de_interleave(logits, 2 * uratio+1)
post_ops = [v for v in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if v not in skip_ops]
logits_x = logits[:batch]
logits_weak, logits_strong = tf.split(logits[batch:], 2)
del logits, skip_ops
# Labeled cross-entropy
loss_xe = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=l_in, logits=logits_x)
loss_xe = tf.reduce_mean(loss_xe)
tf.summary.scalar('losses/xe', loss_xe)
# Pseudo-label cross entropy for unlabeled data
pseudo_labels = tf.stop_gradient(tf.nn.softmax(logits_weak))
pLabels = tf.math.argmax(pseudo_labels, axis=1)
loss_xeu = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=pLabels, logits=logits_strong)
####################### Modification
pLabels = tf.cast(pLabels,dtype=tf.float32)
classes, idx, counts = tf.unique_with_counts(pLabels)
shape = tf.constant([self.dataset.nclass])
classes = tf.cast(classes,dtype=tf.int32)
class_count = tf.scatter_nd(tf.reshape(classes,[tf.size(classes),1]),counts, shape)
print_cc = tf.print("class_count ", class_count, output_stream=sys.stdout)
class_count = tf.cast(class_count,dtype=tf.float32)
mxCount = tf.reduce_max(class_count, axis=0)
if balance > 0:
pLabels = tf.cast(pLabels,dtype=tf.int32)
if balance == 1 or balance == 4:
confidences = tf.zeros_like(pLabels,dtype=tf.float32)
ratios = 1.0 - tf.math.divide_no_nan(class_count, mxCount)
ratios = confidence - delT*ratios
confidences = tf.gather_nd(ratios, tf.reshape(pLabels,[tf.size(pLabels),1]) )
pseudo_mask = tf.reduce_max(pseudo_labels, axis=1) >= confidences
else:
pseudo_mask = tf.reduce_max(pseudo_labels, axis=1) >= confidence
if balance == 3 or balance == 4:
classes, idx, counts = tf.unique_with_counts(tf.boolean_mask(pLabels,pseudo_mask))
shape = tf.constant([self.dataset.nclass])
classes = tf.cast(classes,dtype=tf.int32)
class_count = tf.scatter_nd(tf.reshape(classes,[tf.size(classes),1]),counts, shape)
class_count = tf.cast(class_count,dtype=tf.float32)
pseudo_mask = tf.cast(pseudo_mask,dtype=tf.float32)
if balance > 1:
ratios = tf.math.divide_no_nan(tf.ones_like(class_count,dtype=tf.float32),class_count)
ratio = tf.gather_nd(ratios, tf.reshape(pLabels,[tf.size(pLabels),1]) )
# ratio = sum(pseudo_mask) * ratio / sum(ratio)
Z = tf.reduce_sum(pseudo_mask)
pseudo_mask = tf.math.multiply(pseudo_mask, tf.cast(ratio,dtype=tf.float32))
pseudo_mask = tf.math.divide_no_nan(tf.scalar_mul(Z, pseudo_mask), tf.reduce_sum(pseudo_mask))
else:
pseudo_mask = tf.cast(tf.reduce_max(pseudo_labels, axis=1) >= confidence,dtype=tf.float32)
###################### End
# tf.print(" class_count= ",class_count)
tf.summary.scalar('monitors/mask', tf.reduce_mean(pseudo_mask))
loss_xeu = tf.reduce_mean(loss_xeu * pseudo_mask)
tf.summary.scalar('losses/xeu', loss_xeu)
# L2 regularization
loss_wd = sum(tf.nn.l2_loss(v) for v in utils.model_vars('classify') if 'kernel' in v.name)
tf.summary.scalar('losses/wd', loss_wd)
ema = tf.train.ExponentialMovingAverage(decay=ema)
ema_op = ema.apply(utils.model_vars())
ema_getter = functools.partial(utils.getter_ema, ema)
post_ops.append(ema_op)
# train_op = tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True).minimize(
train_op = tf.train.MomentumOptimizer(lr, mom, use_nesterov=True).minimize(
loss_xe + wu * loss_xeu + wd * loss_wd, colocate_gradients_with_ops=True)
with tf.control_dependencies([train_op]):
train_op = tf.group(*post_ops)
return utils.EasyDict(
xt=xt_in, x=x_in, y=y_in, label=l_in, train_op=train_op,
classify_raw=tf.nn.softmax(classifier(x_in, training=False)), # No EMA, for debugging.
classify_op=tf.nn.softmax(classifier(x_in, getter=ema_getter, training=False)))
def main(argv):
utils.setup_main()
del argv # Unused.
dataset = data.PAIR_DATASETS()[FLAGS.dataset]()
log_width = utils.ilog2(dataset.width)
model = Boss(
os.path.join(FLAGS.train_dir, dataset.name, FixMatch.cta_name()),
dataset,
lr=FLAGS.lr,
wd=FLAGS.wd,
arch=FLAGS.arch,
batch=FLAGS.batch,
nclass=dataset.nclass,
wu=FLAGS.wu,
mom=FLAGS.mom,
delT=FLAGS.delT,
confidence=FLAGS.confidence,
balance=FLAGS.balance,
uratio=FLAGS.uratio,
scales=FLAGS.scales or (log_width - 2),
filters=FLAGS.filters,
repeat=FLAGS.repeat)
model.train(FLAGS.train_kimg << 10, FLAGS.report_kimg << 10)
if __name__ == '__main__':
utils.setup_tf()
flags.DEFINE_float('confidence', 0.95, 'Confidence threshold.')
flags.DEFINE_float('wd', 0.0005, 'Weight decay.')
flags.DEFINE_float('wu', 1, 'Pseudo label loss weight.')
flags.DEFINE_float('mom', 0.9, 'Momentum coefficient.')
flags.DEFINE_float('delT', 0.2, 'The amount balance=1 can reduce the confidence threshold.')
flags.DEFINE_integer('filters', 32, 'Filter size of convolutions.')
flags.DEFINE_integer('repeat', 4, 'Number of residual layers per stage.')
flags.DEFINE_integer('scales', 0, 'Number of 2x2 downscalings in the classifier.')
flags.DEFINE_integer('uratio', 7, 'Unlabeled batch size ratio.')
flags.DEFINE_integer('balance', 0, 'Method to help balance classes')
FLAGS.set_default('augment', 'd.d.d')
FLAGS.set_default('dataset', 'cifar10.3@250-1')
FLAGS.set_default('batch', 64)
FLAGS.set_default('lr', 0.03)
FLAGS.set_default('train_kimg', 1 << 16)
app.run(main)