diff --git "a/docs/source/BestPractices/Reranker\350\256\255\347\273\203.md" "b/docs/source/BestPractices/Reranker\350\256\255\347\273\203.md" index 2a0a056b14..c975c421fe 100644 --- "a/docs/source/BestPractices/Reranker\350\256\255\347\273\203.md" +++ "b/docs/source/BestPractices/Reranker\350\256\255\347\273\203.md" @@ -66,21 +66,54 @@ loss的源代码可以在[这里](https://github.com/modelscope/ms-swift/blob/ma ## 数据集格式 ```json lines -{"query": "query", "response": "relevant_doc1", "rejected_response": ["irrelevant_doc1", "irrelevant_doc2", ...]} -{"query": "query", "response": "relevant_doc2", "rejected_response": ["irrelevant_doc1", "irrelevant_doc2", ...]} -... +{"messages": [{"role": "user", "content": "query"}], "positive_messages": [[{"role": "assistant", "content": "relevant_doc1"}],[{"role": "assistant", "content": "relevant_doc2"}]], "negative_messages": [[{"role": "assistant", "content": "irrelevant_doc1"}],[{"role": "assistant", "content": "irrelevant_doc2"}], ...]} ``` **字段说明:** -- `query`:查询文本 -- `response`:与查询相关的正例文档 -- `rejected_response`:与查询不相关的负例文档列表,支持多个负例 +- `messages`:查询文本 +- `positive_messages`:与查询相关的正例文档列表,支持多个正例 +- `negative_messages`:与查询不相关的负例文档列表,支持多个负例 + +**环境变量配置:** +- `MAX_POSITIVE_SAMPLES`:每个query的最大正例数量(默认:1) +- `MAX_NEGATIVE_SAMPLES`:每个query的最大负例数量(默认:7) + +> 默认会从每条数据中取出`MAX_POSITIVE_SAMPLES`条正样本和`MAX_NEGATIVE_SAMPLES`条负样本,每条正样本会和`MAX_NEGATIVE_SAMPLES`条负样本组成一个group,因此每条数据会扩展成`MAX_POSITIVE_SAMPLES`x`(1 + MAX_NEGATIVE_SAMPLES)`条数据。 +> 如果数据中正例/负例数量不足,会取全部正例/负例,如果数据中正例和负例数量超过`MAX_POSITIVE_SAMPLES`和`MAX_NEGATIVE_SAMPLES`,会进行随机采样。 +> **IMPORTANT**:展开后的数据会放在同一个batch中,因此每个设备上的实际批处理大小(effective batch size)将是 `per_device_train_batch_size` × `MAX_POSITIVE_SAMPLES` × (1 + `MAX_NEGATIVE_SAMPLES`)。请注意调整 `per_device_train_batch_size` 以避免显存不足。 ## 脚手架 SWIFT提供了两个脚手架训练脚本: -- [Pointwise分类式Reranker](https://github.com/tastelikefeet/swift/blob/main/examples/train/reranker/train_reranker.sh) -- [Pointwise生成式Reranker](https://github.com/tastelikefeet/swift/blob/main/examples/train/reranker/train_generative_reranker.sh) -- [Listwise分类式Reranker](https://github.com/tastelikefeet/swift/blob/main/examples/train/reranker/train_reranker_listwise.sh) -- [Listwise生成式Reranker](https://github.com/tastelikefeet/swift/blob/main/examples/train/reranker/train_generative_reranker_listwise.sh) +- [Pointwise分类式Reranker](https://github.com/modelscope/ms-swift/blob/main/examples/train/reranker/train_reranker.sh) +- [Pointwise生成式Reranker](https://github.com/modelscope/ms-swift/blob/main/examples/train/reranker/train_generative_reranker.sh) +- [Listwise分类式Reranker](https://github.com/modelscope/ms-swift/blob/main/examples/train/reranker/train_reranker_listwise.sh) +- [Listwise生成式Reranker](https://github.com/modelscope/ms-swift/blob/main/examples/train/reranker/train_generative_reranker_listwise.sh) + +## 高级功能 + +- Qwen3-Reranker 自定义 Instruction: + - 默认模板如下: + +```text +<|im_start|>system +Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|> +<|im_start|>user +: {Instruction} +: {Query} +: {Document}<|im_end|> +<|im_start|>assistant + + + + + +``` + +- 默认 Instruction: + - `Given a web search query, retrieve relevant passages that answer the query` + +- Instruction 优先级(就近覆盖): + - `positive_messages`/`negative_messages` 内提供的 `system` > 主 `messages` 的 `system` > 默认 Instruction。 + - 即:若某个 positive/negative 的消息序列内包含 `system`,则优先使用该条;否则若主 `messages` 含 `system` 则使用之;两者都未提供时,使用默认 Instruction。 diff --git a/docs/source_en/BestPractices/Reranker.md b/docs/source_en/BestPractices/Reranker.md index 9089655be6..3634409859 100644 --- a/docs/source_en/BestPractices/Reranker.md +++ b/docs/source_en/BestPractices/Reranker.md @@ -65,21 +65,54 @@ The loss function source code can be found [here](https://github.com/modelscope/ ## Dataset Format ```json lines -{"query": "query", "response": "relevant_doc1", "rejected_response": ["irrelevant_doc1", "irrelevant_doc2", ...]} -{"query": "query", "response": "relevant_doc2", "rejected_response": ["irrelevant_doc1", "irrelevant_doc2", ...]} -... +{"messages": [{"role": "user", "content": "query"}], "positive_messages": [[{"role": "assistant", "content": "relevant_doc1"}],[{"role": "assistant", "content": "relevant_doc2"}]], "negative_messages": [[{"role": "assistant", "content": "irrelevant_doc1"}],[{"role": "assistant", "content": "irrelevant_doc2"}], ...]} ``` **Field Description:** -- `query`: Query text -- `response`: Positive document relevant to the query -- `rejected_response`: List of negative documents irrelevant to the query, supports multiple negative examples +- `messages`: Query text +- `positive_messages`: List of positive documents relevant to the query, supports multiple positive examples +- `negative_messages`: List of negative documents irrelevant to the query, supports multiple negative examples + +**Environment Variable Configuration:** +- `MAX_POSITIVE_SAMPLES`: Maximum number of positive examples per query (default: 1) +- `MAX_NEGATIVE_SAMPLES`: Maximum number of negative examples per query (default: 7) + +> By default, `MAX_POSITIVE_SAMPLES` positive examples and `MAX_NEGATIVE_SAMPLES` negative examples will be extracted from each data item. Each positive example will be grouped with `MAX_NEGATIVE_SAMPLES` negative examples to form a group. Therefore, each data item will be expanded into `MAX_POSITIVE_SAMPLES`x`(1 + MAX_NEGATIVE_SAMPLES)` data points. +> If the number of positive/negative examples in the data is insufficient, all positive/negative examples will be used. If the number of positive and negative examples in the data exceeds `MAX_POSITIVE_SAMPLES` and `MAX_NEGATIVE_SAMPLES`, random sampling will be performed. +> **IMPORTANT**: The expanded data will be placed in the same batch. Therefore, the effective batch size on each device will be `per_device_train_batch_size` × `MAX_POSITIVE_SAMPLES` × (1 + `MAX_NEGATIVE_SAMPLES`). Please adjust your `per_device_train_batch_size` accordingly to avoid out-of-memory errors. ## Training Scripts SWIFT provides four training script templates: -- [Pointwise Classification Reranker](https://github.com/tastelikefeet/swift/blob/main/examples/train/reranker/train_reranker.sh) -- [Pointwise Generative Reranker](https://github.com/tastelikefeet/swift/blob/main/examples/train/reranker/train_generative_reranker.sh) -- [Listwise Classification Reranker](https://github.com/tastelikefeet/swift/blob/main/examples/train/reranker/train_reranker_listwise.sh) -- [Listwise Generative Reranker](https://github.com/tastelikefeet/swift/blob/main/examples/train/reranker/train_generative_reranker_listwise.sh) +- [Pointwise Classification Reranker](https://github.com/modelscope/ms-swift/blob/main/examples/train/reranker/train_reranker.sh) +- [Pointwise Generative Reranker](https://github.com/modelscope/ms-swift/blob/main/examples/train/reranker/train_generative_reranker.sh) +- [Listwise Classification Reranker](https://github.com/modelscope/ms-swift/blob/main/examples/train/reranker/train_reranker_listwise.sh) +- [Listwise Generative Reranker](https://github.com/modelscope/ms-swift/blob/main/examples/train/reranker/train_generative_reranker_listwise.sh) + +## Advanced + +- Qwen3-Reranker Custom Instruction: + - Default template: + +```text +<|im_start|>system +Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|> +<|im_start|>user +: {Instruction} +: {Query} +: {Document}<|im_end|> +<|im_start|>assistant + + + + + +``` + +- Default instruction: + - `Given a web search query, retrieve relevant passages that answer the query` + +- Instruction priority (nearest wins): + - `system` inside `positive_messages`/`negative_messages` > `system` in main `messages` > default instruction. + - That is, if a positive/negative message sequence contains a `system`, it takes precedence; otherwise, if main `messages` has a `system`, use it; if neither is provided, use the default instruction. diff --git a/swift/llm/dataset/dataset/llm.py b/swift/llm/dataset/dataset/llm.py index 58caaaa814..be3afc8135 100644 --- a/swift/llm/dataset/dataset/llm.py +++ b/swift/llm/dataset/dataset/llm.py @@ -378,19 +378,18 @@ def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: tags=['similarity', '🔥'])) -class MTEBRerankPreprocessor(ResponsePreprocessor): +class MTEBRerankPreprocessor(RowPreprocessor): - def preprocess(self, row: Dict[str, Any]) -> List[Dict[str, Any]]: + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: query = row['query'] positives = row['positive'] if isinstance(row['positive'], list) else [row['positive']] negatives = row['negative'] if isinstance(row['negative'], list) else [row['negative']] - expanded_rows = [] - for positive in positives: - expanded_row = {'query': query, 'response': positive, 'rejected_response': negatives} - expanded_rows.append(super().preprocess(expanded_row)) + messages = [{'role': 'user', 'content': query}] + positive_messages = [[{'role': 'assistant', 'content': positive}] for positive in positives] + negative_messages = [[{'role': 'assistant', 'content': negative}] for negative in negatives] - return expanded_rows + return {'messages': messages, 'positive_messages': positive_messages, 'negative_messages': negative_messages} register_dataset( diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index 822805efb8..f1ae9893a6 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -3,7 +3,9 @@ import inspect import math import os +import random import re +from collections import defaultdict from contextlib import contextmanager, nullcontext from copy import deepcopy from dataclasses import asdict @@ -250,13 +252,6 @@ def _preprocess_function_call(self, inputs: StdTemplateInputs) -> None: else: i += 1 - def _preprocess_inputs_reranker( - self, - inputs: StdTemplateInputs, - ) -> None: - # TODO: remove - return - def _preprocess_inputs( self, inputs: StdTemplateInputs, @@ -451,37 +446,29 @@ def split_multi_medias(_inputs): return _encoded def _reranker_encode(self, inputs: TemplateInputs) -> Dict[str, Any]: - inputs = inputs.chosen # TODO: refactor - self._preprocess_inputs_reranker(inputs) - _encoded = {} + chosen = inputs.chosen + instruction = chosen.system + + _encoded = defaultdict(list) labels = [] - positive = deepcopy(inputs) - positive.rejected_response = [] - if '{doc}' in positive.messages[-2]['content']: - positive.messages[-2]['content'] = positive.messages[-2]['content'].replace( - '{doc}', inputs.messages[-1]['content']) - positive.messages.pop(-1) - positive_encoded = self._encode_truncated(positive) - for key in positive_encoded: - _encoded[f'positive_{key}'] = positive_encoded[key] - _encoded[f'negative_{key}'] = [] - labels.append(1) - - rejected_len = len(inputs.rejected_response) if inputs.rejected_response else 0 - for i in range(rejected_len): - negative = deepcopy(inputs) - if '{doc}' in negative.messages[-2]['content']: - negative.messages[-2]['content'] = negative.messages[-2]['content'].replace( - '{doc}', negative.rejected_response[i]) - negative.messages.pop(-1) - else: - negative.messages[-1]['content'] = negative.rejected_response[i] - negative.rejected_response = [] + for positive in inputs.positive: + if instruction is not None and positive.system is None: + positive.system = instruction + positive.messages = chosen.messages + positive.messages + positive_encoded = self._encode_truncated(positive) + labels.append(1) + for key in positive_encoded: + _encoded[key].append(positive_encoded[key]) + + for negative in inputs.negative: + if instruction is not None and negative.system is None: + negative.system = instruction + negative.messages = chosen.messages + negative.messages negative_encoded = self._encode_truncated(negative) - for key in negative_encoded: - _encoded[f'negative_{key}'].append(negative_encoded[key]) labels.append(0) + for key in negative_encoded: + _encoded[key].append(negative_encoded[key]) _encoded['labels'] = labels return _encoded @@ -1571,31 +1558,26 @@ def _reranker_data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: - import os + max_positive_samples = int(os.environ.get('MAX_POSITIVE_SAMPLES', 1)) max_negative_samples = int(os.environ.get('MAX_NEGATIVE_SAMPLES', 7)) - labels = [] + labels_list = [] new_batch = [] for b in batch: - keys = [key for key in b.keys() if 'negative' in key] - max_neg = None - for key in keys: - value_list = b[key] - suffix = key[len('negative_'):] - max_neg = min(max_negative_samples, len(value_list)) - for i, value in enumerate(value_list): - b[f'negative{i}_{suffix}'] = value - b.pop(key) - - indexes = ['positive_'] - if max_neg is not None: - for i in range(0, max_neg): - indexes.append(f'negative{i}_') - for prefix in indexes: - new_batch += self._fetch_inputs_startswith([b], prefix) - labels.extend(b.get('labels', None)[:max_negative_samples + 1]) + labels = b.pop('labels') + positive_num = sum(labels) + negative_num = len(labels) - positive_num + max_positive = min(positive_num, max_positive_samples) + max_negative = min(negative_num, max_negative_samples) + for i in random.sample(range(positive_num), max_positive): + new_batch.append({'input_ids': b['input_ids'][i]}) + labels_list.append(1) + for j in random.sample(range(negative_num), max_negative): + new_batch.append({'input_ids': b['input_ids'][j + positive_num]}) + labels_list.append(0) + res = self._data_collator(new_batch, padding_to=padding_to) - if labels: - res['labels'] = torch.tensor(labels, dtype=torch.long) + if labels_list: + res['labels'] = torch.tensor(labels_list, dtype=torch.long) return res def _seq_cls_data_collator(self, diff --git a/swift/llm/template/template/qwen.py b/swift/llm/template/template/qwen.py index b1893f8dd7..e167f00740 100644 --- a/swift/llm/template/template/qwen.py +++ b/swift/llm/template/template/qwen.py @@ -17,7 +17,7 @@ from ..base import Template from ..constant import LLMTemplateType, MLLMTemplateType from ..register import register_template -from ..template_inputs import StdTemplateInputs +from ..template_inputs import StdTemplateInputs, TemplateInputs from ..template_meta import TemplateMeta from ..utils import Context, Word, findall from ..vision_utils import load_audio, load_batch, load_video_ovis2, load_video_ovis2_5 @@ -72,11 +72,17 @@ class Qwen3Template(ThinkingTemplate): class Qwen3RerankerTemplate(Template): instruction = 'Given a web search query, retrieve relevant passages that answer the query' - def _preprocess_inputs_reranker(self, inputs: StdTemplateInputs) -> None: - super()._preprocess_inputs_reranker(inputs) - query = inputs.messages[-2]['content'] - user_message = ': ' + self.instruction + '\n' + ': ' + query + '\n' + ': {doc}' - inputs.messages[-2]['content'] = user_message + def _preprocess_inputs(self, inputs: StdTemplateInputs) -> None: + if inputs.system is not None: + instruction = inputs.system + inputs.system = None + else: + instruction = self.instruction + query = inputs.messages[0]['content'] + document = inputs.messages[1]['content'] + user_message = ': ' + instruction + '\n' + ': ' + query + '\n' + ': ' + document + inputs.messages = [{'role': 'user', 'content': user_message}] + return inputs qwen3_reranker_system = ( diff --git a/swift/plugin/loss.py b/swift/plugin/loss.py index 1830422e90..b1b83b1cd9 100755 --- a/swift/plugin/loss.py +++ b/swift/plugin/loss.py @@ -652,12 +652,7 @@ def listwise_generative_reranker_loss(outputs, negative_logits = logits[:, -1, negative_token_id] # [batch_size] # Create binary classification logits for each sample - # Shape: [batch_size, 2] where dim=1 represents [negative, positive] - binary_logits = torch.stack([negative_logits, positive_logits], dim=1) - - # Convert to relevance scores using softmax (probability of positive class) - binary_probs = torch.softmax(binary_logits, dim=1) - relevance_scores = binary_probs[:, 1] # Probability of positive class [batch_size] + logits = positive_logits - negative_logits # Find positive sample indices to determine group boundaries positive_indices = torch.nonzero(labels == 1, as_tuple=False).squeeze(-1) @@ -684,7 +679,7 @@ def listwise_generative_reranker_loss(outputs, group_end = len(labels) # Extract group relevance scores and labels - group_scores = relevance_scores[group_start:group_end] # [group_size] + group_scores = logits[group_start:group_end] # [group_size] group_labels = labels[group_start:group_end] # [group_size] # Skip groups that are too small @@ -695,9 +690,7 @@ def listwise_generative_reranker_loss(outputs, if group_labels[0] != 1: continue # Skip malformed groups - # Convert relevance scores to logits for cross-entropy loss - # We use log to convert probabilities back to logits, then apply temperature - group_logits = torch.log(group_scores + 1e-8) / temperature # Add small epsilon for numerical stability + group_logits = group_scores / temperature # The positive document is always at index 0 within the group target = torch.tensor(0, dtype=torch.long, device=logits.device) diff --git a/swift/trainers/trainers.py b/swift/trainers/trainers.py index 10f598da7f..a3213cd17d 100644 --- a/swift/trainers/trainers.py +++ b/swift/trainers/trainers.py @@ -161,7 +161,7 @@ def _preprocess_generative_reranker_logits(self, logits, labels): positive_logits = last_step_logits[:, positive_token_id] negative_logits = last_step_logits[:, negative_token_id] - logits = torch.stack([negative_logits, positive_logits], dim=1) + logits = positive_logits - negative_logits return logits else: # Unexpected shape, return as-is @@ -174,26 +174,7 @@ def evaluation_loop(self, *args, **kwargs): def calculate_metric(self, eval_prediction: EvalPrediction) -> Dict[str, float]: from swift.plugin.loss import calculate_reranker_metrics - - # Check if we're using generative reranker (point-wise or list-wise) - if self.args.loss_type in {'generative_reranker', 'listwise_generative_reranker'}: - # For generative reranker, predictions are now [batch_size, 2] from preprocessing - # We need to handle this differently - predictions = eval_prediction.predictions - if len(predictions.shape) == 2 and predictions.shape[1] == 2: - # Predictions are already preprocessed [batch_size, 2] format - # Apply softmax to get probabilities - import numpy as np - exp_logits = np.exp(predictions - np.max(predictions, axis=1, keepdims=True)) - probabilities = exp_logits / np.sum(exp_logits, axis=1, keepdims=True) - relevance_scores = probabilities[:, 1] # Positive class probability - return calculate_reranker_metrics(relevance_scores, eval_prediction.label_ids) - else: - # Fallback to original method if preprocessing didn't work - raise ValueError('Unexpected predictions shape') - else: - # For standard reranker (point-wise or list-wise) - return calculate_reranker_metrics(eval_prediction.predictions, eval_prediction.label_ids) + return calculate_reranker_metrics(eval_prediction.predictions, eval_prediction.label_ids) def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): # Check if we have a custom loss function