Skip to content

Commit

Permalink
train+evaluate in one script; add cmd-line args
Browse files Browse the repository at this point in the history
  • Loading branch information
unixpickle committed Dec 6, 2017
1 parent 52cf969 commit 7ed728d
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 62 deletions.
2 changes: 1 addition & 1 deletion .gitignore
@@ -1,5 +1,5 @@
# Custom output files
omniglot_out
model_checkpoint
data

# Byte-compiled / optimized / DLL files
Expand Down
38 changes: 38 additions & 0 deletions run_omniglot.py
@@ -0,0 +1,38 @@
"""
Train a model on Omniglot.
"""

import tensorflow as tf

from supervised_reptile.args import argument_parser, train_kwargs, evaluate_kwargs
from supervised_reptile.eval import evaluate
from supervised_reptile.models import OmniglotModel
from supervised_reptile.omniglot import read_dataset, split_dataset, augment_dataset
from supervised_reptile.train import train

DATA_DIR = 'data/omniglot'
SAVE_DIR = 'omniglot_out'

def main():
"""
Load data and train a model on it.
"""
args = argument_parser().parse_args()

train_set, test_set = split_dataset(read_dataset(DATA_DIR))
train_set = list(augment_dataset(train_set))
test_set = list(test_set)

model = OmniglotModel(args.classes)

with tf.Session() as sess:
print('Training...')
train(sess, model, train_set, test_set, args.checkpoint, **train_kwargs(args))

print('Evaluating...')
eval_kwargs = evaluate_kwargs(args)
print('Train accuracy: ' + str(evaluate(sess, model, train_set, **eval_kwargs)))
print('Test accuracy: ' + str(evaluate(sess, model, test_set, **eval_kwargs)))

if __name__ == '__main__':
main()
53 changes: 53 additions & 0 deletions supervised_reptile/args.py
@@ -0,0 +1,53 @@
"""
Command-line argument parsing.
"""

import argparse

def argument_parser():
"""
Get an argument parser for a training script.
"""
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--checkpoint', help='checkpoint directory', default='model_checkpoint')
parser.add_argument('--classes', help='number of classes per inner task', default=5, type=int)
parser.add_argument('--shots', help='number of examples per class', default=5, type=int)
parser.add_argument('--inner-batch', help='inner batch size', default=5, type=int)
parser.add_argument('--inner-iters', help='inner iterations', default=20, type=int)
parser.add_argument('--meta-step', help='meta-training step size', default=0.1, type=float)
parser.add_argument('--meta-batch', help='meta-training batch size', default=1, type=int)
parser.add_argument('--meta-iters', help='meta-training iterations', default=70000, type=int)
parser.add_argument('--eval-batch', help='eval inner batch size', default=5, type=int)
parser.add_argument('--eval-iters', help='eval inner iterations', default=50, type=int)
parser.add_argument('--eval-samples', help='evaluation samples', default=10000, type=int)
return parser

def train_kwargs(parsed_args):
"""
Build kwargs for the train() function from the parsed
command-line arguments.
"""
return {
'num_classes': parsed_args.classes,
'num_shots': parsed_args.shots,
'inner_batch_size': parsed_args.inner_batch,
'inner_iters': parsed_args.inner_iters,
'meta_step_size': parsed_args.meta_step,
'meta_batch_size': parsed_args.meta_batch,
'meta_iters': parsed_args.meta_iters,
'eval_inner_batch_size': parsed_args.eval_batch,
'eval_inner_iters': parsed_args.eval_iters,
}

def evaluate_kwargs(parsed_args):
"""
Build kwargs for the evaluate() function from the
parsed command-line arguments.
"""
return {
'num_classes': parsed_args.classes,
'num_shots': parsed_args.shots,
'eval_inner_batch_size': parsed_args.eval_batch,
'eval_inner_iters': parsed_args.eval_iters,
'num_samples': parsed_args.eval_samples
}
27 changes: 27 additions & 0 deletions supervised_reptile/eval.py
@@ -0,0 +1,27 @@
"""
Helpers for evaluating models.
"""

from .reptile import Reptile

