Skip to content

Commit

Permalink
update with main (#1817)
Browse files Browse the repository at this point in the history
* add cmakelist

* add paraformer-torch

* add debug for funasr-onnx-offline

* fix redefinition of jieba StdExtension.hpp

* add loading torch models

* update funasr-onnx-offline

* add SwitchArg for wss-server

* add SwitchArg for funasr-onnx-offline

* update cmakelist

* update funasr-onnx-offline-rtf

* add define condition

* add gpu define for offlne-stream

* update com define

* update offline-stream

* update cmakelist

* update func CompileHotwordEmbedding

* add timestamp for paraformer-torch

* add C10_USE_GLOG for paraformer-torch

* update paraformer-torch

* fix func FunASRWfstDecoderInit

* update model.h

* fix func FunASRWfstDecoderInit

* fix tpass_stream

* update paraformer-torch

* add bladedisc for funasr-onnx-offline

* update comdefine

* update funasr-wss-server

* add log for torch

* fix GetValue BLADEDISC

* fix log

* update cmakelist

* update warmup to 10

* update funasrruntime

* add batch_size for wss-server

* add batch for bins

* add batch for offline-stream

* add batch for paraformer

* add batch for offline-stream

* fix func SetBatchSize

* add SetBatchSize for model

* add SetBatchSize for model

* fix func Forward

* fix padding

* update funasrruntime

* add dec reset for batch

* set batch default value

* add argv for CutSplit

* sort frame_queue

* sorted msgs

* fix FunOfflineInfer

* add dynamic batch for fetch

* fix FetchDynamic

* update run_server.sh

* update run_server.sh

* cpp http post server support (#1739)

* add cpp http server

* add some comment

* remove some comments

* del debug infos

* restore run_server.sh

* adapt to new model struct

* 修复了onnxruntime在macos下编译失败的错误 (#1748)

* Add files via upload

增加macos的编译支持

* Add files via upload

增加macos支持

* Add files via upload

target_link_directories(funasr PUBLIC ${ONNXRUNTIME_DIR}/lib)
target_link_directories(funasr PUBLIC ${FFMPEG_DIR}/lib)
添加 if(APPLE) 限制

---------

Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com>

* Delete docs/images/wechat.png

* Add files via upload

* fixed the issues about seaco-onnx timestamp

* fix bug (#1764)

当语音识别结果包含 `http` 时,标点符号预测会把它会被当成 url

* fix empty asr result (#1765)

解码结果为空的语音片段,text 用空字符串

* update export

* update export

* docs

* docs

* update export name

* docs

* update

* docs

* docs

* keep empty speech result (#1772)

* docs

* docs

* update wechat QRcode

* Add python funasr api support for websocket srv (#1777)

* add python funasr_api supoort

* change little to README.md

* add core tools stream

* modified a little

* fix bug for timeout

* support for buffer decode

* add ffmpeg decode for buffer

* libtorch demo

* update libtorch infer

* update utils

* update demo

* update demo

* update libtorch inference

* update model class

* update seaco paraformer

* bug fix

* bug fix

* auto frontend

* auto frontend

* auto frontend

* auto frontend

* auto frontend

* auto frontend

* auto frontend

* auto frontend

* Dev gzf exp (#1785)

* resume from step

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* batch

* train_loss_avg train_acc_avg

* train_loss_avg train_acc_avg

* train_loss_avg train_acc_avg

* log step

* wav is not exist

* wav is not exist

* decoding

* decoding

* decoding

* wechat

* decoding key

* decoding key

* decoding key

* decoding key

* decoding key

* decoding key

* dynamic batch

* start_data_split_i=0

* total_time/accum_grad

* total_time/accum_grad

* total_time/accum_grad

* update avg slice

* update avg slice

* sensevoice sanm

* sensevoice sanm

* sensevoice sanm

---------

Co-authored-by: 北念 <lzr265946@alibaba-inc.com>

* auto frontend

* update paraformer timestamp

* [Optimization] support bladedisc fp16 optimization (#1790)

* add cif_v1 and cif_export

* Update SDK_advanced_guide_offline_zh.md

* add cif_wo_hidden_v1

* [fix] fix empty asr result (#1794)

* english timestamp for valilla paraformer

* wechat

* [fix] better solution for handling empty result (#1796)

* update scripts

* modify the qformer adaptor (#1804)

Co-authored-by: nichongjia-2007 <nichongjia@gmail.com>

* add ctc inference code (#1806)

Co-authored-by: haoneng.lhn <haoneng.lhn@alibaba-inc.com>

* Update auto_model.py

修复空字串进入speaker model时报raw_text变量不存在的bug

* Update auto_model.py

修复识别出空串后spk_model内变量未定义问题

* update model name

* fix paramter 'quantize' unused issue (#1813)

Co-authored-by: ZihanLiao <liaozihan1@xdf.cn>

* wechat

* Update cif_predictor.py (#1811)

* Update cif_predictor.py

* modify cif_v1_export

under extreme cases, max_label_len calculated by batch_len misaligns with token_num

* Update cif_predictor.py

torch.cumsum precision degradation, using float64 instead

* update code

---------

Co-authored-by: 雾聪 <wucong.lyb@alibaba-inc.com>
Co-authored-by: zhaomingwork <61895407+zhaomingwork@users.noreply.github.com>
Co-authored-by: szsteven008 <97944818+szsteven008@users.noreply.github.com>
Co-authored-by: Ephemeroptera <605686962@qq.com>
Co-authored-by: 彭震东 <zhendong.peng@qq.com>
Co-authored-by: Shi Xian <40013335+R1ckShi@users.noreply.github.com>
Co-authored-by: 维石 <shixian.shi@alibaba-inc.com>
Co-authored-by: 北念 <lzr265946@alibaba-inc.com>
Co-authored-by: xiaowan0322 <wanchen.swc@alibaba-inc.com>
Co-authored-by: zhuangzhong <zhuangzhong@corp.netease.com>
Co-authored-by: Xingchen Song(宋星辰) <xingchensong1996@163.com>
Co-authored-by: nichongjia-2007 <nichongjia@gmail.com>
Co-authored-by: haoneng.lhn <haoneng.lhn@alibaba-inc.com>
Co-authored-by: liugz18 <57401541+liugz18@users.noreply.github.com>
Co-authored-by: Marlowe <54339989+ZihanLiao@users.noreply.github.com>
Co-authored-by: ZihanLiao <liaozihan1@xdf.cn>
Co-authored-by: zhong zhuang <zhuangz@lamda.nju.edu.cn>
  • Loading branch information
18 people committed Jun 19, 2024
1 parent de0b35b commit ad99b26
Show file tree
Hide file tree
Showing 39 changed files with 514 additions and 168 deletions.
Binary file modified docs/images/wechat.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
18 changes: 9 additions & 9 deletions examples/industrial_data_pretraining/bicif_paraformer/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@
device="cpu",
)

res = model.export(type="onnx", quantize=False)
res = model.export(type="torchscripts", quantize=False)
print(res)


# method2, inference from local path
from funasr import AutoModel
# # method2, inference from local path
# from funasr import AutoModel

model = AutoModel(
model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
device="cpu",
)
# model = AutoModel(
# model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
# device="cpu",
# )

res = model.export(type="onnx", quantize=False)
print(res)
# res = model.export(type="onnx", quantize=False)
# print(res)
1 change: 1 addition & 0 deletions examples/industrial_data_pretraining/ctc/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import sys
from funasr import AutoModel


model_dir = "/Users/zhifu/Downloads/modelscope_models/ctc_model"
input_file = (
"https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
Expand Down
19 changes: 10 additions & 9 deletions examples/industrial_data_pretraining/paraformer/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,20 @@
from funasr import AutoModel

model = AutoModel(
model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
model="iic/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404",
)

res = model.export(type="onnx", quantize=False)
res = model.export(type="torchscripts", quantize=False)
# res = model.export(type="bladedisc", input=f"{model.model_path}/example/asr_example.wav")
print(res)


# method2, inference from local path
from funasr import AutoModel
# # method2, inference from local path
# from funasr import AutoModel

model = AutoModel(
model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
)
# model = AutoModel(
# model="/Users/zhifu/.cache/modelscope/hub/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
# )

res = model.export(type="onnx", quantize=False)
print(res)
# res = model.export(type="onnx", quantize=False)
# print(res)
35 changes: 13 additions & 22 deletions funasr/auto/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def inference_with_vad(self, input, input_len=None, **cfg):
input, input_len=input_len, model=self.vad_model, kwargs=self.vad_kwargs, **cfg
)
end_vad = time.time()

# FIX(gcf): concat the vad clips for sense vocie model for better aed
if kwargs.get("merge_vad", False):
for i in range(len(res)):
Expand Down Expand Up @@ -467,23 +467,20 @@ def inference_with_vad(self, input, input_len=None, **cfg):
else:
result[k] += restored_data[j][k]

if not len(result["text"].strip()):
continue
return_raw_text = kwargs.get("return_raw_text", False)
# step.3 compute punc model
raw_text = None
if self.punc_model is not None:
if not len(result["text"].strip()):
if return_raw_text:
result["raw_text"] = ""
else:
deep_update(self.punc_kwargs, cfg)
punc_res = self.inference(
result["text"], model=self.punc_model, kwargs=self.punc_kwargs, **cfg
)
raw_text = copy.copy(result["text"])
if return_raw_text:
result["raw_text"] = raw_text
result["text"] = punc_res[0]["text"]
else:
raw_text = None
deep_update(self.punc_kwargs, cfg)
punc_res = self.inference(
result["text"], model=self.punc_model, kwargs=self.punc_kwargs, **cfg
)
raw_text = copy.copy(result["text"])
if return_raw_text:
result["raw_text"] = raw_text
result["text"] = punc_res[0]["text"]

# speaker embedding cluster after resorted
if self.spk_model is not None and kwargs.get("return_spk_res", True):
Expand Down Expand Up @@ -605,12 +602,6 @@ def export(self, input=None, **cfg):
)

with torch.no_grad():

if type == "onnx":
export_dir = export_utils.export_onnx(model=model, data_in=data_list, **kwargs)
else:
export_dir = export_utils.export_torchscripts(
model=model, data_in=data_list, **kwargs
)
export_dir = export_utils.export(model=model, data_in=data_list, **kwargs)

return export_dir
2 changes: 0 additions & 2 deletions funasr/datasets/llm_datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ def __len__(self):

def __getitem__(self, index):
item = self.index_ds[index]
# import pdb;
# pdb.set_trace()
source = item["source"]
data_src = load_audio_text_image_video(source, fs=self.fs)
if self.preprocessor_speech:
Expand Down
2 changes: 0 additions & 2 deletions funasr/datasets/llm_datasets_qwenaudio/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ def __len__(self):

def __getitem__(self, index):
item = self.index_ds[index]
# import pdb;
# pdb.set_trace()
source = item["source"]
data_src = load_audio_text_image_video(source, fs=self.fs)
if self.preprocessor_speech:
Expand Down
2 changes: 0 additions & 2 deletions funasr/datasets/llm_datasets_vicuna/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ def __len__(self):

def __getitem__(self, index):
item = self.index_ds[index]
# import pdb;
# pdb.set_trace()
source = item["source"]
data_src = load_audio_text_image_video(source, fs=self.fs)
if self.preprocessor_speech:
Expand Down
2 changes: 0 additions & 2 deletions funasr/datasets/sense_voice_datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,6 @@ def __len__(self):
return len(self.index_ds)

def __getitem__(self, index):
# import pdb;
# pdb.set_trace()

output = None
for idx in range(self.retry):
Expand Down
1 change: 0 additions & 1 deletion funasr/frontends/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,6 @@ def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Domain-conversion: e.g. Stft: time -> time-freq
# import pdb;pdb.set_trace()
if self.stft is not None:
input_stft, feats_lens = self._compute_stft(input, input_lengths)
else:
Expand Down
2 changes: 1 addition & 1 deletion funasr/models/bicif_paraformer/cif_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def forward(
output2 = self.upsample_cnn(_output)
output2 = output2.transpose(1, 2)
output2, _ = self.self_attn(output2, mask)
# import pdb; pdb.set_trace()

alphas2 = torch.sigmoid(self.cif_output2(output2))
alphas2 = torch.nn.functional.relu(alphas2 * self.smooth_factor2 - self.noise_threshold2)
# repeat the mask in T demension to match the upsampled length
Expand Down
3 changes: 2 additions & 1 deletion funasr/models/bicif_paraformer/export_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def export_rebuild_model(model, **kwargs):
model.export_input_names = types.MethodType(export_input_names, model)
model.export_output_names = types.MethodType(export_output_names, model)
model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
model.export_name = types.MethodType(export_name, model)

model.export_name = "model"

return model

Expand Down
1 change: 0 additions & 1 deletion funasr/models/contextual_paraformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,6 @@ def forward(
# contextual_mask = myutils.sequence_mask(contextual_length, device=memory.device)[:, None, :]
contextual_mask = self.make_pad_mask(contextual_length)
contextual_mask, _ = self.prepare_mask(contextual_mask)
# import pdb; pdb.set_trace()
contextual_mask = contextual_mask.transpose(2, 1).unsqueeze(1)
cx, tgt_mask, _, _, _ = self.bias_decoder(
x_self_attn, tgt_mask, bias_embed, memory_mask=contextual_mask
Expand Down
19 changes: 18 additions & 1 deletion funasr/models/contextual_paraformer/export_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,21 @@ def __init__(self, model, **kwargs):
self.embedding = model.bias_embed
model.bias_encoder.batch_first = False
self.bias_encoder = model.bias_encoder

def export_dummy_inputs(self):
hotword = torch.tensor(
[
[10, 11, 12, 13, 14, 10, 11, 12, 13, 14],
[100, 101, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[10, 11, 12, 13, 14, 10, 11, 12, 13, 14],
[100, 101, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
],
dtype=torch.int32,
)
# hotword_length = torch.tensor([10, 2, 1], dtype=torch.int32)
return (hotword)


def export_rebuild_model(model, **kwargs):
Expand Down Expand Up @@ -59,7 +74,9 @@ def export_rebuild_model(model, **kwargs):
backbone_model.export_dynamic_axes = types.MethodType(
export_backbone_dynamic_axes, backbone_model
)
backbone_model.export_name = types.MethodType(export_backbone_name, backbone_model)

embedder_model.export_name = "model_eb"
backbone_model.export_name = "model"

return backbone_model, embedder_model

Expand Down
2 changes: 0 additions & 2 deletions funasr/models/lcbnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
from funasr.utils.datadir_writer import DatadirWriter
from funasr.register import tables

import pdb


@tables.register("model_classes", "LCBNet")
class LCBNet(nn.Module):
Expand Down
2 changes: 0 additions & 2 deletions funasr/models/llm_asr/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,6 @@ def forward(
text: (Batch, Length)
text_lengths: (Batch,)
"""
# import pdb;
# pdb.set_trace()
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
Expand Down
2 changes: 0 additions & 2 deletions funasr/models/llm_asr_nar/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,6 @@ def forward(
text: (Batch, Length)
text_lengths: (Batch,)
"""
# import pdb;
# pdb.set_trace()
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
Expand Down
2 changes: 0 additions & 2 deletions funasr/models/mfcca/mfcca_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from funasr.models.transformer.utils.subsampling import TooShortUttError
from funasr.models.transformer.utils.subsampling import check_short_utt
from funasr.models.encoder.abs_encoder import AbsEncoder
import pdb
import math


Expand Down Expand Up @@ -363,7 +362,6 @@ def forward(
t_leng = xs_pad.size(1)
d_dim = xs_pad.size(2)
xs_pad = xs_pad.reshape(-1, channel_size, t_leng, d_dim)
# pdb.set_trace()
if channel_size < 8:
repeat_num = math.ceil(8 / channel_size)
xs_pad = xs_pad.repeat(1, repeat_num, 1, 1)[:, 0:8, :, :]
Expand Down
31 changes: 10 additions & 21 deletions funasr/models/paraformer/cif_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,8 @@ def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
token_num_floor = torch.floor(token_num)

return hidden, alphas, token_num_floor


@torch.jit.script
def cif_v1_export(hidden, alphas, threshold: float):
device = hidden.device
Expand All @@ -516,9 +518,7 @@ def cif_v1_export(hidden, alphas, threshold: float):
fires[fire_idxs] = 1
fires = fires + prefix_sum - prefix_sum_floor

prefix_sum_hidden = torch.cumsum(
alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1
)
prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)

frames = prefix_sum_hidden[fire_idxs]
shift_frames = torch.roll(frames, 1, dims=0)
Expand All @@ -530,9 +530,7 @@ def cif_v1_export(hidden, alphas, threshold: float):
shift_frames[shift_batch_idxs] = 0

remains = fires - torch.floor(fires)
remain_frames = (
remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]
)
remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]

shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
shift_remain_frames[shift_batch_idxs] = 0
Expand All @@ -541,14 +539,13 @@ def cif_v1_export(hidden, alphas, threshold: float):

max_label_len = batch_len.max()

frame_fires = torch.zeros(
batch_size, max_label_len, hidden_size, dtype=dtype, device=device
)
frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
frame_fires_idxs = indices < batch_len.unsqueeze(1)
frame_fires[frame_fires_idxs] = frames
return frame_fires, fires


@torch.jit.script
def cif_export(hidden, alphas, threshold: float):
batch_size, len_time, hidden_size = hidden.size()
Expand Down Expand Up @@ -692,11 +689,8 @@ def cif_v1(hidden, alphas, threshold):
device = hidden.device
dtype = hidden.dtype
batch_size, len_time, hidden_size = hidden.size()
frames = torch.zeros(batch_size, len_time, hidden_size,
dtype=dtype, device=device)
prefix_sum_hidden = torch.cumsum(
alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1
)
frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device)
prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1)

frames = prefix_sum_hidden[fire_idxs]
shift_frames = torch.roll(frames, 1, dims=0)
Expand All @@ -708,10 +702,7 @@ def cif_v1(hidden, alphas, threshold):
shift_frames[shift_batch_idxs] = 0

remains = fires - torch.floor(fires)
remain_frames = (
remains[fire_idxs].unsqueeze(-1).tile((1,
hidden_size)) * hidden[fire_idxs]
)
remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs]

shift_remain_frames = torch.roll(remain_frames, 1, dims=0)
shift_remain_frames[shift_batch_idxs] = 0
Expand All @@ -720,9 +711,7 @@ def cif_v1(hidden, alphas, threshold):

max_label_len = batch_len.max()

frame_fires = torch.zeros(
batch_size, max_label_len, hidden_size, dtype=dtype, device=device
)
frame_fires = torch.zeros(batch_size, max_label_len, hidden_size, dtype=dtype, device=device)
indices = torch.arange(max_label_len, device=device).expand(batch_size, -1)
frame_fires_idxs = indices < batch_len.unsqueeze(1)
frame_fires[frame_fires_idxs] = frames
Expand Down
1 change: 1 addition & 0 deletions funasr/models/paraformer/export_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def export_rebuild_model(model, **kwargs):
model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
model.export_name = types.MethodType(export_name, model)

model.export_name = 'model'
return model


Expand Down
4 changes: 0 additions & 4 deletions funasr/models/paraformer_streaming/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ def __init__(

super().__init__(*args, **kwargs)

# import pdb;
# pdb.set_trace()
self.sampling_ratio = kwargs.get("sampling_ratio", 0.2)

self.scama_mask = None
Expand Down Expand Up @@ -83,8 +81,6 @@ def forward(
text: (Batch, Length)
text_lengths: (Batch,)
"""
# import pdb;
# pdb.set_trace()
decoding_ind = kwargs.get("decoding_ind")
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
Expand Down
2 changes: 1 addition & 1 deletion funasr/models/sanm/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ def forward_qkv(self, x, memory):
return q, k, v

def forward_attention(self, value, scores, mask, ret_attn):
scores = scores + mask
scores = scores + mask.to(scores.device)

self.attn = torch.softmax(scores, dim=-1)
context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
Expand Down
Loading

0 comments on commit ad99b26

Please sign in to comment.