Skip to content
This repository has been archived by the owner on Oct 31, 2022. It is now read-only.

Commit

Permalink
use ModelTester object during training cycle
Browse files Browse the repository at this point in the history
  • Loading branch information
matpalm committed Aug 19, 2018
1 parent 4a1cf6d commit f3eac31
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 66 deletions.
29 changes: 14 additions & 15 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import data
import model
import numpy as np
import random
import tensorflow as tf
import util as u


class Tester(object):
class ModelTester(object):

def __init__(self, image_dir, label_dir, batch_size, width, height,
no_use_skip_connections, base_filter_size, no_use_batch_norm):
Expand Down Expand Up @@ -49,23 +49,22 @@ def test(self, run):
debug_img = None # created on first call
while True:
try:

iterator_batch_size = true_bitmaps.shape[0] # note: final one may be < batch_size

if debug_img is None:
# fetch imgs as well to create debug_img
imgs, true_bitmaps, predicted_bitmaps, xent_loss = sess.run([self.test_imgs,
self.test_xys_bitmaps,
self.model.output,
self.model.xent_loss])
debug_img = u.debug_img(imgs, true_bitmaps, predicted_bitmaps)
# choose a random element from batch
idx = random.randint(0, true_bitmaps.shape[0]-1)
debug_img = u.debug_img(imgs[idx], true_bitmaps[idx], predicted_bitmaps[idx])
else:
true_bitmaps, predicted_bitmaps, xent_loss = sess.run([self.test_xys_bitmaps,
self.model.output,
self.model.xent_loss])

xent_losses.append(xent_loss)

iterator_batch_size = true_bitmaps.shape[0]
num_imgs += iterator_batch_size

for idx in range(iterator_batch_size):
Expand Down Expand Up @@ -102,14 +101,14 @@ def test(self, run):
opts = parser.parse_args()
print(opts)

tester = Tester(image_dir=opts.image_dir,
label_dir=opts.label_dir,
batch_size=opts.batch_size,
width=opts.width,
height=opts.height,
no_use_skip_connections=opts.no_use_skip_connections,
base_filter_size=opts.base_filter_size,
no_use_batch_norm=opts.no_use_batch_norm)
tester = ModelTester(image_dir=opts.image_dir,
label_dir=opts.label_dir,
batch_size=opts.batch_size,
width=opts.width,
height=opts.height,
no_use_skip_connections=opts.no_use_skip_connections,
base_filter_size=opts.base_filter_size,
no_use_batch_norm=opts.no_use_batch_norm)

stats = tester.test(opts.run)

Expand Down
60 changes: 20 additions & 40 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import tensorflow as tf
import tensorflow.contrib.slim as slim
import util as u
import test
import time

parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
Expand All @@ -35,7 +36,7 @@

np.set_printoptions(precision=2, threshold=10000, suppress=True, linewidth=10000)

# Build readers for train and test data.
# Build readers / model for training
train_imgs, train_xys_bitmaps = data.img_xys_iterator(image_dir=opts.train_image_dir,
label_dir=opts.label_dir,
batch_size=opts.batch_size,
Expand All @@ -45,20 +46,7 @@
random_rotation=opts.random_rotate,
repeat=True,
width=opts.width, height=opts.height)
test_imgs, test_xys_bitmaps = data.img_xys_iterator(image_dir=opts.test_image_dir,
label_dir=opts.label_dir,
batch_size=opts.batch_size,
patch_width_height=None, # i.e. no patchs
distort_rgb=False,
flip_left_right=False,
random_rotation=False,
repeat=True,
width=opts.width, height=opts.height)
print(test_imgs.get_shape())
print(test_xys_bitmaps.get_shape())

# Build training and test model with same params.
# TODO: opts for skip and base filters

