Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 34 additions & 15 deletions guardrails/llm_providers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
import asyncio
from typing import Any, Awaitable, Callable, Dict, Iterable, List, Optional, Union, cast
from typing import (
Any,
Awaitable,
Callable,
Dict,
Iterable,
List,
Optional,
Type,
Union,
cast,
)

from guard_rails_api_client.models.validate_payload_llm_api import ValidatePayloadLlmApi
from pydantic import BaseModel
Expand Down Expand Up @@ -154,7 +165,9 @@ def _invoke_llm(
model: str = "gpt-3.5-turbo",
instructions: Optional[str] = None,
msg_history: Optional[List[Dict]] = None,
base_model: Optional[BaseModel] = None,
base_model: Optional[
Union[Type[BaseModel], Type[List[Type[BaseModel]]]]
] = None,
function_call: Optional[Any] = None,
*args,
**kwargs,
Expand Down Expand Up @@ -184,13 +197,15 @@ def _invoke_llm(
)

# Configure function calling if applicable (only for non-streaming)
fn_kwargs = {}
if base_model and not kwargs.get("stream", False):
function_params = [convert_pydantic_model_to_openai_fn(base_model)]
if function_call is None:
function_call = {"name": function_params[0]["name"]}
fn_kwargs = {"functions": function_params, "function_call": function_call}
else:
fn_kwargs = {}
function_params = convert_pydantic_model_to_openai_fn(base_model)
if function_call is None and function_params:
function_call = {"name": function_params["name"]}
fn_kwargs = {
"functions": [function_params],
"function_call": function_call,
}

# Call OpenAI
if "api_key" in kwargs:
Expand Down Expand Up @@ -688,7 +703,9 @@ async def invoke_llm(
model: str = "gpt-3.5-turbo",
instructions: Optional[str] = None,
msg_history: Optional[List[Dict]] = None,
base_model: Optional[BaseModel] = None,
base_model: Optional[
Union[Type[BaseModel], Type[List[Type[BaseModel]]]]
] = None,
function_call: Optional[Any] = None,
*args,
**kwargs,
Expand Down Expand Up @@ -718,13 +735,15 @@ async def invoke_llm(
)

# Configure function calling if applicable
fn_kwargs = {}
if base_model:
function_params = [convert_pydantic_model_to_openai_fn(base_model)]
if function_call is None:
function_call = {"name": function_params[0]["name"]}
fn_kwargs = {"functions": function_params, "function_call": function_call}
else:
fn_kwargs = {}
function_params = convert_pydantic_model_to_openai_fn(base_model)
if function_call is None and function_params:
function_call = {"name": function_params["name"]}
fn_kwargs = {
"functions": [function_params],
"function_call": function_call,
}

# Call OpenAI
if "api_key" in kwargs:
Expand Down
11 changes: 11 additions & 0 deletions guardrails/utils/dataclass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import pydantic.version

PYDANTIC_VERSION = pydantic.version.VERSION

if PYDANTIC_VERSION.startswith("1"):

def dataclass(cls): # type: ignore
return cls

else:
from dataclasses import dataclass # type: ignore # noqa
138 changes: 119 additions & 19 deletions guardrails/utils/pydantic_utils/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,17 @@
from copy import deepcopy
from datetime import date, time
from enum import Enum
from typing import Any, Callable, Dict, Optional, Type, Union, get_args, get_origin
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Type,
Union,
get_args,
get_origin,
)

from pydantic import BaseModel, validator
from pydantic.fields import ModelField
Expand All @@ -22,6 +32,7 @@
from guardrails.datatypes import Object as ObjectDataType
from guardrails.datatypes import String as StringDataType
from guardrails.datatypes import Time as TimeDataType
from guardrails.utils.safe_get import safe_get
from guardrails.validator_base import Validator
from guardrails.validatorsattr import ValidatorsAttr

Expand Down Expand Up @@ -228,7 +239,78 @@ def process_validators(vals, fld):
return model_fields


def convert_pydantic_model_to_openai_fn(model: BaseModel) -> Dict:
def create_bare_model(model: Type[BaseModel]):
class BareModel(BaseModel):
__annotations__ = getattr(model, "__annotations__", {})

return BareModel


def reduce_to_annotations(type_annotation: Any) -> Type[Any]:
if (
type_annotation
and isinstance(type_annotation, type)
and issubclass(type_annotation, BaseModel)
):
return create_bare_model(type_annotation)
return type_annotation


def find_models_in_type(type_annotation: Any) -> Type[Any]:
type_origin = get_origin(type_annotation)
inner_types = get_args(type_annotation)
if type_origin == Union:
data_types = tuple([find_models_in_type(t) for t in inner_types])
return Type[Union[data_types]] # type: ignore
elif type_origin == list:
if len(inner_types) > 1:
raise ValueError("List data type must have exactly one child.")
# No List[List] support; we've already declared that in our types
item_type = safe_get(inner_types, 0)
return Type[List[find_models_in_type(item_type)]]
elif type_origin == dict:
# First arg is key which must be primitive
# Second arg is potentially a model
key_value_type = safe_get(inner_types, 1)
value_value_type = safe_get(inner_types, 1)
return Type[Dict[key_value_type, find_models_in_type(value_value_type)]]
else:
return reduce_to_annotations(type_annotation)


def schema_to_bare_model(model: Type[BaseModel]) -> Type[BaseModel]:
copy = deepcopy(model)
for field_key in copy.__fields__:
field = copy.__fields__.get(field_key)
if field:
extras = field.field_info.extra
if "validators" in extras:
extras["format"] = list(
v.to_prompt()
for v in extras.pop("validators", [])
if hasattr(v, "to_prompt")
)

field.field_info.extra = extras

value_type = find_models_in_type(field.annotation)
field.annotation = value_type
copy.__fields__[field_key] = field

# root_model = reduce_to_annotations(model)

# for key in root_model.__annotations__:
# value = root_model.__annotations__.get(key)
# print("value.field_info: ", value.field_info)
# value_type = find_models_in_type(value)
# root_model.__annotations__[key] = value_type

return copy


def convert_pydantic_model_to_openai_fn(
model: Union[Type[BaseModel], Type[List[Type[BaseModel]]]]
) -> Dict:
"""Convert a Pydantic BaseModel to an OpenAI function.

Args:
Expand All @@ -237,23 +319,41 @@ def convert_pydantic_model_to_openai_fn(model: BaseModel) -> Dict:
Returns:
OpenAI function paramters.
"""

# Create a bare model with no extra fields
class BareModel(BaseModel):
__annotations__ = model.__annotations__

# Convert Pydantic model to JSON schema
json_schema = BareModel.schema()

# Create OpenAI function parameters
fn_params = {
"name": json_schema["title"],
"parameters": json_schema,
}
if "description" in json_schema and json_schema["description"] is not None:
fn_params["description"] = json_schema["description"]

return fn_params
return {}

# schema_model = model

# type_origin = get_origin(model)
# if type_origin == list:
# item_types = get_args(model)
# if len(item_types) > 1:
# raise ValueError("List data type must have exactly one child.")
# # No List[List] support; we've already declared that in our types
# schema_model = safe_get(item_types, 0)

# # Create a bare model with no extra fields
# bare_model = schema_to_bare_model(schema_model)

# # Convert Pydantic model to JSON schema
# json_schema = bare_model.schema()
# json_schema["title"] = schema_model.__name__

# if type_origin == list:
# json_schema = {
# "title": f"Array<{json_schema.get('title')}>",
# "type": "array",
# "items": json_schema,
# }

# # Create OpenAI function parameters
# fn_params = {
# "name": json_schema["title"],
# "parameters": json_schema,
# }
# if "description" in json_schema and json_schema["description"] is not None:
# fn_params["description"] = json_schema["description"]

# return fn_params


def field_to_datatype(field: Union[ModelField, Type]) -> Type[DataType]:
Expand Down
48 changes: 37 additions & 11 deletions guardrails/utils/pydantic_utils/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,19 @@
from copy import deepcopy
from datetime import date, time
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, get_args
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Type,
TypeVar,
Union,
cast,
get_args,
get_origin,
)

