Skip to content

Commit

Permalink
Make fully supervised code also TF1.14 compatible.
Browse files Browse the repository at this point in the history
  • Loading branch information
david-berthelot committed Jul 25, 2019
1 parent 014b3c0 commit 8d0e083
Showing 1 changed file with 35 additions and 32 deletions.
67 changes: 35 additions & 32 deletions fully_supervised/lib/train.py
Expand Up @@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
from absl import flags
from tqdm import trange

from libml import utils
from libml.train import ClassifySemi
import tensorflow as tf
from tqdm import trange

FLAGS = flags.FLAGS

Expand All @@ -36,38 +37,40 @@ def train(self, train_nimg, report_nimg):
self.eval_checkpoint(FLAGS.eval_ckpt)
return
batch = FLAGS.batch
with self.graph.as_default():
train_labeled = self.dataset.train_labeled.batch(batch).prefetch(16)
train_labeled = train_labeled.make_one_shot_iterator().get_next()
scaffold = tf.train.Scaffold(saver=tf.train.Saver(max_to_keep=FLAGS.keep_ckpt,
pad_step_number=10))
train_labeled = self.dataset.train_labeled.batch(batch).prefetch(16)
train_labeled = train_labeled.make_one_shot_iterator().get_next()
scaffold = tf.train.Scaffold(saver=tf.train.Saver(max_to_keep=FLAGS.keep_ckpt,
pad_step_number=10))

with tf.Session(config=utils.get_config()) as sess:
self.session = sess
self.cache_eval()

with tf.train.MonitoredTrainingSession(
scaffold=scaffold,
checkpoint_dir=self.checkpoint_dir,
config=utils.get_config(),
save_checkpoint_steps=FLAGS.save_kimg << 10,
save_summaries_steps=report_nimg - batch) as train_session:
self.session = train_session._tf_sess()
self.tmp.step = self.session.run(self.step)
while self.tmp.step < train_nimg:
loop = trange(self.tmp.step % report_nimg, report_nimg, batch,
leave=False, unit='img', unit_scale=batch,
desc='Epoch %d/%d' % (1 + (self.tmp.step // report_nimg), train_nimg // report_nimg))
for _ in loop:
self.train_step(train_session, train_labeled)
while self.tmp.print_queue:
loop.write(self.tmp.print_queue.pop(0))
while self.tmp.print_queue:
print(self.tmp.print_queue.pop(0))
with tf.train.MonitoredTrainingSession(
scaffold=scaffold,
checkpoint_dir=self.checkpoint_dir,
config=utils.get_config(),
save_checkpoint_steps=FLAGS.save_kimg << 10,
save_summaries_steps=report_nimg - batch) as train_session:
self.session = train_session._tf_sess()
self.tmp.step = self.session.run(self.step)
while self.tmp.step < train_nimg:
loop = trange(self.tmp.step % report_nimg, report_nimg, batch,
leave=False, unit='img', unit_scale=batch,
desc='Epoch %d/%d' % (1 + (self.tmp.step // report_nimg), train_nimg // report_nimg))
for _ in loop:
self.train_step(train_session, train_labeled)
while self.tmp.print_queue:
loop.write(self.tmp.print_queue.pop(0))
while self.tmp.print_queue:
print(self.tmp.print_queue.pop(0))

def tune(self, train_nimg):
batch = FLAGS.batch
with self.graph.as_default():
train_labeled = self.dataset.train_labeled.batch(batch).prefetch(16)
train_labeled = train_labeled.make_one_shot_iterator().get_next()
train_labeled = self.dataset.train_labeled.batch(batch).prefetch(16)
train_labeled = train_labeled.make_one_shot_iterator().get_next()

for _ in trange(0, train_nimg, batch, leave=False, unit='img', unit_scale=batch, desc='Tuning'):
x = self.session.run([train_labeled])
self.session.run([self.ops.tune_op], feed_dict={self.ops.x: x['image'],
self.ops.label: x['label']})
for _ in trange(0, train_nimg, batch, leave=False, unit='img', unit_scale=batch, desc='Tuning'):
x = self.session.run([train_labeled])
self.session.run([self.ops.tune_op], feed_dict={self.ops.x: x['image'],
self.ops.label: x['label']})

0 comments on commit 8d0e083

Please sign in to comment.