Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion guardrails_api/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.4"
__version__ = "0.0.5"
21 changes: 14 additions & 7 deletions guardrails_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from opentelemetry.instrumentation.flask import FlaskInstrumentor
from guardrails_api.clients.postgres_client import postgres_is_enabled
from guardrails_api.otel import otel_is_disabled, initialize
from guardrails_api.utils.trace_server_start_if_enabled import trace_server_start_if_enabled
from guardrails_api.utils.trace_server_start_if_enabled import (
trace_server_start_if_enabled,
)
from guardrails_api.clients.cache_client import CacheClient
from rich.console import Console
from rich.rule import Rule
Expand Down Expand Up @@ -84,7 +86,7 @@ def create_app(

@app.before_request
def basic_cors():
if request.method.lower() == 'options':
if request.method.lower() == "options":
return Response()

app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1, x_host=1, x_port=1)
Expand Down Expand Up @@ -112,20 +114,25 @@ def basic_cors():
app.register_blueprint(root_bp)
app.register_blueprint(guards_bp)

console.print(f"\n:rocket: Guardrails API is available at {self_endpoint}")
console.print(
f"\n:rocket: Guardrails API is available at {self_endpoint}"
f":book: Visit {self_endpoint}/docs to see available API endpoints.\n"
)
console.print(f":book: Visit {self_endpoint}/docs to see available API endpoints.\n")

console.print(":green_circle: Active guards and OpenAI compatible endpoints:")

with app.app_context():
from guardrails_api.blueprints.guards import guard_client

for g in guard_client.get_guards():
g = g.to_dict()
console.print(f"- Guard: [bold white]{g.get('name')}[/bold white] {self_endpoint}/guards/{g.get('name')}/openai/v1")
console.print(
f"- Guard: [bold white]{g.get('name')}[/bold white] {self_endpoint}/guards/{g.get('name')}/openai/v1"
)

console.print("")
console.print(Rule("[bold grey]Server Logs[/bold grey]", characters="=", style="white"))
console.print(
Rule("[bold grey]Server Logs[/bold grey]", characters="=", style="white")
)

return app
return app
23 changes: 16 additions & 7 deletions guardrails_api/blueprints/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
from guardrails_api.clients.postgres_client import postgres_is_enabled
from guardrails_api.utils.handle_error import handle_error
from guardrails_api.utils.get_llm_callable import get_llm_callable
from guardrails_api.utils.openai import outcome_to_chat_completion, outcome_to_stream_response
from guardrails_api.utils.openai import (
outcome_to_chat_completion,
outcome_to_stream_response,
)

guards_bp = Blueprint("guards", __name__, url_prefix="/guards")

Expand Down Expand Up @@ -272,7 +275,6 @@ def validate(guard_name: str):
# ) as validate_span:
# guard: Guard = guard_struct.to_guard(openai_api_key, otel_tracer)


# validate_span.set_attribute("guardName", decoded_guard_name)
if llm_api is not None:
llm_api = get_llm_callable(llm_api)
Expand All @@ -295,7 +297,7 @@ def validate(guard_name: str):
else:
guard: Guard = Guard.from_dict(guard_struct.to_dict())
elif is_async:
guard:Guard = AsyncGuard.from_dict(guard_struct.to_dict())
guard: Guard = AsyncGuard.from_dict(guard_struct.to_dict())

if llm_api is None and num_reasks and num_reasks > 1:
raise HttpError(
Expand All @@ -322,6 +324,7 @@ def validate(guard_name: str):
)
else:
if stream:

def guard_streamer():
guard_stream = guard(
llm_api=llm_api,
Expand Down Expand Up @@ -452,24 +455,30 @@ async def async_validate_streamer(guard_iter):
cache_key = f"{guard.name}-{final_validation_output.call_id}"
cache_client.set(cache_key, serialized_history, 300)
yield f"{final_output_json}\n"

# apropos of https://stackoverflow.com/questions/73949570/using-stream-with-context-as-async
def iter_over_async(ait, loop):
ait = ait.__aiter__()

async def get_next():
try:
try:
obj = await ait.__anext__()
return False, obj
except StopAsyncIteration:
except StopAsyncIteration:
return True, None

while True:
done, obj = loop.run_until_complete(get_next())
if done:
if done:
break
yield obj

if is_async:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
iter = iter_over_async(async_validate_streamer(async_guard_streamer()), loop)
iter = iter_over_async(
async_validate_streamer(async_guard_streamer()), loop
)
else:
iter = validate_streamer(guard_streamer())
return Response(
Expand Down
1 change: 1 addition & 0 deletions guardrails_api/cli/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from guardrails_api.app import create_app
from guardrails_api.utils.configuration import valid_configuration


@cli.command("start")
def start(
env: Optional[str] = typer.Option(
Expand Down
27 changes: 19 additions & 8 deletions guardrails_api/utils/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,32 @@
from typing import Optional
import os

def valid_configuration(config: Optional[str]=""):

def valid_configuration(config: Optional[str] = ""):
default_config_file = os.path.join(os.getcwd(), "./config.py")

default_config_file_path = os.path.abspath(default_config_file)
# If config.py is not present and
# If config.py is not present and
# if a config filepath is not passed and
# if postgres is not there (i.e. we’re using in-mem db)
# if postgres is not there (i.e. we’re using in-mem db)
# then raise ConfigurationError
has_default_config_file = os.path.isfile(default_config_file_path)

has_config_file = (config != "" and config is not None) and os.path.isfile(os.path.abspath(config))
if not has_default_config_file and not has_config_file and not postgres_is_enabled():
raise ConfigurationError("Can not start. Configuration not provided and default"
" configuration not found and postgres is not enabled.")
has_config_file = (config != "" and config is not None) and os.path.isfile(
os.path.abspath(config)
)

if (
not has_default_config_file
and not has_config_file
and not postgres_is_enabled()
):
raise ConfigurationError(
"Can not start. Configuration not provided and default"
" configuration not found and postgres is not enabled."
)
return True


class ConfigurationError(Exception):
pass
pass
27 changes: 21 additions & 6 deletions guardrails_api/utils/handle_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,35 @@ def decorator(*args, **kwargs):
return fn(*args, **kwargs)
except ValidationError as validation_error:
logger.error(validation_error)
traceback.print_exception(type(validation_error), validation_error, validation_error.__traceback__)
return str(validation_error), 400
traceback.print_exception(
type(validation_error), validation_error, validation_error.__traceback__
)
resp_body = {"status_code": 400, "detail": str(validation_error)}
return resp_body, 400
except HttpError as http_error:
logger.error(http_error)
traceback.print_exception(type(http_error), http_error, http_error.__traceback__)
return http_error.to_dict(), http_error.status
traceback.print_exception(
type(http_error), http_error, http_error.__traceback__
)
resp_body = http_error.to_dict()
resp_body["status_code"] = http_error.status
resp_body["detail"] = http_error.message
return resp_body, http_error.status
except HTTPException as http_exception:
logger.error(http_exception)
traceback.print_exception(http_exception)
http_error = HttpError(http_exception.code, http_exception.description)
return http_error.to_dict(), http_error.status
resp_body = http_error.to_dict()
resp_body["status_code"] = http_error.status
resp_body["detail"] = http_error.message

return resp_body, http_error.status
except Exception as e:
logger.error(e)
traceback.print_exception(e)
return HttpError(500, "Internal Server Error").to_dict(), 500
resp_body = HttpError(500, "Internal Server Error").to_dict()
resp_body["status_code"] = 500
resp_body["detail"] = "Internal Server Error"
return resp_body, 500

return decorator
2 changes: 1 addition & 1 deletion guardrails_api/utils/has_internet_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ def has_internet_connection() -> bool:
res.raise_for_status()
return True
except requests.ConnectionError:
return False
return False
1 change: 1 addition & 0 deletions guardrails_api/utils/openai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from guardrails.classes import ValidationOutcome


def outcome_to_stream_response(validation_outcome: ValidationOutcome):
stream_chunk_template = {
"choices": [
Expand Down
3 changes: 2 additions & 1 deletion guardrails_api/utils/trace_server_start_if_enabled.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ def trace_server_start_if_enabled():
config = Credentials.from_rc_file()
if config.enable_metrics is True and has_internet_connection():
from guardrails.utils.hub_telemetry_utils import HubTelemetry

HubTelemetry().create_new_span(
"guardrails-api/start",
[
Expand All @@ -21,4 +22,4 @@ def trace_server_start_if_enabled():
],
True,
False,
)
)
23 changes: 15 additions & 8 deletions tests/blueprints/test_guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,7 @@ def test_validate__call(mocker):

del os.environ["PGHOST"]


def test_validate__call_throws_validation_error(mocker):
os.environ["PGHOST"] = "localhost"

Expand Down Expand Up @@ -610,19 +611,24 @@ def test_validate__call_throws_validation_error(mocker):
prompt="Hello world!",
)

assert response == ('Test guard validation error', 400)
assert response == (
{"status_code": 400, "detail": "Test guard validation error"},
400,
)

del os.environ["PGHOST"]


def test_openai_v1_chat_completions__raises_404(mocker):
from guardrails_api.blueprints.guards import openai_v1_chat_completions

os.environ["PGHOST"] = "localhost"
mock_guard = None

mock_request = MockRequest(
"POST",
json={
"messages": [{"role":"user", "content":"Hello world!"}],
"messages": [{"role": "user", "content": "Hello world!"}],
},
headers={"x-openai-api-key": "mock-key"},
)
Expand All @@ -637,15 +643,16 @@ def test_openai_v1_chat_completions__raises_404(mocker):

response = openai_v1_chat_completions("My%20Guard's%20Name")
assert response[1] == 404
assert response[0]["message"] == 'NotFound'

assert response[0]["message"] == "NotFound"

mock_get_guard.assert_called_once_with("My Guard's Name")

del os.environ["PGHOST"]


def test_openai_v1_chat_completions__call(mocker):
from guardrails_api.blueprints.guards import openai_v1_chat_completions

os.environ["PGHOST"] = "localhost"
mock_guard = MockGuardStruct()
mock_outcome = ValidationOutcome(
Expand All @@ -664,7 +671,7 @@ def test_openai_v1_chat_completions__call(mocker):
mock_request = MockRequest(
"POST",
json={
"messages": [{"role":"user", "content":"Hello world!"}],
"messages": [{"role": "user", "content": "Hello world!"}],
},
headers={"x-openai-api-key": "mock-key"},
)
Expand All @@ -687,7 +694,7 @@ def test_openai_v1_chat_completions__call(mocker):
)
mock_status.return_value = "fail"
mock_call = Call()
mock_call.iterations= Stack(Iteration('some-id', 1))
mock_call.iterations = Stack(Iteration("some-id", 1))
mock_guard.history = Stack(mock_call)

response = openai_v1_chat_completions("My%20Guard's%20Name")
Expand All @@ -698,7 +705,7 @@ def test_openai_v1_chat_completions__call(mocker):

mock___call__.assert_called_once_with(
num_reasks=0,
messages=[{"role":"user", "content":"Hello world!"}],
messages=[{"role": "user", "content": "Hello world!"}],
)

assert response == {
Expand All @@ -716,4 +723,4 @@ def test_openai_v1_chat_completions__call(mocker):
},
}

del os.environ["PGHOST"]
del os.environ["PGHOST"]
2 changes: 2 additions & 0 deletions tests/cli/test_start.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from unittest.mock import MagicMock
import os


def test_start(mocker):
mocker.patch("guardrails_api.cli.start.cli")

Expand All @@ -10,6 +11,7 @@ def test_start(mocker):
)

from guardrails_api.cli.start import start

# pg enabled
os.environ["PGHOST"] = "localhost"
start("env", "config", 8000)
Expand Down
1 change: 1 addition & 0 deletions tests/mocks/mock_guard_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pydantic import ConfigDict
from guardrails.classes.generic import Stack


class MockGuardStruct(GuardStruct):
# Pydantic Config
model_config = ConfigDict(arbitrary_types_allowed=True)
Expand Down
7 changes: 4 additions & 3 deletions tests/utils/test_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
import pytest
from guardrails_api.utils.configuration import valid_configuration, ConfigurationError


def test_valid_configuration(mocker):
with pytest.raises(ConfigurationError):
valid_configuration()

# pg enabled
os.environ["PGHOST"] = "localhost"
valid_configuration("config.py")
os.environ.pop("PGHOST")

# custom config
mock_isfile = mocker.patch("os.path.isfile")
mock_isfile.side_effect = [False, True]
Expand All @@ -20,7 +21,7 @@ def test_valid_configuration(mocker):
mock_isfile.side_effect = [False, False]
with pytest.raises(ConfigurationError):
valid_configuration("")

# default config
mock_isfile = mocker.patch("os.path.isfile")
mock_isfile.side_effect = [True, False]
Expand Down