# 蒸馏 DeepSeek R1


## 复用 Eval.py 进行推理

其实蒸馏和 Eval 一样，就是批量推理，拿到思维链，然后使用 Rejection Sampling 选择正确的答案。

这次因为成本原因，每条训练集只推理一次。

```bash
CURATOR_VIEWER=1 python eval.py --provider qianfan --data_path data/train.jsonl --model_name deepseek-r1 --temperature 0.6 --max_tokens 8192 --is_reasoning True --output_prefix distill --max_concurrent_requests 20
```

## 使用服务商的批量推理功能

In [28]:
import json
from eval import Reasoner, EvalResult

In [29]:
llm = Reasoner(backend="vllm", model_name="deepseek-r1", is_reasoning=True)

data = []
input_data = {}
with open("data/train.jsonl", "r") as f:
    for i, line in enumerate(f):
        d = json.loads(line)

        messages = llm.prompt(d)
        data.append({"custom_id": f"request-{i}", "body": {"messages": messages,"max_tokens": 8192 ,"temperature ": 0.6}})
        input_data[f"request-{i}"] = d

print(json.dumps(data[0], indent=4))

{
    "custom_id": "request-0",
    "body": {
        "messages": [
            {
                "role": "user",
                "content": "Using the numbers 29, 11, 27, 10, create an equation that equals 524. You can use basic arithmetic operations (+, -, *, /) one or multiple times but each number can only be used once, and you must use all the numbers. Return the final equation in <answer> </answer> tags, for example <answer>(1 + 2) / 3</answer>."
            }
        ],
        "max_tokens": 8192,
        "temperature ": 0.6
    }
}


In [30]:
with open("data/distill_batch_requests.jsonl", "w") as f:
    for d in data:
        f.write(json.dumps(d, ensure_ascii=False) + "\n")

In [31]:
# results.json 放到 output 下

results = []

with open("output/results.jsonl", "r") as f:
    for line in f:
        d = json.loads(line)
        custom_id = d['custom_id']
        input = input_data[custom_id]
        response = d.get("response", {}).get("body")
        if not response:
            results.append({
                "correct": False,
                "correct_reason": "call llm failed",
                "target": input["target"],
                "numbers": input["numbers"],
                "prompt": "",
                "completion": "",
                "reasoning": "",
            })
        results.append(llm.parse(input, response)[0])

results[0]


{'target': 67,
 'numbers': [93, 14, 3, 98],
 'prompt': 'Using the numbers 93, 14, 3, 98, create an equation that equals 67. You can use basic arithmetic operations (+, -, *, /) one or multiple times but each number can only be used once, and you must use all the numbers. Return the final equation in <answer> </answer> tags, for example <answer>(1 + 2) / 3</answer>.',
 'completion': "<think>Okay, let's see. I need to use all the numbers 93, 14, 3, and 98 to make an equation that equals 67. Hmm. Let's start by thinking about the operations. Maybe subtract some numbers first. Since 98 is a big number, if I subtract something from it, maybe that's part of the solution. Let me try 98 minus something. Let's see, 98 minus 14 is 84, but that's still bigger than 67. If I subtract 93 from 98, that's 5. Then maybe add 14 and 3? Wait, but I have to use all numbers. Let me write down the numbers: 93, 14, 3, 98.\n\nAnother approach: maybe combine some numbers with multiplication or division. Let's s

In [32]:
total_count = len(results)
correct_count = sum([1 for r in results if r["correct"]])
accuracy = correct_count / total_count
print(f"Accuracy: {correct_count}/{total_count} ({accuracy:.2%})")

Accuracy: 886/1000 (88.60%)


In [33]:
no_reasoning_llm = Reasoner(backend="vllm", model_name="deepseek-r1", is_reasoning=False)

sft_data = []
sft_data_reasoning = []

for d in results:
    if not d["correct"]:
        # 如果推理不正确，则不加入 SFT 数据
        continue
    if len(d["completion"]) >= 7500:
        # 如果 completion 超过 7500，则去除
        # 选择 7500 是需要保证总长度不超过 8192
        continue
    messages = no_reasoning_llm.prompt(d)
    messages.append({
        "role": "assistant",
        "content": d["completion"]
    })
    sft_data.append({
        "messages": messages,
    })
    messages_reasoning = llm.prompt(d)
    messages_reasoning.append({
        "role": "assistant",
        "content": d["completion"]
    })
    sft_data_reasoning.append({
        "messages": messages_reasoning,
    })

print(f"SFT Data: {len(sft_data)}")
with open("data/train_sft_distill.jsonl", "w") as f:
    for d in sft_data:
        f.write(json.dumps(d, ensure_ascii=False) + "\n")

with open("data/train_sft_distill_reasoning.jsonl", "w") as f:
    for d in sft_data_reasoning:
        f.write(json.dumps(d, ensure_ascii=False) + "\n")

# todo: 将答案里的等式统一保留或统一去除以保持一致性

SFT Data: 712
