From f0f6fce73243fd26e34663951fb44362e823d532 Mon Sep 17 00:00:00 2001 From: Marius Killinger <155577904+marius-baseten@users.noreply.github.com> Date: Fri, 22 Mar 2024 16:11:34 -0700 Subject: [PATCH 1/6] Relax version constraints. (#875) --- poetry.lock | 78 +++++++++++++++++-------- pyproject.toml | 24 ++++---- truss/templates/server/requirements.txt | 12 ++-- 3 files changed, 73 insertions(+), 41 deletions(-) diff --git a/poetry.lock b/poetry.lock index 41ec27280..5a30dd11e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "annotated-types" @@ -203,17 +203,17 @@ files = [ [[package]] name = "boto3" -version = "1.34.67" +version = "1.34.69" description = "The AWS SDK for Python" optional = false python-versions = ">=3.8" files = [ - {file = "boto3-1.34.67-py3-none-any.whl", hash = "sha256:473febdf2606cf36f14c470dc3ff1b986efac15f69e37eb0fd728d42749065dd"}, - {file = "boto3-1.34.67.tar.gz", hash = "sha256:950161d438ae1bf31374f04175e5f2624a5de8109674ff80f4de5d962313072a"}, + {file = "boto3-1.34.69-py3-none-any.whl", hash = "sha256:2e25ef6bd325217c2da329829478be063155897d8d3b29f31f7f23ab548519b1"}, + {file = "boto3-1.34.69.tar.gz", hash = "sha256:898a5fed26b1351352703421d1a8b886ef2a74be6c97d5ecc92432ae01fda203"}, ] [package.dependencies] -botocore = ">=1.34.67,<1.35.0" +botocore = ">=1.34.69,<1.35.0" jmespath = ">=0.7.1,<2.0.0" s3transfer = ">=0.10.0,<0.11.0" @@ -222,13 +222,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "botocore" -version = "1.34.67" +version = "1.34.69" description = "Low-level, data-driven core of boto 3." optional = false python-versions = ">=3.8" files = [ - {file = "botocore-1.34.67-py3-none-any.whl", hash = "sha256:56002d7c046ec134811dd079469692ab82919c9840ea4e1c1c373a1d228e37ae"}, - {file = "botocore-1.34.67.tar.gz", hash = "sha256:fc094c055a6ac151820a4d8e28f7b30d03e02695ce180527520a5e219b14e8a1"}, + {file = "botocore-1.34.69-py3-none-any.whl", hash = "sha256:d3802d076d4d507bf506f9845a6970ce43adc3d819dd57c2791f5c19ed6e5950"}, + {file = "botocore-1.34.69.tar.gz", hash = "sha256:d1ab2bff3c2fd51719c2021d9fa2f30fbb9ed0a308f69e9a774ac92c8091380a"}, ] [package.dependencies] @@ -639,13 +639,13 @@ test = ["pytest (>=6)"] [[package]] name = "fastapi" -version = "0.109.2" +version = "0.110.0" description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" optional = false python-versions = ">=3.8" files = [ - {file = "fastapi-0.109.2-py3-none-any.whl", hash = "sha256:2c9bab24667293b501cad8dd388c05240c850b58ec5876ee3283c47d6e1e3a4d"}, - {file = "fastapi-0.109.2.tar.gz", hash = "sha256:f3817eac96fe4f65a2ebb4baa000f394e55f5fccdaf7f75250804bc58f354f73"}, + {file = "fastapi-0.110.0-py3-none-any.whl", hash = "sha256:87a1f6fb632a218222c5984be540055346a8f5d8a68e8f6fb647b1dc9934de4b"}, + {file = "fastapi-0.110.0.tar.gz", hash = "sha256:266775f0dcc95af9d3ef39bad55cff525329a931d5fd51930aadd4f428bf7ff3"}, ] [package.dependencies] @@ -762,18 +762,19 @@ tqdm = ["tqdm"] [[package]] name = "google-api-core" -version = "2.17.1" +version = "2.18.0" description = "Google API client core library" optional = false python-versions = ">=3.7" files = [ - {file = "google-api-core-2.17.1.tar.gz", hash = "sha256:9df18a1f87ee0df0bc4eea2770ebc4228392d8cc4066655b320e2cfccb15db95"}, - {file = "google_api_core-2.17.1-py3-none-any.whl", hash = "sha256:610c5b90092c360736baccf17bd3efbcb30dd380e7a6dc28a71059edb8bd0d8e"}, + {file = "google-api-core-2.18.0.tar.gz", hash = "sha256:62d97417bfc674d6cef251e5c4d639a9655e00c45528c4364fbfebb478ce72a9"}, + {file = "google_api_core-2.18.0-py3-none-any.whl", hash = "sha256:5a63aa102e0049abe85b5b88cb9409234c1f70afcda21ce1e40b285b9629c1d6"}, ] [package.dependencies] google-auth = ">=2.14.1,<3.0.dev0" googleapis-common-protos = ">=1.56.2,<2.0.dev0" +proto-plus = ">=1.22.3,<2.0.0dev" protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0" requests = ">=2.18.0,<3.0.0.dev0" @@ -1711,13 +1712,13 @@ test = ["flaky", "ipykernel (>=6.19.3)", "ipython", "ipywidgets", "nbconvert (>= [[package]] name = "nbconvert" -version = "7.16.2" +version = "7.16.3" description = "Converting Jupyter Notebooks (.ipynb files) to other formats. Output formats include asciidoc, html, latex, markdown, pdf, py, rst, script. nbconvert can be used both as a Python library (`import nbconvert`) or as a command line tool (invoked as `jupyter nbconvert ...`)." optional = false python-versions = ">=3.8" files = [ - {file = "nbconvert-7.16.2-py3-none-any.whl", hash = "sha256:0c01c23981a8de0220255706822c40b751438e32467d6a686e26be08ba784382"}, - {file = "nbconvert-7.16.2.tar.gz", hash = "sha256:8310edd41e1c43947e4ecf16614c61469ebc024898eb808cce0999860fc9fb16"}, + {file = "nbconvert-7.16.3-py3-none-any.whl", hash = "sha256:ddeff14beeeedf3dd0bc506623e41e4507e551736de59df69a91f86700292b3b"}, + {file = "nbconvert-7.16.3.tar.gz", hash = "sha256:a6733b78ce3d47c3f85e504998495b07e6ea9cf9bf6ec1c98dda63ec6ad19142"}, ] [package.dependencies] @@ -1744,7 +1745,7 @@ docs = ["ipykernel", "ipython", "myst-parser", "nbsphinx (>=0.2.12)", "pydata-sp qtpdf = ["nbconvert[qtpng]"] qtpng = ["pyqtwebengine (>=5.15)"] serve = ["tornado (>=6.1)"] -test = ["flaky", "ipykernel", "ipywidgets (>=7.5)", "pytest"] +test = ["flaky", "ipykernel", "ipywidgets (>=7.5)", "pytest (>=7)"] webpdf = ["playwright"] [[package]] @@ -1990,6 +1991,23 @@ files = [ [package.dependencies] wcwidth = "*" +[[package]] +name = "proto-plus" +version = "1.23.0" +description = "Beautiful, Pythonic protocol buffers." +optional = false +python-versions = ">=3.6" +files = [ + {file = "proto-plus-1.23.0.tar.gz", hash = "sha256:89075171ef11988b3fa157f5dbd8b9cf09d65fffee97e29ce403cd8defba19d2"}, + {file = "proto_plus-1.23.0-py3-none-any.whl", hash = "sha256:a829c79e619e1cf632de091013a4173deed13a55f326ef84f05af6f50ff4c82c"}, +] + +[package.dependencies] +protobuf = ">=3.19.0,<5.0.0dev" + +[package.extras] +testing = ["google-api-core[grpc] (>=1.31.5)"] + [[package]] name = "protobuf" version = "4.25.3" @@ -2963,18 +2981,32 @@ files = [ [[package]] name = "types-requests" -version = "2.31.0.2" +version = "2.31.0.6" description = "Typing stubs for requests" optional = false -python-versions = "*" +python-versions = ">=3.7" files = [ - {file = "types-requests-2.31.0.2.tar.gz", hash = "sha256:6aa3f7faf0ea52d728bb18c0a0d1522d9bfd8c72d26ff6f61bfc3d06a411cf40"}, - {file = "types_requests-2.31.0.2-py3-none-any.whl", hash = "sha256:56d181c85b5925cbc59f4489a57e72a8b2166f18273fd8ba7b6fe0c0b986f12a"}, + {file = "types-requests-2.31.0.6.tar.gz", hash = "sha256:cd74ce3b53c461f1228a9b783929ac73a666658f223e28ed29753771477b3bd0"}, + {file = "types_requests-2.31.0.6-py3-none-any.whl", hash = "sha256:a2db9cb228a81da8348b49ad6db3f5519452dd20a9c1e1a868c83c5fe88fd1a9"}, ] [package.dependencies] types-urllib3 = "*" +[[package]] +name = "types-requests" +version = "2.31.0.20240311" +description = "Typing stubs for requests" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-requests-2.31.0.20240311.tar.gz", hash = "sha256:b1c1b66abfb7fa79aae09097a811c4aa97130eb8831c60e47aee4ca344731ca5"}, + {file = "types_requests-2.31.0.20240311-py3-none-any.whl", hash = "sha256:47872893d65a38e282ee9f277a4ee50d1b28bd592040df7d1fdaffdf3779937d"}, +] + +[package.dependencies] +urllib3 = ">=2" + [[package]] name = "types-setuptools" version = "69.2.0.20240317" @@ -3229,4 +3261,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.8,<3.12" -content-hash = "2340fedcace299c5d71ed0f87c3a334907b3df1268e214da1bdf8d82933394ec" +content-hash = "f247b02f095febf25b29d5b61dded3ce0f195ba51b734a5e771394f75db1ce25" diff --git a/pyproject.toml b/pyproject.toml index 0e5a53f0b..4485f0ad8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ keywords = [ [tool.poetry.dependencies] blake3 = "^0.3.3" boto3 = "^1.26.157" -fastapi = "^0.109.1" +fastapi = ">=0.109.1" google-cloud-storage = "2.10.0" httpx = "^0.24.1" huggingface_hub = ">=0.19.4" @@ -32,29 +32,27 @@ inquirerpy = "^0.3.4" Jinja2 = "^3.1.2" loguru = ">=0.7.2" msgpack = ">=1.0.2" -msgpack-numpy = ">=0.4.7.1" +msgpack-numpy = ">=0.4.8" numpy = ">=1.23.5" packaging = ">=20.9" pathspec = ">=0.9.0" -psutil = "^5.9.4" +psutil = ">=5.9.4" pydantic = ">=1.10.0" python = ">=3.8,<3.12" python-json-logger = ">=2.0.2" python-on-whales = "^0.68.0" -PyYAML = "^6.0" +PyYAML = ">=6.0" rich = "^13.4.2" rich-click = "^1.6.1" single-source = "^0.3.0" tenacity = "^8.0.1" -uvicorn = "^0.24.0" -uvloop = "^0.19.0" watchfiles = "^0.19.0" [tool.poetry.group.builder.dependencies] blake3 = "^0.3.3" boto3 = "^1.26.157" click = "^8.0.3" -fastapi = "^0.109.1" +fastapi = ">=0.109.1" google-cloud-storage = "2.10.0" httpx = "^0.24.1" huggingface_hub = ">=0.19.4" @@ -62,11 +60,11 @@ Jinja2 = "^3.1.2" loguru = ">=0.7.2" packaging = ">=20.9" pathspec = ">=0.9.0" -psutil = "^5.9.4" +psutil = ">=5.9.4" python = ">=3.8,<3.12" python-json-logger = ">=2.0.2" -PyYAML = "^6.0" -requests = "^2.28.1" +PyYAML = ">=6.0" +requests = ">=2.31" single-source = "^0.3.0" tenacity = "^8.0.1" uvicorn = "^0.24.0" @@ -95,8 +93,10 @@ flask = "^2.3.3" httpx = { extras = ["cli"], version = "^0.24.1" } mypy = "^1.0.0" pytest-split = "^0.8.1" -requests-mock = "^1.11.0" -types-requests = "2.31.0.2" +requests-mock = ">=1.11.0" +types-requests = ">=2.31.0.2" +uvicorn = ">=0.24.0" +uvloop = ">=0.17.0" [build-system] build-backend = "poetry.core.masonry.api" diff --git a/truss/templates/server/requirements.txt b/truss/templates/server/requirements.txt index f166536e6..04d16ed80 100644 --- a/truss/templates/server/requirements.txt +++ b/truss/templates/server/requirements.txt @@ -1,16 +1,16 @@ -i https://pypi.org/simple -argparse==1.4.0 aiocontextvars==0.2.2 +argparse==1.4.0 cython==3.0.5 +fastapi==0.109.1 +joblib==1.2.0 +loguru==0.7.2 msgpack-numpy==0.4.8 msgpack==1.0.2 +psutil==5.9.4 python-json-logger==2.0.2 pyyaml==6.0.0 -fastapi==0.109.1 +requests==2.31.0 uvicorn==0.24.0 uvloop==0.17.0 -psutil==5.9.4 -joblib==1.2.0 -requests==2.31.0 -loguru==0.7.2 From f2c36dc0c356cc193ff11f701483862593bbcae9 Mon Sep 17 00:00:00 2001 From: Sidharth Shanker Date: Tue, 26 Mar 2024 11:41:19 -0400 Subject: [PATCH 2/6] Add new docker auth settings to truss config. (#876) * Add new docker auth settings to truss config. * Return string. * Update enum value. * Change auth types. * Responded to pr feedback, fix tests. * Update tests. --- pyproject.toml | 2 +- truss/tests/test_config.py | 37 ++++++++++++++++++--- truss/truss_config.py | 67 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 101 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4485f0ad8..6efe1e1ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "truss" -version = "0.9.6" +version = "0.9.7rc6" description = "A seamless bridge from model development to model delivery" license = "MIT" readme = "README.md" diff --git a/truss/tests/test_config.py b/truss/tests/test_config.py index f2eea2a76..106b9a957 100644 --- a/truss/tests/test_config.py +++ b/truss/tests/test_config.py @@ -10,6 +10,8 @@ Accelerator, AcceleratorSpec, BaseImage, + DockerAuthSettings, + DockerAuthType, ModelCache, ModelRepo, Resources, @@ -100,10 +102,7 @@ def test_acc_spec_from_str(input_str, expected_acc): ( {}, BaseImage(), - { - "image": "", - "python_executable_path": "", - }, + {"image": "", "python_executable_path": "", "docker_auth": None}, ), ( {"image": "custom_base_image", "python_executable_path": "/path/python"}, @@ -111,6 +110,36 @@ def test_acc_spec_from_str(input_str, expected_acc): { "image": "custom_base_image", "python_executable_path": "/path/python", + "docker_auth": None, + }, + ), + ( + { + "image": "custom_base_image", + "python_executable_path": "/path/python", + "docker_auth": { + "auth_method": "GCS_SERVICE_ACCOUNT_JSON", + "secret_name": "some-secret-name", + "registry": "some-docker-registry", + }, + }, + BaseImage( + image="custom_base_image", + python_executable_path="/path/python", + docker_auth=DockerAuthSettings( + auth_method=DockerAuthType.GCS_SERVICE_ACCOUNT_JSON, + secret_name="some-secret-name", + registry="some-docker-registry", + ), + ), + { + "image": "custom_base_image", + "python_executable_path": "/path/python", + "docker_auth": { + "auth_method": "GCS_SERVICE_ACCOUNT_JSON", + "secret_name": "some-secret-name", + "registry": "some-docker-registry", + }, }, ), ], diff --git a/truss/truss_config.py b/truss/truss_config.py index ed7cd86bb..9c2531a5b 100644 --- a/truss/truss_config.py +++ b/truss/truss_config.py @@ -297,25 +297,84 @@ def to_list(self) -> List[Dict[str, str]]: return [item.to_dict() for item in self.items] +class DockerAuthType(Enum): + """ + This enum will express all of the types of registry + authentication we support. + """ + + GCS_SERVICE_ACCOUNT_JSON = "GCS_SERVICE_ACCOUNT_JSON" + + +@dataclass +class DockerAuthSettings: + """ + Provides information about how to authenticate to the docker registry containing + the custom base image. + """ + + auth_method: DockerAuthType + secret_name: str + registry: Optional[str] = "" + + @staticmethod + def from_dict(d: Dict[str, str]): + auth_method = d.get("auth_method") + secret_name = d.get("secret_name") + + if auth_method: + # Capitalize the auth method so that we support this field passed + # as "gcs_service_account". + auth_method = auth_method.upper() + + if ( + not secret_name + or not auth_method + or auth_method not in [auth_type.value for auth_type in DockerAuthType] + ): + raise ValueError("Please provide a `secret_name`, and valid `auth_method`") + + return DockerAuthSettings( + auth_method=DockerAuthType[auth_method], + secret_name=secret_name, + registry=d.get("registry"), + ) + + def to_dict(self): + return { + "auth_method": self.auth_method.value, + "secret_name": self.secret_name, + "registry": self.registry, + } + + @dataclass class BaseImage: image: str = "" python_executable_path: str = "" + docker_auth: Optional[DockerAuthSettings] = None @staticmethod def from_dict(d): image = d.get("image", "") python_executable_path = d.get("python_executable_path", "") + docker_auth = d.get("docker_auth") validate_python_executable_path(python_executable_path) return BaseImage( image=image, python_executable_path=python_executable_path, + docker_auth=DockerAuthSettings.from_dict(docker_auth) + if docker_auth + else None, ) def to_dict(self): return { "image": self.image, "python_executable_path": self.python_executable_path, + "docker_auth": transform_optional( + self.docker_auth, lambda docker_auth: docker_auth.to_dict() + ), } @@ -584,6 +643,14 @@ def obj_to_dict(obj, verbose: bool = False): d["trt_llm"] = transform_optional( field_curr_value, lambda data: data.dict() ) + elif isinstance(field_curr_value, BaseImage): + d["base_image"] = transform_optional( + field_curr_value, lambda data: data.to_dict() + ) + elif isinstance(field_curr_value, DockerAuthSettings): + d["docker_auth"] = transform_optional( + field_curr_value, lambda data: data.to_dict() + ) else: d[field_name] = field_curr_value From 9615729b1493783180868f2c412697e7bedbf120 Mon Sep 17 00:00:00 2001 From: Vlad Shulman Date: Tue, 26 Mar 2024 13:40:56 -0700 Subject: [PATCH 3/6] Revert "removing separate caching step (#872)" (#878) This reverts commit 41807058f7ec6862bae5add322a204a2adf1864e. --- truss/templates/cache.Dockerfile.jinja | 2 ++ truss/templates/copy_cache_files.Dockerfile.jinja | 3 +++ truss/templates/server.Dockerfile.jinja | 11 ++++++++--- 3 files changed, 13 insertions(+), 3 deletions(-) create mode 100644 truss/templates/copy_cache_files.Dockerfile.jinja diff --git a/truss/templates/cache.Dockerfile.jinja b/truss/templates/cache.Dockerfile.jinja index 439b6a17e..5ed3dd169 100644 --- a/truss/templates/cache.Dockerfile.jinja +++ b/truss/templates/cache.Dockerfile.jinja @@ -1,3 +1,5 @@ +FROM python:3.11-slim as cache_warmer + RUN mkdir -p /app/model_cache WORKDIR /app diff --git a/truss/templates/copy_cache_files.Dockerfile.jinja b/truss/templates/copy_cache_files.Dockerfile.jinja new file mode 100644 index 000000000..3772acf44 --- /dev/null +++ b/truss/templates/copy_cache_files.Dockerfile.jinja @@ -0,0 +1,3 @@ +{% for file in cached_files %} +COPY --from=cache_warmer {{file.source}} {{file.dst}} +{% endfor %} diff --git a/truss/templates/server.Dockerfile.jinja b/truss/templates/server.Dockerfile.jinja index 98956a544..690980bcd 100644 --- a/truss/templates/server.Dockerfile.jinja +++ b/truss/templates/server.Dockerfile.jinja @@ -1,3 +1,7 @@ +{%- if model_cache %} +{%- include "cache.Dockerfile.jinja" %} +{%- endif %} + {% extends "base.Dockerfile.jinja" %} {% block base_image_patch %} @@ -47,9 +51,10 @@ RUN pip install -r {{server_requirements_filename}} --no-cache-dir && rm -rf /ro {% block app_copy %} - {%- if model_cache %} - {%- include "cache.Dockerfile.jinja" %} - {%- endif %} +{%- if model_cache %} +# Copy data before code for better caching + {%- include "copy_cache_files.Dockerfile.jinja"%} +{%- endif %} {%- if external_data_files %} {% for url, dst in external_data_files %} From 9cd097aacf1eed346011089545e0e4124f76cd03 Mon Sep 17 00:00:00 2001 From: Marius Killinger <155577904+marius-baseten@users.noreply.github.com> Date: Tue, 26 Mar 2024 16:03:40 -0700 Subject: [PATCH 4/6] Orchestration (#840) * Draft for workflow DX. * Draft for workflow DX. * Lalilu. * Implement local mode. * Started Code Gen. * Reorganized example. * Reorganized example. * Stub Gen almost there. * Rewrite Remote Processor. * Separate default config and context for initializing (remote) processors. * Rope move processor to new file * Almost able to generate and deploy truss. Dependencies are still a mess and don't work. Try to optimistically prune workflow file next. * Trusses without stubs/deps work. Stub files are still missing type defs and won't work. * WIP. * Version fixes. * Clean-ups. * Add MVP workflow without pydantic types. Add programmatic deployment * Deployed as dev models successfully call each other. * Cleanup * Fix formatting and remove main section * Adaptive truss config control. Docker build broken due to conflicting uvloop version requirements. * Use requirements file instead of list, this fixes version conflicts. * Fix templated user config type (parsing). * Fix mistral model. * Restructure, better deps resolutions, use constants. * Add docstrings, address revivew comments. * Address revivew comments. Add support for relative paths. --- .gitignore | 3 + .pre-commit-config.yaml | 11 +- poetry.lock | 262 +++++++- pyproject.toml | 7 + slay-examples/text_to_num/requirements.txt | 4 + .../text_to_num/user_package/__init__.py | 0 .../user_package/shared_processor.py | 18 + slay-examples/text_to_num/workflow.py | 201 ++++++ slay/__init__.py | 17 + slay/code_gen.py | 359 ++++++++++ slay/definitions.py | 347 ++++++++++ slay/framework.py | 621 ++++++++++++++++++ slay/public_api.py | 62 ++ slay/stub.py | 62 ++ slay/truss_adapter/__init__.py | 0 slay/truss_adapter/code_gen.py | 121 ++++ slay/truss_adapter/deploy.py | 337 ++++++++++ slay/truss_adapter/model_skeleton.py | 36 + slay/utils.py | 55 ++ 19 files changed, 2497 insertions(+), 26 deletions(-) create mode 100644 slay-examples/text_to_num/requirements.txt create mode 100644 slay-examples/text_to_num/user_package/__init__.py create mode 100644 slay-examples/text_to_num/user_package/shared_processor.py create mode 100644 slay-examples/text_to_num/workflow.py create mode 100644 slay/__init__.py create mode 100644 slay/code_gen.py create mode 100644 slay/definitions.py create mode 100644 slay/framework.py create mode 100644 slay/public_api.py create mode 100644 slay/stub.py create mode 100644 slay/truss_adapter/__init__.py create mode 100644 slay/truss_adapter/code_gen.py create mode 100644 slay/truss_adapter/deploy.py create mode 100644 slay/truss_adapter/model_skeleton.py create mode 100644 slay/utils.py diff --git a/.gitignore b/.gitignore index 61159e0ab..5c238aa67 100644 --- a/.gitignore +++ b/.gitignore @@ -72,3 +72,6 @@ yarn.lock .eslintcache /.pytest_cache/ + +# Slay workflow generated code. +**/.slay_generated/** diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8f89f8823..4e905d524 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -41,5 +41,14 @@ repos: entry: poetry run mypy language: python types: [python] - exclude: ^examples/|^truss/test.+/|model.py$ + exclude: ^examples/|^truss/test.+/|model.py$|^slay.* + pass_filenames: true + - id: mypy + name: mypy-local (3.9) + entry: poetry run mypy + language: python + types: [python] + files: ^slay.* + args: + - --python-version=3.9 pass_filenames: true diff --git a/poetry.lock b/poetry.lock index 5a30dd11e..2610be4cd 100644 --- a/poetry.lock +++ b/poetry.lock @@ -47,6 +47,34 @@ files = [ {file = "appnope-0.1.4.tar.gz", hash = "sha256:1de3860566df9caf38f01f86f65e0e13e379af54f9e4bee1e66b48f2efffd1ee"}, ] +[[package]] +name = "argcomplete" +version = "3.2.3" +description = "Bash tab completion for argparse" +optional = false +python-versions = ">=3.8" +files = [ + {file = "argcomplete-3.2.3-py3-none-any.whl", hash = "sha256:c12355e0494c76a2a7b73e3a59b09024ca0ba1e279fb9ed6c1b82d5b74b6a70c"}, + {file = "argcomplete-3.2.3.tar.gz", hash = "sha256:bf7900329262e481be5a15f56f19736b376df6f82ed27576fa893652c5de6c23"}, +] + +[package.extras] +test = ["coverage", "mypy", "pexpect", "ruff", "wheel"] + +[[package]] +name = "astroid" +version = "3.1.0" +description = "An abstract syntax tree for Python with inference support." +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "astroid-3.1.0-py3-none-any.whl", hash = "sha256:951798f922990137ac090c53af473db7ab4e70c770e6d7fae0cec59f74411819"}, + {file = "astroid-3.1.0.tar.gz", hash = "sha256:ac248253bfa4bd924a0de213707e7ebeeb3138abeb48d798784ead1e56d419d4"}, +] + +[package.dependencies] +typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""} + [[package]] name = "attrs" version = "23.2.0" @@ -66,6 +94,21 @@ tests = ["attrs[tests-no-zope]", "zope-interface"] tests-mypy = ["mypy (>=1.6)", "pytest-mypy-plugins"] tests-no-zope = ["attrs[tests-mypy]", "cloudpickle", "hypothesis", "pympler", "pytest (>=4.3.0)", "pytest-xdist[psutil]"] +[[package]] +name = "autoflake" +version = "1.7.8" +description = "Removes unused imports and unused variables" +optional = false +python-versions = ">=3.7" +files = [ + {file = "autoflake-1.7.8-py3-none-any.whl", hash = "sha256:46373ef69b6714f5064c923bb28bd797c4f8a9497f557d87fc36665c6d956b39"}, + {file = "autoflake-1.7.8.tar.gz", hash = "sha256:e7e46372dee46fa1c97acf310d99d922b63d369718a270809d7c278d34a194cf"}, +] + +[package.dependencies] +pyflakes = ">=1.1.0,<3" +tomli = {version = ">=2.0.1", markers = "python_version < \"3.11\""} + [[package]] name = "backcall" version = "0.2.0" @@ -203,17 +246,17 @@ files = [ [[package]] name = "boto3" -version = "1.34.69" +version = "1.34.70" description = "The AWS SDK for Python" optional = false python-versions = ">=3.8" files = [ - {file = "boto3-1.34.69-py3-none-any.whl", hash = "sha256:2e25ef6bd325217c2da329829478be063155897d8d3b29f31f7f23ab548519b1"}, - {file = "boto3-1.34.69.tar.gz", hash = "sha256:898a5fed26b1351352703421d1a8b886ef2a74be6c97d5ecc92432ae01fda203"}, + {file = "boto3-1.34.70-py3-none-any.whl", hash = "sha256:8d7902e2c0c62837457ba18146e3feaf1dec62018617edc5c0336b65b305b682"}, + {file = "boto3-1.34.70.tar.gz", hash = "sha256:54150a52eb93028b8e09df00319e8dcb68be7459333d5da00d706d75ba5130d6"}, ] [package.dependencies] -botocore = ">=1.34.69,<1.35.0" +botocore = ">=1.34.70,<1.35.0" jmespath = ">=0.7.1,<2.0.0" s3transfer = ">=0.10.0,<0.11.0" @@ -222,13 +265,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "botocore" -version = "1.34.69" +version = "1.34.70" description = "Low-level, data-driven core of boto 3." optional = false python-versions = ">=3.8" files = [ - {file = "botocore-1.34.69-py3-none-any.whl", hash = "sha256:d3802d076d4d507bf506f9845a6970ce43adc3d819dd57c2791f5c19ed6e5950"}, - {file = "botocore-1.34.69.tar.gz", hash = "sha256:d1ab2bff3c2fd51719c2021d9fa2f30fbb9ed0a308f69e9a774ac92c8091380a"}, + {file = "botocore-1.34.70-py3-none-any.whl", hash = "sha256:c86944114e85c8a8d5da06fb84f2609ed3bd23cd2fc06b30250bef7e37e8c589"}, + {file = "botocore-1.34.70.tar.gz", hash = "sha256:fa03d4972cd57d505e6c0eb5d7c7a1caeb7dd49e84f963f7ebeca41fe8ab736e"}, ] [package.dependencies] @@ -545,6 +588,39 @@ tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.1 [package.extras] toml = ["tomli"] +[[package]] +name = "datamodel-code-generator" +version = "0.25.5" +description = "Datamodel Code Generator" +optional = false +python-versions = ">=3.7,<4.0" +files = [ + {file = "datamodel_code_generator-0.25.5-py3-none-any.whl", hash = "sha256:3b62b42c8ebf2bb98cfbc24467b523c5b76780c585b72f4ac2fc1f1f576702ab"}, + {file = "datamodel_code_generator-0.25.5.tar.gz", hash = "sha256:545f897481a94781e32b3c26a452ce049320b091310729f7fc6fa780f6a87898"}, +] + +[package.dependencies] +argcomplete = ">=1.10,<4.0" +black = ">=19.10b0" +genson = ">=1.2.1,<2.0" +inflect = ">=4.1.0,<6.0" +isort = ">=4.3.21,<6.0" +jinja2 = ">=2.10.1,<4.0" +packaging = "*" +pydantic = [ + {version = ">=1.5.1,<2.4.0 || >2.4.0,<3.0", extras = ["email"], markers = "python_version < \"3.10\""}, + {version = ">=1.10.0,<2.4.0 || >2.4.0,<3.0", extras = ["email"], markers = "python_version >= \"3.11\" and python_version < \"4.0\""}, + {version = ">=1.9.0,<2.4.0 || >2.4.0,<3.0", extras = ["email"], markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, +] +pyyaml = ">=6.0.1" +toml = {version = ">=0.10.0,<1.0.0", markers = "python_version < \"3.11\""} + +[package.extras] +debug = ["PySnooper (>=0.4.1,<2.0.0)"] +graphql = ["graphql-core (>=3.2.3,<4.0.0)"] +http = ["httpx"] +validation = ["openapi-spec-validator (>=0.2.8,<0.7.0)", "prance (>=0.18.2)"] + [[package]] name = "debugpy" version = "1.8.1" @@ -609,6 +685,26 @@ files = [ {file = "distlib-0.3.8.tar.gz", hash = "sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64"}, ] +[[package]] +name = "dnspython" +version = "2.6.1" +description = "DNS toolkit" +optional = false +python-versions = ">=3.8" +files = [ + {file = "dnspython-2.6.1-py3-none-any.whl", hash = "sha256:5ef3b9680161f6fa89daf8ad451b5f1a33b18ae8a1c6778cdf4b43f08c0a6e50"}, + {file = "dnspython-2.6.1.tar.gz", hash = "sha256:e8f0f9c23a7b7cb99ded64e6c3a6f3e701d78f50c55e002b839dea7225cff7cc"}, +] + +[package.extras] +dev = ["black (>=23.1.0)", "coverage (>=7.0)", "flake8 (>=7)", "mypy (>=1.8)", "pylint (>=3)", "pytest (>=7.4)", "pytest-cov (>=4.1.0)", "sphinx (>=7.2.0)", "twine (>=4.0.0)", "wheel (>=0.42.0)"] +dnssec = ["cryptography (>=41)"] +doh = ["h2 (>=4.1.0)", "httpcore (>=1.0.0)", "httpx (>=0.26.0)"] +doq = ["aioquic (>=0.9.25)"] +idna = ["idna (>=3.6)"] +trio = ["trio (>=0.23)"] +wmi = ["wmi (>=1.5.1)"] + [[package]] name = "dockerfile" version = "3.3.1" @@ -623,6 +719,21 @@ files = [ {file = "dockerfile-3.3.1.tar.gz", hash = "sha256:4790b3d96d1018302b27661f9624d851a4b7113bce1dbb2d7509991e81a387a9"}, ] +[[package]] +name = "email-validator" +version = "2.1.1" +description = "A robust email address syntax and deliverability validation library." +optional = false +python-versions = ">=3.8" +files = [ + {file = "email_validator-2.1.1-py3-none-any.whl", hash = "sha256:97d882d174e2a65732fb43bfce81a3a834cbc1bde8bf419e30ef5ea976370a05"}, + {file = "email_validator-2.1.1.tar.gz", hash = "sha256:200a70680ba08904be6d1eef729205cc0d687634399a5924d842533efb824b84"}, +] + +[package.dependencies] +dnspython = ">=2.0.0" +idna = ">=2.0.0" + [[package]] name = "exceptiongroup" version = "1.2.0" @@ -672,18 +783,18 @@ devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benc [[package]] name = "filelock" -version = "3.13.1" +version = "3.13.3" description = "A platform independent file lock." optional = false python-versions = ">=3.8" files = [ - {file = "filelock-3.13.1-py3-none-any.whl", hash = "sha256:57dbda9b35157b05fb3e58ee91448612eb674172fab98ee235ccb0b5bee19a1c"}, - {file = "filelock-3.13.1.tar.gz", hash = "sha256:521f5f56c50f8426f5e03ad3b281b490a87ef15bc6c526f168290f0c7148d44e"}, + {file = "filelock-3.13.3-py3-none-any.whl", hash = "sha256:5ffa845303983e7a0b7ae17636509bc97997d58afeafa72fb141a17b152284cb"}, + {file = "filelock-3.13.3.tar.gz", hash = "sha256:a79895a25bbefdf55d1a2a0a80968f7dbb28edcd6d4234a0afb3f37ecde4b546"}, ] [package.extras] -docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.24)"] -testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] +docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] typing = ["typing-extensions (>=4.8)"] [[package]] @@ -760,6 +871,16 @@ smb = ["smbprotocol"] ssh = ["paramiko"] tqdm = ["tqdm"] +[[package]] +name = "genson" +version = "1.2.2" +description = "GenSON is a powerful, user-friendly JSON Schema generator." +optional = false +python-versions = "*" +files = [ + {file = "genson-1.2.2.tar.gz", hash = "sha256:8caf69aa10af7aee0e1a1351d1d06801f4696e005f06cedef438635384346a16"}, +] + [[package]] name = "google-api-core" version = "2.18.0" @@ -1020,13 +1141,13 @@ socks = ["socksio (==1.*)"] [[package]] name = "huggingface-hub" -version = "0.21.4" +version = "0.22.0" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" optional = false python-versions = ">=3.8.0" files = [ - {file = "huggingface_hub-0.21.4-py3-none-any.whl", hash = "sha256:df37c2c37fc6c82163cdd8a67ede261687d80d1e262526d6c0ce73b6b3630a7b"}, - {file = "huggingface_hub-0.21.4.tar.gz", hash = "sha256:e1f4968c93726565a80edf6dc309763c7b546d0cfe79aa221206034d50155531"}, + {file = "huggingface_hub-0.22.0-py3-none-any.whl", hash = "sha256:72dea96299751699180184c06a4689e54cbfacecb1a3d08ac7a269c884bb17c3"}, + {file = "huggingface_hub-0.22.0.tar.gz", hash = "sha256:304f1e235c68c0a9f58bced47f13d6df241a5b4e3678f4981aa1e4f4bce63f6d"}, ] [package.dependencies] @@ -1039,15 +1160,16 @@ tqdm = ">=4.42.1" typing-extensions = ">=3.7.4.3" [package.extras] -all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] +all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "minijinja (>=1.0)", "mypy (==1.5.1)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.3.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] cli = ["InquirerPy (==0.3.4)"] -dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] +dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "minijinja (>=1.0)", "mypy (==1.5.1)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.3.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"] hf-transfer = ["hf-transfer (>=0.1.4)"] -inference = ["aiohttp", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)"] -quality = ["mypy (==1.5.1)", "ruff (>=0.1.3)"] +inference = ["aiohttp", "minijinja (>=1.0)"] +quality = ["mypy (==1.5.1)", "ruff (>=0.3.0)"] tensorflow = ["graphviz", "pydot", "tensorflow"] -testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] +tensorflow-testing = ["keras (<3.0)", "tensorflow"] +testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "minijinja (>=1.0)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] torch = ["safetensors", "torch"] typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"] @@ -1113,6 +1235,21 @@ zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""} docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] testing = ["jaraco.test (>=5.4)", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-ruff (>=0.2.1)", "zipp (>=3.17)"] +[[package]] +name = "inflect" +version = "5.6.2" +description = "Correctly generate plurals, singular nouns, ordinals, indefinite articles; convert numbers to words" +optional = false +python-versions = ">=3.7" +files = [ + {file = "inflect-5.6.2-py3-none-any.whl", hash = "sha256:b45d91a4a28a4e617ff1821117439b06eaa86e2a4573154af0149e9be6687238"}, + {file = "inflect-5.6.2.tar.gz", hash = "sha256:aadc7ed73928f5e014129794bbac03058cca35d0a973a5fc4eb45c7fa26005f9"}, +] + +[package.extras] +docs = ["jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx"] +testing = ["pygments", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"] + [[package]] name = "iniconfig" version = "2.0.0" @@ -1391,6 +1528,54 @@ files = [ {file = "jupyterlab_pygments-0.3.0.tar.gz", hash = "sha256:721aca4d9029252b11cfa9d185e5b5af4d54772bb8072f9b7036f4170054d35d"}, ] +[[package]] +name = "libcst" +version = "1.1.0" +description = "A concrete syntax tree with AST-like properties for Python 3.5, 3.6, 3.7, 3.8, 3.9, and 3.10 programs." +optional = false +python-versions = ">=3.8" +files = [ + {file = "libcst-1.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:63f75656fd733dc20354c46253fde3cf155613e37643c3eaf6f8818e95b7a3d1"}, + {file = "libcst-1.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8ae11eb1ea55a16dc0cdc61b41b29ac347da70fec14cc4381248e141ee2fbe6c"}, + {file = "libcst-1.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4bc745d0c06420fe2644c28d6ddccea9474fb68a2135904043676deb4fa1e6bc"}, + {file = "libcst-1.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8c1f2da45f1c45634090fd8672c15e0159fdc46853336686959b2d093b6e10fa"}, + {file = "libcst-1.1.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:003e5e83a12eed23542c4ea20fdc8de830887cc03662432bb36f84f8c4841b81"}, + {file = "libcst-1.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:3ebbb9732ae3cc4ae7a0e97890bed0a57c11d6df28790c2b9c869f7da653c7c7"}, + {file = "libcst-1.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d68c34e3038d3d1d6324eb47744cbf13f2c65e1214cf49db6ff2a6603c1cd838"}, + {file = "libcst-1.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9dffa1795c2804d183efb01c0f1efd20a7831db6a21a0311edf90b4100d67436"}, + {file = "libcst-1.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc9b6ac36d7ec9db2f053014ea488086ca2ed9c322be104fbe2c71ca759da4bb"}, + {file = "libcst-1.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78b7a38ec4c1c009ac39027d51558b52851fb9234669ba5ba62283185963a31c"}, + {file = "libcst-1.1.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5297a16e575be8173185e936b7765c89a3ca69d4ae217a4af161814a0f9745a7"}, + {file = "libcst-1.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:7ccaf53925f81118aeaadb068a911fac8abaff608817d7343da280616a5ca9c1"}, + {file = "libcst-1.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:75816647736f7e09c6120bdbf408456f99b248d6272277eed9a58cf50fb8bc7d"}, + {file = "libcst-1.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c8f26250f87ca849a7303ed7a4fd6b2c7ac4dec16b7d7e68ca6a476d7c9bfcdb"}, + {file = "libcst-1.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2d37326bd6f379c64190a28947a586b949de3a76be00176b0732c8ee87d67ebe"}, + {file = "libcst-1.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3d8cf974cfa2487b28f23f56c4bff90d550ef16505e58b0dca0493d5293784b"}, + {file = "libcst-1.1.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:82d1271403509b0a4ee6ff7917c2d33b5a015f44d1e208abb1da06ba93b2a378"}, + {file = "libcst-1.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:bca1841693941fdd18371824bb19a9702d5784cd347cb8231317dbdc7062c5bc"}, + {file = "libcst-1.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f36f592e035ef84f312a12b75989dde6a5f6767fe99146cdae6a9ee9aff40dd0"}, + {file = "libcst-1.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f561c9a84eca18be92f4ad90aa9bd873111efbea995449301719a1a7805dbc5c"}, + {file = "libcst-1.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:97fbc73c87e9040e148881041fd5ffa2a6ebf11f64b4ccb5b52e574b95df1a15"}, + {file = "libcst-1.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:99fdc1929703fd9e7408aed2e03f58701c5280b05c8911753a8d8619f7dfdda5"}, + {file = "libcst-1.1.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0bf69cbbab5016d938aac4d3ae70ba9ccb3f90363c588b3b97be434e6ba95403"}, + {file = "libcst-1.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:fe41b33aa73635b1651f64633f429f7aa21f86d2db5748659a99d9b7b1ed2a90"}, + {file = "libcst-1.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:73c086705ed34dbad16c62c9adca4249a556c1b022993d511da70ea85feaf669"}, + {file = "libcst-1.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3a07ecfabbbb8b93209f952a365549e65e658831e9231649f4f4e4263cad24b1"}, + {file = "libcst-1.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c653d9121d6572d8b7f8abf20f88b0a41aab77ff5a6a36e5a0ec0f19af0072e8"}, + {file = "libcst-1.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5f1cd308a4c2f71d5e4eec6ee693819933a03b78edb2e4cc5e3ad1afd5fb3f07"}, + {file = "libcst-1.1.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8afb6101b8b3c86c5f9cec6b90ab4da16c3c236fe7396f88e8b93542bb341f7c"}, + {file = "libcst-1.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:d22d1abfe49aa60fc61fa867e10875a9b3024ba5a801112f4d7ba42d8d53242e"}, + {file = "libcst-1.1.0.tar.gz", hash = "sha256:0acbacb9a170455701845b7e940e2d7b9519db35a86768d86330a0b0deae1086"}, +] + +[package.dependencies] +pyyaml = ">=5.2" +typing-extensions = ">=3.7.4.2" +typing-inspect = ">=0.4.0" + +[package.extras] +dev = ["Sphinx (>=5.1.1)", "black (==23.9.1)", "build (>=0.10.0)", "coverage (>=4.5.4)", "fixit (==2.0.0.post1)", "flake8 (>=3.7.8,<5)", "hypothesis (>=4.36.0)", "hypothesmith (>=0.0.4)", "jinja2 (==3.1.2)", "jupyter (>=1.0.0)", "maturin (>=0.8.3,<0.16)", "nbsphinx (>=0.4.2)", "prompt-toolkit (>=2.0.9)", "pyre-check (==0.9.18)", "setuptools-rust (>=1.5.2)", "setuptools-scm (>=6.0.1)", "slotscheck (>=0.7.1)", "sphinx-rtd-theme (>=0.4.3)", "ufmt (==2.2.0)", "usort (==1.0.7)"] + [[package]] name = "loguru" version = "0.7.2" @@ -2127,6 +2312,7 @@ files = [ [package.dependencies] annotated-types = ">=0.4.0" +email-validator = {version = ">=2.0.0", optional = true, markers = "extra == \"email\""} pydantic-core = "2.16.3" typing-extensions = ">=4.6.1" @@ -2881,6 +3067,17 @@ webencodings = ">=0.4" doc = ["sphinx", "sphinx_rtd_theme"] test = ["flake8", "isort", "pytest"] +[[package]] +name = "toml" +version = "0.10.2" +description = "Python Library for Tom's Obvious, Minimal Language" +optional = false +python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" +files = [ + {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"}, + {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"}, +] + [[package]] name = "tomli" version = "2.0.1" @@ -2949,13 +3146,13 @@ test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0, [[package]] name = "typer" -version = "0.9.0" +version = "0.10.0" description = "Typer, build great CLIs. Easy to code. Based on Python type hints." optional = false python-versions = ">=3.6" files = [ - {file = "typer-0.9.0-py3-none-any.whl", hash = "sha256:5d96d986a21493606a358cae4461bd8cdf83cbf33a5aa950ae629ca3b51467ee"}, - {file = "typer-0.9.0.tar.gz", hash = "sha256:50922fd79aea2f4751a8e0408ff10d2662bd0c8bbfa84755a699f3bada2978b2"}, + {file = "typer-0.10.0-py3-none-any.whl", hash = "sha256:b8a587aa06d3c5422c09c2e9935eb80b4c9de8605fd5ab702b2f92d72246ca48"}, + {file = "typer-0.10.0.tar.gz", hash = "sha256:597f974754520b091665f993f88abdd088bb81c56b3042225434ced0b50a788b"}, ] [package.dependencies] @@ -2966,7 +3163,7 @@ typing-extensions = ">=3.7.4.3" all = ["colorama (>=0.4.3,<0.5.0)", "rich (>=10.11.0,<14.0.0)", "shellingham (>=1.3.0,<2.0.0)"] dev = ["autoflake (>=1.3.1,<2.0.0)", "flake8 (>=3.8.3,<4.0.0)", "pre-commit (>=2.17.0,<3.0.0)"] doc = ["cairosvg (>=2.5.2,<3.0.0)", "mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "pillow (>=9.3.0,<10.0.0)"] -test = ["black (>=22.3.0,<23.0.0)", "coverage (>=6.2,<7.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.910)", "pytest (>=4.4.0,<8.0.0)", "pytest-cov (>=2.10.0,<5.0.0)", "pytest-sugar (>=0.9.4,<0.10.0)", "pytest-xdist (>=1.32.0,<4.0.0)", "rich (>=10.11.0,<14.0.0)", "shellingham (>=1.3.0,<2.0.0)"] +test = ["black (>=22.3.0,<23.0.0)", "coverage (>=6.2,<7.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.971)", "pytest (>=4.4.0,<8.0.0)", "pytest-cov (>=2.10.0,<5.0.0)", "pytest-sugar (>=0.9.4,<0.10.0)", "pytest-xdist (>=1.32.0,<4.0.0)", "rich (>=10.11.0,<14.0.0)", "shellingham (>=1.3.0,<2.0.0)"] [[package]] name = "types-pyyaml" @@ -3040,6 +3237,21 @@ files = [ {file = "typing_extensions-4.10.0.tar.gz", hash = "sha256:b0abd7c89e8fb96f98db18d86106ff1d90ab692004eb746cf6eda2682f91b3cb"}, ] +[[package]] +name = "typing-inspect" +version = "0.9.0" +description = "Runtime inspection utilities for typing module." +optional = false +python-versions = "*" +files = [ + {file = "typing_inspect-0.9.0-py3-none-any.whl", hash = "sha256:9ee6fc59062311ef8547596ab6b955e1b8aa46242d854bfc78f4f6b0eff35f9f"}, + {file = "typing_inspect-0.9.0.tar.gz", hash = "sha256:b23fc42ff6f6ef6954e4852c1fb512cdd18dbea03134f91f856a95ccc9461f78"}, +] + +[package.dependencies] +mypy-extensions = ">=0.3.0" +typing-extensions = ">=3.7.4" + [[package]] name = "urllib3" version = "1.26.18" @@ -3261,4 +3473,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.8,<3.12" -content-hash = "f247b02f095febf25b29d5b61dded3ce0f195ba51b734a5e771394f75db1ce25" +content-hash = "d9973afc3502f26a26564a5fba4054f46eb62a7b78d071f61963899cafce8076" diff --git a/pyproject.toml b/pyproject.toml index 6efe1e1ca..70613b143 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,6 +98,13 @@ types-requests = ">=2.31.0.2" uvicorn = ">=0.24.0" uvloop = ">=0.17.0" +[tool.poetry.group.slay.dependencies] +astroid = "^3.1.0" +datamodel-code-generator = "^0.25.4" +libcst = "<1.2.0" +autoflake = "<=2.2" + + [build-system] build-backend = "poetry.core.masonry.api" requires = ["poetry-core>=1.2.1"] diff --git a/slay-examples/text_to_num/requirements.txt b/slay-examples/text_to_num/requirements.txt new file mode 100644 index 000000000..418da39d6 --- /dev/null +++ b/slay-examples/text_to_num/requirements.txt @@ -0,0 +1,4 @@ +git+https://github.com/basetenlabs/truss.git +httpx +libcst +pydantic diff --git a/slay-examples/text_to_num/user_package/__init__.py b/slay-examples/text_to_num/user_package/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/slay-examples/text_to_num/user_package/shared_processor.py b/slay-examples/text_to_num/user_package/shared_processor.py new file mode 100644 index 000000000..d2524cf21 --- /dev/null +++ b/slay-examples/text_to_num/user_package/shared_processor.py @@ -0,0 +1,18 @@ +import slay + +IMAGE_NUMPY = ( + slay.Image() + .pip_requirements_file(slay.make_abs_path_here("../requirements.txt")) + .pip_requirements(["numpy"]) +) + + +class SplitText(slay.ProcessorBase): + + default_config = slay.Config(image=IMAGE_NUMPY) + + async def run(self, data: str, num_partitions: int) -> tuple[list[str], int]: + import numpy as np + + parts = np.array_split(np.array(list(data)), num_partitions) + return ["".join(part) for part in parts], 123 diff --git a/slay-examples/text_to_num/workflow.py b/slay-examples/text_to_num/workflow.py new file mode 100644 index 000000000..4f41bbadb --- /dev/null +++ b/slay-examples/text_to_num/workflow.py @@ -0,0 +1,201 @@ +# This logging is needed for debuggin class initalization. +# import logging + +# log_format = "%(levelname).1s%(asctime)s %(filename)s:%(lineno)d] %(message)s" +# date_format = "%m%d %H:%M:%S" +# logging.basicConfig(level=logging.DEBUG, format=log_format, datefmt=date_format) + + +import random +import string +from typing import Protocol + +import pydantic +import slay +from truss import truss_config +from user_package import shared_processor + +IMAGE_COMMON = slay.Image().pip_requirements_file( + slay.make_abs_path_here("requirements.txt") +) + + +class GenerateData(slay.ProcessorBase): + + default_config = slay.Config(image=IMAGE_COMMON) + + def run(self, length: int) -> str: + return "".join(random.choices(string.ascii_letters + string.digits, k=length)) + + +IMAGE_TRANSFORMERS_GPU = ( + slay.Image() + .pip_requirements_file(slay.make_abs_path_here("requirements.txt")) + .pip_requirements( + ["transformers==4.38.1", "torch==2.0.1", "sentencepiece", "accelerate"] + ) +) + + +MISTRAL_HF_MODEL = "mistralai/Mistral-7B-Instruct-v0.2" +MISTRAL_CACHE = truss_config.ModelRepo( + repo_id=MISTRAL_HF_MODEL, allow_patterns=["*.json", "*.safetensors", ".model"] +) + + +class MistraLLMConfig(pydantic.BaseModel): + hf_model_name: str + + +class MistralLLM(slay.ProcessorBase[MistraLLMConfig]): + + default_config = slay.Config( + image=IMAGE_TRANSFORMERS_GPU, + compute=slay.Compute().cpu(2).gpu("A10G"), + assets=slay.Assets().cached([MISTRAL_CACHE]), + user_config=MistraLLMConfig(hf_model_name=MISTRAL_HF_MODEL), + ) + # default_config = slay.Config(config_path="mistral_config.yaml") + + def __init__( + self, + context: slay.Context[MistraLLMConfig] = slay.provide_context(), + ) -> None: + super().__init__(context) + import torch + import transformers + + model_name = self.user_config.hf_model_name + + self._model = transformers.AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.float16, + device_map="auto", + ) + self._tokenizer = transformers.AutoTokenizer.from_pretrained( + model_name, + device_map="auto", + torch_dtype=torch.float16, + ) + + self._generate_args = { + "max_new_tokens": 512, + "temperature": 1.0, + "top_p": 0.95, + "top_k": 50, + "repetition_penalty": 1.0, + "no_repeat_ngram_size": 0, + "use_cache": True, + "do_sample": True, + "eos_token_id": self._tokenizer.eos_token_id, + "pad_token_id": self._tokenizer.pad_token_id, + } + + def run(self, data: str) -> str: + import torch + + formatted_prompt = f"[INST] {data} [/INST]" + input_ids = self._tokenizer( + formatted_prompt, return_tensors="pt" + ).input_ids.cuda() + with torch.no_grad(): + output = self._model.generate(inputs=input_ids, **self._generate_args) + result = self._tokenizer.decode(output[0]) + return result + + +class MistralP(Protocol): + def __init__(self, context: slay.Context) -> None: + ... + + def run(self, data: str) -> str: + ... + + +class TextToNum(slay.ProcessorBase): + default_config = slay.Config(image=IMAGE_COMMON) + + def __init__( + self, + context: slay.Context = slay.provide_context(), + mistral: MistralP = slay.provide(MistralLLM), + ) -> None: + super().__init__(context) + self._mistral = mistral + + def run(self, data: str) -> int: + number = 0 + generated_text = self._mistral.run(data) + for char in generated_text: + number += ord(char) + + return number + + +class Workflow(slay.ProcessorBase): + default_config = slay.Config(image=IMAGE_COMMON) + + def __init__( + self, + context: slay.Context = slay.provide_context(), + data_generator: GenerateData = slay.provide(GenerateData), + splitter: shared_processor.SplitText = slay.provide(shared_processor.SplitText), + text_to_num: TextToNum = slay.provide(TextToNum), + ) -> None: + super().__init__(context) + self._data_generator = data_generator + self._data_splitter = splitter + self._text_to_num = text_to_num + + async def run(self, length: int, num_partitions: int) -> tuple[int, str, int]: + data = self._data_generator.run(length) + text_parts, number = await self._data_splitter.run(data, num_partitions) + value = 0 + for part in text_parts: + value += self._text_to_num.run(part) + return value, data, number + + +if __name__ == "__main__": + import logging + + from slay import utils + + # from slay.truss_compat import deploy + + root_logger = logging.getLogger() + root_logger.setLevel(logging.INFO) + log_format = "%(levelname).1s%(asctime)s %(filename)s:%(lineno)d] %(message)s" + date_format = "%m%d %H:%M:%S" + formatter = logging.Formatter(fmt=log_format, datefmt=date_format) + for handler in root_logger.handlers: + handler.setFormatter(formatter) + + # class FakeMistralLLM(slay.ProcessorBase): + # def run(self, data: str) -> str: + # return data.upper() + + # import asyncio + # with slay.run_local(): + # text_to_num = TextToNum(mistral=FakeMistralLLM()) + # wf = Workflow(text_to_num=text_to_num) + # result = asyncio.run(wf.run(length=123, num_partitions=123)) + # print(result) + + with utils.log_level(logging.DEBUG): + remote = slay.deploy_remotely( + Workflow, workflow_name="Test", generate_only=False + ) + + # remote = slay.definitions.BasetenRemoteDescriptor( + # b10_model_id="7qk59gdq", + # b10_model_version_id="woz52g3", + # b10_model_name="Workflow", + # b10_model_url="https://model-7qk59gdq.api.baseten.co/production", + # ) + # with utils.log_level(logging.INFO): + # response = deploy.call_workflow_dbg( + # remote, {"length": 1000, "num_partitions": 100} + # ) + # print(response) + # print(response.json()) diff --git a/slay/__init__.py b/slay/__init__.py new file mode 100644 index 000000000..2dbc287d5 --- /dev/null +++ b/slay/__init__.py @@ -0,0 +1,17 @@ +# flake8: noqa F401 +from slay.definitions import ( + Assets, + Compute, + Config, + Context, + Image, + UsageError, + make_abs_path_here, +) +from slay.public_api import ( + ProcessorBase, + deploy_remotely, + provide, + provide_context, + run_local, +) diff --git a/slay/code_gen.py b/slay/code_gen.py new file mode 100644 index 000000000..306823fb4 --- /dev/null +++ b/slay/code_gen.py @@ -0,0 +1,359 @@ +import ast +import logging +import pathlib +import shlex +import subprocess +import textwrap +from typing import Any, Iterable, Optional + +import libcst +from slay import definitions, utils +from slay.truss_adapter import code_gen + +INDENT = " " * 4 + +STUB_MODULE = "remote_stubs" + + +def _indent(text: str) -> str: + return textwrap.indent(text, INDENT) + + +class _MainRemover(ast.NodeTransformer): + """Removes main-section from module AST.""" + + def visit_If(self, node): + """Robustly matches variations of `if __name__ == "__main__":`.""" + if ( + isinstance(node.test, ast.Compare) + and any( + isinstance(c, ast.Name) and c.id == "__name__" + for c in ast.walk(node.test.left) + ) + and any( + isinstance(c, ast.Constant) and c.value == "__main__" + for c in ast.walk(node.test) + ) + ): + return None + return self.generic_visit(node) + + +def _remove_main_section(source_code: str) -> str: + """Removes main-section from module source.""" + parsed_code = ast.parse(source_code) + transformer = _MainRemover() + transformed_ast = transformer.visit(parsed_code) + return ast.unparse(transformed_ast) + + +def _run_simple_subprocess(cmd: str) -> None: + process = subprocess.Popen( + shlex.split(cmd), stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + _, stderr = process.communicate() + if process.returncode != 0: + raise ChildProcessError(f"Error: {stderr.decode()}") + + +def _format_python_file(file_path: pathlib.Path) -> None: + _run_simple_subprocess( + f"autoflake --in-place --remove-all-unused-imports {file_path}" + ) + _run_simple_subprocess(f"black {file_path}") + _run_simple_subprocess(f"isort {file_path}") + + +def make_processor_dir( + workflow_root: pathlib.Path, + workflow_name: str, + processor_descriptor: definitions.ProcessorAPIDescriptor, +) -> pathlib.Path: + processor_name = processor_descriptor.cls_name + file_name = f"processor_{processor_name}" + processor_dir = ( + workflow_root / definitions.GENERATED_CODE_DIR / workflow_name / file_name + ) + processor_dir.mkdir(exist_ok=True, parents=True) + return processor_dir + + +# Stub Gen ############################################################################# + +# TODO: I/O types: +# * Generate models from pydantic JSON schema. +# * Handle if multiple stubs use same pydantic models -> deduplicate or suffix. +# * Use a serialization / deserialization helper instead of directly generating this +# code. + + +def _endpoint_signature_src(endpoint: definitions.EndpointAPIDescriptor): + """ + E.g.: `async def run(self, data: str, num_partitions: int) -> tuple[list, int]:` + """ + if endpoint.is_generator: + # TODO: implement generator. + raise NotImplementedError("Generator.") + + def_str = "async def" if endpoint.is_async else "def" + args = ", ".join( + f"{arg_name}: {arg_type.as_src_str()}" + for arg_name, arg_type in endpoint.input_names_and_tyes + ) + if len(endpoint.output_types) == 1: + output_type = f"{endpoint.output_types[0].as_src_str()}" + else: + output_type = ( + f"tuple[{', '.join(t.as_src_str() for t in endpoint.output_types)}]" + ) + return f"""{def_str} {endpoint.name}(self, {args}) -> {output_type}:""" + + +def _gen_protocol_src(processor: definitions.ProcessorAPIDescriptor): + """Generates source code for a Protocol that matches the processor.""" + imports = ["from typing import Protocol"] + src_parts = [ + f""" +class {processor.cls_name}P(Protocol): +""" + ] + signature = _endpoint_signature_src(processor.endpoint) + src_parts.append(_indent(f"{signature}\n{INDENT}...\n")) + return "\n".join(src_parts), imports + + +def _endpoint_body_src(endpoint: definitions.EndpointAPIDescriptor): + """Generates source code for calling the stub and wrapping the I/O types. + + E.g.: + ``` + json_args = {"data": data, "num_partitions": num_partitions} + json_result = await self._remote.predict_async(json_args) + return (json_result[0], json_result[1]) + ``` + """ + if endpoint.is_generator: + raise NotImplementedError("Generator") + + json_arg_parts = ( + ( + f"'{arg_name}': {arg_name}.json()" + if arg_type.is_pydantic + else f"'{arg_name}': {arg_name}" + ) + for arg_name, arg_type in endpoint.input_names_and_tyes + ) + + json_args = f"{{{', '.join(json_arg_parts)}}}" + remote_call = ( + "await self._remote.predict_async(json_args)" + if endpoint.is_async + else "self._remote.predict_sync(json_args)" + ) + + if len(endpoint.output_types) == 1: + output_type = utils.expect_one(endpoint.output_types) + if output_type.is_pydantic: + ret = f"{output_type.as_src_str()}.parse_obj(json_result)" + else: + ret = "json_result" + else: + ret_parts = ", ".join( + ( + f"{output_type.as_src_str()}.parse_obj(json_result[{i}])" + if output_type.is_pydantic + else f"json_result[{i}]" + ) + for i, output_type in enumerate(endpoint.output_types) + ) + ret = f"({ret_parts})" + + body = f""" +json_args = {json_args} +json_result = {remote_call} +return {ret} +""" + return body + + +def _gen_stub_src(processor: definitions.ProcessorAPIDescriptor): + """Generates stub class source, e.g: + + ``` + from slay import stub + + class SplitText(stub.StubBase): + def __init__(self, url: str, api_key: str) -> None: + self._remote = stub.BasetenSession(url, api_key) + + async def run(self, data: str, num_partitions: int) -> tuple[list, int]: + json_args = {"data": data, "num_partitions": num_partitions} + json_result = await self._remote.predict_async(json_args) + return (json_result[0], json_result[1]) + ``` + """ + imports = ["from slay import stub"] + + src_parts = [ + f""" +class {processor.cls_name}(stub.StubBase): + + def __init__(self, url: str, api_key: str) -> None: + self._remote = stub.BasetenSession(url, api_key) +""" + ] + body = _indent(_endpoint_body_src(processor.endpoint)) + src_parts.append( + _indent( + f"{_endpoint_signature_src(processor.endpoint)}{body}\n", + ) + ) + return "\n".join(src_parts), imports + + +def generate_stubs_for_deps( + processor_dir: pathlib.Path, + dependencies: Iterable[definitions.ProcessorAPIDescriptor], +) -> Optional[pathlib.Path]: + """Generates a source file with stub classes.""" + # TODO: user-defined I/O types are not imported / included correctly. + imports = set() + src_parts = [] + for dep in dependencies: + # protocol_src, new_deps = gen_protocol_src(dep) + # imports.update(new_deps) + # src_parts.append(protocol_src) + stub_src, new_deps = _gen_stub_src(dep) + imports.update(new_deps) + src_parts.append(stub_src) + + if not (imports or src_parts): + return None + + out_file_path = processor_dir / f"{STUB_MODULE}.py" + with out_file_path.open("w") as fp: + fp.writelines("\n".join(imports)) + fp.writelines(src_parts) + + _format_python_file(out_file_path) + return out_file_path + + +# Remote Processor Gen ################################################################# + + +class _InitRewriter(libcst.CSTTransformer): + """Removes processors from init args and instead initializes corresponding stubs.""" + + def __init__(self, cls_name: str, replacements): + super().__init__() + self._cls_name = cls_name + self._replacements = replacements + + def leave_ClassDef( + self, original_node: libcst.ClassDef, updated_node: libcst.ClassDef + ) -> libcst.ClassDef: + # Target only the Workflow class + if original_node.name.value != self._cls_name: + return updated_node + + new_methods: list[Any] = [] + for method in updated_node.body.body: + if ( + isinstance(method, libcst.FunctionDef) + and method.name.value == "__init__" + ): + new_method = self._modify_init_method(method) + new_methods.append(new_method) + else: + new_methods.append(method) + return updated_node.with_changes( + body=updated_node.body.with_changes(body=new_methods) + ) + + def _modify_init_method(self, method: libcst.FunctionDef) -> libcst.FunctionDef: + keep_params_names = {definitions.SELF_ARG_NAME, definitions.CONTEXT_ARG_NAME} + if method.name.value == "__init__": + # Drop other params - assumes that we have verified that all arguments + # are processors. + keep_params = [] + for param in method.params.params: + if param.name.value in keep_params_names: + keep_params.append(param) + else: + if param.name.value not in self._replacements: + raise ValueError( + f"For argument `{param.name.value}` no processor was " + f"mappend. Available {list(self._replacements.keys())}" + ) + + new_params = method.params.with_changes(params=keep_params) + + processor_assignments = [ + libcst.parse_statement( + f"{name} = stub.stub_factory({stub_cls_ref}, context)" + ) + for name, stub_cls_ref in self._replacements.items() + ] + + # Create new statements for the method body + new_body = method.body.with_changes( + body=processor_assignments + list(method.body.body) + ) + + return method.with_changes(params=new_params, body=new_body) + return method + + +def _rewrite_processor_inits( + source_tree: libcst.Module, processor_desrciptor: definitions.ProcessorAPIDescriptor +): + """Removes processors from init args and instead initializes corresponding stubs.""" + replacements = {} + for name, proc_cls in processor_desrciptor.depdendencies.items(): + replacements[name] = f"{STUB_MODULE}.{proc_cls.__name__}" + + if not replacements: + return source_tree + + logging.debug(f"Adding stub inits to `{processor_desrciptor.cls_name}`.") + + modified_tree = source_tree.visit( + _InitRewriter(processor_desrciptor.cls_name, replacements) + ) + + new_imports = [ + libcst.parse_statement(f"import {STUB_MODULE}"), + libcst.parse_statement("from slay import stub"), + ] + + modified_tree = modified_tree.with_changes( + body=new_imports + list(modified_tree.body) + ) + return modified_tree + + +######################################################################################## + + +def generate_processor_source( + file_path: pathlib.Path, + processor_desrciptor: definitions.ProcessorAPIDescriptor, +): + """Generates code that wraps a processor as a truss-compatible model.""" + sourc_code = _remove_main_section(file_path.read_text()) + source_tree = libcst.parse_module(sourc_code) + source_tree = _rewrite_processor_inits(source_tree, processor_desrciptor) + + # TODO: Processor isolation: either prune file or generate a new file. + # At least remove main section. + + model_def, imports, userconfig_pin = code_gen.generate_truss_model( + processor_desrciptor + ) + new_body: list[libcst.BaseStatement] = ( + imports + list(source_tree.body) + [userconfig_pin, model_def] # type: ignore[assignment, misc, list-item] + ) + source_tree = source_tree.with_changes(body=new_body) + file_path.write_text(source_tree.code) + _format_python_file(file_path) diff --git a/slay/definitions.py b/slay/definitions.py new file mode 100644 index 000000000..3d4c2eb98 --- /dev/null +++ b/slay/definitions.py @@ -0,0 +1,347 @@ +# TODO: this file contains too much implementaiton -> restructure. +import abc +import inspect +import logging +import os +from types import GenericAlias +from typing import Any, ClassVar, Generic, Mapping, Optional, Type, TypeVar + +import pydantic +from pydantic import generics + +UserConfigT = TypeVar("UserConfigT", bound=Optional[pydantic.BaseModel]) + +BASTEN_API_SECRET_NAME = "baseten_api_key" +TRUSS_CONFIG_SLAY_KEY = "slay_metadata" + +ENDPOINT_METHOD_NAME = "run" # Referring to processor method name exposed as endpoint. +# Below arg names must correspond to `definitions.ABCProcessor`. +CONTEXT_ARG_NAME = "context" # Referring to processors `__init__` signature. +SELF_ARG_NAME = "self" + +GENERATED_CODE_DIR = ".slay_generated" +PREDICT_ENDPOINT_NAME = "/predict" +PROCESSOR_MODULE = "processor" + + +class APIDefinitonError(TypeError): + """Raised when user-defined processors do not adhere to API constraints.""" + + +class MissingDependencyError(TypeError): + """Raised when a needed resource could not be found or is not defined.""" + + +class UsageError(Exception): + """Raised when components are not used the expected way at runtime.""" + + +class AbsPath: + _abs_file_path: str + _creating_module: str + _original_path: str + + def __init__( + self, abs_file_path: str, creating_module: str, original_path: str + ) -> None: + self._abs_file_path = abs_file_path + self._creating_module = creating_module + self._original_path = original_path + + def raise_if_not_exists(self) -> None: + if not os.path.isfile(self._abs_file_path): + raise MissingDependencyError( + f"With the file path `{self._original_path}` an absolute path relative " + f"to the calling module `{self._creating_module}` was created, " + f"resulting `{self._abs_file_path}` - but no file was found." + ) + + @property + def abs_path(self) -> str: + return self._abs_file_path + + +def make_abs_path_here(file_path: str) -> AbsPath: + """Helper to specify file paths relative to the *immediately calling* module. + + E.g. in you have a project structure like this" + + root/ + workflow.py + common_requirements.text + sub_package/ + processor.py + processor_requirements.txt + + Not in `root/sub_package/processor.py` you can point to the requirements + file like this: + + ``` + shared = RelativePathToHere("../common_requirements.text") + specific = RelativePathToHere("processor_requirements.text") + ``` + + Caveat: this helper uses the directory of the immediately calling module as an + absolute reference point for resolving the file location. + Therefore you MUST NOT wrap the instantiation of `RelativePathToHere` into a + function (e.g. applying decorators) or use dynamic code execution. + + Ok: + ``` + def foo(path: AbsPath): + abs_path = path.abs_path + + + foo(make_abs_path_here("blabla")) + ``` + + Not Ok: + ``` + def foo(path: str): + badbadbad = make_abs_path_here(path).abs_path + + foo("blabla")) + ``` + """ + # TODO: the absolute path resoultion below uses the calling module as a + # reference point. This would not work if users wrap this call in a funciton + # - we hope the naming makes clear that this should not be done. + caller_frame = inspect.stack()[1] + module_path = caller_frame.filename + if not os.path.isabs(file_path): + module_dir = os.path.dirname(os.path.abspath(module_path)) + abs_file_path = os.path.normpath(os.path.join(module_dir, file_path)) + logging.info(f"Inferring absolute path for `{file_path}` as `{abs_file_path}`.") + else: + abs_file_path = file_path + + return AbsPath(abs_file_path, module_path, file_path) + + +class ImageSpec(pydantic.BaseModel): + class Config: + arbitrary_types_allowed = True + + # TODO: this is not stable yet and might change or refer back to truss. + base_image: str = "python:3.11-slim" + pip_requirements_file: Optional[AbsPath] = None + pip_requirements: list[str] = [] + apt_requirements: list[str] = [] + + +class Image: + """Builder to create image spec.""" + + _spec: ImageSpec + + def __init__(self) -> None: + self._spec = ImageSpec() + + def pip_requirements_file(self, file_path: AbsPath) -> "Image": + self._spec.pip_requirements_file = file_path + return self + + def pip_requirements(self, requirements: list[str]) -> "Image": + self._spec.pip_requirements = requirements + return self + + def apt_requirements(self, requirements: list[str]) -> "Image": + self._spec.apt_requirements = requirements + return self + + def get_spec(self) -> ImageSpec: + return self._spec.copy(deep=True) + + +class ComputeSpec(pydantic.BaseModel): + # TODO: this is not stable yet and might change or refer back to truss. + cpu: str = "1" + memory: str = "2Gi" + gpu: Optional[str] = None + + +class Compute: + """Builder to create compute spec.""" + + _spec: ComputeSpec + + def __init__(self) -> None: + self._spec = ComputeSpec() + + def cpu(self, cpu: int) -> "Compute": + self._spec.cpu = str(cpu) + return self + + def memory(self, memory: str) -> "Compute": + self._spec.memory = memory + return self + + def gpu(self, kind: str, count: int = 1) -> "Compute": + self._spec.gpu = f"{kind}:{count}" + return self + + def get_spec(self) -> ComputeSpec: + return self._spec.copy(deep=True) + + +class AssetSpec(pydantic.BaseModel): + # TODO: this is not stable yet and might change or refer back to truss. + secrets: dict[str, str] = {} + cached: list[Any] = [] + + +class Assets: + """Builder to create asset spec.""" + + _spec: AssetSpec + + def __init__(self) -> None: + self._spec = AssetSpec() + + def add_secret(self, key: str) -> "Assets": + self._spec.secrets[key] = "***" # Actual value is provided in deployment. + return self + + def cached(self, value: list[Any]) -> "Assets": + self._spec.cached = value + return self + + def get_spec(self) -> AssetSpec: + return self._spec.copy(deep=True) + + +class Config(generics.GenericModel, Generic[UserConfigT]): + """Bundles config values needed to deploy a processor.""" + + class Config: + arbitrary_types_allowed = True + + name: Optional[str] = None + image: Image = Image() + compute: Compute = Compute() + assets: Assets = Assets() + user_config: UserConfigT = pydantic.Field(default=None) + + def get_image_spec(self) -> ImageSpec: + return self.image.get_spec() + + def get_compute_spec(self) -> ComputeSpec: + return self.compute.get_spec() + + def get_asset_spec(self) -> AssetSpec: + return self.assets.get_spec() + + +class Context(generics.GenericModel, Generic[UserConfigT]): + """Bundles config values needed to instantiate a processor in deployment.""" + + class Config: + arbitrary_types_allowed = True + + user_config: UserConfigT = pydantic.Field(default=None) + stub_cls_to_url: Mapping[str, str] = {} + # secrets: Optional[secrets_resolver.Secrets] = None + # TODO: above type results in `truss.server.shared.secrets_resolver.Secrets` + # due to the templating, at runtime the object passed will be from + # `shared.secrets_resolver` and give pydantic validation error. + secrets: Optional[Any] = None + + def get_stub_url(self, stub_cls_name: str) -> str: + if stub_cls_name not in self.stub_cls_to_url: + raise MissingDependencyError(f"{stub_cls_name}") + return self.stub_cls_to_url[stub_cls_name] + + def get_baseten_api_key(self) -> str: + if not self.secrets: + raise UsageError(f"Secrets not set in `{self.__class__.__name__}` object.") + if BASTEN_API_SECRET_NAME not in self.secrets: + raise MissingDependencyError( + "For using workflows, it is required to setup a an API key with name " + f"`{BASTEN_API_SECRET_NAME}` on baseten to allow workflow processor to " + "call other processors." + ) + + api_key = self.secrets[BASTEN_API_SECRET_NAME] + return api_key + + +class TrussMetadata(generics.GenericModel, Generic[UserConfigT]): + """Plugin for the truss config (in config["model_metadata"]["slay_metadata"]).""" + + class Config: + arbitrary_types_allowed = True + + user_config: UserConfigT = pydantic.Field(default=None) + stub_cls_to_url: Mapping[str, str] = {} + + +class ABCProcessor(Generic[UserConfigT], abc.ABC): + default_config: ClassVar[Config] + _init_is_patched: ClassVar[bool] = False + _context: Context[UserConfigT] + + @abc.abstractmethod + def __init__(self, context: Context[UserConfigT]) -> None: + ... + + # Cannot add this abstract method to API, because we want to allow arbitraty + # arg/kwarg names and specifying any function signature here would give type errors + # @abc.abstractmethod + # def predict(self, *args, **kwargs) -> Any: ... + + @property + @abc.abstractmethod + def user_config(self) -> UserConfigT: + ... + + +class TypeDescriptor(pydantic.BaseModel): + """For describing I/O types of processors.""" + + # TODO: Supporting pydantic types. + + raw: Any # The raw type annotation object (could be a type or GenericAlias). + + def as_src_str(self) -> str: + if isinstance(self.raw, type): + return self.raw.__name__ + else: + return str(self.raw) + + @property + def is_pydantic(self) -> bool: + return ( + isinstance(self.raw, type) + and not isinstance(self.raw, GenericAlias) + and issubclass(self.raw, pydantic.BaseModel) + ) + + +class EndpointAPIDescriptor(pydantic.BaseModel): + name: str = ENDPOINT_METHOD_NAME + input_names_and_tyes: list[tuple[str, TypeDescriptor]] + output_types: list[TypeDescriptor] + is_async: bool + is_generator: bool + + +class ProcessorAPIDescriptor(pydantic.BaseModel): + processor_cls: Type[ABCProcessor] + src_path: str + depdendencies: Mapping[str, Type[ABCProcessor]] + endpoint: EndpointAPIDescriptor + user_config_type: TypeDescriptor + + def __hash__(self) -> int: + return hash(self.processor_cls) + + @property + def cls_name(self) -> str: + return self.processor_cls.__name__ + + +class BasetenRemoteDescriptor(pydantic.BaseModel): + b10_model_id: str + b10_model_version_id: str + b10_model_name: str + b10_model_url: str diff --git a/slay/framework.py b/slay/framework.py new file mode 100644 index 000000000..75599c253 --- /dev/null +++ b/slay/framework.py @@ -0,0 +1,621 @@ +import collections +import contextlib +import inspect +import logging +import os +import pathlib +import shutil +import sys +import types +from typing import ( + Any, + Callable, + Iterable, + Mapping, + MutableMapping, + Optional, + Protocol, + Type, + get_args, + get_origin, +) + +import pydantic +from slay import code_gen, definitions, utils +from slay.truss_adapter import deploy + +_SIMPLE_TYPES = {int, float, complex, bool, str, bytes, None} +_SIMPLE_CONTAINERS = {list, dict} + + +# Checking of processor class definition ############################################### + + +def _validate_io_type(param: inspect.Parameter) -> None: + """ + For processor I/O (both data or parameters) we allow simple types + (int, str, float...) and `list` or `dict` containers of these. + Any deeper nested and structured data must be typed as a pydnatic model. + """ + anno = param.annotation + if anno in _SIMPLE_TYPES: + return + if isinstance(anno, types.GenericAlias): + if get_origin(anno) not in _SIMPLE_CONTAINERS: + raise definitions.APIDefinitonError( + f"For generic types, only containers {_SIMPLE_CONTAINERS} are " + f"allowed, but got `{param}`." + ) + args = get_args(anno) + for arg in args: + if arg not in _SIMPLE_TYPES: + raise definitions.APIDefinitonError( + f"For generic types, only arg types {_SIMPLE_TYPES} are " + f"allowed, but got `{param}`." + ) + return + if issubclass(anno, pydantic.BaseModel): + try: + anno.schema() + except Exception as e: + raise definitions.APIDefinitonError( + "Pydantic annotations must be able to generate a schema. " + f"Please fix `{param}`." + ) from e + return + + raise definitions.APIDefinitonError(anno) + + +def _validate_endpoint_params( + params: list[inspect.Parameter], cls_name: str +) -> list[tuple[str, definitions.TypeDescriptor]]: + if len(params) == 0: + raise definitions.APIDefinitonError( + f"`{cls_name}.{definitions.ENDPOINT_METHOD_NAME}` must be a method, i.e. " + "with `self` argument." + ) + if params[0].name != definitions.SELF_ARG_NAME: + raise definitions.APIDefinitonError( + f"`{cls_name}.{definitions.ENDPOINT_METHOD_NAME}` must be a method, i.e. " + "with `self` argument." + ) + input_name_and_types = [] + for param in params[1:]: # Skip self argument. + if param.annotation == inspect.Parameter.empty: + raise definitions.APIDefinitonError( + "Inputs of endpoints must have type annotations. " + f"For `{cls_name}` got:\n{param}" + ) + _validate_io_type(param) + type_descriptor = definitions.TypeDescriptor(raw=param.annotation) + input_name_and_types.append((param.name, type_descriptor)) + return input_name_and_types + + +def _validate_and_describe_endpoint( + cls: Type[definitions.ABCProcessor], +) -> definitions.EndpointAPIDescriptor: + """The "endpoint method" of a processor must have the follwing signature: + + ``` + [async] def run( + self, [param_0: anno_0, param_1: anno_1 = default_1, ...]) -> ret_anno: + ``` + + * The name must be `run`. + * It can be sync or async or def. + * The number and names of parameters are arbitrary, both positional and named + parameters are ok. + * All parameters and the return value must have type annotations. See + `_validate_io_type` for valid types. + * Generators are allowed, too (but not yet supported). + """ + if not hasattr(cls, definitions.ENDPOINT_METHOD_NAME): + raise definitions.APIDefinitonError( + f"`{cls.__name__}` must have a {definitions.ENDPOINT_METHOD_NAME}` method." + ) + endpoint_method = getattr( + cls, definitions.ENDPOINT_METHOD_NAME + ) # This is the unbound method. + if not inspect.isfunction(endpoint_method): + raise definitions.APIDefinitonError( + f"`{cls.__name__}.{definitions.ENDPOINT_METHOD_NAME}` must be a method." + ) + + signature = inspect.signature(endpoint_method) + input_name_and_types = _validate_endpoint_params( + list(signature.parameters.values()), cls.__name__ + ) + + if signature.return_annotation == inspect.Parameter.empty: + raise definitions.APIDefinitonError( + f"Return values of endpoints must be type annotated. Got:\n{signature}" + ) + if get_origin(signature.return_annotation) is tuple: + output_types = list( + definitions.TypeDescriptor(raw=arg) + for arg in get_args(signature.return_annotation) + ) + else: + output_types = [definitions.TypeDescriptor(raw=signature.return_annotation)] + + if inspect.isasyncgenfunction(endpoint_method): + is_async = True + is_generator = True + elif inspect.iscoroutinefunction(endpoint_method): + is_async = True + is_generator = False + else: + is_async = False + is_generator = inspect.isgeneratorfunction(endpoint_method) + + return definitions.EndpointAPIDescriptor( + input_names_and_tyes=input_name_and_types, + output_types=output_types, + is_async=is_async, + is_generator=is_generator, + ) + + +def _get_generic_class_type(var): + """Extracts `SomeGeneric` from `SomeGeneric` or `SomeGeneric[T]` uniformly.""" + origin = get_origin(var) + return origin if origin is not None else var + + +def _validate_dependency_arg(param) -> Type[definitions.ABCProcessor]: + # TODO: handle subclasses, unions, optionals, check default value etc. + if not isinstance(param.default, ProcessorProvisionPlaceholder): + raise definitions.APIDefinitonError( + f"Any extra arguments of a processor's __init__ must have a default " + f"value of type `{ProcessorProvisionPlaceholder}` (created with the " + f"`provide` directive). Got `{param.default}` for `{param.name}`." + ) + processor_cls = param.default.processor_cls + if not ( + # TODO: `Protocol` is not a proper class and this might be version dependent. + # Find a better way to inspect this. + issubclass(param.annotation, Protocol) # type: ignore[arg-type] + or issubclass(processor_cls, param.annotation) + ): + definitions.APIDefinitonError( + f"The type annotaiton for `{param.name}` must either be a `{Protocol}` " + "or a class/subclass of the processor type used as default value. " + f"Got `{param.default}`." + ) + if not issubclass(processor_cls, definitions.ABCProcessor): + raise definitions.APIDefinitonError( + f"`{processor_cls}` must be a subclass of `{definitions.ABCProcessor}`." + ) + return processor_cls + + +class _ProcessorInitParams: + def __init__(self, params: list[inspect.Parameter]) -> None: + self._params = params + self._validate_self_arg() + self._validate_context_arg() + + def _validate_self_arg(self) -> None: + if len(self._params) == 0: + raise definitions.APIDefinitonError( + "Methods must have first argument `self`." + ) + + if self._params[0].name != definitions.SELF_ARG_NAME: + raise definitions.APIDefinitonError( + "Methods must have first argument `self`." + ) + + def _validate_context_arg(self) -> None: + context_exception = definitions.APIDefinitonError( + f"`{definitions.ABCProcessor}` must have " + f"`{definitions.CONTEXT_ARG_NAME}` argument of type " + f"`{definitions.Context}`." + ) + if len(self._params) < 2: + raise context_exception + if self._params[1].name != definitions.CONTEXT_ARG_NAME: + raise context_exception + + param = self._params[1] + param_type = _get_generic_class_type(param.annotation) + if not issubclass(param_type, definitions.Context): + raise context_exception + if not isinstance(param.default, ContextProvisionPlaceholder): + raise definitions.APIDefinitonError( + f"The default value for the `context` argument of a processor's " + f"__init__ must be of type `{ContextProvisionPlaceholder}` (created " + f"with the `provide_context` directive). Got `{param.default}`." + ) + + def validated_dependencies(self) -> Mapping[str, Type[definitions.ABCProcessor]]: + used_classes = set() + dependencies = {} + for param in self._params[2:]: # Skip self and context. + processor_cls = _validate_dependency_arg(param) + if processor_cls in used_classes: + raise definitions.APIDefinitonError( + f"The same processor class cannot be used multiple times for " + f"different arguments. Got previously used `{processor_cls}` " + f"for `{param.name}`." + ) + dependencies[param.name] = processor_cls + used_classes.add(processor_cls) + return dependencies + + +def _validate_init_and_get_dependencies( + cls: Type[definitions.ABCProcessor], +) -> Mapping[str, Type[definitions.ABCProcessor]]: + """The `__init__`-method of a processor must have the follwing signature: + ``` + def __init__( + self, + context: slay.Context = slay.provide_context(), + [dep_0: dep_0_type = slay.provide(dep_0_proc_class),] + [dep_1: dep_1_type = slay.provide(dep_1_proc_class),] + ... + ) -> None: + + * The context argument is required and must have a default construced with the + `provide_context` directive. The type can be templated by a user defined config + e.g. `slay.Context[UserConfig]`. + * The names and number of other - "dependency" - arguments are arbitrary. + * Default values for dependencies must be constructed with the `provide` directive + to make the dependency injeciton work. The argument to `provide` must be a + processor class. + * The type annotation for depdencies can be a procssor class, but it can also be + a `Protocol` with an equivalent `run` method (e.g. for getting correct type + checks when providing fake processors for local testing.). + """ + params = _ProcessorInitParams( + list(inspect.signature(cls.__init__).parameters.values()) + ) + return params.validated_dependencies() + + +def _validate_variable_access(cls: Type[definitions.ABCProcessor]) -> None: + # TODO ensure that processors are only accessed via `provided` in `__init__`,` + # not from manual instantiations on module-level or nested in a processor. + # See other constraints listed in: + # https://www.notion.so/ml-infra/WIP-Orchestration-a8cb4dad00dd488191be374b469ffd0a?pvs=4#7df299eb008f467a80f7ee3c0eccf0f0 + ... + + +def check_and_register_class(cls: Type[definitions.ABCProcessor]) -> None: + processor_descriptor = definitions.ProcessorAPIDescriptor( + processor_cls=cls, + depdendencies=_validate_init_and_get_dependencies(cls), + endpoint=_validate_and_describe_endpoint(cls), + src_path=os.path.abspath(inspect.getfile(cls)), + user_config_type=definitions.TypeDescriptor( + raw=type(cls.default_config.user_config) + ), + ) + logging.debug(f"Descriptor for {cls}:\n{processor_descriptor}\n") + _validate_variable_access(cls) + _global_processor_registry.register_processor(processor_descriptor) + + +# Dependency-Injection / Registry ###################################################### + + +class _BaseProvisionPlaceholder: + """A marker for object to be depdenency injected by the framework.""" + + +class ProcessorProvisionPlaceholder(_BaseProvisionPlaceholder): + # TODO: extend with RPC customization, e.g. timeouts, retries etc. + processor_cls: Type[definitions.ABCProcessor] + + def __init__(self, processor_cls: Type[definitions.ABCProcessor]) -> None: + self.processor_cls = processor_cls + + def __str__(self) -> str: + return f"{self.__class__.__name__}({self._cls_name})" + + @property + def _cls_name(self) -> str: + return self.processor_cls.__name__ + + +class ContextProvisionPlaceholder(_BaseProvisionPlaceholder): + def __str__(self) -> str: + return f"{self.__class__.__name__}" + + +class _ProcessorRegistry: + # Because dependencies are required to be present when registering a processor, + # this dict contains natively a topological sorting of the dependency graph. + _processors: collections.OrderedDict[ + Type[definitions.ABCProcessor], definitions.ProcessorAPIDescriptor + ] + _name_to_cls: MutableMapping[str, Type] + + def __init__(self) -> None: + self._processors = collections.OrderedDict() + self._name_to_cls = {} + + def register_processor( + self, processor_descriptor: definitions.ProcessorAPIDescriptor + ): + for dep in processor_descriptor.depdendencies.values(): + # To depend on a processor, the class must be defined (module initialized) + # which entails that is has already been added to the registry. + assert dep in self._processors, dep + + # Because class are globally unique, to prevent re-use / overwriting of names, + # We must check this in addition. + if processor_descriptor.cls_name in self._name_to_cls: + conflict = self._name_to_cls[processor_descriptor.cls_name] + existing_source_path = self._processors[conflict].src_path + raise definitions.APIDefinitonError( + f"A processor with name `{processor_descriptor.cls_name}` was already " + f"defined, processors names must be uniuqe. The pre-existing name " + f"comes from:\n`{existing_source_path}`\nNew conflicting from\n " + f"{processor_descriptor.src_path}" + ) + + self._processors[processor_descriptor.processor_cls] = processor_descriptor + self._name_to_cls[ + processor_descriptor.cls_name + ] = processor_descriptor.processor_cls + + @property + def processor_descriptors(self) -> list[definitions.ProcessorAPIDescriptor]: + return list(self._processors.values()) + + def get_descriptor( + self, processor_cls: Type[definitions.ABCProcessor] + ) -> definitions.ProcessorAPIDescriptor: + return self._processors[processor_cls] + + def get_dependencies( + self, processor: definitions.ProcessorAPIDescriptor + ) -> Iterable[definitions.ProcessorAPIDescriptor]: + return [ + self._processors[desc] + for desc in self._processors[processor.processor_cls].depdendencies.values() + ] + + +_global_processor_registry = _ProcessorRegistry() + + +# Processor class runtime utils ######################################################## + + +def _determine_arguments(func: Callable, **kwargs): + """Merges proivded and default arguments to effective invocation arguments.""" + sig = inspect.signature(func) + bound_args = sig.bind_partial(**kwargs) + bound_args.apply_defaults() + return bound_args.arguments + + +def ensure_args_are_injected(cls, original_init: Callable, kwargs) -> None: + """Asserts all placeholder markers are replaced by actual objects.""" + final_args = _determine_arguments(original_init, **kwargs) + for name, value in final_args.items(): + if isinstance(value, _BaseProvisionPlaceholder): + raise definitions.UsageError( + f"When initializing class `{cls.__name__}`, for " + f"default argument `{name}` a symbolic placeholder value " + f"was passed (`{value}`). Processors must be either a) locally " + f"instantiated in `{run_local.__name__}` context or b) deployed " + "remotely. Naive instantiations are prohibited." + ) + + +# Local Deployment ##################################################################### + + +def _create_local_context( + processor_cls: Type[definitions.ABCProcessor], +) -> definitions.Context: + if hasattr(processor_cls, "default_config"): + defaults = processor_cls.default_config + return definitions.Context(user_config=defaults.user_config) + return definitions.Context() + + +def _create_modified_init_for_local( + processor_descriptor: definitions.ProcessorAPIDescriptor, + cls_to_instance: MutableMapping[ + Type[definitions.ABCProcessor], definitions.ABCProcessor + ], +): + """Replaces the default argument values with local processor instantiations. + + If this patch is used, processors can be functionally instantiated without + any init args (because the patched defaults are sufficient). + """ + original_init = processor_descriptor.processor_cls.__init__ + + def init_for_local(self: definitions.ABCProcessor, **kwargs) -> None: + logging.debug(f"Patched `__init__` of `{processor_descriptor.cls_name}`.") + kwargs_mod = dict(kwargs) + if definitions.CONTEXT_ARG_NAME not in kwargs_mod: + context = _create_local_context(processor_descriptor.processor_cls) + kwargs_mod[definitions.CONTEXT_ARG_NAME] = context + else: + logging.debug( + f"Use explicitly given context for `{self.__class__.__name__}`." + ) + for arg_name, dep_cls in processor_descriptor.depdendencies.items(): + if arg_name in kwargs_mod: + logging.debug( + f"Use explicitly given instance for `{arg_name}` of " + f"type `{dep_cls.__name__}`." + ) + continue + if dep_cls in cls_to_instance: + logging.debug( + f"Use previously created instace for `{arg_name}` of type " + f"`{dep_cls.__name__}`." + ) + instance = cls_to_instance[dep_cls] + else: + logging.debug( + f"Create new instace for `{arg_name}` of type `{dep_cls.__name__}`." + ) + assert dep_cls._init_is_patched + instance = dep_cls() # type: ignore # Here init args are patched. + cls_to_instance[dep_cls] = instance + + kwargs_mod[arg_name] = instance + + original_init(self, **kwargs_mod) + + return init_for_local + + +@contextlib.contextmanager +def run_local() -> Any: + """Context to run processors with depenedency injection from local instances.""" + type_to_instance: MutableMapping[ + Type[definitions.ABCProcessor], definitions.ABCProcessor + ] = {} + original_inits: MutableMapping[Type[definitions.ABCProcessor], Callable] = {} + + for processor_descriptor in _global_processor_registry.processor_descriptors: + original_inits[ + processor_descriptor.processor_cls + ] = processor_descriptor.processor_cls.__init__ + init_for_local = _create_modified_init_for_local( + processor_descriptor, type_to_instance + ) + processor_descriptor.processor_cls.__init__ = init_for_local # type: ignore[method-assign] + processor_descriptor.processor_cls._init_is_patched = True + try: + yield + finally: + # Restore original classes to unpatched state. + for processor_cls, original_init in original_inits.items(): + processor_cls.__init__ = original_init # type: ignore[method-assign] + processor_cls._init_is_patched = False + + +# Remote Deployment #################################################################### + + +def _create_remote_service( + baseten_client: deploy.BasetenClient, + processor_dir: pathlib.Path, + workflow_root: pathlib.Path, + processor_descriptor: definitions.ProcessorAPIDescriptor, + stub_cls_to_url: Mapping[str, str], + maybe_stub_file: Optional[pathlib.Path], + worfklow_name: str, + generate_only: bool, +) -> definitions.BasetenRemoteDescriptor: + processor_filepath = shutil.copy( + processor_descriptor.src_path, + os.path.join(processor_dir, f"{definitions.PROCESSOR_MODULE}.py"), + ) + code_gen.generate_processor_source( + pathlib.Path(processor_filepath), processor_descriptor + ) + # Only add needed stub URLs. + stub_cls_to_url = { + stub_cls.__name__: stub_cls_to_url[stub_cls.__name__] + for stub_cls in processor_descriptor.depdendencies.values() + } + # Convert to truss and deploy. + # TODO: support file-based config (and/or merge file and python-src configvalues). + slay_config = processor_descriptor.processor_cls.default_config + processor_name = slay_config.name or processor_descriptor.cls_name + model_name = f"{worfklow_name}.{processor_name}" + truss_dir = deploy.make_truss( + processor_dir, + workflow_root, + slay_config, + model_name, + stub_cls_to_url, + maybe_stub_file, + ) + if generate_only: + remote_descriptor = definitions.BasetenRemoteDescriptor( + b10_model_id="dummy", + b10_model_name=model_name, + b10_model_version_id="dymmy", + b10_model_url="https://dummy", + ) + else: + with utils.log_level(logging.INFO): + remote_descriptor = baseten_client.deploy_truss(truss_dir) + logging.debug(remote_descriptor) + return remote_descriptor + + +def _get_ordered_processor_descriptors( + processors: Iterable[Type[definitions.ABCProcessor]], +) -> Iterable[definitions.ProcessorAPIDescriptor]: + """Gather all processors needed and returns a topologically ordered list.""" + needed_processors: set[definitions.ProcessorAPIDescriptor] = set() + + def add_needed_procssors(proc: definitions.ProcessorAPIDescriptor): + needed_processors.add(proc) + for processor_descriptor in _global_processor_registry.get_dependencies(proc): + needed_processors.add(processor_descriptor) + add_needed_procssors(processor_descriptor) + + for proc_cls in processors: + proc = _global_processor_registry.get_descriptor(proc_cls) + add_needed_procssors(proc) + + # Iterating over the registry ensures topological ordering. + return [ + processor_descriptor + for processor_descriptor in _global_processor_registry.processor_descriptors + if processor_descriptor in needed_processors + ] + + +def deploy_remotely( + entrypoint: Type[definitions.ABCProcessor], + worfklow_name: str, + baseten_url: str = "https://app.baseten.co", + generate_only: bool = False, +) -> definitions.BasetenRemoteDescriptor: + """ + * Gathers dependencies of `entrypoint. + * Generates stubs. + * Generates modifies processors to use these stubs. + * Generates truss models and deploys them to baseten. + """ + # TODO: more control e.g. publish vs. draft. + workflow_root = pathlib.Path(sys.argv[0]).absolute().parent + api_key = deploy.get_api_key_from_trussrc() + baseten_client = deploy.BasetenClient(baseten_url, api_key) + entrypoint_descr = _global_processor_registry.get_descriptor(entrypoint) + + ordered_descriptors = _get_ordered_processor_descriptors([entrypoint]) + stub_cls_to_url: dict[str, str] = {} + entrypoint_remote: Optional[definitions.BasetenRemoteDescriptor] = None + for processor_descriptor in ordered_descriptors: + processor_dir = code_gen.make_processor_dir( + workflow_root, worfklow_name, processor_descriptor + ) + maybe_stub_file = code_gen.generate_stubs_for_deps( + processor_dir, + _global_processor_registry.get_dependencies(processor_descriptor), + ) + remote_descriptor = _create_remote_service( + baseten_client, + processor_dir, + workflow_root, + processor_descriptor, + stub_cls_to_url, + maybe_stub_file, + worfklow_name, + generate_only, + ) + stub_cls_to_url[processor_descriptor.cls_name] = remote_descriptor.b10_model_url + if processor_descriptor == entrypoint_descr: + entrypoint_remote = remote_descriptor + + assert entrypoint_remote is not None + return entrypoint_remote diff --git a/slay/public_api.py b/slay/public_api.py new file mode 100644 index 000000000..43e2bf7b6 --- /dev/null +++ b/slay/public_api.py @@ -0,0 +1,62 @@ +""" +TODO: + * Shim to call already hosted basteten model. + * Helper to create a `Processor` from a truss dir. +""" + + +from typing import Any, ContextManager, Type, final + +from slay import definitions, framework + + +def provide_context() -> Any: + """Sets a 'symbolic marker' for injecting a Context object at runtime.""" + return framework.ContextProvisionPlaceholder() + + +def provide(processor_cls: Type[definitions.ABCProcessor]) -> Any: + """Sets a 'symbolic marker' for injecting a stub or local processor at runtime.""" + # TODO: consider adding retry or timeout configuraiton here. + return framework.ProcessorProvisionPlaceholder(processor_cls) + + +class ProcessorBase(definitions.ABCProcessor[definitions.UserConfigT]): + def __init_subclass__(cls, **kwargs) -> None: + super().__init_subclass__(**kwargs) + framework.check_and_register_class(cls) + + original_init = cls.__init__ + + def init_with_arg_check(self, *args, **kwargs): + if args: + raise definitions.UsageError("Only kwargs are allowed.") + framework.ensure_args_are_injected(cls, original_init, kwargs) + original_init(self, *args, **kwargs) + + cls.__init__ = init_with_arg_check # type: ignore[method-assign] + + def __init__( + self, context: definitions.Context[definitions.UserConfigT] = provide_context() + ) -> None: + self._context = context + + @final + @property + def user_config(self) -> definitions.UserConfigT: + return self._context.user_config + + +def deploy_remotely( + entrypoint: Type[definitions.ABCProcessor], + workflow_name: str, + generate_only: bool = False, +) -> definitions.BasetenRemoteDescriptor: + return framework.deploy_remotely( + entrypoint, workflow_name, generate_only=generate_only + ) + + +def run_local() -> ContextManager[None]: + """Context manager for using in-process instantiations of processor dependencies.""" + return framework.run_local() diff --git a/slay/stub.py b/slay/stub.py new file mode 100644 index 000000000..37ade9551 --- /dev/null +++ b/slay/stub.py @@ -0,0 +1,62 @@ +import abc +import functools +from typing import Type, TypeVar + +import httpx +from slay import definitions + + +def _handle_respose(response: httpx.Response): + # TODO: improve error handling, extract context from response and include in + # re-raised exception. Consider re-raising same exception or if not a use a + # generic "RPCError" exception class or similar. + if response.is_server_error: + raise ValueError(response) + if response.is_client_error: + raise ValueError(response) + return response.json() + + +class BasetenSession: + """Helper to invoke predict method on baseten deployments.""" + + # TODO: make timeout, retries etc. configurable. + def __init__(self, url: str, api_key: str) -> None: + self._auth_header = {"Authorization": f"Api-Key {api_key}"} + self._url = url + + @functools.cached_property + def _client_sync(self) -> httpx.Client: + return httpx.Client(base_url=self._url, headers=self._auth_header) + + @functools.cached_property + def _client_async(self) -> httpx.AsyncClient: + return httpx.AsyncClient(base_url=self._url, headers=self._auth_header) + + def predict_sync(self, json_paylod): + return _handle_respose( + self._client_sync.post(definitions.PREDICT_ENDPOINT_NAME, json=json_paylod) + ) + + async def predict_async(self, json_paylod): + return _handle_respose( + await self._client_async.post( + definitions.PREDICT_ENDPOINT_NAME, json=json_paylod + ) + ) + + +class StubBase(abc.ABC): + @abc.abstractmethod + def __init__(self, url: str, api_key: str) -> None: + ... + + +StubT = TypeVar("StubT", bound=StubBase) + + +def stub_factory(stub_cls: Type[StubT], context: definitions.Context) -> StubT: + return stub_cls( + url=context.get_stub_url(stub_cls.__name__), + api_key=context.get_baseten_api_key(), + ) diff --git a/slay/truss_adapter/__init__.py b/slay/truss_adapter/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/slay/truss_adapter/code_gen.py b/slay/truss_adapter/code_gen.py new file mode 100644 index 000000000..3b44ab373 --- /dev/null +++ b/slay/truss_adapter/code_gen.py @@ -0,0 +1,121 @@ +import logging +import pathlib +from typing import Any + +import libcst +from slay import definitions, utils +from slay.truss_adapter import model_skeleton + + +class _SpecifyProcessorTypeAnnotation(libcst.CSTTransformer): + def __init__(self, new_annotaiton: str) -> None: + super().__init__() + self._new_annotaiton = new_annotaiton + + def leave_SimpleStatementLine( + self, + original_node: libcst.SimpleStatementLine, + updated_node: libcst.SimpleStatementLine, + ) -> libcst.SimpleStatementLine: + new_body: list[Any] = [] + for statement in updated_node.body: + if ( + isinstance(statement, libcst.AnnAssign) + and isinstance(statement.target, libcst.Name) + and statement.target.value == "_processor" + ): + new_annotation = libcst.Annotation( + annotation=libcst.Name(value=self._new_annotaiton) + ) + new_statement = statement.with_changes(annotation=new_annotation) + new_body.append(new_statement) + else: + new_body.append(statement) + + return updated_node.with_changes(body=tuple(new_body)) + + +def generate_truss_model( + processor_desrciptor: definitions.ProcessorAPIDescriptor, +) -> tuple[libcst.CSTNode, list[libcst.SimpleStatementLine], libcst.CSTNode]: + logging.info(f"Generating Baseten model for `{processor_desrciptor.cls_name}`.") + skeleton_tree = libcst.parse_module( + pathlib.Path(model_skeleton.__file__).read_text() + ) + + imports = [ + node + for node in skeleton_tree.body + if isinstance(node, libcst.SimpleStatementLine) + and any( + isinstance(stmt, libcst.Import) or isinstance(stmt, libcst.ImportFrom) + for stmt in node.body + ) + ] + + class_definition: libcst.ClassDef = utils.expect_one( + node + for node in skeleton_tree.body + if isinstance(node, libcst.ClassDef) + and node.name.value == model_skeleton.ProcessorModel.__name__ + ) + + load_def = libcst.parse_statement( + f""" +def load(self) -> None: + self._processor = {processor_desrciptor.cls_name}(context=self._context) +""" + ) + + endpoint_descriptor = processor_desrciptor.endpoint + def_str = "async def" if endpoint_descriptor.is_async else "def" + # Convert json payload dict to processor args. + obj_arg_parts = ", ".join( + ( + f"{arg_name}={arg_type.as_src_str()}.parse_obj(payload['{arg_name}'])" + if arg_type.is_pydantic + else f"{arg_name}=payload['{arg_name}']" + ) + for arg_name, arg_type in endpoint_descriptor.input_names_and_tyes + ) + + if len(endpoint_descriptor.output_types) == 1: + output_type = endpoint_descriptor.output_types[0] + result = "result.dict()" if output_type.is_pydantic else "result" + else: + result_parts = [ + f"result[{i}].dict()" if t.is_pydantic else f"result[{i}]" + for i, t in enumerate(endpoint_descriptor.output_types) + ] + result = f"({', '.join(result_parts)})" + + maybe_await = "await " if endpoint_descriptor.is_async else "" + + predict_def = libcst.parse_statement( + f""" +{def_str} predict(self, payload): + result = {maybe_await}self._processor.{endpoint_descriptor.name}({obj_arg_parts}) + return {result} + +""" + ) + new_body: list[libcst.BaseStatement] = list( # type: ignore[assignment,misc] + class_definition.body.body + ) + [ + load_def, + predict_def, + ] + new_block = libcst.IndentedBlock(body=new_body) + class_definition = class_definition.with_changes(body=new_block) + class_definition = class_definition.visit( # type: ignore[assignment] + _SpecifyProcessorTypeAnnotation(processor_desrciptor.cls_name) + ) + + if issubclass(processor_desrciptor.user_config_type.raw, type(None)): + userconfig_pin = libcst.parse_statement("UserConfigT = None") + else: + userconfig_pin = libcst.parse_statement( + f"UserConfigT = {processor_desrciptor.user_config_type.as_src_str()}" + ) + + return class_definition, imports, userconfig_pin diff --git a/slay/truss_adapter/deploy.py b/slay/truss_adapter/deploy.py new file mode 100644 index 000000000..f7af64a1d --- /dev/null +++ b/slay/truss_adapter/deploy.py @@ -0,0 +1,337 @@ +import configparser +import enum +import logging +import os +import pathlib +import shutil +import time +from pathlib import Path +from typing import Any, Mapping, Optional, cast + +import httpx +import requests +import slay +import truss +from slay import definitions, utils +from slay.utils import ConditionStatus +from truss import truss_config +from truss.contexts.image_builder import serving_image_builder +from truss.remote import remote_factory, truss_remote +from truss.remote.baseten import service as b10_service + +_REQUIREMENTS_FILENAME = "pip_requirements.txt" +_MODEL_CLASS_FILENAME = "processor.py" +_MODEL_CLASS_NAME = "ProcessorModel" +_TRUSS_DIR = ".truss_gen" + + +def _copy_python_source_files(root_dir: pathlib.Path, dest_dir: pathlib.Path) -> None: + """Copy all python files under root recursively, but skips generated code.""" + + def python_files_only(path, names): + return [ + name + for name in names + if os.path.isfile(os.path.join(path, name)) + and not name.endswith(".py") + or definitions.GENERATED_CODE_DIR in name + ] + + shutil.copytree(root_dir, dest_dir, ignore=python_files_only, dirs_exist_ok=True) + + +def _make_truss_config( + truss_dir: pathlib.Path, + slay_config: definitions.Config, + stub_cls_to_url: Mapping[str, str], + model_name: str, +) -> truss_config.TrussConfig: + """Generate a truss config for a processor.""" + config = truss_config.TrussConfig() + config.model_name = model_name + config.model_class_filename = _MODEL_CLASS_FILENAME + config.model_class_name = _MODEL_CLASS_NAME + # Compute. + compute = slay_config.get_compute_spec() + config.resources.cpu = compute.cpu + config.resources.accelerator = truss_config.AcceleratorSpec.from_str(compute.gpu) + config.resources.use_gpu = bool(compute.gpu) + # Image. + image = slay_config.get_image_spec() + config.base_image = truss_config.BaseImage(image=image.base_image) + pip_requirements: list[str] = [] + if image.pip_requirements_file: + image.pip_requirements_file.raise_if_not_exists() + pip_requirements.extend( + req + for req in pathlib.Path(image.pip_requirements_file.abs_path) + .read_text() + .splitlines() + if not req.strip().startswith("#") + ) + pip_requirements.extend(image.pip_requirements) + # `pip_requirements` will add server requirements which give version conflicts. + # config.requirements = pip_requirements + pip_requirements_file_path = truss_dir / _REQUIREMENTS_FILENAME + pip_requirements_file_path.write_text("\n".join(pip_requirements)) + # TODO: apparently absolute paths don't work with remote build (but work in local). + config.requirements_file = _REQUIREMENTS_FILENAME # str(pip_requirements_file_path) + config.system_packages = image.apt_requirements + # Assets. + assets = slay_config.get_asset_spec() + config.secrets = assets.secrets + if definitions.BASTEN_API_SECRET_NAME not in config.secrets: + config.secrets[definitions.BASTEN_API_SECRET_NAME] = "***" + else: + logging.info( + f"Workflows automatically add {definitions.BASTEN_API_SECRET_NAME} " + "to secrets - no need to manually add it." + ) + config.model_cache.models = assets.cached + # Metadata. + slay_metadata: definitions.TrussMetadata = definitions.TrussMetadata( + user_config=slay_config.user_config, + stub_cls_to_url=stub_cls_to_url, + ) + config.model_metadata[definitions.TRUSS_CONFIG_SLAY_KEY] = slay_metadata.dict() + return config + + +def make_truss( + processor_dir: pathlib.Path, + workflow_root: pathlib.Path, + slay_config: definitions.Config, + model_name: str, + stub_cls_to_url: Mapping[str, str], + maybe_stub_file: Optional[pathlib.Path], +) -> pathlib.Path: + truss_dir = processor_dir / _TRUSS_DIR + truss_dir.mkdir(exist_ok=True) + config = _make_truss_config(truss_dir, slay_config, stub_cls_to_url, model_name) + + config.write_to_yaml_file( + truss_dir / serving_image_builder.CONFIG_FILE, verbose=False + ) + + # Copy other sources. + model_dir = truss_dir / truss_config.DEFAULT_MODEL_MODULE_DIR + model_dir.mkdir(parents=True, exist_ok=True) + shutil.copy( + processor_dir / f"{definitions.PROCESSOR_MODULE}.py", + model_dir / _MODEL_CLASS_FILENAME, + ) + + pkg_dir = truss_dir / truss_config.DEFAULT_BUNDLED_PACKAGES_DIR + pkg_dir.mkdir(parents=True, exist_ok=True) + if maybe_stub_file is not None: + shutil.copy(maybe_stub_file, pkg_dir) + # TODO This assume all imports are absolute w.r.t workflow root (or site-packages). + # Also: apparently packages need an `__init__`, or crash. + _copy_python_source_files(workflow_root, pkg_dir / pkg_dir) + + # TODO Truss package contains this from `{ include = "slay", from = "." }` + # pyproject.toml. But for quick dev loop just copy from local. + shutil.copytree( + os.path.dirname(slay.__file__), + pkg_dir / "slay", + dirs_exist_ok=True, + ) + return truss_dir + + +def get_api_key_from_trussrc() -> str: + try: + return remote_factory.load_config().get("baseten", "api_key") + except configparser.Error as e: + raise definitions.MissingDependencyError( + "You must have a `trussrc` file with a baseten API key." + ) from e + + +class _BasetenEnv(enum.Enum): + LOCAL = enum.auto() + STAGING = enum.auto() + PROD = enum.auto() + DEV = enum.auto() + + +def _infer_env(baseten_url: str) -> _BasetenEnv: + if baseten_url in {"localhost", "127.0.0.1", "0.0.0.0"}: + return _BasetenEnv.LOCAL + + if "staging" in baseten_url: + return _BasetenEnv.STAGING + + if "dev" in baseten_url: + return _BasetenEnv.DEV + + return _BasetenEnv.PROD + + +def _model_url(baseten_env: _BasetenEnv, model_id: str, production: bool) -> str: + # TODO: get URLs from REST API instead. + if baseten_env == _BasetenEnv.LOCAL: + return f"http://localhost:8000/models/{model_id}" + + model_env = "production" if production else "development" + + if baseten_env in {_BasetenEnv.STAGING, _BasetenEnv.DEV}: + env_str = f".{str(baseten_env).lower()}" + else: + env_str = "" + + return f"https://model-{model_id}.api{env_str}.baseten.co/{model_env}" + + +class BasetenClient: + """Helper to deploy models on baseten and inquire their status.""" + + # TODO: use rest APIs where possible in stead of graphql_query. + def __init__(self, baseten_url: str, baseten_api_key: str) -> None: + self._baseten_url = baseten_url + self._baseten_env = _infer_env(baseten_url) + self._baseten_api_key = baseten_api_key + self._remote_provider: truss_remote.TrussRemote = self._create_remote_provider() + + def deploy_truss(self, truss_root: Path) -> definitions.BasetenRemoteDescriptor: + # TODO: add intentional control of pushing as dev or prod model. + tr = truss.load(str(truss_root)) + model_name = tr.spec.config.model_name + assert model_name is not None + + logging.info(f"Deploying model `{model_name}`.") + production = False + # Models must be trusted to use the API KEY secret. + service = self._remote_provider.push( + tr, model_name=model_name, trusted=True, publish=production + ) + if service is None: + raise ValueError() + service = cast(b10_service.BasetenService, service) + + model_service = definitions.BasetenRemoteDescriptor( + b10_model_id=service.model_id, + b10_model_version_id=service.model_version_id, + b10_model_name=model_name, + b10_model_url=_model_url(self._baseten_env, service.model_id, production), + ) + return model_service + + def get_model(self, model_name: str) -> definitions.BasetenRemoteDescriptor: + query_string = f""" + {{ + model_version(name: "{model_name}") {{ + oracle{{ + id + name + versions{{ + id + current_deployment_status + }} + }} + }} + }} + """ + try: + resp = self._post_graphql_query(query_string, retries=True)["data"][ + "model_version" + ]["oracle"] + except Exception as e: + raise definitions.MissingDependencyError("Model cout not be found.") from e + + model_id = resp["id"] + model_version_id = resp["versions"][0]["id"] + return definitions.BasetenRemoteDescriptor( + b10_model_id=model_id, + b10_model_version_id=model_version_id, + b10_model_url=_model_url(self._baseten_env, model_id, False), + b10_model_name=model_name, + ) + + def _create_remote_provider(self) -> truss_remote.TrussRemote: + remote_config = truss_remote.RemoteConfig( + name="baseten", + configs={ + "remote_provider": "baseten", + "api_key": self._baseten_api_key, + "remote_url": self._baseten_url, + }, + ) + remote_factory.RemoteFactory.update_remote_config(remote_config) + return remote_factory.RemoteFactory.create(remote="baseten") + + def _wait_for_model_to_be_ready(self, model_version_id: str) -> None: + logging.info(f"Waiting for model {model_version_id} to be ready") + + def is_model_ready() -> ConditionStatus: + query_string = f""" + {{ + model_version(id: "{model_version_id}") {{ + current_model_deployment_status {{ + status + reason + }} + }} + }} + """ + resp = self._post_graphql_query(query_string, retries=True) + status = resp["data"]["model_version"]["current_model_deployment_status"][ + "status" + ] + logging.info(f"Model status: {status}") + if status == "MODEL_READY": + return ConditionStatus.SUCCESS + if "FAILED" in status: + return ConditionStatus.FAILURE + return ConditionStatus.NOT_DONE + + is_ready = utils.wait_for_condition(is_model_ready, 1800) + if not is_ready: + raise RuntimeError("Model failed to be ready in 30 minutes") + + def _post_graphql_query(self, query_string: str, retries: bool = False) -> dict: + headers = {"Authorization": f"Api-Key {self._baseten_api_key}"} + while True: + resp = requests.post( + f"{self._baseten_url}/graphql/", + data={"query": query_string}, + headers=headers, + timeout=120, + ) + if not resp.ok: + if not retries: + logging.error( + f"GraphQL endpoint failed with error: {resp.content.decode()}" + ) + resp.raise_for_status() + else: + logging.info( + f"GraphQL endpoint failed with error: {resp.content.decode()}, " + "retries are on, ignore" + ) + else: + resp_dict = resp.json() + errors = resp_dict.get("errors") + if errors: + raise RuntimeError(errors[0]["message"], resp) + return resp_dict + + +def call_workflow_dbg( + remote: definitions.BasetenRemoteDescriptor, + payload: Any, + max_retries: int = 100, + retry_wait_sec: int = 3, +) -> httpx.Response: + """For debugging only: tries calling a workflow.""" + api_key = get_api_key_from_trussrc() + session = httpx.Client( + base_url=remote.b10_model_url, headers={"Authorization": f"Api-Key {api_key}"} + ) + for _ in range(max_retries): + try: + response = session.post(definitions.PREDICT_ENDPOINT_NAME, json=payload) + return response + except Exception: + time.sleep(retry_wait_sec) + raise diff --git a/slay/truss_adapter/model_skeleton.py b/slay/truss_adapter/model_skeleton.py new file mode 100644 index 000000000..44d52403d --- /dev/null +++ b/slay/truss_adapter/model_skeleton.py @@ -0,0 +1,36 @@ +import pathlib + +import pydantic +from slay import definitions +from truss.templates.shared import secrets_resolver + +# Better: in >=3.10 use `TypeAlias`. +UserConfigT = pydantic.BaseModel + + +class ProcessorModel: + _context: definitions.Context[UserConfigT] + _processor: definitions.ABCProcessor + + def __init__( + self, config: dict, data_dir: pathlib.Path, secrets: secrets_resolver.Secrets + ) -> None: + truss_metadata: definitions.TrussMetadata[ + UserConfigT + ] = definitions.TrussMetadata[UserConfigT].parse_obj( + config["model_metadata"][definitions.TRUSS_CONFIG_SLAY_KEY] + ) + self._context = definitions.Context[UserConfigT]( + user_config=truss_metadata.user_config, + stub_cls_to_url=truss_metadata.stub_cls_to_url, + secrets=secrets, + ) + + # Below illustrated code will be added by code generation. + + # def load(self): + # self._processor = {ProcssorCls}(self._context) + + # Sync or async. + # def predict(self, payload): + # return self._processor.{method_name}(payload) diff --git a/slay/utils.py b/slay/utils.py new file mode 100644 index 000000000..10e56232e --- /dev/null +++ b/slay/utils.py @@ -0,0 +1,55 @@ +import contextlib +import enum +import logging +import time +from typing import Callable, Iterable, TypeVar + +T = TypeVar("T") + + +@contextlib.contextmanager +def log_level(level: int): + """Change loglevel for code in this context.""" + current_logging_level = logging.getLogger().getEffectiveLevel() + logging.getLogger().setLevel(level) + try: + yield + finally: + logging.getLogger().setLevel(current_logging_level) + + +def expect_one(it: Iterable[T]) -> T: + """Assert that an iterable has exactly on element and return it.""" + it = iter(it) + try: + element = next(it) + except StopIteration: + raise ValueError("Iterable is empty.") + + try: + _ = next(it) + except StopIteration: + return element + + raise ValueError("Iterable has more than one element.") + + +class ConditionStatus(enum.Enum): + SUCCESS = enum.auto() + FAILURE = enum.auto() + NOT_DONE = enum.auto() + + +def wait_for_condition( + condition: Callable[[], ConditionStatus], + retries: int = 10, + sleep_between_retries_secs: int = 1, +) -> bool: + for _ in range(retries): + cond_status = condition() + if cond_status == ConditionStatus.SUCCESS: + return True + if cond_status == ConditionStatus.FAILURE: + return False + time.sleep(sleep_between_retries_secs) + return False From c4b2d19ee71f041ff2f93ddb78a31369d0cf23de Mon Sep 17 00:00:00 2001 From: Sidharth Shanker Date: Wed, 27 Mar 2024 08:55:50 -0400 Subject: [PATCH 5/6] Rename GCS -> GCP. (#880) --- truss/tests/test_config.py | 6 +++--- truss/truss_config.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/truss/tests/test_config.py b/truss/tests/test_config.py index 106b9a957..9a578ba51 100644 --- a/truss/tests/test_config.py +++ b/truss/tests/test_config.py @@ -118,7 +118,7 @@ def test_acc_spec_from_str(input_str, expected_acc): "image": "custom_base_image", "python_executable_path": "/path/python", "docker_auth": { - "auth_method": "GCS_SERVICE_ACCOUNT_JSON", + "auth_method": "GCP_SERVICE_ACCOUNT_JSON", "secret_name": "some-secret-name", "registry": "some-docker-registry", }, @@ -127,7 +127,7 @@ def test_acc_spec_from_str(input_str, expected_acc): image="custom_base_image", python_executable_path="/path/python", docker_auth=DockerAuthSettings( - auth_method=DockerAuthType.GCS_SERVICE_ACCOUNT_JSON, + auth_method=DockerAuthType.GCP_SERVICE_ACCOUNT_JSON, secret_name="some-secret-name", registry="some-docker-registry", ), @@ -136,7 +136,7 @@ def test_acc_spec_from_str(input_str, expected_acc): "image": "custom_base_image", "python_executable_path": "/path/python", "docker_auth": { - "auth_method": "GCS_SERVICE_ACCOUNT_JSON", + "auth_method": "GCP_SERVICE_ACCOUNT_JSON", "secret_name": "some-secret-name", "registry": "some-docker-registry", }, diff --git a/truss/truss_config.py b/truss/truss_config.py index 9c2531a5b..247328b93 100644 --- a/truss/truss_config.py +++ b/truss/truss_config.py @@ -303,7 +303,7 @@ class DockerAuthType(Enum): authentication we support. """ - GCS_SERVICE_ACCOUNT_JSON = "GCS_SERVICE_ACCOUNT_JSON" + GCP_SERVICE_ACCOUNT_JSON = "GCP_SERVICE_ACCOUNT_JSON" @dataclass From d389b1b64edfda6c328fa7ad80e1224ff2b0db36 Mon Sep 17 00:00:00 2001 From: basetenbot <96544894+basetenbot@users.noreply.github.com> Date: Wed, 27 Mar 2024 13:20:56 +0000 Subject: [PATCH 6/6] Bump version to 0.9.7 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 70613b143..fd8a13c5f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "truss" -version = "0.9.7rc6" +version = "0.9.7" description = "A seamless bridge from model development to model delivery" license = "MIT" readme = "README.md"