Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup Truss APIs around predict URLs. #898

Merged
merged 1 commit into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 1 addition & 2 deletions slay/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,6 @@ def _create_remote_service(
baseten_service = baseten_client.deploy_truss(
truss_dir, publish=options.publish, promote=options.promote
)
logs_url = baseten_client.get_logs_url(baseten_service)
# Assuming baseten_url is like "https://app.baseten.co" or ""https://app.dev.baseten.co",
deploy_url = options.baseten_url.replace(
"https://", f"https://model-{baseten_service.model_id}."
Expand All @@ -577,7 +576,7 @@ def _create_remote_service(
service = definitions.ServiceDescriptor(
name=model_name, predict_url=f"{deploy_url}/predict"
)
logging.info(f"馃 View logs for your deployment at {logs_url}.")
logging.info(f"馃 View logs for your deployment at {baseten_service.logs_url}.")
else:
raise NotImplementedError(options)

Expand Down
3 changes: 0 additions & 3 deletions slay/truss_adapter/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,3 @@ def deploy_truss(
if service is None:
raise ValueError()
return cast(b10_service.BasetenService, service)

def get_logs_url(self, service: b10_service.BasetenService) -> str:
return self._remote_provider.get_remote_logs_url(service)
6 changes: 2 additions & 4 deletions truss/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,7 @@ def watch(
sys.exit(1)

service = remote_provider.get_service(model_identifier=ModelName(model_name))
logs_url = remote_provider.get_remote_logs_url(service)
rich.print(f"馃 View logs for your deployment at {logs_url}")
rich.print(f"馃 View logs for your deployment at {service.logs_url}")
remote_provider.sync_truss_to_dev_version_by_name(model_name, target_directory)


Expand Down Expand Up @@ -537,8 +536,7 @@ def push(
it will become the next production deployment of your model."""
console.print(promotion_text, style="green")

logs_url = remote_provider.get_remote_logs_url(service) # type: ignore[attr-defined]
rich.print(f"馃 View logs for your deployment at {logs_url}")
rich.print(f"馃 View logs for your deployment at {service.logs_url}")
if wait:
start_time = time.time()
with console.status("[bold green]Deploying...") as status:
Expand Down
29 changes: 24 additions & 5 deletions truss/remote/baseten/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from enum import Enum
from typing import Optional
from typing import Any, Optional

import requests
from truss.remote.baseten.auth import AuthService
Expand All @@ -9,6 +9,18 @@

logger = logging.getLogger(__name__)

API_URL_MAPPING = {
"https://app.baseten.co": "https://api.baseten.co",
"https://app.staging.baseten.co": "https://api.staging.baseten.co",
"https://app.dev.baseten.co": "https://api.staging.baseten.co",
# For local development, this is how we map URLs
"http://localhost:8000": "http://api.localhost:8000",
}

# If a non-standard domain is used with the baseten remote, default to
# using the production api routes
DEFAULT_API_DOMAIN = "https://api.baseten.co"


class BasetenApi:
"""
Expand All @@ -23,14 +35,21 @@ class BasetenApi:
class GraphQLErrorCodes(Enum):
RESOURCE_NOT_FOUND = "RESOURCE_NOT_FOUND"

def __init__(
self, graphql_api_url: str, rest_api_url: str, auth_service: AuthService
):
def __init__(self, remote_url: str, auth_service: AuthService) -> None:
graphql_api_url = f"{remote_url}/graphql/"
# Ensure we strip off trailing '/' to denormalize URLs.
rest_api_url = API_URL_MAPPING.get(remote_url.strip("/"), DEFAULT_API_DOMAIN)

self._remote_url = remote_url
self._graphql_api_url = graphql_api_url
self._rest_api_url = rest_api_url
self._auth_service = auth_service
self._auth_token = self._auth_service.authenticate()

@property
def remote_url(self) -> str:
return self._remote_url

def _post_graphql_query(self, query_string: str) -> dict:
headers = self._auth_token.header()
resp = requests.post(
Expand Down Expand Up @@ -254,7 +273,7 @@ def patch_draft_truss(self, model_name, patch_request):
resp = self._post_graphql_query(query_string)
return resp["data"]["patch_draft_truss"]

def get_deployment(self, model_id: str, deployment_id: str) -> str:
def get_deployment(self, model_id: str, deployment_id: str) -> Any:
headers = self._auth_token.header()
resp = requests.get(
f"{self._rest_api_url}/v1/models/{model_id}/deployments/{deployment_id}",
Expand Down
6 changes: 3 additions & 3 deletions truss/remote/baseten/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,20 @@
class ApiKey:
value: str

def __init__(self, value: str):
def __init__(self, value: str) -> None:
self.value = value

def header(self):
return {"Authorization": f"Api-Key {self.value}"}


class AuthService:
def __init__(self, api_key: Optional[str] = None):
def __init__(self, api_key: Optional[str] = None) -> None:
if not api_key:
api_key = os.environ.get("BASETEN_API_KEY", None)
self._api_key = api_key

def validate(self):
def validate(self) -> None:
if not self._api_key:
raise AuthorizationError("No API key provided.")

Expand Down
37 changes: 5 additions & 32 deletions truss/remote/baseten/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,39 +27,18 @@
from truss.remote.baseten.error import ApiError
from truss.remote.baseten.service import BasetenService
from truss.remote.baseten.utils.transfer import base64_encoded_json_str
from truss.remote.truss_remote import TrussRemote, TrussService
from truss.remote.truss_remote import TrussRemote
from truss.truss_config import ModelServer
from truss.truss_handle import TrussHandle
from truss.util.path import is_ignored, load_trussignore_patterns
from watchfiles import watch

API_URL_MAPPING = {
"https://app.baseten.co": "https://api.baseten.co",
"https://app.staging.baseten.co": "https://api.staging.baseten.co",
"https://app.dev.baseten.co": "https://api.staging.baseten.co",
# For local development, this is how we map URLs
"http://localhost:8000": "http://api.localhost:8000",
}

# If a non-standard domain is used with the baseten remote, default to
# using the production api routes
DEFAULT_API_DOMAIN = "https://api.baseten.co"


class BasetenRemote(TrussRemote):
def __init__(self, remote_url: str, api_key: str, **kwargs):
super().__init__(remote_url, **kwargs)
self._auth_service = AuthService(api_key=api_key)
self._api = BasetenApi(
f"{self._remote_url}/graphql/",
# Ensure we strip off trailing '/' to denormalize
# URLs.
API_URL_MAPPING.get(self._remote_url.strip("/"), DEFAULT_API_DOMAIN),
self._auth_service,
)

def authenticate(self):
return self._auth_service.validate()
self._api = BasetenApi(remote_url, self._auth_service)

def push( # type: ignore
self,
Expand All @@ -70,7 +49,7 @@ def push( # type: ignore
promote: bool = False,
preserve_previous_prod_deployment: bool = False,
deployment_name: Optional[str] = None,
):
) -> BasetenService:
if model_name.isspace():
raise ValueError("Model name cannot be empty")

Expand Down Expand Up @@ -214,17 +193,11 @@ def get_service(self, **kwargs) -> BasetenService:
api=self._api,
)

def get_remote_logs_url(
self,
service: TrussService,
) -> str:
return service.logs_url(self._remote_url)

def sync_truss_to_dev_version_by_name(
self,
model_name: str,
target_directory: str,
):
) -> None:
# verify that development deployment exists for given model name
dev_version = get_dev_version(
self._api, model_name
Expand Down Expand Up @@ -257,7 +230,7 @@ def patch(
self,
watch_path: Path,
truss_ignore_patterns: List[str],
):
) -> None:
from truss.cli.console import console, error_console

try:
Expand Down
42 changes: 35 additions & 7 deletions truss/remote/baseten/service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import time
from typing import Dict, Optional
import urllib.parse
from typing import Any, Dict, Iterator, Optional

import requests
from tenacity import retry, stop_after_delay, wait_fixed
Expand All @@ -12,6 +13,14 @@
DEFAULT_STREAM_ENCODING = "utf-8"


def _add_model_subdomain(rest_api_url: str, model_subdomain: str) -> str:
"""E.g. `https://api.baseten.co` -> `https://{model_subdomain}.api.baseten.co`"""
parsed_url = urllib.parse.urlparse(rest_api_url)
new_netloc = f"{model_subdomain}.{parsed_url.netloc}"
model_url = parsed_url._replace(netloc=new_netloc)
return str(urllib.parse.urlunparse(model_url))


class BasetenService(TrussService):
def __init__(
self,
Expand Down Expand Up @@ -51,9 +60,9 @@ def invocation_url(self) -> str:
def predict(
self,
model_request_body: Dict,
):
) -> Any:
response = self._send_request(
self.invocation_url, "POST", data=model_request_body, stream=True
self.predict_url, "POST", data=model_request_body, stream=True
)

if response.headers.get("transfer-encoding") == "chunked":
Expand Down Expand Up @@ -85,14 +94,33 @@ def decode_content():
def authenticate(self) -> dict:
return self._auth_service.authenticate().header()

def logs_url(self, base_url: str) -> str:
return f"{base_url}/models/{self._model_id}/logs/{self._model_version_id}"
@property
def logs_url(self) -> str:
return (
f"{self._api.remote_url}/models/{self._model_id}/"
f"logs/{self._model_version_id}"
)

@property
def predict_url(self) -> str:
"""
Get the URL for the prediction endpoint.
"""
# E.g. `https://api.baseten.co` -> `https://model-{model_id}.api.baseten.co`
url = _add_model_subdomain(self._api.remote_url, f"/model-{self.model_id}")
if self.is_draft:
# "https://model-{model_id}.api.baseten.co/development".
url = f"{url}/development/predict"
else:
# "https://model-{model_id}.api.baseten.co/deployment/{deployment_id}".
url = f"{url}/deployment/{self.model_version_id}/predict"
return url

@retry(stop=stop_after_delay(60), wait=wait_fixed(1), reraise=True)
def _fetch_deployment(self):
def _fetch_deployment(self) -> Any:
return self._api.get_deployment(self._model_id, self._model_version_id)

def poll_deployment_status(self):
def poll_deployment_status(self) -> Iterator[str]:
"""
Wait for the service to be deployed.
"""
Expand Down