Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Release 0.9.9 #893

Merged
merged 4 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.9.8"
version = "0.9.9"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand Down
81 changes: 73 additions & 8 deletions slay/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import inspect
import logging
import os
import traceback
from types import GenericAlias
from typing import Any, ClassVar, Generic, Mapping, Optional, Type, TypeVar

Expand All @@ -11,7 +12,8 @@

UserConfigT = TypeVar("UserConfigT", bound=Optional[pydantic.BaseModel])

BASETEN_API_SECRET_NAME = "baseten_api_key"
BASETEN_API_SECRET_NAME = "baseten_workflow_api_key"
SECRET_DUMMY = "***"
TRUSS_CONFIG_SLAY_KEY = "slay_metadata"

ENDPOINT_METHOD_NAME = "run" # Referring to processor method name exposed as endpoint.
Expand Down Expand Up @@ -199,7 +201,8 @@ def __init__(self) -> None:
self._spec = AssetSpec()

def add_secret(self, key: str) -> "Assets":
self._spec.secrets[key] = "***" # Actual value is provided in deployment.
# Actual value is provided in deployment.
self._spec.secrets[key] = SECRET_DUMMY
return self

def cached(self, value: list[Any]) -> "Assets":
Expand Down Expand Up @@ -257,16 +260,22 @@ def get_service_descriptor(self, stub_cls_name: str) -> ServiceDescriptor:
return self.stub_cls_to_service[stub_cls_name]

def get_baseten_api_key(self) -> str:
if not self.secrets:
if self.secrets is None:
raise UsageError(f"Secrets not set in `{self.__class__.__name__}` object.")
error_msg = (
"For using workflows, it is required to setup a an API key with name "
f"`{BASETEN_API_SECRET_NAME}` on baseten to allow workflow processor to "
"call other processors. For local execution, secrets can be provided "
"to `run_local`."
)
if BASETEN_API_SECRET_NAME not in self.secrets:
raise MissingDependencyError(
"For using workflows, it is required to setup a an API key with name "
f"`{BASETEN_API_SECRET_NAME}` on baseten to allow workflow processor to "
"call other processors."
)
raise MissingDependencyError(error_msg)

api_key = self.secrets[BASETEN_API_SECRET_NAME]
if api_key == SECRET_DUMMY:
raise MissingDependencyError(
f"{error_msg}. Retrieved dummy value of `{api_key}`."
)
return api_key


Expand Down Expand Up @@ -357,4 +366,60 @@ class DeploymentOptionsBaseten(DeploymentOptions):


class DeploymentOptionsLocalDocker(DeploymentOptions):
# Local docker-to-docker requests don't need auth, but we need to set a
# value different from `SECRET_DUMMY` to not trigger the check that the secret
# is unset. Additionally, if local docker containers make calls to models deployed
# on baseten, a real API key must be provided (i.e. the default must be overridden).
baseten_workflow_api_key: str = "docker_dummy_key"


class StackFrame(pydantic.BaseModel):
filename: str
lineno: Optional[int]
name: str
line: Optional[str]

@classmethod
def from_frame_summary(cls, frame: traceback.FrameSummary):
return cls(
filename=frame.filename,
lineno=frame.lineno,
name=frame.name,
line=frame.line,
)

def to_frame_summary(self) -> traceback.FrameSummary:
return traceback.FrameSummary(
filename=self.filename, lineno=self.lineno, name=self.name, line=self.line
)


class RemoteErrorDetail(pydantic.BaseModel):
remote_name: str
exception_class_name: str
exception_module_name: Optional[str]
exception_message: str
user_stack_trace: list[StackFrame]

def to_stack_summary(self) -> traceback.StackSummary:
return traceback.StackSummary.from_list(
frame.to_frame_summary() for frame in self.user_stack_trace
)

def format(self) -> str:
stack = "".join(traceback.format_list(self.to_stack_summary()))
exc_info = (
f"\n(Exception class defined in `{self.exception_module_name}`.)"
if self.exception_module_name
else ""
)
error = (
f"{RemoteErrorDetail.__name__} in `{self.remote_name}`\n"
f"Traceback (most recent call last):\n"
f"{stack}{self.exception_class_name}: {self.exception_message}{exc_info}"
)
return error


