-
Notifications
You must be signed in to change notification settings - Fork 4
/
infer.py
46 lines (37 loc) · 1.66 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
#!/usr/bin/env python
# coding: utf-8
import torch
import model_handling
from data_handling import DataCollatorForNormSeq2Seq
from model_handling import EncoderDecoderSpokenNorm
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
# Init tokenizer and model
tokenizer = model_handling.init_tokenizer()
model = EncoderDecoderSpokenNorm.from_pretrained('nguyenvulebinh/spoken-norm', cache_dir=model_handling.cache_dir)
data_collator = DataCollatorForNormSeq2Seq(tokenizer)
# Infer sample
bias_list = ['scotland', 'covid', 'delta', 'beta']
input_str = 'ngày hai tám tháng tư cô vít bùng phát ở sờ cốt lờn chiếm tám mươi phần trăm là biến chủng đen ta và bê ta'
inputs = tokenizer([input_str])
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
if len(bias_list) > 0:
bias = data_collator.encode_list_string(bias_list)
bias_input_ids = bias['input_ids']
bias_attention_mask = bias['attention_mask']
else:
bias_input_ids = None
bias_attention_mask = None
inputs = {
"input_ids": torch.tensor(input_ids),
"attention_mask": torch.tensor(attention_mask),
"bias_input_ids": bias_input_ids,
"bias_attention_mask": bias_attention_mask,
}
# Format input text **with** bias phrases
outputs = model.generate(**inputs, output_attentions=True, num_beams=1, num_return_sequences=1)
for output in outputs.cpu().detach().numpy().tolist():
# print('\n', tokenizer.decode(output, skip_special_tokens=True).split(), '\n')
print(tokenizer.sp_model.DecodePieces(tokenizer.decode(output, skip_special_tokens=True).split()))
# output: 28/4 covid bùng phát ở scotland chiếm 80 % là biến chủng delta và beta