From 0b7db65c51def2f38692258f0b20d91956688ecc Mon Sep 17 00:00:00 2001 From: Christopher Dewan Date: Fri, 14 Dec 2018 01:13:45 -0800 Subject: [PATCH] Document top-level API a little, make testing from the command line easier. (#106) Summary: Pull Request resolved: https://github.com/facebookresearch/pytext/pull/106 Add docstrings and types to a couple functions exported to the top-level PyText API. Update the test cli mode to take in a PyTextConfig rather than a TestConfig for easier command line usage. Reviewed By: ahhegazy Differential Revision: D13367488 fbshipit-source-id: 2347e63f4e31a737566ef6aabd3399a79e6ff023 --- pytext/__init__.py | 27 +++++++++- pytext/config/serialize.py | 14 ++--- pytext/config/test/pytext_all_config_test.py | 5 +- pytext/main.py | 56 ++++++++++++++------ pytext/workflow.py | 26 ++++++--- 5 files changed, 90 insertions(+), 38 deletions(-) diff --git a/pytext/__init__.py b/pytext/__init__.py index 1bf305a45..89906da7b 100644 --- a/pytext/__init__.py +++ b/pytext/__init__.py @@ -1,12 +1,15 @@ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import json import uuid +from typing import Callable, Mapping, Optional import numpy as np from caffe2.python import workspace from caffe2.python.predictor import predictor_exporter from .builtin_task import register_builtin_tasks +from .config import PyTextConfig, config_from_json from .config.component import create_featurizer from .data.featurizer import InputRecord from .utils.onnx_utils import CAFFE2_DB_TYPE, convert_caffe2_blob_name @@ -15,6 +18,9 @@ register_builtin_tasks() +Predictor = Callable[[Mapping[str, str]], Mapping[str, np.array]] + + def _predict(workspace_id, feature_config, predict_net, featurizer, input): workspace.SwitchWorkspace(workspace_id) features = featurizer.featurize(InputRecord(**input)) @@ -49,7 +55,26 @@ def _predict(workspace_id, feature_config, predict_net, featurizer, input): } -def create_predictor(config, model_file=None): +def load_config(filename: str) -> PyTextConfig: + """ + Load a PyText configuration file from a file path. + See pytext.config.pytext_config for more info on configs. + """ + with open(filename) as file: + config_json = json.loads(file.read()) + if "config" not in config_json: + return config_from_json(PyTextConfig, config_json) + return config_from_json(PyTextConfig, config_json["config"]) + + +def create_predictor( + config: PyTextConfig, model_file: Optional[str] = None +) -> Predictor: + """ + Create a simple prediction API from a training config and an exported caffe2 + model file. This model file should be created by calling export on a trained + model snapshot. + """ workspace_id = str(uuid.uuid4()) workspace.SwitchWorkspace(workspace_id, True) predict_net = predictor_exporter.prepare_prediction_net( diff --git a/pytext/config/serialize.py b/pytext/config/serialize.py index 269f66f8c..7b3868932 100644 --- a/pytext/config/serialize.py +++ b/pytext/config/serialize.py @@ -7,11 +7,6 @@ from .pytext_config import PyTextConfig, TestConfig -class Mode(Enum): - TRAIN = "train" - TEST = "test" - - class ConfigParseError(Exception): pass @@ -219,13 +214,10 @@ def _get_class_type(cls): return cls.__origin__ if hasattr(cls, "__origin__") else cls -def parse_config(mode, config_json): +def parse_config(config_json): """ Parse PyTextConfig object from parameter string or parameter file """ - config_cls = {Mode.TRAIN: PyTextConfig, Mode.TEST: TestConfig}[mode] - # TODO T32608471 should assume the entire json is PyTextConfig later, right - # now we're matching the file format for pytext trainer.py inside fbl if "config" not in config_json: - return config_from_json(config_cls, config_json) - return config_from_json(config_cls, config_json["config"]) + return config_from_json(PyTextConfig, config_json) + return config_from_json(PyTextConfig, config_json["config"]) diff --git a/pytext/config/test/pytext_all_config_test.py b/pytext/config/test/pytext_all_config_test.py index 32cc2959b..3196f69a1 100644 --- a/pytext/config/test/pytext_all_config_test.py +++ b/pytext/config/test/pytext_all_config_test.py @@ -6,7 +6,7 @@ import unittest from pytext.builtin_task import register_builtin_tasks -from pytext.config.serialize import Mode, parse_config +from pytext.config.serialize import parse_config register_builtin_tasks() @@ -32,6 +32,5 @@ def test_load_all_configs(self): print("--- loading:", filename) with open(filename) as file: config_json = json.load(file) - # Most configs don't work in Mode.TEST - config = parse_config(Mode.TRAIN, config_json) + config = parse_config(config_json) self.assertIsNotNone(config) diff --git a/pytext/main.py b/pytext/main.py index 59484fbd6..91cd1e0a9 100644 --- a/pytext/main.py +++ b/pytext/main.py @@ -10,7 +10,7 @@ import torch from pytext import create_predictor from pytext.config import PyTextConfig, TestConfig -from pytext.config.serialize import Mode, config_from_json, config_to_json, parse_config +from pytext.config.serialize import config_from_json, config_to_json, parse_config from pytext.task import load from pytext.utils.documentation_helper import ( ROOT_CONFIG, @@ -22,7 +22,7 @@ from pytext.workflow import ( batch_predict, export_saved_model_to_caffe2, - test_model, + test_model_from_snapshot_path, train_model, ) from torch.multiprocessing.spawn import spawn @@ -180,33 +180,57 @@ def gen_default_config(context, task_name, options): @main.command() +@click.option( + "--model-snapshot", + default="", + help="load model snapshot and test configuration from this file", +) +@click.option("--test-path", default="", help="path to test data") +@click.option( + "--use-cuda/--no-cuda", + default=None, + help="Run supported parts of the model on GPU if available.", +) @click.pass_context -def test(context): - """Test a trained model snapshot.""" - config_json = context.obj.load_config() - config = parse_config(Mode.TEST, config_json) +def test(context, model_snapshot, test_path, use_cuda): + """Test a trained model snapshot. + + If model-snapshot is provided, the models and configuration will then be loaded from + the snapshot rather than any passed config file. + Otherwise, a config file will be loaded. + """ + if model_snapshot: + print(f"Loading model snapshot and config from {model_snapshot}") + if use_cuda is None: + raise Exception( + "if --model-snapshot is set --use-cuda/--no-cuda must be set" + ) + else: + print(f"No model snapshot provided, loading from config") + config = parse_config(context.obj.load_config()) + model_snapshot = config.save_snapshot_path + use_cuda = config.use_cuda_if_available + print(f"Configured model snapshot {model_snapshot}") print("\n=== Starting testing...") - test_model(config) + test_model_from_snapshot_path(model_snapshot, use_cuda, test_path) @main.command() @click.pass_context def train(context): """Train a model and save the best snapshot.""" - config_json = context.obj.load_config() - config = parse_config(Mode.TRAIN, config_json) + config = parse_config(context.obj.load_config()) print("\n===Starting training...") if config.distributed_world_size == 1: train_model(config) else: train_model_distributed(config) print("\n=== Starting testing...") - test_config = TestConfig( - load_snapshot_path=config.save_snapshot_path, - test_path=config.task.data_handler.test_path, - use_cuda_if_available=config.use_cuda_if_available, + test_model_from_snapshot_path( + config.save_snapshot_path, + config.use_cuda_if_available, + config.task.data_handler.test_path, ) - test_model(test_config) @main.command() @@ -215,7 +239,7 @@ def train(context): @click.pass_context def export(context, model, output_path): """Convert a pytext model snapshot to a caffe2 model.""" - config = parse_config(Mode.TRAIN, context.obj.load_config()) + config = parse_config(context.obj.load_config()) model = model or config.save_snapshot_path output_path = output_path or config.export_caffe2_path print(f"Exporting {model} to {output_path}") @@ -227,7 +251,7 @@ def export(context, model, output_path): @click.pass_context def predict(context, exported_model): """Start a repl executing examples against a caffe2 model.""" - config = parse_config(Mode.TRAIN, context.obj.load_config()) + config = parse_config(context.obj.load_config()) print(f"Loading model from {exported_model or config.export_caffe2_path}") predictor = create_predictor(config, exported_model) diff --git a/pytext/workflow.py b/pytext/workflow.py index 9411c3ea7..aac0aba6d 100644 --- a/pytext/workflow.py +++ b/pytext/workflow.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import os -from typing import Any, Dict, List, Tuple, get_type_hints +from typing import Any, Dict, List, Optional, Tuple, get_type_hints import torch from pytext.config import PyTextConfig, TestConfig @@ -103,16 +103,28 @@ def export_saved_model_to_caffe2( def test_model(test_config: TestConfig, metrics_channel: Channel = None) -> Any: - _set_cuda(test_config.use_cuda_if_available) + return test_model_from_snapshot_path( + test_config.load_snapshot_path, + test_config.use_cuda_if_available, + test_config.test_path, + metrics_channel, + ) + - task, train_config = load(test_config.load_snapshot_path) +def test_model_from_snapshot_path( + snapshot_path: str, + use_cuda_if_available: bool, + test_path: Optional[str] = None, + metrics_channel: Optional[Channel] = None, +): + _set_cuda(use_cuda_if_available) + task, train_config = load(snapshot_path) + if not test_path: + test_path = train_config.task.data_handler.test_path if metrics_channel is not None: task.metric_reporter.add_channel(metrics_channel) - return ( - task.test(test_config.test_path), - train_config.task.metric_reporter.output_path, - ) + return (task.test(test_path), train_config.task.metric_reporter.output_path) def batch_predict(model_file: str, examples: List[Dict[str, Any]]):