Skip to content

Commit

Permalink
update with main (#1800)
Browse files Browse the repository at this point in the history
* add cmakelist

* add paraformer-torch

* add debug for funasr-onnx-offline

* fix redefinition of jieba StdExtension.hpp

* add loading torch models

* update funasr-onnx-offline

* add SwitchArg for wss-server

* add SwitchArg for funasr-onnx-offline

* update cmakelist

* update funasr-onnx-offline-rtf

* add define condition

* add gpu define for offlne-stream

* update com define

* update offline-stream

* update cmakelist

* update func CompileHotwordEmbedding

* add timestamp for paraformer-torch

* add C10_USE_GLOG for paraformer-torch

* update paraformer-torch

* fix func FunASRWfstDecoderInit

* update model.h

* fix func FunASRWfstDecoderInit

* fix tpass_stream

* update paraformer-torch

* add bladedisc for funasr-onnx-offline

* update comdefine

* update funasr-wss-server

* add log for torch

* fix GetValue BLADEDISC

* fix log

* update cmakelist

* update warmup to 10

* update funasrruntime

* add batch_size for wss-server

* add batch for bins

* add batch for offline-stream

* add batch for paraformer

* add batch for offline-stream

* fix func SetBatchSize

* add SetBatchSize for model

* add SetBatchSize for model

* fix func Forward

* fix padding

* update funasrruntime

* add dec reset for batch

* set batch default value

* add argv for CutSplit

* sort frame_queue

* sorted msgs

* fix FunOfflineInfer

* add dynamic batch for fetch

* fix FetchDynamic

* update run_server.sh

* update run_server.sh

* cpp http post server support (#1739)

* add cpp http server

* add some comment

* remove some comments

* del debug infos

* restore run_server.sh

* adapt to new model struct

* 修复了onnxruntime在macos下编译失败的错误 (#1748)

* Add files via upload

增加macos的编译支持

* Add files via upload

增加macos支持

* Add files via upload

target_link_directories(funasr PUBLIC ${ONNXRUNTIME_DIR}/lib)
target_link_directories(funasr PUBLIC ${FFMPEG_DIR}/lib)
添加 if(APPLE) 限制

---------

Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com>

* Delete docs/images/wechat.png

* Add files via upload

* fixed the issues about seaco-onnx timestamp

* fix bug (#1764)

当语音识别结果包含 `http` 时,标点符号预测会把它会被当成 url

* fix empty asr result (#1765)

解码结果为空的语音片段,text 用空字符串

* docs

* docs

* docs

* docs

* docs

* keep empty speech result (#1772)

* docs

* docs

* update wechat QRcode

* Add python funasr api support for websocket srv (#1777)

* add python funasr_api supoort

* change little to README.md

* add core tools stream

* modified a little

* fix bug for timeout

* support for buffer decode

* add ffmpeg decode for buffer

* auto frontend

* auto frontend

* auto frontend

* auto frontend

* auto frontend

* auto frontend

* auto frontend

* auto frontend

* Dev gzf exp (#1785)

* resume from step

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* train_loss_avg train_acc_avg

* train_loss_avg train_acc_avg

* train_loss_avg train_acc_avg

* log step

* wav is not exist

* wav is not exist

* decoding

* decoding

* decoding

* wechat

* decoding key

* decoding key

* decoding key

* decoding key

* decoding key

* decoding key

* dynamic batch

* start_data_split_i=0

* total_time/accum_grad

* total_time/accum_grad

* total_time/accum_grad

* update avg slice

* update avg slice

* sensevoice sanm

* sensevoice sanm

* sensevoice sanm

---------

Co-authored-by: 北念 <lzr265946@alibaba-inc.com>

* auto frontend

* update paraformer timestamp

* add cif_v1 and cif_export

* Update SDK_advanced_guide_offline_zh.md

* add cif_wo_hidden_v1

* [fix] fix empty asr result (#1794)

* wechat

* [fix] better solution for handling empty result (#1796)

---------

Co-authored-by: 雾聪 <wucong.lyb@alibaba-inc.com>
Co-authored-by: zhaomingwork <61895407+zhaomingwork@users.noreply.github.com>
Co-authored-by: szsteven008 <97944818+szsteven008@users.noreply.github.com>
Co-authored-by: Ephemeroptera <605686962@qq.com>
Co-authored-by: 彭震东 <zhendong.peng@qq.com>
Co-authored-by: Shi Xian <40013335+R1ckShi@users.noreply.github.com>
Co-authored-by: 维石 <shixian.shi@alibaba-inc.com>
Co-authored-by: 北念 <lzr265946@alibaba-inc.com>
Co-authored-by: zhuangzhong <zhuangzhong@corp.netease.com>
Co-authored-by: Xingchen Song(宋星辰) <xingchensong1996@163.com>
  • Loading branch information
11 people committed Jun 11, 2024
1 parent a8653d8 commit 20aa072
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 10 deletions.
Binary file modified docs/images/wechat.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 4 additions & 0 deletions funasr/auto/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,10 @@ def inference_with_vad(self, input, input_len=None, **cfg):
# f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
# f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")

if len(results_sorted) != n:
results_ret_list.append({"key": key, "text": "", "timestamp": []})
logging.info("decoding, utt: {}, empty result".format(key))
continue
restored_data = [0] * n
for j in range(n):
index = sorted_data[j][1]
Expand Down
1 change: 1 addition & 0 deletions funasr/models/llm_asr/adaptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def forward(self, x, ilens=None):
olens = None
olens = (ilens - 1) // self.k + 1
masks = (~make_pad_mask(olens)[:, None, :]).to(x.device)

if self.blocks is not None:
for layer, block in enumerate(self.blocks):
x, masks = block(x, masks)
Expand Down
127 changes: 124 additions & 3 deletions funasr/models/paraformer/cif_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def forward(
hidden, alphas, token_num, mask=mask
)

acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold)

if target_length is None and self.tail_threshold > 0.0:
token_num_int = torch.max(token_num).type(torch.int32).item()
Expand Down Expand Up @@ -245,7 +245,7 @@ def forward(
hidden, alphas, token_num, mask=None
)

acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold)
if target_length is None and self.tail_threshold > 0.0:
token_num_int = torch.max(token_num).type(torch.int32).item()
acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
Expand Down Expand Up @@ -449,7 +449,7 @@ def forward(
mask = mask.transpose(-1, -2).float()
mask = mask.squeeze(-1)
hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
acoustic_embeds, cif_peak = cif_export(hidden, alphas, self.threshold)
acoustic_embeds, cif_peak = cif_v1_export(hidden, alphas, self.threshold)

return acoustic_embeds, token_num, alphas, cif_peak

Expand Down Expand Up @@ -494,7 +494,60 @@ def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
token_num_floor = torch.floor(token_num)

return hidden, alphas, token_num_floor
@torch.jit.script
def cif_v1_export(hidden, alphas, threshold: float):
device = hidden.device
dtype = hidden.dtype
batch_size, len_time, hidden_size = hidden.size()
threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)

frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)

prefix_sum = torch.cumsum(alphas, dim=1)
prefix_sum_floor = torch.floor(prefix_sum)
dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)

dislocation_prefix_sum_floor[:, 0] = 0
dislocation_diff = prefix_sum_floor - dislocation_prefix_sum_floor

fire_idxs = dislocation_diff > 0
fires[fire_idxs] = 1
fires = fires + prefix_sum - prefix_sum_floor

prefix_sum_hidden = torch.cumsum(
alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1
)

frames = prefix_sum_hidden[fire_idxs]
shift_frames = torch.roll(frames, 1, dims=0)

batch_len = fire_idxs.sum(1)
batch_idxs = torch.cumsum(batch_len, dim=0)
shift_batch_idxs = torch.roll(batch_idxs, 1, dims=0)
shift_batch_idxs[0] = 0
shift_frames[shift_batch_idxs] = 0

remains = fires - torch.floor(fires)
remain_frames = (
remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
)

shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
shift_remain_frames[shift_batch_idxs] = 0

frames = frames - shift_frames + shift_remain_frames - remain_frames

max_label_len = batch_len.max()

frame_fires = torch.zeros(
batch_size, max_label_len, hidden_size, dtype=dtype, device=device
)
indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
frame_fires_idxs = indices < batch_len.unsqueeze(1)
frame_fires[frame_fires_idxs] = frames
return frame_fires, fires

@torch.jit.script
def cif_export(hidden, alphas, threshold: float):
Expand Down Expand Up @@ -608,6 +661,74 @@ def cif(hidden, alphas, threshold):
return torch.stack(list_ls, 0), fires


def cif_wo_hidden_v1(alphas, threshold, return_fire_idxs=False):
batch_size, len_time = alphas.size()
device = alphas.device
dtype = alphas.dtype

threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)

fires = torch.zeros(batch_size, len_time, dtype=dtype, device=device)

prefix_sum = torch.cumsum(alphas, dim=1)
prefix_sum_floor = torch.floor(prefix_sum)
dislocation_prefix_sum = torch.roll(prefix_sum, 1, dims=1)
dislocation_prefix_sum_floor = torch.floor(dislocation_prefix_sum)

dislocation_prefix_sum_floor[:, 0] = 0
dislocation_diff = prefix_sum_floor - dislocation_prefix_sum_floor

fire_idxs = dislocation_diff > 0
fires[fire_idxs] = 1
fires = fires + prefix_sum - prefix_sum_floor
if return_fire_idxs:
return fires, fire_idxs
return fires


def cif_v1(hidden, alphas, threshold):
fires, fire_idxs = cif_wo_hidden_v1(alphas, threshold, return_fire_idxs=True)

device = hidden.device
dtype = hidden.dtype
batch_size, len_time, hidden_size = hidden.size()
frames = torch.zeros(batch_size, len_time, hidden_size,
dtype=dtype, device=device)
prefix_sum_hidden = torch.cumsum(
alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1
)

frames = prefix_sum_hidden[fire_idxs]
shift_frames = torch.roll(frames, 1, dims=0)

batch_len = fire_idxs.sum(1)
batch_idxs = torch.cumsum(batch_len, dim=0)
shift_batch_idxs = torch.roll(batch_idxs, 1, dims=0)
shift_batch_idxs[0] = 0
shift_frames[shift_batch_idxs] = 0

remains = fires - torch.floor(fires)
remain_frames = (
remains[fire_idxs].unsqueeze(-1).tile((1,
hidden_size)) * hidden[fire_idxs]
)

shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
shift_remain_frames[shift_batch_idxs] = 0

frames = frames - shift_frames + shift_remain_frames - remain_frames

max_label_len = batch_len.max()

frame_fires = torch.zeros(
batch_size, max_label_len, hidden_size, dtype=dtype, device=device
)
indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
frame_fires_idxs = indices < batch_len.unsqueeze(1)
frame_fires[frame_fires_idxs] = frames
return frame_fires, fires


def cif_wo_hidden(alphas, threshold):
batch_size, len_time = alphas.size()

Expand Down
24 changes: 20 additions & 4 deletions funasr/models/paraformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# MIT License (https://opensource.org/licenses/MIT)

import time
import copy
import torch
import logging
from torch.cuda.amp import autocast
Expand All @@ -21,6 +22,7 @@
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank


Expand Down Expand Up @@ -452,6 +454,7 @@ def inference(
is_use_lm = (
kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
)
pred_timestamp = kwargs.get("pred_timestamp", False)
if self.beam_search is None and (is_use_lm or is_use_ctc):
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
Expand Down Expand Up @@ -506,6 +509,7 @@ def inference(
predictor_outs[2],
predictor_outs[3],
)

pre_token_length = pre_token_length.round().long()
if torch.max(pre_token_length) < 1:
return []
Expand Down Expand Up @@ -564,10 +568,22 @@ def inference(
# Change integer-ids to tokens
token = tokenizer.ids2tokens(token_int)
text_postprocessed = tokenizer.tokens2text(token)
if not hasattr(tokenizer, "bpemodel"):
text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)

result_i = {"key": key[i], "text": text_postprocessed}

if pred_timestamp:
timestamp_str, timestamp = ts_prediction_lfr6_standard(
pre_peak_index[i],
alphas[i],
copy.copy(token),
vad_offset=kwargs.get("begin_time", 0),
upsample_rate=1,
)
if not hasattr(tokenizer, "bpemodel"):
text_postprocessed, time_stamp_postprocessed, _ = postprocess_utils.sentence_postprocess(token, timestamp)
result_i = {"key": key[i], "text": text_postprocessed, "timestamp": time_stamp_postprocessed,}
else:
if not hasattr(tokenizer, "bpemodel"):
text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
result_i = {"key": key[i], "text": text_postprocessed}

if ibest_writer is not None:
ibest_writer["token"][key[i]] = " ".join(token)
Expand Down
6 changes: 3 additions & 3 deletions funasr/utils/timestamp_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ def cif_wo_hidden(alphas, threshold):


def ts_prediction_lfr6_standard(
us_alphas, us_peaks, char_list, vad_offset=0.0, force_time_shift=-1.5, sil_in_str=True
us_alphas, us_peaks, char_list, vad_offset=0.0, force_time_shift=-1.5, sil_in_str=True, upsample_rate=3,
):
if not len(char_list):
return "", []
START_END_THRESHOLD = 5
MAX_TOKEN_DURATION = 12
TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled
MAX_TOKEN_DURATION = 12 # 3 times upsampled
TIME_RATE=10.0 * 6 / 1000 / upsample_rate
if len(us_alphas.shape) == 2:
alphas, peaks = us_alphas[0], us_peaks[0] # support inference batch_size=1 only
else:
Expand Down
1 change: 1 addition & 0 deletions runtime/docs/SDK_advanced_guide_offline_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ python3 funasr_wss_client.py --host "127.0.0.1" --port 10095 --mode offline \
--port 10095 部署端口号
--wav-path 需要进行转写的音频文件,支持文件路径
--hotword 热词文件,每行一个热词,格式(热词 权重):阿里巴巴 20
--thread-num 设置客户端线程数
--use-itn 设置是否使用itn,默认1开启,设置为0关闭
```

Expand Down

0 comments on commit 20aa072

Please sign in to comment.