Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 43 additions & 10 deletions docs/source/BestPractices/Reranker训练.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,21 +66,54 @@ loss的源代码可以在[这里](https://github.com/modelscope/ms-swift/blob/ma
## 数据集格式

```json lines
{"query": "query", "response": "relevant_doc1", "rejected_response": ["irrelevant_doc1", "irrelevant_doc2", ...]}
{"query": "query", "response": "relevant_doc2", "rejected_response": ["irrelevant_doc1", "irrelevant_doc2", ...]}
...
{"messages": [{"role": "user", "content": "query"}], "positive_messages": [[{"role": "assistant", "content": "relevant_doc1"}],[{"role": "assistant", "content": "relevant_doc2"}]], "negative_messages": [[{"role": "assistant", "content": "irrelevant_doc1"}],[{"role": "assistant", "content": "irrelevant_doc2"}], ...]}
```

**字段说明:**
- `query`:查询文本
- `response`:与查询相关的正例文档
- `rejected_response`:与查询不相关的负例文档列表,支持多个负例
- `messages`:查询文本
- `positive_messages`:与查询相关的正例文档列表,支持多个正例
- `negative_messages`:与查询不相关的负例文档列表,支持多个负例

**环境变量配置:**
- `MAX_POSITIVE_SAMPLES`:每个query的最大正例数量(默认:1)
- `MAX_NEGATIVE_SAMPLES`:每个query的最大负例数量(默认:7)

> 默认会从每条数据中取出`MAX_POSITIVE_SAMPLES`条正样本和`MAX_NEGATIVE_SAMPLES`条负样本,每条正样本会和`MAX_NEGATIVE_SAMPLES`条负样本组成一个group,因此每条数据会扩展成`MAX_POSITIVE_SAMPLES`x`(1 + MAX_NEGATIVE_SAMPLES)`条数据。
> 如果数据中正例/负例数量不足,会取全部正例/负例,如果数据中正例和负例数量超过`MAX_POSITIVE_SAMPLES`和`MAX_NEGATIVE_SAMPLES`,会进行随机采样。
> **IMPORTANT**:展开后的数据会放在同一个batch中,因此每个设备上的实际批处理大小(effective batch size)将是 `per_device_train_batch_size` × `MAX_POSITIVE_SAMPLES` × (1 + `MAX_NEGATIVE_SAMPLES`)。请注意调整 `per_device_train_batch_size` 以避免显存不足。

## 脚手架

SWIFT提供了两个脚手架训练脚本:

- [Pointwise分类式Reranker](https://github.com/tastelikefeet/swift/blob/main/examples/train/reranker/train_reranker.sh)
- [Pointwise生成式Reranker](https://github.com/tastelikefeet/swift/blob/main/examples/train/reranker/train_generative_reranker.sh)
- [Listwise分类式Reranker](https://github.com/tastelikefeet/swift/blob/main/examples/train/reranker/train_reranker_listwise.sh)
- [Listwise生成式Reranker](https://github.com/tastelikefeet/swift/blob/main/examples/train/reranker/train_generative_reranker_listwise.sh)
- [Pointwise分类式Reranker](https://github.com/modelscope/ms-swift/blob/main/examples/train/reranker/train_reranker.sh)
- [Pointwise生成式Reranker](https://github.com/modelscope/ms-swift/blob/main/examples/train/reranker/train_generative_reranker.sh)
- [Listwise分类式Reranker](https://github.com/modelscope/ms-swift/blob/main/examples/train/reranker/train_reranker_listwise.sh)
- [Listwise生成式Reranker](https://github.com/modelscope/ms-swift/blob/main/examples/train/reranker/train_generative_reranker_listwise.sh)

## 高级功能

- Qwen3-Reranker 自定义 Instruction:
- 默认模板如下:

```text
<|im_start|>system
Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>
<|im_start|>user
<Instruct>: {Instruction}
<Query>: {Query}
<Document>: {Document}<|im_end|>
<|im_start|>assistant
<think>

</think>


```

- 默认 Instruction:
- `Given a web search query, retrieve relevant passages that answer the query`

- Instruction 优先级(就近覆盖):
- `positive_messages`/`negative_messages` 内提供的 `system` > 主 `messages` 的 `system` > 默认 Instruction。
- 即:若某个 positive/negative 的消息序列内包含 `system`,则优先使用该条;否则若主 `messages` 含 `system` 则使用之;两者都未提供时,使用默认 Instruction。
53 changes: 43 additions & 10 deletions docs/source_en/BestPractices/Reranker.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,21 +65,54 @@ The loss function source code can be found [here](https://github.com/modelscope/
## Dataset Format

```json lines
{"query": "query", "response": "relevant_doc1", "rejected_response": ["irrelevant_doc1", "irrelevant_doc2", ...]}
{"query": "query", "response": "relevant_doc2", "rejected_response": ["irrelevant_doc1", "irrelevant_doc2", ...]}
...
{"messages": [{"role": "user", "content": "query"}], "positive_messages": [[{"role": "assistant", "content": "relevant_doc1"}],[{"role": "assistant", "content": "relevant_doc2"}]], "negative_messages": [[{"role": "assistant", "content": "irrelevant_doc1"}],[{"role": "assistant", "content": "irrelevant_doc2"}], ...]}
```

**Field Description:**
- `query`: Query text
- `response`: Positive document relevant to the query
- `rejected_response`: List of negative documents irrelevant to the query, supports multiple negative examples
- `messages`: Query text
- `positive_messages`: List of positive documents relevant to the query, supports multiple positive examples
- `negative_messages`: List of negative documents irrelevant to the query, supports multiple negative examples

**Environment Variable Configuration:**
- `MAX_POSITIVE_SAMPLES`: Maximum number of positive examples per query (default: 1)
- `MAX_NEGATIVE_SAMPLES`: Maximum number of negative examples per query (default: 7)

> By default, `MAX_POSITIVE_SAMPLES` positive examples and `MAX_NEGATIVE_SAMPLES` negative examples will be extracted from each data item. Each positive example will be grouped with `MAX_NEGATIVE_SAMPLES` negative examples to form a group. Therefore, each data item will be expanded into `MAX_POSITIVE_SAMPLES`x`(1 + MAX_NEGATIVE_SAMPLES)` data points.
> If the number of positive/negative examples in the data is insufficient, all positive/negative examples will be used. If the number of positive and negative examples in the data exceeds `MAX_POSITIVE_SAMPLES` and `MAX_NEGATIVE_SAMPLES`, random sampling will be performed.
> **IMPORTANT**: The expanded data will be placed in the same batch. Therefore, the effective batch size on each device will be `per_device_train_batch_size` × `MAX_POSITIVE_SAMPLES` × (1 + `MAX_NEGATIVE_SAMPLES`). Please adjust your `per_device_train_batch_size` accordingly to avoid out-of-memory errors.

## Training Scripts

SWIFT provides four training script templates:

- [Pointwise Classification Reranker](https://github.com/tastelikefeet/swift/blob/main/examples/train/reranker/train_reranker.sh)
- [Pointwise Generative Reranker](https://github.com/tastelikefeet/swift/blob/main/examples/train/reranker/train_generative_reranker.sh)
- [Listwise Classification Reranker](https://github.com/tastelikefeet/swift/blob/main/examples/train/reranker/train_reranker_listwise.sh)
- [Listwise Generative Reranker](https://github.com/tastelikefeet/swift/blob/main/examples/train/reranker/train_generative_reranker_listwise.sh)
- [Pointwise Classification Reranker](https://github.com/modelscope/ms-swift/blob/main/examples/train/reranker/train_reranker.sh)
- [Pointwise Generative Reranker](https://github.com/modelscope/ms-swift/blob/main/examples/train/reranker/train_generative_reranker.sh)
- [Listwise Classification Reranker](https://github.com/modelscope/ms-swift/blob/main/examples/train/reranker/train_reranker_listwise.sh)
- [Listwise Generative Reranker](https://github.com/modelscope/ms-swift/blob/main/examples/train/reranker/train_generative_reranker_listwise.sh)

## Advanced

- Qwen3-Reranker Custom Instruction:
- Default template:

```text
<|im_start|>system
Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>
<|im_start|>user
<Instruct>: {Instruction}
<Query>: {Query}
<Document>: {Document}<|im_end|>
<|im_start|>assistant
<think>

</think>


```

- Default instruction:
- `Given a web search query, retrieve relevant passages that answer the query`

- Instruction priority (nearest wins):
- `system` inside `positive_messages`/`negative_messages` > `system` in main `messages` > default instruction.
- That is, if a positive/negative message sequence contains a `system`, it takes precedence; otherwise, if main `messages` has a `system`, use it; if neither is provided, use the default instruction.
13 changes: 6 additions & 7 deletions swift/llm/dataset/dataset/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,19 +378,18 @@ def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]:
tags=['similarity', '🔥']))


class MTEBRerankPreprocessor(ResponsePreprocessor):
class MTEBRerankPreprocessor(RowPreprocessor):

def preprocess(self, row: Dict[str, Any]) -> List[Dict[str, Any]]:
def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]:
query = row['query']
positives = row['positive'] if isinstance(row['positive'], list) else [row['positive']]
negatives = row['negative'] if isinstance(row['negative'], list) else [row['negative']]

expanded_rows = []
for positive in positives:
expanded_row = {'query': query, 'response': positive, 'rejected_response': negatives}
expanded_rows.append(super().preprocess(expanded_row))
messages = [{'role': 'user', 'content': query}]
positive_messages = [[{'role': 'assistant', 'content': positive}] for positive in positives]
negative_messages = [[{'role': 'assistant', 'content': negative}] for negative in negatives]

return expanded_rows
return {'messages': messages, 'positive_messages': positive_messages, 'negative_messages': negative_messages}


register_dataset(
Expand Down
92 changes: 37 additions & 55 deletions swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import inspect
import math
import os
import random
import re
from collections import defaultdict
from contextlib import contextmanager, nullcontext
from copy import deepcopy
from dataclasses import asdict
Expand Down Expand Up @@ -250,13 +252,6 @@ def _preprocess_function_call(self, inputs: StdTemplateInputs) -> None:
else:
i += 1

def _preprocess_inputs_reranker(
self,
inputs: StdTemplateInputs,
) -> None:
# TODO: remove
return

def _preprocess_inputs(
self,
inputs: StdTemplateInputs,
Expand Down Expand Up @@ -451,37 +446,29 @@ def split_multi_medias(_inputs):
return _encoded

def _reranker_encode(self, inputs: TemplateInputs) -> Dict[str, Any]:
inputs = inputs.chosen # TODO: refactor
self._preprocess_inputs_reranker(inputs)
_encoded = {}
chosen = inputs.chosen
instruction = chosen.system

_encoded = defaultdict(list)
labels = []

positive = deepcopy(inputs)
positive.rejected_response = []
if '{doc}' in positive.messages[-2]['content']:
positive.messages[-2]['content'] = positive.messages[-2]['content'].replace(
'{doc}', inputs.messages[-1]['content'])
positive.messages.pop(-1)
positive_encoded = self._encode_truncated(positive)
for key in positive_encoded:
_encoded[f'positive_{key}'] = positive_encoded[key]
_encoded[f'negative_{key}'] = []
labels.append(1)

rejected_len = len(inputs.rejected_response) if inputs.rejected_response else 0
for i in range(rejected_len):
negative = deepcopy(inputs)
if '{doc}' in negative.messages[-2]['content']:
negative.messages[-2]['content'] = negative.messages[-2]['content'].replace(
'{doc}', negative.rejected_response[i])
negative.messages.pop(-1)
else:
negative.messages[-1]['content'] = negative.rejected_response[i]
negative.rejected_response = []
for positive in inputs.positive:
if instruction is not None and positive.system is None:
positive.system = instruction
positive.messages = chosen.messages + positive.messages
positive_encoded = self._encode_truncated(positive)
labels.append(1)
for key in positive_encoded:
_encoded[key].append(positive_encoded[key])

for negative in inputs.negative:
if instruction is not None and negative.system is None:
negative.system = instruction
negative.messages = chosen.messages + negative.messages
negative_encoded = self._encode_truncated(negative)
for key in negative_encoded:
_encoded[f'negative_{key}'].append(negative_encoded[key])
labels.append(0)
for key in negative_encoded:
_encoded[key].append(negative_encoded[key])

_encoded['labels'] = labels
return _encoded
Expand Down Expand Up @@ -1571,31 +1558,26 @@ def _reranker_data_collator(self,
batch: List[Dict[str, Any]],
*,
padding_to: Optional[int] = None) -> Dict[str, Any]:
import os
max_positive_samples = int(os.environ.get('MAX_POSITIVE_SAMPLES', 1))
max_negative_samples = int(os.environ.get('MAX_NEGATIVE_SAMPLES', 7))
labels = []
labels_list = []
new_batch = []
for b in batch:
keys = [key for key in b.keys() if 'negative' in key]
max_neg = None
for key in keys:
value_list = b[key]
suffix = key[len('negative_'):]
max_neg = min(max_negative_samples, len(value_list))
for i, value in enumerate(value_list):
b[f'negative{i}_{suffix}'] = value
b.pop(key)

indexes = ['positive_']
if max_neg is not None:
for i in range(0, max_neg):
indexes.append(f'negative{i}_')
for prefix in indexes:
new_batch += self._fetch_inputs_startswith([b], prefix)
labels.extend(b.get('labels', None)[:max_negative_samples + 1])
labels = b.pop('labels')
positive_num = sum(labels)
negative_num = len(labels) - positive_num
max_positive = min(positive_num, max_positive_samples)
max_negative = min(negative_num, max_negative_samples)
for i in random.sample(range(positive_num), max_positive):
new_batch.append({'input_ids': b['input_ids'][i]})
labels_list.append(1)
for j in random.sample(range(negative_num), max_negative):
new_batch.append({'input_ids': b['input_ids'][j + positive_num]})
labels_list.append(0)

res = self._data_collator(new_batch, padding_to=padding_to)
if labels:
res['labels'] = torch.tensor(labels, dtype=torch.long)
if labels_list:
res['labels'] = torch.tensor(labels_list, dtype=torch.long)
return res

def _seq_cls_data_collator(self,
Expand Down
18 changes: 12 additions & 6 deletions swift/llm/template/template/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ..base import Template
from ..constant import LLMTemplateType, MLLMTemplateType
from ..register import register_template
from ..template_inputs import StdTemplateInputs
from ..template_inputs import StdTemplateInputs, TemplateInputs
from ..template_meta import TemplateMeta
from ..utils import Context, Word, findall
from ..vision_utils import load_audio, load_batch, load_video_ovis2, load_video_ovis2_5
Expand Down Expand Up @@ -72,11 +72,17 @@ class Qwen3Template(ThinkingTemplate):
class Qwen3RerankerTemplate(Template):
instruction = 'Given a web search query, retrieve relevant passages that answer the query'

def _preprocess_inputs_reranker(self, inputs: StdTemplateInputs) -> None:
super()._preprocess_inputs_reranker(inputs)
query = inputs.messages[-2]['content']
user_message = '<Instruct>: ' + self.instruction + '\n' + '<Query>: ' + query + '\n' + '<Document>: {doc}'
inputs.messages[-2]['content'] = user_message
def _preprocess_inputs(self, inputs: StdTemplateInputs) -> None:
if inputs.system is not None:
instruction = inputs.system
inputs.system = None
else:
instruction = self.instruction
query = inputs.messages[0]['content']
document = inputs.messages[1]['content']
user_message = '<Instruct>: ' + instruction + '\n' + '<Query>: ' + query + '\n' + '<Document>: ' + document
inputs.messages = [{'role': 'user', 'content': user_message}]
return inputs


qwen3_reranker_system = (
Expand Down
13 changes: 3 additions & 10 deletions swift/plugin/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,12 +652,7 @@ def listwise_generative_reranker_loss(outputs,
negative_logits = logits[:, -1, negative_token_id] # [batch_size]

# Create binary classification logits for each sample
# Shape: [batch_size, 2] where dim=1 represents [negative, positive]
binary_logits = torch.stack([negative_logits, positive_logits], dim=1)

# Convert to relevance scores using softmax (probability of positive class)
binary_probs = torch.softmax(binary_logits, dim=1)
relevance_scores = binary_probs[:, 1] # Probability of positive class [batch_size]
logits = positive_logits - negative_logits

# Find positive sample indices to determine group boundaries
positive_indices = torch.nonzero(labels == 1, as_tuple=False).squeeze(-1)
Expand All @@ -684,7 +679,7 @@ def listwise_generative_reranker_loss(outputs,
group_end = len(labels)

# Extract group relevance scores and labels
group_scores = relevance_scores[group_start:group_end] # [group_size]
group_scores = logits[group_start:group_end] # [group_size]
group_labels = labels[group_start:group_end] # [group_size]

# Skip groups that are too small
Expand All @@ -695,9 +690,7 @@ def listwise_generative_reranker_loss(outputs,
if group_labels[0] != 1:
continue # Skip malformed groups

# Convert relevance scores to logits for cross-entropy loss
# We use log to convert probabilities back to logits, then apply temperature
group_logits = torch.log(group_scores + 1e-8) / temperature # Add small epsilon for numerical stability
group_logits = group_scores / temperature

# The positive document is always at index 0 within the group
target = torch.tensor(0, dtype=torch.long, device=logits.device)
Expand Down
Loading
Loading