From 7addf8e923d65293500ba07f5a7be5b3aa7f5339 Mon Sep 17 00:00:00 2001 From: javelin Date: Thu, 20 Feb 2025 19:25:41 +0530 Subject: [PATCH] fix: add support for univ endpoints in javelin sdk --- .pre-commit-config.yaml | 13 ---- examples/univ_endpoint_generic.py | 55 +++++++++++++++++ javelin_cli/_internal/commands.py | 3 +- javelin_sdk/client.py | 72 +++++++++++++++++++++-- javelin_sdk/models.py | 21 ++++++- javelin_sdk/services/gateway_service.py | 1 - javelin_sdk/services/modelspec_service.py | 1 - javelin_sdk/services/provider_service.py | 1 - javelin_sdk/services/route_service.py | 52 ++++++++++++++-- javelin_sdk/services/secret_service.py | 1 - javelin_sdk/services/template_service.py | 1 - javelin_sdk/services/trace_service.py | 1 - 12 files changed, 190 insertions(+), 32 deletions(-) create mode 100644 examples/univ_endpoint_generic.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 14fd505..39cc3eb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,23 +1,10 @@ repos: - - repo: https://github.com/pycqa/isort - rev: 5.12.0 - hooks: - - id: isort - name: isort (python) - - repo: https://github.com/psf/black rev: 24.3.0 hooks: - id: black language_version: python3 - - repo: https://github.com/charliermarsh/ruff-pre-commit - # Ruff version. - rev: "v0.0.265" - hooks: - - id: ruff - args: [--fix, --exit-non-zero-on-fix] - - repo: https://github.com/python-poetry/poetry rev: "1.4.0" # add version here hooks: diff --git a/examples/univ_endpoint_generic.py b/examples/univ_endpoint_generic.py new file mode 100644 index 0000000..f19aa21 --- /dev/null +++ b/examples/univ_endpoint_generic.py @@ -0,0 +1,55 @@ +import asyncio +import json +import os +from typing import Any, Dict + +from javelin_sdk import JavelinClient, JavelinConfig + + +# Helper function to pretty print responses +def print_response(provider: str, response: Dict[str, Any]) -> None: + print(f"=== Response from {provider} ===") + print(json.dumps(response, indent=2)) + + +# Setup client configuration +config = JavelinConfig( + base_url="https://api-dev.javelin.live", + javelin_api_key=os.getenv("JAVELIN_API_KEY"), + llm_api_key=os.getenv("OPENAI_API_KEY"), +) +client = JavelinClient(config) + +# Example messages in OpenAI format +messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What are the three primary colors?"}, +] + +# Define the headers based on the curl command +custom_headers = { + "Content-Type": "application/json", + "x-javelin-route": "openai_univ", + "x-javelin-model": "gpt-4", + "x-javelin-provider": "https://api.openai.com/v1", + "x-api-key": os.getenv("JAVELIN_API_KEY"), # Use environment variable for security + "Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}", # Use environment variable for security +} + + +async def main(): + try: + query_body = {"messages": messages, "temperature": 0.7} + openai_response = await client.aquery_unified_endpoint( + provider_name="openai", + endpoint_type="chat", + query_body=query_body, + headers=custom_headers, + ) + print_response("OpenAI", openai_response) + except Exception as e: + print(f"OpenAI query failed: {str(e)}") + + +# Run the async function +asyncio.run(main()) diff --git a/javelin_cli/_internal/commands.py b/javelin_cli/_internal/commands.py index 2454b6d..a3a6c57 100644 --- a/javelin_cli/_internal/commands.py +++ b/javelin_cli/_internal/commands.py @@ -2,8 +2,6 @@ import os from pathlib import Path -from pydantic import ValidationError - from javelin_sdk.client import JavelinClient from javelin_sdk.exceptions import ( BadRequest, @@ -29,6 +27,7 @@ Template, Templates, ) +from pydantic import ValidationError def get_javelin_client(): diff --git a/javelin_sdk/client.py b/javelin_sdk/client.py index 711b2e7..77543cf 100644 --- a/javelin_sdk/client.py +++ b/javelin_sdk/client.py @@ -6,9 +6,6 @@ from urllib.parse import unquote, urljoin, urlparse, urlunparse import httpx -from opentelemetry.semconv._incubating.attributes import gen_ai_attributes -from opentelemetry.trace import SpanKind, Status, StatusCode - from javelin_sdk.chat_completions import Chat, Completions from javelin_sdk.models import HttpMethod, JavelinConfig, Request from javelin_sdk.services.gateway_service import GatewayService @@ -19,6 +16,8 @@ from javelin_sdk.services.template_service import TemplateService from javelin_sdk.services.trace_service import TraceService from javelin_sdk.tracing_setup import configure_span_exporter +from opentelemetry.semconv._incubating.attributes import gen_ai_attributes +from opentelemetry.trace import SpanKind, Status, StatusCode API_BASEURL = "https://api-dev.javelin.live" API_BASE_PATH = "/v1" @@ -546,7 +545,7 @@ def get_inference_model(inference_profile_identifier: str) -> str: ) model_id = foundation_model_response["modelDetails"]["modelId"] return model_id - except Exception as e: + except Exception: # Fail silently if the model is not found return None @@ -557,7 +556,7 @@ def get_foundation_model(model_identifier: str) -> str: modelIdentifier=model_identifier ) return response["modelDetails"]["modelId"] - except Exception as e: + except Exception: # Fail silently if the model is not found return None @@ -668,6 +667,7 @@ def _prepare_request(self, request: Request) -> tuple: is_transformation_rules=request.is_transformation_rules, is_model_specs=request.is_model_specs, is_reload=request.is_reload, + univ_model=request.univ_model_config, ) headers = {**self._headers, **(request.headers or {})} return url, headers @@ -708,6 +708,7 @@ def _construct_url( is_transformation_rules: bool = False, is_model_specs: bool = False, is_reload: bool = False, + univ_model: Optional[Dict[str, Any]] = None, ) -> str: url_parts = [self.base_url] @@ -770,6 +771,12 @@ def _construct_url( if query_params: query_string = "&".join(f"{k}={v}" for k, v in query_params.items()) url += f"?{query_string}" + + # Integrate construct_endpoint_url logic + if univ_model: + endpoint_url = self.construct_endpoint_url(univ_model) + url = urljoin(url, endpoint_url) + return url # Gateway methods @@ -876,6 +883,12 @@ def _construct_url( aquery_llama = lambda self, route_name, query_body: self.route_service.aquery_llama( route_name, query_body ) + query_unified_endpoint = lambda self, provider_name, endpoint_type, query_body, headers=None, query_params=None: self.route_service.query_unified_endpoint( + provider_name, endpoint_type, query_body, headers, query_params + ) + aquery_unified_endpoint = lambda self, provider_name, endpoint_type, query_body, headers=None, query_params=None: self.route_service.aquery_unified_endpoint( + provider_name, endpoint_type, query_body, headers, query_params + ) # Secret methods create_secret = lambda self, secret: self.secret_service.create_secret(secret) @@ -969,3 +982,52 @@ async def aget_last_n_chronicle_records( ) response = await self._send_request_async(request) return response + + def construct_endpoint_url(self, request_model: Dict[str, Any]) -> str: + """ + Constructs the endpoint URL based on the request model. + + :param base_url: The base URL for the API. + :param request_model: The request model containing endpoint details. + :return: The constructed endpoint URL. + """ + base_url = self.base_url + provider_name = request_model.get("provider_name") + endpoint_type = request_model.get("endpoint_type") + deployment = request_model.get("deployment") + arn = request_model.get("arn") + api_version = request_model.get( + "api_version", "2023-07-01-preview" + ) # Default version + + if not provider_name: + raise ValueError("Provider name is not specified in the request model.") + + if provider_name == "azureopenai" and deployment: + # Handle Azure OpenAI endpoints + if endpoint_type == "chat": + return f"{base_url}/{provider_name}/deployments/{deployment}/chat/completions?api-version={api_version}" + elif endpoint_type == "completion": + return f"{base_url}/{provider_name}/deployments/{deployment}/completions?api-version={api_version}" + elif endpoint_type == "embeddings": + return f"{base_url}/{provider_name}/deployments/{deployment}/embeddings?api-version={api_version}" + elif arn: + # Handle Bedrock endpoints + if endpoint_type == "invoke": + return f"{base_url}/v1/model/{arn}/invoke" + elif endpoint_type == "converse": + return f"{base_url}/v1/model/{arn}/converse" + elif endpoint_type == "invoke_stream": + return f"{base_url}/v1/model/{arn}/invoke-with-response-stream" + elif endpoint_type == "converse_stream": + return f"{base_url}/v1/model/{arn}/converse-stream" + else: + # Handle OpenAI compatible endpoints + if endpoint_type == "chat": + return f"{base_url}/{provider_name}/chat/completions" + elif endpoint_type == "completion": + return f"{base_url}/{provider_name}/completions" + elif endpoint_type == "embeddings": + return f"{base_url}/{provider_name}/embeddings" + + raise ValueError("Invalid request model configuration") diff --git a/javelin_sdk/models.py b/javelin_sdk/models.py index 9001287..6058e9a 100644 --- a/javelin_sdk/models.py +++ b/javelin_sdk/models.py @@ -1,9 +1,8 @@ from enum import Enum, auto from typing import Any, Dict, List, Optional -from pydantic import BaseModel, Field, field_validator - from javelin_sdk.exceptions import UnauthorizedError +from pydantic import BaseModel, Field, field_validator class GatewayConfig(BaseModel): @@ -478,6 +477,7 @@ def __init__( is_transformation_rules: bool = False, is_model_specs: bool = False, is_reload: bool = False, + univ_model_config: Optional[Dict[str, Any]] = None, ): self.method = method self.gateway = gateway @@ -494,6 +494,7 @@ def __init__( self.is_transformation_rules = is_transformation_rules self.is_model_specs = is_model_specs self.is_reload = is_reload + self.univ_model_config = univ_model_config class Message(BaseModel): @@ -568,3 +569,19 @@ class EndpointType(str, Enum): INVOKE_STREAM = "invoke_stream" CONVERSE_STREAM = "converse_stream" ALL = "all" + + +class UnivModelConfig: + def __init__( + self, + provider_name: str, + endpoint_type: str, + deployment: Optional[str] = None, + arn: Optional[str] = None, + api_version: Optional[str] = None, + ): + self.provider_name = provider_name + self.endpoint_type = endpoint_type + self.deployment = deployment + self.arn = arn + self.api_version = api_version diff --git a/javelin_sdk/services/gateway_service.py b/javelin_sdk/services/gateway_service.py index 361318c..dbb0df7 100644 --- a/javelin_sdk/services/gateway_service.py +++ b/javelin_sdk/services/gateway_service.py @@ -1,7 +1,6 @@ from typing import List import httpx - from javelin_sdk.exceptions import ( BadRequest, GatewayAlreadyExistsError, diff --git a/javelin_sdk/services/modelspec_service.py b/javelin_sdk/services/modelspec_service.py index c8a2c8a..349cafe 100644 --- a/javelin_sdk/services/modelspec_service.py +++ b/javelin_sdk/services/modelspec_service.py @@ -1,7 +1,6 @@ from typing import Any, Dict, Optional import httpx - from javelin_sdk.exceptions import ( BadRequest, InternalServerError, diff --git a/javelin_sdk/services/provider_service.py b/javelin_sdk/services/provider_service.py index 4b73d76..98c55b6 100644 --- a/javelin_sdk/services/provider_service.py +++ b/javelin_sdk/services/provider_service.py @@ -1,7 +1,6 @@ from typing import Any, Dict, List import httpx - from javelin_sdk.exceptions import ( BadRequest, InternalServerError, diff --git a/javelin_sdk/services/route_service.py b/javelin_sdk/services/route_service.py index d28b1ba..b394ea5 100644 --- a/javelin_sdk/services/route_service.py +++ b/javelin_sdk/services/route_service.py @@ -1,10 +1,7 @@ import json -import time from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union import httpx -from jsonpath_ng import parse - from javelin_sdk.exceptions import ( BadRequest, InternalServerError, @@ -13,7 +10,8 @@ RouteNotFoundError, UnauthorizedError, ) -from javelin_sdk.models import HttpMethod, Request, Route, Routes +from javelin_sdk.models import HttpMethod, Request, Route, Routes, UnivModelConfig +from jsonpath_ng import parse class RouteService: @@ -310,3 +308,49 @@ async def areload_route(self, route_name: str) -> str: ) ) return response + + def query_unified_endpoint( + self, + provider_name: str, + endpoint_type: str, + query_body: Dict[str, Any], + headers: Optional[Dict[str, str]] = None, + query_params: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + univ_model_config = UnivModelConfig( + provider_name=provider_name, + endpoint_type=endpoint_type, + ) + + request = Request( + method=HttpMethod.POST, + data=query_body, + univ_model_config=univ_model_config.__dict__, + headers=headers, + query_params=query_params, + ) + response = self.client._send_request_sync(request) + return response.json() + + async def aquery_unified_endpoint( + self, + provider_name: str, + endpoint_type: str, + query_body: Dict[str, Any], + headers: Optional[Dict[str, str]] = None, + query_params: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + univ_model_config = UnivModelConfig( + provider_name=provider_name, + endpoint_type=endpoint_type, + ) + + request = Request( + method=HttpMethod.POST, + data=query_body, + univ_model_config=univ_model_config.__dict__, + headers=headers, + query_params=query_params, + ) + response = await self.client._send_request_async(request) + return response.json() diff --git a/javelin_sdk/services/secret_service.py b/javelin_sdk/services/secret_service.py index 61eba35..957d9eb 100644 --- a/javelin_sdk/services/secret_service.py +++ b/javelin_sdk/services/secret_service.py @@ -1,7 +1,6 @@ from typing import List import httpx - from javelin_sdk.exceptions import ( BadRequest, InternalServerError, diff --git a/javelin_sdk/services/template_service.py b/javelin_sdk/services/template_service.py index 9241467..ba7d9db 100644 --- a/javelin_sdk/services/template_service.py +++ b/javelin_sdk/services/template_service.py @@ -1,7 +1,6 @@ from typing import List import httpx - from javelin_sdk.exceptions import ( BadRequest, InternalServerError, diff --git a/javelin_sdk/services/trace_service.py b/javelin_sdk/services/trace_service.py index 2a9c5cc..7184b4f 100644 --- a/javelin_sdk/services/trace_service.py +++ b/javelin_sdk/services/trace_service.py @@ -1,7 +1,6 @@ from typing import List import httpx - from javelin_sdk.exceptions import ( BadRequest, InternalServerError,