Skip to content
This repository has been archived by the owner on Jun 16, 2021. It is now read-only.

Commit

Permalink
Update train/test to use minibatches, for larger test datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
brilee committed Sep 7, 2016
1 parent 4a05765 commit 9f65a05
Showing 1 changed file with 62 additions and 21 deletions.
83 changes: 62 additions & 21 deletions policy.py
Expand Up @@ -39,6 +39,7 @@ def __init__(self, num_input_planes, k=32, num_int_conv_layers=3):
self.training_summary_writer = None
self.session = tf.Session()
self.set_up_network()
self.set_up_summaries()

def set_up_network(self):
# a global_step variable allows epoch counts to persist through multiple training sessions
Expand Down Expand Up @@ -85,22 +86,30 @@ def _conv2d(x, W):
was_correct = tf.equal(tf.argmax(output, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(was_correct, tf.float32))

saver = tf.train.Saver()

weight_summaries = tf.merge_summary([
tf.histogram_summary(weight_var.name, weight_var)
for weight_var in [W_conv_init] + W_conv_intermediate + [W_conv_final, b_conv_final]],
name="weight_summaries"
)
activation_summaries = tf.merge_summary([
tf.histogram_summary(act_var.name, act_var)
for act_var in [h_conv_init] + h_conv_intermediate, [h_conv_final]],
for act_var in [h_conv_init] + h_conv_intermediate + [h_conv_final]],
name="activation_summaries"
)
_accuracy = tf.scalar_summary("accuracy", accuracy)
_cost = tf.scalar_summary("log_likelihood_cost", log_likelihood_cost)
accuracy_summaries = tf.merge_summary([_accuracy, _cost], name="accuracy_summaries")
saver = tf.train.Saver()

# save everything to self.
for name, thing in locals().items():
if not name.startswith('_'):
setattr(self, name, thing)

def set_up_summaries(self):
# See summarize() for why things are set up this way
accuracy_summary = tf.placeholder(tf.float32, [])
cost_summary = tf.placeholder(tf.float32, [])
_accuracy = tf.scalar_summary("accuracy", accuracy_summary)
_cost = tf.scalar_summary("log_likelihood_cost", cost_summary)
accuracy_summaries = tf.merge_summary([_accuracy, _cost], name="accuracy_summaries")
# save everything to self.
for name, thing in locals().items():
if not name.startswith('_'):
Expand All @@ -110,6 +119,16 @@ def initialize_logging(self, tensorboard_logdir):
self.test_summary_writer = tf.train.SummaryWriter(os.path.join(tensorboard_logdir, "test"), self.session.graph)
self.training_summary_writer = tf.train.SummaryWriter(os.path.join(tensorboard_logdir, "training"), self.session.graph)

def summarize(self, accuracy, cost):
# Accuracy and cost cannot be calculated with the full test dataset
# in one pass, so they must be computed in batches. Unfortunately,
# the built-in TF summary nodes cannot be told to aggregate multiple
# executions. Therefore, we aggregate the accuracy/cost ourselves at
# the python level, and then shove it through the accuracy/cost summary
# nodes to generate the appropriate summary protobufs for writing.
return self.session.run(self.accuracy_summaries,
feed_dict={self.accuracy_summary: accuracy, self.cost_summary: cost})

def initialize_variables(self, save_file=None):
if save_file is None:
self.session.run(tf.initialize_all_variables())
Expand All @@ -124,18 +143,28 @@ def save_variables(self, save_file):

def train(self, training_data, batch_size=32):
num_minibatches = training_data.data_size // batch_size
aggregate_accuracy, aggregate_cost = 0, 0
for i in range(num_minibatches):
batch_x, batch_y = training_data.get_batch(batch_size)
global_step = self.get_global_step()
if global_step % 100 == 0:
accuracy_summaries, activation_summaries, train_accuracy = self.session.run(
[self.accuracy_summaries, self.activation_summaries, self.accuracy],
feed_dict={self.x: batch_x, self.y: batch_y})
if self.training_summary_writer is not None:
self.training_summary_writer.add_summary(accuracy_summaries, global_step)
self.training_summary_writer.add_summary(activation_summaries, global_step)
print("Step %d, training data accuracy: %g" % (global_step, train_accuracy))
self.session.run(self.train_step, feed_dict={self.x: batch_x, self.y: batch_y})
_, accuracy, cost = self.session.run(
[self.train_step, self.accuracy, self.log_likelihood_cost],
feed_dict={self.x: batch_x, self.y: batch_y})
aggregate_accuracy += accuracy
aggregate_cost += cost

avg_accuracy = aggregate_accuracy / num_minibatches
avg_cost = aggregate_cost / num_minibatches
global_step = self.get_global_step()
aggregate_accuracy, aggregate_cost = 0, 0
print("Step %d training data accuracy: %g; cost: %g" % (global_step, avg_accuracy, avg_cost))
if self.training_summary_writer is not None:
activation_summaries = self.session.run(
self.activation_summaries,
feed_dict={self.x: batch_x, self.y: batch_y})
accuracy_summaries = self.summarize(avg_accuracy, avg_cost)
self.training_summary_writer.add_summary(activation_summaries, global_step)
self.training_summary_writer.add_summary(accuracy_summaries, global_step)


def run(self, position):
'Return a sorted list of (probability, move) tuples'
Expand All @@ -144,14 +173,26 @@ def run(self, position):
move_probs = [(prob, utils.unflatten_coords(i)) for i, prob in enumerate(probabilities)]
return sorted(move_probs, reverse=True)

def check_accuracy(self, test_data):
def check_accuracy(self, test_data, batch_size=128):
num_minibatches = test_data.data_size // batch_size
weight_summaries = self.session.run(self.weight_summaries)
accuracy_summaries, test_accuracy = self.session.run(
[self.accuracy_summaries, self.accuracy],
feed_dict={self.x: test_data.pos_features, self.y: test_data.next_moves})

aggregate_accuracy, aggregate_cost = 0, 0
for i in range(num_minibatches):
batch_x, batch_y = test_data.get_batch(batch_size)
accuracy, cost = self.session.run(
[self.accuracy, self.log_likelihood_cost],
feed_dict={self.x: batch_x, self.y: batch_y})
aggregate_accuracy += accuracy
aggregate_cost += cost

avg_accuracy = aggregate_accuracy / num_minibatches
avg_cost = aggregate_cost / num_minibatches
accuracy_summaries = self.summarize(avg_accuracy, avg_cost)
global_step = self.get_global_step()
print("Step %s test data accuracy: %g; cost: %g" % (global_step, avg_accuracy, avg_cost))

if self.test_summary_writer is not None:
self.test_summary_writer.add_summary(weight_summaries, global_step)
self.test_summary_writer.add_summary(accuracy_summaries, global_step)
print("Step %s test data accuracy: %g" % (global_step, test_accuracy))

0 comments on commit 9f65a05

Please sign in to comment.