Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
train+evaluate in one script; add cmd-line args
- Loading branch information
1 parent
52cf969
commit 7ed728d
Showing
7 changed files
with
160 additions
and
62 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
# Custom output files | ||
omniglot_out | ||
model_checkpoint | ||
data | ||
|
||
# Byte-compiled / optimized / DLL files | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
""" | ||
Train a model on Omniglot. | ||
""" | ||
|
||
import tensorflow as tf | ||
|
||
from supervised_reptile.args import argument_parser, train_kwargs, evaluate_kwargs | ||
from supervised_reptile.eval import evaluate | ||
from supervised_reptile.models import OmniglotModel | ||
from supervised_reptile.omniglot import read_dataset, split_dataset, augment_dataset | ||
from supervised_reptile.train import train | ||
|
||
DATA_DIR = 'data/omniglot' | ||
SAVE_DIR = 'omniglot_out' | ||
|
||
def main(): | ||
""" | ||
Load data and train a model on it. | ||
""" | ||
args = argument_parser().parse_args() | ||
|
||
train_set, test_set = split_dataset(read_dataset(DATA_DIR)) | ||
train_set = list(augment_dataset(train_set)) | ||
test_set = list(test_set) | ||
|
||
model = OmniglotModel(args.classes) | ||
|
||
with tf.Session() as sess: | ||
print('Training...') | ||
train(sess, model, train_set, test_set, args.checkpoint, **train_kwargs(args)) | ||
|
||
print('Evaluating...') | ||
eval_kwargs = evaluate_kwargs(args) | ||
print('Train accuracy: ' + str(evaluate(sess, model, train_set, **eval_kwargs))) | ||
print('Test accuracy: ' + str(evaluate(sess, model, test_set, **eval_kwargs))) | ||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
""" | ||
Command-line argument parsing. | ||
""" | ||
|
||
import argparse | ||
|
||
def argument_parser(): | ||
""" | ||
Get an argument parser for a training script. | ||
""" | ||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||
parser.add_argument('--checkpoint', help='checkpoint directory', default='model_checkpoint') | ||
parser.add_argument('--classes', help='number of classes per inner task', default=5, type=int) | ||
parser.add_argument('--shots', help='number of examples per class', default=5, type=int) | ||
parser.add_argument('--inner-batch', help='inner batch size', default=5, type=int) | ||
parser.add_argument('--inner-iters', help='inner iterations', default=20, type=int) | ||
parser.add_argument('--meta-step', help='meta-training step size', default=0.1, type=float) | ||
parser.add_argument('--meta-batch', help='meta-training batch size', default=1, type=int) | ||
parser.add_argument('--meta-iters', help='meta-training iterations', default=70000, type=int) | ||
parser.add_argument('--eval-batch', help='eval inner batch size', default=5, type=int) | ||
parser.add_argument('--eval-iters', help='eval inner iterations', default=50, type=int) | ||
parser.add_argument('--eval-samples', help='evaluation samples', default=10000, type=int) | ||
return parser | ||
|
||
def train_kwargs(parsed_args): | ||
""" | ||
Build kwargs for the train() function from the parsed | ||
command-line arguments. | ||
""" | ||
return { | ||
'num_classes': parsed_args.classes, | ||
'num_shots': parsed_args.shots, | ||
'inner_batch_size': parsed_args.inner_batch, | ||
'inner_iters': parsed_args.inner_iters, | ||
'meta_step_size': parsed_args.meta_step, | ||
'meta_batch_size': parsed_args.meta_batch, | ||
'meta_iters': parsed_args.meta_iters, | ||
'eval_inner_batch_size': parsed_args.eval_batch, | ||
'eval_inner_iters': parsed_args.eval_iters, | ||
} | ||
|
||
def evaluate_kwargs(parsed_args): | ||
""" | ||
Build kwargs for the evaluate() function from the | ||
parsed command-line arguments. | ||
""" | ||
return { | ||
'num_classes': parsed_args.classes, | ||
'num_shots': parsed_args.shots, | ||
'eval_inner_batch_size': parsed_args.eval_batch, | ||
'eval_inner_iters': parsed_args.eval_iters, | ||
'num_samples': parsed_args.eval_samples | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
""" | ||
Helpers for evaluating models. | ||
""" | ||
|
||
from .reptile import Reptile | ||
|
||
# pylint: disable=R0913,R0914 | ||
def evaluate(sess, | ||
model, | ||
dataset, | ||
num_classes=5, | ||
num_shots=5, | ||
eval_inner_batch_size=5, | ||
eval_inner_iters=50, | ||
num_samples=10000): | ||
""" | ||
Evaluate a model on a dataset. | ||
""" | ||
reptile = Reptile(sess) | ||
total_correct = 0 | ||
for _ in range(num_samples): | ||
total_correct += reptile.evaluate(dataset, model.input_ph, model.label_ph, | ||
model.minimize_op, model.predictions, | ||
num_classes=num_classes, num_shots=num_shots, | ||
inner_batch_size=eval_inner_batch_size, | ||
inner_iters=eval_inner_iters) | ||
return total_correct / (num_samples * num_classes) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.