In [1]:
from typing import Any

from pydantic_core import core_schema
from typing_extensions import Annotated

from pydantic import (
    BaseModel,
    GetCoreSchemaHandler,
    GetJsonSchemaHandler,
    ValidationError,
)
from pydantic.json_schema import JsonSchemaValue


class ThirdPartyType:
    """
    This is meant to represent a type from a third-party library that wasn't designed with Pydantic
    integration in mind, and so doesn't have a `pydantic_core.CoreSchema` or anything.
    """

    x: int

    def __init__(self):
        self.x = 0


class _ThirdPartyTypePydanticAnnotation:
    @classmethod
    def __get_pydantic_core_schema__(
        cls,
        _source_type: Any,
        _handler: GetCoreSchemaHandler,
    ) -> core_schema.CoreSchema:
        """
        We return a pydantic_core.CoreSchema that behaves in the following ways:

        * ints will be parsed as `ThirdPartyType` instances with the int as the x attribute
        * `ThirdPartyType` instances will be parsed as `ThirdPartyType` instances without any changes
        * Nothing else will pass validation
        * Serialization will always return just an int
        """

        def validate_from_int(value: int) -> ThirdPartyType:
            result = ThirdPartyType()
            result.x = value
            return result
        print("HOWDY!")
        print(f"\t{_handler.field_name=}")
        from_int_schema = core_schema.chain_schema(
            [
                core_schema.int_schema(),
                core_schema.no_info_plain_validator_function(validate_from_int),
                #core_schema.int_schema()
            ]
        )

        return core_schema.json_or_python_schema(
            json_schema=from_int_schema,
            python_schema=core_schema.union_schema(
                [
                    # check if it's an instance first before doing any further work
                    core_schema.is_instance_schema(ThirdPartyType),
                    from_int_schema,
                ]
            ),
            # serialization=core_schema.plain_serializer_function_ser_schema(
            #     lambda instance: instance.x
            # ),
        )

    # @classmethod
    # def __get_pydantic_json_schema__(
    #     cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
    # ) -> JsonSchemaValue:
    #     # Use the same schema that would be used for `int`
    #     return handler(core_schema.int_schema())


# We now create an `Annotated` wrapper that we'll use as the annotation for fields on `BaseModel`s, etc.
PydanticThirdPartyType = Annotated[
    ThirdPartyType, _ThirdPartyTypePydanticAnnotation
]


# Create a model class that uses this annotation as a field
class Model(BaseModel):
    third_party_type: PydanticThirdPartyType


# Demonstrate that this field is handled correctly, that ints are parsed into `ThirdPartyType`, and that
# these instances are also "dumped" directly into ints as expected.
m_int = Model(third_party_type=1)
assert isinstance(m_int.third_party_type, ThirdPartyType)
assert m_int.third_party_type.x == 1
#assert m_int.model_dump() == {'third_party_type': 1}

# Do the same thing where an instance of ThirdPartyType is passed in
instance = ThirdPartyType()
assert instance.x == 0
instance.x = 10

m_instance = Model(third_party_type=instance)
assert isinstance(m_instance.third_party_type, ThirdPartyType)
assert m_instance.third_party_type.x == 10
#assert m_instance.model_dump() == {'third_party_type': 10}

# Demonstrate that validation errors are raised as expected for invalid inputs
try:
    Model(third_party_type='a')
except ValidationError as e:
    print(e)
    """
    2 validation errors for Model
    third_party_type.is-instance[ThirdPartyType]
      Input should be an instance of ThirdPartyType [type=is_instance_of, input_value='a', input_type=str]
    third_party_type.chain[int,function-plain[validate_from_int()]]
      Input should be a valid integer, unable to parse string as an integer [type=int_parsing, input_value='a', input_type=str]
    """


assert Model.model_json_schema() == {
    'properties': {
        'third_party_type': {'title': 'Third Party Type', 'type': 'integer'}
    },
    'required': ['third_party_type'],
    'title': 'Model',
    'type': 'object',
}


HOWDY!
	_handler.field_name='third_party_type'
