# 推理JIT优化
本节介绍如何利用 MindSpore 的 JIT（Just-In-Time）编译技术，对 `DeepSeek-R1-Distill-Qwen-1.5B` 模型进行推理优化。通过开启 JIT 编译，降低单次推理耗时，从而提升对话响应速度与用户体验。
优化实践可参考示例代码：[deepseek-r1-distill-qwen-1.5b-jit.py](https://github.com/mindspore-courses/orange-pi-mindspore/blob/master/Online/training/01-DeepSeek-R1-Distill-Qwen-1.5B/deepseek-r1-distill-qwen-1.5b-jit.py)

>本教程仅适用于 昇思大模型平台的单卡环境，在昇腾开发板上的实际操作，请以上述示例代码为准。

In [1]:
%%capture captured_output
# 实验环境已经预装了mindspore==2.6.0，如需更换mindspore版本，可更改下面 MINDSPORE_VERSION 变量
!pip uninstall mindspore -y
%env MINDSPORE_VERSION=2.6.0
!pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/${MINDSPORE_VERSION}/MindSpore/unified/aarch64/mindspore-${MINDSPORE_VERSION}-cp39-cp39-linux_aarch64.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com -i https://pypi.tuna.tsinghua.edu.cn/simple

In [2]:
# 查看当前 mindspore 版本
!pip show mindspore

Name: mindspore
Version: 2.6.0
Summary: MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.
Home-page: https://www.mindspore.cn
Author: The MindSpore Authors
Author-email: contact@mindspore.cn
License: Apache 2.0
Location: /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages
Requires: asttokens, astunparse, dill, numpy, packaging, pillow, protobuf, psutil, safetensors, scipy
Required-by: mindnlp


In [3]:
%%capture captured_output
# 安装mindnlp 0.4.1 版本
!pip uninstall mindnlp -y
!pip install https://xihe.mindspore.cn/coderepo/web/v1/file/MindSpore/mindnlp/main/media/mindnlp-0.4.1-py3-none-any.whl

In [4]:
import mindspore
from mindnlp.transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache
from mindnlp.core import ops
from mindnlp.configs import set_pyboost
import time
import numpy as np

# 开启O2级别的jit优化，开启图算融合
mindspore.set_context(
    enable_graph_kernel=True,
    mode=mindspore.GRAPH_MODE,
    jit_config={
        "jit_level": "O2",
    },
)

  setattr(self, word, getattr(machar, word).flat[0])
  return self._float_to_str(self.smallest_subnormal)
  setattr(self, word, getattr(machar, word).flat[0])
  return self._float_to_str(self.smallest_subnormal)


In [5]:
def sample_top_p(probs, p=0.9):
    """
    Top-p采样函数，用于生成文本时选择下一个token。
    此处优先采用基于numpy而不是原生MindSpore的实现方式，因为在香橙派上运行效率更高
    """
    probs_np = probs.asnumpy()
    # 按概率降序排序
    sorted_indices = np.argsort(-probs_np, axis=-1)
    sorted_probs = np.take_along_axis(probs_np, sorted_indices, axis=-1)
    # 计算累积概率并创建掩码
    cumulative_probs = np.cumsum(sorted_probs, axis=-1)
    mask = cumulative_probs - sorted_probs > p
    sorted_probs[mask] = 0.0
    sorted_probs = sorted_probs / np.sum(sorted_probs, axis=-1, keepdims=True)
    # 转换回MindSpore Tensor
    sorted_probs_tensor = mindspore.Tensor(sorted_probs, dtype=mindspore.float32)
    sorted_indices_tensor = mindspore.Tensor(sorted_indices, dtype=mindspore.int32)
    next_token_idx = ops.multinomial(sorted_probs_tensor, 1)
    batch_size = probs.shape[0]
    batch_indices = ops.arange(0, batch_size, dtype=mindspore.int32).reshape(-1, 1)
    # 此处采用基于mindspore.ops的实现方式，在香橙派上兼容性最好
    # next_token = sorted_indices_tensor[batch_indices, next_token_idx]
    next_token = mindspore.ops.gather(sorted_indices_tensor, next_token_idx, axis=1, batch_dims=1)
    # next_token = mindspore.mint.gather(sorted_indices_tensor, dim=1, index=next_token_idx)
    return next_token

In [6]:
# 该任务将使用DeepSeek-R1-Distill-Qwen-1.5B模型，对给定的prompt进行补齐
prompts = [
    "请介绍一下自己。<think>",
    "My favorite all time favorite condiment is ketchup.",
]

# 生成参数配置
NUM_TOKENS_TO_GENERATE = 40  # 每个输入要生成的token数量
TEMPERATURE = 0.8            # 温度参数（控制生成多样性）
TOP_P = 0.8                  # Top-p采样阈值

model_id = "MindSpore-Lab/DeepSeek-R1-Distill-Qwen-1.5B-FP16"
tokenizer = AutoTokenizer.from_pretrained(model_id, mirror="modelers")
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, mirror="modelers")

