Skip to content

Commit

Permalink
Cleanup Truss APIs around predict URLs.
Browse files Browse the repository at this point in the history
  • Loading branch information
marius-baseten committed Apr 16, 2024
1 parent 1303747 commit 3f1e89c
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 101 deletions.
29 changes: 24 additions & 5 deletions truss/remote/baseten/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from enum import Enum
from typing import Optional
from typing import Any, Optional

import requests
from truss.remote.baseten.auth import AuthService
Expand All @@ -9,6 +9,18 @@

logger = logging.getLogger(__name__)

API_URL_MAPPING = {
"https://app.baseten.co": "https://api.baseten.co",
"https://app.staging.baseten.co": "https://api.staging.baseten.co",
"https://app.dev.baseten.co": "https://api.staging.baseten.co",
# For local development, this is how we map URLs
"http://localhost:8000": "http://api.localhost:8000",
}

# If a non-standard domain is used with the baseten remote, default to
# using the production api routes
DEFAULT_API_DOMAIN = "https://api.baseten.co"


class BasetenApi:
"""
Expand All @@ -23,14 +35,21 @@ class BasetenApi:
class GraphQLErrorCodes(Enum):
RESOURCE_NOT_FOUND = "RESOURCE_NOT_FOUND"

def __init__(
self, graphql_api_url: str, rest_api_url: str, auth_service: AuthService
):
def __init__(self, remote_url: str, auth_service: AuthService) -> None:
graphql_api_url = f"{remote_url}/graphql/"
# Ensure we strip off trailing '/' to denormalize URLs.
rest_api_url = API_URL_MAPPING.get(remote_url.strip("/"), DEFAULT_API_DOMAIN)

self._remote_url = remote_url
self._graphql_api_url = graphql_api_url
self._rest_api_url = rest_api_url
self._auth_service = auth_service
self._auth_token = self._auth_service.authenticate()

@property
def remote_url(self) -> str:
return self._remote_url

def _post_graphql_query(self, query_string: str) -> dict:
headers = self._auth_token.header()
resp = requests.post(
Expand Down Expand Up @@ -254,7 +273,7 @@ def patch_draft_truss(self, model_name, patch_request):
resp = self._post_graphql_query(query_string)
return resp["data"]["patch_draft_truss"]

def get_deployment(self, model_id: str, deployment_id: str) -> str:
def get_deployment(self, model_id: str, deployment_id: str) -> Any:
headers = self._auth_token.header()
resp = requests.get(
f"{self._rest_api_url}/v1/models/{model_id}/deployments/{deployment_id}",
Expand Down
6 changes: 3 additions & 3 deletions truss/remote/baseten/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,20 @@
class ApiKey:
value: str

def __init__(self, value: str):
def __init__(self, value: str) -> None:
self.value = value

def header(self):
return {"Authorization": f"Api-Key {self.value}"}


class AuthService:
def __init__(self, api_key: Optional[str] = None):
def __init__(self, api_key: Optional[str] = None) -> None:
if not api_key:
api_key = os.environ.get("BASETEN_API_KEY", None)
self._api_key = api_key

def validate(self):
def validate(self) -> None:
if not self._api_key:
raise AuthorizationError("No API key provided.")

Expand Down
37 changes: 5 additions & 32 deletions truss/remote/baseten/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,39 +27,18 @@
from truss.remote.baseten.error import ApiError
from truss.remote.baseten.service import BasetenService
from truss.remote.baseten.utils.transfer import base64_encoded_json_str
from truss.remote.truss_remote import TrussRemote, TrussService
from truss.remote.truss_remote import TrussRemote
from truss.truss_config import ModelServer
from truss.truss_handle import TrussHandle
from truss.util.path import is_ignored, load_trussignore_patterns
from watchfiles import watch

API_URL_MAPPING = {
"https://app.baseten.co": "https://api.baseten.co",
"https://app.staging.baseten.co": "https://api.staging.baseten.co",
"https://app.dev.baseten.co": "https://api.staging.baseten.co",
# For local development, this is how we map URLs
"http://localhost:8000": "http://api.localhost:8000",
}

# If a non-standard domain is used with the baseten remote, default to
# using the production api routes
DEFAULT_API_DOMAIN = "https://api.baseten.co"


class BasetenRemote(TrussRemote):
def __init__(self, remote_url: str, api_key: str, **kwargs):
super().__init__(remote_url, **kwargs)
self._auth_service = AuthService(api_key=api_key)
self._api = BasetenApi(
f"{self._remote_url}/graphql/",
# Ensure we strip off trailing '/' to denormalize
# URLs.
API_URL_MAPPING.get(self._remote_url.strip("/"), DEFAULT_API_DOMAIN),
self._auth_service,
)

def authenticate(self):
return self._auth_service.validate()
self._api = BasetenApi(remote_url, self._auth_service)

def push( # type: ignore
self,
Expand All @@ -70,7 +49,7 @@ def push( # type: ignore
promote: bool = False,
preserve_previous_prod_deployment: bool = False,
deployment_name: Optional[str] = None,
):
) -> BasetenService:
if model_name.isspace():
raise ValueError("Model name cannot be empty")

Expand Down Expand Up @@ -214,17 +193,11 @@ def get_service(self, **kwargs) -> BasetenService:
api=self._api,
)

def get_remote_logs_url(
self,
service: TrussService,
) -> str:
return service.logs_url(self._remote_url)

def sync_truss_to_dev_version_by_name(
self,
model_name: str,
target_directory: str,
):
) -> None:
# verify that development deployment exists for given model name
dev_version = get_dev_version(
self._api, model_name
Expand Down Expand Up @@ -257,7 +230,7 @@ def patch(
self,
watch_path: Path,
truss_ignore_patterns: List[str],
):
) -> None:
from truss.cli.console import console, error_console

try:
Expand Down
42 changes: 35 additions & 7 deletions truss/remote/baseten/service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import time
from typing import Dict, Optional
import urllib.parse
from typing import Any, Dict, Iterator, Optional

import requests
from tenacity import retry, stop_after_delay, wait_fixed
Expand All @@ -12,6 +13,14 @@
DEFAULT_STREAM_ENCODING = "utf-8"


def _add_model_subdomain(rest_api_url: str, model_subdomain: str) -> str:
"""E.g. `https://api.baseten.co` -> `https://{model_subdomain}.api.baseten.co`"""
parsed_url = urllib.parse.urlparse(rest_api_url)
new_netloc = f"{model_subdomain}.{parsed_url.netloc}"
model_url = parsed_url._replace(netloc=new_netloc)
return str(urllib.parse.urlunparse(model_url))


class BasetenService(TrussService):
def __init__(
self,
Expand Down Expand Up @@ -51,9 +60,9 @@ def invocation_url(self) -> str:
def predict(
self,
model_request_body: Dict,
):
) -> Any:
response = self._send_request(
self.invocation_url, "POST", data=model_request_body, stream=True
self.predict_url, "POST", data=model_request_body, stream=True
)

if response.headers.get("transfer-encoding") == "chunked":
Expand Down Expand Up @@ -85,14 +94,33 @@ def decode_content():
def authenticate(self) -> dict:
return self._auth_service.authenticate().header()

def logs_url(self, base_url: str) -> str:
return f"{base_url}/models/{self._model_id}/logs/{self._model_version_id}"
@property
def logs_url(self) -> str:
return (
f"{self._api.remote_url}/models/{self._model_id}/"
f"logs/{self._model_version_id}"
)

@property
def predict_url(self) -> str:
"""
Get the URL for the prediction endpoint.
"""
# E.g. `https://api.baseten.co` -> `https://model-{model_id}.api.baseten.co`
url = _add_model_subdomain(self._api.remote_url, f"/model-{self.model_id}")
if self.is_draft:
# "https://model-{model_id}.api.baseten.co/development".
url = f"{url}/development/predict"
else:
# "https://model-{model_id}.api.baseten.co/deployment/{deployment_id}".
url = f"{url}/deployment/{self.model_version_id}/predict"
return url

@retry(stop=stop_after_delay(60), wait=wait_fixed(1), reraise=True)
def _fetch_deployment(self):
def _fetch_deployment(self) -> Any:
return self._api.get_deployment(self._model_id, self._model_version_id)

def poll_deployment_status(self):
def poll_deployment_status(self) -> Iterator[str]:
"""
Wait for the service to be deployed.
"""
Expand Down
63 changes: 24 additions & 39 deletions truss/remote/truss_remote.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Dict, Optional
from typing import Any, Dict, Iterator, Optional

import requests
from truss.truss_handle import TrussHandle
Expand All @@ -23,7 +23,7 @@ class TrussService(ABC):
"""