2 validation errors for Model
third_party_type.is-instance[ThirdPartyType]
  Input should be an instance of ThirdPartyType [type=is_instance_of, input_value='a', input_type=str]
    For further information visit https://errors.pydantic.dev/2.6/v/is_instance_of
third_party_type.chain[int,function-plain[validate_from_int()]]
  Input should be a valid integer, unable to parse string as an integer [type=int_parsing, input_value='a', input_type=str]
    For further information visit https://errors.pydantic.dev/2.6/v/int_parsing


In [2]:
# Create a model class that uses this annotation as a field
class MyModel(BaseModel):
    third_party_type: PydanticThirdPartyType
    tpt: PydanticThirdPartyType

HOWDY!
	_handler.field_name='third_party_type'
HOWDY!
	_handler.field_name='tpt'


In [3]:
mm = MyModel(third_party_type=1, tpt=8)

In [4]:
mm.model_json_schema()

{'properties': {'third_party_type': {'title': 'Third Party Type',
   'type': 'integer'},
  'tpt': {'title': 'Tpt', 'type': 'integer'}},
 'required': ['third_party_type', 'tpt'],
 'title': 'MyModel',
 'type': 'object'}

In [83]:
from astropy.units import Unit, IrreducibleUnit, UnitBase
class _UnitTypePydanticAnnotation:
    @classmethod
    def __get_pydantic_core_schema__(
        cls,
        source_type: Any,
        _handler: GetCoreSchemaHandler,
    ) -> core_schema.CoreSchema:
        """
        We return a pydantic_core.CoreSchema that behaves in the following ways:

        * ints will be parsed as `ThirdPartyType` instances with the int as the x attribute
        * `ThirdPartyType` instances will be parsed as `ThirdPartyType` instances without any changes
        * Nothing else will pass validation
        * Serialization will always return just an int
        """

        def validate_from_str(value):
            result = source_type(value)  # Unit(value)
            return result
            
        print("HOWDY!")
        print(f"\t{_handler.field_name=}")
        from_str_schema = core_schema.chain_schema(
            [
                core_schema.str_schema(),
                core_schema.no_info_plain_validator_function(validate_from_str),
                #core_schema.int_schema()
            ]
        )
        from_float_schema = core_schema.chain_schema(
            [
                core_schema.float_schema(),
                core_schema.no_info_plain_validator_function(validate_from_str),
            ]
        )
        return core_schema.json_or_python_schema(
            json_schema=from_str_schema,
            python_schema=core_schema.union_schema(
                [
                    # check if it's an instance first before doing any further work
                    core_schema.is_instance_schema(UnitBase),
                    core_schema.is_instance_schema(Quantity),
                    from_str_schema,
                    from_float_schema,
                ]
            ),
            serialization=core_schema.plain_serializer_function_ser_schema(
                lambda instance: str(instance)
            ),
        )

    # @classmethod
    # def __get_pydantic_json_schema__(
    #     cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
    # ) -> JsonSchemaValue:
    #     # Use the same schema that would be used for `int`
    #     return handler(core_schema.int_schema())

# We now create an `Annotated` wrapper that we'll use as the annotation for fields on `BaseModel`s, etc.
PydanticUnitType = Annotated[
    Unit, _UnitTypePydanticAnnotation
]

In [84]:
class UnitModel(BaseModel):
    length: PydanticUnitType
    time: PydanticUnitType

HOWDY!
	_handler.field_name='length'
HOWDY!
	_handler.field_name='time'


In [85]:
um = UnitModel(length="meter", time="second")

In [86]:
um.model_json_schema()

{'properties': {'length': {'title': 'Length', 'type': 'string'},
  'time': {'title': 'Time', 'type': 'string'}},
 'required': ['length', 'time'],
 'title': 'UnitModel',
 'type': 'object'}

In [87]:
um

UnitModel(length=Unit("m"), time=Unit("s"))

In [88]:
um.model_dump()

{'length': 'm', 'time': 's'}

In [89]:
print(um.model_dump_json(indent=2))

{
  "length": "m",
  "time": "s"
}


