diff --git a/ludwig/utils/triton_utils.py b/ludwig/utils/triton_utils.py index 77a2a334e1a..ce20c581cc8 100644 --- a/ludwig/utils/triton_utils.py +++ b/ludwig/utils/triton_utils.py @@ -163,7 +163,7 @@ def _get_model_config(model: LudwigModel) -> str: def export_triton( model: LudwigModel, output_path: str, model_name: str = "ludwig_model", model_version: Union[int, str] = 1 -) -> str: +) -> (str, str): """Exports a torchscript model to a output path that serves as a repository for Triton Inference Server. # Inputs @@ -174,7 +174,7 @@ def export_triton( :param model_name: (Union[int,str]) The optional model verison. # Return - :return: (str) The saved model path. + :return: (str, str) The saved model path, and config path. """ model_ts = generate_triton_torchscript(model) model_dir = os.path.join(output_path, model_name, str(model_version)) @@ -186,4 +186,4 @@ def export_triton( config_path = os.path.join(output_path, model_name, "config.pbtxt") with open(config_path, "w") as f: f.write(_get_model_config(model)) - return model_path + return model_path, config_path diff --git a/tests/integration_tests/test_triton.py b/tests/integration_tests/test_triton.py index b910d3254b4..c25e14493f9 100644 --- a/tests/integration_tests/test_triton.py +++ b/tests/integration_tests/test_triton.py @@ -90,11 +90,13 @@ def test_triton_torchscript(csv_filename, tmpdir): triton_path = os.path.join(tmpdir, "triton") model_name = "test_triton" model_version = 1 - model_path = export_triton(ludwig_model, triton_path, model_name, model_version) + model_path, config_path = export_triton(ludwig_model, triton_path, model_name, model_version) # Validate relative path output_filename = os.path.relpath(model_path, triton_path) assert output_filename == f"{model_name}/{model_version}/model.pt" + config_filename = os.path.relpath(config_path, triton_path) + assert config_filename == f"{model_name}/config.pbtxt" # Restore the torchscript model restored_model = torch.jit.load(model_path)