From 02cd5ea0c9edf4d04c2feec6ba001a25a21936ea Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 20 Oct 2025 17:12:20 +0800 Subject: [PATCH 1/3] fix infer pt dp --- swift/llm/infer/infer.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/swift/llm/infer/infer.py b/swift/llm/infer/infer.py index 9017aecf00..6e6ac8e1b4 100644 --- a/swift/llm/infer/infer.py +++ b/swift/llm/infer/infer.py @@ -247,7 +247,10 @@ def infer_dataset(self) -> List[Dict[str, Any]]: while idx < len(val_dataset): shard_size = min(args.write_batch_size, len(val_dataset) - idx) shard_dataset = val_dataset.select(range(idx, idx + shard_size)) - result_list += self._batch_infer(shard_dataset, request_config) + result = self._batch_infer(shard_dataset, request_config) + if self.jsonl_writer: + self.jsonl_writer.append(result, gather_obj=True) + result_list += result idx += shard_size prog_bar.update(shard_size) prog_bar.close() @@ -267,6 +270,10 @@ def _batch_infer(self, val_dataset, request_config): data_parallel_size = args.global_world_size // args.vllm_tensor_parallel_size else: rank, data_parallel_size = args.rank, args.global_world_size + if len(val_dataset) < data_parallel_size: + data_parallel_size = len(val_dataset) + if rank >= len(val_dataset): + return [] if rank >= 0 and data_parallel_size > 1: val_dataset = val_dataset.shard(data_parallel_size, rank, contiguous=True) val_dataset = list(val_dataset) @@ -279,14 +286,13 @@ def _batch_infer(self, val_dataset, request_config): labels_list.append(labels) resp_list = self.infer(val_dataset, request_config, template=self.template, use_tqdm=True, **self.infer_kwargs) - if not (args.infer_backend == 'vllm' and rank >= 0 and args.rank % args.vllm_tensor_parallel_size != 0): + if not (args.infer_backend == 'vllm' and rank >= 0 + and args.rank % args.vllm_tensor_parallel_size != 0): # DP & TP for data, resp, labels in zip(val_dataset, resp_list, labels_list): response = resp.choices[0].message.content data['messages'].append({'role': 'assistant', 'content': response}) data = {'response': response, 'labels': labels, 'logprobs': resp.choices[0].logprobs, **data} result_list.append(data) - if self.jsonl_writer: - self.jsonl_writer.append(result_list, gather_obj=True) return result_list From 993851aeb2ca71efbbe262fad982ca44e8dcede4 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 20 Oct 2025 17:20:33 +0800 Subject: [PATCH 2/3] update --- swift/llm/infer/infer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/swift/llm/infer/infer.py b/swift/llm/infer/infer.py index 6e6ac8e1b4..e7c5b281e8 100644 --- a/swift/llm/infer/infer.py +++ b/swift/llm/infer/infer.py @@ -270,6 +270,7 @@ def _batch_infer(self, val_dataset, request_config): data_parallel_size = args.global_world_size // args.vllm_tensor_parallel_size else: rank, data_parallel_size = args.rank, args.global_world_size + # The dataset is insufficient for DP partitioning if len(val_dataset) < data_parallel_size: data_parallel_size = len(val_dataset) if rank >= len(val_dataset): From 2fc7bb90ecb0c4005281cdb29772659f8ea67750 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 20 Oct 2025 17:22:06 +0800 Subject: [PATCH 3/3] fix --- swift/llm/infer/infer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swift/llm/infer/infer.py b/swift/llm/infer/infer.py index 6e6ac8e1b4..02acaa05d5 100644 --- a/swift/llm/infer/infer.py +++ b/swift/llm/infer/infer.py @@ -271,9 +271,9 @@ def _batch_infer(self, val_dataset, request_config): else: rank, data_parallel_size = args.rank, args.global_world_size if len(val_dataset) < data_parallel_size: - data_parallel_size = len(val_dataset) if rank >= len(val_dataset): return [] + data_parallel_size = len(val_dataset) if rank >= 0 and data_parallel_size > 1: val_dataset = val_dataset.shard(data_parallel_size, rank, contiguous=True) val_dataset = list(val_dataset)