Skip to content

Commit

Permalink
Add basic grpc MT server (NVIDIA#1807)
Browse files Browse the repository at this point in the history
* Add basic grpc MT server

Add readme, server updates

Signed-off-by: Ryan Leary <rleary@nvidia.com>

* style fix

Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>

* fixing license headers

Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>

* Add punctuation model into NMT service

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Fix merge conflicts

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* style fixes to unblock CI

Signed-off-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>

* Add a Jarvis ASR + NeMo NMT client

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* style fixes

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Refactor gRPC service

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Update license headers

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Update one more license header

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Whitepsace in header

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Style fixes

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Fix grpc requirement

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Update license headers

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Add option to specify src/tgt lang and import fixes

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Style fixes

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Fix unused imports

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Renaming variables

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

Co-authored-by: Ryan Leary <rleary@nvidia.com>
Co-authored-by: Oleksii Kuchaiev <okuchaiev@users.noreply.github.com>
Co-authored-by: Oleksii Kuchaiev <okuchaiev@nvidia.com>
Co-authored-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>
  • Loading branch information
5 people authored and jfsantos committed Nov 18, 2021
1 parent 0514ceb commit 9467ab3
Show file tree
Hide file tree
Showing 15 changed files with 877 additions and 3 deletions.
5 changes: 2 additions & 3 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ transformers>=4.8.1
sentencepiece<1.0.0
webdataset>=0.1.48,<=0.1.62
tqdm>=4.41.0
opencc
pangu
jieba
numba
grpcio
grpcio-tools
3 changes: 3 additions & 0 deletions requirements/requirements_nlp.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@ sacremoses>=0.0.43
nltk==3.6.2
wordninja==2.0.0
fasttext
opencc
pangu
jieba
63 changes: 63 additions & 0 deletions tools/nmt_grpc_service/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# NMT gRPC Server Getting Started

## Starting the NMT server

Start the server by specifying multiple models (.nemo files) via the `--model` argument:

```
python server.py --model models/en-es.nemo --model models/en-de.nemo --model models/en-fr.nemo
```

If working with the outputs of a speech recognition system without punctuation and capitalization, you can provide the path to a .nemo model file that performs punctuation and capitalization ex: https://ngc.nvidia.com/catalog/models/nvidia:nemo:punctuation_en_bert via the `--punctuation_model` flag.

NOTE: The server will throw an error if NMT models do not have have src_language and tgt_language attributes.

## Notes

Port can be overridden with `--port` flag. Default is 50052. Beam decoder parameters can also be set at server start time. See `--help` for more details.

## Example Text Client

```
python client.py --target_language de --source_language en --text Hello
```

# ASR with Riva + Translation with NeMo cascade

Below, we'll describe how to use Riva's ASR models and NeMo's NMT models to do speech translation via a cascade pipeline.

## Installing Riva and python APIs

Follow instructions in https://docs.nvidia.com/deeplearning/riva/user-guide/docs/quick-start-guide.html to setup and install Riva along with the python whl.

For latest setup instructions follow the link above, since instructions below may not be up-to-date.

```bash
ngc registry resource download-version nvidia/riva/riva_quickstart:1.4.0-beta
```

```bash
cd riva_quickstart_v1.4.0-beta
bash riva_init.sh
bash riva_start.sh

pip install riva_api-1.4.0b0-py3-none-any.whl
```

This will start a Riva Speech Recognition service and `nvidia-smi` should show `tritonserver` running on GPU0.

## ASR + NMT

Start the NeMo translation server using instructions in the previous section (with or without a punctuation and capitalization model).

Run the cascade client using a single channel audio wav file specifying the target language to translate into. By default, Riva ASR is in Englisha and so we specify only the target language to translate into.

```bash
python asr_nmt_client.py --audio-file recording.mono.wav --asr_punctuation --target_language de
```

To view ASR outputs only

```bash
python asr_nmt_client.py --audio-file recording.mono.wav --asr_punctuation --target_language de --asr_only
```
286 changes: 286 additions & 0 deletions tools/nmt_grpc_service/api/nmt_pb2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: nmt.proto
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database

# @@protoc_insertion_point(imports)

_sym_db = _symbol_database.Default()


DESCRIPTOR = _descriptor.FileDescriptor(
name='nmt.proto',
package='nvidia.riva.nmt',
syntax='proto3',
serialized_options=b'Z\026nvidia.com/riva_speech\370\001\001',
create_key=_descriptor._internal_create_key,
serialized_pb=b'\n\tnmt.proto\x12\x0fnvidia.riva.nmt\"W\n\x14TranslateTextRequest\x12\r\n\x05texts\x18\x01 \x03(\t\x12\x17\n\x0fsource_language\x18\x03 \x01(\t\x12\x17\n\x0ftarget_language\x18\x04 \x01(\t\"4\n\x0bTranslation\x12\x13\n\x0btranslation\x18\x01 \x01(\t\x12\x10\n\x08language\x18\x02 \x01(\t\"K\n\x15TranslateTextResponse\x12\x32\n\x0ctranslations\x18\x01 \x03(\x0b\x32\x1c.nvidia.riva.nmt.Translation2q\n\rRivaTranslate\x12`\n\rTranslateText\x12%.nvidia.riva.nmt.TranslateTextRequest\x1a&.nvidia.riva.nmt.TranslateTextResponse\"\x00\x42\x1bZ\x16nvidia.com/riva_speech\xf8\x01\x01\x62\x06proto3',
)


_TRANSLATETEXTREQUEST = _descriptor.Descriptor(
name='TranslateTextRequest',
full_name='nvidia.riva.nmt.TranslateTextRequest',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='texts',
full_name='nvidia.riva.nmt.TranslateTextRequest.texts',
index=0,
number=1,
type=9,
cpp_type=9,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR,
create_key=_descriptor._internal_create_key,
),
_descriptor.FieldDescriptor(
name='source_language',
full_name='nvidia.riva.nmt.TranslateTextRequest.source_language',
index=1,
number=3,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=b"".decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR,
create_key=_descriptor._internal_create_key,
),
_descriptor.FieldDescriptor(
name='target_language',
full_name='nvidia.riva.nmt.TranslateTextRequest.target_language',
index=2,
number=4,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=b"".decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR,
create_key=_descriptor._internal_create_key,
),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=30,
serialized_end=117,
)