In [90]:
um2 = UnitModel(length=Unit("meter"), time=Unit("second"))

In [91]:
um2

UnitModel(length=Unit("m"), time=Unit("s"))

In [92]:
um2.model_json_schema()

{'properties': {'length': {'title': 'Length', 'type': 'string'},
  'time': {'title': 'Time', 'type': 'string'}},
 'required': ['length', 'time'],
 'title': 'UnitModel',
 'type': 'object'}

From https://docs.pydantic.dev/2.5/concepts/json_schema/#modifying-the-schema

In [93]:
um2.model_dump()

{'length': 'm', 'time': 's'}

In [94]:
um3 = UnitModel(length=5, time="year")

In [95]:
um3

UnitModel(length=Unit(dimensionless with a scale of 5.0), time=Unit("yr"))

In [96]:
from astropy.units import Quantity

In [97]:
QuantityType = Annotated[
    Quantity,
    _UnitTypePydanticAnnotation
]

In [98]:
class QuantityModel(BaseModel):
    one: QuantityType
    two: QuantityType

HOWDY!
	_handler.field_name='one'
HOWDY!
	_handler.field_name='two'


In [103]:
qmod = QuantityModel(one="5 meter", two=Quantity("2 second"))

In [104]:
qmod

QuantityModel(one=<Quantity 5. m>, two=<Quantity 2. s>)

In [105]:
qmod.model_json_schema()

{'properties': {'one': {'title': 'One', 'type': 'string'},
  'two': {'title': 'Two', 'type': 'string'}},
 'required': ['one', 'two'],
 'title': 'QuantityModel',
 'type': 'object'}

In [106]:
qmod.model_dump()

{'one': '5.0 m', 'two': '2.0 s'}

In [47]:
q = Quantity(5)
isinstance(q, UnitBase)

False

In [None]:
from typing import Any, Type

from pydantic_core import ValidationError, core_schema
from typing_extensions import Annotated

from pydantic import BaseModel, GetCoreSchemaHandler


class AllowAnySubclass:
    def __get_pydantic_core_schema__(
        self, source: Type[Any], handler: GetCoreSchemaHandler
    ) -> core_schema.CoreSchema:
        # we can't call handler since it will fail for arbitrary types
        def validate(value: Any) -> Any:
            if not isinstance(value, source):
                raise ValueError(
                    f'Expected an instance of {source}, got an instance of {type(value)}'
                )

        return core_schema.no_info_plain_validator_function(validate)


class Foo:
    pass

class Goo(Foo):
    pass
    
class Model(BaseModel):
    f: Annotated[Foo, AllowAnySubclass()]


print(Model(f=Foo()))
#> f=None


class NotFoo:
    pass


try:
    Model(f=NotFoo())
except ValidationError as e:
    print(e)
    """
    1 validation error for Model
    f
      Value error, Expected an instance of <class '__main__.Foo'>, got an instance of <class '__main__.NotFoo'> [type=value_error, input_value=<__main__.NotFoo object at 0x0123456789ab>, input_type=NotFoo]
    """


In [None]:
class FooModel(BaseModel):
    f: Annotated[Foo, AllowAnySubclass()]
    g: Annotated[Foo, AllowAnySubclass()]

In [None]:
mmm =FooModel(f=Foo(), g=Goo())

In [None]:
mmm.model_dump_json()

In [None]:
mmm.model_json_schema()

From https://docs.pydantic.dev/2.5/concepts/types/#adding-validation-and-serialization

In [None]:
from pydantic import (
    AfterValidator,
    PlainSerializer,
    TypeAdapter,
    WithJsonSchema,
)

TruncatedFloat = Annotated[
    float,
    AfterValidator(lambda x: round(x, 1)),
    PlainSerializer(lambda x: f'{x:.1e}', return_type=str),
    WithJsonSchema({'type': 'string'}, mode='serialization'),
]

In [None]:
class TFModel(BaseModel):
    one: TruncatedFloat
    two: TruncatedFloat

In [None]:
tf_mod = TFModel(one=2.34, two=3.14)

In [None]:
tf_mod.model_json_schema()