Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 57 additions & 97 deletions src/gradient/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,74 +14,69 @@

# --------------- Pydantic v2, v3 compatibility ---------------

# Pyright incorrectly reports some of our functions as overriding a method when they don't
# pyright: reportIncompatibleMethodOverride=false

PYDANTIC_V1 = pydantic.VERSION.startswith("1.")

if TYPE_CHECKING:

def parse_date(value: date | StrBytesIntFloat) -> date: # noqa: ARG001
...

def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: # noqa: ARG001
...

def get_args(t: type[Any]) -> tuple[Any, ...]: # noqa: ARG001
...

def is_union(tp: type[Any] | None) -> bool: # noqa: ARG001
...

def get_origin(t: type[Any]) -> type[Any] | None: # noqa: ARG001
...

def is_literal_type(type_: type[Any]) -> bool: # noqa: ARG001
...

def is_typeddict(type_: type[Any]) -> bool: # noqa: ARG001
...
def parse_date(value: date | StrBytesIntFloat) -> date: ...
def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: ...
def get_args(t: type[Any]) -> tuple[Any, ...]: ...
def is_union(tp: type[Any] | None) -> bool: ...
def get_origin(t: type[Any]) -> type[Any] | None: ...
def is_literal_type(type_: type[Any]) -> bool: ...
def is_typeddict(type_: type[Any]) -> bool: ...

else:
# v1 re-exports
if PYDANTIC_V1:
# Pydantic v1 re-exports
from pydantic.typing import (
get_args as get_args,
is_union as is_union,
get_origin as get_origin,
is_typeddict as is_typeddict,
is_literal_type as is_literal_type,
get_args,
is_union,
get_origin,
is_typeddict,
is_literal_type,
)
from pydantic.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime
from pydantic.datetime_parse import parse_date, parse_datetime
else:
from ._utils import (
get_args as get_args,
is_union as is_union,
get_origin as get_origin,
parse_date as parse_date,
is_typeddict as is_typeddict,
parse_datetime as parse_datetime,
is_literal_type as is_literal_type,
)


# refactored config
# Safe fallback import — avoids crash if _utils missing
try:
from ._utils import (
get_args,
is_union,
get_origin,
parse_date,
is_typeddict,
parse_datetime,
is_literal_type,
)
except ModuleNotFoundError:
# fallback to pydantic builtins if _utils missing
from pydantic import TypeAdapter
def get_args(t: Any) -> tuple[Any, ...]: return getattr(t, "__args__", ())
def get_origin(t: Any) -> Any: return getattr(t, "__origin__", None)
def parse_date(value): return TypeAdapter(date).validate_python(value)
def parse_datetime(value): return TypeAdapter(datetime).validate_python(value)
def is_union(tp): return getattr(tp, "__origin__", None) is Union
def is_literal_type(t): return getattr(t, "__origin__", None).__name__ == "Literal"
def is_typeddict(t): return hasattr(t, "__annotations__")


# ---------------- ConfigDict handling ----------------
if TYPE_CHECKING:
from pydantic import ConfigDict as ConfigDict
from pydantic import ConfigDict
else:
if PYDANTIC_V1:
# TODO: provide an error message here?
ConfigDict = None
else:
from pydantic import ConfigDict as ConfigDict
from pydantic import ConfigDict


# ---------------- Core compatibility helpers ----------------

# renamed methods / properties
def parse_obj(model: type[_ModelT], value: object) -> _ModelT:
if PYDANTIC_V1:
return cast(_ModelT, model.parse_obj(value)) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
else:
return model.model_validate(value)
return cast(_ModelT, model.parse_obj(value))
return model.model_validate(value)


def field_is_required(field: FieldInfo) -> bool:
Expand All @@ -95,40 +90,27 @@ def field_get_default(field: FieldInfo) -> Any:
if PYDANTIC_V1:
return value
from pydantic_core import PydanticUndefined

