Skip to content

Commit

Permalink
feat: Updated dspy/primitives/box.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sweep-ai[bot] committed Dec 20, 2023
1 parent 215893d commit 1068536
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions dspy/primitives/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class BoxType(type):
'radd', 'rsub', 'rmul', 'rtruediv', 'rfloordiv', 'rmod',
'rpow', 'rlshift', 'rrshift', 'rand', 'ror', 'rxor',
# Sequence operations
'getitem', 'setitem', 'delitem', 'contains',
# 'getitem', 'setitem', 'delitem', 'contains', # Handled separately to enable async operations
# Unary and other operations
'neg', 'pos', 'abs', 'invert', 'round', 'len',
'getitem', 'setitem', 'delitem', 'contains', 'iter',
Expand All @@ -111,13 +111,14 @@ class BoxType(type):

def __init__(cls, name, bases, attrs):
def create_method(op):
def method(self, other=None):
async def method(self, other=None):
if op in ['len', 'keys', 'values', 'items']:
return getattr(self._value, op)()
return await getattr(self._value, op)()
# 'getitem' and 'setitem' are special cases and must be handled outside of this loop
elif isinstance(other, Box):
return Box(getattr(self._value, f'__{op}__')(other._value))
return Box(await getattr(self._value, f'__{op}__')(other._value))
elif other is not None:
return Box(getattr(self._value, f'__{op}__')(other))
return Box(await getattr(self._value, f'__{op}__')(other))
else:
return NotImplemented
return method
Expand All @@ -127,6 +128,13 @@ def method(self, other=None):

super().__init__(name, bases, attrs)

async def __getitem__(self, index):
return await self._value.__getitem__(index)

async def __setitem__(self, index, value):
await self._value.__setitem__(index, value)


class Box(metaclass=BoxType):
def __init__(self, value, source=False):
self._value = value
Expand All @@ -142,8 +150,8 @@ def __bool__(self):
return bool(self._value)

# if method is missing just call it on the _value
def __getattr__(self, name):
return Box(getattr(self._value, name))
async def __getattr__(self, name):
return Box(await getattr(self._value, name))

# # Unlike the others, this one collapses to a bool directly
# def __eq__(self, other):
Expand Down

0 comments on commit 1068536

Please sign in to comment.