def __init__(self, service_url: str, is_draft: bool, **kwargs):
def __init__(self, service_url: str, is_draft: bool, **kwargs) -> None:
self._service_url = service_url
self._is_draft = is_draft

Expand All @@ -34,7 +34,7 @@ def _send_request(
headers: Optional[Dict] = None,
data: Optional[Dict] = None,
stream: Optional[bool] = False,
):
) -> Any:
"""
Send a HTTP request.
Expand Down Expand Up @@ -75,7 +75,7 @@ def _send_request(
return response

@property
def is_draft(self):
def is_draft(self) -> bool:
"""
Check if the service is in draft mode.
Expand Down Expand Up @@ -112,7 +112,7 @@ def is_ready(self) -> bool:
response = self._send_request(readiness_url, "GET", {})
return response.status_code == 200

def predict(self, model_request_body: Dict):
def predict(self, model_request_body: Dict) -> Any:
"""
Send a prediction request to the service.
Expand All @@ -122,11 +122,9 @@ def predict(self, model_request_body: Dict):
Returns:
A Response object resulting from the prediction request.
"""
invocation_url = f"{self._service_url}/v1/models/model:predict"
response = self._send_request(invocation_url, "POST", data=model_request_body)
return response
return self._send_request(self.predict_url, "POST", data=model_request_body)