# 使用model.jit()将全图静态图化
model.jit()

inputs = tokenizer(prompts, return_tensors="ms", padding=True)
set_pyboost(False)

Qwen2ForCausalLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`.`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.
Sliding Window Attention is enabled but not implemented for `eager`; unexpected results may be encountered.


In [7]:
# 使用@mindspore.jit装饰器封装模型推理函数
@mindspore.jit(jit_config=mindspore.JitConfig(jit_syntax_level='STRICT'))
def get_decode_one_tokens_logits(model, cur_token, input_pos, cache_position, past_key_values, temperature=TEMPERATURE, top_p=TOP_P):
    """单个token的解码函数，返回logits，可以使用jit进行优化"""
    logits = model(
        cur_token,
        position_ids=input_pos,
        cache_position=cache_position,
        past_key_values=past_key_values,
        return_dict=False,
        use_cache=True
    )[0]
    return logits

In [8]:
def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values, temperature=TEMPERATURE, top_p=TOP_P):
    """单个token的解码函数，由logits、温度和Top_p选择合适的token"""
    logits = get_decode_one_tokens_logits(model, cur_token, input_pos, cache_position, past_key_values, temperature, top_p)

    if temperature > 0:
        probs = mindspore.mint.softmax(logits[:, -1] / temperature, dim=-1)
        new_token = sample_top_p(probs, top_p)
    else:
        new_token = mindspore.mint.argmax(logits[:, -1], dim=-1)[:, None]

    return new_token


batch_size, seq_length = inputs["input_ids"].shape

# 创建静态缓存（用于加速自回归生成）
past_key_values = StaticCache(
    config=model.config, max_batch_size=2, max_cache_len=512, dtype=model.dtype
)
cache_position = ops.arange(seq_length)
generated_ids = ops.zeros(
    batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=mindspore.int32
)
generated_ids[:, cache_position] = inputs["input_ids"].to(mindspore.int32)

# 初始前向传播获取首个logits
logits = model(
    **inputs, cache_position=cache_position, past_key_values=past_key_values,return_dict=False, use_cache=True
)[0]

# 生成第一个新token
if TEMPERATURE > 0:
    probs = mindspore.mint.softmax(logits[:, -1] / TEMPERATURE, dim=-1)
    next_token = sample_top_p(probs, TOP_P)
else:
    next_token = mindspore.mint.argmax(logits[:, -1], dim=-1)[:, None]

generated_ids[:, seq_length] = next_token[:, 0]

# 自回归生成循环
cache_position = mindspore.tensor([seq_length + 1])
for i in range(1, NUM_TOKENS_TO_GENERATE):
    s = time.time()
    next_token = decode_one_tokens(model, next_token, None, cache_position, past_key_values)
    generated_ids[:, cache_position] = next_token.int()
    cache_position += 1
    t = time.time()
    # 打印单步生成耗时
    print("[%d]:" % i, t - s)

text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
print(text)

[ERROR] CORE(1765,ffffbdb5c020,python3.9):2025-07-11-01:46:21.129.015 [mindspore/core/utils/file_utils.cc:253] GetRealPath] Get realpath failed, path[/tmp/ipykernel_1765/2065961418.py]


.[1]: 18.387741088867188
[2]: 0.11385655403137207
[3]: 0.1119377613067627
[4]: 0.11188721656799316
[5]: 0.11326718330383301
[6]: 0.1119697093963623
[7]: 0.11219620704650879
[8]: 0.11276006698608398
[9]: 0.11196351051330566
[10]: 0.11230587959289551
[11]: 0.11212468147277832
[12]: 0.11254668235778809
[13]: 0.11221146583557129
[14]: 0.1121068000793457
[15]: 0.1147160530090332
[16]: 0.11426210403442383
[17]: 0.11398482322692871
[18]: 0.11723160743713379
[19]: 0.11890697479248047
[20]: 0.1141505241394043
[21]: 0.11414146423339844
[22]: 0.11372566223144531
[23]: 0.11393380165100098
[24]: 0.11448025703430176
[25]: 0.115509033203125
[26]: 0.12166905403137207
[27]: 0.12461304664611816
[28]: 0.11643409729003906
[29]: 0.11528921127319336
[30]: 0.11597490310668945
[31]: 0.11604595184326172
[32]: 0.11636590957641602
[33]: 0.11526799201965332
[34]: 0.11614346504211426
[35]: 0.11723947525024414
[36]: 0.11544680595397949
[37]: 0.11590099334716797
[38]: 0.11560893058776855
[39]: 0.11474037170410156
['