if value == PydanticUndefined:
return None
return value
return None if value == PydanticUndefined else value


def field_outer_type(field: FieldInfo) -> Any:
if PYDANTIC_V1:
return field.outer_type_ # type: ignore
return field.annotation
return field.outer_type_ if PYDANTIC_V1 else field.annotation


def get_model_config(model: type[pydantic.BaseModel]) -> Any:
if PYDANTIC_V1:
return model.__config__ # type: ignore
return model.model_config
return model.__config__ if PYDANTIC_V1 else model.model_config


def get_model_fields(model: type[pydantic.BaseModel]) -> dict[str, FieldInfo]:
if PYDANTIC_V1:
return model.__fields__ # type: ignore
return model.model_fields
return model.__fields__ if PYDANTIC_V1 else model.model_fields


def model_copy(model: _ModelT, *, deep: bool = False) -> _ModelT:
if PYDANTIC_V1:
return model.copy(deep=deep) # type: ignore
return model.model_copy(deep=deep)
return model.copy(deep=deep) if PYDANTIC_V1 else model.model_copy(deep=deep)


def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str:
if PYDANTIC_V1:
return model.json(indent=indent) # type: ignore
return model.model_dump_json(indent=indent)
return model.json(indent=indent) if PYDANTIC_V1 else model.model_dump_json(indent=indent)


def model_dump(
Expand All @@ -140,18 +122,18 @@ def model_dump(
warnings: bool = True,
mode: Literal["json", "python"] = "python",
) -> dict[str, Any]:
if (not PYDANTIC_V1) or hasattr(model, "model_dump"):
if hasattr(model, "model_dump"): # Pydantic v2+
return model.model_dump(
mode=mode,
exclude=exclude,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
# warnings are not supported in Pydantic v1
warnings=True if PYDANTIC_V1 else warnings,
warnings=warnings,
)
# Pydantic v1 fallback
return cast(
"dict[str, Any]",
model.dict( # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
model.dict(
exclude=exclude,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
Expand All @@ -160,60 +142,38 @@ def model_dump(


def model_parse(model: type[_ModelT], data: Any) -> _ModelT:
if PYDANTIC_V1:
return model.parse_obj(data) # pyright: ignore[reportDeprecated]
return model.model_validate(data)
return model.parse_obj(data) if PYDANTIC_V1 else model.model_validate(data)


# generic models
# ---------------- GenericModel compatibility ----------------
if TYPE_CHECKING:

class GenericModel(pydantic.BaseModel): ...

else:
if PYDANTIC_V1:
import pydantic.generics

class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel): ...
else:
# there no longer needs to be a distinction in v2 but
# we still have to create our own subclass to avoid
# inconsistent MRO ordering errors
class GenericModel(pydantic.BaseModel): ...


# cached properties
# ---------------- cached_property handling ----------------
if TYPE_CHECKING:
cached_property = property

# we define a separate type (copied from typeshed)
# that represents that `cached_property` is `set`able
# at runtime, which differs from `@property`.
#
# this is a separate type as editors likely special case
# `@property` and we don't want to cause issues just to have
# more helpful internal types.

class typed_cached_property(Generic[_T]):
func: Callable[[Any], _T]
attrname: str | None

def __init__(self, func: Callable[[Any], _T]) -> None: ...

@overload
def __get__(self, instance: None, owner: type[Any] | None = None) -> Self: ...

@overload
def __get__(self, instance: object, owner: type[Any] | None = None) -> _T: ...

def __get__(self, instance: object, owner: type[Any] | None = None) -> _T | Self:
raise NotImplementedError()

def __set_name__(self, owner: type[Any], name: str) -> None: ...

# __set__ is not defined at runtime, but @cached_property is designed to be settable
def __set__(self, instance: object, value: _T) -> None: ...
else:
from functools import cached_property as cached_property

typed_cached_property = cached_property