-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Open
Labels
questionFurther information is requestedFurther information is requested
Description
cpu环境,M2
同一个音频
在pytoch执行下是有文字输出,但是在onnx执行下就输出空
pytoch代码:
from funasr import AutoModel
paramformer_path="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
wavPath="test.wav"
paramformModel = AutoModel(
model=paramformer_path,
disable_update= False,
model_revision="v2.0.4",
device="cpu"
)
paramformModel.generate(wavPath,
language="auto",
hotword="阿里巴巴",
batch_size_s=300)
输出结果:res:[{'key': 'A132916050_1260_3300', 'text': '喂喂喂听到吗'}]
onnx代码:
from funasr_onnx import Paraformer
paramformer_path="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
paramformModel = Paraformer(paramformer_path, batch_size=1, quantize=False)
wavPath="test.wav"
# paramformModel([wavPath]) paramformModel(wavPath,hotword="")都测试过
paramformModel(wavPath)
输出结果res:[]
导出onnx的代码:
from funasr import AutoModel
model_path = "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
# 2. 核心:
model = AutoModel(
model=model_path,
device="cpu",
)
# 3. 导出
model.export(
output_dir=model_path,
type="onnx",
quantize=False,
opset=14
)
排除音频问题,onnx可正常初始化
Gemini给出的解决方案,但是无法真实应用
import os
import json
import numpy as np
import onnxruntime as ort
import librosa
import time
from numpy.lib.stride_tricks import as_strided
class HighPerformanceASR:
def __init__(self, model_dir, threads=4):
# 1. 路径设置
self.onnx_path = os.path.join(model_dir, "model.onnx")
self.tokens_path = os.path.join(model_dir, "tokens.json")
self.am_mvn_path = os.path.join(model_dir, "am.mvn")
# 2. 预加载词表(单次执行,不计入推理耗时)
with open(self.tokens_path, "r", encoding="utf-8") as f:
self.token_list = json.load(f)
# 3. 预加载并解析 CMVN
import pickle
with open(self.am_mvn_path, "rb") as f:
cmvn = pickle.load(f)
self.means = cmvn[0].astype(np.float32)
self.istd = cmvn[1].astype(np.float32)
# 4. ONNX Runtime 性能调优
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.intra_op_num_threads = threads # 算子内并行线程数
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
# 针对 Mac 环境,如有必要可开启 CoreML 加速
# providers = [('CoreMLExecutionProvider', {'MLComputeUnits': 'ALL'}), 'CPUExecutionProvider']
self.session = ort.InferenceSession(self.onnx_path, sess_options=sess_options,
providers=['CPUExecutionProvider'])
self.input_names = [i.name for i in self.session.get_inputs()]
def _apply_lfr_vectorized(self, fbank, m=7, n=6):
"""利用 NumPy 视图步进实现零拷贝拼接,速度提升 100x"""
L, D = fbank.shape
T = (L - m) // n + 1
if T <= 0: return None
# 创建一个步进视图,直接映射原始内存,避免 Python 循环拼接
itemsize = fbank.itemsize
feat = as_strided(
fbank,
shape=(T, m, D),
strides=(fbank.strides[0] * n, fbank.strides[0], fbank.strides[1])
)
return feat.reshape(T, -1) # 形状变为 [T, 560]
def predict(self, wav_path):
start_time = time.time()
# A. 前处理 (Fbank)
y, sr = librosa.load(wav_path, sr=16000)
fbank = librosa.feature.melspectrogram(
y=y, sr=sr, n_mels=80, n_fft=400, hop_length=160, win_length=400, center=False
).T
fbank = np.log(fbank + 1e-6).astype(np.float32)
# B. CMVN 归一化 (向量化加速)
fbank = (fbank - self.means) * self.istd
# C. LFR 拼帧 (向量化加速)
feat = self._apply_lfr_vectorized(fbank)
if feat is None: return "", 0
feat = np.expand_dims(feat, axis=0)
feat_len = np.array([feat.shape[1]], dtype=np.int32)
# D. 补齐 bias_embed
inputs = {"speech": feat, "speech_lengths": feat_len}
if "bias_embed" in self.input_names:
inputs["bias_embed"] = np.zeros((1, 1, 512), dtype=np.float32)
# E. ONNX 推理 (这是最耗时的 C++ 算子部分)
outputs = self.session.run(None, inputs)
# F. 后处理解码
token_ids = np.argmax(outputs[0], axis=-1)[0]
res_text = "".join([self.token_list[tid].replace("@@", "") for tid in token_ids if tid > 2])
total_time = time.time() - start_time
return res_text, total_time
if __name__ == "__main__":
MODEL_DIR = "iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
WAV_FILE = "test.wav"
# 第一次运行(预热)
engine = HighPerformanceASR(MODEL_DIR, threads=4)
_, _ = engine.predict(WAV_FILE)
# 正式计时
text, duration = engine.predict(WAV_FILE)
print(f"识别结果: {text}")
print(f"推理耗时: {duration:.4f} 秒")
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
questionFurther information is requestedFurther information is requested