Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 43 additions & 12 deletions datalayer_core/sdk/datalayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ def create_runtime(
if name is None:
name = f"runtime-{environment}-{uuid.uuid4()}"

# print(f"Runtime {name}")

if snapshot_name is not None:
snapshots = self.list_snapshots()
for snapshot in snapshots:
Expand Down Expand Up @@ -738,7 +740,9 @@ def _start(self) -> None:

if self._kernel_client is None:
self._runtime = self._create_runtime(self._environment_name)
runtime: dict[str, str] = self._runtime.get("runtime") # type: ignore
# print(self._runtime)
runtime: dict[str, str] = self._runtime["runtime"] # type: ignore
# print("runtime", runtime)
self._ingress = runtime["ingress"]
self._kernel_token = runtime["token"]
self._pod_name = runtime["pod_name"]
Expand Down Expand Up @@ -820,7 +824,10 @@ def set_variables(self, variables: dict[str, Any]) -> Response:
return Response([])

def execute_file(
self, path: Union[str, Path], variables: Optional[dict[str, Any]] = None
self,
path: Union[str, Path],
variables: Optional[dict[str, Any]] = None,
output: Optional[str] = None,
) -> Response:
"""
Execute a Python file in the runtime.
Expand All @@ -831,6 +838,8 @@ def execute_file(
Path to the Python file to execute.
variables: Optional[dict[str, Any]]
Optional variables to set before executing the code.
output: Optional[str]
Optional output variable to return as result.

Returns
-------
Expand All @@ -841,18 +850,26 @@ def execute_file(
if variables:
self.set_variables(variables)

for _id, cell in _get_cells(fname):
if self._kernel_client:
if self._kernel_client:
outputs = []
for _id, cell in _get_cells(fname):
reply = self._kernel_client.execute_interactive(
cell,
silent=False,
)
return Response(reply.get("outputs", []))
outputs.append(reply.get("outputs", []))
if output is not None:
return self.get_variable(output)

return Response(outputs)
return Response([])

def execute_code(
self, code: str, variables: Optional[dict[str, Any]] = None
) -> Response:
self,
code: str,
variables: Optional[dict[str, Any]] = None,
output: Optional[str] = None,
) -> Union[Response, Any]:
"""
Execute code in the runtime.

Expand All @@ -862,6 +879,8 @@ def execute_code(
The Python code to execute.
variables: Optional[dict[str, Any]]
Optional variables to set before executing the code.
output: Optional[str]
Optional output variable to return as result.

Returns
-------
Expand All @@ -874,6 +893,8 @@ def execute_code(
self.set_variables(variables)
reply = self._kernel_client.execute(code)
result = reply.get("outputs", {})
if output is not None:
return self.get_variable(output)
else:
raise RuntimeError(
"Kernel client is not started. Call `start()` first."
Expand All @@ -884,8 +905,11 @@ def execute_code(
return Response([])

def execute(
self, code_or_path: Union[str, Path], variables: Optional[dict[str, Any]] = None
) -> Response:
self,
code_or_path: Union[str, Path],
variables: Optional[dict[str, Any]] = None,
output: Optional[str] = None,
) -> Union[Response, Any]:
"""
Execute code in the runtime.

Expand All @@ -895,10 +919,13 @@ def execute(
The Python code or path to the file to execute.
variables: Optional[dict[str, Any]]
Optional variables to set before executing the code.
output: Optional[str]
Optional output variable to return as result.

Returns
-------
dict: The result of the code execution.
dict:
The result of the code execution.


{
Expand All @@ -916,9 +943,13 @@ def execute(
}
"""
if self._check_file(code_or_path):
return self.execute_file(str(code_or_path), variables)
return self.execute_file(
str(code_or_path), variables=variables, output=output
)
else:
return self.execute_code(str(code_or_path), variables)
return self.execute_code(
str(code_or_path), variables=variables, output=output
)

def terminate(self) -> bool:
"""Terminate the Runtime."""
Expand Down
134 changes: 134 additions & 0 deletions datalayer_core/sdk/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright (c) 2023-2025 Datalayer, Inc.
# Distributed under the terms of the Modified BSD License.

import functools
import inspect
from typing import Any, Callable, Optional, Union

from datalayer_core.sdk.datalayer import DatalayerClient

# TODO:
# - inputs are different from args and kwargs (rename)
# - inputs cannot be kewyword args of the function
# - incorrect number of args


def datalayer(
runtime_name: Union[Callable[..., Any], Optional[str]] = None,
inputs: Optional[list[str]] = None,
output: Optional[str] = None,
snapshot_name: Optional[str] = None,
) -> Any:
"""
Decorator to execute a function in a Datalayer runtime.

Parameters
----------
runtime_name : str, optional
The name of the runtime to use. If not provided, a default runtime will be used.
inputs : list[str], optional
A list of input variable names for the function.
output : str, optional
The name of the output variable for the function
snapshot_name : str, optional
The name of the runtime snapshot to use

Returns
-------
Callable[..., Any]
A decorator that wraps the function to be executed in a Datalayer runtime.

Examples
--------

>>> from datalayer_core.sdk.decorators import datalayer
>>> @datalayer
... def example(x: float, y: float) -> float:
... return x + y

>>> from datalayer_core.sdk.decorators import datalayer
>>> @datalayer(runtime_name="example-runtime", inputs=["x", "y"], output="z")
... def example(x: float, y: float) -> float:
... return x + y
"""
variables = {}
inputs_decorated = inputs
output_decorated = output
snapshot_name_decorated = snapshot_name

if callable(runtime_name):
runtime_name_decorated = None
else:
runtime_name_decorated = runtime_name

def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
if output_decorated is None:
output = f"DATALAYER_RUNTIME_OUTPUT_{func.__name__}".upper()

sig = inspect.signature(func)
if inputs_decorated is None:
inputs = []
for name, _param in sig.parameters.items():
inputs.append(name)
variables[name] = (
_param.default
if _param.default is not inspect.Parameter.empty
else None
)
else:
if len(sig.parameters) != len(inputs_decorated):
raise ValueError(
f"Function {func.__name__} has {len(sig.parameters)} parameters, "
f"but {len(inputs_decorated)} inputs were provided."
)

@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
sig = inspect.signature(func)
mapping = {}
for idx, (name, _param) in enumerate(sig.parameters.items()):
mapping[name] = (inputs_decorated or inputs)[idx]

for kwarg, kwarg_value in kwargs.items():
variables[mapping[kwarg]] = kwarg_value

for idx, (arg_value) in enumerate(args):
kwarg = (inputs_decorated or inputs)[idx]
variables[kwarg] = arg_value

function_call = (
f"{output_decorated or output} = {func.__name__}("
+ ", ".join(inputs_decorated or inputs)
+ ")"
)

start = 0
func_source_lines = inspect.getsource(func).split("\n")
for start, line in enumerate(func_source_lines):
if line.startswith("def "):
break
function_source = "\n".join(func_source_lines[start:])

# print("inputs", inputs_decorated or inputs)
# print("variables", variables)
# print([function_source])
# print([function_call])

client = DatalayerClient()
with client.create_runtime(
name=runtime_name_decorated, snapshot_name=snapshot_name_decorated
) as runtime:
runtime.execute(function_source)
return runtime.execute(
function_call,
variables=variables,
output=output_decorated or output,
)

return wrapper

# print(f"Using runtime: {runtime_name}, inputs: {inputs}, output: {output}")
if callable(runtime_name):
return decorator(runtime_name)
else:
return decorator
42 changes: 42 additions & 0 deletions datalayer_core/tests/test_sdk_decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) 2023-2025 Datalayer, Inc.
# Distributed under the terms of the Modified BSD License.

import os
import time

import pytest
from dotenv import load_dotenv

from datalayer_core.sdk.decorators import datalayer

load_dotenv()


DATALAYER_TEST_TOKEN = os.environ.get("DATALAYER_TEST_TOKEN")


def sum_test(x: float, y: float, z: float = 1) -> float:
return x + y + z


@pytest.mark.parametrize(
"args,expected_output,decorator",
[
([1, 4.5, 2], 7.5, datalayer),
([1, 4.5, 2], 7.5, datalayer(runtime_name="runtime-test")),
([1, 4.5, 2], 7.5, datalayer(output="result")),
([1, 4.5, 2], 7.5, datalayer(inputs=["a", "b", "c"])),
],
)
@pytest.mark.skipif(
not bool(DATALAYER_TEST_TOKEN),
reason="DATALAYER_TEST_TOKEN is not set, skipping secret tests.",
)
def test_decorator(args, expected_output, decorator): # type: ignore
"""
Test the Datalayer decorator.
"""
time.sleep(10)
func = decorator(sum_test)
assert func(*args) == expected_output
time.sleep(10)
31 changes: 29 additions & 2 deletions examples/sdk.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,43 @@
# Copyright (c) 2023-2025 Datalayer, Inc.
# Distributed under the terms of the Modified BSD License.
import inspect

from dotenv import load_dotenv

from datalayer_core import DatalayerClient
from datalayer_core.sdk.decorators import datalayer

# Using .env file with DATALAYER_RUN_URL and DATALAYER_TOKEN defined
load_dotenv()

client = DatalayerClient()
print(client.list_runtimes())

# @datalayer
# @datalayer()
# @datalayer(runtime_name="example-runtime")
@datalayer(snapshot_name="snapshot-iris-model")
# @datalayer(runtime_name="example-runtime", output="result")
# @datalayer(runtime_name="example-runtime", inputs=["a", "b", "c"])
def sum(x: float, y: float, z: int = 1) -> float:
return x + y


print([sum(1, 4.5, z=2)])

# sig = inspect.signature(example)
# print("\nParameters:")
# for name, param in sig.parameters.items():
# print(f" Name: {name}")
# print(f" Kind: {param.kind}")
# print(f" Default Value: {param.default}")
# print(f" Annotation: {param.annotation}")
# print("---")

# print(client.list_runtimes())
# with client.create_runtime() as runtime:
# runtime.execute('x = 1')
# runtime.execute('y = 4.5')
# runtime.execute('def example(x: float, y: float) -> float:\n return x + y\n')
# runtime.execute('print(example(x, y))')
# response = runtime.execute("import os;print(len(os.environ['MY_SECRET']))")
# print(response.stdout)
# response = runtime.execute(
Expand Down
Loading