Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 139 additions & 82 deletions core/http/endpoints/openai/realtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,30 @@ const (
"Avoid parenthetical asides, URLs, and anything that cannot be clearly vocalized."
)

// resolveOutputModalities returns the effective output modalities for a
// response: response-level overrides session-level, and the OpenAI Realtime
// spec default is ["audio"] when neither is set.
func resolveOutputModalities(session, response []types.Modality) []types.Modality {
if len(response) > 0 {
return response
}
if len(session) > 0 {
return session
}
return []types.Modality{types.ModalityAudio}
}

// modalitiesContainAudio reports whether the resolved modalities include audio
// output.
func modalitiesContainAudio(m []types.Modality) bool {
for _, x := range m {
if x == types.ModalityAudio {
return true
}
}
return false
}

// A model can be "emulated" that is: transcribe audio to text -> feed text to the LLM -> generate audio as result
// If the model support instead audio-to-audio, we will use the specific gRPC calls instead

Expand Down Expand Up @@ -82,6 +106,10 @@ type Session struct {
InputSampleRate int
OutputSampleRate int
MaxOutputTokens types.IntOrInf
// OutputModalities mirrors the OpenAI Realtime spec field of the same
// name. Empty means "use the spec default" (audio). ["text"] suppresses
// TTS so the client receives only response.output_text.* events.
OutputModalities []types.Modality
// MaxHistoryItems caps the number of MessageItems passed to the LLM each
// turn (0 = unlimited). Small models — especially the LFM2.5-Audio 1.5B
// served via the liquid-audio backend — degrade quickly past a handful
Expand Down Expand Up @@ -162,13 +190,14 @@ func (s *Session) ToServer() types.SessionUnion {
} else {
return types.SessionUnion{
Realtime: &types.RealtimeSession{
ID: s.ID,
Object: "realtime.session",
Model: s.Model,
Instructions: s.Instructions,
Tools: s.Tools,
ToolChoice: s.ToolChoice,
MaxOutputTokens: s.MaxOutputTokens,
ID: s.ID,
Object: "realtime.session",
Model: s.Model,
Instructions: s.Instructions,
Tools: s.Tools,
ToolChoice: s.ToolChoice,
MaxOutputTokens: s.MaxOutputTokens,
OutputModalities: s.OutputModalities,
Audio: &types.RealtimeSessionAudio{
Input: &types.SessionAudioInput{
TurnDetection: s.TurnDetection,
Expand Down Expand Up @@ -1015,6 +1044,10 @@ func updateSession(session *Session, update *types.SessionUnion, cl *config.Mode
session.MaxOutputTokens = rt.MaxOutputTokens
}

if len(rt.OutputModalities) > 0 {
session.OutputModalities = rt.OutputModalities
}

return nil
}

Expand Down Expand Up @@ -1654,106 +1687,130 @@ func triggerResponseAtTurn(ctx context.Context, session *Session, conv *Conversa
})
}

// Check for cancellation before TTS
if ctx.Err() != nil {
xlog.Debug("Response cancelled before TTS (barge-in)")
sendCancelledResponse()
return
var audioString string
_, isWebRTC := t.(*WebRTCTransport)
var respMods []types.Modality
if overrides != nil {
respMods = overrides.OutputModalities
}

audioFilePath, res, err := session.ModelInterface.TTS(ctx, finalSpeech, session.Voice, session.InputAudioTranscription.Language)
if err != nil {
modalities := resolveOutputModalities(session.OutputModalities, respMods)
if modalitiesContainAudio(modalities) {
// Check for cancellation before TTS
if ctx.Err() != nil {
xlog.Debug("TTS cancelled (barge-in)")
xlog.Debug("Response cancelled before TTS (barge-in)")
sendCancelledResponse()
return
}
xlog.Error("TTS failed", "error", err)
sendError(t, "tts_error", fmt.Sprintf("TTS generation failed: %v", err), "", item.Assistant.ID)
return
}
if !res.Success {
xlog.Error("TTS failed", "message", res.Message)
sendError(t, "tts_error", fmt.Sprintf("TTS generation failed: %s", res.Message), "", item.Assistant.ID)
return
}
defer os.Remove(audioFilePath)

audioBytes, err := os.ReadFile(audioFilePath)
if err != nil {
xlog.Error("failed to read TTS file", "error", err)
sendError(t, "tts_error", fmt.Sprintf("Failed to read TTS audio: %v", err), "", item.Assistant.ID)
return
}

// Parse WAV header to get raw PCM and the actual sample rate from the TTS backend.
pcmData, ttsSampleRate := laudio.ParseWAV(audioBytes)
if ttsSampleRate == 0 {
ttsSampleRate = localSampleRate
}
xlog.Debug("TTS audio parsed", "raw_bytes", len(audioBytes), "pcm_bytes", len(pcmData), "sample_rate", ttsSampleRate)
audioFilePath, res, err := session.ModelInterface.TTS(ctx, finalSpeech, session.Voice, session.InputAudioTranscription.Language)
if err != nil {
if ctx.Err() != nil {
xlog.Debug("TTS cancelled (barge-in)")
sendCancelledResponse()
return
}
xlog.Error("TTS failed", "error", err)
sendError(t, "tts_error", fmt.Sprintf("TTS generation failed: %v", err), "", item.Assistant.ID)
return
}
if !res.Success {
xlog.Error("TTS failed", "message", res.Message)
sendError(t, "tts_error", fmt.Sprintf("TTS generation failed: %s", res.Message), "", item.Assistant.ID)
return
}
defer func() { _ = os.Remove(audioFilePath) }()

// SendAudio (WebRTC) passes PCM at the TTS sample rate directly to the
// Opus encoder, which resamples to 48kHz internally. This avoids a
// lossy intermediate resample through 16kHz.
// XXX: This is a noop in websocket mode; it's included in the JSON instead
if err := t.SendAudio(ctx, pcmData, ttsSampleRate); err != nil {
if ctx.Err() != nil {
xlog.Debug("Audio playback cancelled (barge-in)")
sendCancelledResponse()
audioBytes, err := os.ReadFile(audioFilePath)
if err != nil {
xlog.Error("failed to read TTS file", "error", err)
sendError(t, "tts_error", fmt.Sprintf("Failed to read TTS audio: %v", err), "", item.Assistant.ID)
return
}
xlog.Error("failed to send audio via transport", "error", err)
}

_, isWebRTC := t.(*WebRTCTransport)
// Parse WAV header to get raw PCM and the actual sample rate from the TTS backend.
pcmData, ttsSampleRate := laudio.ParseWAV(audioBytes)
if ttsSampleRate == 0 {
ttsSampleRate = localSampleRate
}
xlog.Debug("TTS audio parsed", "raw_bytes", len(audioBytes), "pcm_bytes", len(pcmData), "sample_rate", ttsSampleRate)

// SendAudio (WebRTC) passes PCM at the TTS sample rate directly to the
// Opus encoder, which resamples to 48kHz internally. This avoids a
// lossy intermediate resample through 16kHz.
// XXX: This is a noop in websocket mode; it's included in the JSON instead
if err := t.SendAudio(ctx, pcmData, ttsSampleRate); err != nil {
if ctx.Err() != nil {
xlog.Debug("Audio playback cancelled (barge-in)")
sendCancelledResponse()
return
}
xlog.Error("failed to send audio via transport", "error", err)
}

// For WebSocket clients, resample to the session's output rate and
// deliver audio as base64 in JSON events. WebRTC clients already
// received audio over the RTP track, so skip the base64 payload.
var audioString string
if !isWebRTC {
wsPCM := pcmData
if ttsSampleRate != session.OutputSampleRate {
samples := sound.BytesToInt16sLE(pcmData)
resampled := sound.ResampleInt16(samples, ttsSampleRate, session.OutputSampleRate)
wsPCM = sound.Int16toBytesLE(resampled)
// For WebSocket clients, resample to the session's output rate and
// deliver audio as base64 in JSON events. WebRTC clients already
// received audio over the RTP track, so skip the base64 payload.
if !isWebRTC {
wsPCM := pcmData
if ttsSampleRate != session.OutputSampleRate {
samples := sound.BytesToInt16sLE(pcmData)
resampled := sound.ResampleInt16(samples, ttsSampleRate, session.OutputSampleRate)
wsPCM = sound.Int16toBytesLE(resampled)
}
audioString = base64.StdEncoding.EncodeToString(wsPCM)
}
audioString = base64.StdEncoding.EncodeToString(wsPCM)
}

sendEvent(t, types.ResponseOutputAudioTranscriptDeltaEvent{
ServerEventBase: types.ServerEventBase{},
ResponseID: responseID,
ItemID: item.Assistant.ID,
OutputIndex: 0,
ContentIndex: 0,
Delta: finalSpeech,
})
sendEvent(t, types.ResponseOutputAudioTranscriptDoneEvent{
ServerEventBase: types.ServerEventBase{},
ResponseID: responseID,
ItemID: item.Assistant.ID,
OutputIndex: 0,
ContentIndex: 0,
Transcript: finalSpeech,
})
sendEvent(t, types.ResponseOutputAudioTranscriptDeltaEvent{
ServerEventBase: types.ServerEventBase{},
ResponseID: responseID,
ItemID: item.Assistant.ID,
OutputIndex: 0,
ContentIndex: 0,
Delta: finalSpeech,
})
sendEvent(t, types.ResponseOutputAudioTranscriptDoneEvent{
ServerEventBase: types.ServerEventBase{},
ResponseID: responseID,
ItemID: item.Assistant.ID,
OutputIndex: 0,
ContentIndex: 0,
Transcript: finalSpeech,
})

if !isWebRTC {
sendEvent(t, types.ResponseOutputAudioDeltaEvent{
if !isWebRTC {
sendEvent(t, types.ResponseOutputAudioDeltaEvent{
ServerEventBase: types.ServerEventBase{},
ResponseID: responseID,
ItemID: item.Assistant.ID,
OutputIndex: 0,
ContentIndex: 0,
Delta: audioString,
})
sendEvent(t, types.ResponseOutputAudioDoneEvent{
ServerEventBase: types.ServerEventBase{},
ResponseID: responseID,
ItemID: item.Assistant.ID,
OutputIndex: 0,
ContentIndex: 0,
})
}
} else {
// Text-only mode: skip TTS, emit only the text events.
sendEvent(t, types.ResponseOutputTextDeltaEvent{
ServerEventBase: types.ServerEventBase{},
ResponseID: responseID,
ItemID: item.Assistant.ID,
OutputIndex: 0,
ContentIndex: 0,
Delta: audioString,
Delta: finalSpeech,
})
sendEvent(t, types.ResponseOutputAudioDoneEvent{
sendEvent(t, types.ResponseOutputTextDoneEvent{
ServerEventBase: types.ServerEventBase{},
ResponseID: responseID,
ItemID: item.Assistant.ID,
OutputIndex: 0,
ContentIndex: 0,
Text: finalSpeech,
})
}

Expand Down
39 changes: 39 additions & 0 deletions core/http/endpoints/openai/realtime_modality_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package openai

import (
"github.com/mudler/LocalAI/core/http/endpoints/openai/types"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)

var _ = Describe("resolveOutputModalities", func() {
It("defaults to audio when neither session nor response specify", func() {
got := resolveOutputModalities(nil, nil)
Expect(got).To(ConsistOf(types.ModalityAudio))
})

It("uses session modalities when response omits them", func() {
sess := []types.Modality{types.ModalityText}
got := resolveOutputModalities(sess, nil)
Expect(got).To(ConsistOf(types.ModalityText))
})

It("response modalities override session", func() {
sess := []types.Modality{types.ModalityAudio}
resp := []types.Modality{types.ModalityText}
got := resolveOutputModalities(sess, resp)
Expect(got).To(ConsistOf(types.ModalityText))
})

It("returns false from modalitiesContainAudio for text-only", func() {
Expect(modalitiesContainAudio([]types.Modality{types.ModalityText})).To(BeFalse())
})

It("returns true from modalitiesContainAudio for audio (default)", func() {
Expect(modalitiesContainAudio([]types.Modality{types.ModalityAudio})).To(BeTrue())
})

It("returns true when both audio and text are present", func() {
Expect(modalitiesContainAudio([]types.Modality{types.ModalityText, types.ModalityAudio})).To(BeTrue())
})
})
Loading