Skip to content

Commit

Permalink
Merge branch 'main' into wanhan/fix_c_sharp_meta_generation
Browse files Browse the repository at this point in the history
  • Loading branch information
D-W- committed May 10, 2024
2 parents 941c565 + d9c62a2 commit 35f3b35
Show file tree
Hide file tree
Showing 12 changed files with 604 additions and 19 deletions.
27 changes: 27 additions & 0 deletions src/promptflow-azure/promptflow/azure/_entities/_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from promptflow._utils.flow_utils import dump_flow_yaml_to_existing_path, load_flow_dag, resolve_flow_path
from promptflow._utils.logger_utils import LoggerFactory
from promptflow.azure._ml import AdditionalIncludesMixin, Code
from promptflow.core._model_configuration import MODEL_CONFIG_NAME_2_CLASS
from promptflow.exceptions import UserErrorException

from .._constants._flow import ADDITIONAL_INCLUDES, DEFAULT_STORAGE, ENVIRONMENT, PYTHON_REQUIREMENTS_TXT
from .._restclient.flow.models import FlowDto
Expand Down Expand Up @@ -67,6 +69,8 @@ def __init__(
self.code = kwargs.get("flow_resource_id", None)
elif self._flow_source == AzureFlowSource.INDEX:
self.code = kwargs.get("entity_id", None)
# set this in runtime to validate against signature
self._init_kwargs = None

def _validate_flow_from_source(self, source: Union[str, PathLike]) -> Path:
"""Validate flow from source.
Expand Down Expand Up @@ -162,6 +166,9 @@ def _try_build_local_code(self) -> Optional[Code]:
flow_file=flow_directory / flow_file, working_dir=flow_directory
)
dag_updated = update_signatures(code=flow_dir, data=flow_dag) or dag_updated
# validate init kwargs with signature
self._validate_init_kwargs(init_signatures=flow_dag.get("init"), init_kwargs=self._init_kwargs)
# validate and resolve environment
self._environment = self._resolve_environment(flow_dir, flow_dag)
if dag_updated:
dump_flow_yaml_to_existing_path(flow_dag, flow_dir)
Expand Down Expand Up @@ -248,3 +255,23 @@ def _to_dict(self):
@property
def language(self):
return self._flow_dict.get("language", FlowLanguage.Python)

@classmethod
def _validate_init_kwargs(cls, init_signatures: dict, init_kwargs: dict):
init_kwargs = init_kwargs or {}
if not isinstance(init_kwargs, dict):
raise UserErrorException(f"Init kwargs should be a dict, got {type(init_kwargs)}")
# validate init kwargs against signature
for param_name, param_value in init_kwargs.items():
if param_name not in init_signatures:
raise UserErrorException(
f"Init kwargs {param_name} is not in the flow signature. Current signatures: {init_signatures}"
)
param_signature = init_signatures[param_name]
param_type = param_signature.get("type")
if param_type in MODEL_CONFIG_NAME_2_CLASS:
if pydash.get(param_value, "connection") is None:
raise UserErrorException(
f"Init kwargs {param_name} with type {param_type} is missing connection. "
"Only connection model configs with connection is supported in cloud."
)
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,8 @@ def _resolve_flow_and_session_id(self, run: Run) -> Tuple[str, Optional[str]]:
if run._use_remote_flow:
return self._resolve_flow_definition_resource_id(run=run), None
flow = load_flow(run.flow)
# set init kwargs for validation
flow._init_kwargs = run.init
self._flow_operations._resolve_arm_id_or_upload_dependencies(
flow=flow,
# ignore .promptflow/dag.tools.json only for run submission scenario in python
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from promptflow._sdk._constants import FLOW_TOOLS_JSON, PROMPT_FLOW_DIR_NAME, DownloadedRun, RunStatus
from promptflow._sdk._errors import InvalidRunError, InvalidRunStatusError, RunNotFoundError
from promptflow._sdk._load_functions import load_run
from promptflow._sdk.entities import Run
from promptflow._sdk.entities import AzureOpenAIConnection, Run
from promptflow._utils.flow_utils import get_flow_lineage_id
from promptflow._utils.yaml_utils import dump_yaml, load_yaml
from promptflow.azure import PFClient
Expand Down Expand Up @@ -1367,41 +1367,70 @@ def assert_func(details_dict):
@pytest.mark.skipif(not is_live(), reason="Content change in submission time which lead to recording issue.")
def test_model_config_obj_in_init(self, pf):
def assert_func(details_dict):
return details_dict["outputs.azure_open_ai_model_config_azure_endpoint"] != [None, None,] and details_dict[
"outputs.azure_open_ai_model_config_connection"
] == [None, None]
return details_dict["outputs.azure_open_ai_model_config_azure_endpoint"] != [None, None]

flow_path = Path(f"{EAGER_FLOWS_DIR}/basic_model_config")
flow_path = Path(f"{EAGER_FLOWS_DIR}/basic_single_model_config")
# init with model config object
config1 = AzureOpenAIModelConfiguration(azure_deployment="my_deployment", connection="azure_open_ai")
config2 = OpenAIModelConfiguration(model="my_model", base_url="fake_base_url")
run = pf.run(
flow=flow_path,
data=f"{EAGER_FLOWS_DIR}/basic_model_config/inputs.jsonl",
init={"azure_open_ai_model_config": config1, "open_ai_model_config": config2},
data=f"{EAGER_FLOWS_DIR}/basic_single_model_config/inputs.jsonl",
init={"azure_open_ai_model_config": config1},
)
assert "azure_open_ai_model_config" in run.properties["azureml.promptflow.init_kwargs"]
assert_batch_run_result(run, pf, assert_func)

@pytest.mark.skipif(not is_live(), reason="Content change in submission time which lead to recording issue.")
def test_model_config_dict_in_init(self, pf):
def assert_func(details_dict):
return details_dict["outputs.azure_open_ai_model_config_azure_endpoint"] != [None, None,] and details_dict[
"outputs.azure_open_ai_model_config_connection"
] == [None, None]
return details_dict["outputs.azure_open_ai_model_config_azure_endpoint"] != [None, None]

flow_path = Path(f"{EAGER_FLOWS_DIR}/basic_model_config")
flow_path = Path(f"{EAGER_FLOWS_DIR}/basic_single_model_config")
# init with model config dict
config1 = dict(azure_deployment="my_deployment", connection="azure_open_ai")
config2 = dict(model="my_model", base_url="fake_base_url")
run = pf.run(
flow=flow_path,
data=f"{EAGER_FLOWS_DIR}/basic_model_config/inputs.jsonl",
init={"azure_open_ai_model_config": config1, "open_ai_model_config": config2},
data=f"{EAGER_FLOWS_DIR}/basic_single_model_config/inputs.jsonl",
init={"azure_open_ai_model_config": config1},
)
assert "azure_open_ai_model_config" in run.properties["azureml.promptflow.init_kwargs"]
assert_batch_run_result(run, pf, assert_func)

def test_exception_in_model_config(self, pf):
flow_path = Path(f"{EAGER_FLOWS_DIR}/basic_model_config")
error_msg = "Init kwargs open_ai_model_config with type OpenAIModelConfiguration is missing connection."

# init with model config dict
config1 = dict(azure_deployment="my_deployment", connection="azure_open_ai")
config2 = dict(model="my_model", base_url="fake_base_url")
with pytest.raises(UserErrorException) as e:
pf.run(
flow=flow_path,
data=f"{EAGER_FLOWS_DIR}/basic_model_config/inputs.jsonl",
init={"azure_open_ai_model_config": config1, "open_ai_model_config": config2},
)
assert error_msg in str(e.value)

# init with model config object
config1 = AzureOpenAIModelConfiguration(azure_deployment="my_deployment", connection="azure_open_ai")
config2 = OpenAIModelConfiguration(model="my_model", base_url="fake_base_url")
with pytest.raises(UserErrorException) as e:
pf.run(
flow=flow_path,
data=f"{EAGER_FLOWS_DIR}/basic_model_config/inputs.jsonl",
init={"azure_open_ai_model_config": config1, "open_ai_model_config": config2},
)
assert error_msg in str(e.value)

# invalid model config value, non-json serializable object is not supported.
with pytest.raises(UserErrorException) as e:
pf.run(
flow=Path(f"{EAGER_FLOWS_DIR}/basic_callable_class"),
data=f"{EAGER_FLOWS_DIR}/basic_callable_class/inputs.jsonl",
init={"obj_input": AzureOpenAIConnection(api_base="fake_api_base")},
)
assert "Invalid init kwargs:" in str(e.value)


def assert_batch_run_result(run: Run, pf: PFClient, assert_func):
run = pf.runs.stream(run)
Expand Down
15 changes: 15 additions & 0 deletions src/promptflow-devkit/promptflow/_sdk/entities/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ def __init__(
# TODO: such run is not resumable, not sure if we need specific error message for this case.
self._dynamic_callable = kwargs.get("dynamic_callable", None)
if init:
# validate if provided init kwargs for early exception
self._validate_init(init)
self._properties[FlowRunProperties.INIT_KWARGS] = init

def _copy(self, **kwargs):
Expand Down Expand Up @@ -578,6 +580,19 @@ def _get_flow_dir(self) -> Path:
def _get_schema_cls(self):
return RunSchema

@classmethod
def _validate_init(cls, init: Dict[str, Any]) -> Dict[str, Any]:
"""Validate and parse init kwargs."""
if not init:
return {}
if not isinstance(init, dict):
raise UserErrorException(f"Invalid init kwargs: {init}. Expecting a dictionary.")

try:
json.dumps(init, default=asdict)
except Exception as e:
raise UserErrorException(f"Invalid init kwargs: {init}. Expecting a json serializable dictionary.") from e

@classmethod
def _to_rest_init(cls, init):
"""Convert init to rest object."""
Expand Down
17 changes: 17 additions & 0 deletions src/promptflow-devkit/tests/sdk_cli_test/unittests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from promptflow._sdk.operations._local_storage_operations import LocalStorageOperations
from promptflow._utils.context_utils import inject_sys_path
from promptflow._utils.yaml_utils import load_yaml
from promptflow.connections import AzureOpenAIConnection
from promptflow.exceptions import UserErrorException, ValidationException

FLOWS_DIR = Path("./tests/test_configs/flows")
Expand Down Expand Up @@ -268,3 +269,19 @@ def static_method():
for entry in [non_callable, function, obj.method, obj.class_method, obj.static_method, MyClass.class_method]:
with pytest.raises(UserErrorException):
callable_to_entry_string(entry)

@pytest.mark.parametrize(
"init_val, expected_error_msg",
[
("val", "Invalid init kwargs: val"),
(
{"obj_input": AzureOpenAIConnection(api_base="fake_api_base")},
"Expecting a json serializable dictionary.",
),
],
)
def test_invalid_init_kwargs(self, pf, init_val, expected_error_msg):
flow_path = Path(f"{EAGER_FLOWS_DIR}/basic_callable_class")
with pytest.raises(UserErrorException) as e:
pf.run(flow=flow_path, data=f"{EAGER_FLOWS_DIR}/basic_callable_class/inputs.jsonl", init=init_val)
assert expected_error_msg in str(e.value)

0 comments on commit 35f3b35

Please sign in to comment.