from pydantic import BaseModel, ConfigDict, field_validator
from pydantic.fields import FieldInfo
Expand All @@ -21,6 +33,7 @@
from guardrails.datatypes import Object as ObjectDataType
from guardrails.datatypes import String as StringDataType
from guardrails.datatypes import Time as TimeDataType
from guardrails.utils.safe_get import safe_get
from guardrails.validator_base import Validator
from guardrails.validatorsattr import ValidatorsAttr

Expand Down Expand Up @@ -114,14 +127,9 @@ def is_enum(type_annotation: Any) -> bool:
return False


def _create_bare_model(model: Type[BaseModel]) -> Type[BaseModel]:
class BareModel(BaseModel):
__annotations__ = getattr(model, "__annotations__", {})

return BareModel


def convert_pydantic_model_to_openai_fn(model: BaseModel) -> Dict:
def convert_pydantic_model_to_openai_fn(
model: Union[Type[BaseModel], Type[List[Type[BaseModel]]]]
) -> Dict:
"""Convert a Pydantic BaseModel to an OpenAI function.

Args:
Expand All @@ -131,10 +139,28 @@ def convert_pydantic_model_to_openai_fn(model: BaseModel) -> Dict:
OpenAI function paramters.
"""

bare_model = _create_bare_model(type(model))
schema_model = model

type_origin = get_origin(model)
if type_origin == list:
item_types = get_args(model)
if len(item_types) > 1:
raise ValueError("List data type must have exactly one child.")
# No List[List] support; we've already declared that in our types
schema_model = safe_get(item_types, 0)

schema_model = cast(Type[BaseModel], schema_model)

# Convert Pydantic model to JSON schema
json_schema = bare_model.model_json_schema()
json_schema = schema_model.model_json_schema()
json_schema["title"] = schema_model.__name__

if type_origin == list:
json_schema = {
"title": f"Array<{json_schema.get('title')}>",
"type": "array",
"items": json_schema,
}

# Create OpenAI function parameters
fn_params = {
Expand Down
4 changes: 3 additions & 1 deletion guardrails/validator_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from guardrails.classes import InputType
from guardrails.constants import hub
from guardrails.errors import ValidationError
from guardrails.utils.dataclass import dataclass


class Filter:
Expand Down Expand Up @@ -231,10 +232,11 @@ class FailResult(ValidationResult):
fix_value: Optional[Any] = None


@dataclass # type: ignore
class Validator(Runnable):
"""Base class for validators."""

rail_alias: str
rail_alias: str = ""

run_in_separate_process = False
override_value_on_pass = False
Expand Down
Loading