-
Notifications
You must be signed in to change notification settings - Fork 18
/
transcribe_file_offline.py
58 lines (50 loc) · 2.29 KB
/
transcribe_file_offline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: MIT
import argparse
from pathlib import Path
import grpc
import riva.client
from riva.client.argparse_utils import add_asr_config_argparse_parameters, add_connection_argparse_parameters
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Offline file transcription via Riva AI Services. \"Offline\" means that entire audio "
"content of `--input-file` is sent in one request and then a transcript for whole file recieved in "
"one response.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--input-file", required=True, type=Path, help="A path to a local file to transcribe.")
parser = add_connection_argparse_parameters(parser)
parser = add_asr_config_argparse_parameters(parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True)
args = parser.parse_args()
args.input_file = args.input_file.expanduser()
return args
def main() -> None:
args = parse_args()
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
asr_service = riva.client.ASRService(auth)
config = riva.client.RecognitionConfig(
language_code=args.language_code,
max_alternatives=args.max_alternatives,
profanity_filter=args.profanity_filter,
enable_automatic_punctuation=args.automatic_punctuation,
verbatim_transcripts=not args.no_verbatim_transcripts,
enable_word_time_offsets=args.word_time_offsets or args.speaker_diarization,
)
riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score)
riva.client.add_speaker_diarization_to_config(config, args.speaker_diarization)
riva.client.add_endpoint_parameters_to_config(
config,
args.start_history,
args.start_threshold,
args.stop_history,
args.stop_history_eou,
args.stop_threshold
)
with args.input_file.open('rb') as fh:
data = fh.read()
try:
riva.client.print_offline(response=asr_service.offline_recognize(data, config))
except grpc.RpcError as e:
print(e.details())
if __name__ == "__main__":
main()