class GenericRemoteException(Exception):
...
16 changes: 10 additions & 6 deletions slay/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,11 +415,11 @@ def ensure_args_are_injected(cls, original_init: Callable, kwargs) -> None:


def _create_local_context(
processor_cls: Type[definitions.ABCProcessor],
processor_cls: Type[definitions.ABCProcessor], secrets: Mapping[str, str]
) -> definitions.Context:
if hasattr(processor_cls, "default_config"):
defaults = processor_cls.default_config
return definitions.Context(user_config=defaults.user_config)
return definitions.Context(user_config=defaults.user_config, secrets=secrets)
return definitions.Context()


Expand All @@ -428,6 +428,7 @@ def _create_modified_init_for_local(
cls_to_instance: MutableMapping[
Type[definitions.ABCProcessor], definitions.ABCProcessor
],
secrets: Mapping[str, str],
):
"""Replaces the default argument values with local processor instantiations.

Expand All @@ -440,7 +441,7 @@ def init_for_local(self: definitions.ABCProcessor, **kwargs) -> None:
logging.debug(f"Patched `__init__` of `{processor_descriptor.cls_name}`.")
kwargs_mod = dict(kwargs)
if definitions.CONTEXT_ARG_NAME not in kwargs_mod:
context = _create_local_context(processor_descriptor.processor_cls)
context = _create_local_context(processor_descriptor.processor_cls, secrets)
kwargs_mod[definitions.CONTEXT_ARG_NAME] = context
else:
logging.debug(
Expand Down Expand Up @@ -475,7 +476,7 @@ def init_for_local(self: definitions.ABCProcessor, **kwargs) -> None:


@contextlib.contextmanager
def run_local() -> Any:
def run_local(secrets: Optional[Mapping[str, str]] = None) -> Any:
"""Context to run processors with dependency injection from local instances."""
type_to_instance: MutableMapping[
Type[definitions.ABCProcessor], definitions.ABCProcessor
Expand All @@ -487,7 +488,7 @@ def run_local() -> Any:
processor_descriptor.processor_cls
] = processor_descriptor.processor_cls.__init__
init_for_local = _create_modified_init_for_local(
processor_descriptor, type_to_instance
processor_descriptor, type_to_instance, secrets or {}
)
processor_descriptor.processor_cls.__init__ = init_for_local # type: ignore[method-assign]
processor_descriptor.processor_cls._init_is_patched = True
Expand Down Expand Up @@ -518,7 +519,7 @@ def _create_remote_service(
code_gen.generate_processor_source(
pathlib.Path(processor_filepath), processor_descriptor
)
# Only add needed stub URLs.
# Filter only needed services.
stub_cls_to_service = {
stub_cls.__name__: stub_cls_to_service[stub_cls.__name__]
for stub_cls in processor_descriptor.dependencies.values()
Expand All @@ -545,6 +546,9 @@ def _create_remote_service(
elif isinstance(options, definitions.DeploymentOptionsLocalDocker):
port = utils.get_free_port()
tr = truss_handle.TrussHandle(truss_dir)
tr.add_secret(
definitions.BASETEN_API_SECRET_NAME, options.baseten_workflow_api_key
)
_ = tr.docker_run(
local_port=port, detach=True, wait_for_server_ready=True, network="host"
)
Expand Down
7 changes: 3 additions & 4 deletions slay/public_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
* Shim to call already hosted baseten model.
* Helper to create a `Processor` from a truss dir.
"""

from typing import Any, ContextManager, Type, final
from typing import Any, ContextManager, Mapping, Optional, Type, final

from slay import definitions, framework, utils

