/
flow.py
executable file
·604 lines (526 loc) · 26.1 KB
/
flow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
#!/usr/bin/env python
# ==============================================================================
# \file flow.py
# \author chenghuige
# \date 2016-08-17 10:48:46.141744
# \Description
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
flags = tf.app.flags
#import gflags as flags
FLAGS = flags.FLAGS
import tensorflow as tf
flags = tf.app.flags
FLAGS = flags.FLAGS
import tensorflow.contrib.slim as slim
from tensorflow.contrib import tpu
try:
#import horovod.tensorflow as hvd
from mpi4py import MPI
comm = MPI.COMM_WORLD
except Exception:
pass
import os, sys, traceback
import melt
import gezi
from melt.utils import logging
import glob
import inspect
import traceback
import numpy as np
from tqdm import tqdm
def tf_flow(process_once, model_dir=None, num_steps=None, sess=None):
"""
basic flow for tf records, allow most freedom for usage, if not tfrecords no need for flow
Args:
train_once: function with 2 inputs sess and step
"""
init_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
if sess is None:
sess = tf.InteractiveSession()
if not model_dir:
sess.run(init_op)
else:
melt.restore(sess, model_dir)
# coord = tf.train.Coordinator()
# threads = tf.train.start_queue_runners(sess=sess, coord=coord)
# try:
# # dataset
# print('tf_flow using datset')
# step = 0
# try:
# while True:
# stop = process_once(sess, step)
# if stop is True:
# print('Early stop running %d stpes'%(step))
# raise tf.errors.OutOfRangeError(None, None,'Early stop running %d stpes'%(step))
# step += 1
# if num_steps and step == num_steps:
# raise tf.errors.OutOfRangeError(None, None, 'Reached max num steps')
# except tf.errors.OutOfRangeError:
# print('Done training for %d steps.' % (step))
# except Exception:
# # old queue method
# print('tf_flow using queue')
try:
step = 0
#while not coord.should_stop():
while True:
stop = process_once(sess, step)
if stop is True:
print('Early stop running %d stpes'%(step))
raise tf.errors.OutOfRangeError(None, None,'Early stop running %d stpes'%(step))
step += 1
if num_steps and step == num_steps:
raise tf.errors.OutOfRangeError(None, None, 'Reached max num steps')
except tf.errors.OutOfRangeError:
print('Done training for %d steps.' % (step))
# finally:
# coord.request_stop()
# coord.join(threads)
sess.close()
return step
def _get_model_path(model_dir, save_model=False):
if os.path.exists(model_dir + '.index') and os.path.exists(model_dir + '.meta'):
# input is real model path
return model_dir
if not os.path.exists(model_dir):
if save_model:
gezi.try_mkdir(model_dir)
return None
ckpt = tf.train.get_checkpoint_state(model_dir)
if ckpt and ckpt.model_checkpoint_path:
#input valid dir and return latest model
return os.path.join(model_dir, os.path.basename(ckpt.model_checkpoint_path))
elif os.path.isdir(model_dir):
#input valid dir but no models
return None
else:
#this might be user specified model like ./model/model-100.ckpt
#the file exists and we NOTICE we do not check if it is valid model file!
return model_dir
def _get_checkpoint_path(checkpoint_path, step=None, num_steps_per_epoch=None, epoch=None):
if not num_steps_per_epoch:
return checkpoint_path
return '%s-%.2f'%(checkpoint_path, epoch or step / float(num_steps_per_epoch))
def tf_train_flow(train_once_fn,
model_dir=None,
log_dir=None,
max_models_keep=1,
save_interval_seconds=600,
save_interval_steps=1000,
num_epochs=None,
num_steps=None,
save_model=True,
save_interval_epochs=None,
freeze_graph=False,
num_steps_per_epoch=0,
restore_from_latest=True,
metric_eval_fn=None,
valid_interval_epochs=0,
inference_fn=None,
inference_interval_epochs=0,
init_fn=None,
restore_fn=None,
restore_include=None,
restore_exclude=None,
save_all_scope=False, #TODO save load from restore scope only but svae all
variables_to_restore=None,
variables_to_save=None, #by default will be the same as variables_to_restore
output_collection_names=None,
output_node_names=None,
learning_rate=None, #not use yet, just use in train_once
learning_rate_patience=None,
learning_rate_decay_factor=None,
write_during_train=True,
model=None,
sess=None):
"""
similary flow as tf_flow, but add model try reload and save
"""
use_horovod = 'OMPI_COMM_WORLD_RANK' in os.environ
if use_horovod:
if FLAGS.torch:
import horovod.torch as hvd
else:
import horovod.tensorflow as hvd
model_dir_ = model_dir
if use_horovod and hvd.rank() != 0:
model_dir = None
if sess is None:
#TODO melt.get_session is global session but may cause non close at last
sess = melt.get_session()
if FLAGS.use_tpu:
sess.run(tpu.initialize_system())
#logging.info('tf_train_flow start')
#logging.info('max_models_keep:', max_models_keep)
#logging.info('save_interval_seconds:', save_interval_seconds)
if model_dir:
if model:
checkpoint = tf.train.Checkpoint(model=model)
ckpt_dir = model_dir + '/ckpt'
checkpoint_prefix = os.path.join(ckpt_dir, 'ckpt')
#this is usefull for you use another model with another scope, and just load and restore/save initalize your scope vars!
#this is not for finetune but mainly for like using another model as in predict like this introducing graph other model scope and ignore here
# var_list = None if not restore_scope else tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=restore_scope)
# #logging.info('-------------var_list', var_list)
# if not variables_to_restore:
# variables_to_restore = var_list
if not variables_to_restore:
variables_to_restore = slim.get_variables_to_restore(include=restore_include, exclude=restore_exclude)
if not variables_to_save:
variables_to_save = variables_to_restore
if save_all_scope:
variables_to_save = None
#if variables_to_restore is None:
logging.info('variables_to_restore from %s' % model_dir)
#load all var in checkpoint try to save all var(might more then original checkpoint) if not specifiy variables_to_save
varnames_in_checkpoint = melt.get_checkpoint_varnames(model_dir)
#logging.info('varnames_in_checkpoint: {}'.format(varnames_in_checkpoint))
# TODO has someproblem say tf.Variable 'r_net/text_encoder/cudnn_rnn/cu_dnngru/recurrent_kernel/adam_v:0' even though in checkpoint I have renated it as ignore/rnet
variables_to_restore_from_model = slim.get_variables_to_restore(include=varnames_in_checkpoint)
#logging.info('variables_to_restore_from_model: {}'.format(variables_to_restore_from_model))
if not variables_to_restore:
variables_to_restore = variables_to_restore_from_model
else:
variables_to_restore = [v for v in variables_to_restore if v in variables_to_restore_from_model]
if restore_exclude:
for excl in restore_exclude:
variables_to_restore = [v for v in variables_to_restore if not excl in v.name]
#--tf 1.6 adadelta will have same vars...
variables_to_restore = list(set(variables_to_restore))
#logging.info('variables_to_restore', variables_to_restore[:100])
logging.info('variables_to_restore', [x for x in variables_to_restore if not 'OptimizeLoss' in x.name][:100])
##finally remove global_step since melt.apps.train will handle it!
global_step = tf.train.get_or_create_global_step()
#variables_to_restore = [v for v in variables_to_restore if not tf.GraphKeys.GLOBAL_STEP in v.name]
#variables_to_restore = [v for v in variables_to_restore if not 'learning_rate' in v.name]
# TODO fixme if step, step2.. and in checkpoint step then here will be step2...
#print('------------', [v for v in variables_to_restore if 'step' in v.name])
loader = tf.train.Saver(var_list=variables_to_restore)
logging.info('max models to keep {}, keep every {} hours'.format(max_models_keep, save_interval_seconds / 3600.0))
saver = tf.train.Saver(
max_to_keep=max_models_keep,
keep_checkpoint_every_n_hours=save_interval_seconds / 3600.0,
var_list=variables_to_save)
epoch_saver = tf.train.Saver(var_list=variables_to_save, max_to_keep=1000)
best_epoch_saver = tf.train.Saver(var_list=variables_to_save)
#logging.info('variables_to_save:{}'.format(variables_to_save))
# # #TODO for safe restore all init will be ok ?
# # if variables_to_restore is None:
init_op = tf.group(tf.global_variables_initializer(), #variables_initializer(global_variables())
tf.local_variables_initializer()) #variables_initializer(local_variables())
# # else:
# # init_op = tf.group(tf.variables_initializer(variables_to_restore),
# # tf.local_variables_initializer())
##--mostly this will be fine except for using assistant predictor, initialize again! will make assistant predictor wrong
##so assume to all run init op! if using assistant predictor, make sure it use another session
# https://stackoverflow.com/questions/35164529/in-tensorflow-is-there-any-way-to-just-initialize-uninitialised-variables
# def guarantee_initialized_variables(session, list_of_variables = None):
# if list_of_variables is None:
# list_of_variables = tf.global_variables()
# uninitialized_variables = list(tf.get_variable(name) for name in
# session.run(tf.report_uninitialized_variables(list_of_variables)))
# return unintialized_variables
# unintialized_variables = guarantee_initialized_variables(sess)
# init_op = tf.group(tf.initialize_variables(uninitialized_vars), tf.local_variables_initializer())
timer = gezi.Timer('sess run init_op in melt.tf_train_flow')
#model.save('./weights')
# notice
sess.run(init_op)
timer.print_elapsed()
#melt.init_uninitialized_variables(sess)
#pre_step means the step last saved, train without pretrained,then -1
pre_step = -1
fixed_pre_step = -1 #fixed pre step is for epoch num to be correct if you change batch size
#print(model_dir)
pre_epoch = None
if model_dir:
model_path = _get_model_path(model_dir, save_model)
# if not model_path:
# model_path = _get_model_path(os.path.join(model_dir, 'epoch'))
#print(model_path)
model_dir = gezi.get_dir(model_dir) #incase you pass ./model/model-ckpt1000 -> ./model
if model_path is not None:
if not restore_from_latest:
logging.info('using recent but not latest model')
model_path = melt.recent_checkpoint(model_dir)
model_name = os.path.basename(model_path)
timer = gezi.Timer('Loading and training from existing model [%s]' % model_path)
if restore_fn is not None:
restore_fn(sess)
loader.restore(sess, model_path)
## not supported
#model.save()
#model.save_weights('./weights')
timer.print()
#pre_step = melt.get_model_step(model_path) - 1 if FLAGS.global_step is None else FLAGS.global_step -1
# TODO check ..
pre_step = sess.run(tf.train.get_global_step()) - 1
pre_epoch = melt.get_model_epoch(model_path) if FLAGS.global_epoch is None else FLAGS.global_epoch
fixed_pre_step = pre_step
# if pre_epoch is not None:
# #like using batch size 32, then reload train using batch size 64
# if abs(pre_step / num_steps_per_epoch - pre_epoch) > 0.1:
# fixed_pre_step = int(pre_epoch * num_steps_per_epoch)
# logging.info('Warning, epoch is diff with pre_step / num_steps_per_epoch:{}, pre_epoch:{},maybe you change batch size and we will adjust to set pre_step as {}'\
# .format(pre_step / num_steps_per_epoch, pre_epoch, fixed_pre_step))
else:
latest_checkpoint = None
if not use_horovod: #now will hang
try:
latest_checkpoint = tf.train.latest_checkpoint(ckpt_dir)
if latest_checkpoint:
logging.info('Try start from eager trained mode, latest checkpoint:', latest_checkpoint)
checkpoint.restore(latest_checkpoint).run_restore_ops(session=sess)
pre_epoch = int(latest_checkpoint.split('-')[-1])
#pre_step = pre_epoch * num_steps_per_epoch - 1
# TODO check
pre_step = sess.run(tf.train.get_global_step()) - 1
fixed_pre_step = pre_step
logging.info('Start step is:', pre_step)
except Exception:
logging.info('Something wrong with restore from eager trained model')
if latest_checkpoint is None:
logging.info('Train all start step 0')
#https://stackoverflow.com/questions/40220201/tensorflow-tf-initialize-all-variables-vs-tf-initialize-local-variables
#tf.initialize_all_variables() is a shortcut to tf.initialize_variables(tf.all_variables()),
#tf.initialize_local_variables() is a shortcut to tf.initialize_variables(tf.local_variables()),
#which initializes variables in GraphKeys.VARIABLES and GraphKeys.LOCAL_VARIABLE collections, respectively.
#init_op = tf.group(tf.global_variables_initializer(),
# tf.local_variables_initializer())
#[var for var in tf.all_variables() if var.op.name.startswith(restore_scope)] will be the same as tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=restore_scope)
#sess.run(init_op)
#like use image model, build image graph, reload first train, and then will go to same checkpoint all varaible just restore will ok
#for finetune from loading other model init
if init_fn is not None:
init_fn(sess)
#sess.run(tf.assign(global_step, tf.constant(global_step_val, dtype=tf.int64)))
try:
learning_rate = tf.get_collection('learning_rate')[-1]
learning_rate_weight = tf.get_collection('learning_rate_weight')[-1]
sess.run(tf.assign(learning_rate, learning_rate * learning_rate_weight))
except Exception:
# if not using weight_decay but using optimizer decay then will go here as learning rate is a tensor can not assign
pass
try:
logging.info('Actual start global step:', sess.run(global_step), 'learning rate:', sess.run(learning_rate), 'learning_rate_weight:', sess.run(learning_rate_weight))
except Exception:
pass
if use_horovod:
bcast = hvd.broadcast_global_variables(0)
sess.run(bcast)
# before below
if gezi.has_env('METRIC'):
# safe for using horovod
model_path_ = _get_model_path(model_dir_)
l = metric_eval_fn(model_path_)
if not use_horovod or hvd.rank() == 0:
print(list(zip(l[1], l[0])))
exit(0)
if model_dir_:
#if save_interval_epochs and num_steps_per_epoch and num_steps >= 0:
epoch_dir = os.path.join(model_dir_, 'epoch')
if not use_horovod or hvd.rank() == 0:
gezi.try_mkdir(epoch_dir)
checkpoint_path = os.path.join(model_dir_, 'model.ckpt')
# coord = tf.train.Coordinator()
# threads = tf.train.start_queue_runners(sess=sess, coord=coord)
#tf.train.write_graph(sess.graph_def, model_dir, 'train.pbtxt')
only_one_step = False
if use_horovod:
## TODO FIXME why bcast here not work ? simple test work see tests/bcast.py
#comm.bcast(pre_step, root=0)
temp = np.array([pre_step, fixed_pre_step])
comm.Bcast(temp, root=0)
pre_step, fixed_pre_step = temp[0], temp[1]
step = start = pre_step + 1
fixed_step = fixed_pre_step + 1
#first = True
#hack just for save one model after load
if num_steps < 0 or (num_steps and num_steps < step):
logging.info('just load and resave then exit')
model_path_ = _get_checkpoint_path(checkpoint_path, step, num_steps_per_epoch, epoch=pre_epoch)
saver.save(sess, model_path_, global_step=step + 1)
if freeze_graph:
melt.freeze_graph(sess, model_path_, step + 1, output_collection_names, output_node_names)
sess.close()
exit(0)
if num_epochs < 0:
only_one_step = True
logging.info('just run one step')
if FLAGS.work_mode != 'train':
assert not os.path.isdir(FLAGS.model_dir), FLAGS.model_dir
if 'valid' in FLAGS.work_mode:
vals, names = metric_eval_fn(FLAGS.model_dir)
logging.info(list(zip(names, vals)))
if 'test' in FLAGS.work_mode:
inference_fn(FLAGS.model_dir)
exit(0)
#early_stop = True #TODO allow config
num_bad_epochs = 0
pre_epoch_eval_loss = 1e20
best_epoch_eval_loss = 1e20
num_allowed_bad_epochs = 4 #allow 5 non decrease eval loss epochs before stop
epoch_saved_step = 0
num_epochs = num_epochs if num_epochs else 1024
#-------------------------------main loop
try:
for epoch in range(num_epochs):
for _ in tqdm(range(num_steps_per_epoch), ascii=True):
model_step_path = None
if model_dir_:
model_path_ = os.path.join(epoch_dir,'model.ckpt-%.2f'%(fixed_step / float(num_steps_per_epoch)))
model_step_path_ = model_path_ + '-' + str(step)
if (write_during_train and metric_eval_fn is not None and valid_interval_epochs and fixed_step % int(num_steps_per_epoch * valid_interval_epochs) == 0):
model_step_path = model_step_path_
else:
model_step_path = None
if step == 0:
model_step_path = None
#print('--------------------step', step)
stop = train_once_fn(sess,
step,
is_start=(step==start),
fixed_step=fixed_step,
num_epochs=num_epochs,
model_path=model_step_path,
use_horovod=use_horovod,
## TODO FIXME this line will cause tensorflow.python.framework.errors_impl.NotFoundError: Resource localhost/save_counter/N10tensorflow3VarE does not exist.
)
#first = False
if only_one_step:
stop = True
step += 1
fixed_step += 1
if save_model and step and model_dir:
#step 0 is also saved! actually train one step and save
is_step_save = step % save_interval_steps == 0
is_epoch_save = save_interval_steps and num_steps_per_epoch and fixed_step % int(num_steps_per_epoch * save_interval_epochs) == 0
is_step_save = is_step_save or is_epoch_save
if is_step_save:
timer = gezi.Timer('save model step %d to %s'%(step, checkpoint_path), False)
model_path_ = _get_checkpoint_path(checkpoint_path, fixed_step, num_steps_per_epoch)
saver.save(sess, model_path_, global_step=step)
if freeze_graph:
melt.freeze_graph(sess, model_path_, step, output_collection_names, output_node_names)
if log_dir != model_dir:
assert log_dir
command = 'rsync -l -r -t %s/* %s' % (log_dir, model_dir)
print(command, file=sys.stderr)
os.system(command)
timer.print_elapsed()
if is_epoch_save:
# TODO only epoch in name not sep ?
epoch_saved_step = step
model_path_ = os.path.join(epoch_dir,'model.ckpt-%.2f'%(fixed_step / float(num_steps_per_epoch)))
model_step_path = model_path_ + '-' + str(step)
timer = gezi.Timer('epoch save to {}'.format(model_path_))
epoch_saver.save(sess, model_path_, global_step=step)
#logging.info(timer.elapsed())
timer.print_elapsed()
# TODO FIXME if add keras save below wil hang, might due to rank 0 save so slower then others
# fixed by adding allreduce on each loop end
# [1,0]<stderr>:Stalled ranks:
# [1,0]<stderr>:0: [HorovodAllreduce_Const_9_0]
if model and not use_horovod:
#if model:
#model.save_weights(epoch_dir + '/ckpt-%.2f' % (fixed_step / float(num_steps_per_epoch)))
# TODO FIXME if restart will save from 1... again..
timer = gezi.Timer('keras epoch save to {}'.format(checkpoint_prefix))
checkpoint.save(checkpoint_prefix, session=sess)
#print(sess.run(checkpoint.save_counter))
#logging.info(timer.elapsed())
timer.print_elapsed()
if freeze_graph:
melt.freeze_graph(sess, model_path_, step, output_collection_names, output_node_names)
if write_during_train:
if inference_fn is not None and inference_interval_epochs and fixed_step % int(num_steps_per_epoch * inference_interval_epochs) == 0:
model_step_path = model_path_ + '-' + str(step)
try:
#print('--------------inference fn')
inference_fn(model_path=model_step_path)
except Exception:
logging.info(traceback.format_exc())
## metric eval move to train_once which means you first save then eval, so might need 1 more step
# if metric_eval_fn is not None and valid_interval_epochs and fixed_step % int(num_steps_per_epoch * valid_interval_epochs) == 0:
# model_step_path = model_path_ + '-' + str(step)
# try:
# metric_eval_fn(model_path=model_step_path)
# except Exception:
# logging.info(traceback.format_exc())
if stop is True:
print('Early stop running %d stpes'%(step), file=sys.stderr)
raise tf.errors.OutOfRangeError(None, None,'Early stop running %d stpes'%(step))
if num_steps and (step + 1) == start + num_steps:
raise tf.errors.OutOfRangeError(None, None,'Reached max num steps')
#max_num_epochs = 1000
max_num_epochs = num_epochs
#if max_num_epochs and num_steps_per_epoch and fixed_step // num_steps_per_epoch >= max_num_epochs:
if max_num_epochs and num_steps_per_epoch and fixed_step / num_steps_per_epoch > max_num_epochs:
raise tf.errors.OutOfRangeError(None, None,'Reached max num epochs of %d'%max_num_epochs)
raise tf.errors.OutOfRangeError(None, None,'Reached max num epochs of %d'%max_num_epochs)
except tf.errors.OutOfRangeError:
# if run 2 epoch and we have just epoch saved, do not need to save only 1 step more model
if (step - epoch_saved_step > 1) and not (step==start) and save_model and step % save_interval_steps != 0 and model_dir:
model_path_ = _get_checkpoint_path(checkpoint_path, step, num_steps_per_epoch)
saver.save(sess, model_path_, global_step=step)
if freeze_graph:
melt.freeze_graph(sess, model_path_, step, output_collection_names, output_node_names)
## this is for train which save files to hdfs then HACK (set log_dir different then model_dir log save to local and sync to model_dir when saving model)
if log_dir != model_dir:
assert log_dir
command = 'rsync -l -r -t %s/* %s' % (log_dir, model_dir)
print(command, file=sys.stderr)
os.system(command)
if only_one_step:
# TODO strange logging.info will not show to screen if using horovod
logging.info('Done one step')
exit(0)
## since we have saved model then eval no need to do here, expect for train with specific steps but not common use
# if (step - epoch_saved_step > 1) and metric_eval_fn is not None:
# metric_eval_fn(model_path=model_step_path)
if (num_epochs and fixed_step / num_steps_per_epoch >= num_epochs) or (num_steps and step == start + num_steps) :
logging.info('Done training for %.3f epochs, %d steps.' % (fixed_step / num_steps_per_epoch, step))
#FIXME becase coord.join seems not work, RuntimeError: Coordinator stopped with threads still running: Thread-9
exit(0)
else:
logging.info('Should not stop, but stopped at epoch: %.3f'%(fixed_step / num_steps_per_epoch))
logging.info(traceback.format_exc())
#raise e
if FLAGS.use_tpu:
sess.run(tpu.shutdown_system())
sess.close()
#@TODO not tested yet
def tf_test_flow(test_once, model_dir='./model',
model_name=None, num_epochs=1, num_steps=0,
sess=None):
"""
basic flow for tf records, allow most freedom for usage, if not tfrecords no need for flow
Args:
test_once: function with 2 inputs sess and step
model_dir: can be dir like ./model will fetch lates model in model dir , or be real model path like ./model/model.0.ckpt
"""
if sess is None:
sess = tf.InteractiveSession()
melt.restore(sess, model_dir, model_name)
if not os.path.isdir(model_dir):
model_dir = os.path.dirname(model_dir)
summary_op = tf.merge_all_summaries()
summary_writer = tf.train.SummaryWriter(model_dir, sess.graph)
# coord = tf.train.Coordinator()
# threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
step = 0
#while not coord.should_stop():
while True:
test_once(sess, step)
step += 1
if num_steps and step == num_steps:
raise tf.errors.OutOfRangeError(None, None, 'Reached max num steps')
except tf.errors.OutOfRangeError:
print('Done testing for %d epochs, %d steps.' % (num_epochs, step))