print("patch train model...")
train_model = model.Model(train_imgs,
is_training=True,
Expand All @@ -68,12 +56,10 @@
train_model.calculate_losses_wrt(labels=train_xys_bitmaps)

print("full res test model...")
test_model = model.Model(test_imgs,
is_training=False,
use_skip_connections=not opts.no_use_skip_connections,
base_filter_size=opts.base_filter_size,
use_batch_norm=not opts.no_use_batch_norm)
test_model.calculate_losses_wrt(labels=test_xys_bitmaps)
tester = test.ModelTester(opts.test_image_dir, opts.label_dir,
opts.batch_size, opts.width, opts.height,
opts.no_use_skip_connections, opts.base_filter_size,
opts.no_use_batch_norm)

global_step = tf.train.get_or_create_global_step()

Expand Down Expand Up @@ -101,10 +87,10 @@

# train a bit.
for _ in range(opts.train_steps):
_, xl = sess.run([train_op, train_model.xent_loss])
sess.run(train_op)

# fetch global_step
step = sess.run(global_step)
# fetch global_step & xent_loss
step, xl = sess.run([global_step, train_model.xent_loss])

# report one liner
print("step %d/%d\ttime %d\txent_loss %f" % (step, opts.steps,
Expand All @@ -113,35 +99,29 @@

# train / test summaries
# includes loss summaries as well as a hand rolled debug image

# ...train
i, bm, logits, o, xl = sess.run([train_imgs, train_xys_bitmaps,
train_model.logits, train_model.output,
train_model.xent_loss])
train_summaries_writer.add_summary(u.explicit_summaries({"xent": xl}), step)
debug_img_summary = u.pil_image_to_tf_summary(u.debug_img(i, bm, o))
debug_img_summary = u.pil_image_to_tf_summary(u.debug_img(i[0], bm[0], o[0]))
train_summaries_writer.add_summary(debug_img_summary, step)
train_summaries_writer.flush()

# save checkpoint (to be reloaded by test)
# TODO: this is clumsy; need to refactor test to use current session instead
# of loading entirely new one... will do for now.
train_model.save(sess, "ckpts/%s" % opts.run)

# ... test
i, bm, o, xl = sess.run([test_imgs, test_xys_bitmaps, test_model.output,
test_model.xent_loss])

set_comparison = u.SetComparison()
for batch_idx in range(bm.shape[0]):
true_centroids = u.centroids_of_connected_components(bm[batch_idx])
predicted_centroids = u.centroids_of_connected_components(o[batch_idx])
set_comparison.compare_sets(true_centroids, predicted_centroids)
precision, recall, f1 = set_comparison.precision_recall_f1()
tag_values = {"xent": xl, "precision": precision, "recall": recall, "f1": f1}
stats = tester.test(opts.run)
tag_values = {k: stats[k] for k in ['precision', 'recall', 'f1']}
test_summaries_writer.add_summary(u.explicit_summaries(tag_values), step)

debug_img_summary = u.pil_image_to_tf_summary(u.debug_img(i, bm, o))
debug_img_summary = u.pil_image_to_tf_summary(stats['debug_img'])
test_summaries_writer.add_summary(debug_img_summary, step)
test_summaries_writer.flush()

# save checkpoint
train_model.save(sess, "ckpts/%s" % opts.run)

# check if done by steps or time
if step >= opts.steps:
done = True
Expand Down
20 changes: 9 additions & 11 deletions util.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,26 +30,24 @@ def xys_to_bitmap(xys, height, width, rescale=1.0):
raise e
return bitmap

def debug_img(i, bm, o):
def debug_img(img, bitmap, logistic_output):
# create a debug image with three columns; 1) original RGB. 2) black/white
# bitmap of labels 3) black/white bitmap of predictions (with centroids coloured
# red. we expect i, bm and o to be a batch but we just use first element
_bs, h, w, _c = bm.shape
# red.
h, w, _channels = bitmap.shape
canvas = Image.new('RGB', (w*3, h), (50, 50, 50))
# original input image on left
i = zero_centered_array_to_pil_image(i[0])
i = i.resize((w, h))
canvas.paste(i, (0, 0))
img = zero_centered_array_to_pil_image(img)
img = img.resize((w, h))
canvas.paste(img, (0, 0))
# label bitmap in center
bm = bitmap_to_pil_image(bm[0])
canvas.paste(bm, (w, 0))
canvas.paste(bitmap_to_pil_image(bitmap), (w, 0))
# logistic output on right
logistic_output = bitmap_to_pil_image(o[0])
canvas.paste(logistic_output, (w*2, 0))
canvas.paste(bitmap_to_pil_image(logistic_output), (w*2, 0))
# draw red dots on right hand side image corresponding to
# final thresholded prediction
draw = ImageDraw.Draw(canvas)
for y, x in centroids_of_connected_components(o[0]):
for y, x in centroids_of_connected_components(logistic_output):
draw.rectangle((w*2+x,y,w*2+x,y), fill='red')
# finally draw blue lines between the three to delimit boundaries
draw.line([w,0,w,h], fill='blue')
Expand Down

0 comments on commit f3eac31

Please sign in to comment.