In [None]:
from typing import Any, List, Union, Dict
from pydantic import BaseModel, field_validator, ValidationError, PrivateAttr


class SingleTypeMixin:
    """Mixin to enforce only single float values for `param1` and `param2`."""

    # @field_validator("param1", "param2", mode="before")
    # @classmethod
    # def enforce_single_type(cls, value):
    #     if isinstance(value, list):
    #         raise ValueError("Lists are not allowed for this class.")
    #     if not isinstance(value, (float, int)):  # Allowing int to be coerced to float
    #         raise ValueError("Must be a float.")
    #     return float(value)   
    # 
    

def validate_nested_single(cls):
    """Ensures nested attributes (including inside dictionaries) only contain floats."""
    @field_validator("nested", mode="before")
    @classmethod
    def enforce_nested_single(cls, value):
        if isinstance(value, BaseModel):  # Validate Pydantic object
            for field, field_value in value.model_dump().items():
                if isinstance(field_value, list):
                    raise ValueError(f"Nested attribute '{field}' must not be a list.")
                # if not isinstance(field_value, (float, int)):
                #     raise ValueError(f"Nested attribute '{field}' must be a float.")

        elif isinstance(value, dict):  # Validate dictionary contents
            for key, dict_value in value.items():
                if isinstance(dict_value, BaseModel):  # Recursively validate BaseModel objects
                    for field, field_value in dict_value.model_dump().items():
                        if isinstance(field_value, list):
                            raise ValueError(f"Nested dictionary attribute '{key}.{field}' must not be a list.")
                        # if not isinstance(field_value, (float, int)):
                        #     raise ValueError(f"Nested dictionary attribute '{key}.{field}' must be a float.")

        return value

    cls.model_rebuild()  # Rebuild model after adding validator
    return cls





from itertools import product
class MultiTemplate(BaseModel):
    """
    """
    _multi_params: dict = PrivateAttr(default={})  # Private storage
    _sonata_config: dict = PrivateAttr(default={})
    
    _single_version_class_name: str = PrivateAttr(default="")
    
    
    @property
    def multi_params(self) -> dict:
        
        for attr_name, attr_value in self.__dict__.items():

            print("KEY:", attr_name, "VALUE:", attr_value)

            if isinstance(attr_value, dict) and all(isinstance(dict_val, MultiTemplate) for dict_key, dict_val in attr_value.items()):
                for dict_key, dict_val in attr_value.items():
                    for key, value in dict_val.__dict__.items():
                        if not isinstance(value, BaseModel) and isinstance(value, list) and len(value) > 1:
                            self._multi_params[f"{attr_name}.{dict_key}.{key}"] = {
                                "coord_param_keys": [attr_name, dict_key, key],
                                "coord_param_values": value
                            }
                            
        return self._multi_params


    def generate_grid_scan_coords(self) -> list:

        all_tuples = []
        for key, value in self.multi_params.items():
            tups = []
            for k, v in zip([value["coord_param_keys"] for i in range(len(value['coord_param_values']))], value['coord_param_values']):
                tups.append((k, v))

            all_tuples.append(tups)

        coords = [coord for coord in product(*all_tuples)]
        return coords
    
    def cast_to_single_instance(self):
        # NOT AN IDEAL WAY OF FINDING THE NEW CLASS (Is there an alternative?)
        class_to_cast_to = globals()[self._single_version_class_name] 

        # BOTH BELOW WORK, not sure which is better
        # single_instance = class_to_cast_to(**self.model_dump())
        single_instance = class_to_cast_to.model_validate(self.model_dump())
        return single_instance



import copy
class Campaign(BaseModel):

    template_instance: MultiTemplate = None

    _coord_instances: list = PrivateAttr(default=[])

    @property
    def coord_instances(self) -> list[MultiTemplate]:

        if len(self._coord_instances) > 0: return self._coord_instances

        for coord in self.template_instance.generate_grid_scan_coords():

            coord_template_instance = copy.deepcopy(self.template_instance)

            for param in coord:
                keys = param[0]
                val = param[1]

                current_level = coord_template_instance
                for i, key in enumerate(keys):

                    if isinstance(current_level, MultiTemplate):

                        if i == len(keys) - 1:
                            current_level.__dict__[key] = val
                        else:
                            current_level = current_level.__dict__[key]
                
                    elif isinstance(current_level, dict):
                        current_level = current_level[key]


            try:
                coord_instance = coord_template_instance.cast_to_single_instance()
                
            except ValidationError as e:
                print("Validation Error:", e)

            self._coord_instances.append(coord_instance)
        
        return self._coord_instances




    
class Stimulus(MultiTemplate):
    """A nested model containing numeric attributes."""
    nested_param1: Union[float, List[float]]
    nested_param2: Union[float, List[float]]

class SimulationCampaignTemplate(MultiTemplate):
    """Base simulation model that contains a generic nested object."""
    # param1: Union[float, List[float]]
    # param2: Union[float, List[float]]
    nested: dict[str, Stimulus]

    _single_version_class_name: str = PrivateAttr(default="Simulation")

@validate_nested_single
class Simulation(SimulationCampaignTemplate, SingleTypeMixin):
    """Only allows single float values and ensures nested attributes follow the same rule."""
    pass