def patch(self):
def patch(self) -> None:
"""
Patch the service. TrussServices in draft mode can be patched.
"""
Expand All @@ -142,15 +140,24 @@ def authenticate(self) -> dict:
"""
return {}

@property
@abstractmethod
def logs_url(self, base_url: str) -> str:
def logs_url(self) -> str:
"""
Get the URL for the service logs.
"""
pass

@property
@abstractmethod
def predict_url(self) -> str:
"""
Get the URL for the prediction endpoint.
"""
pass

@abstractmethod
def poll_deployment_status(self):
def poll_deployment_status(self) -> Iterator[str]:
"""
Poll for a deployment status.
"""
Expand All @@ -173,11 +180,11 @@ class TrussRemote(ABC):
"""

def __init__(self, remote_url: str, **kwargs):
def __init__(self, remote_url: str, **kwargs) -> None:
self._remote_url = remote_url

@abstractmethod
def push(self, truss_handle: TrussHandle, **kwargs):
def push(self, truss_handle: TrussHandle, **kwargs) -> TrussService:
"""
Push a TrussHandle to the remote service.
Expand All @@ -192,21 +199,7 @@ def push(self, truss_handle: TrussHandle, **kwargs):
pass

@abstractmethod
def authenticate(self, **kwargs):
"""
Authenticate the user to push to the remote service.
This method should be implemented in subclasses. It should check whether
the user has valid authentication credentials to push to the remote service.
If not, it should raise an exception.
Args:
**kwargs: Additional keyword arguments for the authentication operation.
"""
pass

@abstractmethod
def get_service(self, **kwargs):
def get_service(self, **kwargs) -> TrussService:
"""
Get a TrussService object for interacting with the remote service.
Expand All @@ -221,17 +214,9 @@ def get_service(self, **kwargs):
pass

@abstractmethod
def get_remote_logs_url(self, service: TrussService) -> str:
"""
Get the URL for the remote service logs.
Args:
service: The TrussService object for interacting with the remote service.
"""
pass

@abstractmethod
def sync_truss_to_dev_version_by_name(self, model_name: str, target_directory: str):
def sync_truss_to_dev_version_by_name(
self, model_name: str, target_directory: str
) -> None:
"""
This method watches for changes to files in the `target_directory`,
and syncs them to the development version of the model, identified
Expand Down
5 changes: 1 addition & 4 deletions truss/tests/remote/test_remote_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest
from truss.remote.remote_factory import RemoteFactory
from truss.remote.truss_remote import RemoteConfig, TrussRemote, TrussService
from truss.remote.truss_remote import RemoteConfig, TrussRemote

SAMPLE_CONFIG = {"api_key": "test_key", "remote_url": "http://test.com"}

Expand Down Expand Up @@ -36,9 +36,6 @@ def authenticate(self):
def push(self):
return {"status": "success"}

def get_remote_logs_url(self, service: TrussService) -> str:
raise NotImplementedError

def get_service(self, **kwargs):
raise NotImplementedError

Expand Down

0 comments on commit 3f1e89c

Please sign in to comment.