diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 8ef4d3291b72..034a855930fb 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -14,7 +14,6 @@ transformers>=4.8.1 sentencepiece<1.0.0 webdataset>=0.1.48,<=0.1.62 tqdm>=4.41.0 -opencc -pangu -jieba numba +grpcio +grpcio-tools diff --git a/requirements/requirements_nlp.txt b/requirements/requirements_nlp.txt index 8dbad1f43d39..31cde32a66e8 100644 --- a/requirements/requirements_nlp.txt +++ b/requirements/requirements_nlp.txt @@ -14,3 +14,6 @@ sacremoses>=0.0.43 nltk==3.6.2 wordninja==2.0.0 fasttext +opencc +pangu +jieba diff --git a/tools/nmt_grpc_service/README.md b/tools/nmt_grpc_service/README.md new file mode 100644 index 000000000000..c151b756d94e --- /dev/null +++ b/tools/nmt_grpc_service/README.md @@ -0,0 +1,63 @@ +# NMT gRPC Server Getting Started + +## Starting the NMT server + +Start the server by specifying multiple models (.nemo files) via the `--model` argument: + +``` +python server.py --model models/en-es.nemo --model models/en-de.nemo --model models/en-fr.nemo +``` + +If working with the outputs of a speech recognition system without punctuation and capitalization, you can provide the path to a .nemo model file that performs punctuation and capitalization ex: https://ngc.nvidia.com/catalog/models/nvidia:nemo:punctuation_en_bert via the `--punctuation_model` flag. + +NOTE: The server will throw an error if NMT models do not have have src_language and tgt_language attributes. + +## Notes + +Port can be overridden with `--port` flag. Default is 50052. Beam decoder parameters can also be set at server start time. See `--help` for more details. + +## Example Text Client + +``` +python client.py --target_language de --source_language en --text Hello +``` + +# ASR with Riva + Translation with NeMo cascade + +Below, we'll describe how to use Riva's ASR models and NeMo's NMT models to do speech translation via a cascade pipeline. + +## Installing Riva and python APIs + +Follow instructions in https://docs.nvidia.com/deeplearning/riva/user-guide/docs/quick-start-guide.html to setup and install Riva along with the python whl. + +For latest setup instructions follow the link above, since instructions below may not be up-to-date. + +```bash +ngc registry resource download-version nvidia/riva/riva_quickstart:1.4.0-beta +``` + +```bash +cd riva_quickstart_v1.4.0-beta +bash riva_init.sh +bash riva_start.sh + +pip install riva_api-1.4.0b0-py3-none-any.whl +``` + +This will start a Riva Speech Recognition service and `nvidia-smi` should show `tritonserver` running on GPU0. + +## ASR + NMT + +Start the NeMo translation server using instructions in the previous section (with or without a punctuation and capitalization model). + +Run the cascade client using a single channel audio wav file specifying the target language to translate into. By default, Riva ASR is in Englisha and so we specify only the target language to translate into. + +```bash +python asr_nmt_client.py --audio-file recording.mono.wav --asr_punctuation --target_language de +``` + +To view ASR outputs only + +```bash +python asr_nmt_client.py --audio-file recording.mono.wav --asr_punctuation --target_language de --asr_only +``` diff --git a/tools/nmt_grpc_service/api/nmt_pb2.py b/tools/nmt_grpc_service/api/nmt_pb2.py new file mode 100644 index 000000000000..14edb2e8bc10 --- /dev/null +++ b/tools/nmt_grpc_service/api/nmt_pb2.py @@ -0,0 +1,286 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: nmt.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='nmt.proto', + package='nvidia.riva.nmt', + syntax='proto3', + serialized_options=b'Z\026nvidia.com/riva_speech\370\001\001', + create_key=_descriptor._internal_create_key, + serialized_pb=b'\n\tnmt.proto\x12\x0fnvidia.riva.nmt\"W\n\x14TranslateTextRequest\x12\r\n\x05texts\x18\x01 \x03(\t\x12\x17\n\x0fsource_language\x18\x03 \x01(\t\x12\x17\n\x0ftarget_language\x18\x04 \x01(\t\"4\n\x0bTranslation\x12\x13\n\x0btranslation\x18\x01 \x01(\t\x12\x10\n\x08language\x18\x02 \x01(\t\"K\n\x15TranslateTextResponse\x12\x32\n\x0ctranslations\x18\x01 \x03(\x0b\x32\x1c.nvidia.riva.nmt.Translation2q\n\rRivaTranslate\x12`\n\rTranslateText\x12%.nvidia.riva.nmt.TranslateTextRequest\x1a&.nvidia.riva.nmt.TranslateTextResponse\"\x00\x42\x1bZ\x16nvidia.com/riva_speech\xf8\x01\x01\x62\x06proto3', +) + + +_TRANSLATETEXTREQUEST = _descriptor.Descriptor( + name='TranslateTextRequest', + full_name='nvidia.riva.nmt.TranslateTextRequest', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='texts', + full_name='nvidia.riva.nmt.TranslateTextRequest.texts', + index=0, + number=1, + type=9, + cpp_type=9, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + ), + _descriptor.FieldDescriptor( + name='source_language', + full_name='nvidia.riva.nmt.TranslateTextRequest.source_language', + index=1, + number=3, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=b"".decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + ), + _descriptor.FieldDescriptor( + name='target_language', + full_name='nvidia.riva.nmt.TranslateTextRequest.target_language', + index=2, + number=4, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=b"".decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[], + serialized_start=30, + serialized_end=117, +) + + +_TRANSLATION = _descriptor.Descriptor( + name='Translation', + full_name='nvidia.riva.nmt.Translation', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='translation', + full_name='nvidia.riva.nmt.Translation.translation', + index=0, + number=1, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=b"".decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + ), + _descriptor.FieldDescriptor( + name='language', + full_name='nvidia.riva.nmt.Translation.language', + index=1, + number=2, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=b"".decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[], + serialized_start=119, + serialized_end=171, +) + + +_TRANSLATETEXTRESPONSE = _descriptor.Descriptor( + name='TranslateTextResponse', + full_name='nvidia.riva.nmt.TranslateTextResponse', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='translations', + full_name='nvidia.riva.nmt.TranslateTextResponse.translations', + index=0, + number=1, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[], + serialized_start=173, + serialized_end=248, +) + +_TRANSLATETEXTRESPONSE.fields_by_name['translations'].message_type = _TRANSLATION +DESCRIPTOR.message_types_by_name['TranslateTextRequest'] = _TRANSLATETEXTREQUEST +DESCRIPTOR.message_types_by_name['Translation'] = _TRANSLATION +DESCRIPTOR.message_types_by_name['TranslateTextResponse'] = _TRANSLATETEXTRESPONSE +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +TranslateTextRequest = _reflection.GeneratedProtocolMessageType( + 'TranslateTextRequest', + (_message.Message,), + { + 'DESCRIPTOR': _TRANSLATETEXTREQUEST, + '__module__': 'nmt_pb2' + # @@protoc_insertion_point(class_scope:nvidia.riva.nmt.TranslateTextRequest) + }, +) +_sym_db.RegisterMessage(TranslateTextRequest) + +Translation = _reflection.GeneratedProtocolMessageType( + 'Translation', + (_message.Message,), + { + 'DESCRIPTOR': _TRANSLATION, + '__module__': 'nmt_pb2' + # @@protoc_insertion_point(class_scope:nvidia.riva.nmt.Translation) + }, +) +_sym_db.RegisterMessage(Translation) + +TranslateTextResponse = _reflection.GeneratedProtocolMessageType( + 'TranslateTextResponse', + (_message.Message,), + { + 'DESCRIPTOR': _TRANSLATETEXTRESPONSE, + '__module__': 'nmt_pb2' + # @@protoc_insertion_point(class_scope:nvidia.riva.nmt.TranslateTextResponse) + }, +) +_sym_db.RegisterMessage(TranslateTextResponse) + + +DESCRIPTOR._options = None + +_RIVATRANSLATE = _descriptor.ServiceDescriptor( + name='RivaTranslate', + full_name='nvidia.riva.nmt.RivaTranslate', + file=DESCRIPTOR, + index=0, + serialized_options=None, + create_key=_descriptor._internal_create_key, + serialized_start=250, + serialized_end=363, + methods=[ + _descriptor.MethodDescriptor( + name='TranslateText', + full_name='nvidia.riva.nmt.RivaTranslate.TranslateText', + index=0, + containing_service=None, + input_type=_TRANSLATETEXTREQUEST, + output_type=_TRANSLATETEXTRESPONSE, + serialized_options=None, + create_key=_descriptor._internal_create_key, + ), + ], +) +_sym_db.RegisterServiceDescriptor(_RIVATRANSLATE) + +DESCRIPTOR.services_by_name['RivaTranslate'] = _RIVATRANSLATE + +# @@protoc_insertion_point(module_scope) diff --git a/tools/nmt_grpc_service/api/nmt_pb2_grpc.py b/tools/nmt_grpc_service/api/nmt_pb2_grpc.py new file mode 100644 index 000000000000..637f2786e32e --- /dev/null +++ b/tools/nmt_grpc_service/api/nmt_pb2_grpc.py @@ -0,0 +1,97 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +from . import nmt_pb2 as nmt__pb2 + + +class RivaTranslateStub(object): + """Riva NLP Services implement task-specific APIs for popular NLP tasks including + intent recognition (as well as slot filling), and entity extraction. + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.TranslateText = channel.unary_unary( + '/nvidia.riva.nmt.RivaTranslate/TranslateText', + request_serializer=nmt__pb2.TranslateTextRequest.SerializeToString, + response_deserializer=nmt__pb2.TranslateTextResponse.FromString, + ) + + +class RivaTranslateServicer(object): + """Riva NLP Services implement task-specific APIs for popular NLP tasks including + intent recognition (as well as slot filling), and entity extraction. + """ + + def TranslateText(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_RivaTranslateServicer_to_server(servicer, server): + rpc_method_handlers = { + 'TranslateText': grpc.unary_unary_rpc_method_handler( + servicer.TranslateText, + request_deserializer=nmt__pb2.TranslateTextRequest.FromString, + response_serializer=nmt__pb2.TranslateTextResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler('nvidia.riva.nmt.RivaTranslate', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + +# This class is part of an EXPERIMENTAL API. +class RivaTranslate(object): + """Riva NLP Services implement task-specific APIs for popular NLP tasks including + intent recognition (as well as slot filling), and entity extraction. + """ + + @staticmethod + def TranslateText( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + '/nvidia.riva.nmt.RivaTranslate/TranslateText', + nmt__pb2.TranslateTextRequest.SerializeToString, + nmt__pb2.TranslateTextResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) diff --git a/tools/nmt_grpc_service/asr_nmt_client.py b/tools/nmt_grpc_service/asr_nmt_client.py new file mode 100644 index 000000000000..fbdc4ea5933b --- /dev/null +++ b/tools/nmt_grpc_service/asr_nmt_client.py @@ -0,0 +1,123 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import sys +import wave + +import api.nmt_pb2 as nmt +import api.nmt_pb2_grpc as nmtsrv +import grpc +import pyaudio +import riva_api.audio_pb2 as riva +import riva_api.riva_asr_pb2 as rivaasr +import riva_api.riva_asr_pb2_grpc as rivaasr_srv + + +def get_args(): + parser = argparse.ArgumentParser(description="Streaming transcription via Riva AI Speech Services") + parser.add_argument("--riva-server", default="localhost:50051", type=str, help="URI to GRPC server endpoint") + parser.add_argument("--audio-file", required=True, help="path to local file to stream") + parser.add_argument("--output-device", type=int, default=None, help="output device to use") + parser.add_argument("--list-devices", action="store_true", help="list output devices indices") + parser.add_argument("--nmt-server", default="localhost:50052", help="port on which NMT server runs") + parser.add_argument("--asr_only", action="store_true", help="Whether to skip MT and just display") + parser.add_argument("--target_language", default="es", help="Target language to translate into.") + parser.add_argument( + "--asr_punctuation", + action="store_true", + help="Whether to use Riva's punctuation model for ASR transcript postprocessing.", + ) + return parser.parse_args() + + +def listen_print_loop(responses, nmt_stub, target_language, asr_only=False): + num_chars_printed = 0 + prev_utterances = [] + for response in responses: + if not response.results: + continue + result = response.results[0] + if not result.alternatives: + continue + transcript = result.alternatives[0].transcript + original_transcript = transcript + if not asr_only: + req = nmt.TranslateTextRequest(texts=[transcript], source_language='en', target_language=target_language) + translation = nmt_stub.TranslateText(req).translations[0].translation + transcript = translation + overwrite_chars = ' ' * (num_chars_printed - len(transcript)) + if not result.is_final: + sys.stdout.write(">> " + transcript + overwrite_chars + '\r') + sys.stdout.flush() + num_chars_printed = len(transcript) + 3 + else: + print("## " + transcript + overwrite_chars + "\n") + num_chars_printed = 0 + prev_utterances.append(original_transcript) + + +CHUNK = 1024 +args = get_args() +wf = wave.open(args.audio_file, 'rb') +channel = grpc.insecure_channel(args.riva_server) +client = rivaasr_srv.RivaSpeechRecognitionStub(channel) +nmt_channel = grpc.insecure_channel(args.nmt_server) +nmt_stub = nmtsrv.RivaTranslateStub(nmt_channel) +config = rivaasr.RecognitionConfig( + encoding=riva.AudioEncoding.LINEAR_PCM, + sample_rate_hertz=wf.getframerate(), + language_code="en-US", + max_alternatives=1, + enable_automatic_punctuation=args.asr_punctuation, +) +streaming_config = rivaasr.StreamingRecognitionConfig(config=config, interim_results=True) + +# instantiate PyAudio (1) +p = pyaudio.PyAudio() +if args.list_devices: + for i in range(p.get_device_count()): + info = p.get_device_info_by_index(i) + if info['maxOutputChannels'] < 1: + continue + print(f"{info['index']}: {info['name']}") + sys.exit(0) + +# open stream (2) +stream = p.open( + output_device_index=args.output_device, + format=p.get_format_from_width(wf.getsampwidth()), + channels=wf.getnchannels(), + rate=wf.getframerate(), + output=True, +) + +# read data +def generator(w, s): + d = w.readframes(CHUNK) + yield rivaasr.StreamingRecognizeRequest(streaming_config=s) + while len(d) > 0: + yield rivaasr.StreamingRecognizeRequest(audio_content=d) + stream.write(d) + d = w.readframes(CHUNK) + return + + +responses = client.StreamingRecognize(generator(wf, streaming_config)) +listen_print_loop(responses, nmt_stub, target_language=args.target_language, asr_only=args.asr_only) +# stop stream (4) +stream.stop_stream() +stream.close() +# close PyAudio (5) +p.terminate() diff --git a/tools/nmt_grpc_service/client.py b/tools/nmt_grpc_service/client.py new file mode 100644 index 000000000000..4b6fde82dbdf --- /dev/null +++ b/tools/nmt_grpc_service/client.py @@ -0,0 +1,48 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from time import time + +import api.nmt_pb2 as nmt +import api.nmt_pb2_grpc as nmtsrv +import grpc + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--target_language", default="es", type=str, required=True) + parser.add_argument("--source_language", default="en", type=str, required=True) + parser.add_argument("--text", default="Hello!", type=str, required=True) + parser.add_argument("--port", default=50052, type=int, required=False) + + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = get_args() + with grpc.insecure_channel(f'localhost:{args.port}') as channel: + stub = nmtsrv.RivaTranslateStub(channel) + + iterations = 1 + start_time = time() + for _ in range(iterations): + req = nmt.TranslateTextRequest( + texts=[args.text], source_language=args.source_language, target_language=args.target_language + ) + result = stub.TranslateText(req) + end_time = time() + print(f"Time to complete {iterations} synchronous requests: {end_time-start_time}") + print(result) diff --git a/tools/nmt_grpc_service/nmt.proto b/tools/nmt_grpc_service/nmt.proto new file mode 100644 index 000000000000..aa5bc84b1d93 --- /dev/null +++ b/tools/nmt_grpc_service/nmt.proto @@ -0,0 +1,45 @@ +// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package nvidia.riva.nmt; + +option cc_enable_arenas = true; +option go_package = "nvidia.com/riva_speech"; + +// Riva NLP Services implement task-specific APIs for popular NLP tasks including +// intent recognition (as well as slot filling), and entity extraction. +service RivaTranslate { + rpc TranslateText(TranslateTextRequest) returns (TranslateTextResponse) {} + +} + +message TranslateTextRequest { + repeated string texts = 1; + + string source_language = 3; + string target_language = 4; +} + +message Translation { + string translation = 1; + + string language = 2; +} + +message TranslateTextResponse { + repeated Translation translations = 1; + +} diff --git a/tools/nmt_grpc_service/server.py b/tools/nmt_grpc_service/server.py new file mode 100644 index 000000000000..58af60b71b06 --- /dev/null +++ b/tools/nmt_grpc_service/server.py @@ -0,0 +1,209 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +from concurrent import futures + +import api.nmt_pb2 as nmt +import api.nmt_pb2_grpc as nmtsrv +import grpc +import torch + +import nemo.collections.nlp as nemo_nlp +from nemo.utils import logging + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", + required=True, + action="append", + help="List of .nemo files specified by using --model xyz.nemo multiple times.", + ) + parser.add_argument( + "--punctuation_model", + default="", + type=str, + help="Optionally provide a path a .nemo file for punctation and capitalization (recommend if working with Riva speech recognition outputs)", + ) + parser.add_argument("--port", default=50052, type=int, required=False) + parser.add_argument( + "--lang_directions", + required=False, + action="append", + help="Use this arg if any of your models don't have the src_language or tgt_language attributes.", + ) + parser.add_argument("--batch_size", type=int, default=256, help="Maximum number of batches to process") + parser.add_argument("--beam_size", type=int, default=1, help="Beam Size") + parser.add_argument("--len_pen", type=float, default=0.6, help="Length Penalty") + parser.add_argument("--max_delta_length", type=int, default=5, help="Max Delta Generation Length.") + + args = parser.parse_args() + return args + + +def batches(lst, n): + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(lst), n): + yield lst[i : i + n] + + +class RivaTranslateServicer(nmtsrv.RivaTranslateServicer): + """Provides methods that implement functionality of route guide server.""" + + def __init__( + self, + model_paths, + punctuation_model_path, + beam_size=1, + len_pen=0.6, + max_delta_length=5, + batch_size=256, + lang_directions=[], + ): + self._models = {} + self._beam_size = beam_size + self._len_pen = len_pen + self._max_delta_length = max_delta_length + self._batch_size = batch_size + self._punctuation_model_path = punctuation_model_path + self._model_paths = model_paths + self._lang_directions = lang_directions + + # __TODO__ make this easier to use in the future. + # If lang direction is provided for one model, it must be provided for all of them. + + if len(self._lang_directions) != 0: + if len(self._lang_directions) == len(self._model_paths): + raise ValueError( + f"Found a different number of models ({len(self._model_paths)}) and language directions ({len(self._lang_directions)})" + ) + + for idx, model_path in enumerate(self._model_paths): + assert os.path.exists(model_path) + logging.info(f"Loading model {model_path}") + if len(self._lang_directions) > 0: + src_language = self._lang_directions.split('-')[0] + tgt_language = self._lang_directions.split('-')[1] + else: + src_language, tgt_language = None, None + self._load_model(model_path, src_language, tgt_language) + + if self._punctuation_model_path != "": + assert os.path.exists(punctuation_model_path) + logging.info(f"Loading punctuation model {model_path}") + self._load_puncutation_model(punctuation_model_path) + + logging.info("Models loaded. Ready for inference requests.") + + def _load_puncutation_model(self, punctuation_model_path): + if punctuation_model_path.endswith(".nemo"): + self.punctuation_model = nemo_nlp.models.PunctuationCapitalizationModel.restore_from( + restore_path=punctuation_model_path + ) + self.punctuation_model.eval() + else: + raise NotImplemented(f"Only support .nemo files, but got: {punctuation_model_path}") + + if torch.cuda.is_available(): + self.punctuation_model = self.punctuation_model.cuda() + + def _load_model(self, model_path, src_language=None, tgt_language=None): + if model_path.endswith(".nemo"): + logging.info("Attempting to initialize from .nemo file") + model = nemo_nlp.models.machine_translation.MTEncDecModel.restore_from(restore_path=model_path) + model = model.eval() + model.beam_search.beam_size = self._beam_size + model.beam_search.len_pen = self._len_pen + model.beam_search.max_delta_length = self._max_delta_length + if torch.cuda.is_available(): + model = model.cuda() + else: + raise NotImplemented(f"Only support .nemo files, but got: {model_path}") + + if (not hasattr(model, "src_language") or not hasattr(model, "tgt_language")) and ( + src_language is None or tgt_language is None + ): + raise ValueError( + f"Could not find src_language and tgt_language in model attributes nor in --lang_directions. Please specify --lang_directions for all models. Ex: --lang_directions en-es --lang_directions en-de etc." + ) + + if src_language is not None and tgt_language is not None: + model.src_language = src_language + model.tgt_language = tgt_language + else: + src_language = model.src_language + tgt_language = model.tgt_language + + if src_language not in self._models: + self._models[src_language] = {} + + if tgt_language not in self._models[src_language]: + self._models[src_language][tgt_language] = model + if torch.cuda.is_available(): + self._models[src_language][tgt_language] = self._models[src_language][tgt_language].cuda() + else: + raise ValueError(f"Already found model for language pair {src_language}-{tgt_language}") + + def TranslateText(self, request, context): + logging.info(f"Request received w/ {len(request.texts)} utterances") + results = [] + + if request.source_language not in self._models: + context.set_code(grpc.StatusCode.INVALID_ARGUMENT) + context.set_details( + f"Could not find source-target language pair {request.source_language}-{request.target_language} in list of models." + ) + return nmt.TranslateTextResponse() + + if request.target_language not in self._models[request.source_language]: + context.set_code(grpc.StatusCode.INVALID_ARGUMENT) + context.set_details( + f"Could not find source-target language pair {request.source_language}-{request.target_language} in list of models." + ) + return nmt.TranslateTextResponse() + + request_strings = [x for x in request.texts] + + for batch in batches(request_strings, self._batch_size): + if self._punctuation_model_path != "": + batch = self.punctuation_model.add_punctuation_capitalization(batch) + batch_results = self._models[request.source_language][request.target_language].translate(text=batch) + translations = [nmt.Translation(translation=x) for x in batch_results] + results.extend(translations) + + return nmt.TranslateTextResponse(translations=results) + + +def serve(): + args = get_args() + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + servicer = RivaTranslateServicer( + model_paths=args.model, + punctuation_model_path=args.punctuation_model, + beam_size=args.beam_size, + len_pen=args.len_pen, + batch_size=args.batch_size, + max_delta_length=args.max_delta_length, + ) + nmtsrv.add_RivaTranslateServicer_to_server(servicer, server) + server.add_insecure_port('[::]:' + str(args.port)) + server.start() + server.wait_for_termination() + + +if __name__ == '__main__': + serve() diff --git a/examples/nlp/machine_translation/nmt_webapp/README.rst b/tools/nmt_webapp/README.rst similarity index 100% rename from examples/nlp/machine_translation/nmt_webapp/README.rst rename to tools/nmt_webapp/README.rst diff --git a/examples/nlp/machine_translation/nmt_webapp/config.json b/tools/nmt_webapp/config.json similarity index 100% rename from examples/nlp/machine_translation/nmt_webapp/config.json rename to tools/nmt_webapp/config.json diff --git a/examples/nlp/machine_translation/nmt_webapp/index.html b/tools/nmt_webapp/index.html similarity index 100% rename from examples/nlp/machine_translation/nmt_webapp/index.html rename to tools/nmt_webapp/index.html diff --git a/examples/nlp/machine_translation/nmt_webapp/nmt_service.py b/tools/nmt_webapp/nmt_service.py similarity index 99% rename from examples/nlp/machine_translation/nmt_webapp/nmt_service.py rename to tools/nmt_webapp/nmt_service.py index 3661fcdb4a80..ce8adf068745 100644 --- a/examples/nlp/machine_translation/nmt_webapp/nmt_service.py +++ b/tools/nmt_webapp/nmt_service.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import json import time diff --git a/examples/nlp/machine_translation/nmt_webapp/requirements.txt b/tools/nmt_webapp/requirements.txt similarity index 100% rename from examples/nlp/machine_translation/nmt_webapp/requirements.txt rename to tools/nmt_webapp/requirements.txt diff --git a/examples/nlp/machine_translation/nmt_webapp/style.css b/tools/nmt_webapp/style.css similarity index 100% rename from examples/nlp/machine_translation/nmt_webapp/style.css rename to tools/nmt_webapp/style.css