# # Valid Simulation instance (SingleType) with direct nested object
# nested_single = Stimulus(nested_param1=1.5, nested_param2=2.5)  # ✅ Must be single floats
# print(nested_single)
# simulation = Simulation(param1=2.0, param2=3.0, nested={"config": nested_single})
# print(simulation)

# Valid SimulationCampaignTemplate instance (MultiType) with direct nested object
nested_multi = Stimulus(nested_param1=[1.5, 2.5], nested_param2=[3.5, 7.0])  # ✅ Allows lists
simulation_campaign_template = SimulationCampaignTemplate(nested={"config": nested_multi})
# print(simulation_campaign_template)

campaign = Campaign(template_instance=simulation_campaign_template)
print(campaign.coord_instances)

# # Valid Simulation with nested dictionary containing BaseModel
# simulation_with_dict = Simulation(
#     param1=2.0, 
#     param2=3.0, 
#     nested={"config": Stimulus(nested_param1=2.2, nested_param2=4.4)}
# )
# print(simulation_with_dict)

# # Valid SimulationCampaignTemplate with nested dictionary containing BaseModel
# simulation_campaign_template_with_dict = SimulationCampaignTemplate(
#     param1=[2.0, 4.0], 
#     param2=3.0, 
#     nested={"config": Stimulus(nested_param1=[1.1, 2.2], nested_param2=3.3)}
# )
# print(simulation_campaign_template_with_dict)

# print("\nNew Test\n")


# print(simulation_campaign_template.multi_params)
# print(simulation_campaign_template.generate_grid_scan_coords())
# print(simulation_campaign_template_with_dict.multi_params)

# coord_simulation = copy.deepcopy(simulation_campaign_template_with_dict)





KEY: nested VALUE: {'config': Stimulus(nested_param1=[1.5, 2.5], nested_param2=[3.5, 7.0])}
[Simulation(nested={'config': Stimulus(nested_param1=1.5, nested_param2=3.5)}), Simulation(nested={'config': Stimulus(nested_param1=1.5, nested_param2=7.0)}), Simulation(nested={'config': Stimulus(nested_param1=2.5, nested_param2=3.5)}), Simulation(nested={'config': Stimulus(nested_param1=2.5, nested_param2=7.0)})]


In [12]:
# ❌ Nested dictionary contains a BaseModel with a list (not allowed in Simulation)
try:
    invalid_simulation_dict = Simulation(
        param1=2.0, 
        param2=3.0, 
        nested={"config": Stimulus(nested_param1=[1.5, 2.5], nested_param2=3.5)}
    )
except ValueError as e:
    print(e)

# ❌ Nested dictionary contains a BaseModel with a string (not allowed in SimulationCampaignTemplate)
try:
    invalid_simulation_campaign_template_dict = SimulationCampaignTemplate(
        param1=[2.0, 4.0], 
        param2=3.0, 
        nested={"config": Stimulus(nested_param1="invalid", nested_param2=3.5)}
    )
except ValueError as e:
    print(e)

2 validation errors for Stimulus
nested_param1.float
  Input should be a valid number, unable to parse string as a number [type=float_parsing, input_value='invalid', input_type=str]
    For further information visit https://errors.pydantic.dev/2.10/v/float_parsing
nested_param1.list[float]
  Input should be a valid list [type=list_type, input_value='invalid', input_type=str]
    For further information visit https://errors.pydantic.dev/2.10/v/list_type


In [None]:
# class MultiTypeMixin:
#     """Mixin to allow both single float and list of floats for `param1` and `param2`."""

#     @field_validator("param1", "param2", mode="before")
#     @classmethod
#     def allow_list_type(cls, value):
#         if isinstance(value, (float, int)):
#             return float(value)  # Convert int to float if necessary
#         if isinstance(value, list) and all(isinstance(v, (float, int)) for v in value):
#             return [float(v) for v in value]  # Convert all ints to floats
#         raise ValueError("Must be a float or a list of floats.")

# def validate_nested_multi(cls):
#     """Ensures nested attributes (including inside dictionaries) allow floats and lists of floats."""
#     @field_validator("nested", mode="before")
#     @classmethod
#     def enforce_nested_multi(cls, value):
#         if isinstance(value, BaseModel):  # Validate Pydantic object
#             for field, field_value in value.model_dump().items():
#                 if not (isinstance(field_value, (float, int)) or
#                         (isinstance(field_value, list) and all(isinstance(v, (float, int)) for v in field_value))):
#                     raise ValueError(f"Nested attribute '{field}' must be a float or list of floats.")

#         elif isinstance(value, dict):  # Validate dictionary contents
#             for key, dict_value in value.items():
#                 if isinstance(dict_value, BaseModel):  # Recursively validate BaseModel objects
#                     for field, field_value in dict_value.model_dump().items():
#                         if not (isinstance(field_value, (float, int)) or
#                                 (isinstance(field_value, list) and all(isinstance(v, (float, int)) for v in field_value))):
#                             raise ValueError(f"Nested dictionary attribute '{key}.{field}' must be a float or list of floats.")

    #     return value

    # cls.model_rebuild()  # Rebuild model after adding validator
    # return cls

# @validate_nested_multi
# class SimulationCampaignTemplate(MultiTypeMixin, BaseTemplate):
#     """Allows both single float and list of floats, ensuring nested attributes follow the same rule."""
#     pass