Skip to content

Commit

Permalink
Merge pull request #1763 from coolEphemeroptera/main
Browse files Browse the repository at this point in the history
fixed the issues about seaco-onnx timestamp
  • Loading branch information
R1ckShi committed May 28, 2024
2 parents 2c4ae54 + 18bbf14 commit 50b2668
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 5 deletions.
10 changes: 8 additions & 2 deletions funasr/models/seaco_paraformer/export_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,11 @@ def export_backbone_forward(
dha_ids = dha_pred.max(-1)[-1]
dha_mask = (dha_ids == self.NOBIAS).int().unsqueeze(-1)
decoder_out = decoder_out * dha_mask + dha_pred * (1 - dha_mask)
return decoder_out, pre_token_length, alphas

# get predicted timestamps
us_alphas, us_cif_peak = self.predictor.get_upsample_timestmap(enc, mask, pre_token_length)

return decoder_out, pre_token_length, us_alphas, us_cif_peak


def export_backbone_dummy_inputs(self):
Expand All @@ -178,7 +182,7 @@ def export_backbone_input_names(self):


def export_backbone_output_names(self):
return ["logits", "token_num", "alphas"]
return ["logits", "token_num", "us_alphas", "us_cif_peak"]


def export_backbone_dynamic_axes(self):
Expand All @@ -190,6 +194,8 @@ def export_backbone_dynamic_axes(self):
"bias_embed": {0: "batch_size", 1: "num_hotwords"},
"logits": {0: "batch_size", 1: "logits_length"},
"pre_acoustic_embeds": {1: "feats_length1"},
"us_alphas": {0: "batch_size", 1: "alphas_length"},
"us_cif_peak": {0: "batch_size", 1: "alphas_length"},
}


Expand Down
41 changes: 38 additions & 3 deletions runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,9 @@ def __init__(
def __call__(
self, wav_content: Union[str, np.ndarray, List[str]], hotwords: str, **kwargs
) -> List:
# def __call__(
# self, waveform_list:list, hotwords: str, **kwargs
# ) -> List:
# make hotword list
hotwords, hotwords_length = self.proc_hotword(hotwords)
# import pdb; pdb.set_trace()
Expand All @@ -345,15 +348,47 @@ def __call__(
try:
outputs = self.bb_infer(feats, feats_len, bias_embed)
am_scores, valid_token_lens = outputs[0], outputs[1]

if len(outputs) == 4:
# for BiCifParaformer Inference
us_alphas, us_peaks = outputs[2], outputs[3]
else:
us_alphas, us_peaks = None, None

except ONNXRuntimeError:
# logging.warning(traceback.format_exc())
logging.warning("input wav is silence or noise")
preds = [""]
else:
preds = self.decode(am_scores, valid_token_lens)
for pred in preds:
pred = sentence_postprocess(pred)
asr_res.append({"preds": pred})
if us_peaks is None:
for pred in preds:
if self.language == "en-bpe":
pred = sentence_postprocess_sentencepiece(pred)
else:
pred = sentence_postprocess(pred)
asr_res.append({"preds": pred})
else:
for pred, us_peaks_ in zip(preds, us_peaks):
raw_tokens = pred
timestamp, timestamp_raw = time_stamp_lfr6_onnx(
us_peaks_, copy.copy(raw_tokens)
)
text_proc, timestamp_proc, _ = sentence_postprocess(
raw_tokens, timestamp_raw
)
# logging.warning(timestamp)
if len(self.plot_timestamp_to):
self.plot_wave_timestamp(
waveform_list[0], timestamp, self.plot_timestamp_to
)
asr_res.append(
{
"preds": text_proc,
"timestamp": timestamp_proc,
"raw_tokens": raw_tokens,
}
)
return asr_res

def proc_hotword(self, hotwords):
Expand Down

0 comments on commit 50b2668

Please sign in to comment.