Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 42 additions & 4 deletions python/cocoindex/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,10 +369,30 @@ async def prepare(self) -> None:

async def __call__(self, *args: Any, **kwargs: Any) -> Any:
decoded_args = []
for arg_info, arg in zip(self._args_info, args):
if arg_info.is_required and arg is None:
skipped_idx: list[int] | None = None
if op_args.batching:
if len(args) != 1:
raise ValueError(
"Batching is only supported for single argument functions"
)
arg_info = self._args_info[0]
if arg_info.is_required and args[0] is None:
return None
decoded_args.append(arg_info.decoder(arg))
decoded = arg_info.decoder(args[0])
if arg_info.is_required:
skipped_idx = [i for i, arg in enumerate(decoded) if arg is None]
if len(skipped_idx) > 0:
decoded = [v for v in decoded if v is not None]
if len(decoded) == 0:
return [None for _ in range(len(skipped_idx))]
else:
skipped_idx = None
decoded_args.append(decoded)
else:
for arg_info, arg in zip(self._args_info, args):
if arg_info.is_required and arg is None:
return None
decoded_args.append(arg_info.decoder(arg))

decoded_kwargs = {}
for kwarg_name, arg in kwargs.items():
Expand All @@ -387,7 +407,25 @@ async def __call__(self, *args: Any, **kwargs: Any) -> Any:

assert self._acall is not None
output = await self._acall(*decoded_args, **decoded_kwargs)
return self._result_encoder(output)

if skipped_idx is None:
return self._result_encoder(output)

padded_output: list[Any] = []
next_idx = 0
for v in output:
while next_idx < len(skipped_idx) and skipped_idx[next_idx] == len(
padded_output
):
next_idx += 1
padded_output.append(None)
padded_output.append(v)

while next_idx < len(skipped_idx):
padded_output.append(None)
next_idx += 1

return self._result_encoder(padded_output)

def enable_cache(self) -> bool:
return op_args.cache
Expand Down
Loading