Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
Document top-level API a little, make testing from the command line e…
Browse files Browse the repository at this point in the history
…asier. (#106)

Summary:
Pull Request resolved: #106

Add docstrings and types to a couple functions exported to the top-level PyText API.

Update the test cli mode to take in a PyTextConfig rather than a TestConfig for easier command line usage.

Reviewed By: ahhegazy

Differential Revision: D13367488

fbshipit-source-id: 2347e63f4e31a737566ef6aabd3399a79e6ff023
  • Loading branch information
m3rlin45 authored and facebook-github-bot committed Dec 14, 2018
1 parent dc81bac commit 0b7db65
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 38 deletions.
27 changes: 26 additions & 1 deletion pytext/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import json
import uuid
from typing import Callable, Mapping, Optional

import numpy as np
from caffe2.python import workspace
from caffe2.python.predictor import predictor_exporter

from .builtin_task import register_builtin_tasks
from .config import PyTextConfig, config_from_json
from .config.component import create_featurizer
from .data.featurizer import InputRecord
from .utils.onnx_utils import CAFFE2_DB_TYPE, convert_caffe2_blob_name
Expand All @@ -15,6 +18,9 @@
register_builtin_tasks()


Predictor = Callable[[Mapping[str, str]], Mapping[str, np.array]]


def _predict(workspace_id, feature_config, predict_net, featurizer, input):
workspace.SwitchWorkspace(workspace_id)
features = featurizer.featurize(InputRecord(**input))
Expand Down Expand Up @@ -49,7 +55,26 @@ def _predict(workspace_id, feature_config, predict_net, featurizer, input):
}


def create_predictor(config, model_file=None):
def load_config(filename: str) -> PyTextConfig:
"""
Load a PyText configuration file from a file path.
See pytext.config.pytext_config for more info on configs.
"""
with open(filename) as file:
config_json = json.loads(file.read())
if "config" not in config_json:
return config_from_json(PyTextConfig, config_json)
return config_from_json(PyTextConfig, config_json["config"])


