Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 0 additions & 13 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,23 +1,10 @@
repos:
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
name: isort (python)

- repo: https://github.com/psf/black
rev: 24.3.0
hooks:
- id: black
language_version: python3

- repo: https://github.com/charliermarsh/ruff-pre-commit
# Ruff version.
rev: "v0.0.265"
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]

- repo: https://github.com/python-poetry/poetry
rev: "1.4.0" # add version here
hooks:
Expand Down
55 changes: 55 additions & 0 deletions examples/univ_endpoint_generic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import asyncio
import json
import os
from typing import Any, Dict

from javelin_sdk import JavelinClient, JavelinConfig


# Helper function to pretty print responses
def print_response(provider: str, response: Dict[str, Any]) -> None:
print(f"=== Response from {provider} ===")
print(json.dumps(response, indent=2))


# Setup client configuration
config = JavelinConfig(
base_url="https://api-dev.javelin.live",
javelin_api_key=os.getenv("JAVELIN_API_KEY"),
llm_api_key=os.getenv("OPENAI_API_KEY"),
)
client = JavelinClient(config)

# Example messages in OpenAI format
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What are the three primary colors?"},
]

# Define the headers based on the curl command
custom_headers = {
"Content-Type": "application/json",
"x-javelin-route": "openai_univ",
"x-javelin-model": "gpt-4",
"x-javelin-provider": "https://api.openai.com/v1",
"x-api-key": os.getenv("JAVELIN_API_KEY"), # Use environment variable for security
"Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}", # Use environment variable for security
}


async def main():
try:
query_body = {"messages": messages, "temperature": 0.7}
openai_response = await client.aquery_unified_endpoint(
provider_name="openai",
endpoint_type="chat",
query_body=query_body,
headers=custom_headers,
)
print_response("OpenAI", openai_response)
except Exception as e:
print(f"OpenAI query failed: {str(e)}")


# Run the async function
asyncio.run(main())
3 changes: 1 addition & 2 deletions javelin_cli/_internal/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import os
from pathlib import Path

from pydantic import ValidationError

from javelin_sdk.client import JavelinClient
from javelin_sdk.exceptions import (
BadRequest,
Expand All @@ -29,6 +27,7 @@
Template,
Templates,
)
from pydantic import ValidationError


def get_javelin_client():
Expand Down
72 changes: 67 additions & 5 deletions javelin_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
from urllib.parse import unquote, urljoin, urlparse, urlunparse

import httpx
from opentelemetry.semconv._incubating.attributes import gen_ai_attributes
from opentelemetry.trace import SpanKind, Status, StatusCode

from javelin_sdk.chat_completions import Chat, Completions
from javelin_sdk.models import HttpMethod, JavelinConfig, Request
from javelin_sdk.services.gateway_service import GatewayService
Expand All @@ -19,6 +16,8 @@
from javelin_sdk.services.template_service import TemplateService
from javelin_sdk.services.trace_service import TraceService
from javelin_sdk.tracing_setup import configure_span_exporter
from opentelemetry.semconv._incubating.attributes import gen_ai_attributes
from opentelemetry.trace import SpanKind, Status, StatusCode

API_BASEURL = "https://api-dev.javelin.live"
API_BASE_PATH = "/v1"
Expand Down Expand Up @@ -546,7 +545,7 @@ def get_inference_model(inference_profile_identifier: str) -> str:
)
model_id = foundation_model_response["modelDetails"]["modelId"]
return model_id
except Exception as e:
except Exception:
# Fail silently if the model is not found
return None

Expand All @@ -557,7 +556,7 @@ def get_foundation_model(model_identifier: str) -> str:
modelIdentifier=model_identifier
)
return response["modelDetails"]["modelId"]
except Exception as e:
except Exception:
# Fail silently if the model is not found
return None

