In [1]:
from __future__ import annotations

import logging

import rich
from rich import print as rprint
from rich.logging import RichHandler

logging.basicConfig(
    level="INFO",
    format="%(message)s",
    datefmt="[%X]",
    handlers=[RichHandler(rich_tracebacks=True, enable_link_path=False)],
)

In [12]:
from abc import ABC
from collections.abc import Sequence
from typing import Annotated, Literal

from typing_extensions import TypeAliasType

import nshconfig as C


# This test is a simplified version of the one in nshtrainer,
# which currently fails with pydantic>=2.11, but works with pydantic<2.11.
# Track the following issue: https://github.com/pydantic/pydantic/issues/11682
class PluginBaseConfig(C.Config, ABC):
    pass


plugin_registry = C.Registry(
    PluginBaseConfig,
    discriminator="name",
    config={
        "duplicate_tag_policy": "warn-and-replace",
        "auto_rebuild": False,
    },
)
PluginConfig = TypeAliasType(
    "PluginConfig", Annotated[PluginBaseConfig, plugin_registry.DynamicResolution()]
)
rprint(plugin_registry)
rprint(plugin_registry.pydantic_schema())


@plugin_registry.rebuild_on_registers
class RootConfig(C.Config):
    plugins: Sequence[PluginConfig] | None = None


rprint(RootConfig(plugins=[]))
rprint(plugin_registry)


In [13]:
@plugin_registry.register
class Plugin1(PluginBaseConfig):
    name: Literal["plugin1"] = "plugin1"
    value: int = 42


print(plugin_registry)
rprint(plugin_registry.pydantic_schema())


print(RootConfig(plugins=[]))
print(RootConfig(plugins=[Plugin1()]))

Registry(base_cls=<class '__main__.PluginBaseConfig'>, discriminator='name', config={'duplicate_tag_policy': 'warn-and-replace', 'auto_rebuild': False}, _elements=[_RegistryEntry(tag='plugin1', cls=<class '__main__.Plugin1'>)], _on_register_callbacks=[<function Registry.rebuild_on_registers.<locals>._rebuild at 0x7e07a6f92d40>])


plugins=[]
plugins=[Plugin1(name='plugin1', value=42)]


In [14]:
RootConfig.model_rebuild()
rprint(RootConfig.__pydantic_core_schema__)

In [None]:
from typing import Any, TypeAlias

A: TypeAlias = Annotated[PluginBaseConfig, plugin_registry.DynamicResolution()]


@plugin_registry.register
class Plugin2(PluginBaseConfig):
    name: Literal["plugin2"] = "plugin2"

    nested_plugin: A | None = None


rprint(plugin_registry)
rprint(plugin_registry.pydantic_schema())

rprint(RootConfig(plugins=[]))
rprint(RootConfig(plugins=[Plugin1()]))
rprint(RootConfig(plugins=[Plugin2()]))

RecursionError: maximum recursion depth exceeded

In [5]:
PluginBaseConfig.__pydantic_core_schema__

{'type': 'model',
 'cls': __main__.PluginBaseConfig,
 'schema': {'type': 'model-fields',
  'fields': {},
  'model_name': 'PluginBaseConfig',
  'computed_fields': []},
 'custom_init': False,
 'root_model': False,
 'post_init': 'model_post_init',
 'config': {'title': 'PluginBaseConfig',
  'extra_fields_behavior': 'ignore',
  'strict': True,
  'revalidate_instances': 'always',
  'validate_default': True,
  'validation_error_cause': True},
 'serialization': {'type': 'function-wrap',
  'function': <function nshconfig._src.config.Config.include_literals(self, next_serializer)>,
  'info_arg': False},
 'ref': '__main__.PluginBaseConfig:944843392',
 'metadata': {'pydantic_js_functions': [<bound method BaseModel.__get_pydantic_json_schema__ of <class '__main__.PluginBaseConfig'>>]}}

In [6]:
C.TypeAdapter(PluginBaseConfig).core_schema

{'type': 'model',
 'cls': __main__.PluginBaseConfig,
 'schema': {'type': 'model-fields',
  'fields': {},
  'model_name': 'PluginBaseConfig',
  'computed_fields': []},
 'custom_init': False,
 'root_model': False,
 'post_init': 'model_post_init',
 'config': {'title': 'PluginBaseConfig',
  'extra_fields_behavior': 'ignore',
  'strict': True,
  'revalidate_instances': 'always',
  'validate_default': True,
  'validation_error_cause': True},
 'serialization': {'type': 'function-wrap',
  'function': <function nshconfig._src.config.Config.include_literals(self, next_serializer)>,
  'info_arg': False},
 'ref': '__main__.PluginBaseConfig:944843392',
 'metadata': {'pydantic_js_functions': [<bound method BaseModel.__get_pydantic_json_schema__ of <class '__main__.PluginBaseConfig'>>]}}

In [7]:
from pydantic._internal import (
    _config,
    _generate_schema,
    _mock_val_ser,
    _namespace_utils,
    _repr,
    _typing_extra,
    _utils,
)

localns = {}
globalns = {}
ns_resolver = _namespace_utils.NsResolver(
    namespaces_tuple=_namespace_utils.NamespacesTuple(locals=localns, globals=globalns),
    parent_namespace=localns,
)
config_wrapper = _config.ConfigWrapper({})
schema_generator = _generate_schema.GenerateSchema(
    config_wrapper, ns_resolver=ns_resolver
)
core_schema = schema_generator.generate_schema(PluginBaseConfig)
pprint(core_schema)

core_schema = schema_generator.clean_schema(core_schema)
pprint(core_schema)
pprint(PluginBaseConfig.__pydantic_core_schema__)

{'schema_ref': '__main__.PluginBaseConfig:944843392', 'type': 'definition-ref'}
{'cls': <class '__main__.PluginBaseConfig'>,
 'config': {'extra_fields_behavior': 'ignore',
            'revalidate_instances': 'always',
            'strict': True,
            'title': 'PluginBaseConfig',
            'validate_default': True,
            'validation_error_cause': True},
 'custom_init': False,
 'metadata': {'pydantic_js_functions': [<bound method BaseModel.__get_pydantic_json_schema__ of <class '__main__.PluginBaseConfig'>>]},
 'post_init': 'model_post_init',
 'ref': '__main__.PluginBaseConfig:944843392',
 'root_model': False,
 'schema': {'computed_fields': [],
            'fields': {},
            'model_name': 'PluginBaseConfig',
            'type': 'model-fields'},
 'serialization': {'function': <function Config.include_literals at 0x7eccac0653a0>,
                   'info_arg': False,
                   'type': 'function-wrap'},
 'type': 'model'}
{'cls': <class '__main__.PluginBaseConf