Skip to content

Commit

Permalink
Merge pull request #78 from Gabriel2409/feature/on-schedule-job-callback
Browse files Browse the repository at this point in the history
Feature/on schedule job callback
  • Loading branch information
marrrcin committed Sep 29, 2023
2 parents 5ae7bcf + d2d0906 commit 1a7f8dd
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 5 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## [Unreleased]

- Added `--on-job-scheduled` argument to `kedro azureml run` to plug-in custom behaviour after Azure ML job is scheduled [@Gabriel2409](https://github.com/Gabriel2409)

## [0.6.0] - 2023-09-01

- Added ability to mark a node as deterministic (enables caching on Azure ML) by [@tomasvanpottelbergh](https://github.com/tomasvanpottelbergh)
Expand Down
15 changes: 15 additions & 0 deletions docs/source/03_quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,21 @@ In case you need to customize pipeline run context, modifying configuration file
- ``--pipeline`` allows to select a pipeline to run (by default, the ``__default__`` pipeline is started),
- ``--params`` takes a JSON string with parameters override (JSONed version of ``conf/*/parameters.yml``, not the Kedro's ``params:`` syntax),
- ``--env-var KEY=VALUE`` sets the OS environment variable injected to the steps during runtime (can be used multiple times).
- ``--load-versions`` specifies a particular dataset version (timestamp) for loading (similar behavior as Kedro)
- ``--on-job-scheduled path.to.module:my_function`` specifies a callback function to be called on the azureml pipeline job start (example below)

.. code:: python
# src/mymodule/myfile.py
def save_output_callback(job):
"""saves the pipeline job name to a file"""
with open("myfile.txt", "w") as f:
f.write(job.name)
.. code:: console
kedro azureml run --on-job-scheduled mymodule.myfile:save_output_callback
.. |br| raw:: html

Expand Down
22 changes: 20 additions & 2 deletions kedro_azureml/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
import logging
import os
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Optional, Tuple

import click
from kedro.framework.cli.project import LOAD_VERSION_HELP
from kedro.framework.cli.utils import _split_load_versions
from kedro.framework.startup import ProjectMetadata

from kedro_azureml.cli_functions import (
default_job_callback,
dynamic_import_job_schedule_func_from_str,
get_context_and_pipeline,
parse_extra_env_params,
parse_extra_params,
Expand Down Expand Up @@ -216,6 +218,18 @@ def init(
help=LOAD_VERSION_HELP,
callback=_split_load_versions,
)
@click.option(
"--on-job-scheduled",
"on_job_scheduled",
callback=dynamic_import_job_schedule_func_from_str,
help="""Specify a function to execute when the azureml pipeline job
is scheduled. The function should be in the format 'path.to.module:function' with
'path.to.module' being the relative path starting from the src folder created on
kedro initialisation.
The function will be called with the scheduled pipeline job as an argument just
after the job creation. Return values will be discarded.
Defaults to echoing the job.studio_url""",
)
@click.pass_obj
@click.pass_context
def run(
Expand All @@ -229,6 +243,7 @@ def run(
wait_for_completion: bool,
env_var: Tuple[str],
load_versions: Dict[str, str],
on_job_scheduled: Optional[Callable],
):
"""Runs the specified pipeline in Azure ML Pipelines; Additional parameters can be passed from command line.
Can be used with --wait-for-completion param to block the caller until the pipeline finishes in Azure ML.
Expand All @@ -255,10 +270,13 @@ def run(
):
az_client = AzureMLPipelinesClient(az_pipeline, subscription_id)

if not on_job_scheduled:
on_job_scheduled = default_job_callback

is_ok = az_client.run(
mgr.plugin_config.azure,
wait_for_completion,
lambda job: click.echo(job.studio_url),
on_job_scheduled,
)

if is_ok:
Expand Down
69 changes: 68 additions & 1 deletion kedro_azureml/cli_functions.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import importlib
import json
import logging
import os
import re
from contextlib import contextmanager
from pathlib import Path
from typing import Dict, Optional
from typing import Callable, Dict, Optional

import click

Expand Down Expand Up @@ -148,3 +149,69 @@ def parse_extra_env_params(extra_env):
raise Exception(f"Invalid env-var: {entry}, expected format: KEY=VALUE")

return {(e := entry.split("=", maxsplit=1))[0]: e[1] for entry in extra_env}


def dynamic_import_job_schedule_func_from_str(
ctx: click.Context,
param: click.Parameter,
import_str: str,
) -> Optional[Callable]:
"""
Dynamically import and retrieve a function from a specified module.
The function must have exactly one parameter of type azure.ai.ml.entities.Job.
Note that there is no check on the parameter type
This function is designed to be used in Click-based command-line applications.
:param ctx: The Click context.
:type ctx: click.Context
:param param: The Click parameter associated with this function.
:type param: click.Parameter
:param import_str: A string in the format 'path.to.file:function'
specifying the module and function to import.
:type import_str: str
:returns: The imported function.
:rtype: Any
:raises click.BadParameter: If the `import_str` is not in the correct format,
if the specified module cannot be imported,
if the specified attribute cannot be retrieved from the module,
if the retrieved attribute is not a callable function,
Example usage:
>>> instance = dynamic_import_job_schedule_func_from_str(
ctx, param, "my_module:my_function"
)
Inspired by the `uvicorn/importer.py` module's `import_from_string` function.
"""
# base case: no callback
if import_str is None:
return

# check format
module_str, _, attrs_str = import_str.partition(":")
if not module_str or not attrs_str:
raise click.BadParameter(
"import_str must be in format <module>:<function>", param=param
)

try:
module = importlib.import_module(module_str)
instance = getattr(module, attrs_str)

# fails if we try to import an attribute that is not a function
if not callable(instance):
raise click.BadParameter(
f"The attribute '{attrs_str}' is not a callable function.", param=param
)

return instance
except (ImportError, AttributeError, ValueError) as e:
# catches errors if module or attribute does not exist
raise click.BadParameter(f"Error: {e}", param=param)


def default_job_callback(job):
click.echo(job.studio_url)
15 changes: 15 additions & 0 deletions tests/helpers/on_job_scheduled_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Main purpose of this file is to help test the --on-job-scheduled argument in
`kedro azureml run`. The reason it is needed is to avoid having to patch the
importlib.import_module method to return a dummy module, which for some reason
causes the rest of the run to fail.
"""


def existing_function(job):
"""Purpose of this function is to be mocked. It must still exist so that
getattr does not return an error after importing the module with importlib"""
return


# Purpose of this variable is to test what happens when we pass a non callable attribute
existing_attr = True
93 changes: 91 additions & 2 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,19 @@ def test_can_invoke_execute_cli(
(["A=CDE=F123"], {"A": "CDE=F123"}),
),
)
@pytest.mark.parametrize(
"on_job_scheduled",
# we pass args as True/False but the test will generate a valid string
# if we pass True
(
False,
True,
),
ids=(
"no_job_scheduled",
"existing_callback",
),
)
def test_can_invoke_run(
patched_kedro_package,
cli_context,
Expand All @@ -267,6 +280,7 @@ def test_can_invoke_run(
amlignore: str,
gitignore: str,
extra_env: list,
on_job_scheduled: str,
):
create_kedro_conf_dirs(tmp_path)
with patch.dict(
Expand All @@ -279,7 +293,11 @@ def test_can_invoke_run(
"kedro_azureml.auth.utils.InteractiveBrowserCredential"
) as interactive_credentials, patch.dict(
os.environ, {"AZURE_STORAGE_ACCOUNT_KEY": "dummy_key"}
):
), patch(
# mock of existing_function to test --on-job-scheduled
"tests.helpers.on_job_scheduled_helper.existing_function",
return_value=MagicMock(),
) as on_job_scheduled_callback_mock:
if not use_default_credentials:
default_credentials.side_effect = ValueError()

Expand All @@ -299,7 +317,15 @@ def test_can_invoke_run(
["-s", "subscription_id"]
+ (["--wait-for-completion"] if wait_for_completion else [])
+ (["--aml-env", aml_env] if aml_env else [])
+ (sum([["--env-var", k] for k in extra_env[0]], [])),
+ (sum([["--env-var", k] for k in extra_env[0]], []))
+ (
[
"--on-job-scheduled",
"tests.helpers.on_job_scheduled_helper:existing_function",
]
if on_job_scheduled
else []
),
obj=cli_context,
)
assert result.exit_code == 0
Expand All @@ -308,6 +334,11 @@ def test_can_invoke_run(
ml_client.jobs.create_or_update.assert_called_once()
ml_client.compute.get.assert_called_once()

if on_job_scheduled:
on_job_scheduled_callback_mock.assert_called_once()
else:
on_job_scheduled_callback_mock.assert_not_called()

if wait_for_completion:
ml_client.jobs.stream.assert_called_once()

Expand Down Expand Up @@ -443,3 +474,61 @@ def test_fail_if_invalid_env_provided_in_run(
str(result.exception)
== f"Invalid env-var: {env_var}, expected format: KEY=VALUE"
)


@pytest.mark.parametrize(
"on_job_scheduled",
(
"bad_str_format",
"nonexistant_module:func",
"tests.helpers.on_job_scheduled_helper:absent_attr",
"tests.helpers.on_job_scheduled_helper:existing_attr",
),
ids=(
"bad_str_format",
"no_module",
"no_attr",
"not_callable",
),
)
def test_fail_if_invalid_on_job_scheduled_provided_in_run(
patched_kedro_package,
cli_context,
dummy_pipeline,
tmp_path: Path,
on_job_scheduled: str,
):
create_kedro_conf_dirs(tmp_path)
with patch.dict(
"kedro.framework.project.pipelines", {"__default__": dummy_pipeline}
), patch.object(Path, "cwd", return_value=tmp_path), patch(
"kedro_azureml.client.MLClient"
) as ml_client_patched, patch(
"kedro_azureml.auth.utils.DefaultAzureCredential"
), patch.dict(
os.environ, {"AZURE_STORAGE_ACCOUNT_KEY": "dummy_key"}
):
ml_client = ml_client_patched.from_config()
ml_client.jobs.stream.side_effect = ValueError()

runner = CliRunner()
result = runner.invoke(
cli.run, ["--on-job-scheduled", on_job_scheduled], obj=cli_context
)
assert result.exit_code != 0
assert result.exception, "Exception should have been raised"

if on_job_scheduled == "bad_str_format":
assert "import_str must be in format <module>:<function>" in result.output
elif on_job_scheduled == "nonexistant_module:func":
assert "No module named 'nonexistant_module'" in result.output
elif on_job_scheduled == "tests.helpers.on_job_scheduled_helper:absent_attr":
assert (
"module 'tests.helpers.on_job_scheduled_helper' has no attribute 'absent_attr'"
in result.output
)
elif on_job_scheduled == "tests.helpers.on_job_scheduled_helper:existing_attr":
assert (
"The attribute 'existing_attr' is not a callable function"
in result.output
)

0 comments on commit 1a7f8dd

Please sign in to comment.