Expand Down Expand Up @@ -668,6 +667,7 @@ def _prepare_request(self, request: Request) -> tuple:
is_transformation_rules=request.is_transformation_rules,
is_model_specs=request.is_model_specs,
is_reload=request.is_reload,
univ_model=request.univ_model_config,
)
headers = {**self._headers, **(request.headers or {})}
return url, headers
Expand Down Expand Up @@ -708,6 +708,7 @@ def _construct_url(
is_transformation_rules: bool = False,
is_model_specs: bool = False,
is_reload: bool = False,
univ_model: Optional[Dict[str, Any]] = None,
) -> str:
url_parts = [self.base_url]

Expand Down Expand Up @@ -770,6 +771,12 @@ def _construct_url(
if query_params:
query_string = "&".join(f"{k}={v}" for k, v in query_params.items())
url += f"?{query_string}"

# Integrate construct_endpoint_url logic
if univ_model:
endpoint_url = self.construct_endpoint_url(univ_model)
url = urljoin(url, endpoint_url)

return url

# Gateway methods
Expand Down Expand Up @@ -876,6 +883,12 @@ def _construct_url(
aquery_llama = lambda self, route_name, query_body: self.route_service.aquery_llama(
route_name, query_body
)
query_unified_endpoint = lambda self, provider_name, endpoint_type, query_body, headers=None, query_params=None: self.route_service.query_unified_endpoint(
provider_name, endpoint_type, query_body, headers, query_params
)
aquery_unified_endpoint = lambda self, provider_name, endpoint_type, query_body, headers=None, query_params=None: self.route_service.aquery_unified_endpoint(
provider_name, endpoint_type, query_body, headers, query_params
)

# Secret methods
create_secret = lambda self, secret: self.secret_service.create_secret(secret)
Expand Down Expand Up @@ -969,3 +982,52 @@ async def aget_last_n_chronicle_records(
)
response = await self._send_request_async(request)
return response

def construct_endpoint_url(self, request_model: Dict[str, Any]) -> str:
"""
Constructs the endpoint URL based on the request model.

:param base_url: The base URL for the API.
:param request_model: The request model containing endpoint details.
:return: The constructed endpoint URL.
"""
base_url = self.base_url
provider_name = request_model.get("provider_name")
endpoint_type = request_model.get("endpoint_type")
deployment = request_model.get("deployment")
arn = request_model.get("arn")
api_version = request_model.get(
"api_version", "2023-07-01-preview"
) # Default version

if not provider_name:
raise ValueError("Provider name is not specified in the request model.")

if provider_name == "azureopenai" and deployment:
# Handle Azure OpenAI endpoints
if endpoint_type == "chat":
return f"{base_url}/{provider_name}/deployments/{deployment}/chat/completions?api-version={api_version}"
elif endpoint_type == "completion":
return f"{base_url}/{provider_name}/deployments/{deployment}/completions?api-version={api_version}"
elif endpoint_type == "embeddings":
return f"{base_url}/{provider_name}/deployments/{deployment}/embeddings?api-version={api_version}"
elif arn:
# Handle Bedrock endpoints
if endpoint_type == "invoke":
return f"{base_url}/v1/model/{arn}/invoke"
elif endpoint_type == "converse":
return f"{base_url}/v1/model/{arn}/converse"
elif endpoint_type == "invoke_stream":
return f"{base_url}/v1/model/{arn}/invoke-with-response-stream"
elif endpoint_type == "converse_stream":
return f"{base_url}/v1/model/{arn}/converse-stream"
else:
# Handle OpenAI compatible endpoints
if endpoint_type == "chat":
return f"{base_url}/{provider_name}/chat/completions"
elif endpoint_type == "completion":
return f"{base_url}/{provider_name}/completions"
elif endpoint_type == "embeddings":
return f"{base_url}/{provider_name}/embeddings"

raise ValueError("Invalid request model configuration")
21 changes: 19 additions & 2 deletions javelin_sdk/models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from enum import Enum, auto
from typing import Any, Dict, List, Optional

from pydantic import BaseModel, Field, field_validator

from javelin_sdk.exceptions import UnauthorizedError
from pydantic import BaseModel, Field, field_validator


class GatewayConfig(BaseModel):
Expand Down Expand Up @@ -478,6 +477,7 @@ def __init__(
is_transformation_rules: bool = False,
is_model_specs: bool = False,
is_reload: bool = False,
univ_model_config: Optional[Dict[str, Any]] = None,
):
self.method = method
self.gateway = gateway
Expand All @@ -494,6 +494,7 @@ def __init__(
self.is_transformation_rules = is_transformation_rules
self.is_model_specs = is_model_specs
self.is_reload = is_reload
self.univ_model_config = univ_model_config


class Message(BaseModel):
Expand Down Expand Up @@ -568,3 +569,19 @@ class EndpointType(str, Enum):
INVOKE_STREAM = "invoke_stream"
CONVERSE_STREAM = "converse_stream"
ALL = "all"


class UnivModelConfig:
def __init__(
self,
provider_name: str,
endpoint_type: str,
deployment: Optional[str] = None,
arn: Optional[str] = None,
api_version: Optional[str] = None,
):
self.provider_name = provider_name
self.endpoint_type = endpoint_type
self.deployment = deployment
self.arn = arn
self.api_version = api_version
1 change: 0 additions & 1 deletion javelin_sdk/services/gateway_service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import List

import httpx

from javelin_sdk.exceptions import (
BadRequest,
GatewayAlreadyExistsError,
Expand Down
1 change: 0 additions & 1 deletion javelin_sdk/services/modelspec_service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, Dict, Optional

import httpx

from javelin_sdk.exceptions import (
BadRequest,
InternalServerError,
Expand Down
1 change: 0 additions & 1 deletion javelin_sdk/services/provider_service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, Dict, List

import httpx

from javelin_sdk.exceptions import (
BadRequest,
InternalServerError,
Expand Down
52 changes: 48 additions & 4 deletions javelin_sdk/services/route_service.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import json
import time
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union

import httpx
from jsonpath_ng import parse

from javelin_sdk.exceptions import (
BadRequest,
InternalServerError,
Expand All @@ -13,7 +10,8 @@
RouteNotFoundError,
UnauthorizedError,
)
from javelin_sdk.models import HttpMethod, Request, Route, Routes
from javelin_sdk.models import HttpMethod, Request, Route, Routes, UnivModelConfig
from jsonpath_ng import parse


class RouteService:
Expand Down Expand Up @@ -310,3 +308,49 @@ async def areload_route(self, route_name: str) -> str:
)
)
return response

def query_unified_endpoint(
self,
provider_name: str,
endpoint_type: str,
query_body: Dict[str, Any],
headers: Optional[Dict[str, str]] = None,
query_params: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
univ_model_config = UnivModelConfig(
provider_name=provider_name,
endpoint_type=endpoint_type,
)

request = Request(
method=HttpMethod.POST,
data=query_body,
univ_model_config=univ_model_config.__dict__,
headers=headers,
query_params=query_params,
)
response = self.client._send_request_sync(request)
return response.json()

async def aquery_unified_endpoint(
self,
provider_name: str,
endpoint_type: str,
query_body: Dict[str, Any],
headers: Optional[Dict[str, str]] = None,
query_params: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
univ_model_config = UnivModelConfig(
provider_name=provider_name,
endpoint_type=endpoint_type,
)

request = Request(
method=HttpMethod.POST,
data=query_body,
univ_model_config=univ_model_config.__dict__,
headers=headers,
query_params=query_params,
)
response = await self.client._send_request_async(request)
return response.json()
1 change: 0 additions & 1 deletion javelin_sdk/services/secret_service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import List

import httpx

from javelin_sdk.exceptions import (
BadRequest,
InternalServerError,
Expand Down
1 change: 0 additions & 1 deletion javelin_sdk/services/template_service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import List

import httpx

from javelin_sdk.exceptions import (
BadRequest,
InternalServerError,
Expand Down
1 change: 0 additions & 1 deletion javelin_sdk/services/trace_service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import List

import httpx

from javelin_sdk.exceptions import (
BadRequest,
InternalServerError,
Expand Down