Skip to content

Commit

Permalink
TTS endpoint: add optional language paramter
Browse files Browse the repository at this point in the history
Signed-off-by: blob42 <contact@blob42.xyz>
  • Loading branch information
blob42 committed Apr 29, 2024
1 parent 96a6ade commit 1a2d0cb
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 69 deletions.
1 change: 1 addition & 0 deletions backend/backend.proto
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ message TTSRequest {
string model = 2;
string dst = 3;
string voice = 4;
optional string language = 5;
}

message TokenizationResponse {
Expand Down
2 changes: 1 addition & 1 deletion backend/python/coqui/coqui_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def LoadModel(self, request, context):
def TTS(self, request, context):
try:
# if model is multilangual add language from request or env as fallback
lang = request.Lang or COQUI_LANGUAGE
lang = request.language or COQUI_LANGUAGE
if self.tts.is_multi_lingual and lang is None:
return backend_pb2.Result(success=False, message=f"Model is multi-lingual, but no language was provided")

Expand Down
9 changes: 8 additions & 1 deletion core/backend/tts.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func generateUniqueFileName(dir, baseName, ext string) string {
}
}

func ModelTTS(backend, text, modelFile, voice string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (string, *proto.Result, error) {
func ModelTTS(backend, text, modelFile, voice string, language string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (string, *proto.Result, error) {
bb := backend
if bb == "" {
bb = model.PiperBackend
Expand Down Expand Up @@ -83,6 +83,7 @@ func ModelTTS(backend, text, modelFile, voice string, loader *model.ModelLoader,
Model: modelPath,
Voice: voice,
Dst: filePath,
Language: &language,
})

// return RPC error if any
Expand All @@ -92,3 +93,9 @@ func ModelTTS(backend, text, modelFile, voice string, loader *model.ModelLoader,

return filePath, res, err
}

func ModelTTSInfo(backend, text, modelFile, voice string, language string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (string, *proto.Result, error) {


return "", nil, nil
}
3 changes: 2 additions & 1 deletion core/cli/tts.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type TTSCMD struct {
Backend string `short:"b" default:"piper" help:"Backend to run the TTS model"`
Model string `short:"m" required:"" help:"Model name to run the TTS"`
Voice string `short:"v" help:"Voice name to run the TTS"`
Language string `short:"l" help:"Language to use with the TTS"`
OutputFile string `short:"o" type:"path" help:"The path to write the output wav file"`
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
BackendAssetsPath string `env:"LOCALAI_BACKEND_ASSETS_PATH,BACKEND_ASSETS_PATH" type:"path" default:"/tmp/localai/backend_data" help:"Path used to extract libraries that are required by some of the backends in runtime" group:"storage"`
Expand Down Expand Up @@ -45,7 +46,7 @@ func (t *TTSCMD) Run(ctx *Context) error {
options := config.BackendConfig{}
options.SetDefaults()

filePath, _, err := backend.ModelTTS(t.Backend, text, t.Model, t.Voice, ml, opts, options)
filePath, _, err := backend.ModelTTS(t.Backend, text, t.Model, t.Voice, t.Language, ml, opts, options)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion core/http/endpoints/elevenlabs/tts.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfi
}
log.Debug().Msgf("Request for model: %s", modelFile)

filePath, _, err := backend.ModelTTS(cfg.Backend, input.Text, modelFile, voiceID, ml, appConfig, *cfg)
filePath, _, err := backend.ModelTTS(cfg.Backend, input.Text, modelFile, "", voiceID, ml, appConfig, *cfg)
if err != nil {
return err
}
Expand Down
14 changes: 8 additions & 6 deletions core/http/endpoints/localai/tts.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ import (
)

// TTSEndpoint is the OpenAI Speech API endpoint https://platform.openai.com/docs/api-reference/audio/createSpeech
// @Summary Generates audio from the input text.
// @Param request body schema.TTSRequest true "query params"
// @Success 200 {string} binary "Response"
// @Router /v1/audio/speech [post]
// @Router /tts [post]
// @Summary Generates audio from the input text.
// @Accept json
// @Produce audio/x-wav
// @Param request body schema.TTSRequest true "query params"
// @Success 200 {string} binary "generated audio/wav file"
// @Router /v1/audio/speech [post]
// @Router /tts [post]
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {

Expand Down Expand Up @@ -52,7 +54,7 @@ func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfi
cfg.Backend = input.Backend
}

filePath, _, err := backend.ModelTTS(cfg.Backend, input.Input, modelFile, input.Voice, ml, appConfig, *cfg)
filePath, _, err := backend.ModelTTS(cfg.Backend, input.Input, modelFile, input.Voice, input.Language, ml, appConfig, *cfg)
if err != nil {
return err
}
Expand Down
120 changes: 61 additions & 59 deletions core/schema/localai.go
Original file line number Diff line number Diff line change
@@ -1,59 +1,61 @@
package schema

import (
gopsutil "github.com/shirou/gopsutil/v3/process"
)

type BackendMonitorRequest struct {
Model string `json:"model" yaml:"model"`
}

type BackendMonitorResponse struct {
MemoryInfo *gopsutil.MemoryInfoStat
MemoryPercent float32
CPUPercent float64
}

type TTSRequest struct {
Model string `json:"model" yaml:"model"`
Input string `json:"input" yaml:"input"`
Voice string `json:"voice" yaml:"voice"`
Backend string `json:"backend" yaml:"backend"`
}

type StoresSet struct {
Store string `json:"store,omitempty" yaml:"store,omitempty"`

Keys [][]float32 `json:"keys" yaml:"keys"`
Values []string `json:"values" yaml:"values"`
}

type StoresDelete struct {
Store string `json:"store,omitempty" yaml:"store,omitempty"`

Keys [][]float32 `json:"keys"`
}

type StoresGet struct {
Store string `json:"store,omitempty" yaml:"store,omitempty"`

Keys [][]float32 `json:"keys" yaml:"keys"`
}

type StoresGetResponse struct {
Keys [][]float32 `json:"keys" yaml:"keys"`
Values []string `json:"values" yaml:"values"`
}

type StoresFind struct {
Store string `json:"store,omitempty" yaml:"store,omitempty"`

Key []float32 `json:"key" yaml:"key"`
Topk int `json:"topk" yaml:"topk"`
}

type StoresFindResponse struct {
Keys [][]float32 `json:"keys" yaml:"keys"`
Values []string `json:"values" yaml:"values"`
Similarities []float32 `json:"similarities" yaml:"similarities"`
}
package schema

import (
gopsutil "github.com/shirou/gopsutil/v3/process"
)

type BackendMonitorRequest struct {
Model string `json:"model" yaml:"model"`
}

type BackendMonitorResponse struct {
MemoryInfo *gopsutil.MemoryInfoStat
MemoryPercent float32
CPUPercent float64
}

// @Descsription TTS request body
type TTSRequest struct {
Model string `json:"model" yaml:"model"` // model name or full path
Input string `json:"input" yaml:"input"` // text input
Voice string `json:"voice" yaml:"voice"` // voice audio file or speaker id
Backend string `json:"backend" yaml:"backend"`
Language string `json:"language,omitempty" yaml:"language,omitempty"` // (optional) language to use with TTS model
}

type StoresSet struct {
Store string `json:"store,omitempty" yaml:"store,omitempty"`

Keys [][]float32 `json:"keys" yaml:"keys"`
Values []string `json:"values" yaml:"values"`
}

type StoresDelete struct {
Store string `json:"store,omitempty" yaml:"store,omitempty"`

Keys [][]float32 `json:"keys"`
}

type StoresGet struct {
Store string `json:"store,omitempty" yaml:"store,omitempty"`

Keys [][]float32 `json:"keys" yaml:"keys"`
}

type StoresGetResponse struct {
Keys [][]float32 `json:"keys" yaml:"keys"`
Values []string `json:"values" yaml:"values"`
}

type StoresFind struct {
Store string `json:"store,omitempty" yaml:"store,omitempty"`

Key []float32 `json:"key" yaml:"key"`
Topk int `json:"topk" yaml:"topk"`
}

type StoresFindResponse struct {
Keys [][]float32 `json:"keys" yaml:"keys"`
Values []string `json:"values" yaml:"values"`
Similarities []float32 `json:"similarities" yaml:"similarities"`
}

0 comments on commit 1a2d0cb

Please sign in to comment.