def create_predictor(
config: PyTextConfig, model_file: Optional[str] = None
) -> Predictor:
"""
Create a simple prediction API from a training config and an exported caffe2
model file. This model file should be created by calling export on a trained
model snapshot.
"""
workspace_id = str(uuid.uuid4())
workspace.SwitchWorkspace(workspace_id, True)
predict_net = predictor_exporter.prepare_prediction_net(
Expand Down
14 changes: 3 additions & 11 deletions pytext/config/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,6 @@
from .pytext_config import PyTextConfig, TestConfig


class Mode(Enum):
TRAIN = "train"
TEST = "test"


class ConfigParseError(Exception):
pass

Expand Down Expand Up @@ -219,13 +214,10 @@ def _get_class_type(cls):
return cls.__origin__ if hasattr(cls, "__origin__") else cls


def parse_config(mode, config_json):
def parse_config(config_json):
"""
Parse PyTextConfig object from parameter string or parameter file
"""
config_cls = {Mode.TRAIN: PyTextConfig, Mode.TEST: TestConfig}[mode]
# TODO T32608471 should assume the entire json is PyTextConfig later, right
# now we're matching the file format for pytext trainer.py inside fbl
if "config" not in config_json:
return config_from_json(config_cls, config_json)
return config_from_json(config_cls, config_json["config"])
return config_from_json(PyTextConfig, config_json)
return config_from_json(PyTextConfig, config_json["config"])
5 changes: 2 additions & 3 deletions pytext/config/test/pytext_all_config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import unittest

from pytext.builtin_task import register_builtin_tasks
from pytext.config.serialize import Mode, parse_config
from pytext.config.serialize import parse_config


register_builtin_tasks()
Expand All @@ -32,6 +32,5 @@ def test_load_all_configs(self):
print("--- loading:", filename)
with open(filename) as file:
config_json = json.load(file)
# Most configs don't work in Mode.TEST
config = parse_config(Mode.TRAIN, config_json)
config = parse_config(config_json)
self.assertIsNotNone(config)
56 changes: 40 additions & 16 deletions pytext/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
from pytext import create_predictor
from pytext.config import PyTextConfig, TestConfig
from pytext.config.serialize import Mode, config_from_json, config_to_json, parse_config
from pytext.config.serialize import config_from_json, config_to_json, parse_config
from pytext.task import load
from pytext.utils.documentation_helper import (
ROOT_CONFIG,
Expand All @@ -22,7 +22,7 @@
from pytext.workflow import (
batch_predict,
export_saved_model_to_caffe2,
test_model,
test_model_from_snapshot_path,
train_model,
)
from torch.multiprocessing.spawn import spawn
Expand Down Expand Up @@ -180,33 +180,57 @@ def gen_default_config(context, task_name, options):


@main.command()
@click.option(
"--model-snapshot",
default="",
help="load model snapshot and test configuration from this file",
)
@click.option("--test-path", default="", help="path to test data")
@click.option(
"--use-cuda/--no-cuda",
default=None,
help="Run supported parts of the model on GPU if available.",
)
@click.pass_context
def test(context):
"""Test a trained model snapshot."""
config_json = context.obj.load_config()
config = parse_config(Mode.TEST, config_json)
def test(context, model_snapshot, test_path, use_cuda):
"""Test a trained model snapshot.
If model-snapshot is provided, the models and configuration will then be loaded from
the snapshot rather than any passed config file.
Otherwise, a config file will be loaded.
"""
if model_snapshot:
print(f"Loading model snapshot and config from {model_snapshot}")
if use_cuda is None:
raise Exception(
"if --model-snapshot is set --use-cuda/--no-cuda must be set"
)
else:
print(f"No model snapshot provided, loading from config")
config = parse_config(context.obj.load_config())
model_snapshot = config.save_snapshot_path
use_cuda = config.use_cuda_if_available
print(f"Configured model snapshot {model_snapshot}")
print("\n=== Starting testing...")
test_model(config)
test_model_from_snapshot_path(model_snapshot, use_cuda, test_path)


@main.command()
@click.pass_context
def train(context):
"""Train a model and save the best snapshot."""
config_json = context.obj.load_config()
config = parse_config(Mode.TRAIN, config_json)
config = parse_config(context.obj.load_config())
print("\n===Starting training...")
if config.distributed_world_size == 1:
train_model(config)
else:
train_model_distributed(config)
print("\n=== Starting testing...")
test_config = TestConfig(
load_snapshot_path=config.save_snapshot_path,
test_path=config.task.data_handler.test_path,
use_cuda_if_available=config.use_cuda_if_available,
test_model_from_snapshot_path(
config.save_snapshot_path,
config.use_cuda_if_available,
config.task.data_handler.test_path,
)
test_model(test_config)


@main.command()
Expand All @@ -215,7 +239,7 @@ def train(context):
@click.pass_context
def export(context, model, output_path):
"""Convert a pytext model snapshot to a caffe2 model."""
config = parse_config(Mode.TRAIN, context.obj.load_config())
config = parse_config(context.obj.load_config())
model = model or config.save_snapshot_path
output_path = output_path or config.export_caffe2_path
print(f"Exporting {model} to {output_path}")
Expand All @@ -227,7 +251,7 @@ def export(context, model, output_path):
@click.pass_context
def predict(context, exported_model):
"""Start a repl executing examples against a caffe2 model."""
config = parse_config(Mode.TRAIN, context.obj.load_config())
config = parse_config(context.obj.load_config())
print(f"Loading model from {exported_model or config.export_caffe2_path}")
predictor = create_predictor(config, exported_model)

Expand Down
26 changes: 19 additions & 7 deletions pytext/workflow.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import os
from typing import Any, Dict, List, Tuple, get_type_hints
from typing import Any, Dict, List, Optional, Tuple, get_type_hints

import torch
from pytext.config import PyTextConfig, TestConfig
Expand Down Expand Up @@ -103,16 +103,28 @@ def export_saved_model_to_caffe2(


def test_model(test_config: TestConfig, metrics_channel: Channel = None) -> Any:
_set_cuda(test_config.use_cuda_if_available)
return test_model_from_snapshot_path(
test_config.load_snapshot_path,
test_config.use_cuda_if_available,
test_config.test_path,
metrics_channel,
)


task, train_config = load(test_config.load_snapshot_path)
def test_model_from_snapshot_path(
snapshot_path: str,
use_cuda_if_available: bool,
test_path: Optional[str] = None,
metrics_channel: Optional[Channel] = None,
):
_set_cuda(use_cuda_if_available)
task, train_config = load(snapshot_path)
if not test_path:
test_path = train_config.task.data_handler.test_path
if metrics_channel is not None:
task.metric_reporter.add_channel(metrics_channel)

return (
task.test(test_config.test_path),
train_config.task.metric_reporter.output_path,
)
return (task.test(test_path), train_config.task.metric_reporter.output_path)


def batch_predict(model_file: str, examples: List[Dict[str, Any]]):
Expand Down

0 comments on commit 0b7db65

Please sign in to comment.