This repository has been archived by the owner on May 28, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 163
/
train.py
76 lines (65 loc) · 3.28 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from absl import flags
from tqdm import trange
from libml import utils
from libml.train import ClassifySemi
FLAGS = flags.FLAGS
class ClassifyFullySupervised(ClassifySemi):
"""Fully supervised classification.
"""
def train_step(self, train_session, data_labeled):
x = self.session.run(data_labeled)
self.tmp.step = train_session.run([self.ops.train_op, self.ops.update_step],
feed_dict={self.ops.x: x['image'],
self.ops.label: x['label']})[1]
def train(self, train_nimg, report_nimg):
if FLAGS.eval_ckpt:
self.eval_checkpoint(FLAGS.eval_ckpt)
return
batch = FLAGS.batch
train_labeled = self.dataset.train_labeled.batch(batch).prefetch(16)
train_labeled = train_labeled.make_one_shot_iterator().get_next()
scaffold = tf.train.Scaffold(saver=tf.train.Saver(max_to_keep=FLAGS.keep_ckpt,
pad_step_number=10))
with tf.Session(config=utils.get_config()) as sess:
self.session = sess
self.cache_eval()
with tf.train.MonitoredTrainingSession(
scaffold=scaffold,
checkpoint_dir=self.checkpoint_dir,
config=utils.get_config(),
save_checkpoint_steps=FLAGS.save_kimg << 10,
save_summaries_steps=report_nimg - batch) as train_session:
self.session = train_session._tf_sess()
self.tmp.step = self.session.run(self.step)
while self.tmp.step < train_nimg:
loop = trange(self.tmp.step % report_nimg, report_nimg, batch,
leave=False, unit='img', unit_scale=batch,
desc='Epoch %d/%d' % (1 + (self.tmp.step // report_nimg), train_nimg // report_nimg))
for _ in loop:
self.train_step(train_session, train_labeled)
while self.tmp.print_queue:
loop.write(self.tmp.print_queue.pop(0))
while self.tmp.print_queue:
print(self.tmp.print_queue.pop(0))
def tune(self, train_nimg):
batch = FLAGS.batch
train_labeled = self.dataset.train_labeled.batch(batch).prefetch(16)
train_labeled = train_labeled.make_one_shot_iterator().get_next()
for _ in trange(0, train_nimg, batch, leave=False, unit='img', unit_scale=batch, desc='Tuning'):
x = self.session.run([train_labeled])
self.session.run([self.ops.tune_op], feed_dict={self.ops.x: x['image'],
self.ops.label: x['label']})