Skip to content

Commit

Permalink
__get_pydantic_core_schema__
Browse files Browse the repository at this point in the history
  • Loading branch information
rok committed Jan 6, 2024
1 parent 2e3d2ac commit a83267c
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 12 deletions.
69 changes: 59 additions & 10 deletions python/lancedb/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,21 @@
import pydantic
import semver
from pydantic.fields import FieldInfo
from lance.arrow import (
EncodedImageScalar,
ImageURIScalar,
ImageURIArray,
EncodedImageType,
EncodedImageArray,
)

from .embeddings import EmbeddingFunctionRegistry

PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__)
try:
from pydantic_core import CoreSchema, core_schema
from pydantic import GetJsonSchemaHandler
from pydantic.json_schema import JsonSchemaValue
except ImportError:
if PYDANTIC_VERSION >= (2,):
raise
Expand Down Expand Up @@ -213,20 +222,42 @@ def value_arrow_type() -> pa.DataType:
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
) -> CoreSchema:
return core_schema.no_info_after_validator_function(
cls,
core_schema.str_schema(),
def validate_from_bytes(value: bytes) -> EncodedImageScalar:
return EncodedImageScalar(value)

from_bytes_schema = core_schema.chain_schema(
[
core_schema.bytes_schema(),
core_schema.no_info_plain_validator_function(validate_from_bytes),
]
)

return core_schema.json_or_python_schema(
json_schema=from_bytes_schema,
python_schema=core_schema.union_schema(
[
core_schema.is_instance_schema(EncodedImageArray),
from_bytes_schema,
]
),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: instance.values
),
)

@classmethod
def __get_pydantic_json_schema__(
cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
return handler(core_schema.bytes_schema())

@classmethod
def __get_validators__(cls) -> Generator[Callable, None, None]:
yield cls.validate

# For pydantic v1
# For pydantic v2
@classmethod
def validate(cls, v):
from lance.arrow import ImageURIArray, EncodedImageType, EncodedImageArray

if isinstance(v, ImageURIArray):
v = v.read_uris()
if isinstance(v, pa.BinaryArray):
Expand Down Expand Up @@ -278,16 +309,34 @@ def value_arrow_type() -> pa.DataType:
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
) -> CoreSchema:
return core_schema.no_info_after_validator_function(
cls,
core_schema.str_schema(),
def validate_from_str(value: str) -> ImageURIScalar:
return ImageURIScalar(value)

from_str_schema = core_schema.chain_schema(
[
core_schema.str_schema(),
core_schema.no_info_plain_validator_function(validate_from_str),
]
)

return core_schema.json_or_python_schema(
json_schema=from_str_schema,
python_schema=core_schema.union_schema(
[
core_schema.is_instance_schema(ImageURIArray),
from_str_schema,
]
),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: instance.values
),
)

@classmethod
def __get_validators__(cls) -> Generator[Callable, None, None]:
yield cls.validate

# For pydantic v1
# For pydantic v2
@classmethod
def validate(cls, v):
from lance.arrow import ImageURIArray, ImageURIType
Expand Down
7 changes: 5 additions & 2 deletions python/tests/test_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from datetime import date, datetime
from typing import List, Optional, Tuple

import numpy as np
from pathlib import Path
import pyarrow as pa
import pydantic
import pytest
Expand Down Expand Up @@ -237,6 +237,9 @@ def test_lance_model_with_lance_types():
png_uris = [
"file://" + os.path.join(os.path.dirname(__file__), "images/1.png"),
]
if os.name == "nt":
png_uris = [str(Path(x)) for x in png_uris]

default_image_uris = ImageURIArray.from_uris(png_uris)
default_encoded_images = default_image_uris.read_uris()

Expand All @@ -256,6 +259,6 @@ class TestModel(LanceModel):
assert expected_model == actual_model

actual_model = TestModel(
encoded_images=default_image_uris, image_uris=default_image_uris
encoded_images=default_encoded_images, image_uris=default_image_uris
)
assert expected_model == actual_model

0 comments on commit a83267c

Please sign in to comment.