# pylint: disable=R0913,R0914
def evaluate(sess,
model,
dataset,
num_classes=5,
num_shots=5,
eval_inner_batch_size=5,
eval_inner_iters=50,
num_samples=10000):
"""
Evaluate a model on a dataset.
"""
reptile = Reptile(sess)
total_correct = 0
for _ in range(num_samples):
total_correct += reptile.evaluate(dataset, model.input_ph, model.label_ph,
model.minimize_op, model.predictions,
num_classes=num_classes, num_shots=num_shots,
inner_batch_size=eval_inner_batch_size,
inner_iters=eval_inner_iters)
return total_correct / (num_samples * num_classes)
20 changes: 10 additions & 10 deletions supervised_reptile/reptile.py
Expand Up @@ -25,12 +25,12 @@ def train_step(self,
input_ph,
label_ph,
minimize_op,
num_classes=5,
num_shots=5,
inner_batch_size=5,
inner_iters=20,
meta_step_size=0.1,
meta_batch_size=1):
num_classes,
num_shots,
inner_batch_size,
inner_iters,
meta_step_size,
meta_batch_size):
"""
Perform a Reptile training step.
Expand Down Expand Up @@ -67,10 +67,10 @@ def evaluate(self,
label_ph,
minimize_op,
predictions,
num_classes=5,
num_shots=5,
inner_batch_size=5,
inner_iters=50):
num_classes,
num_shots,
inner_batch_size,
inner_iters):
"""
Run a single evaluation of the model.
Expand Down
61 changes: 31 additions & 30 deletions supervised_reptile/train.py
Expand Up @@ -9,48 +9,49 @@
from .reptile import Reptile

# pylint: disable=R0913,R0914
def train(model,
def train(sess,
model,
train_set,
test_set,
save_dir,
num_outer_iters=70000,
num_classes=5,
num_shots=5,
inner_batch_size=5,
inner_iters=20,
meta_step_size=0.1,
meta_batch_size=1,
meta_iters=70000,
eval_inner_batch_size=5,
eval_inner_iters=50):
eval_inner_iters=50,
log_fn=print):
"""
Train a model on a dataset.
"""
if not os.path.exists(save_dir):
os.mkdir(save_dir)
with tf.Session() as sess:
saver = tf.train.Saver()
reptile = Reptile(sess)
accuracy_ph = tf.placeholder(tf.float32, shape=())
tf.summary.scalar('accuracy', accuracy_ph)
merged = tf.summary.merge_all()
train_writer = tf.summary.FileWriter(os.path.join(save_dir, 'train'), sess.graph)
test_writer = tf.summary.FileWriter(os.path.join(save_dir, 'test'), sess.graph)
tf.global_variables_initializer().run()
sess.run(tf.global_variables_initializer())
for i in range(num_outer_iters):
print('batch %d' % i)
reptile.train_step(train_set, model.input_ph, model.label_ph, model.minimize_op,
num_classes=num_classes, num_shots=num_shots,
inner_batch_size=inner_batch_size, inner_iters=inner_iters,
meta_step_size=meta_step_size, meta_batch_size=meta_batch_size)
for dataset, writer in [(train_set, train_writer), (test_set, test_writer)]:
correct = reptile.evaluate(dataset, model.input_ph, model.label_ph,
model.minimize_op, model.predictions,
num_classes=num_classes, num_shots=num_shots,
inner_batch_size=eval_inner_batch_size,
inner_iters=eval_inner_iters)
summary = sess.run(merged, feed_dict={accuracy_ph: correct/num_classes})
writer.add_summary(summary, i)
writer.flush()
if i % 100 == 0:
saver.save(sess, os.path.join(save_dir, 'model.ckpt'), global_step=i)
saver = tf.train.Saver()
reptile = Reptile(sess)
accuracy_ph = tf.placeholder(tf.float32, shape=())
tf.summary.scalar('accuracy', accuracy_ph)
merged = tf.summary.merge_all()
train_writer = tf.summary.FileWriter(os.path.join(save_dir, 'train'), sess.graph)
test_writer = tf.summary.FileWriter(os.path.join(save_dir, 'test'), sess.graph)
tf.global_variables_initializer().run()
sess.run(tf.global_variables_initializer())
for i in range(meta_iters):
log_fn('batch %d' % i)
reptile.train_step(train_set, model.input_ph, model.label_ph, model.minimize_op,
num_classes=num_classes, num_shots=num_shots,
inner_batch_size=inner_batch_size, inner_iters=inner_iters,
meta_step_size=meta_step_size, meta_batch_size=meta_batch_size)
for dataset, writer in [(train_set, train_writer), (test_set, test_writer)]:
correct = reptile.evaluate(dataset, model.input_ph, model.label_ph,
model.minimize_op, model.predictions,
num_classes=num_classes, num_shots=num_shots,
inner_batch_size=eval_inner_batch_size,
inner_iters=eval_inner_iters)
summary = sess.run(merged, feed_dict={accuracy_ph: correct/num_classes})
writer.add_summary(summary, i)
writer.flush()
if i % 100 == 0:
saver.save(sess, os.path.join(save_dir, 'model.ckpt'), global_step=i)
21 changes: 0 additions & 21 deletions train_omniglot.py

This file was deleted.

0 comments on commit 7ed728d

Please sign in to comment.