From 55251d31d4f59e8ad16d57538c24faba365e4d68 Mon Sep 17 00:00:00 2001 From: blob42 Date: Tue, 23 Apr 2024 15:13:50 +0200 Subject: [PATCH] TTS endpoint: add optional lang parameter --- backend/backend.proto | 3 +- backend/python/coqui/coqui_server.py | 2 +- core/backend/tts.go | 3 +- core/cli/tts.go | 3 +- core/http/endpoints/elevenlabs/tts.go | 2 +- core/http/endpoints/localai/tts.go | 2 +- core/schema/localai.go | 119 +++++++++++++------------- 7 files changed, 69 insertions(+), 65 deletions(-) diff --git a/backend/backend.proto b/backend/backend.proto index 62e1a1a6444..fb3166cda69 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -238,6 +238,7 @@ message TTSRequest { string model = 2; string dst = 3; string voice = 4; + optional string lang = 5; } message TokenizationResponse { @@ -264,4 +265,4 @@ message StatusResponse { message Message { string role = 1; string content = 2; -} \ No newline at end of file +} diff --git a/backend/python/coqui/coqui_server.py b/backend/python/coqui/coqui_server.py index d4dc731da60..332cc0fbe24 100644 --- a/backend/python/coqui/coqui_server.py +++ b/backend/python/coqui/coqui_server.py @@ -67,7 +67,7 @@ def LoadModel(self, request, context): def TTS(self, request, context): try: # if model is multilangual add language from request or env as fallback - lang = request.Lang or COQUI_LANGUAGE + lang = request.lang or COQUI_LANGUAGE if self.tts.is_multi_lingual and lang is None: return backend_pb2.Result(success=False, message=f"Model is multi-lingual, but no language was provided") diff --git a/core/backend/tts.go b/core/backend/tts.go index de3e56a57a5..7ac6e4c1fe3 100644 --- a/core/backend/tts.go +++ b/core/backend/tts.go @@ -29,7 +29,7 @@ func generateUniqueFileName(dir, baseName, ext string) string { } } -func ModelTTS(backend, text, modelFile, voice string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (string, *proto.Result, error) { +func ModelTTS(backend, text, modelFile, voice string, lang string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (string, *proto.Result, error) { bb := backend if bb == "" { bb = model.PiperBackend @@ -83,6 +83,7 @@ func ModelTTS(backend, text, modelFile, voice string, loader *model.ModelLoader, Model: modelPath, Voice: voice, Dst: filePath, + Lang: &lang, }) // return RPC error if any diff --git a/core/cli/tts.go b/core/cli/tts.go index 1d8fd3a39ec..f5ca21414bf 100644 --- a/core/cli/tts.go +++ b/core/cli/tts.go @@ -18,6 +18,7 @@ type TTSCMD struct { Backend string `short:"b" default:"piper" help:"Backend to run the TTS model"` Model string `short:"m" required:"" help:"Model name to run the TTS"` Voice string `short:"v" help:"Voice name to run the TTS"` + Lang string `short:"l" help:"Language to use with the TTS"` OutputFile string `short:"o" type:"path" help:"The path to write the output wav file"` ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"` BackendAssetsPath string `env:"LOCALAI_BACKEND_ASSETS_PATH,BACKEND_ASSETS_PATH" type:"path" default:"/tmp/localai/backend_data" help:"Path used to extract libraries that are required by some of the backends in runtime" group:"storage"` @@ -45,7 +46,7 @@ func (t *TTSCMD) Run(ctx *Context) error { options := config.BackendConfig{} options.SetDefaults() - filePath, _, err := backend.ModelTTS(t.Backend, text, t.Model, t.Voice, ml, opts, options) + filePath, _, err := backend.ModelTTS(t.Backend, text, t.Model, t.Voice, t.Lang, ml, opts, options) if err != nil { return err } diff --git a/core/http/endpoints/elevenlabs/tts.go b/core/http/endpoints/elevenlabs/tts.go index 841f9b5f784..685d7ff45ed 100644 --- a/core/http/endpoints/elevenlabs/tts.go +++ b/core/http/endpoints/elevenlabs/tts.go @@ -52,7 +52,7 @@ func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfi } log.Debug().Msgf("Request for model: %s", modelFile) - filePath, _, err := backend.ModelTTS(cfg.Backend, input.Text, modelFile, voiceID, ml, appConfig, *cfg) + filePath, _, err := backend.ModelTTS(cfg.Backend, input.Text, modelFile, "",voiceID, ml, appConfig, *cfg) if err != nil { return err } diff --git a/core/http/endpoints/localai/tts.go b/core/http/endpoints/localai/tts.go index e862b656796..67b7922fe73 100644 --- a/core/http/endpoints/localai/tts.go +++ b/core/http/endpoints/localai/tts.go @@ -52,7 +52,7 @@ func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfi cfg.Backend = input.Backend } - filePath, _, err := backend.ModelTTS(cfg.Backend, input.Input, modelFile, input.Voice, ml, appConfig, *cfg) + filePath, _, err := backend.ModelTTS(cfg.Backend, input.Input, modelFile, input.Voice, input.Lang, ml, appConfig, *cfg) if err != nil { return err } diff --git a/core/schema/localai.go b/core/schema/localai.go index e9b61cf3d50..eeb2fb4afbc 100644 --- a/core/schema/localai.go +++ b/core/schema/localai.go @@ -1,59 +1,60 @@ -package schema - -import ( - gopsutil "github.com/shirou/gopsutil/v3/process" -) - -type BackendMonitorRequest struct { - Model string `json:"model" yaml:"model"` -} - -type BackendMonitorResponse struct { - MemoryInfo *gopsutil.MemoryInfoStat - MemoryPercent float32 - CPUPercent float64 -} - -type TTSRequest struct { - Model string `json:"model" yaml:"model"` - Input string `json:"input" yaml:"input"` - Voice string `json:"voice" yaml:"voice"` - Backend string `json:"backend" yaml:"backend"` -} - -type StoresSet struct { - Store string `json:"store,omitempty" yaml:"store,omitempty"` - - Keys [][]float32 `json:"keys" yaml:"keys"` - Values []string `json:"values" yaml:"values"` -} - -type StoresDelete struct { - Store string `json:"store,omitempty" yaml:"store,omitempty"` - - Keys [][]float32 `json:"keys"` -} - -type StoresGet struct { - Store string `json:"store,omitempty" yaml:"store,omitempty"` - - Keys [][]float32 `json:"keys" yaml:"keys"` -} - -type StoresGetResponse struct { - Keys [][]float32 `json:"keys" yaml:"keys"` - Values []string `json:"values" yaml:"values"` -} - -type StoresFind struct { - Store string `json:"store,omitempty" yaml:"store,omitempty"` - - Key []float32 `json:"key" yaml:"key"` - Topk int `json:"topk" yaml:"topk"` -} - -type StoresFindResponse struct { - Keys [][]float32 `json:"keys" yaml:"keys"` - Values []string `json:"values" yaml:"values"` - Similarities []float32 `json:"similarities" yaml:"similarities"` -} +package schema + +import ( + gopsutil "github.com/shirou/gopsutil/v3/process" +) + +type BackendMonitorRequest struct { + Model string `json:"model" yaml:"model"` +} + +type BackendMonitorResponse struct { + MemoryInfo *gopsutil.MemoryInfoStat + MemoryPercent float32 + CPUPercent float64 +} + +type TTSRequest struct { + Model string `json:"model" yaml:"model"` + Input string `json:"input" yaml:"input"` + Voice string `json:"voice" yaml:"voice"` + Backend string `json:"backend" yaml:"backend"` + Lang string `json:"lang,omitempty" yaml:"lang,omitempty"` +} + +type StoresSet struct { + Store string `json:"store,omitempty" yaml:"store,omitempty"` + + Keys [][]float32 `json:"keys" yaml:"keys"` + Values []string `json:"values" yaml:"values"` +} + +type StoresDelete struct { + Store string `json:"store,omitempty" yaml:"store,omitempty"` + + Keys [][]float32 `json:"keys"` +} + +type StoresGet struct { + Store string `json:"store,omitempty" yaml:"store,omitempty"` + + Keys [][]float32 `json:"keys" yaml:"keys"` +} + +type StoresGetResponse struct { + Keys [][]float32 `json:"keys" yaml:"keys"` + Values []string `json:"values" yaml:"values"` +} + +type StoresFind struct { + Store string `json:"store,omitempty" yaml:"store,omitempty"` + + Key []float32 `json:"key" yaml:"key"` + Topk int `json:"topk" yaml:"topk"` +} + +type StoresFindResponse struct { + Keys [][]float32 `json:"keys" yaml:"keys"` + Values []string `json:"values" yaml:"values"` + Similarities []float32 `json:"similarities" yaml:"similarities"` +}