Expand Down Expand Up @@ -62,6 +61,6 @@ def deploy_remotely(
return framework.deploy_remotely(entrypoint, options)


def run_local() -> ContextManager[None]:
def run_local(secrets: Optional[Mapping[str, str]] = None) -> ContextManager[None]:
"""Context manager for using in-process instantiations of processor dependencies."""
return framework.run_local()
return framework.run_local(secrets)
24 changes: 8 additions & 16 deletions slay/stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,9 @@
from typing import Type, TypeVar, final

import httpx
from slay import definitions
from slay import definitions, utils


def _handle_response(response: httpx.Response):
# TODO: improve error handling, extract context from response and include in
# re-raised exception. Consider re-raising same exception or if not a use a
# generic "RPCError" exception class or similar.
if response.is_server_error:
raise ValueError(response)
if response.is_client_error:
raise ValueError(response)
return response.json()
DEFAULT_TIMEOUT_SEC = 600


class BasetenSession:
Expand All @@ -26,28 +17,29 @@ def __init__(
self, service_descriptor: definitions.ServiceDescriptor, api_key: str
) -> None:
logging.info(
f"Stub session for {service_descriptor.name} with predict URL `{service_descriptor.predict_url}`."
f"Stub session for {service_descriptor.name} with predict URL "
f"`{service_descriptor.predict_url}`."
)
self._auth_header = {"Authorization": f"Api-Key {api_key}"}
self._service_descriptor = service_descriptor

@functools.cached_property
def _client_sync(self) -> httpx.Client:
return httpx.Client(headers=self._auth_header)
return httpx.Client(headers=self._auth_header, timeout=DEFAULT_TIMEOUT_SEC)

@functools.cached_property
def _client_async(self) -> httpx.AsyncClient:
return httpx.AsyncClient(headers=self._auth_header)
return httpx.AsyncClient(headers=self._auth_header, timeout=DEFAULT_TIMEOUT_SEC)

def predict_sync(self, json_payload):
return _handle_response(
return utils.handle_response(
self._client_sync.post(
self._service_descriptor.predict_url, json=json_payload
)
)

async def predict_async(self, json_payload):
return _handle_response(
return utils.handle_response(
await self._client_async.post(
self._service_descriptor.predict_url, json=json_payload
)
Expand Down
15 changes: 8 additions & 7 deletions slay/truss_adapter/code_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def generate_truss_model(
pathlib.Path(model_skeleton.__file__).read_text()
)

imports: list[libcst.SimpleStatementLine] = [
imports: list[Any] = [
node
for node in skeleton_tree.body
if isinstance(node, libcst.SimpleStatementLine)
Expand All @@ -52,6 +52,8 @@ def generate_truss_model(
for stmt in node.body
)
]
imports.append(libcst.parse_statement("import logging"))
imports.append(libcst.parse_statement("from slay import utils"))

class_definition: libcst.ClassDef = utils.expect_one(
node
Expand All @@ -67,7 +69,6 @@ def load(self) -> None:
self._processor = {processor_descriptor.cls_name}(context=self._context)
"""
)
imports.append(libcst.parse_statement("import logging")) # type: ignore[arg-type]

endpoint_descriptor = processor_descriptor.endpoint
def_str = "async def" if endpoint_descriptor.is_async else "def"
Expand Down Expand Up @@ -96,14 +97,14 @@ def load(self) -> None:
predict_def = libcst.parse_statement(
f"""
{def_str} predict(self, payload):
result = {maybe_await}self._processor.{endpoint_descriptor.name}({obj_arg_parts})
return {result}
with utils.exception_to_http_error(
include_stack=True, processor_name="{processor_descriptor.cls_name}"):
result = {maybe_await}self._processor.{endpoint_descriptor.name}({obj_arg_parts})
return {result}

"""
)
new_body: list[libcst.BaseStatement] = list( # type: ignore[assignment,misc]
class_definition.body.body
) + [
new_body: list[Any] = list(class_definition.body.body) + [
load_def,
predict_def,
]
Expand Down
2 changes: 1 addition & 1 deletion slay/truss_adapter/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _make_truss_config(
assets = slay_config.get_asset_spec()
config.secrets = assets.secrets
if definitions.BASETEN_API_SECRET_NAME not in config.secrets:
config.secrets[definitions.BASETEN_API_SECRET_NAME] = "***"
config.secrets[definitions.BASETEN_API_SECRET_NAME] = definitions.SECRET_DUMMY
else:
logging.info(
f"Workflows automatically add {definitions.BASETEN_API_SECRET_NAME} "
Expand Down