diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index fce20eb7d2..a402e56af6 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -357,11 +357,16 @@ def get_base_model(model): else: return model - def _rlhf_encode(self, inputs: TemplateInputs) -> Dict[str, Any]: + def _rlhf_encode(self, inputs: TemplateInputs, check_rejected=True) -> Dict[str, Any]: chosen = inputs.chosen margin = chosen.margin chosen_encoded = self._encode_truncated(chosen) - rejected_encoded = self._encode_truncated(inputs.rejected) + if inputs.rejected is None: + if check_rejected: + raise ValueError('inputs.rejected is None') + rejected_encoded = {} + else: + rejected_encoded = self._encode_truncated(inputs.rejected) encoded = {} for prefix in ['chosen', 'rejected']: @@ -373,7 +378,7 @@ def _rlhf_encode(self, inputs: TemplateInputs) -> Dict[str, Any]: return encoded def _kto_encode(self, inputs: TemplateInputs) -> Dict[str, Any]: - encoded = self._rlhf_encode(inputs) + encoded = self._rlhf_encode(inputs, check_rejected=False) encoded['label'] = bool(inputs.chosen.label) return encoded @@ -1485,7 +1490,10 @@ def _kto_data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optiona kl_batch = self._fetch_inputs_startswith(batch, 'rejected_') res = self._data_collator(new_batch, padding_to=padding_to) - kl_res = self._data_collator(kl_batch, padding_to=padding_to) + if any(kl_batch): + kl_res = self._data_collator(kl_batch, padding_to=padding_to) + else: + kl_res = {} res = { **{f'completion_{k}': v for k, v in res.items()}, diff --git a/swift/llm/train/kto.py b/swift/llm/train/kto.py index 966c11cb61..63da51013b 100644 --- a/swift/llm/train/kto.py +++ b/swift/llm/train/kto.py @@ -41,16 +41,17 @@ def _get_kl_dataset(dataset: Optional[HfDataset], def prepare_kto_dataset(args, train_dataset, val_dataset): - world_size = get_dist_setting()[2] - if hasattr(args, 'global_batch_size') and args.global_batch_size is not None: - total_batch_size = args.global_batch_size - else: - total_batch_size = (world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps) - if total_batch_size <= 1: - raise ValueError('Batch size is 1 (too small). KTO will not work properly because the KL term ' - 'will be equivalent to the implied reward.') - train_dataset = _get_kl_dataset(train_dataset, total_batch_size, args.dataset_num_proc, args.data_seed) - val_dataset = _get_kl_dataset(val_dataset, total_batch_size, args.dataset_num_proc, args.data_seed) + if args.loss_type != 'apo_zero_unpaired': + world_size = get_dist_setting()[2] + if hasattr(args, 'global_batch_size') and args.global_batch_size is not None: + total_batch_size = args.global_batch_size + else: + total_batch_size = (world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps) + if total_batch_size <= 1: + raise ValueError('Batch size is 1 (too small). KTO will not work properly because the KL term ' + 'will be equivalent to the implied reward.') + train_dataset = _get_kl_dataset(train_dataset, total_batch_size, args.dataset_num_proc, args.data_seed) + val_dataset = _get_kl_dataset(val_dataset, total_batch_size, args.dataset_num_proc, args.data_seed) label = train_dataset['label'] num_desirable = max(sum(label), 1) diff --git a/swift/megatron/trainers/kto_trainer.py b/swift/megatron/trainers/kto_trainer.py index f201767d3e..9ddb8ab343 100644 --- a/swift/megatron/trainers/kto_trainer.py +++ b/swift/megatron/trainers/kto_trainer.py @@ -76,7 +76,7 @@ def loss_func(self, output_tensor, *, data, kl_data, label): loss = loss.mean() mean_metric = { 'loss': loss.detach().clone(), - 'kl': kl.detach(), + 'kl': kl.squeeze().detach(), } metric = self._all_reduce_metric(mean_metric) sum_metric = { @@ -159,7 +159,11 @@ def _prepare_batch(self, data, vp_stage): num_samples = data.pop('num_samples') for key in ['completion_', 'KL_completion_']: _data = {k[len(key):]: v for k, v in data.items() if k.startswith(key)} - res.append(super()._prepare_batch(_data, vp_stage, num_samples)) + if not self.args.calculate_KL and key == 'KL_completion_': + _data = {} + else: + _data = super()._prepare_batch(_data, vp_stage, num_samples) + res.append(_data) res[0]['label'] = data['label'] return res