In [1]:
from pydantic import BaseModel

In [2]:
from pathlib import Path
from typing import Any, Literal

from pydantic import BaseModel, ConfigDict, ValidatorFunctionWrapHandler, model_validator


def get_subclasses_recursive[T](cls: type[T]) -> list[type[T]]:
    """Returns all the subclasses of a given class."""
    subclasses = []
    for subclass in cls.__subclasses__():
        subclasses.append(subclass)
        subclasses.extend(get_subclasses_recursive(subclass))
    return subclasses


def get_subclass_recursive[T](cls: type[T], name: str) -> type[T]:
    return next(c for c in get_subclasses_recursive(cls=cls) if c.__qualname__ == name)


class OBIBaseModel(BaseModel):
    """Sets `type` fields for model_dump which are then used for desserialization.

    Sets encoder for EntitySDK Entities
    """


    type: str = ""

    model_config = ConfigDict(json_encoders={Path: str})

    @model_validator(mode="before")
    @classmethod
    def set_type(cls, data: Any) -> dict[str, Any]:
        """Automatically sets `type` when instantiated in Python if a dictionary."""
        if isinstance(data, dict) and "type" not in data:
            data["type"] = cls.__qualname__
        return data

    def __init_subclass__(cls, **kwargs) -> None:
        """Dynamically set the `type` field to the class name."""
        super().__init_subclass__(**kwargs)
        cls.__annotations__["type"] = Literal[cls.__qualname__]

    def __str__(self) -> str:
        """Return a string representation of the OBIBaseModel object."""
        return self.__repr__()


    @model_validator(mode="wrap")
    @classmethod
    def retrieve_type_on_deserialization(cls, 
                                        value: Any, 
                                        handler: ValidatorFunctionWrapHandler
                ) -> "OBIBaseModel":

        if isinstance(value, dict):

            # `sub_cls(**modified_value)` will trigger a recursion, and thus we need to
            # remove `type` from the dictionary before passing it to the subclass constructor
            
            modified_value = value.copy()
            sub_cls_name = modified_value.pop("type", None)
            
            if sub_cls_name is not None:

                sub_cls = get_subclass_recursive(
                    cls=OBIBaseModel,
                    name=sub_cls_name,  # , allow_same_class=True
                )

                return sub_cls(**modified_value)

            return handler(value)

        return handler(value)


In [3]:
from pydantic import PrivateAttr

from obi_one.core.base import OBIBaseModel
from obi_one.core.param import MultiValueScanParam


class Block(OBIBaseModel):
    """Defines a component of a Form.

    Parameters can be of type | list[type]
    when a list is used it is used as a dimension in a multi-dimensional parameter scan.
    Tuples should be used when list-like parameter is needed.
    """

    _multiple_value_parameters: list[MultiValueScanParam] = PrivateAttr(default=[])

    def multiple_value_parameters(
        self, category_name: str, block_key: str = ""
    ) -> list[MultiValueScanParam]:
        """Return a list of MultiValueScanParam objects for the block."""
        self._multiple_value_parameters = []

        for key, value in self.__dict__.items():
            if isinstance(value, list):  # and len(value) > 1:
                multi_values = value
                if block_key:
                    self._multiple_value_parameters.append(
                        MultiValueScanParam(
                            location_list=[category_name, block_key, key], values=multi_values
                        )
                    )
                else:
                    self._multiple_value_parameters.append(
                        MultiValueScanParam(location_list=[category_name, key], values=multi_values)
                    )

        return self._multiple_value_parameters

    def enforce_no_lists(self) -> None:
        """Raise a TypeError if any attribute is a list."""
        for key, value in self.__dict__.items():
            if isinstance(value, list):
                msg = f"Attribute '{key}' must not be a list."
                raise TypeError(msg)


In [4]:
from typing import Union, ClassVar, get_args
from pydantic import Field, field_validator
import abc
class BlockReference(OBIBaseModel, abc.ABC):
    block_dict_name: str = Field(default="")
    block_name: str = Field(default="")

    allowed_block_types: ClassVar[Any] = None

    _block: Block = None

    def allowed_block_type_names(allowed_block_types) -> list[str]:
        """Returns a list of allowed block type names."""
        if allowed_block_types is None:
            raise ValueError("allowed_block_types must be set in the subclass.")
            
        return [block.__name__ for block in get_args(allowed_block_types)]
    
    @property
    def block(self) -> Block:
        """Returns the block associated with this reference."""
        if self._block is None:
            raise ValueError("Block has not been set.")
        return self._block

    @block.setter
    def block(self, value: Block) -> None:
        """Sets the block associated with this reference."""
        if not isinstance(value, self.allowed_block_types):
            raise TypeError(f"Value must be of type {self.block_type.__name__}.")
        self._block = value

