From 78faee3090f7f9f8a8e324734be0f0372a520435 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 14 Nov 2025 15:17:39 +0800 Subject: [PATCH 1/5] fix kto apo_zero_unpaired --- swift/llm/template/base.py | 11 ++++++++--- swift/llm/train/kto.py | 21 +++++++++++---------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index fce20eb7d2..2f6f4bd7ec 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -357,11 +357,16 @@ def get_base_model(model): else: return model - def _rlhf_encode(self, inputs: TemplateInputs) -> Dict[str, Any]: + def _rlhf_encode(self, inputs: TemplateInputs, check_rejected=True) -> Dict[str, Any]: chosen = inputs.chosen margin = chosen.margin chosen_encoded = self._encode_truncated(chosen) - rejected_encoded = self._encode_truncated(inputs.rejected) + if check_rejected and inputs.rejected is None: + raise ValueError('inputs.rejected is None') + if inputs.rejected is None: + rejected_encoded = {} + else: + rejected_encoded = self._encode_truncated(inputs.rejected) encoded = {} for prefix in ['chosen', 'rejected']: @@ -373,7 +378,7 @@ def _rlhf_encode(self, inputs: TemplateInputs) -> Dict[str, Any]: return encoded def _kto_encode(self, inputs: TemplateInputs) -> Dict[str, Any]: - encoded = self._rlhf_encode(inputs) + encoded = self._rlhf_encode(inputs, check_rejected=False) encoded['label'] = bool(inputs.chosen.label) return encoded diff --git a/swift/llm/train/kto.py b/swift/llm/train/kto.py index 966c11cb61..b84db2c32a 100644 --- a/swift/llm/train/kto.py +++ b/swift/llm/train/kto.py @@ -41,16 +41,17 @@ def _get_kl_dataset(dataset: Optional[HfDataset], def prepare_kto_dataset(args, train_dataset, val_dataset): - world_size = get_dist_setting()[2] - if hasattr(args, 'global_batch_size') and args.global_batch_size is not None: - total_batch_size = args.global_batch_size - else: - total_batch_size = (world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps) - if total_batch_size <= 1: - raise ValueError('Batch size is 1 (too small). KTO will not work properly because the KL term ' - 'will be equivalent to the implied reward.') - train_dataset = _get_kl_dataset(train_dataset, total_batch_size, args.dataset_num_proc, args.data_seed) - val_dataset = _get_kl_dataset(val_dataset, total_batch_size, args.dataset_num_proc, args.data_seed) + if args.loss_type != 'apo_zero_unpaired': + world_size = get_dist_setting()[2] + if hasattr(args, 'global_batch_size') and args.global_batch_size is not None: + total_batch_size = args.global_batch_size + else: + total_batch_size = (world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps) + if total_batch_size <= 1: + raise ValueError('Batch size is 1 (too small). KTO will not work properly because the KL term ' + 'will be equivalent to the implied reward.') + train_dataset = _get_kl_dataset(train_dataset, total_batch_size, args.dataset_num_proc, args.data_seed) + val_dataset = _get_kl_dataset(val_dataset, total_batch_size, args.dataset_num_proc, args.data_seed) label = train_dataset['label'] num_desirable = max(sum(label), 1) From 4aac2de26afe2ec4d58837a105d429c03a8e66e8 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 14 Nov 2025 15:22:10 +0800 Subject: [PATCH 2/5] fix --- swift/llm/template/base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index 2f6f4bd7ec..e2129f8f7f 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -1490,7 +1490,10 @@ def _kto_data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optiona kl_batch = self._fetch_inputs_startswith(batch, 'rejected_') res = self._data_collator(new_batch, padding_to=padding_to) - kl_res = self._data_collator(kl_batch, padding_to=padding_to) + if any(kl_batch): + kl_res = self._data_collator(kl_batch, padding_to=padding_to) + else: + kl_res = {} res = { **{f'completion_{k}': v for k, v in res.items()}, From 8a8588c5a823bd5b94cd89348cb665628938e8b5 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 14 Nov 2025 16:00:52 +0800 Subject: [PATCH 3/5] update --- swift/llm/template/base.py | 4 ++-- swift/megatron/trainers/kto_trainer.py | 6 +++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index e2129f8f7f..a402e56af6 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -361,9 +361,9 @@ def _rlhf_encode(self, inputs: TemplateInputs, check_rejected=True) -> Dict[str, chosen = inputs.chosen margin = chosen.margin chosen_encoded = self._encode_truncated(chosen) - if check_rejected and inputs.rejected is None: - raise ValueError('inputs.rejected is None') if inputs.rejected is None: + if check_rejected: + raise ValueError('inputs.rejected is None') rejected_encoded = {} else: rejected_encoded = self._encode_truncated(inputs.rejected) diff --git a/swift/megatron/trainers/kto_trainer.py b/swift/megatron/trainers/kto_trainer.py index d0a385aa41..eaf0d8c2c1 100644 --- a/swift/megatron/trainers/kto_trainer.py +++ b/swift/megatron/trainers/kto_trainer.py @@ -155,7 +155,11 @@ def _prepare_batch(self, data, vp_stage): num_samples = data.pop('num_samples') for key in ['completion_', 'KL_completion_']: _data = {k[len(key):]: v for k, v in data.items() if k.startswith(key)} - res.append(super()._prepare_batch(_data, vp_stage, num_samples)) + if not self.args.calculate_KL and key == 'KL_completion_': + _data = {} + else: + _data = super()._prepare_batch(_data, vp_stage, num_samples) + res.append(_data) res[0]['label'] = data['label'] return res From a41a4d3b10e3d2b2ce768ce43ff37793c862fb33 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 14 Nov 2025 16:08:46 +0800 Subject: [PATCH 4/5] lint pass --- swift/llm/train/kto.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swift/llm/train/kto.py b/swift/llm/train/kto.py index b84db2c32a..63da51013b 100644 --- a/swift/llm/train/kto.py +++ b/swift/llm/train/kto.py @@ -49,7 +49,7 @@ def prepare_kto_dataset(args, train_dataset, val_dataset): total_batch_size = (world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps) if total_batch_size <= 1: raise ValueError('Batch size is 1 (too small). KTO will not work properly because the KL term ' - 'will be equivalent to the implied reward.') + 'will be equivalent to the implied reward.') train_dataset = _get_kl_dataset(train_dataset, total_batch_size, args.dataset_num_proc, args.data_seed) val_dataset = _get_kl_dataset(val_dataset, total_batch_size, args.dataset_num_proc, args.data_seed) From b76e1b603f966c5e3a13cf09b728424151f91e93 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 14 Nov 2025 16:17:59 +0800 Subject: [PATCH 5/5] fix --- swift/megatron/trainers/kto_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swift/megatron/trainers/kto_trainer.py b/swift/megatron/trainers/kto_trainer.py index 7d3b228d88..9ddb8ab343 100644 --- a/swift/megatron/trainers/kto_trainer.py +++ b/swift/megatron/trainers/kto_trainer.py @@ -76,7 +76,7 @@ def loss_func(self, output_tensor, *, data, kl_data, label): loss = loss.mean() mean_metric = { 'loss': loss.detach().clone(), - 'kl': kl.detach(), + 'kl': kl.squeeze().detach(), } metric = self._all_reduce_metric(mean_metric) sum_metric = {