Skip to content

Commit

Permalink
[NeuralChat] OpenAI compatible audio API (#1226)
Browse files Browse the repository at this point in the history
  • Loading branch information
Spycsh committed Feb 2, 2024
1 parent 3385c42 commit d62ff9e
Show file tree
Hide file tree
Showing 8 changed files with 212 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ You can customize the configuration file 'audio_service.yaml' to match your envi
| tts.args.voice | "default" |
| tts.args.stream_mode | false |
| tts.args.output_audio_path | "./output_audio.wav" |
| tts.args.speedup | 1.0 |
| tasks_list | ['plugin_audio'] |


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ You can customize the configuration file 'textbot.yaml' to match your environmen
| tts.args.voice | "default" |
| tts.args.stream_mode | true |
| tts.args.output_audio_path | "./output_audio" |
| tts.args.speedup | 1.0 |
| tasks_list | ['voicechat'] |


Expand All @@ -67,3 +68,48 @@ To start the VoiceChat server, use the following command:
```shell
nohup bash run.sh &
```

# Quick test with OpenAI compatible endpoints (audio)

To make our audio service compatible to OpenAI [endpoints](https://platform.openai.com/docs/api-reference/audio/), we offer the following three endpoints:

```
/v1/audio/speech
/v1/audio/transcriptions
/v1/audio/translations
```

To test whether the talkingbot server can serve your requests correctly, you can use `curl` as follows:

```
curl http://localhost:8888/v1/audio/translations \
-H "Content-Type: multipart/form-data" \
-F file="@sample_1.wav"
curl http://localhost:8888/v1/audio/transcriptions \
-H "Content-Type: multipart/form-data" \
-F file="@sample_zh_cn.wav"
curl http://localhost:8888/v1/audio/speech \
-H "Content-Type: application/json" \
-d '{
"model": "speecht5",
"input": "The quick brown fox jumped over the lazy dog.",
"voice": "default"
}' \
--output speech.mp3
```

# Customized endpoints of a audio-input-audio-output pipeline

You can check `intel_extension_for_transformers/neural_chat/server/restful/voicechat_api.py` to see the other customized endpoints offered by us:

```
/v1/talkingbot/asr
/v1/talkingbot/llm_tts
/v1/talkingbot/create_embedding
```

`/v1/talkingbot/asr` is equivalent to `/v1/audio/transcriptions` and for backward compatibility we simply keep that for audio-to-speech conversion.
`/v1/talkingbot/llm_tts` merges two tasks: `LLM text generation` and the `text to speech` into one process, which is designed specifically for converting steadily the LLM streaming outputs to speech.
`/v1/talkingbot/create_embedding` is used to create a SpeechT5 speaker embedding for zero-shot voice cloning. Although voice-cloning is relatively weak for SpeechT5, we still keep this endpoint for quick start. If you want to clone your voice properly, please check the current best practices for SpeechT5 based on few-shot voice-cloning finetuning in this [repo](../../../../finetuning/tts/).
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class TextToSpeech():
3) Customized voice (Original model + User's customized input voice embedding)
"""
def __init__(self, output_audio_path="./response.wav", voice="default", stream_mode=False, device="cpu",
reduce_noise=False):
reduce_noise=False, speedup=1.0):
"""Make sure your export LD_PRELOAD=<path to libiomp5.so and libtcmalloc> beforehand."""
# default setting
if device == "auto":
Expand Down Expand Up @@ -101,6 +101,7 @@ def __init__(self, output_audio_path="./response.wav", voice="default", stream_m

self.normalizer = EnglishNormalizer()
self.noise_reducer = NoiseReducer() if reduce_noise else None
self.speedup = speedup

def _audiosegment_to_librosawav(self, audiosegment):
# https://github.com/jiaaro/pydub/blob/master/API.markdown#audiosegmentget_array_of_samples
Expand Down Expand Up @@ -175,8 +176,13 @@ def _batch_long_text(self, text, batch_length):
res = [i + "." for i in res] # avoid unexpected end of sequence
return res

def _speedup(self, path, speed):
from pydub import AudioSegment
from pydub.effects import speedup
sound = AudioSegment.from_file(path)
speedup(sound, playback_speed=speed).export(path)

def text2speech(self, text, output_audio_path, voice="default",
def text2speech(self, text, output_audio_path="./response.wav", voice="default", speedup=1.0,
do_batch_tts=False, batch_length=400):
"""Text to speech.
Expand Down Expand Up @@ -220,12 +226,14 @@ def text2speech(self, text, output_audio_path, voice="default",
sf.write(output_audio_path, all_speech, samplerate=16000)
if self.noise_reducer:
output_audio_path = self.noise_reducer.reduce_audio_amplify(output_audio_path, all_speech)
if speedup != 1.0:
self._speedup(output_audio_path, speedup)
return output_audio_path

def stream_text2speech(self, generator, output_audio_path, voice="default"):
def stream_text2speech(self, generator, output_audio_path, voice="default", speedup=1.0):
"""Stream the generation of audios with an LLM text generator."""
for idx, response in enumerate(generator):
yield self.text2speech(response, f"{output_audio_path}_{idx}.wav", voice)
yield self.text2speech(response, f"{output_audio_path}_{idx}.wav", voice, speedup)


def post_llm_inference_actions(self, text_or_generator):
Expand All @@ -249,6 +257,7 @@ def cache_words_into_sentences():
# output the trailing sequence
if len(buffered_texts) > 0:
yield ''.join(buffered_texts)
return self.stream_text2speech(cache_words_into_sentences(), self.output_audio_path, self.voice)
return self.stream_text2speech(
cache_words_into_sentences(), self.output_audio_path, self.voice, self.speedup)
else:
return self.text2speech(text_or_generator, self.output_audio_path, self.voice)
return self.text2speech(text_or_generator, self.output_audio_path, self.voice, self.speedup)
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,13 @@ class ChatCompletionResponse(BaseModel):
usage: UsageInfo


class CreateSpeechRequest(BaseModel):
model: str = "speecht5",
input: str = "hello.",
voice: str = "default",
response_format: Optional[str] = None,
speed: Optional[Union[float, int]] = 1.0

class DeltaMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,24 @@ def __init__(self) -> None:
super().__init__()
self.chatbot = None

def handle_voice_asr_request(self, filename: str) -> str:
def handle_voice_asr_request(self, filename: str, language: str = "auto") -> str:
asr = get_plugin_instance("asr")
asr.language = language
try:
return asr.audio2text(filename)
except Exception as e:
raise Exception(e)

async def handle_voice_tts_request(self, text: str, voice: str, audio_output_path: Optional[str]=None) -> str:
async def handle_voice_tts_request(self,
text: str,
voice: str,
audio_output_path: Optional[str] = None,
speedup: float = 1.0) -> str:

plugins.tts.args['voice'] = voice
plugins.tts.args['output_audio_path'] = audio_output_path
tts = get_plugin_instance("tts")
tts.speedup = speedup
try:
result = tts.post_llm_inference_actions(text)
def audio_file_generate(result):
Expand Down Expand Up @@ -78,7 +84,7 @@ async def handle_create_speaker_embedding(self, spk_id):


@router.post("/plugin/audio/asr")
async def handle_talkingbot_asr(file: UploadFile = File(...)):
async def handle_talkingbot_asr(file: UploadFile = File(...), language: str = "auto"):
file_name = file.filename
logger.info(f'Received file: {file_name}')
with open("tmp_audio_bytes", 'wb') as fout:
Expand All @@ -89,7 +95,7 @@ async def handle_talkingbot_asr(file: UploadFile = File(...)):
# bytes to wav
file_name = file_name +'.wav'
audio.export(f"{file_name}", format="wav")
asr_result = router.handle_voice_asr_request(file_name)
asr_result = router.handle_voice_asr_request(file_name, language=language)
return {"asr_result": asr_result}


Expand All @@ -98,11 +104,12 @@ async def talkingbot(request: Request):
data = await request.json()
text = data["text"]
voice = data["voice"]
speedup = float(data["speed"]) if "speed" in data else 1.0
audio_output_path = data["audio_output_path"] if "audio_output_path" in data else "output_audio.wav"

logger.info(f'Received prompt: {text}, and use voice: {voice}')

return await router.handle_voice_tts_request(text, voice, audio_output_path)
return await router.handle_voice_tts_request(text, voice, audio_output_path, speedup)


@router.post("/plugin/audio/create_embedding")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@
# limitations under the License.

from fastapi import APIRouter, Request
from fastapi.responses import StreamingResponse
from typing import Optional
from fastapi.responses import StreamingResponse, FileResponse
from typing import Optional, Union
from ...cli.log import logger
from fastapi import File, UploadFile, Form
from ...config import GenerationConfig
from ...plugins import plugins
import base64
import torch

from ...server.restful.openai_protocol import CreateSpeechRequest
class VoiceChatAPIRouter(APIRouter):

def __init__(self) -> None:
Expand All @@ -44,8 +44,9 @@ def get_chatbot(self):
raise RuntimeError("Chatbot instance has not been set.")
return self.chatbot

def handle_voice_asr_request(self, filename: str) -> str:
def handle_voice_asr_request(self, filename: str, language: str = "auto") -> str:
chatbot = self.get_chatbot()
chatbot.asr.language = language
try:
return chatbot.asr.audio2text(filename)
except Exception as e:
Expand All @@ -68,6 +69,26 @@ def audio_file_generate(result):
except Exception as e:
raise Exception(e)

async def handle_create_speech_request(self,
model: str,
input: str,
voice: str,
response_format: str = None,
speed: float = 1.0
) -> str:
chatbot = self.get_chatbot()
try:
if model == "speecht5":
result_path: str = chatbot.tts.text2speech(text=input,
output_audio_path="speech.{}".format(response_format),
voice=voice,
speedup=speed)
return FileResponse(result_path)
else:
raise Exception("More models to be supported soon!")
except Exception as e:
raise Exception(e)

async def handle_create_speaker_embedding(self, spk_id):
chatbot = self.get_chatbot()
try:
Expand Down Expand Up @@ -130,3 +151,60 @@ async def create_speaker_embedding(file: UploadFile = File(...)):

await router.handle_create_speaker_embedding(spk_id)
return {"spk_id": spk_id}


# https://platform.openai.com/docs/api-reference/audio/createSpeech
@router.post("/v1/audio/speech")
async def create_speech(request: CreateSpeechRequest):
response = await router.handle_create_speech_request(model=request.model,
input=request.input,
voice=request.voice,
response_format="mp3" if request.response_format == (None,) \
else request.response_format,
speed=float(request.speed),
)
return response

# https://platform.openai.com/docs/api-reference/audio/createTranscription
@router.post("/v1/audio/transcriptions")
async def create_transcription(file: UploadFile = File(...),
model: str = "whisper",
language: Optional[str] = "auto",
prompt: Optional[str] = None,
response_format: Optional[str] = "text",
temperature: Optional[float] = 0.0):
file_name = file.filename
logger.info(f'Received file: {file_name}')
with open("tmp_audio_bytes", 'wb') as fout:
content = await file.read()
fout.write(content)
from pydub import AudioSegment
audio = AudioSegment.from_file("tmp_audio_bytes")
audio = audio.set_frame_rate(16000)
# bytes to wav
file_name = file_name +'.wav'
audio.export(f"{file_name}", format="wav")
asr_result = router.handle_voice_asr_request(file_name, language=language)
return {"asr_result": asr_result}

# https://platform.openai.com/docs/api-reference/audio/createTranslation
# The difference from /v1/audio/transcriptions is that this endpoint is specifically for English ASR
@router.post("/v1/audio/translations")
async def create_translation(file: UploadFile = File(...),
model: str = "whisper",
prompt: Optional[str] = None,
response_format: Optional[str] = "text",
temperature: Optional[float] = 0.0):
file_name = file.filename
logger.info(f'Received file: {file_name}')
with open("tmp_audio_bytes", 'wb') as fout:
content = await file.read()
fout.write(content)
from pydub import AudioSegment
audio = AudioSegment.from_file("tmp_audio_bytes")
audio = audio.set_frame_rate(16000)
# bytes to wav
file_name = file_name +'.wav'
audio.export(f"{file_name}", format="wav")
asr_result = router.handle_voice_asr_request(file_name, language=None)
return {"asr_result": asr_result}
Original file line number Diff line number Diff line change
Expand Up @@ -131,5 +131,19 @@ def test_tts_messy_input(self):
result = self.asr.audio2text(output_audio_path)
self.assertEqual("please refer to the following responses to this inquiry", result.lower())

def test_tts_speedup(self):
text = "hello there."
set_seed(555)
output_audio_path1 = os.path.join(os.getcwd(), "tmp_audio/7.wav")
output_audio_path1 = self.tts_noise_reducer.text2speech(text, output_audio_path1, voice="default", speedup=1.0,)
set_seed(555)
output_audio_path2 = os.path.join(os.getcwd(), "tmp_audio/8.wav")
output_audio_path2 = self.tts_noise_reducer.text2speech(text, output_audio_path2, voice="default", speedup=2.0,)
self.assertTrue(os.path.exists(output_audio_path2))
from pydub import AudioSegment
waveform1 = AudioSegment.from_file(output_audio_path1).set_frame_rate(16000)
waveform2 = AudioSegment.from_file(output_audio_path2).set_frame_rate(16000)
self.assertNotEqual(len(waveform1), len(waveform2))

if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -108,5 +108,40 @@ def test_create_speaker_embedding(self):
assert response.status_code == 200
assert "spk_id" in response.json()

def test_create_speech(self):
tts_data = {
"model": "speecht5",
"input": "Hello, this is a test.",
"voice": "default",
"speed": 1.5
}
response = client.post("/v1/audio/speech", json=tts_data)
assert response.status_code == 200

def test_create_transcription(self):
# Create a test audio file for ASR
with open("./sample_audio.wav", "rb") as audio_file:
files = {
"file": ("test_audio.wav", audio_file, "audio/wav"),
"model": "whisper",
"language": "auto"
}
response = client.post("/v1/audio/transcriptions", files=files)

assert response.status_code == 200
assert "asr_result" in response.json()

def test_create_translations(self):
# Create a test audio file for ASR
with open("./sample_audio.wav", "rb") as audio_file:
files = {
"file": ("test_audio.wav", audio_file, "audio/wav"),
"model": "whisper",
}
response = client.post("/v1/audio/translations", files=files)

assert response.status_code == 200
assert "asr_result" in response.json()

if __name__ == "__main__":
unittest.main()

0 comments on commit d62ff9e

Please sign in to comment.