diff --git a/dspy/primitives/program.py b/dspy/primitives/program.py index 5567ea5ad..bb87e0538 100644 --- a/dspy/primitives/program.py +++ b/dspy/primitives/program.py @@ -7,6 +7,9 @@ class ProgramMeta(type): + """ + Metaclass for the Program class. + """ pass # def __call__(cls, *args, **kwargs): # obj = super(ProgramMeta, cls).__call__(*args, **kwargs) @@ -18,17 +21,43 @@ class ProgramMeta(type): class Module(BaseModule, metaclass=ProgramMeta): + """ + The Module class represents a module in the DSPy framework. + It provides methods for managing predictors and handling calls. + """ def _base_init(self): + """ + Initialize the base attributes of the Module instance. + """ self._compiled = False def __init__(self): + """ + Initialize a new instance of the Module class. + """ self._compiled = False def __call__(self, *args, **kwargs): + """ + Call the Module instance. + + Args: + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + Any: The result of the forward method. + """ return self.forward(*args, **kwargs) def named_predictors(self): + """ + Get the named predictors of the Module instance. + + Returns: + list: A list of tuples, where each tuple contains the name of a predictor and the predictor itself. + """ from dspy.predict.predict import Predict named_parameters = self.named_parameters() @@ -39,9 +68,21 @@ def named_predictors(self): ] def predictors(self): + """ + Get the predictors of the Module instance. + + Returns: + list: A list of predictors. + """ return [param for _, param in self.named_predictors()] def __repr__(self): + """ + Get a string representation of the Module instance. + + Returns: + str: A string representation of the Module instance. + """ s = [] for name, param in self.named_predictors(): @@ -50,7 +91,15 @@ def __repr__(self): return "\n".join(s) def map_named_predictors(self, func): - """Applies a function to all named predictors.""" + """ + Apply a function to all named predictors of the Module instance. + + Args: + func (function): The function to apply. + + Returns: + Module: The Module instance itself. + """ for name, predictor in self.named_predictors(): set_attribute_by_name(self, name, func(predictor)) return self @@ -81,6 +130,14 @@ def set_attribute_by_name(obj, name, value): list_pattern = re.compile(r"^([^\[]+)\[([0-9]+)\]$") dict_pattern = re.compile(r"^([^\[]+)\['([^']+)'\]$") + # Match for module.attribute pattern + module_match = module_pattern.match(name) + if module_match: + # Regular expressions for different patterns + module_pattern = re.compile(r"^([^.]+)\.(.+)$") + list_pattern = re.compile(r"^([^\[]+)\[([0-9]+)\]$") + dict_pattern = re.compile(r"^([^\[]+)\['([^']+)'\]$") + # Match for module.attribute pattern module_match = module_pattern.match(name) if module_match: