Skip to content

Commit

Permalink
Update api to return both model path and config path
Browse files Browse the repository at this point in the history
  • Loading branch information
brightsparc committed Jun 1, 2022
1 parent 1b14b21 commit c156cb2
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
6 changes: 3 additions & 3 deletions ludwig/utils/triton_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def _get_model_config(model: LudwigModel) -> str:

def export_triton(
model: LudwigModel, output_path: str, model_name: str = "ludwig_model", model_version: Union[int, str] = 1
) -> str:
) -> (str, str):
"""Exports a torchscript model to a output path that serves as a repository for Triton Inference Server.
# Inputs
Expand All @@ -174,7 +174,7 @@ def export_triton(
:param model_name: (Union[int,str]) The optional model verison.
# Return
:return: (str) The saved model path.
:return: (str, str) The saved model path, and config path.
"""
model_ts = generate_triton_torchscript(model)
model_dir = os.path.join(output_path, model_name, str(model_version))
Expand All @@ -186,4 +186,4 @@ def export_triton(
config_path = os.path.join(output_path, model_name, "config.pbtxt")
with open(config_path, "w") as f:
f.write(_get_model_config(model))
return model_path
return model_path, config_path
4 changes: 3 additions & 1 deletion tests/integration_tests/test_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,13 @@ def test_triton_torchscript(csv_filename, tmpdir):
triton_path = os.path.join(tmpdir, "triton")
model_name = "test_triton"
model_version = 1
model_path = export_triton(ludwig_model, triton_path, model_name, model_version)
model_path, config_path = export_triton(ludwig_model, triton_path, model_name, model_version)

# Validate relative path
output_filename = os.path.relpath(model_path, triton_path)
assert output_filename == f"{model_name}/{model_version}/model.pt"
config_filename = os.path.relpath(config_path, triton_path)
assert config_filename == f"{model_name}/config.pbtxt"

# Restore the torchscript model
restored_model = torch.jit.load(model_path)
Expand Down

0 comments on commit c156cb2

Please sign in to comment.