Skip to content

Commit

Permalink
Cleanup Truss APIs around predict URLs. (#898)
Browse files Browse the repository at this point in the history
  • Loading branch information
marius-baseten committed Apr 16, 2024
1 parent 1303747 commit 2e412cc
Show file tree
Hide file tree
Showing 11 changed files with 111 additions and 113 deletions.
3 changes: 1 addition & 2 deletions slay/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,6 @@ def _create_remote_service(
baseten_service = baseten_client.deploy_truss(
truss_dir, publish=options.publish, promote=options.promote
)
logs_url = baseten_client.get_logs_url(baseten_service)
# Assuming baseten_url is like "https://app.baseten.co" or ""https://app.dev.baseten.co",
deploy_url = options.baseten_url.replace(
"https://", f"https://model-{baseten_service.model_id}."
Expand All @@ -577,7 +576,7 @@ def _create_remote_service(
service = definitions.ServiceDescriptor(
name=model_name, predict_url=f"{deploy_url}/predict"
)
logging.info(f"🪵 View logs for your deployment at {logs_url}.")
logging.info(f"🪵 View logs for your deployment at {baseten_service.logs_url}.")
else:
raise NotImplementedError(options)

Expand Down
3 changes: 0 additions & 3 deletions slay/truss_adapter/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,3 @@ def deploy_truss(
if service is None:
raise ValueError()
return cast(b10_service.BasetenService, service)

def get_logs_url(self, service: b10_service.BasetenService) -> str:
return self._remote_provider.get_remote_logs_url(service)
6 changes: 2 additions & 4 deletions truss/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,7 @@ def watch(
sys.exit(1)

service = remote_provider.get_service(model_identifier=ModelName(model_name))
logs_url = remote_provider.get_remote_logs_url(service)
rich.print(f"🪵 View logs for your deployment at {logs_url}")
rich.print(f"🪵 View logs for your deployment at {service.logs_url}")
remote_provider.sync_truss_to_dev_version_by_name(model_name, target_directory)


Expand Down Expand Up @@ -537,8 +536,7 @@ def push(
it will become the next production deployment of your model."""
console.print(promotion_text, style="green")

logs_url = remote_provider.get_remote_logs_url(service) # type: ignore[attr-defined]
rich.print(f"🪵 View logs for your deployment at {logs_url}")
rich.print(f"🪵 View logs for your deployment at {service.logs_url}")
if wait:
start_time = time.time()
with console.status("[bold green]Deploying...") as status:
Expand Down
29 changes: 24 additions & 5 deletions truss/remote/baseten/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from enum import Enum
from typing import Optional
from typing import Any, Optional

import requests
from truss.remote.baseten.auth import AuthService
Expand All @@ -9,6 +9,18 @@

logger = logging.getLogger(__name__)

API_URL_MAPPING = {
"https://app.baseten.co": "https://api.baseten.co",
"https://app.staging.baseten.co": "https://api.staging.baseten.co",
"https://app.dev.baseten.co": "https://api.staging.baseten.co",
# For local development, this is how we map URLs
"http://localhost:8000": "http://api.localhost:8000",
}

# If a non-standard domain is used with the baseten remote, default to
# using the production api routes
DEFAULT_API_DOMAIN = "https://api.baseten.co"


class BasetenApi:
"""
Expand All @@ -23,14 +35,21 @@ class BasetenApi:
class GraphQLErrorCodes(Enum):
RESOURCE_NOT_FOUND = "RESOURCE_NOT_FOUND"

def __init__(
self, graphql_api_url: str, rest_api_url: str, auth_service: AuthService
):
def __init__(self, remote_url: str, auth_service: AuthService) -> None:
graphql_api_url = f"{remote_url}/graphql/"
# Ensure we strip off trailing '/' to denormalize URLs.
rest_api_url = API_URL_MAPPING.get(remote_url.strip("/"), DEFAULT_API_DOMAIN)

self._remote_url = remote_url
self._graphql_api_url = graphql_api_url
self._rest_api_url = rest_api_url
self._auth_service = auth_service
self._auth_token = self._auth_service.authenticate()

@property
def remote_url(self) -> str:
return self._remote_url

def _post_graphql_query(self, query_string: str) -> dict:
headers = self._auth_token.header()
resp = requests.post(
Expand Down Expand Up @@ -254,7 +273,7 @@ def patch_draft_truss(self, model_name, patch_request):
resp = self._post_graphql_query(query_string)
return resp["data"]["patch_draft_truss"]

def get_deployment(self, model_id: str, deployment_id: str) -> str:
def get_deployment(self, model_id: str, deployment_id: str) -> Any:
headers = self._auth_token.header()
resp = requests.get(
f"{self._rest_api_url}/v1/models/{model_id}/deployments/{deployment_id}",
Expand Down
6 changes: 3 additions & 3 deletions truss/remote/baseten/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,20 @@
class ApiKey:
value: str

def __init__(self, value: str):
def __init__(self, value: str) -> None:
self.value = value

def header(self):
return {"Authorization": f"Api-Key {self.value}"}


class AuthService:
def __init__(self, api_key: Optional[str] = None):
def __init__(self, api_key: Optional[str] = None) -> None:
if not api_key:
api_key = os.environ.get("BASETEN_API_KEY", None)
self._api_key = api_key

def validate(self):
def validate(self) -> None:
if not self._api_key:
raise AuthorizationError("No API key provided.")

Expand Down
37 changes: 5 additions & 32 deletions truss/remote/baseten/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,39 +27,18 @@
from truss.remote.baseten.error import ApiError
from truss.remote.baseten.service import BasetenService
from truss.remote.baseten.utils.transfer import base64_encoded_json_str
from truss.remote.truss_remote import TrussRemote, TrussService
from truss.remote.truss_remote import TrussRemote
from truss.truss_config import ModelServer
from truss.truss_handle import TrussHandle
from truss.util.path import is_ignored, load_trussignore_patterns
from watchfiles import watch

API_URL_MAPPING = {
"https://app.baseten.co": "https://api.baseten.co",
"https://app.staging.baseten.co": "https://api.staging.baseten.co",
"https://app.dev.baseten.co": "https://api.staging.baseten.co",
# For local development, this is how we map URLs
"http://localhost:8000": "http://api.localhost:8000",
}

# If a non-standard domain is used with the baseten remote, default to
# using the production api routes
DEFAULT_API_DOMAIN = "https://api.baseten.co"


class BasetenRemote(TrussRemote):
def __init__(self, remote_url: str, api_key: str, **kwargs):
super().__init__(remote_url, **kwargs)
self._auth_service = AuthService(api_key=api_key)
self._api = BasetenApi(
f"{self._remote_url}/graphql/",
# Ensure we strip off trailing '/' to denormalize
# URLs.
API_URL_MAPPING.get(self._remote_url.strip("/"), DEFAULT_API_DOMAIN),
self._auth_service,
)

def authenticate(self):
return self._auth_service.validate()
self._api = BasetenApi(remote_url, self._auth_service)

def push( # type: ignore
self,
Expand All @@ -70,7 +49,7 @@ def push( # type: ignore
promote: bool = False,
preserve_previous_prod_deployment: bool = False,
deployment_name: Optional[str] = None,
):
) -> BasetenService:
if model_name.isspace():
raise ValueError("Model name cannot be empty")

Expand Down Expand Up @@ -214,17 +193,11 @@ def get_service(self, **kwargs) -> BasetenService:
api=self._api,
)

def get_remote_logs_url(
self,
service: TrussService,
) -> str:
return service.logs_url(self._remote_url)

def sync_truss_to_dev_version_by_name(
self,
model_name: str,
target_directory: str,
):
) -> None:
# verify that development deployment exists for given model name
dev_version = get_dev_version(
self._api, model_name
Expand Down Expand Up @@ -257,7 +230,7 @@ def patch(
self,
watch_path: Path,
truss_ignore_patterns: List[str],
):
) -> None:
from truss.cli.console import console, error_console

try:
Expand Down
42 changes: 35 additions & 7 deletions truss/remote/baseten/service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import time
from typing import Dict, Optional
import urllib.parse
from typing import Any, Dict, Iterator, Optional

import requests
from tenacity import retry, stop_after_delay, wait_fixed
Expand All @@ -12,6 +13,14 @@
DEFAULT_STREAM_ENCODING = "utf-8"


def _add_model_subdomain(rest_api_url: str, model_subdomain: str) -> str:
"""E.g. `https://api.baseten.co` -> `https://{model_subdomain}.api.baseten.co`"""
parsed_url = urllib.parse.urlparse(rest_api_url)
new_netloc = f"{model_subdomain}.{parsed_url.netloc}"
model_url = parsed_url._replace(netloc=new_netloc)
return str(urllib.parse.urlunparse(model_url))


class BasetenService(TrussService):
def __init__(
self,
Expand Down Expand Up @@ -51,9 +60,9 @@ def invocation_url(self) -> str:
def predict(
self,
model_request_body: Dict,
):
) -> Any:
response = self._send_request(
self.invocation_url, "POST", data=model_request_body, stream=True
self.predict_url, "POST", data=model_request_body, stream=True
)

if response.headers.get("transfer-encoding") == "chunked":
Expand Down Expand Up @@ -85,14 +94,33 @@ def decode_content():
def authenticate(self) -> dict:
return self._auth_service.authenticate().header()

def logs_url(self, base_url: str) -> str:
return f"{base_url}/models/{self._model_id}/logs/{self._model_version_id}"
@property
def logs_url(self) -> str:
return (
f"{self._api.remote_url}/models/{self._model_id}/"
f"logs/{self._model_version_id}"
)

@property
def predict_url(self) -> str:
"""
Get the URL for the prediction endpoint.
"""
# E.g. `https://api.baseten.co` -> `https://model-{model_id}.api.baseten.co`
url = _add_model_subdomain(self._api.remote_url, f"/model-{self.model_id}")
if self.is_draft:
# "https://model-{model_id}.api.baseten.co/development".
url = f"{url}/development/predict"
else:
# "https://model-{model_id}.api.baseten.co/deployment/{deployment_id}".
url = f"{url}/deployment/{self.model_version_id}/predict"
return url

@retry(stop=stop_after_delay(60), wait=wait_fixed(1), reraise=True)
def _fetch_deployment(self):
def _fetch_deployment(self) -> Any:
return self._api.get_deployment(self._model_id, self._model_version_id)

def poll_deployment_status(self):
def poll_deployment_status(self) -> Iterator[str]:
"""
Wait for the service to be deployed.
"""
Expand Down

0 comments on commit 2e412cc

Please sign in to comment.