diff --git a/docs/images/wechat.png b/docs/images/wechat.png index b5d9a5535..491a1c3df 100644 Binary files a/docs/images/wechat.png and b/docs/images/wechat.png differ diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index 22b1ac0ec..047e652a9 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -429,6 +429,10 @@ def inference_with_vad(self, input, input_len=None, **cfg): # f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, " # f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}") + if len(results_sorted) != n: + results_ret_list.append({"key": key, "text": "", "timestamp": []}) + logging.info("decoding, utt: {}, empty result".format(key)) + continue restored_data = [0] * n for j in range(n): index = sorted_data[j][1] diff --git a/funasr/models/llm_asr/adaptor.py b/funasr/models/llm_asr/adaptor.py index c93988328..93534fef0 100644 --- a/funasr/models/llm_asr/adaptor.py +++ b/funasr/models/llm_asr/adaptor.py @@ -125,6 +125,7 @@ def forward(self, x, ilens=None): olens = None olens = (ilens - 1) // self.k + 1 masks = (~make_pad_mask(olens)[:, None, :]).to(x.device) + if self.blocks is not None: for layer, block in enumerate(self.blocks): x, masks = block(x, masks) diff --git a/funasr/models/paraformer/cif_predictor.py b/funasr/models/paraformer/cif_predictor.py index 8b1a9bbdb..7490310c9 100644 --- a/funasr/models/paraformer/cif_predictor.py +++ b/funasr/models/paraformer/cif_predictor.py @@ -80,7 +80,7 @@ def forward( hidden, alphas, token_num, mask=mask ) - acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold) + acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold) if target_length is None and self.tail_threshold > 0.0: token_num_int = torch.max(token_num).type(torch.int32).item() @@ -245,7 +245,7 @@ def forward( hidden, alphas, token_num, mask=None ) - acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold) + acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold) if target_length is None and self.tail_threshold > 0.0: token_num_int = torch.max(token_num).type(torch.int32).item() acoustic_embeds = acoustic_embeds[:, :token_num_int, :] @@ -449,7 +449,7 @@ def forward( mask = mask.transpose(-1, -2).float() mask = mask.squeeze(-1) hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask) - acoustic_embeds, cif_peak = cif_export(hidden, alphas, self.threshold) + acoustic_embeds, cif_peak = cif_v1_export(hidden, alphas, self.threshold) return acoustic_embeds, token_num, alphas, cif_peak @@ -494,7 +494,60 @@ def tail_process_fn(self, hidden, alphas, token_num=None, mask=None): token_num_floor = torch.floor(token_num) return hidden, alphas, token_num_floor +@torch.jit.script +def cif_v1_export(hidden, alphas, threshold: float): + device = hidden.device + dtype = hidden.dtype + batch_size, len_time, hidden_size = hidden.size() + threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device) + + frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device) + fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device) + + prefix_sum = torch.cumsum(alphas, dim=1) + prefix_sum_floor = torch.floor(prefix_sum) + dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1) + dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum) + + dislocation_prefix_sum_floor[:, 0] = 0 + dislocation_diff = prefix_sum_floor - dislocation_prefix_sum_floor + + fire_idxs = dislocation_diff > 0 + fires[fire_idxs] = 1 + fires = fires + prefix_sum - prefix_sum_floor + + prefix_sum_hidden = torch.cumsum( + alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1 + ) + frames = prefix_sum_hidden[fire_idxs] + shift_frames = torch.roll(frames, 1, dims=0) + + batch_len = fire_idxs.sum(1) + batch_idxs = torch.cumsum(batch_len, dim=0) + shift_batch_idxs = torch.roll(batch_idxs, 1, dims=0) + shift_batch_idxs[0] = 0 + shift_frames[shift_batch_idxs] = 0 + + remains = fires - torch.floor(fires) + remain_frames = ( + remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs] + ) + + shift_remain_frames = torch.roll(remain_frames, 1, dims=0) + shift_remain_frames[shift_batch_idxs] = 0 + + frames = frames - shift_frames + shift_remain_frames - remain_frames + + max_label_len = batch_len.max() + + frame_fires = torch.zeros( + batch_size, max_label_len, hidden_size, dtype=dtype, device=device + ) + indices = torch.arange(max_label_len, device=device).expand(batch_size, -1) + frame_fires_idxs = indices < batch_len.unsqueeze(1) + frame_fires[frame_fires_idxs] = frames + return frame_fires, fires @torch.jit.script def cif_export(hidden, alphas, threshold: float): @@ -608,6 +661,74 @@ def cif(hidden, alphas, threshold): return torch.stack(list_ls, 0), fires +def cif_wo_hidden_v1(alphas, threshold, return_fire_idxs=False): + batch_size, len_time = alphas.size() + device = alphas.device + dtype = alphas.dtype + + threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device) + + fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device) + + prefix_sum = torch.cumsum(alphas, dim=1) + prefix_sum_floor = torch.floor(prefix_sum) + dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1) + dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum) + + dislocation_prefix_sum_floor[:, 0] = 0 + dislocation_diff = prefix_sum_floor - dislocation_prefix_sum_floor + + fire_idxs = dislocation_diff > 0 + fires[fire_idxs] = 1 + fires = fires + prefix_sum - prefix_sum_floor + if return_fire_idxs: + return fires, fire_idxs + return fires + + +def cif_v1(hidden, alphas, threshold): + fires, fire_idxs = cif_wo_hidden_v1(alphas, threshold, return_fire_idxs=True) + + device = hidden.device + dtype = hidden.dtype + batch_size, len_time, hidden_size = hidden.size() + frames = torch.zeros(batch_size, len_time, hidden_size, + dtype=dtype, device=device) + prefix_sum_hidden = torch.cumsum( + alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1 + ) + + frames = prefix_sum_hidden[fire_idxs] + shift_frames = torch.roll(frames, 1, dims=0) + + batch_len = fire_idxs.sum(1) + batch_idxs = torch.cumsum(batch_len, dim=0) + shift_batch_idxs = torch.roll(batch_idxs, 1, dims=0) + shift_batch_idxs[0] = 0 + shift_frames[shift_batch_idxs] = 0 + + remains = fires - torch.floor(fires) + remain_frames = ( + remains[fire_idxs].unsqueeze(-1).tile((1, + hidden_size)) * hidden[fire_idxs] + ) + + shift_remain_frames = torch.roll(remain_frames, 1, dims=0) + shift_remain_frames[shift_batch_idxs] = 0 + + frames = frames - shift_frames + shift_remain_frames - remain_frames + + max_label_len = batch_len.max() + + frame_fires = torch.zeros( + batch_size, max_label_len, hidden_size, dtype=dtype, device=device + ) + indices = torch.arange(max_label_len, device=device).expand(batch_size, -1) + frame_fires_idxs = indices < batch_len.unsqueeze(1) + frame_fires[frame_fires_idxs] = frames + return frame_fires, fires + + def cif_wo_hidden(alphas, threshold): batch_size, len_time = alphas.size() diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py index 0d9bb2be1..85967af3e 100644 --- a/funasr/models/paraformer/model.py +++ b/funasr/models/paraformer/model.py @@ -4,6 +4,7 @@ # MIT License (https://opensource.org/licenses/MIT) import time +import copy import torch import logging from torch.cuda.amp import autocast @@ -21,6 +22,7 @@ from funasr.losses.label_smoothing_loss import LabelSmoothingLoss from funasr.models.transformer.utils.add_sos_eos import add_sos_eos from funasr.models.transformer.utils.nets_utils import make_pad_mask +from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank @@ -452,6 +454,7 @@ def inference( is_use_lm = ( kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None ) + pred_timestamp = kwargs.get("pred_timestamp", False) if self.beam_search is None and (is_use_lm or is_use_ctc): logging.info("enable beam_search") self.init_beam_search(**kwargs) @@ -506,6 +509,7 @@ def inference( predictor_outs[2], predictor_outs[3], ) + pre_token_length = pre_token_length.round().long() if torch.max(pre_token_length) < 1: return [] @@ -564,10 +568,22 @@ def inference( # Change integer-ids to tokens token = tokenizer.ids2tokens(token_int) text_postprocessed = tokenizer.tokens2text(token) - if not hasattr(tokenizer, "bpemodel"): - text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) - - result_i = {"key": key[i], "text": text_postprocessed} + + if pred_timestamp: + timestamp_str, timestamp = ts_prediction_lfr6_standard( + pre_peak_index[i], + alphas[i], + copy.copy(token), + vad_offset=kwargs.get("begin_time", 0), + upsample_rate=1, + ) + if not hasattr(tokenizer, "bpemodel"): + text_postprocessed, time_stamp_postprocessed, _ = postprocess_utils.sentence_postprocess(token, timestamp) + result_i = {"key": key[i], "text": text_postprocessed, "timestamp": time_stamp_postprocessed,} + else: + if not hasattr(tokenizer, "bpemodel"): + text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) + result_i = {"key": key[i], "text": text_postprocessed} if ibest_writer is not None: ibest_writer["token"][key[i]] = " ".join(token) diff --git a/funasr/utils/timestamp_tools.py b/funasr/utils/timestamp_tools.py index 831d77357..af61e5a8f 100644 --- a/funasr/utils/timestamp_tools.py +++ b/funasr/utils/timestamp_tools.py @@ -29,13 +29,13 @@ def cif_wo_hidden(alphas, threshold): def ts_prediction_lfr6_standard( - us_alphas, us_peaks, char_list, vad_offset=0.0, force_time_shift=-1.5, sil_in_str=True + us_alphas, us_peaks, char_list, vad_offset=0.0, force_time_shift=-1.5, sil_in_str=True, upsample_rate=3, ): if not len(char_list): return "", [] START_END_THRESHOLD = 5 - MAX_TOKEN_DURATION = 12 - TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled + MAX_TOKEN_DURATION = 12 # 3 times upsampled + TIME_RATE=10.0 * 6 / 1000 / upsample_rate if len(us_alphas.shape) == 2: alphas, peaks = us_alphas[0], us_peaks[0] # support inference batch_size=1 only else: diff --git a/runtime/docs/SDK_advanced_guide_offline_zh.md b/runtime/docs/SDK_advanced_guide_offline_zh.md index 1cecb8881..902e16908 100644 --- a/runtime/docs/SDK_advanced_guide_offline_zh.md +++ b/runtime/docs/SDK_advanced_guide_offline_zh.md @@ -149,6 +149,7 @@ python3 funasr_wss_client.py --host "127.0.0.1" --port 10095 --mode offline \ --port 10095 部署端口号 --wav-path 需要进行转写的音频文件,支持文件路径 --hotword 热词文件,每行一个热词,格式(热词 权重):阿里巴巴 20 +--thread-num 设置客户端线程数 --use-itn 设置是否使用itn,默认1开启,设置为0关闭 ```