diff --git a/datalayer_core/sdk/datalayer.py b/datalayer_core/sdk/datalayer.py index 09341a1b..ce3f7eea 100644 --- a/datalayer_core/sdk/datalayer.py +++ b/datalayer_core/sdk/datalayer.py @@ -220,6 +220,8 @@ def create_runtime( if name is None: name = f"runtime-{environment}-{uuid.uuid4()}" + # print(f"Runtime {name}") + if snapshot_name is not None: snapshots = self.list_snapshots() for snapshot in snapshots: @@ -738,7 +740,9 @@ def _start(self) -> None: if self._kernel_client is None: self._runtime = self._create_runtime(self._environment_name) - runtime: dict[str, str] = self._runtime.get("runtime") # type: ignore + # print(self._runtime) + runtime: dict[str, str] = self._runtime["runtime"] # type: ignore + # print("runtime", runtime) self._ingress = runtime["ingress"] self._kernel_token = runtime["token"] self._pod_name = runtime["pod_name"] @@ -820,7 +824,10 @@ def set_variables(self, variables: dict[str, Any]) -> Response: return Response([]) def execute_file( - self, path: Union[str, Path], variables: Optional[dict[str, Any]] = None + self, + path: Union[str, Path], + variables: Optional[dict[str, Any]] = None, + output: Optional[str] = None, ) -> Response: """ Execute a Python file in the runtime. @@ -831,6 +838,8 @@ def execute_file( Path to the Python file to execute. variables: Optional[dict[str, Any]] Optional variables to set before executing the code. + output: Optional[str] + Optional output variable to return as result. Returns ------- @@ -841,18 +850,26 @@ def execute_file( if variables: self.set_variables(variables) - for _id, cell in _get_cells(fname): - if self._kernel_client: + if self._kernel_client: + outputs = [] + for _id, cell in _get_cells(fname): reply = self._kernel_client.execute_interactive( cell, silent=False, ) - return Response(reply.get("outputs", [])) + outputs.append(reply.get("outputs", [])) + if output is not None: + return self.get_variable(output) + + return Response(outputs) return Response([]) def execute_code( - self, code: str, variables: Optional[dict[str, Any]] = None - ) -> Response: + self, + code: str, + variables: Optional[dict[str, Any]] = None, + output: Optional[str] = None, + ) -> Union[Response, Any]: """ Execute code in the runtime. @@ -862,6 +879,8 @@ def execute_code( The Python code to execute. variables: Optional[dict[str, Any]] Optional variables to set before executing the code. + output: Optional[str] + Optional output variable to return as result. Returns ------- @@ -874,6 +893,8 @@ def execute_code( self.set_variables(variables) reply = self._kernel_client.execute(code) result = reply.get("outputs", {}) + if output is not None: + return self.get_variable(output) else: raise RuntimeError( "Kernel client is not started. Call `start()` first." @@ -884,8 +905,11 @@ def execute_code( return Response([]) def execute( - self, code_or_path: Union[str, Path], variables: Optional[dict[str, Any]] = None - ) -> Response: + self, + code_or_path: Union[str, Path], + variables: Optional[dict[str, Any]] = None, + output: Optional[str] = None, + ) -> Union[Response, Any]: """ Execute code in the runtime. @@ -895,10 +919,13 @@ def execute( The Python code or path to the file to execute. variables: Optional[dict[str, Any]] Optional variables to set before executing the code. + output: Optional[str] + Optional output variable to return as result. Returns ------- - dict: The result of the code execution. + dict: + The result of the code execution. { @@ -916,9 +943,13 @@ def execute( } """ if self._check_file(code_or_path): - return self.execute_file(str(code_or_path), variables) + return self.execute_file( + str(code_or_path), variables=variables, output=output + ) else: - return self.execute_code(str(code_or_path), variables) + return self.execute_code( + str(code_or_path), variables=variables, output=output + ) def terminate(self) -> bool: """Terminate the Runtime.""" diff --git a/datalayer_core/sdk/decorators.py b/datalayer_core/sdk/decorators.py new file mode 100644 index 00000000..8b719fa7 --- /dev/null +++ b/datalayer_core/sdk/decorators.py @@ -0,0 +1,134 @@ +# Copyright (c) 2023-2025 Datalayer, Inc. +# Distributed under the terms of the Modified BSD License. + +import functools +import inspect +from typing import Any, Callable, Optional, Union + +from datalayer_core.sdk.datalayer import DatalayerClient + +# TODO: +# - inputs are different from args and kwargs (rename) +# - inputs cannot be kewyword args of the function +# - incorrect number of args + + +def datalayer( + runtime_name: Union[Callable[..., Any], Optional[str]] = None, + inputs: Optional[list[str]] = None, + output: Optional[str] = None, + snapshot_name: Optional[str] = None, +) -> Any: + """ + Decorator to execute a function in a Datalayer runtime. + + Parameters + ---------- + runtime_name : str, optional + The name of the runtime to use. If not provided, a default runtime will be used. + inputs : list[str], optional + A list of input variable names for the function. + output : str, optional + The name of the output variable for the function + snapshot_name : str, optional + The name of the runtime snapshot to use + + Returns + ------- + Callable[..., Any] + A decorator that wraps the function to be executed in a Datalayer runtime. + + Examples + -------- + + >>> from datalayer_core.sdk.decorators import datalayer + >>> @datalayer + ... def example(x: float, y: float) -> float: + ... return x + y + + >>> from datalayer_core.sdk.decorators import datalayer + >>> @datalayer(runtime_name="example-runtime", inputs=["x", "y"], output="z") + ... def example(x: float, y: float) -> float: + ... return x + y + """ + variables = {} + inputs_decorated = inputs + output_decorated = output + snapshot_name_decorated = snapshot_name + + if callable(runtime_name): + runtime_name_decorated = None + else: + runtime_name_decorated = runtime_name + + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: + if output_decorated is None: + output = f"DATALAYER_RUNTIME_OUTPUT_{func.__name__}".upper() + + sig = inspect.signature(func) + if inputs_decorated is None: + inputs = [] + for name, _param in sig.parameters.items(): + inputs.append(name) + variables[name] = ( + _param.default + if _param.default is not inspect.Parameter.empty + else None + ) + else: + if len(sig.parameters) != len(inputs_decorated): + raise ValueError( + f"Function {func.__name__} has {len(sig.parameters)} parameters, " + f"but {len(inputs_decorated)} inputs were provided." + ) + + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + sig = inspect.signature(func) + mapping = {} + for idx, (name, _param) in enumerate(sig.parameters.items()): + mapping[name] = (inputs_decorated or inputs)[idx] + + for kwarg, kwarg_value in kwargs.items(): + variables[mapping[kwarg]] = kwarg_value + + for idx, (arg_value) in enumerate(args): + kwarg = (inputs_decorated or inputs)[idx] + variables[kwarg] = arg_value + + function_call = ( + f"{output_decorated or output} = {func.__name__}(" + + ", ".join(inputs_decorated or inputs) + + ")" + ) + + start = 0 + func_source_lines = inspect.getsource(func).split("\n") + for start, line in enumerate(func_source_lines): + if line.startswith("def "): + break + function_source = "\n".join(func_source_lines[start:]) + + # print("inputs", inputs_decorated or inputs) + # print("variables", variables) + # print([function_source]) + # print([function_call]) + + client = DatalayerClient() + with client.create_runtime( + name=runtime_name_decorated, snapshot_name=snapshot_name_decorated + ) as runtime: + runtime.execute(function_source) + return runtime.execute( + function_call, + variables=variables, + output=output_decorated or output, + ) + + return wrapper + + # print(f"Using runtime: {runtime_name}, inputs: {inputs}, output: {output}") + if callable(runtime_name): + return decorator(runtime_name) + else: + return decorator diff --git a/datalayer_core/tests/test_sdk_decorators.py b/datalayer_core/tests/test_sdk_decorators.py new file mode 100644 index 00000000..ad4ce24c --- /dev/null +++ b/datalayer_core/tests/test_sdk_decorators.py @@ -0,0 +1,42 @@ +# Copyright (c) 2023-2025 Datalayer, Inc. +# Distributed under the terms of the Modified BSD License. + +import os +import time + +import pytest +from dotenv import load_dotenv + +from datalayer_core.sdk.decorators import datalayer + +load_dotenv() + + +DATALAYER_TEST_TOKEN = os.environ.get("DATALAYER_TEST_TOKEN") + + +def sum_test(x: float, y: float, z: float = 1) -> float: + return x + y + z + + +@pytest.mark.parametrize( + "args,expected_output,decorator", + [ + ([1, 4.5, 2], 7.5, datalayer), + ([1, 4.5, 2], 7.5, datalayer(runtime_name="runtime-test")), + ([1, 4.5, 2], 7.5, datalayer(output="result")), + ([1, 4.5, 2], 7.5, datalayer(inputs=["a", "b", "c"])), + ], +) +@pytest.mark.skipif( + not bool(DATALAYER_TEST_TOKEN), + reason="DATALAYER_TEST_TOKEN is not set, skipping secret tests.", +) +def test_decorator(args, expected_output, decorator): # type: ignore + """ + Test the Datalayer decorator. + """ + time.sleep(10) + func = decorator(sum_test) + assert func(*args) == expected_output + time.sleep(10) diff --git a/examples/sdk.py b/examples/sdk.py index 0de07a6d..2a0691fe 100644 --- a/examples/sdk.py +++ b/examples/sdk.py @@ -1,16 +1,43 @@ # Copyright (c) 2023-2025 Datalayer, Inc. # Distributed under the terms of the Modified BSD License. +import inspect from dotenv import load_dotenv from datalayer_core import DatalayerClient +from datalayer_core.sdk.decorators import datalayer # Using .env file with DATALAYER_RUN_URL and DATALAYER_TOKEN defined load_dotenv() -client = DatalayerClient() -print(client.list_runtimes()) + +# @datalayer +# @datalayer() +# @datalayer(runtime_name="example-runtime") +@datalayer(snapshot_name="snapshot-iris-model") +# @datalayer(runtime_name="example-runtime", output="result") +# @datalayer(runtime_name="example-runtime", inputs=["a", "b", "c"]) +def sum(x: float, y: float, z: int = 1) -> float: + return x + y + + +print([sum(1, 4.5, z=2)]) + +# sig = inspect.signature(example) +# print("\nParameters:") +# for name, param in sig.parameters.items(): +# print(f" Name: {name}") +# print(f" Kind: {param.kind}") +# print(f" Default Value: {param.default}") +# print(f" Annotation: {param.annotation}") +# print("---") + +# print(client.list_runtimes()) # with client.create_runtime() as runtime: +# runtime.execute('x = 1') +# runtime.execute('y = 4.5') +# runtime.execute('def example(x: float, y: float) -> float:\n return x + y\n') +# runtime.execute('print(example(x, y))') # response = runtime.execute("import os;print(len(os.environ['MY_SECRET']))") # print(response.stdout) # response = runtime.execute(