From 5d28065f33b5da47c3128be9d5e8d559fd1a1c56 Mon Sep 17 00:00:00 2001 From: 0russwest0 <1074124719@qq.com> Date: Thu, 4 Sep 2025 20:00:27 +0800 Subject: [PATCH 01/13] update --- swift/llm/dataset/dataset/llm.py | 9 +++--- swift/llm/template/base.py | 43 ++++++++++------------------- swift/llm/template/template/qwen.py | 20 ++++++++++---- 3 files changed, 33 insertions(+), 39 deletions(-) diff --git a/swift/llm/dataset/dataset/llm.py b/swift/llm/dataset/dataset/llm.py index 58caaaa814..0f7e852f95 100644 --- a/swift/llm/dataset/dataset/llm.py +++ b/swift/llm/dataset/dataset/llm.py @@ -385,12 +385,11 @@ def preprocess(self, row: Dict[str, Any]) -> List[Dict[str, Any]]: positives = row['positive'] if isinstance(row['positive'], list) else [row['positive']] negatives = row['negative'] if isinstance(row['negative'], list) else [row['negative']] - expanded_rows = [] - for positive in positives: - expanded_row = {'query': query, 'response': positive, 'rejected_response': negatives} - expanded_rows.append(super().preprocess(expanded_row)) + messages = [{"role": "user", "content": query}] + positive_messages = [[{"role": "user", "content": positive}] for positive in positives] + negative_messages = [[{"role": "user", "content": negative}] for negative in negatives] - return expanded_rows + return {"messages": messages, "positive_messages": positive_messages, "negative_messages": negative_messages} register_dataset( diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index c8a86cb3ad..87ba57806f 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -441,39 +441,26 @@ def split_multi_medias(_inputs): return _encoded def _reranker_encode(self, inputs: TemplateInputs) -> Dict[str, Any]: - inputs = inputs.chosen # TODO: refactor - self._preprocess_inputs_reranker(inputs) - _encoded = {} - labels = [] + from collections import defaultdict + inputs = self._preprocess_inputs_reranker(inputs) + chosen = inputs.chosen + + _encoded = defaultdict(list) - positive = deepcopy(inputs) - positive.rejected_response = [] - if '{doc}' in positive.messages[-2]['content']: - positive.messages[-2]['content'] = positive.messages[-2]['content'].replace( - '{doc}', inputs.messages[-1]['content']) - positive.messages.pop(-1) - positive_encoded = self._encode_truncated(positive) - for key in positive_encoded: - _encoded[f'positive_{key}'] = positive_encoded[key] - _encoded[f'negative_{key}'] = [] - labels.append(1) - - rejected_len = len(inputs.rejected_response) if inputs.rejected_response else 0 - for i in range(rejected_len): - negative = deepcopy(inputs) - if '{doc}' in negative.messages[-2]['content']: - negative.messages[-2]['content'] = negative.messages[-2]['content'].replace( - '{doc}', negative.rejected_response[i]) - negative.messages.pop(-1) - else: - negative.messages[-1]['content'] = negative.rejected_response[i] - negative.rejected_response = [] + for positive in inputs.positive: + positive.messages = chosen.messages + positive.messages + positive_encoded = self._encode_truncated(positive) + for key in positive_encoded: + _encoded[f'positive_{key}'].append(positive_encoded[key]) + _encoded['labels'].append(1) + + for negative in inputs.negative: + negative.messages = chosen.messages + negative.messages negative_encoded = self._encode_truncated(negative) for key in negative_encoded: _encoded[f'negative_{key}'].append(negative_encoded[key]) - labels.append(0) + _encoded['labels'].append(0) - _encoded['labels'] = labels return _encoded def _seq_cls_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: diff --git a/swift/llm/template/template/qwen.py b/swift/llm/template/template/qwen.py index f6a935f168..b148ea6dfa 100644 --- a/swift/llm/template/template/qwen.py +++ b/swift/llm/template/template/qwen.py @@ -16,7 +16,7 @@ from ..base import Template from ..constant import LLMTemplateType, MLLMTemplateType from ..register import register_template -from ..template_inputs import StdTemplateInputs +from ..template_inputs import StdTemplateInputs, TemplateInputs from ..template_meta import TemplateMeta from ..utils import Context, Word, findall from ..vision_utils import load_audio, load_batch, load_video_ovis2, load_video_ovis2_5 @@ -71,12 +71,20 @@ class Qwen3Template(ThinkingTemplate): class Qwen3RerankerTemplate(Template): instruction = 'Given a web search query, retrieve relevant passages that answer the query' - def _preprocess_inputs_reranker(self, inputs: StdTemplateInputs) -> None: + def _preprocess_inputs_reranker(self, inputs: TemplateInputs) -> None: super()._preprocess_inputs_reranker(inputs) - query = inputs.messages[-2]['content'] - user_message = ': ' + self.instruction + '\n' + ': ' + query + '\n' + ': {doc}' - inputs.messages[-2]['content'] = user_message - + if inputs.chosen.system is not None: + instruction = inputs.chosen.system + else: + instruction = self.instruction + query = inputs.chosen.messages[-1]['content'] + for positive in inputs.positive: + user_message = ': ' + instruction + '\n' + ': ' + query + '\n' + ': ' + positive.messages[-1]['content'] + positive.messages = {'role': 'user', 'content': user_message} + for negative in inputs.negative: + user_message = ': ' + instruction + '\n' + ': ' + query + '\n' + ': ' + negative.messages[-1]['content'] + negative.messages = {'role': 'user', 'content': user_message} + inputs.chosen.messages = [] qwen3_reranker_system = ( 'Judge whether the Document meets the requirements based on the Query and the Instruct provided. ' From 815b7ad15ae9e123a1b3e54249b16f48141b6aeb Mon Sep 17 00:00:00 2001 From: 0russwest0 <1074124719@qq.com> Date: Sun, 7 Sep 2025 13:32:05 +0800 Subject: [PATCH 02/13] update --- swift/llm/dataset/dataset/llm.py | 4 +-- swift/llm/template/base.py | 46 +++++++++++++++-------------- swift/llm/template/template/qwen.py | 5 ++-- 3 files changed, 29 insertions(+), 26 deletions(-) diff --git a/swift/llm/dataset/dataset/llm.py b/swift/llm/dataset/dataset/llm.py index 0f7e852f95..0ccd68a3cd 100644 --- a/swift/llm/dataset/dataset/llm.py +++ b/swift/llm/dataset/dataset/llm.py @@ -386,8 +386,8 @@ def preprocess(self, row: Dict[str, Any]) -> List[Dict[str, Any]]: negatives = row['negative'] if isinstance(row['negative'], list) else [row['negative']] messages = [{"role": "user", "content": query}] - positive_messages = [[{"role": "user", "content": positive}] for positive in positives] - negative_messages = [[{"role": "user", "content": negative}] for negative in negatives] + positive_messages = [[{"role": "assistant", "content": positive}] for positive in positives] + negative_messages = [[{"role": "assistant", "content": negative}] for negative in negatives] return {"messages": messages, "positive_messages": positive_messages, "negative_messages": negative_messages} diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index 87ba57806f..26f00a40b1 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -451,14 +451,14 @@ def _reranker_encode(self, inputs: TemplateInputs) -> Dict[str, Any]: positive.messages = chosen.messages + positive.messages positive_encoded = self._encode_truncated(positive) for key in positive_encoded: - _encoded[f'positive_{key}'].append(positive_encoded[key]) + _encoded[key].append(positive_encoded[key]) _encoded['labels'].append(1) for negative in inputs.negative: negative.messages = chosen.messages + negative.messages negative_encoded = self._encode_truncated(negative) for key in negative_encoded: - _encoded[f'negative_{key}'].append(negative_encoded[key]) + _encoded[key].append(negative_encoded[key]) _encoded['labels'].append(0) return _encoded @@ -1546,30 +1546,32 @@ def _reranker_data_collator(self, *, padding_to: Optional[int] = None) -> Dict[str, Any]: import os + import random max_negative_samples = int(os.environ.get('MAX_NEGATIVE_SAMPLES', 7)) - labels = [] + labels_list = [] new_batch = [] for b in batch: - keys = [key for key in b.keys() if 'negative' in key] - max_neg = None - for key in keys: - value_list = b[key] - suffix = key[len('negative_'):] - max_neg = min(max_negative_samples, len(value_list)) - for i, value in enumerate(value_list): - b[f'negative{i}_{suffix}'] = value - b.pop(key) - - indexes = ['positive_'] - if max_neg is not None: - for i in range(0, max_neg): - indexes.append(f'negative{i}_') - for prefix in indexes: - new_batch += self._fetch_inputs_startswith([b], prefix) - labels.extend(b.get('labels', None)[:max_negative_samples + 1]) + labels = b.pop('labels') + positive_num = sum(labels) + negative_num = len(labels) - positive_num + for i in range(positive_num): + for key in b.keys(): + new_batch.append({key: b[key][i]}) + labels_list.append(1) + if negative_num > max_negative_samples: + for j in random.sample(range(negative_num), max_negative_samples): + for key in b.keys(): + new_batch.append({key: b[key][j+positive_num]}) + labels_list.append(0) + else: + for j in range(negative_num): + for key in b.keys(): + new_batch.append({key: b[key][j+positive_num]}) + labels_list.append(0) + res = self._data_collator(new_batch, padding_to=padding_to) - if labels: - res['labels'] = torch.tensor(labels, dtype=torch.long) + if labels_list: + res['labels'] = torch.tensor(labels_list, dtype=torch.long) return res def _seq_cls_data_collator(self, diff --git a/swift/llm/template/template/qwen.py b/swift/llm/template/template/qwen.py index b148ea6dfa..e4d22d47c4 100644 --- a/swift/llm/template/template/qwen.py +++ b/swift/llm/template/template/qwen.py @@ -80,11 +80,12 @@ def _preprocess_inputs_reranker(self, inputs: TemplateInputs) -> None: query = inputs.chosen.messages[-1]['content'] for positive in inputs.positive: user_message = ': ' + instruction + '\n' + ': ' + query + '\n' + ': ' + positive.messages[-1]['content'] - positive.messages = {'role': 'user', 'content': user_message} + positive.messages = [{'role': 'user', 'content': user_message}] for negative in inputs.negative: user_message = ': ' + instruction + '\n' + ': ' + query + '\n' + ': ' + negative.messages[-1]['content'] - negative.messages = {'role': 'user', 'content': user_message} + negative.messages = [{'role': 'user', 'content': user_message}] inputs.chosen.messages = [] + return inputs qwen3_reranker_system = ( 'Judge whether the Document meets the requirements based on the Query and the Instruct provided. ' From a9bf07129ddf487052375816d34d290cfdb008c8 Mon Sep 17 00:00:00 2001 From: 0russwest0 <1074124719@qq.com> Date: Wed, 10 Sep 2025 20:19:59 +0800 Subject: [PATCH 03/13] update --- swift/llm/dataset/dataset/llm.py | 2 +- swift/llm/template/base.py | 30 +++++++++++++++--------------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/swift/llm/dataset/dataset/llm.py b/swift/llm/dataset/dataset/llm.py index 0ccd68a3cd..85dd9321a0 100644 --- a/swift/llm/dataset/dataset/llm.py +++ b/swift/llm/dataset/dataset/llm.py @@ -378,7 +378,7 @@ def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: tags=['similarity', '🔥'])) -class MTEBRerankPreprocessor(ResponsePreprocessor): +class MTEBRerankPreprocessor(RowPreprocessor): def preprocess(self, row: Dict[str, Any]) -> List[Dict[str, Any]]: query = row['query'] diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index 083168498e..5e7daaf04f 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -454,21 +454,24 @@ def _reranker_encode(self, inputs: TemplateInputs) -> Dict[str, Any]: chosen = inputs.chosen _encoded = defaultdict(list) + labels = [] for positive in inputs.positive: positive.messages = chosen.messages + positive.messages positive_encoded = self._encode_truncated(positive) + labels.append(1) for key in positive_encoded: _encoded[key].append(positive_encoded[key]) - _encoded['labels'].append(1) for negative in inputs.negative: negative.messages = chosen.messages + negative.messages negative_encoded = self._encode_truncated(negative) + labels.append(0) for key in negative_encoded: _encoded[key].append(negative_encoded[key]) - _encoded['labels'].append(0) + _encoded['labels'] = labels + _encoded['length'] = len(labels) return _encoded def _seq_cls_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: @@ -1566,19 +1569,16 @@ def _reranker_data_collator(self, positive_num = sum(labels) negative_num = len(labels) - positive_num for i in range(positive_num): - for key in b.keys(): - new_batch.append({key: b[key][i]}) - labels_list.append(1) - if negative_num > max_negative_samples: - for j in random.sample(range(negative_num), max_negative_samples): - for key in b.keys(): - new_batch.append({key: b[key][j+positive_num]}) - labels_list.append(0) - else: - for j in range(negative_num): - for key in b.keys(): - new_batch.append({key: b[key][j+positive_num]}) - labels_list.append(0) + new_batch.append({'input_ids': b['input_ids'][i]}) + labels_list.append(1) + if negative_num > max_negative_samples: + for j in random.sample(range(negative_num), max_negative_samples): + new_batch.append({'input_ids': b['input_ids'][j+positive_num]}) + labels_list.append(0) + else: + for j in range(negative_num): + new_batch.append({'input_ids': b['input_ids'][j+positive_num]}) + labels_list.append(0) res = self._data_collator(new_batch, padding_to=padding_to) if labels_list: From aeec54ce741b8c6829069ef136735453f163cd46 Mon Sep 17 00:00:00 2001 From: 0russwest0 <1074124719@qq.com> Date: Thu, 11 Sep 2025 14:36:22 +0800 Subject: [PATCH 04/13] refactor --- swift/llm/dataset/dataset/llm.py | 8 ++++---- swift/llm/template/base.py | 23 ++++++++++------------- swift/llm/template/template/qwen.py | 21 +++++++++------------ swift/plugin/loss.py | 13 +++---------- swift/trainers/trainers.py | 23 ++--------------------- 5 files changed, 28 insertions(+), 60 deletions(-) diff --git a/swift/llm/dataset/dataset/llm.py b/swift/llm/dataset/dataset/llm.py index 85dd9321a0..ee5d4dee94 100644 --- a/swift/llm/dataset/dataset/llm.py +++ b/swift/llm/dataset/dataset/llm.py @@ -385,11 +385,11 @@ def preprocess(self, row: Dict[str, Any]) -> List[Dict[str, Any]]: positives = row['positive'] if isinstance(row['positive'], list) else [row['positive']] negatives = row['negative'] if isinstance(row['negative'], list) else [row['negative']] - messages = [{"role": "user", "content": query}] - positive_messages = [[{"role": "assistant", "content": positive}] for positive in positives] - negative_messages = [[{"role": "assistant", "content": negative}] for negative in negatives] + messages = [{'role': 'user', 'content': query}] + positive_messages = [[{'role': 'assistant', 'content': positive}] for positive in positives] + negative_messages = [[{'role': 'assistant', 'content': negative}] for negative in negatives] - return {"messages": messages, "positive_messages": positive_messages, "negative_messages": negative_messages} + return {'messages': messages, 'positive_messages': positive_messages, 'negative_messages': negative_messages} register_dataset( diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index 5f29fcda29..2aa1f2bbc7 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -250,13 +250,6 @@ def _preprocess_function_call(self, inputs: StdTemplateInputs) -> None: else: i += 1 - def _preprocess_inputs_reranker( - self, - inputs: StdTemplateInputs, - ) -> None: - # TODO: remove - return - def _preprocess_inputs( self, inputs: StdTemplateInputs, @@ -452,13 +445,15 @@ def split_multi_medias(_inputs): def _reranker_encode(self, inputs: TemplateInputs) -> Dict[str, Any]: from collections import defaultdict - inputs = self._preprocess_inputs_reranker(inputs) chosen = inputs.chosen - + instruction = chosen.system + _encoded = defaultdict(list) labels = [] for positive in inputs.positive: + if instruction is not None and positive.system is None: + positive.system = instruction positive.messages = chosen.messages + positive.messages positive_encoded = self._encode_truncated(positive) labels.append(1) @@ -466,6 +461,8 @@ def _reranker_encode(self, inputs: TemplateInputs) -> Dict[str, Any]: _encoded[key].append(positive_encoded[key]) for negative in inputs.negative: + if instruction is not None and negative.system is None: + negative.system = instruction negative.messages = chosen.messages + negative.messages negative_encoded = self._encode_truncated(negative) labels.append(0) @@ -473,7 +470,6 @@ def _reranker_encode(self, inputs: TemplateInputs) -> Dict[str, Any]: _encoded[key].append(negative_encoded[key]) _encoded['labels'] = labels - _encoded['length'] = len(labels) return _encoded def _seq_cls_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: @@ -1563,6 +1559,7 @@ def _reranker_data_collator(self, padding_to: Optional[int] = None) -> Dict[str, Any]: import os import random + max_positive_samples = int(os.environ.get('MAX_POSITIVE_SAMPLES', 1)) max_negative_samples = int(os.environ.get('MAX_NEGATIVE_SAMPLES', 7)) labels_list = [] new_batch = [] @@ -1570,16 +1567,16 @@ def _reranker_data_collator(self, labels = b.pop('labels') positive_num = sum(labels) negative_num = len(labels) - positive_num - for i in range(positive_num): + for i in range(min(positive_num, max_positive_samples)): new_batch.append({'input_ids': b['input_ids'][i]}) labels_list.append(1) if negative_num > max_negative_samples: for j in random.sample(range(negative_num), max_negative_samples): - new_batch.append({'input_ids': b['input_ids'][j+positive_num]}) + new_batch.append({'input_ids': b['input_ids'][j + positive_num]}) labels_list.append(0) else: for j in range(negative_num): - new_batch.append({'input_ids': b['input_ids'][j+positive_num]}) + new_batch.append({'input_ids': b['input_ids'][j + positive_num]}) labels_list.append(0) res = self._data_collator(new_batch, padding_to=padding_to) diff --git a/swift/llm/template/template/qwen.py b/swift/llm/template/template/qwen.py index b1ed616381..e167f00740 100644 --- a/swift/llm/template/template/qwen.py +++ b/swift/llm/template/template/qwen.py @@ -72,22 +72,19 @@ class Qwen3Template(ThinkingTemplate): class Qwen3RerankerTemplate(Template): instruction = 'Given a web search query, retrieve relevant passages that answer the query' - def _preprocess_inputs_reranker(self, inputs: TemplateInputs) -> None: - super()._preprocess_inputs_reranker(inputs) - if inputs.chosen.system is not None: - instruction = inputs.chosen.system + def _preprocess_inputs(self, inputs: StdTemplateInputs) -> None: + if inputs.system is not None: + instruction = inputs.system + inputs.system = None else: instruction = self.instruction - query = inputs.chosen.messages[-1]['content'] - for positive in inputs.positive: - user_message = ': ' + instruction + '\n' + ': ' + query + '\n' + ': ' + positive.messages[-1]['content'] - positive.messages = [{'role': 'user', 'content': user_message}] - for negative in inputs.negative: - user_message = ': ' + instruction + '\n' + ': ' + query + '\n' + ': ' + negative.messages[-1]['content'] - negative.messages = [{'role': 'user', 'content': user_message}] - inputs.chosen.messages = [] + query = inputs.messages[0]['content'] + document = inputs.messages[1]['content'] + user_message = ': ' + instruction + '\n' + ': ' + query + '\n' + ': ' + document + inputs.messages = [{'role': 'user', 'content': user_message}] return inputs + qwen3_reranker_system = ( 'Judge whether the Document meets the requirements based on the Query and the Instruct provided. ' 'Note that the answer can only be \"yes\" or \"no\".') diff --git a/swift/plugin/loss.py b/swift/plugin/loss.py index 1830422e90..b1b83b1cd9 100755 --- a/swift/plugin/loss.py +++ b/swift/plugin/loss.py @@ -652,12 +652,7 @@ def listwise_generative_reranker_loss(outputs, negative_logits = logits[:, -1, negative_token_id] # [batch_size] # Create binary classification logits for each sample - # Shape: [batch_size, 2] where dim=1 represents [negative, positive] - binary_logits = torch.stack([negative_logits, positive_logits], dim=1) - - # Convert to relevance scores using softmax (probability of positive class) - binary_probs = torch.softmax(binary_logits, dim=1) - relevance_scores = binary_probs[:, 1] # Probability of positive class [batch_size] + logits = positive_logits - negative_logits # Find positive sample indices to determine group boundaries positive_indices = torch.nonzero(labels == 1, as_tuple=False).squeeze(-1) @@ -684,7 +679,7 @@ def listwise_generative_reranker_loss(outputs, group_end = len(labels) # Extract group relevance scores and labels - group_scores = relevance_scores[group_start:group_end] # [group_size] + group_scores = logits[group_start:group_end] # [group_size] group_labels = labels[group_start:group_end] # [group_size] # Skip groups that are too small @@ -695,9 +690,7 @@ def listwise_generative_reranker_loss(outputs, if group_labels[0] != 1: continue # Skip malformed groups - # Convert relevance scores to logits for cross-entropy loss - # We use log to convert probabilities back to logits, then apply temperature - group_logits = torch.log(group_scores + 1e-8) / temperature # Add small epsilon for numerical stability + group_logits = group_scores / temperature # The positive document is always at index 0 within the group target = torch.tensor(0, dtype=torch.long, device=logits.device) diff --git a/swift/trainers/trainers.py b/swift/trainers/trainers.py index 10f598da7f..a3213cd17d 100644 --- a/swift/trainers/trainers.py +++ b/swift/trainers/trainers.py @@ -161,7 +161,7 @@ def _preprocess_generative_reranker_logits(self, logits, labels): positive_logits = last_step_logits[:, positive_token_id] negative_logits = last_step_logits[:, negative_token_id] - logits = torch.stack([negative_logits, positive_logits], dim=1) + logits = positive_logits - negative_logits return logits else: # Unexpected shape, return as-is @@ -174,26 +174,7 @@ def evaluation_loop(self, *args, **kwargs): def calculate_metric(self, eval_prediction: EvalPrediction) -> Dict[str, float]: from swift.plugin.loss import calculate_reranker_metrics - - # Check if we're using generative reranker (point-wise or list-wise) - if self.args.loss_type in {'generative_reranker', 'listwise_generative_reranker'}: - # For generative reranker, predictions are now [batch_size, 2] from preprocessing - # We need to handle this differently - predictions = eval_prediction.predictions - if len(predictions.shape) == 2 and predictions.shape[1] == 2: - # Predictions are already preprocessed [batch_size, 2] format - # Apply softmax to get probabilities - import numpy as np - exp_logits = np.exp(predictions - np.max(predictions, axis=1, keepdims=True)) - probabilities = exp_logits / np.sum(exp_logits, axis=1, keepdims=True) - relevance_scores = probabilities[:, 1] # Positive class probability - return calculate_reranker_metrics(relevance_scores, eval_prediction.label_ids) - else: - # Fallback to original method if preprocessing didn't work - raise ValueError('Unexpected predictions shape') - else: - # For standard reranker (point-wise or list-wise) - return calculate_reranker_metrics(eval_prediction.predictions, eval_prediction.label_ids) + return calculate_reranker_metrics(eval_prediction.predictions, eval_prediction.label_ids) def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): # Check if we have a custom loss function From 49bb5bc2cd1c448f27da3b812df55d516e3036dd Mon Sep 17 00:00:00 2001 From: 0russwest0 <1074124719@qq.com> Date: Thu, 11 Sep 2025 15:06:38 +0800 Subject: [PATCH 05/13] update --- swift/llm/template/base.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index 2aa1f2bbc7..317f989e43 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -1567,17 +1567,14 @@ def _reranker_data_collator(self, labels = b.pop('labels') positive_num = sum(labels) negative_num = len(labels) - positive_num - for i in range(min(positive_num, max_positive_samples)): + max_positive = min(positive_num, max_positive_samples) + max_negative = min(negative_num, max_negative_samples) + for i in random.sample(range(positive_num), max_positive): new_batch.append({'input_ids': b['input_ids'][i]}) labels_list.append(1) - if negative_num > max_negative_samples: - for j in random.sample(range(negative_num), max_negative_samples): - new_batch.append({'input_ids': b['input_ids'][j + positive_num]}) - labels_list.append(0) - else: - for j in range(negative_num): - new_batch.append({'input_ids': b['input_ids'][j + positive_num]}) - labels_list.append(0) + for j in random.sample(range(negative_num), max_negative): + new_batch.append({'input_ids': b['input_ids'][j + positive_num]}) + labels_list.append(0) res = self._data_collator(new_batch, padding_to=padding_to) if labels_list: From c9c10702d46d4d33ee6e7d84cd2ad3ed93518695 Mon Sep 17 00:00:00 2001 From: 0russwest0 <1074124719@qq.com> Date: Thu, 11 Sep 2025 15:18:08 +0800 Subject: [PATCH 06/13] Update dataset format and environment variable configuration in Reranker documentation --- .../Reranker\350\256\255\347\273\203.md" | 26 ++++++++++++------- docs/source_en/BestPractices/Reranker.md | 26 ++++++++++++------- 2 files changed, 32 insertions(+), 20 deletions(-) diff --git "a/docs/source/BestPractices/Reranker\350\256\255\347\273\203.md" "b/docs/source/BestPractices/Reranker\350\256\255\347\273\203.md" index 2a0a056b14..4853c98f86 100644 --- "a/docs/source/BestPractices/Reranker\350\256\255\347\273\203.md" +++ "b/docs/source/BestPractices/Reranker\350\256\255\347\273\203.md" @@ -66,21 +66,27 @@ loss的源代码可以在[这里](https://github.com/modelscope/ms-swift/blob/ma ## 数据集格式 ```json lines -{"query": "query", "response": "relevant_doc1", "rejected_response": ["irrelevant_doc1", "irrelevant_doc2", ...]} -{"query": "query", "response": "relevant_doc2", "rejected_response": ["irrelevant_doc1", "irrelevant_doc2", ...]} -... +{"messages": [{"role": "user", "content": "query"}], "positive_messages": [[{"role": "assistant", "content": "relevant_doc1"}],[{"role": "assistant", "content": "relevant_doc2"}]], "negative_messages": [[{"role": "assistant", "content": "irrelevant_doc1"}],[{"role": "assistant", "content": "irrelevant_doc2"}], ...]} ``` **字段说明:** -- `query`:查询文本 -- `response`:与查询相关的正例文档 -- `rejected_response`:与查询不相关的负例文档列表,支持多个负例 +- `messages`:查询文本 +- `positive_messages`:与查询相关的正例文档列表,支持多个正例 +- `negative_messages`:与查询不相关的负例文档列表,支持多个负例 + +**环境变量配置:** +- `MAX_POSITIVE_SAMPLES`:每个query的最大正例数量(默认:1) +- `MAX_NEGATIVE_SAMPLES`:每个query的最大负例数量(默认:7) + +> 默认会从每条数据中取出`MAX_POSITIVE_SAMPLES`条正样本和`MAX_NEGATIVE_SAMPLES`条负样本,每条数据会扩展成`MAX_POSITIVE_SAMPLES`x`MAX_NEGATIVE_SAMPLES`条数据。 +> 如果数据中正例/负例数量不足,会取全部正例/负例,如果数据中正例和负例数量超过`MAX_POSITIVE_SAMPLES`和`MAX_NEGATIVE_SAMPLES`,会进行随机采样。 +> **IMPORTANT**:展开后的数据会放在同一个batch中,真实的per_device_batch_size会变成`per_device_batch_size`x`MAX_POSITIVE_SAMPLES`x`MAX_NEGATIVE_SAMPLES`。 ## 脚手架 SWIFT提供了两个脚手架训练脚本: -- [Pointwise分类式Reranker](https://github.com/tastelikefeet/swift/blob/main/examples/train/reranker/train_reranker.sh) -- [Pointwise生成式Reranker](https://github.com/tastelikefeet/swift/blob/main/examples/train/reranker/train_generative_reranker.sh) -- [Listwise分类式Reranker](https://github.com/tastelikefeet/swift/blob/main/examples/train/reranker/train_reranker_listwise.sh) -- [Listwise生成式Reranker](https://github.com/tastelikefeet/swift/blob/main/examples/train/reranker/train_generative_reranker_listwise.sh) +- [Pointwise分类式Reranker](https://github.com/modelscope/ms-swift/blob/main/examples/train/reranker/train_reranker.sh) +- [Pointwise生成式Reranker](https://github.com/modelscope/ms-swift/blob/main/examples/train/reranker/train_generative_reranker.sh) +- [Listwise分类式Reranker](https://github.com/modelscope/ms-swift/blob/main/examples/train/reranker/train_reranker_listwise.sh) +- [Listwise生成式Reranker](https://github.com/modelscope/ms-swift/blob/main/examples/train/reranker/train_generative_reranker_listwise.sh) diff --git a/docs/source_en/BestPractices/Reranker.md b/docs/source_en/BestPractices/Reranker.md index 9089655be6..282d6cef61 100644 --- a/docs/source_en/BestPractices/Reranker.md +++ b/docs/source_en/BestPractices/Reranker.md @@ -65,21 +65,27 @@ The loss function source code can be found [here](https://github.com/modelscope/ ## Dataset Format ```json lines -{"query": "query", "response": "relevant_doc1", "rejected_response": ["irrelevant_doc1", "irrelevant_doc2", ...]} -{"query": "query", "response": "relevant_doc2", "rejected_response": ["irrelevant_doc1", "irrelevant_doc2", ...]} -... +{"messages": [{"role": "user", "content": "query"}], "positive_messages": [[{"role": "assistant", "content": "relevant_doc1"}],[{"role": "assistant", "content": "relevant_doc2"}]], "negative_messages": [[{"role": "assistant", "content": "irrelevant_doc1"}],[{"role": "assistant", "content": "irrelevant_doc2"}], ...]} ``` **Field Description:** -- `query`: Query text -- `response`: Positive document relevant to the query -- `rejected_response`: List of negative documents irrelevant to the query, supports multiple negative examples +- `messages`: Query text +- `positive_messages`: List of positive documents relevant to the query, supports multiple positive examples +- `negative_messages`: List of negative documents irrelevant to the query, supports multiple negative examples + +**Environment Variable Configuration:** +- `MAX_POSITIVE_SAMPLES`: Maximum number of positive examples per query (default: 1) +- `MAX_NEGATIVE_SAMPLES`: Maximum number of negative examples per query (default: 7) + +> By default, `MAX_POSITIVE_SAMPLES` positive examples and `MAX_NEGATIVE_SAMPLES` negative examples will be extracted from each data, and each data will be expanded into `MAX_POSITIVE_SAMPLES`x`MAX_NEGATIVE_SAMPLES` data. +> If the number of positive/negative examples in the data is insufficient, all positive/negative examples will be used. If the number of positive and negative examples in the data exceeds `MAX_POSITIVE_SAMPLES` and `MAX_NEGATIVE_SAMPLES`, random sampling will be performed. +> **IMPORTANT**: The expanded data will be placed in the same batch, and the real `per_device_batch_size` will become `per_device_batch_size`x`MAX_POSITIVE_SAMPLES`x`MAX_NEGATIVE_SAMPLES`. ## Training Scripts SWIFT provides four training script templates: -- [Pointwise Classification Reranker](https://github.com/tastelikefeet/swift/blob/main/examples/train/reranker/train_reranker.sh) -- [Pointwise Generative Reranker](https://github.com/tastelikefeet/swift/blob/main/examples/train/reranker/train_generative_reranker.sh) -- [Listwise Classification Reranker](https://github.com/tastelikefeet/swift/blob/main/examples/train/reranker/train_reranker_listwise.sh) -- [Listwise Generative Reranker](https://github.com/tastelikefeet/swift/blob/main/examples/train/reranker/train_generative_reranker_listwise.sh) +- [Pointwise Classification Reranker](https://github.com/modelscope/ms-swift/blob/main/examples/train/reranker/train_reranker.sh) +- [Pointwise Generative Reranker](https://github.com/modelscope/ms-swift/blob/main/examples/train/reranker/train_generative_reranker.sh) +- [Listwise Classification Reranker](https://github.com/modelscope/ms-swift/blob/main/examples/train/reranker/train_reranker_listwise.sh) +- [Listwise Generative Reranker](https://github.com/modelscope/ms-swift/blob/main/examples/train/reranker/train_generative_reranker_listwise.sh) From 10189e3ee50772ccf0d3021e83a6746a67d480d1 Mon Sep 17 00:00:00 2001 From: russwest404 <80997191+0russwest0@users.noreply.github.com> Date: Thu, 11 Sep 2025 15:37:25 +0800 Subject: [PATCH 07/13] Update swift/llm/dataset/dataset/llm.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- swift/llm/dataset/dataset/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swift/llm/dataset/dataset/llm.py b/swift/llm/dataset/dataset/llm.py index ee5d4dee94..be3afc8135 100644 --- a/swift/llm/dataset/dataset/llm.py +++ b/swift/llm/dataset/dataset/llm.py @@ -380,7 +380,7 @@ def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: class MTEBRerankPreprocessor(RowPreprocessor): - def preprocess(self, row: Dict[str, Any]) -> List[Dict[str, Any]]: + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: query = row['query'] positives = row['positive'] if isinstance(row['positive'], list) else [row['positive']] negatives = row['negative'] if isinstance(row['negative'], list) else [row['negative']] From 4e41d3eaa964e3a56e2513ed87296a89f3c264fe Mon Sep 17 00:00:00 2001 From: russwest404 <80997191+0russwest0@users.noreply.github.com> Date: Thu, 11 Sep 2025 15:39:01 +0800 Subject: [PATCH 08/13] Apply suggestions from code review Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../source/BestPractices/Reranker\350\256\255\347\273\203.md" | 4 ++-- docs/source_en/BestPractices/Reranker.md | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git "a/docs/source/BestPractices/Reranker\350\256\255\347\273\203.md" "b/docs/source/BestPractices/Reranker\350\256\255\347\273\203.md" index 4853c98f86..b27f02122f 100644 --- "a/docs/source/BestPractices/Reranker\350\256\255\347\273\203.md" +++ "b/docs/source/BestPractices/Reranker\350\256\255\347\273\203.md" @@ -78,9 +78,9 @@ loss的源代码可以在[这里](https://github.com/modelscope/ms-swift/blob/ma - `MAX_POSITIVE_SAMPLES`:每个query的最大正例数量(默认:1) - `MAX_NEGATIVE_SAMPLES`:每个query的最大负例数量(默认:7) -> 默认会从每条数据中取出`MAX_POSITIVE_SAMPLES`条正样本和`MAX_NEGATIVE_SAMPLES`条负样本,每条数据会扩展成`MAX_POSITIVE_SAMPLES`x`MAX_NEGATIVE_SAMPLES`条数据。 +> 默认会从每条数据中取出`MAX_POSITIVE_SAMPLES`条正样本和`MAX_NEGATIVE_SAMPLES`条负样本,每条正样本会和`MAX_NEGATIVE_SAMPLES`条负样本组成一个listwise group,因此每条数据会扩展成`MAX_POSITIVE_SAMPLES`x`(1 + MAX_NEGATIVE_SAMPLES)`条数据。 > 如果数据中正例/负例数量不足,会取全部正例/负例,如果数据中正例和负例数量超过`MAX_POSITIVE_SAMPLES`和`MAX_NEGATIVE_SAMPLES`,会进行随机采样。 -> **IMPORTANT**:展开后的数据会放在同一个batch中,真实的per_device_batch_size会变成`per_device_batch_size`x`MAX_POSITIVE_SAMPLES`x`MAX_NEGATIVE_SAMPLES`。 +> **IMPORTANT**:展开后的数据会放在同一个batch中,真实的per_device_batch_size会变成`per_device_batch_size`x`MAX_POSITIVE_SAMPLES`x`(1 + MAX_NEGATIVE_SAMPLES)`。 ## 脚手架 diff --git a/docs/source_en/BestPractices/Reranker.md b/docs/source_en/BestPractices/Reranker.md index 282d6cef61..e5b62f78b5 100644 --- a/docs/source_en/BestPractices/Reranker.md +++ b/docs/source_en/BestPractices/Reranker.md @@ -77,9 +77,9 @@ The loss function source code can be found [here](https://github.com/modelscope/ - `MAX_POSITIVE_SAMPLES`: Maximum number of positive examples per query (default: 1) - `MAX_NEGATIVE_SAMPLES`: Maximum number of negative examples per query (default: 7) -> By default, `MAX_POSITIVE_SAMPLES` positive examples and `MAX_NEGATIVE_SAMPLES` negative examples will be extracted from each data, and each data will be expanded into `MAX_POSITIVE_SAMPLES`x`MAX_NEGATIVE_SAMPLES` data. +> By default, `MAX_POSITIVE_SAMPLES` positive examples and `MAX_NEGATIVE_SAMPLES` negative examples will be extracted from each data item. Each positive example will be grouped with `MAX_NEGATIVE_SAMPLES` negative examples to form a listwise group. Therefore, each data item will be expanded into `MAX_POSITIVE_SAMPLES`x`(1 + MAX_NEGATIVE_SAMPLES)` data points. > If the number of positive/negative examples in the data is insufficient, all positive/negative examples will be used. If the number of positive and negative examples in the data exceeds `MAX_POSITIVE_SAMPLES` and `MAX_NEGATIVE_SAMPLES`, random sampling will be performed. -> **IMPORTANT**: The expanded data will be placed in the same batch, and the real `per_device_batch_size` will become `per_device_batch_size`x`MAX_POSITIVE_SAMPLES`x`MAX_NEGATIVE_SAMPLES`. +> **IMPORTANT**: The expanded data will be placed in the same batch, and the real `per_device_batch_size` will become `per_device_batch_size`x`MAX_POSITIVE_SAMPLES`x`(1 + MAX_NEGATIVE_SAMPLES)`. ## Training Scripts From 69c4a08b33fbf8d67acb21df911d581417647fba Mon Sep 17 00:00:00 2001 From: 0russwest0 <1074124719@qq.com> Date: Thu, 11 Sep 2025 15:43:51 +0800 Subject: [PATCH 09/13] update --- "docs/source/BestPractices/Reranker\350\256\255\347\273\203.md" | 2 +- docs/source_en/BestPractices/Reranker.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git "a/docs/source/BestPractices/Reranker\350\256\255\347\273\203.md" "b/docs/source/BestPractices/Reranker\350\256\255\347\273\203.md" index b27f02122f..5794e99fc9 100644 --- "a/docs/source/BestPractices/Reranker\350\256\255\347\273\203.md" +++ "b/docs/source/BestPractices/Reranker\350\256\255\347\273\203.md" @@ -78,7 +78,7 @@ loss的源代码可以在[这里](https://github.com/modelscope/ms-swift/blob/ma - `MAX_POSITIVE_SAMPLES`:每个query的最大正例数量(默认:1) - `MAX_NEGATIVE_SAMPLES`:每个query的最大负例数量(默认:7) -> 默认会从每条数据中取出`MAX_POSITIVE_SAMPLES`条正样本和`MAX_NEGATIVE_SAMPLES`条负样本,每条正样本会和`MAX_NEGATIVE_SAMPLES`条负样本组成一个listwise group,因此每条数据会扩展成`MAX_POSITIVE_SAMPLES`x`(1 + MAX_NEGATIVE_SAMPLES)`条数据。 +> 默认会从每条数据中取出`MAX_POSITIVE_SAMPLES`条正样本和`MAX_NEGATIVE_SAMPLES`条负样本,每条正样本会和`MAX_NEGATIVE_SAMPLES`条负样本组成一个group,因此每条数据会扩展成`MAX_POSITIVE_SAMPLES`x`(1 + MAX_NEGATIVE_SAMPLES)`条数据。 > 如果数据中正例/负例数量不足,会取全部正例/负例,如果数据中正例和负例数量超过`MAX_POSITIVE_SAMPLES`和`MAX_NEGATIVE_SAMPLES`,会进行随机采样。 > **IMPORTANT**:展开后的数据会放在同一个batch中,真实的per_device_batch_size会变成`per_device_batch_size`x`MAX_POSITIVE_SAMPLES`x`(1 + MAX_NEGATIVE_SAMPLES)`。 diff --git a/docs/source_en/BestPractices/Reranker.md b/docs/source_en/BestPractices/Reranker.md index e5b62f78b5..beecb41d2d 100644 --- a/docs/source_en/BestPractices/Reranker.md +++ b/docs/source_en/BestPractices/Reranker.md @@ -77,7 +77,7 @@ The loss function source code can be found [here](https://github.com/modelscope/ - `MAX_POSITIVE_SAMPLES`: Maximum number of positive examples per query (default: 1) - `MAX_NEGATIVE_SAMPLES`: Maximum number of negative examples per query (default: 7) -> By default, `MAX_POSITIVE_SAMPLES` positive examples and `MAX_NEGATIVE_SAMPLES` negative examples will be extracted from each data item. Each positive example will be grouped with `MAX_NEGATIVE_SAMPLES` negative examples to form a listwise group. Therefore, each data item will be expanded into `MAX_POSITIVE_SAMPLES`x`(1 + MAX_NEGATIVE_SAMPLES)` data points. +> By default, `MAX_POSITIVE_SAMPLES` positive examples and `MAX_NEGATIVE_SAMPLES` negative examples will be extracted from each data item. Each positive example will be grouped with `MAX_NEGATIVE_SAMPLES` negative examples to form a group. Therefore, each data item will be expanded into `MAX_POSITIVE_SAMPLES`x`(1 + MAX_NEGATIVE_SAMPLES)` data points. > If the number of positive/negative examples in the data is insufficient, all positive/negative examples will be used. If the number of positive and negative examples in the data exceeds `MAX_POSITIVE_SAMPLES` and `MAX_NEGATIVE_SAMPLES`, random sampling will be performed. > **IMPORTANT**: The expanded data will be placed in the same batch, and the real `per_device_batch_size` will become `per_device_batch_size`x`MAX_POSITIVE_SAMPLES`x`(1 + MAX_NEGATIVE_SAMPLES)`. From 39d027c8da03e4a0379422cd481935373976997d Mon Sep 17 00:00:00 2001 From: 0russwest0 <1074124719@qq.com> Date: Fri, 12 Sep 2025 16:44:14 +0800 Subject: [PATCH 10/13] update --- .../Reranker\350\256\255\347\273\203.md" | 27 +++++++++++++++++++ docs/source_en/BestPractices/Reranker.md | 27 +++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git "a/docs/source/BestPractices/Reranker\350\256\255\347\273\203.md" "b/docs/source/BestPractices/Reranker\350\256\255\347\273\203.md" index 5794e99fc9..2140745886 100644 --- "a/docs/source/BestPractices/Reranker\350\256\255\347\273\203.md" +++ "b/docs/source/BestPractices/Reranker\350\256\255\347\273\203.md" @@ -90,3 +90,30 @@ SWIFT提供了两个脚手架训练脚本: - [Pointwise生成式Reranker](https://github.com/modelscope/ms-swift/blob/main/examples/train/reranker/train_generative_reranker.sh) - [Listwise分类式Reranker](https://github.com/modelscope/ms-swift/blob/main/examples/train/reranker/train_reranker_listwise.sh) - [Listwise生成式Reranker](https://github.com/modelscope/ms-swift/blob/main/examples/train/reranker/train_generative_reranker_listwise.sh) + +## 高级功能 + +- Qwen3-Reranker 自定义 Instruction: + - 默认模板如下: + +```text +<|im_start|>system +Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|> +<|im_start|>user +: {Instruction} +: {Query} +: {Document}<|im_end|> +<|im_start|>assistant + + + + + +``` + +- 默认 Instruction: + - `Given a web search query, retrieve relevant passages that answer the query` + +- Instruction 优先级(就近覆盖): + - `positive_messages`/`negative_messages` 内提供的 `system` > 主 `messages` 的 `system` > 默认 Instruction。 + - 即:若某个 positive/negative 的消息序列内包含 `system`,则优先使用该条;否则若主 `messages` 含 `system` 则使用之;两者都未提供时,使用默认 Instruction。 diff --git a/docs/source_en/BestPractices/Reranker.md b/docs/source_en/BestPractices/Reranker.md index beecb41d2d..73164c0f9d 100644 --- a/docs/source_en/BestPractices/Reranker.md +++ b/docs/source_en/BestPractices/Reranker.md @@ -89,3 +89,30 @@ SWIFT provides four training script templates: - [Pointwise Generative Reranker](https://github.com/modelscope/ms-swift/blob/main/examples/train/reranker/train_generative_reranker.sh) - [Listwise Classification Reranker](https://github.com/modelscope/ms-swift/blob/main/examples/train/reranker/train_reranker_listwise.sh) - [Listwise Generative Reranker](https://github.com/modelscope/ms-swift/blob/main/examples/train/reranker/train_generative_reranker_listwise.sh) + +## Advanced + +- Qwen3-Reranker Custom Instruction: + - Default template: + +```text +<|im_start|>system +Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|> +<|im_start|>user +: {Instruction} +: {Query} +: {Document}<|im_end|> +<|im_start|>assistant + + + + + +``` + +- Default instruction: + - `Given a web search query, retrieve relevant passages that answer the query` + +- Instruction priority (nearest wins): + - `system` inside `positive_messages`/`negative_messages` > `system` in main `messages` > default instruction. + - That is, if a positive/negative message sequence contains a `system`, it takes precedence; otherwise, if main `messages` has a `system`, use it; if neither is provided, use the default instruction. From 1ce3cd3b775ae15906e2769e4bb81434f24d83fa Mon Sep 17 00:00:00 2001 From: 0russwest0 <1074124719@qq.com> Date: Fri, 12 Sep 2025 16:49:50 +0800 Subject: [PATCH 11/13] update --- swift/llm/template/base.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index 317f989e43..60b4b2a2fb 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -3,8 +3,10 @@ import inspect import math import os +import random import re from contextlib import contextmanager, nullcontext +from collections import defaultdict from copy import deepcopy from dataclasses import asdict from functools import partial, wraps @@ -444,7 +446,6 @@ def split_multi_medias(_inputs): return _encoded def _reranker_encode(self, inputs: TemplateInputs) -> Dict[str, Any]: - from collections import defaultdict chosen = inputs.chosen instruction = chosen.system @@ -1557,8 +1558,6 @@ def _reranker_data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: - import os - import random max_positive_samples = int(os.environ.get('MAX_POSITIVE_SAMPLES', 1)) max_negative_samples = int(os.environ.get('MAX_NEGATIVE_SAMPLES', 7)) labels_list = [] From 0417d82d05899404369c5fe2a7754de1d1a83a19 Mon Sep 17 00:00:00 2001 From: russwest404 <80997191+0russwest0@users.noreply.github.com> Date: Fri, 12 Sep 2025 16:51:08 +0800 Subject: [PATCH 12/13] Apply suggestions from code review Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- "docs/source/BestPractices/Reranker\350\256\255\347\273\203.md" | 2 +- docs/source_en/BestPractices/Reranker.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git "a/docs/source/BestPractices/Reranker\350\256\255\347\273\203.md" "b/docs/source/BestPractices/Reranker\350\256\255\347\273\203.md" index 2140745886..c975c421fe 100644 --- "a/docs/source/BestPractices/Reranker\350\256\255\347\273\203.md" +++ "b/docs/source/BestPractices/Reranker\350\256\255\347\273\203.md" @@ -80,7 +80,7 @@ loss的源代码可以在[这里](https://github.com/modelscope/ms-swift/blob/ma > 默认会从每条数据中取出`MAX_POSITIVE_SAMPLES`条正样本和`MAX_NEGATIVE_SAMPLES`条负样本,每条正样本会和`MAX_NEGATIVE_SAMPLES`条负样本组成一个group,因此每条数据会扩展成`MAX_POSITIVE_SAMPLES`x`(1 + MAX_NEGATIVE_SAMPLES)`条数据。 > 如果数据中正例/负例数量不足,会取全部正例/负例,如果数据中正例和负例数量超过`MAX_POSITIVE_SAMPLES`和`MAX_NEGATIVE_SAMPLES`,会进行随机采样。 -> **IMPORTANT**:展开后的数据会放在同一个batch中,真实的per_device_batch_size会变成`per_device_batch_size`x`MAX_POSITIVE_SAMPLES`x`(1 + MAX_NEGATIVE_SAMPLES)`。 +> **IMPORTANT**:展开后的数据会放在同一个batch中,因此每个设备上的实际批处理大小(effective batch size)将是 `per_device_train_batch_size` × `MAX_POSITIVE_SAMPLES` × (1 + `MAX_NEGATIVE_SAMPLES`)。请注意调整 `per_device_train_batch_size` 以避免显存不足。 ## 脚手架 diff --git a/docs/source_en/BestPractices/Reranker.md b/docs/source_en/BestPractices/Reranker.md index 73164c0f9d..3634409859 100644 --- a/docs/source_en/BestPractices/Reranker.md +++ b/docs/source_en/BestPractices/Reranker.md @@ -79,7 +79,7 @@ The loss function source code can be found [here](https://github.com/modelscope/ > By default, `MAX_POSITIVE_SAMPLES` positive examples and `MAX_NEGATIVE_SAMPLES` negative examples will be extracted from each data item. Each positive example will be grouped with `MAX_NEGATIVE_SAMPLES` negative examples to form a group. Therefore, each data item will be expanded into `MAX_POSITIVE_SAMPLES`x`(1 + MAX_NEGATIVE_SAMPLES)` data points. > If the number of positive/negative examples in the data is insufficient, all positive/negative examples will be used. If the number of positive and negative examples in the data exceeds `MAX_POSITIVE_SAMPLES` and `MAX_NEGATIVE_SAMPLES`, random sampling will be performed. -> **IMPORTANT**: The expanded data will be placed in the same batch, and the real `per_device_batch_size` will become `per_device_batch_size`x`MAX_POSITIVE_SAMPLES`x`(1 + MAX_NEGATIVE_SAMPLES)`. +> **IMPORTANT**: The expanded data will be placed in the same batch. Therefore, the effective batch size on each device will be `per_device_train_batch_size` × `MAX_POSITIVE_SAMPLES` × (1 + `MAX_NEGATIVE_SAMPLES`). Please adjust your `per_device_train_batch_size` accordingly to avoid out-of-memory errors. ## Training Scripts From d43edda99496cbc0e2413eee2d442d620bdf2ba7 Mon Sep 17 00:00:00 2001 From: 0russwest0 <1074124719@qq.com> Date: Fri, 12 Sep 2025 16:55:05 +0800 Subject: [PATCH 13/13] fix --- swift/llm/template/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index 60b4b2a2fb..f1ae9893a6 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -5,8 +5,8 @@ import os import random import re -from contextlib import contextmanager, nullcontext from collections import defaultdict +from contextlib import contextmanager, nullcontext from copy import deepcopy from dataclasses import asdict from functools import partial, wraps