Skip to content

Commit

Permalink
[NeuralChat] Fix tts crash with messy retrieval input and enhance nor…
Browse files Browse the repository at this point in the history
…malizer (#1088)

Fix tts crash with messy retrieval input and enhance normalizer
  • Loading branch information
Spycsh committed Jan 5, 2024
1 parent df26e1d commit 4d8d9a2
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from datasets import load_dataset, Audio, Dataset, Features, ClassLabel
import os
import torch
from speechbrain.pretrained import EncoderClassifier
from typing import Any, Dict, List, Union
from transformers import SpeechT5HifiGan
import soundfile as sf
Expand Down Expand Up @@ -59,6 +58,7 @@ def __init__(self, output_audio_path="./response.wav", voice="default", stream_m
self.stream_mode = stream_mode
self.spk_model_name = "speechbrain/spkrec-xvect-voxceleb"
try:
from speechbrain.pretrained import EncoderClassifier
self.speaker_model = EncoderClassifier.from_hparams(
source=self.spk_model_name,
run_opts={"device": "cpu"},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@
class EnglishNormalizer:
def __init__(self):
self.correct_dict = {
"A": "Eigh",
"A": "eigh",
"B": "bee",
"C": "cee",
"D": "dee",
"E": "yee",
"F": "ef",
"G": "jee",
"H": "aitch",
"I": "I",
"I": "eye",
"J": "jay",
"K": "kay",
"L": "el",
Expand All @@ -58,8 +58,7 @@ def __init__(self):
def correct_abbreviation(self, text):
# TODO mixed abbreviation or proper noun like i7, ffmpeg, BTW should be supported

# words = text.split() # CVPR-15 will be upper but 1 and 5 will be splitted to two numbers
words = re.split(' |_|/', text)
words = re.split(r' |_|/|\*|\#', text) # ignore the characters that not break sentence
results = []
for idx, word in enumerate(words):
if word.startswith("-"): # bypass negative number
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_correct_conjunctions(self):
text = "CVPR-15 ICML-21 PM2.5"
text = self.normalizer.correct_abbreviation(text)
result = self.normalizer.correct_number(text)
self.assertEqual(result, "cee vee pea ar fifteen I cee em el twenty-one pea em two point five.")
self.assertEqual(result, "cee vee pea ar fifteen eye cee em el twenty-one pea em two point five.")

if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,11 @@ def test_tts_long_text(self):
output_audio_path = os.path.join(os.getcwd(), "tmp_audio/2.wav")
set_seed(555)
output_audio_path = self.tts.text2speech(text, output_audio_path, voice="default", do_batch_tts=True, batch_length=120)
result = self.asr.audio2text(output_audio_path)
self.assertTrue(os.path.exists(output_audio_path))
self.assertEqual("intel extension for transformers is an innovative toolkit to accelerate transformer based " + \
"models on intel platforms in particular effective on 4th intel xeon scalable processor " + \
"sapphire rapids codenamed sapphire rapids", result)

def test_create_speaker_embedding(self):
driven_audio_path = \
Expand All @@ -117,5 +121,15 @@ def test_tts_remove_noise(self):
result = self.asr.audio2text(output_audio_path)
self.assertEqual(text.lower(), result.lower())

def test_tts_messy_input(self):
text = "Please refer to the following responses to this inquiry:\n" + 244 * "* " + "*"
output_audio_path = os.path.join(os.getcwd(), "tmp_audio/6.wav")
set_seed(555)
output_audio_path = self.tts_noise_reducer.text2speech(text, output_audio_path, voice="default")
self.assertTrue(os.path.exists(output_audio_path))
# verify accuracy
result = self.asr.audio2text(output_audio_path)
self.assertEqual("please refer to the following responses to this inquiry", result.lower())

if __name__ == "__main__":
unittest.main()

0 comments on commit 4d8d9a2

Please sign in to comment.