_TRANSLATION = _descriptor.Descriptor(
name='Translation',
full_name='nvidia.riva.nmt.Translation',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='translation',
full_name='nvidia.riva.nmt.Translation.translation',
index=0,
number=1,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=b"".decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR,
create_key=_descriptor._internal_create_key,
),
_descriptor.FieldDescriptor(
name='language',
full_name='nvidia.riva.nmt.Translation.language',
index=1,
number=2,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=b"".decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR,
create_key=_descriptor._internal_create_key,
),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=119,
serialized_end=171,
)


_TRANSLATETEXTRESPONSE = _descriptor.Descriptor(
name='TranslateTextResponse',
full_name='nvidia.riva.nmt.TranslateTextResponse',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='translations',
full_name='nvidia.riva.nmt.TranslateTextResponse.translations',
index=0,
number=1,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
serialized_options=None,
file=DESCRIPTOR,
create_key=_descriptor._internal_create_key,
),
],
extensions=[],
nested_types=[],
enum_types=[],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=173,
serialized_end=248,
)

_TRANSLATETEXTRESPONSE.fields_by_name['translations'].message_type = _TRANSLATION
DESCRIPTOR.message_types_by_name['TranslateTextRequest'] = _TRANSLATETEXTREQUEST
DESCRIPTOR.message_types_by_name['Translation'] = _TRANSLATION
DESCRIPTOR.message_types_by_name['TranslateTextResponse'] = _TRANSLATETEXTRESPONSE
_sym_db.RegisterFileDescriptor(DESCRIPTOR)

TranslateTextRequest = _reflection.GeneratedProtocolMessageType(
'TranslateTextRequest',
(_message.Message,),
{
'DESCRIPTOR': _TRANSLATETEXTREQUEST,
'__module__': 'nmt_pb2'
# @@protoc_insertion_point(class_scope:nvidia.riva.nmt.TranslateTextRequest)
},
)
_sym_db.RegisterMessage(TranslateTextRequest)

Translation = _reflection.GeneratedProtocolMessageType(
'Translation',
(_message.Message,),
{
'DESCRIPTOR': _TRANSLATION,
'__module__': 'nmt_pb2'
# @@protoc_insertion_point(class_scope:nvidia.riva.nmt.Translation)
},
)
_sym_db.RegisterMessage(Translation)

TranslateTextResponse = _reflection.GeneratedProtocolMessageType(
'TranslateTextResponse',
(_message.Message,),
{
'DESCRIPTOR': _TRANSLATETEXTRESPONSE,
'__module__': 'nmt_pb2'
# @@protoc_insertion_point(class_scope:nvidia.riva.nmt.TranslateTextResponse)
},
)
_sym_db.RegisterMessage(TranslateTextResponse)


DESCRIPTOR._options = None

_RIVATRANSLATE = _descriptor.ServiceDescriptor(
name='RivaTranslate',
full_name='nvidia.riva.nmt.RivaTranslate',
file=DESCRIPTOR,
index=0,
serialized_options=None,
create_key=_descriptor._internal_create_key,
serialized_start=250,
serialized_end=363,
methods=[
_descriptor.MethodDescriptor(
name='TranslateText',
full_name='nvidia.riva.nmt.RivaTranslate.TranslateText',
index=0,
containing_service=None,
input_type=_TRANSLATETEXTREQUEST,
output_type=_TRANSLATETEXTRESPONSE,
serialized_options=None,
create_key=_descriptor._internal_create_key,
),
],
)
_sym_db.RegisterServiceDescriptor(_RIVATRANSLATE)

DESCRIPTOR.services_by_name['RivaTranslate'] = _RIVATRANSLATE

# @@protoc_insertion_point(module_scope)
Loading

0 comments on commit 9467ab3

Please sign in to comment.