From 07ca21a72a063ef720af84322afc002bd16c7b22 Mon Sep 17 00:00:00 2001 From: Jiangzhou He Date: Tue, 25 Nov 2025 09:09:32 -0800 Subject: [PATCH] fix: skip None input arguments for batched custom functions --- python/cocoindex/op.py | 46 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 42 insertions(+), 4 deletions(-) diff --git a/python/cocoindex/op.py b/python/cocoindex/op.py index dd888f26..c14f845a 100644 --- a/python/cocoindex/op.py +++ b/python/cocoindex/op.py @@ -369,10 +369,30 @@ async def prepare(self) -> None: async def __call__(self, *args: Any, **kwargs: Any) -> Any: decoded_args = [] - for arg_info, arg in zip(self._args_info, args): - if arg_info.is_required and arg is None: + skipped_idx: list[int] | None = None + if op_args.batching: + if len(args) != 1: + raise ValueError( + "Batching is only supported for single argument functions" + ) + arg_info = self._args_info[0] + if arg_info.is_required and args[0] is None: return None - decoded_args.append(arg_info.decoder(arg)) + decoded = arg_info.decoder(args[0]) + if arg_info.is_required: + skipped_idx = [i for i, arg in enumerate(decoded) if arg is None] + if len(skipped_idx) > 0: + decoded = [v for v in decoded if v is not None] + if len(decoded) == 0: + return [None for _ in range(len(skipped_idx))] + else: + skipped_idx = None + decoded_args.append(decoded) + else: + for arg_info, arg in zip(self._args_info, args): + if arg_info.is_required and arg is None: + return None + decoded_args.append(arg_info.decoder(arg)) decoded_kwargs = {} for kwarg_name, arg in kwargs.items(): @@ -387,7 +407,25 @@ async def __call__(self, *args: Any, **kwargs: Any) -> Any: assert self._acall is not None output = await self._acall(*decoded_args, **decoded_kwargs) - return self._result_encoder(output) + + if skipped_idx is None: + return self._result_encoder(output) + + padded_output: list[Any] = [] + next_idx = 0 + for v in output: + while next_idx < len(skipped_idx) and skipped_idx[next_idx] == len( + padded_output + ): + next_idx += 1 + padded_output.append(None) + padded_output.append(v) + + while next_idx < len(skipped_idx): + padded_output.append(None) + next_idx += 1 + + return self._result_encoder(padded_output) def enable_cache(self) -> bool: return op_args.cache