In [5]:
from typing import ClassVar
class Form(OBIBaseModel):
    """A Form is a configuration for single or multi-dimensional parameter scans.

    A Form is composed of Blocks, which either appear at the root level
    or within dictionaries of Blocks where the dictionary is takes a Union of Block types.
    """

    name: ClassVar[str] = "Add a name class' name variable"
    description: ClassVar[str] = """Add a description to the class' description variable"""
    single_coord_class_name: ClassVar[str] = ""

    @model_validator(mode="after")
    def reference_blocks(self):

        for attr_name, attr_value in self.__dict__.items():
            # Check if the attribute is a dictionary of Block instances
            if isinstance(attr_value, dict) and all(
                isinstance(dict_val, Block) for dict_key, dict_val in attr_value.items()
            ):
                category_blocks_dict = attr_value

                # If so iterate through the dictionary's Block instances
                for block_key, block in category_blocks_dict.items():
                    for block_attr_name, block_attr_value in block.__dict__.items():
                        # If the Block instance has a `block` attribute, set it to the Form instance
                        if isinstance(block_attr_value, BlockReference):
                            block_reference = block_attr_value

                            block_reference.block = self.__dict__[block_reference.block_dict_name][block_reference.block_name]
        return self

    def cast_to_single_coord(self) -> OBIBaseModel:
        """Cast the form to a single coordinate object."""
        module = __import__(self.__module__)
        class_to_cast_to = getattr(module, self.single_coord_class_name)
        single_coord = class_to_cast_to.model_construct(**self.__dict__)
        single_coord.type = self.single_coord_class_name
        return single_coord

    @property
    def single_coord_scan_default_subpath(self) -> str:
        return self.single_coord_class_name + "/"

In [6]:


class PredefinedNeuronSet(Block):
    """Any predefined set of neurons."""
    param_a: str = ""

class IDNeuronSet(Block):
    """Any predefined set of neurons."""
    param_a: str = ""

NeuronSetUnion = (
    PredefinedNeuronSet |
    IDNeuronSet
)

# neuron_set_union_names = [cls.__name__ for cls in NeuronSetUnion]
# print(neuron_set_union_names)
# print(NeuronSetUnion.__name__)





    
class NeuronSetBlockReference(BlockReference):

    """"""
    
    allowed_block_types: ClassVar[Any] = NeuronSetUnion

    block_name: str = Field(default="", allowed_block_types=BlockReference.allowed_block_type_names(allowed_block_types))


    
    

class SpikeStimulus(Block):
    """A stimulus is a component of a Form that defines a stimulus to be applied."""

    neuron_set: NeuronSetBlockReference = None
    param_a: float = Field(default=0.0, description="Parameter A for the stimulus")

class RateStimulus(Block):
    """A stimulus is a component of a Form that defines a stimulus to be applied."""

    neuron_set: NeuronSetBlockReference = None
    param_a: float = Field(default=0.0, description="Parameter A for the stimulus")
    
StimulusUnion = (
    SpikeStimulus |
    RateStimulus
)




In [7]:
class SimulationsForm(Form):
    """A Form for defining simulations."""

    name: ClassVar[str] = "Simulations Form"
    description: ClassVar[str] = "A form for defining simulations."
    single_coord_class_name: ClassVar[str] = "Simulation"

    neuron_sets: dict[str, NeuronSetUnion] = {}
    stimuli: dict[str, StimulusUnion] = {}

# Show the JSON schema of SimulationsForm
import json
print(json.dumps(SimulationsForm.model_json_schema(), indent=2))
    

{
  "$defs": {
    "IDNeuronSet": {
      "description": "Any predefined set of neurons.",
      "properties": {
        "type": {
          "const": "IDNeuronSet",
          "title": "Type",
          "type": "string"
        },
        "param_a": {
          "default": "",
          "title": "Param A",
          "type": "string"
        }
      },
      "required": [
        "type"
      ],
      "title": "IDNeuronSet",
      "type": "object"
    },
    "NeuronSetBlockReference": {
      "properties": {
        "type": {
          "const": "NeuronSetBlockReference",
          "title": "Type",
          "type": "string"
        },
        "block_dict_name": {
          "default": "",
          "title": "Block Dict Name",
          "type": "string"
        },
        "block_name": {
          "allowed_block_types": [
            "PredefinedNeuronSet",
            "IDNeuronSet"
          ],
          "default": "",
          "title": "Block Name",
          "type": "string"
        }
  

In [8]:
simulations_form = SimulationsForm(
    neuron_sets={
        "example_predefined_neuron_set": PredefinedNeuronSet(param_a="j"),
        "example_id_neuron_set": IDNeuronSet(param_a="j")
    },
    stimuli={
        "example_spike_stimulus": SpikeStimulus(
            param_a=0.5,
            neuron_set=NeuronSetBlockReference(
                block_name="example_predefined_neuron_set",
                block_dict_name="neuron_sets",
            )
        ),
        "example_rate_stimulus": RateStimulus(
            parma_a=0.5,
            neuron_set=NeuronSetBlockReference(
                block_name="example_id_neuron_set",
                block_dict_name="neuron_sets",
            )
        )
    }
)



In [9]:
simulations_form.stimuli["example_spike_stimulus"].neuron_set.block

PredefinedNeuronSet(type='PredefinedNeuronSet', param_a='j')

In [10]:
import json
print(json.dumps(simulations_form.model_dump(), indent=2))


{
  "type": "SimulationsForm",
  "neuron_sets": {
    "example_predefined_neuron_set": {
      "type": "PredefinedNeuronSet",
      "param_a": "j"
    },
    "example_id_neuron_set": {
      "type": "IDNeuronSet",
      "param_a": "j"
    }
  },
  "stimuli": {
    "example_spike_stimulus": {
      "type": "SpikeStimulus",
      "neuron_set": {
        "type": "NeuronSetBlockReference",
        "block_dict_name": "neuron_sets",
        "block_name": "example_predefined_neuron_set"
      },
      "param_a": 0.5
    },
    "example_rate_stimulus": {
      "type": "RateStimulus",
      "neuron_set": {
        "type": "NeuronSetBlockReference",
        "block_dict_name": "neuron_sets",
        "block_name": "example_id_neuron_set"
      },
      "param_a": 0.0
    }
  }
}
