diff --git a/src/assertical/fake/generator.py b/src/assertical/fake/generator.py index 0915bff..2f05ee1 100644 --- a/src/assertical/fake/generator.py +++ b/src/assertical/fake/generator.py @@ -11,7 +11,6 @@ Optional, TypeVar, Union, - cast, get_args, get_origin, get_type_hints, @@ -55,6 +54,12 @@ class CollectionType(IntEnum): OPTIONAL_LIST = auto() # For type T - represents list[Optional[T]] REQUIRED_SET = auto() # For type T - represents set[T] OPTIONAL_SET = auto() # For type T - represents set[Optional[T]] + REQUIRED_DICT = auto() + OPTIONAL_DICT = auto() + + +SUPPORTED_COLLECTION_TYPES = {list, dict, set} +TWO_PARAMETER_COLLECTION_TYPES: set[CollectionType] = {CollectionType.OPTIONAL_DICT, CollectionType.REQUIRED_DICT} @dataclass @@ -85,12 +90,22 @@ class PropertyGenerationDetails: # For example, a list[int] would have type_to_generate as int and this property as REQUIRED_LIST collection_type: Optional[CollectionType] + second_type_to_generate: Optional[type] = None + second_is_primitive_type: Optional[bool] = None + second_is_optional: Optional[bool] = None + @dataclass class _PlaceholderDataclassBase: """Dataclass has no base class - instead we fall back to using this as a placeholder""" +@dataclass +class _PlaceholderCollectionBase: + """lists, dicts and sets have no base class other than object + - instead we fall back to using this as a placeholder""" + + AnyType = TypeVar("AnyType") @@ -248,6 +263,9 @@ def get_generatable_class_base(t: type) -> Optional[type]: if optional_arg is not None: target_type = optional_arg + if get_origin(target_type) in SUPPORTED_COLLECTION_TYPES: + return _PlaceholderCollectionBase + if not inspect.isclass(target_type): return None @@ -305,7 +323,7 @@ def enumerate_class_properties(t: type) -> Generator[PropertyGenerationDetails, if t_generatable_base is None: raise Exception(f"Type {t} does not inherit from one of {CLASS_INSTANCE_GENERATORS.keys()}") - type_hints = get_type_hints(t) + type_hints = TYPE_HINT_FETCHER[t_generatable_base](t) for member_name in CLASS_MEMBER_FETCHERS[t_generatable_base](t): @@ -320,8 +338,12 @@ def enumerate_class_properties(t: type) -> Generator[PropertyGenerationDetails, collection_type: Optional[CollectionType] = None is_optional: bool = False is_primitive: bool = False + second_type_to_generate: Optional[type] = None + second_is_primitive: Optional[bool] = None + second_is_optional: Optional[bool] = None + if member_name in type_hints: - declared_type = cast(type, type_hints[member_name]) + declared_type = type_hints[member_name] member_type = remove_passthrough_type(declared_type) optional_arg_type = get_optional_type_argument(member_type) is_optional = optional_arg_type is not None @@ -332,13 +354,37 @@ def enumerate_class_properties(t: type) -> Generator[PropertyGenerationDetails, collection_type = CollectionType.OPTIONAL_LIST elif get_origin(optional_arg_type) == set: collection_type = CollectionType.OPTIONAL_SET + elif get_origin(optional_arg_type) == dict: + collection_type = CollectionType.OPTIONAL_DICT else: if get_origin(member_type) == list: collection_type = CollectionType.REQUIRED_LIST elif get_origin(member_type) == set: collection_type = CollectionType.REQUIRED_SET + elif get_origin(member_type) == dict: + collection_type = CollectionType.REQUIRED_DICT if collection_type is not None: + # Determine second argument (if required) + if collection_type in TWO_PARAMETER_COLLECTION_TYPES: + second_member_type = get_args(optional_arg_type)[1] if is_optional else get_args(member_type)[1] + second_optional_arg_type = get_optional_type_argument(second_member_type) + second_is_optional = second_optional_arg_type is not None + if collection_type in (CollectionType.OPTIONAL_DICT, CollectionType.REQUIRED_DICT): + if is_generatable_type(second_member_type): + second_type_to_generate = get_first_generatable_primitive( + second_member_type, include_optional=False + ) + assert ( + second_type_to_generate is not None + ), f"Error generating member {member_name}. Couldn't find type for {second_member_type}" + second_is_primitive = True + elif get_generatable_class_base(second_member_type) is not None: + second_type_to_generate = ( + second_optional_arg_type if second_is_optional else second_member_type + ) + + # Determine first argument member_type = get_args(optional_arg_type)[0] if is_optional else get_args(member_type)[0] optional_arg_type = get_optional_type_argument(member_type) is_optional = optional_arg_type is not None @@ -373,6 +419,9 @@ def enumerate_class_properties(t: type) -> Generator[PropertyGenerationDetails, is_primitive_type=is_primitive, is_optional=is_optional, collection_type=collection_type, + second_type_to_generate=second_type_to_generate, + second_is_primitive_type=second_is_primitive, + second_is_optional=second_is_optional, ) @@ -381,9 +430,10 @@ def generate_class_instance( # noqa: C901 seed: int = 1, optional_is_none: bool = False, generate_relationships: bool = False, + _return_seed: bool = False, _visited_type_stack: Optional[list[type]] = None, **kwargs: Any, -) -> AnyType: +) -> Union[AnyType, tuple[AnyType, int]]: """Given a child class of a key to CLASS_INSTANCE_GENERATORS - generate an instance of that class with all properties being assigned unique values based off of seed. The values will match type hints @@ -407,7 +457,8 @@ def generate_class_instance( # noqa: C901 if _visited_type_stack is None: _visited_type_stack = [] if t in _visited_type_stack: - return None # type: ignore # This only happens in recursion - the top level object will never be None + # This only happens in recursion - the top level object will never be None + return (None, seed) if _return_seed else None # type: ignore _visited_type_stack.append(t) # We can only generate class instances of classes that inherit from a known base @@ -430,17 +481,46 @@ def generate_class_instance( # noqa: C901 continue if member.type_to_generate is None: - raise Exception( - f"Type {t} has property {member.name} with type {member.declared_type} that cannot be generated" - ) + # Don't raise exception for ungeneratable types if their value is going to be None + if not (optional_is_none and member.is_optional): + raise Exception( + f"Type {t} has property {member.name} with type {member.declared_type} that cannot be generated" + ) generated_value: Any = None empty_collection: bool = False collection_type: Optional[CollectionType] = member.collection_type + def generate_member( + is_primitive_type: bool, type_to_generate: type, current_seed: int, empty_collection: bool + ) -> tuple[Any, int, bool]: + if is_primitive_type: + generated_value = generate_value(type_to_generate, seed=current_seed, optional_is_none=optional_is_none) + current_seed += 1 + else: + generated_value = None + if generate_relationships: + generated_value, current_seed = generate_class_instance( + type_to_generate, + seed=current_seed, + optional_is_none=optional_is_none, + generate_relationships=generate_relationships, + _visited_type_stack=_visited_type_stack, + _return_seed=True, + ) + + # None can be generated when Type A has child B that includes a backreference to A. in these + # circumstances the visited_types short circuit will just return None from generate_class_instance + # (to stop infinite recursion) The way we handle this is to just generate an empty list (if this is + # a list entity) + if generated_value is None: + empty_collection = True + + return generated_value, current_seed, empty_collection + if optional_is_none and ( - member.collection_type == CollectionType.OPTIONAL_LIST - or member.collection_type == CollectionType.OPTIONAL_SET + member.collection_type + in [CollectionType.OPTIONAL_LIST, CollectionType.OPTIONAL_SET, CollectionType.OPTIONAL_DICT] ): # We can short circuit some generation if we know the top level collection should be None # In this case - we just set everything to None @@ -452,38 +532,32 @@ def generate_class_instance( # noqa: C901 # that are None - so we just add a None to the parent collection (or just generate None) generated_value = None current_seed += 1 - elif member.is_primitive_type: - generated_value = generate_value( - member.type_to_generate, seed=current_seed, optional_is_none=optional_is_none - ) - current_seed += 1 else: - if generate_relationships: - generated_value = generate_class_instance( - member.type_to_generate, - seed=current_seed, - optional_is_none=optional_is_none, - generate_relationships=generate_relationships, - _visited_type_stack=_visited_type_stack, - ) - - # None can be generated when Type A has child B that includes a backreference to A. in these - # circumstances the visited_types short circuit will just return None from generate_class_instance - # (to stop infinite recursion) The way we handle this is to just generate an empty list (if this is - # a list entity) - if generated_value is None: - empty_collection = True - # collection_type = CollectionType.REQUIRED_LIST - else: - # In this case we have a complex type but we aren't generating relationships - throw in a placeholder - empty_collection = True - generated_value = None - current_seed += 1000 # Rather than calculating how many seed values were utilised - set it arbitrarily high + generated_value, current_seed, empty_collection = generate_member( + is_primitive_type=member.is_primitive_type, + type_to_generate=member.type_to_generate, # type: ignore + current_seed=current_seed, + empty_collection=empty_collection, + ) if collection_type == CollectionType.REQUIRED_LIST or collection_type == CollectionType.OPTIONAL_LIST: values[member.name] = [] if empty_collection else [generated_value] elif collection_type == CollectionType.REQUIRED_SET or collection_type == CollectionType.OPTIONAL_SET: values[member.name] = set([]) if empty_collection else set([generated_value]) + elif collection_type == CollectionType.REQUIRED_DICT or collection_type == CollectionType.OPTIONAL_DICT: + if optional_is_none and member.second_is_optional: + # In this case the parent collection is NOT able to be set to None but does support adding items + # that are None - so we just add a None to the parent collection (or just generate None) + second_generated_value = None + current_seed += 1 + else: + second_generated_value, current_seed, empty_collection = generate_member( + is_primitive_type=member.second_is_primitive_type, # type: ignore + type_to_generate=member.second_type_to_generate, # type: ignore + current_seed=current_seed, + empty_collection=empty_collection, + ) + values[member.name] = {} if empty_collection else {generated_value: second_generated_value} else: values[member.name] = generated_value @@ -492,7 +566,9 @@ def generate_class_instance( # noqa: C901 raise Exception(f"The following kwargs were unused {expected_kwargs_references.difference(kwargs_references)}") _visited_type_stack.pop() # When we finish generating a type, allow recursion back into that type - return CLASS_INSTANCE_GENERATORS[t_generatable_base](t, values) + + instance = CLASS_INSTANCE_GENERATORS[t_generatable_base](t, values) + return (instance, current_seed) if _return_seed else instance def clone_class_instance(obj: AnyType, ignored_properties: Optional[set[str]] = None) -> AnyType: @@ -627,12 +703,18 @@ def register_value_generator(t: type, generator: Callable[[int], Any]) -> None: BASE_CLASS_PUBLIC_MEMBERS: dict[type, set[str]] = {} DEFAULT_CLASS_INSTANCE_GENERATOR: Callable[[type, dict[str, Any]], Any] = lambda target, kwargs: target(**kwargs) DEFAULT_MEMBER_FETCHER: Callable[[type], list[str]] = lambda target: [name for (name, _) in inspect.getmembers(target)] +DEFAULT_PUBLIC_MEMBER_CHECKER: Callable[[str], bool] = is_member_public + +TYPE_HINT_FETCHER: dict[type, Callable[[type], dict[str, type]]] = {} +DEFAULT_TYPE_HINT_FETCHER: Callable[[type], dict[str, type]] = get_type_hints def register_base_type( base_type: type, instance_generator: Callable[[type, dict[str, Any]], Any], member_fetcher: Callable[[type], list[str]], + public_member_checker: Callable[[str], bool] = DEFAULT_PUBLIC_MEMBER_CHECKER, + type_hint_fetcher: Callable[[type], dict[str, type]] = DEFAULT_TYPE_HINT_FETCHER, ) -> None: """Registers a type that will allow all subclasses to be generated/cloned by functions in this module. @@ -646,7 +728,8 @@ def register_base_type( polluting the global registry""" CLASS_INSTANCE_GENERATORS[base_type] = instance_generator CLASS_MEMBER_FETCHERS[base_type] = member_fetcher - BASE_CLASS_PUBLIC_MEMBERS[base_type] = set([m for m in member_fetcher(base_type) if is_member_public(m)]) + BASE_CLASS_PUBLIC_MEMBERS[base_type] = set([m for m in member_fetcher(base_type) if public_member_checker(m)]) + TYPE_HINT_FETCHER[base_type] = type_hint_fetcher # Base type registration @@ -656,6 +739,16 @@ def register_base_type( lambda target: [f.name for f in fields(target) if f.init], ) +# Handling of collections +register_base_type( + _PlaceholderCollectionBase, + lambda target, kwargs: kwargs["self"], + lambda _: ["self"], + lambda _: False, # "base class" doesn't have any public members + lambda target: {"self": target}, +) + + if "pydantic_xml" in sys.modules: register_base_type( BaseXmlModel, diff --git a/tests/fake/test_generator_common.py b/tests/fake/test_generator_common.py index 5170bbc..7b47948 100644 --- a/tests/fake/test_generator_common.py +++ b/tests/fake/test_generator_common.py @@ -87,6 +87,8 @@ def test_generate_value(): generate_value(RandomOtherClass, 1) with pytest.raises(Exception): generate_value(list[int], 1) + with pytest.raises(Exception): + generate_value(dict[str, int], 1) assert generate_value(str, 1, True) == generate_value(str, 1, True) assert generate_value(str, 1, True) is not generate_value(str, 1, True) @@ -230,6 +232,7 @@ def test_is_passthrough_type(): assert not is_passthrough_type(Union[str, int]) assert not is_passthrough_type(str) assert not is_passthrough_type(list[int]) + assert not is_passthrough_type(dict[str, int]) def test_remove_passthrough_type(): @@ -299,6 +302,7 @@ def test_get_first_generatable_primitive(): assert get_first_generatable_primitive(list[str], include_optional=True) is None assert get_first_generatable_primitive(list[int], include_optional=True) is None assert get_first_generatable_primitive(Mapped[list[str]], include_optional=True) is None + assert get_first_generatable_primitive(dict[str, int], include_optional=True) is None # With include_optional disabled assert get_first_generatable_primitive(int, include_optional=False) == int @@ -318,6 +322,7 @@ def test_get_first_generatable_primitive(): assert get_first_generatable_primitive(list[str], include_optional=False) is None assert get_first_generatable_primitive(list[int], include_optional=False) is None assert get_first_generatable_primitive(Mapped[list[str]], include_optional=False) is None + assert get_first_generatable_primitive(dict[str, int], include_optional=False) is None def test_get_first_generatable_primitive_py310_optional(): diff --git a/tests/fake/test_generator_dataclass.py b/tests/fake/test_generator_dataclass.py index 6a586e2..2d37ef2 100644 --- a/tests/fake/test_generator_dataclass.py +++ b/tests/fake/test_generator_dataclass.py @@ -4,10 +4,11 @@ from dataclasses import dataclass, field from datetime import datetime, time from typing import Generator, Optional +from pathlib import Path import pytest -from assertical.asserts.type import assert_list_type, assert_set_type +from assertical.asserts.type import assert_dict_type, assert_list_type, assert_set_type from assertical.fake.generator import ( CollectionType, PropertyGenerationDetails, @@ -27,6 +28,7 @@ class ParentDataclass: myStr: str myList: list[int] myTime: time + myDict: dict[str, int] @dataclass(frozen=True) @@ -43,11 +45,17 @@ class OptionalCollectionsClass: optional_int_vals: list[Optional[int]] optional_int_list: Optional[list[int]] optional_optional_ints: Optional[list[Optional[int]]] + refs: set[ReferenceDataclass] optional_refs_vals: set[Optional[ReferenceDataclass]] optional_refs_list: Optional[set[ReferenceDataclass]] optional_optional_refs: Optional[set[Optional[ReferenceDataclass]]] + dict_ints: dict[str, int] + dict_optional_ints: dict[str, Optional[int]] + optional_dict: Optional[dict[str, int]] + optional_dict_optional_ints: Optional[dict[str, Optional[int]]] + @dataclass class InitRestrictionsDataclass: @@ -59,6 +67,88 @@ def __post_init__(self): self.myRestrictedInt2 = 2 +@dataclass +class CollectionsDataclass: + l1: list[int] + l2: list[Optional[int]] + l3: Optional[list[int]] + l4: Optional[list[Optional[int]]] + l5: list[ReferenceDataclass] + l6: list[list[ReferenceDataclass]] + + s1: set[int] + s2: set[Optional[int]] + s3: Optional[set[int]] + s4: Optional[set[Optional[int]]] + s5: set[ReferenceDataclass] + + d1: dict[str, int] + d2: dict[str, ReferenceDataclass] + d3: dict[str, list[int]] + d4: dict[int, dict[str, list[int]]] + d5: dict[str, dict[str, int]] + d6: dict[str, dict[str, ReferenceDataclass]] + + +def test_dataclass_with_ungeneratable_type(): + @dataclass + class DataclassWithUngeneratableType: + path: Optional[Path] # assertical can't generate Path values + + with pytest.raises(Exception): + generate_class_instance(DataclassWithUngeneratableType) + + # Check optional_is_none allows by-passing ungeneratable types + generate_class_instance(DataclassWithUngeneratableType, optional_is_none=True) + + +def test_collections_dataclass(): + _ = generate_class_instance(CollectionsDataclass, seed=1, generate_relationships=True) + + +@pytest.mark.parametrize( + "t,optional_is_none,generate_relationships,expected_value", + [ + (list[int], True, True, [1]), + (list[int], True, False, [1]), + (list[int], False, True, [1]), + (list[int], False, False, [1]), + (Optional[list[int]], True, True, None), + (Optional[list[int]], True, False, None), + (Optional[list[int]], False, True, [1]), + (Optional[list[int]], False, False, [1]), + (list[Optional[int]], True, True, [None]), + (list[Optional[int]], True, False, [None]), + (list[Optional[int]], False, True, [1]), + (list[Optional[int]], False, False, [1]), + (Optional[list[Optional[int]]], True, True, None), + (Optional[list[Optional[int]]], True, False, None), + (Optional[list[Optional[int]]], False, True, [1]), + (Optional[list[Optional[int]]], False, False, [1]), + (list[ReferenceDataclass], True, True, [ReferenceDataclass(myOptInt=None, myInt=2)]), + (list[ReferenceDataclass], True, False, []), + (list[ReferenceDataclass], False, True, [ReferenceDataclass(myOptInt=1, myInt=2)]), + (list[ReferenceDataclass], False, False, []), + (dict[str, int], True, True, {"1-str": 2}), + (dict[str, ReferenceDataclass], True, True, {"1-str": ReferenceDataclass(myOptInt=None, myInt=3)}), + (dict[str, list[int]], True, True, {"1-str": [2]}), + (dict[str, dict[str, int]], True, True, {"1-str": {"2-str": 3}}), + (dict[int, dict[str, list[int]]], True, True, {1: {"2-str": [3]}}), + ( + dict[str, dict[str, ReferenceDataclass]], + True, + True, + {"1-str": {"2-str": ReferenceDataclass(myOptInt=None, myInt=4)}}, + ), + ], +) +def test_collections(t: type, optional_is_none: bool, generate_relationships: bool, expected_value): + value = generate_class_instance( + t, seed=1, optional_is_none=optional_is_none, generate_relationships=generate_relationships + ) + assert value == expected_value + + def test_clone_class_instance_dataclass(): original = generate_class_instance(ReferenceDataclass, generate_relationships=True) clone = clone_class_instance(original) @@ -249,6 +339,50 @@ def test_generate_kwargs(): True, CollectionType.OPTIONAL_SET, ), + PropertyGenerationDetails( + "dict_ints", + dict[str, int], + str, + True, + False, + CollectionType.REQUIRED_DICT, + int, + True, + False, + ), + PropertyGenerationDetails( + "dict_optional_ints", + dict[str, Optional[int]], + str, + True, + False, + CollectionType.REQUIRED_DICT, + int, + True, + True, + ), + PropertyGenerationDetails( + "optional_dict", + Optional[dict[str, int]], + str, + True, + False, + CollectionType.OPTIONAL_DICT, + int, + True, + False, + ), + PropertyGenerationDetails( + "optional_dict_optional_ints", + Optional[dict[str, Optional[int]]], + str, + True, + False, + CollectionType.OPTIONAL_DICT, + int, + True, + True, + ), ], ), ], @@ -296,3 +430,13 @@ def test_generate_OptionalCollectionsClass_relationships(): assert len(optional.optional_refs_vals) == 1 and list(optional.optional_refs_vals)[0] is None assert optional.optional_refs_list is None assert optional.optional_optional_refs is None + + assert_dict_type(str, int, all_set.dict_ints, count=1) + assert_dict_type(str, int, all_set.dict_optional_ints, count=1) + assert_dict_type(str, int, all_set.optional_dict, count=1) + assert_dict_type(str, int, all_set.optional_dict_optional_ints, count=1) + + assert_dict_type(str, int, optional.dict_ints, count=1) + assert len(optional.dict_optional_ints) == 1 and list(optional.dict_optional_ints.values()) == [None] + assert optional.optional_dict is None + assert optional.optional_dict_optional_ints is None