Skip to content

Commit

Permalink
feat: Updated dspy/primitives/program.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sweep-ai[bot] committed Dec 20, 2023
1 parent 62d69ad commit 8b7b1b2
Showing 1 changed file with 58 additions and 1 deletion.
59 changes: 58 additions & 1 deletion dspy/primitives/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@


class ProgramMeta(type):
"""
Metaclass for the Program class.
"""
pass
# def __call__(cls, *args, **kwargs):
# obj = super(ProgramMeta, cls).__call__(*args, **kwargs)
Expand All @@ -18,17 +21,43 @@ class ProgramMeta(type):


class Module(BaseModule, metaclass=ProgramMeta):
"""
The Module class represents a module in the DSPy framework.
It provides methods for managing predictors and handling calls.
"""

def _base_init(self):
"""
Initialize the base attributes of the Module instance.
"""
self._compiled = False

def __init__(self):
"""
Initialize a new instance of the Module class.
"""
self._compiled = False

def __call__(self, *args, **kwargs):
"""
Call the Module instance.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
Any: The result of the forward method.
"""
return self.forward(*args, **kwargs)

def named_predictors(self):
"""
Get the named predictors of the Module instance.
Returns:
list: A list of tuples, where each tuple contains the name of a predictor and the predictor itself.
"""
from dspy.predict.predict import Predict

named_parameters = self.named_parameters()
Expand All @@ -39,9 +68,21 @@ def named_predictors(self):
]

def predictors(self):
"""
Get the predictors of the Module instance.
Returns:
list: A list of predictors.
"""
return [param for _, param in self.named_predictors()]

def __repr__(self):
"""
Get a string representation of the Module instance.
Returns:
str: A string representation of the Module instance.
"""
s = []

for name, param in self.named_predictors():
Expand All @@ -50,7 +91,15 @@ def __repr__(self):
return "\n".join(s)

def map_named_predictors(self, func):
"""Applies a function to all named predictors."""
"""
Apply a function to all named predictors of the Module instance.
Args:
func (function): The function to apply.
Returns:
Module: The Module instance itself.
"""
for name, predictor in self.named_predictors():
set_attribute_by_name(self, name, func(predictor))
return self
Expand Down Expand Up @@ -81,6 +130,14 @@ def set_attribute_by_name(obj, name, value):
list_pattern = re.compile(r"^([^\[]+)\[([0-9]+)\]$")
dict_pattern = re.compile(r"^([^\[]+)\['([^']+)'\]$")

# Match for module.attribute pattern
module_match = module_pattern.match(name)
if module_match:
# Regular expressions for different patterns
module_pattern = re.compile(r"^([^.]+)\.(.+)$")
list_pattern = re.compile(r"^([^\[]+)\[([0-9]+)\]$")
dict_pattern = re.compile(r"^([^\[]+)\['([^']+)'\]$")

# Match for module.attribute pattern
module_match = module_pattern.match(name)
if module_match:
Expand Down

0 comments on commit 8b7b1b2

Please sign in to comment.