diff --git a/AGENTS.md b/AGENTS.md index 075a469..ab361ef 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,5 +1,8 @@ # xAI SDK implementation notes -- `GrokClient` is primarily backed by generated gRPC protocol clients, but text to speech uses xAI's documented REST/WebSocket voice endpoints because there are no generated TTS protocol types in `src\xAI.Protocol`. +- `GrokClient` is primarily backed by generated gRPC protocol clients, but voice features use xAI's documented REST/WebSocket endpoints because there are no generated voice protocol types in `src\xAI.Protocol`. +- Voice REST calls use `GrokClient.HttpHandler` (backed by `httpHandlers` cache) — a plain `SocketsHttpHandler`+Polly pipeline separate from the gRPC channel. `ChannelHandler` returns `ChannelBase` only; there is no `.Handler` property on it. - `AsITextToSpeechClient` returns an `ITextToSpeechClient` implementation that uses `POST /v1/tts` for unary audio and `wss://.../v1/tts` for streaming audio. +- `AsISpeechToTextClient` returns an `ISpeechToTextClient` implementation that uses `POST /v1/stt` for file transcription and `wss://.../v1/stt` for raw-audio streaming transcription. - TTS defaults follow xAI docs: voice `eve`, language `en` when omitted by `TextToSpeechOptions`, and MP3 output when no codec is specified. +- STT streaming defaults follow xAI docs: encoding `pcm` and sample rate `16000` when omitted; WebSocket input must be raw encoded audio, not MP3/WAV container bytes. diff --git a/readme.md b/readme.md index 8dba024..b8c03fc 100644 --- a/readme.md +++ b/readme.md @@ -51,6 +51,12 @@ var speech = new GrokClient(Environment.GetEnvironmentVariable("XAI_API_KEY")!) var audio = await speech.GetAudioAsync("Hello! Welcome to xAI text to speech.", new TextToSpeechOptions { VoiceId = "eve", Language = "en" }); + +var transcription = new GrokClient(Environment.GetEnvironmentVariable("XAI_API_KEY")!) + .AsISpeechToTextClient(); + +var text = await transcription.GetTextAsync(File.OpenRead("audio.mp3"), + new SpeechToTextOptions { TextLanguage = "en" }); ``` ## File Attachments @@ -402,6 +408,8 @@ Console.WriteLine($"Edited image URL: {editedImage.Uri}"); ## Text to Speech Grok supports text to speech via the `ITextToSpeechClient` abstraction from Microsoft.Extensions.AI. +See the [xAI text to speech docs](https://docs.x.ai/developers/model-capabilities/audio/text-to-speech) +for supported voices, formats, and streaming details. Use `AsITextToSpeechClient` to get a TTS client: ```csharp @@ -465,6 +473,87 @@ var options = new GrokTextToSpeechOptions var response = await speech.GetAudioAsync("Streaming at 24 kHz, 128 kbps.", options); ``` +## Speech to Text + +Grok supports speech to text via the `ISpeechToTextClient` abstraction from Microsoft.Extensions.AI. +See the [xAI speech to text docs](https://docs.x.ai/developers/model-capabilities/audio/speech-to-text) +for supported languages, audio formats, diarization, multichannel audio, and streaming details. +Use `AsISpeechToTextClient` to get an STT client: + +```csharp +var transcription = new GrokClient(Environment.GetEnvironmentVariable("XAI_API_KEY")!) + .AsISpeechToTextClient(); +``` + +### Unary (single response) + +Call `GetTextAsync` to transcribe an audio file in a single request. The result contains transcript +text, timing information, and the raw xAI response: + +```csharp +await using var audio = File.OpenRead("meeting.mp3"); + +var response = await transcription.GetTextAsync(audio, + new GrokSpeechToTextOptions + { + TextLanguage = "en", + Format = true, + }); + +Console.WriteLine(response.Text); +``` + +Set `Format = true` with `TextLanguage` to enable xAI's inverse text normalization, such as converting +spoken numbers and currencies into written form. + +### Streaming + +Call `GetStreamingTextAsync` to stream raw audio and receive transcript updates as speech is processed. +The xAI streaming endpoint expects raw encoded audio such as PCM, µ-law, or A-law rather than MP3/WAV +container bytes: + +```csharp +await using var audio = File.OpenRead("audio.pcm"); + +await foreach (var update in transcription.GetStreamingTextAsync(audio, + new GrokSpeechToTextOptions + { + AudioFormat = "pcm", + SpeechSampleRate = 16000, + TextLanguage = "en", + InterimResults = true, + })) +{ + if (update.Kind is SpeechToTextResponseUpdateKind.TextUpdating or + SpeechToTextResponseUpdateKind.TextUpdated) + { + Console.WriteLine(update.Text); + } +} +``` + +### Grok-Specific Options + +Use `GrokSpeechToTextOptions` to control xAI transcription behavior beyond the base +`SpeechToTextOptions`: + +```csharp +var options = new GrokSpeechToTextOptions +{ + TextLanguage = "en", + SpeechSampleRate = 16000, + Format = true, // normalize spoken numbers, currencies, and units + AudioFormat = "pcm", // pcm | mulaw | alaw for raw audio + Diarize = true, // include speaker IDs on words when available + Multichannel = true, // transcribe each channel independently + Channels = 2, + InterimResults = true, // streaming only + Endpointing = 10, // streaming silence duration in milliseconds +}; + +var response = await transcription.GetTextAsync(File.OpenRead("call.pcm"), options); +``` + # xAI.Protocol diff --git a/src/xAI.Tests/SanityChecks.cs b/src/xAI.Tests/SanityChecks.cs index af73fba..a972bb3 100644 --- a/src/xAI.Tests/SanityChecks.cs +++ b/src/xAI.Tests/SanityChecks.cs @@ -1,13 +1,8 @@ using System.Text.Json; -using Devlooped.Extensions.AI; -using DotNetEnv; -using Grpc.Core; -using Grpc.Net.Client.Configuration; using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using xAI.Protocol; -using Xunit.Abstractions; -using Xunit.Sdk; +using static ConfigurationExtensions; using ChatConversation = Devlooped.Extensions.AI.Chat; namespace xAI.Tests; @@ -18,7 +13,7 @@ public class SanityChecks(ITestOutputHelper output) public async Task NoEmbeddingModels() { var services = new ServiceCollection() - .AddxAIProtocol(Environment.GetEnvironmentVariable("CI_XAI_API_KEY")!) + .AddxAIProtocol(Configuration["CI_XAI_API_KEY"]!) .BuildServiceProvider(); var client = services.GetRequiredService(); @@ -33,7 +28,7 @@ public async Task NoEmbeddingModels() public async Task ListModelsAsync() { var services = new ServiceCollection() - .AddxAIProtocol(Environment.GetEnvironmentVariable("CI_XAI_API_KEY")!) + .AddxAIProtocol(Configuration["CI_XAI_API_KEY"]!) .BuildServiceProvider(); var client = services.GetRequiredService(); @@ -50,7 +45,7 @@ public async Task ListModelsAsync() public async Task ExecuteLocalFunctionWithWebSearch() { var services = new ServiceCollection() - .AddxAIProtocol(Environment.GetEnvironmentVariable("CI_XAI_API_KEY")!) + .AddxAIProtocol(Configuration["CI_XAI_API_KEY"]!) .BuildServiceProvider(); var client = services.GetRequiredService(); @@ -161,7 +156,7 @@ public async Task ExecuteLocalFunctionWithWebSearch() public async Task ClientSideFunction(bool streaming) { var getDateCalls = 0; - var grok = new GrokClient(Env.GetString("CI_XAI_API_KEY")!) + var grok = new GrokClient(Configuration["CI_XAI_API_KEY"]!) .AsIChatClient("grok-4-1-fast") .AsBuilder() .UseFunctionInvocation() @@ -203,7 +198,7 @@ What is today's date? Use the get_date tool. [InlineData(true)] public async Task AgenticWebSearch(bool streaming) { - var grok = new GrokClient(Env.GetString("CI_XAI_API_KEY")!) + var grok = new GrokClient(Configuration["CI_XAI_API_KEY"]!) .AsIChatClient("grok-4-1-fast"); var options = new GrokChatOptions @@ -249,7 +244,7 @@ What is the current price of Tesla (TSLA) stock? Use web search (Yahoo Finance o [InlineData(true)] public async Task AgenticXSearch(bool streaming) { - var grok = new GrokClient(Env.GetString("CI_XAI_API_KEY")!) + var grok = new GrokClient(Configuration["CI_XAI_API_KEY"]!) .AsIChatClient("grok-4-1-fast"); var options = new GrokChatOptions @@ -288,7 +283,7 @@ What is the top news from Tesla on X? Use the X search tool. [InlineData(true)] public async Task AgenticMcpServer(bool streaming) { - var grok = new GrokClient(Env.GetString("CI_XAI_API_KEY")!) + var grok = new GrokClient(Configuration["CI_XAI_API_KEY"]!) .AsIChatClient("grok-4-1-fast"); var options = new GrokChatOptions @@ -299,7 +294,7 @@ public async Task AgenticMcpServer(bool streaming) [ new HostedMcpServerTool("GitHub", "https://api.githubcopilot.com/mcp/") { - Headers = new Dictionary < string, string > {["Authorization"] = Env.GetString("GITHUB_TOKEN") ! }, + Headers = new Dictionary < string, string > {["Authorization"] = Configuration["GITHUB_TOKEN"] ! }, AllowedTools = ["list_releases", "get_release_by_tag"], } ] @@ -340,7 +335,7 @@ What is the latest release version of the {{ThisAssembly.Git.Url}} repository? U [InlineData(true)] public async Task AgenticFileSearch(bool streaming) { - var grok = new GrokClient(Env.GetString("CI_XAI_API_KEY")!) + var grok = new GrokClient(Configuration["CI_XAI_API_KEY"]!) .AsIChatClient("grok-4-1-fast"); var options = new GrokChatOptions @@ -406,7 +401,7 @@ Use the collection search tool. [InlineData(true)] public async Task AgenticCodeInterpreter(bool streaming) { - var client = new GrokClient(Env.GetString("CI_XAI_API_KEY")!); + var client = new GrokClient(Configuration["CI_XAI_API_KEY"]!); var grok = client.AsIChatClient("grok-4-1-fast"); @@ -451,6 +446,72 @@ parseable by a decimal parser. output.WriteLine($"Code interpreter calls: {codeInterpreterCalls.Count}"); } + [SecretsTheory("CI_XAI_API_KEY")] + [InlineData("rex")] + public async Task TextToSpeech_SpeechToText(string voiceId) + { + using var client = new GrokClient(Configuration["CI_XAI_API_KEY"]!); + using var tts = client.AsITextToSpeechClient(); + + var expected = "El que cree en mí, en realidad no cree en mí, sino en aquel que me envió."; + var tempFile = System.IO.Path.Combine(System.IO.Path.GetTempPath(), $"xai-tts-{Guid.NewGuid():N}.pcm"); + + try + { + await using (var fileStream = System.IO.File.Create(tempFile)) + { + await foreach (var update in tts.GetStreamingAudioAsync( + expected, + new TextToSpeechOptions + { + VoiceId = voiceId, + Language = "es-ES", + // uses mp3 by default + })) + { + if (update.Kind == TextToSpeechResponseUpdateKind.AudioUpdating) + { + foreach (var content in update.Contents) + { + if (content is DataContent data) + { + await fileStream.WriteAsync(data.Data); + } + } + } + } + } + + Assert.True(System.IO.File.Exists(tempFile)); + Assert.True(new System.IO.FileInfo(tempFile).Length > 0); + + using var stt = client.AsISpeechToTextClient(); + await using var audioStream = System.IO.File.OpenRead(tempFile); + + // auto-detect format from content + var transcription = await stt.GetTextAsync(audioStream); + + Assert.Equal( + NormalizeTranscription(expected), + NormalizeTranscription(transcription.Text), + ignoreCase: true); + } + finally + { + if (System.IO.File.Exists(tempFile)) + System.IO.File.Delete(tempFile); + } + } + + static string NormalizeTranscription(string? text) + { + var withoutPunctuation = new string((text ?? string.Empty) + .Select(character => char.IsPunctuation(character) ? ' ' : character) + .ToArray()); + + return string.Join(" ", withoutPunctuation.Split((char[]?)null, StringSplitOptions.RemoveEmptyEntries)); + } + static async Task GetResponseAsync(IChatClient client, ChatConversation chat, GrokChatOptions options, bool streaming) { if (!streaming) diff --git a/src/xAI.Tests/SpeechToTextClientTests.cs b/src/xAI.Tests/SpeechToTextClientTests.cs new file mode 100644 index 0000000..4aac82d --- /dev/null +++ b/src/xAI.Tests/SpeechToTextClientTests.cs @@ -0,0 +1,380 @@ +using System.Net; +using System.Net.Http.Headers; +using System.Net.WebSockets; +using System.Text; +using System.Text.Json; +using Grpc.Net.Client; +using Microsoft.Extensions.AI; + +namespace xAI.Tests; + +public class SpeechToTextClientTests +{ + [Fact] + public void AsISpeechToTextClient_ReturnsMetadata() + { + using var client = new GrokClient("test-api-key", CreateOptions(new CaptureHandler())); + using var stt = client.AsISpeechToTextClient(); + + var metadata = stt.GetService(); + + Assert.NotNull(metadata); + Assert.Equal("xai", metadata.ProviderName); + Assert.Equal(client.Options.Endpoint, metadata.ProviderUri); + Assert.Null(metadata.DefaultModelId); + } + + [Fact] + public async Task GetTextAsync_MapsRequestAndResponse() + { + var handler = new CaptureHandler(_ => new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent( + """ + { + "text": "Hello world", + "language": "English", + "duration": 1.25, + "words": [ + { "text": "Hello", "start": 0.10, "end": 0.50 }, + { "text": "world", "start": 0.60, "end": 1.10 } + ] + } + """, Encoding.UTF8, "application/json"), + }); + + using var client = new GrokClient("test-api-key", CreateOptions(handler)); + using var stt = client.AsISpeechToTextClient(); + + var response = await stt.GetTextAsync(new MemoryStream([1, 2, 3]), + new GrokSpeechToTextOptions + { + TextLanguage = "en", + SpeechSampleRate = 16000, + Format = true, + AudioFormat = "pcm", + Multichannel = true, + Channels = 2, + Diarize = true, + ModelId = "test-model", + }); + + Assert.Equal(HttpMethod.Post, handler.Request!.Method); + Assert.Equal(new Uri($"{client.Options.Endpoint}v1/stt"), handler.Request.RequestUri); + Assert.Equal("Bearer", handler.Request.Headers.Authorization?.Scheme); + Assert.Equal("test-api-key", handler.Request.Headers.Authorization?.Parameter); + + var body = handler.RequestBody!; + AssertFieldOrder(body, "format", "language", "sample_rate", "audio_format", "multichannel", "channels", "diarize", "file"); + Assert.Contains("format", GetField(body, "format")); + Assert.Contains("true", body); + Assert.Contains("language", GetField(body, "language")); + Assert.Contains("en", body); + Assert.Contains("sample_rate", GetField(body, "sample_rate")); + Assert.Contains("16000", body); + Assert.Contains("audio_format", GetField(body, "audio_format")); + Assert.Contains("pcm", body); + Assert.Contains("audio.mp3", body); + + Assert.Equal("Hello world", response.Text); + Assert.Null(response.ModelId); + Assert.Equal(TimeSpan.FromSeconds(0.10), response.StartTime); + Assert.Equal(TimeSpan.FromSeconds(1.10), response.EndTime); + Assert.Equal("English", response.AdditionalProperties?["language"]); + Assert.Equal(1.25, response.AdditionalProperties?["duration"]); + } + + [Fact] + public async Task GetTextAsync_WithError_ThrowsHttpRequestException() + { + var handler = new CaptureHandler(_ => new HttpResponseMessage(HttpStatusCode.BadRequest) + { + ReasonPhrase = "Bad Request", + Content = new StringContent("""{"error":"missing file"}"""), + }); + + using var client = new GrokClient("test-api-key", CreateOptions(handler)); + using var stt = client.AsISpeechToTextClient(); + + var exception = await Assert.ThrowsAsync(() => stt.GetTextAsync(new MemoryStream([1]))); + + Assert.Equal(HttpStatusCode.BadRequest, exception.StatusCode); + Assert.Contains("missing file", exception.Message); + } + + [Fact] + public async Task GetTextAsync_WithNullStream_ThrowsArgumentNullException() + { + using var client = new GrokClient("test-api-key", CreateOptions(new CaptureHandler())); + using var stt = client.AsISpeechToTextClient(); + + await Assert.ThrowsAsync(() => stt.GetTextAsync(null!)); + } + + [Fact] + public async Task GetTextAsync_WithTranslation_ThrowsNotSupportedException() + { + using var client = new GrokClient("test-api-key", CreateOptions(new CaptureHandler())); + using var stt = client.AsISpeechToTextClient(); + + await Assert.ThrowsAsync(() => stt.GetTextAsync(new MemoryStream([1]), + new SpeechToTextOptions + { + SpeechLanguage = "en", + TextLanguage = "fr", + })); + } + + [Fact] + public async Task GetTextAsync_WithFormatAndNoLanguage_ThrowsArgumentException() + { + using var client = new GrokClient("test-api-key", CreateOptions(new CaptureHandler())); + using var stt = client.AsISpeechToTextClient(); + + await Assert.ThrowsAsync(() => stt.GetTextAsync(new MemoryStream([1]), + new GrokSpeechToTextOptions { Format = true })); + } + + [Fact] + public async Task GetStreamingTextAsync_MapsWebSocketEvents() + { + var webSocket = new FakeWebSocket( + """{"type":"transcript.created"}""", + """{"type":"transcript.partial","text":"Hel","is_final":false,"speech_final":false,"start":0.0,"duration":0.4}""", + """{"type":"transcript.partial","text":"Hello","is_final":true,"speech_final":true,"start":0.0,"duration":0.8,"channel_index":1}""", + """{"type":"transcript.done","text":"Hello world","duration":1.2}"""); + + Uri? capturedUri = null; + string? capturedApiKey = null; + using var stt = new GrokSpeechToTextClient( + new HttpClient(new CaptureHandler()), + new Uri("https://streaming.test/base/"), + "test-api-key", + (uri, apiKey, _) => + { + capturedUri = uri; + capturedApiKey = apiKey; + return ValueTask.FromResult(webSocket); + }); + + var updates = new List(); + await foreach (var update in stt.GetStreamingTextAsync(new MemoryStream([1, 2, 3, 4]), + new GrokSpeechToTextOptions + { + AudioFormat = "mulaw", + SpeechSampleRate = 8000, + TextLanguage = "en", + InterimResults = true, + Endpointing = 5, + Diarize = true, + Multichannel = true, + Channels = 2, + ModelId = "ignored-model", + })) + { + updates.Add(update); + } + + Assert.Equal("test-api-key", capturedApiKey); + Assert.Equal("wss://streaming.test/base/v1/stt?sample_rate=8000&encoding=mulaw&interim_results=true&endpointing=5&language=en&diarize=true&multichannel=true&channels=2", capturedUri!.ToString()); + + Assert.Collection(webSocket.SentBinaryMessages, + message => Assert.Equal(new byte[] { 1, 2, 3, 4 }, message)); + + Assert.Collection(webSocket.SentTextMessages, + message => + { + using var json = JsonDocument.Parse(message); + Assert.Equal("audio.done", json.RootElement.GetProperty("type").GetString()); + }); + + Assert.Collection(updates, + update => + { + Assert.Equal(SpeechToTextResponseUpdateKind.SessionOpen, update.Kind); + Assert.Null(update.ModelId); + }, + update => + { + Assert.Equal(SpeechToTextResponseUpdateKind.TextUpdating, update.Kind); + Assert.Null(update.ModelId); + Assert.Equal("Hel", update.Text); + Assert.Equal(TimeSpan.Zero, update.StartTime); + Assert.Equal(TimeSpan.FromSeconds(0.4), update.EndTime); + }, + update => + { + Assert.Equal(SpeechToTextResponseUpdateKind.TextUpdated, update.Kind); + Assert.Null(update.ModelId); + Assert.Equal("Hello", update.Text); + Assert.Equal(1, update.AdditionalProperties?["channel_index"]); + }, + update => + { + Assert.Equal(SpeechToTextResponseUpdateKind.TextUpdated, update.Kind); + Assert.Null(update.ModelId); + Assert.Equal("Hello world", update.Text); + }, + update => + { + Assert.Equal(SpeechToTextResponseUpdateKind.SessionClose, update.Kind); + Assert.Null(update.ModelId); + }); + } + + [Fact] + public async Task GetStreamingTextAsync_WithErrorEvent_YieldsErrorUpdate() + { + var webSocket = new FakeWebSocket( + """{"type":"transcript.created"}""", + """{"type":"error","message":"bad audio"}""", + """{"type":"transcript.done","duration":0}"""); + + using var stt = new GrokSpeechToTextClient( + new HttpClient(new CaptureHandler()), + new Uri("https://streaming.test/"), + "test-api-key", + (_, _, _) => ValueTask.FromResult(webSocket)); + + var updates = new List(); + await foreach (var update in stt.GetStreamingTextAsync(new MemoryStream([1]))) + { + updates.Add(update); + } + + Assert.Contains(updates, update => update.Kind == SpeechToTextResponseUpdateKind.Error && update.Text == "bad audio"); + } + + [Fact] + public async Task GetStreamingTextAsync_WithUnsupportedEncoding_ThrowsArgumentException() + { + using var stt = new GrokSpeechToTextClient( + new HttpClient(new CaptureHandler()), + new Uri("https://streaming.test/"), + "test-api-key", + (_, _, _) => throw new InvalidOperationException("Should not connect.")); + + await Assert.ThrowsAsync(async () => + { + await foreach (var _ in stt.GetStreamingTextAsync(new MemoryStream([1]), + new GrokSpeechToTextOptions { AudioFormat = "mp3" })) + { + } + }); + } + + static GrokClientOptions CreateOptions(HttpMessageHandler handler) => new() + { + Endpoint = new Uri($"https://unit-{Guid.NewGuid():N}.test/"), + ChannelOptions = new GrpcChannelOptions + { + HttpHandler = handler, + }, + }; + + static void AssertFieldOrder(string body, params string[] fields) + { + var previous = -1; + foreach (var field in fields) + { + var current = IndexOfField(body, field); + Assert.True(current >= 0, $"Field '{field}' was not found in multipart body."); + Assert.True(current > previous, $"Field '{field}' was not in the expected multipart order."); + previous = current; + } + } + + static int IndexOfField(string body, string field) + { + var index = body.IndexOf($"name=\"{field}\"", StringComparison.Ordinal); + return index >= 0 ? index : body.IndexOf($"name={field}", StringComparison.Ordinal); + } + + static string GetField(string body, string field) + { + var index = IndexOfField(body, field); + Assert.True(index >= 0, $"Field '{field}' was not found in multipart body."); + return body[index..Math.Min(body.Length, index + 100)]; + } + + sealed class CaptureHandler(Func? responder = null) : HttpMessageHandler + { + readonly Func responder = responder ?? (_ => new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent("""{"text":"ok","duration":0}""", Encoding.UTF8, "application/json"), + }); + + public HttpRequestMessage? Request { get; private set; } + public string? RequestBody { get; private set; } + + protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + Request = request; + RequestBody = request.Content is null ? null : await request.Content.ReadAsStringAsync(cancellationToken); + return responder(request); + } + } + + sealed class FakeWebSocket(params string[] messages) : WebSocket + { + readonly Queue messages = new(messages.Select(Encoding.UTF8.GetBytes)); + WebSocketState state = WebSocketState.Open; + WebSocketCloseStatus? closeStatus; + string? closeStatusDescription; + + public List SentTextMessages { get; } = []; + public List SentBinaryMessages { get; } = []; + + public override WebSocketCloseStatus? CloseStatus => closeStatus; + + public override string? CloseStatusDescription => closeStatusDescription; + + public override WebSocketState State => state; + + public override string? SubProtocol => null; + + public override void Abort() => state = WebSocketState.Aborted; + + public override Task CloseAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken) + { + this.closeStatus = closeStatus; + closeStatusDescription = statusDescription; + state = WebSocketState.Closed; + return Task.CompletedTask; + } + + public override Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken) + => CloseAsync(closeStatus, statusDescription, cancellationToken); + + public override void Dispose() => state = WebSocketState.Closed; + + public override Task ReceiveAsync(ArraySegment buffer, CancellationToken cancellationToken) + { + if (messages.Count == 0) + { + state = WebSocketState.CloseReceived; + return Task.FromResult(new WebSocketReceiveResult(0, WebSocketMessageType.Close, true, WebSocketCloseStatus.NormalClosure, "closed")); + } + + var message = messages.Dequeue(); + message.CopyTo(buffer.Array!, buffer.Offset); + return Task.FromResult(new WebSocketReceiveResult(message.Length, WebSocketMessageType.Text, true)); + } + + public override Task SendAsync(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) + { + if (messageType == WebSocketMessageType.Text) + { + SentTextMessages.Add(Encoding.UTF8.GetString(buffer.Array!, buffer.Offset, buffer.Count)); + } + else + { + var copy = new byte[buffer.Count]; + Array.Copy(buffer.Array!, buffer.Offset, copy, 0, buffer.Count); + SentBinaryMessages.Add(copy); + } + + return Task.CompletedTask; + } + } +} diff --git a/src/xAI.Tests/TextToSpeechClientTests.cs b/src/xAI.Tests/TextToSpeechClientTests.cs index 266f451..6087074 100644 --- a/src/xAI.Tests/TextToSpeechClientTests.cs +++ b/src/xAI.Tests/TextToSpeechClientTests.cs @@ -77,7 +77,7 @@ public async Task GetAudioAsync_MapsRequestAndResponse() var data = Assert.IsType(content); Assert.Equal("audio/wav", data.MediaType); Assert.Equal(audio, data.Data.ToArray()); - Assert.Equal("test-model", response.ModelId); + Assert.Null(response.ModelId); } [Theory] @@ -181,6 +181,7 @@ public async Task GetStreamingAudioAsync_MapsWebSocketEvents() SampleRate = 8000, OptimizeStreamingLatency = 1, TextNormalization = true, + ModelId = "ignored-model", })) { updates.Add(update); @@ -206,6 +207,7 @@ public async Task GetStreamingAudioAsync_MapsWebSocketEvents() update => { Assert.Equal(TextToSpeechResponseUpdateKind.AudioUpdating, update.Kind); + Assert.Null(update.ModelId); var data = Assert.IsType(Assert.Single(update.Contents)); Assert.Equal(new byte[] { 1, 2, 3 }, data.Data.ToArray()); Assert.Equal("audio/basic", data.MediaType); @@ -213,6 +215,7 @@ public async Task GetStreamingAudioAsync_MapsWebSocketEvents() update => { Assert.Equal(TextToSpeechResponseUpdateKind.SessionClose, update.Kind); + Assert.Null(update.ModelId); Assert.Equal("trace-123", update.AdditionalProperties?["trace_id"]); }); } @@ -237,58 +240,6 @@ public async Task GetStreamingAudioAsync_WithErrorEvent_ThrowsInvalidOperationEx Assert.Contains("voice rejected", exception.Message); } - [SecretsTheory("XAI_API_KEY")] - //[InlineData("ara")] - //[InlineData("eve")] - [InlineData("rex")] // 👈 el mejor para Jesus - //[InlineData("sal")] - //[InlineData("leo")] - public async Task GetStreamingAudioAsync_IntegrationTest_SavesAndPlaysAudio(string voiceId) - { - var apiKey = Environment.GetEnvironmentVariable("XAI_API_KEY"); - using var client = new GrokClient(apiKey!); - using var tts = client.AsITextToSpeechClient(); - - var tempFile = System.IO.Path.Combine(System.IO.Path.GetTempPath(), $"xai-tts-{Guid.NewGuid():N}.mp3"); - - await using (var fileStream = System.IO.File.Create(tempFile)) - { - await foreach (var update in tts.GetStreamingAudioAsync( - """ - El que cree en mí, en realidad no cree en mí, sino en aquel que me envió. - Y el que me ve, ve al que me envió. - Yo soy la luz, y he venido al mundo para que todo el que crea en mí no permanezca en las tinieblas. - """, - new GrokTextToSpeechOptions - { - VoiceId = voiceId, - AudioFormat = "mp3", - - })) - { - if (update.Kind == TextToSpeechResponseUpdateKind.AudioUpdating) - { - foreach (var content in update.Contents) - { - if (content is DataContent data) - { - await fileStream.WriteAsync(data.Data); - } - } - } - } - } - - Assert.True(System.IO.File.Exists(tempFile)); - Assert.True(new System.IO.FileInfo(tempFile).Length > 0); - - System.Diagnostics.Process.Start(new System.Diagnostics.ProcessStartInfo - { - FileName = tempFile, - UseShellExecute = true - }); - } - static GrokClientOptions CreateOptions(HttpMessageHandler handler) => new() { Endpoint = new Uri($"https://unit-{Guid.NewGuid():N}.test/"), diff --git a/src/xAI/GrokClient.cs b/src/xAI/GrokClient.cs index 43871ec..a2c627c 100644 --- a/src/xAI/GrokClient.cs +++ b/src/xAI/GrokClient.cs @@ -1,4 +1,4 @@ -using System.Collections.Concurrent; +using System.Collections.Concurrent; using System.Diagnostics; using System.Net.Http.Headers; using Grpc.Core; @@ -15,14 +15,16 @@ namespace xAI; /// The options used to configure the client. public sealed class GrokClient(string apiKey, GrokClientOptions options) : IDisposable { - static readonly ConcurrentDictionary<(Uri, string), (ChannelBase, HttpMessageHandler)> channels = []; + static readonly ConcurrentDictionary<(Uri, string), ChannelBase> channels = []; + static readonly ConcurrentDictionary<(Uri, string), HttpMessageHandler> httpHandlers = []; + readonly HttpMessageHandler? configuredHttpHandler = options.ChannelOptions?.HttpHandler; /// Initializes a new instance of the class with default options. public GrokClient(string apiKey) : this(apiKey, new GrokClientOptions()) { } /// Testing ctor. internal GrokClient(ChannelBase channel, GrokClientOptions options, string? apiKey = default) : this(apiKey ?? "", options) - => channels[(options.Endpoint, apiKey ?? "")] = (channel, GetHttpHandler(options.ChannelOptions, apiKey ?? "")); + => channels[(options.Endpoint, apiKey ?? "")] = channel; /// Gets the API key used for authentication. public string ApiKey { get; } = apiKey; @@ -34,29 +36,29 @@ internal GrokClient(ChannelBase channel, GrokClientOptions options, string? apiK public GrokClientOptions Options { get; } = options; /// Gets a new instance of that reuses the client configuration details provided to the instance. - public Auth.AuthClient GetAuthClient() => new(ChannelHandler.Channel); + public Auth.AuthClient GetAuthClient() => new(ChannelHandler); /// Gets a new instance of that reuses the client configuration details provided to the instance. - public Chat.ChatClient GetChatClient() => new(ChannelHandler.Channel, Options); + public Chat.ChatClient GetChatClient() => new(ChannelHandler, Options); /// Gets a new instance of that reuses the client configuration details provided to the instance. - public Documents.DocumentsClient GetDocumentsClient() => new(ChannelHandler.Channel); + public Documents.DocumentsClient GetDocumentsClient() => new(ChannelHandler); /// Gets a new instance of that reuses the client configuration details provided to the instance. - public Embedder.EmbedderClient GetEmbedderClient() => new(ChannelHandler.Channel); + public Embedder.EmbedderClient GetEmbedderClient() => new(ChannelHandler); /// Gets a new instance of that reuses the client configuration details provided to the instance. - public Image.ImageClient GetImageClient() => new(ChannelHandler.Channel, Options); + public Image.ImageClient GetImageClient() => new(ChannelHandler, Options); /// Gets a new instance of that reuses the client configuration details provided to the instance. - public Models.ModelsClient GetModelsClient() => new(ChannelHandler.Channel); + public Models.ModelsClient GetModelsClient() => new(ChannelHandler); /// Gets a new instance of that reuses the client configuration details provided to the instance. - public Tokenize.TokenizeClient GetTokenizeClient() => new(ChannelHandler.Channel); + public Tokenize.TokenizeClient GetTokenizeClient() => new(ChannelHandler); - internal (ChannelBase Channel, HttpMessageHandler Handler) ChannelHandler => channels.GetOrAdd((Endpoint, ApiKey), key => + internal ChannelBase ChannelHandler => channels.GetOrAdd((Endpoint, ApiKey), key => { - var handler = GetHttpHandler(Options.ChannelOptions, key.Item2); + var handler = GetHttpHandler(configuredHttpHandler, key.Item2); // Provide some sensible defaults for gRPC channel options, while allowing users to // override them via GrokClientOptions.ChannelOptions if needed. @@ -69,12 +71,14 @@ internal GrokClient(ChannelBase channel, GrokClientOptions options, string? apiK options.HttpHandler = handler; - return (GrpcChannel.ForAddress(key.Item1, options), handler); + return GrpcChannel.ForAddress(key.Item1, options); }); - static HttpMessageHandler GetHttpHandler(GrpcChannelOptions? options, string apiKey) + internal HttpMessageHandler HttpHandler => + httpHandlers.GetOrAdd((Endpoint, ApiKey), key => GetHttpHandler(configuredHttpHandler, key.Item2)); + + static HttpMessageHandler GetHttpHandler(HttpMessageHandler? inner, string apiKey) { - var inner = options?.HttpHandler; if (inner == null) { // If no custom HttpHandler is provided, we create one with Polly retry @@ -122,7 +126,11 @@ StatusCode.DeadlineExceeded or } /// Clears the cached list of gRPC channels in the client. - public void Dispose() => channels.Clear(); + public void Dispose() + { + channels.Clear(); + httpHandlers.Clear(); + } class AuthenticationHeaderHandler(string apiKey) : DelegatingHandler { diff --git a/src/xAI/GrokClientExtensions.cs b/src/xAI/GrokClientExtensions.cs index f81da7c..9cb9e5d 100644 --- a/src/xAI/GrokClientExtensions.cs +++ b/src/xAI/GrokClientExtensions.cs @@ -1,4 +1,4 @@ -using System.ComponentModel; +using System.ComponentModel; using Microsoft.Extensions.AI; using xAI.Protocol; @@ -10,7 +10,7 @@ public static class GrokClientExtensions { /// Creates a new from the specified using the given model as the default. public static IChatClient AsIChatClient(this GrokClient client, string defaultModelId) - => new GrokChatClient(client.ChannelHandler.Channel, client.Options, defaultModelId); + => new GrokChatClient(client.ChannelHandler, client.Options, defaultModelId); /// Creates a new from the specified using the given model as the default. public static IChatClient AsIChatClient(this Chat.ChatClient client, string defaultModelId) @@ -18,7 +18,7 @@ public static IChatClient AsIChatClient(this Chat.ChatClient client, string defa /// Creates a new from the specified using the given model as the default. public static IImageGenerator AsIImageGenerator(this GrokClient client, string defaultModelId) - => new GrokImageGenerator(client.ChannelHandler.Channel, client.Options, defaultModelId); + => new GrokImageGenerator(client.ChannelHandler, client.Options, defaultModelId); /// Creates a new from the specified using the given model as the default. public static IImageGenerator AsIImageGenerator(this Image.ImageClient client, string defaultModelId) @@ -26,5 +26,9 @@ public static IImageGenerator AsIImageGenerator(this Image.ImageClient client, s /// Creates a new from the specified . public static ITextToSpeechClient AsITextToSpeechClient(this GrokClient client) - => new GrokTextToSpeechClient(client.ChannelHandler.Handler, client.Options, client.ApiKey); + => new GrokTextToSpeechClient(client.HttpHandler, client.Options, client.ApiKey); + + /// Creates a new from the specified . + public static ISpeechToTextClient AsISpeechToTextClient(this GrokClient client) + => new GrokSpeechToTextClient(client.HttpHandler, client.Options, client.ApiKey); } diff --git a/src/xAI/GrokSpeechToTextClient.cs b/src/xAI/GrokSpeechToTextClient.cs new file mode 100644 index 0000000..c34106f --- /dev/null +++ b/src/xAI/GrokSpeechToTextClient.cs @@ -0,0 +1,503 @@ +using System.Buffers; +using System.Collections.Specialized; +using System.Globalization; +using System.Net.Http.Headers; +using System.Net.Http.Json; +using System.Net.WebSockets; +using System.Runtime.CompilerServices; +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization; +using System.Text.Json.Serialization.Metadata; +using Microsoft.Extensions.AI; + +namespace xAI; + +/// Represents an for xAI's Grok speech to text service. +partial class GrokSpeechToTextClient : ISpeechToTextClient +{ + const string DefaultFilename = "audio.mp3"; + const string DefaultStreamingEncoding = "pcm"; + const int DefaultStreamingSampleRate = 16000; + const int DefaultStreamingChunkSize = 8192; + + static readonly Dictionary extensionToMediaType = new(StringComparer.OrdinalIgnoreCase) + { + [".wav"] = "audio/wav", + [".mp3"] = "audio/mpeg", + [".ogg"] = "audio/ogg", + [".opus"] = "audio/opus", + [".flac"] = "audio/flac", + [".aac"] = "audio/aac", + [".mp4"] = "audio/mp4", + [".m4a"] = "audio/mp4", + [".mkv"] = "video/x-matroska", + }; + + readonly SpeechToTextClientMetadata metadata; + readonly HttpClient httpClient; + readonly Uri endpoint; + readonly string? apiKey; + readonly Func> webSocketFactory; + + internal GrokSpeechToTextClient(HttpMessageHandler handler, GrokClientOptions options, string? apiKey) + : this(new HttpClient(handler, disposeHandler: false), options.Endpoint, apiKey, CreateWebSocketAsync) + { + } + + internal GrokSpeechToTextClient( + HttpClient httpClient, + Uri endpoint, + string? apiKey, + Func> webSocketFactory) + { + this.httpClient = Throw.IfNull(httpClient); + this.endpoint = Throw.IfNull(endpoint); + this.apiKey = apiKey; + this.webSocketFactory = Throw.IfNull(webSocketFactory); + + metadata = new("xai", endpoint); + } + + /// + public async Task GetTextAsync( + Stream audioSpeechStream, + SpeechToTextOptions? options = null, + CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(audioSpeechStream); + + using var message = new HttpRequestMessage(HttpMethod.Post, GetHttpEndpoint()) + { + Content = CreateMultipartContent(audioSpeechStream, options), + }; + + using var response = await httpClient.SendAsync(message, cancellationToken).ConfigureAwait(false); + + if (!response.IsSuccessStatusCode) + await ThrowHttpExceptionAsync(response, cancellationToken).ConfigureAwait(false); + + var transcript = await response.Content.ReadFromJsonAsync(SpeechToTextJsonContext.Default.GrokSpeechToTextResponse, cancellationToken).ConfigureAwait(false) + ?? throw new InvalidOperationException("xAI STT response body was empty."); + + return ToSpeechToTextResponse(transcript, options); + } + + /// + public async IAsyncEnumerable GetStreamingTextAsync( + Stream audioSpeechStream, + SpeechToTextOptions? options = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(audioSpeechStream); + + using var webSocket = await webSocketFactory(GetStreamingEndpoint(options), apiKey, cancellationToken).ConfigureAwait(false); + + using (var ready = await ReceiveJsonAsync(webSocket, cancellationToken).ConfigureAwait(false)) + { + var root = ready.RootElement; + var rawRepresentation = root.Clone(); + var type = GetRequiredString(root, "type"); + + if (type != "transcript.created") + throw new InvalidOperationException($"Expected xAI STT streaming event type 'transcript.created' but received '{type}'."); + + yield return new SpeechToTextResponseUpdate + { + Kind = SpeechToTextResponseUpdateKind.SessionOpen, + RawRepresentation = rawRepresentation, + }; + } + + await SendAudioAsync(webSocket, audioSpeechStream, cancellationToken).ConfigureAwait(false); + await SendJsonAsync(webSocket, AudioDoneMessage.Instance, SpeechToTextJsonContext.Default.AudioDoneMessage, cancellationToken).ConfigureAwait(false); + + while (true) + { + using var json = await ReceiveJsonAsync(webSocket, cancellationToken).ConfigureAwait(false); + var root = json.RootElement; + var rawRepresentation = root.Clone(); + var type = GetRequiredString(root, "type"); + + switch (type) + { + case "transcript.partial": + yield return CreateTextUpdate(root, rawRepresentation, options); + break; + + case "transcript.done": + if (TryGetString(root, "text") is { Length: > 0 }) + yield return CreateTextUpdate(root, rawRepresentation, options, SpeechToTextResponseUpdateKind.TextUpdated); + + yield return new SpeechToTextResponseUpdate + { + Kind = SpeechToTextResponseUpdateKind.SessionClose, + RawRepresentation = rawRepresentation, + AdditionalProperties = CreateStreamingAdditionalProperties(root), + }; + yield break; + + case "error": + yield return new SpeechToTextResponseUpdate + { + Kind = SpeechToTextResponseUpdateKind.Error, + RawRepresentation = rawRepresentation, + Contents = [new TextContent(GetRequiredString(root, "message"))], + }; + break; + + default: + throw new InvalidOperationException($"Unsupported xAI STT streaming event type: {type}"); + } + } + } + + /// + public object? GetService(Type serviceType, object? serviceKey = null) => serviceKey is not null ? null : serviceType switch + { + Type t when t == typeof(SpeechToTextClientMetadata) => metadata, + Type t when t == typeof(GrokSpeechToTextClient) => this, + Type t when t == typeof(HttpClient) => httpClient, + Type t when t.IsInstanceOfType(this) => this, + _ => null + }; + + /// + public void Dispose() => httpClient.Dispose(); + + static MultipartFormDataContent CreateMultipartContent(Stream audioSpeechStream, SpeechToTextOptions? options) + { + var content = new MultipartFormDataContent(); + var grokOptions = options as GrokSpeechToTextOptions; + var language = GetLanguage(options); + + if (grokOptions?.Format is bool format) + { + if (format && language is null) + throw new ArgumentException("xAI STT requires a language when Format is true.", nameof(options)); + + content.Add(new StringContent(format ? "true" : "false"), "format"); + } + + if (language is not null) + content.Add(new StringContent(language), "language"); + + if (options?.SpeechSampleRate is int sampleRate) + content.Add(new StringContent(sampleRate.ToString(CultureInfo.InvariantCulture)), "sample_rate"); + + if (grokOptions?.AudioFormat is { Length: > 0 } audioFormat) + content.Add(new StringContent(GetRawAudioFormat(audioFormat)), "audio_format"); + + if (grokOptions?.Multichannel is bool multichannel) + content.Add(new StringContent(multichannel ? "true" : "false"), "multichannel"); + + if (grokOptions?.Channels is int channels) + content.Add(new StringContent(channels.ToString(CultureInfo.InvariantCulture)), "channels"); + + if (grokOptions?.Diarize is bool diarize) + content.Add(new StringContent(diarize ? "true" : "false"), "diarize"); + + var filename = GetFilename(audioSpeechStream); + var streamContent = new StreamContent(audioSpeechStream); + streamContent.Headers.ContentType = new MediaTypeHeaderValue(GetMediaType(filename)); + content.Add(streamContent, "file", filename); + + return content; + } + + Uri GetHttpEndpoint() => GetEndpoint(endpoint, "https", "v1/stt", null); + + Uri GetStreamingEndpoint(SpeechToTextOptions? options) + { + var grokOptions = options as GrokSpeechToTextOptions; + var query = new NameValueCollection + { + ["sample_rate"] = (options?.SpeechSampleRate ?? DefaultStreamingSampleRate).ToString(CultureInfo.InvariantCulture), + ["encoding"] = GetStreamingEncoding(grokOptions?.AudioFormat), + }; + + if (grokOptions?.InterimResults is bool interimResults) + query["interim_results"] = interimResults ? "true" : "false"; + + if (grokOptions?.Endpointing is int endpointing) + query["endpointing"] = endpointing.ToString(CultureInfo.InvariantCulture); + + if (GetLanguage(options) is { } language) + query["language"] = language; + + if (grokOptions?.Diarize is bool diarize) + query["diarize"] = diarize ? "true" : "false"; + + if (grokOptions?.Multichannel is bool multichannel) + query["multichannel"] = multichannel ? "true" : "false"; + + if (grokOptions?.Channels is int channels) + query["channels"] = channels.ToString(CultureInfo.InvariantCulture); + + return GetEndpoint(endpoint, endpoint.Scheme == Uri.UriSchemeHttp ? "ws" : "wss", "v1/stt", query); + } + + static SpeechToTextResponse ToSpeechToTextResponse(GrokSpeechToTextResponse transcript, SpeechToTextOptions? options) + { + var response = new SpeechToTextResponse([new TextContent(transcript.Text ?? "")]) + { + RawRepresentation = transcript, + AdditionalProperties = CreateResponseAdditionalProperties(transcript), + }; + + if (transcript.Words is { Count: > 0 } words) + { + response.StartTime = TimeSpan.FromSeconds(words[0].Start); + response.EndTime = TimeSpan.FromSeconds(words[^1].End); + } + else if (transcript.Duration is double duration) + { + response.StartTime = TimeSpan.Zero; + response.EndTime = TimeSpan.FromSeconds(duration); + } + + return response; + } + + static SpeechToTextResponseUpdate CreateTextUpdate( + JsonElement root, + JsonElement rawRepresentation, + SpeechToTextOptions? options, + SpeechToTextResponseUpdateKind? kind = null) + { + var update = new SpeechToTextResponseUpdate + { + Kind = kind ?? (GetBoolean(root, "is_final") == true ? SpeechToTextResponseUpdateKind.TextUpdated : SpeechToTextResponseUpdateKind.TextUpdating), + RawRepresentation = rawRepresentation, + Contents = TryGetString(root, "text") is { } text ? [new TextContent(text)] : [], + AdditionalProperties = CreateStreamingAdditionalProperties(root), + }; + + if (TryGetDouble(root, "start") is double start) + update.StartTime = TimeSpan.FromSeconds(start); + + if (TryGetDouble(root, "duration") is double duration) + update.EndTime = TimeSpan.FromSeconds((update.StartTime?.TotalSeconds ?? 0) + duration); + + return update; + } + + static AdditionalPropertiesDictionary? CreateResponseAdditionalProperties(GrokSpeechToTextResponse transcript) + { + AdditionalPropertiesDictionary? properties = null; + + AddProperty(ref properties, "language", transcript.Language); + AddProperty(ref properties, "duration", transcript.Duration); + AddProperty(ref properties, "words", transcript.Words); + AddProperty(ref properties, "channels", transcript.Channels); + + return properties; + } + + static AdditionalPropertiesDictionary? CreateStreamingAdditionalProperties(JsonElement root) + { + AdditionalPropertiesDictionary? properties = null; + + AddProperty(ref properties, "channel_index", TryGetInt(root, "channel_index")); + AddProperty(ref properties, "is_final", GetBoolean(root, "is_final")); + AddProperty(ref properties, "speech_final", GetBoolean(root, "speech_final")); + AddProperty(ref properties, "duration", TryGetDouble(root, "duration")); + + return properties; + } + + static void AddProperty(ref AdditionalPropertiesDictionary? properties, string name, object? value) + { + if (value is null) + return; + + (properties ??= [])[name] = value; + } + + static string? GetLanguage(SpeechToTextOptions? options) + { + if (options?.TextLanguage is { Length: > 0 } textLanguage && + options.SpeechLanguage is { Length: > 0 } speechLanguage && + !string.Equals(textLanguage, speechLanguage, StringComparison.OrdinalIgnoreCase)) + throw new NotSupportedException("xAI STT does not support translation between different speech and text languages."); + + return options?.TextLanguage ?? options?.SpeechLanguage; + } + + static string GetFilename(Stream audioSpeechStream) => + audioSpeechStream is FileStream fileStream ? Path.GetFileName(fileStream.Name) : DefaultFilename; + + static string GetMediaType(string filename) => + extensionToMediaType.TryGetValue(Path.GetExtension(filename), out var mediaType) ? mediaType : "application/octet-stream"; + + static string GetRawAudioFormat(string format) => format.ToLowerInvariant() switch + { + "pcm" or "audio/pcm" or "audio/l16" => "pcm", + "mulaw" or "ulaw" or "audio/basic" => "mulaw", + "alaw" or "audio/alaw" => "alaw", + _ => format.ToLowerInvariant(), + }; + + static string GetStreamingEncoding(string? format) + { + var encoding = string.IsNullOrWhiteSpace(format) ? DefaultStreamingEncoding : GetRawAudioFormat(format); + + return encoding switch + { + "pcm" or "mulaw" or "alaw" => encoding, + _ => throw new ArgumentException($"Unsupported xAI STT streaming encoding: {format}", nameof(format)), + }; + } + + static Uri GetEndpoint(Uri endpoint, string scheme, string relativePath, NameValueCollection? query) => new UriBuilder(endpoint) + { + Scheme = scheme, + Path = CombinePath(endpoint.AbsolutePath, relativePath), + Query = query is null ? "" : ToQueryString(query), + }.Uri; + + static string CombinePath(string basePath, string relativePath) + { + var path = basePath == "/" ? "" : basePath.TrimEnd('/'); + return $"{path}/{relativePath.TrimStart('/')}"; + } + + static string ToQueryString(NameValueCollection query) + { + var builder = new StringBuilder(); + + foreach (string key in query) + { + if (query[key] is not { } value) + continue; + + if (builder.Length > 0) + builder.Append('&'); + + builder + .Append(Uri.EscapeDataString(key)) + .Append('=') + .Append(Uri.EscapeDataString(value)); + } + + return builder.ToString(); + } + + static async Task ThrowHttpExceptionAsync(HttpResponseMessage response, CancellationToken cancellationToken) + { + var body = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false); + var message = string.IsNullOrWhiteSpace(body) ? + $"xAI STT request failed with status code {(int)response.StatusCode} ({response.ReasonPhrase})." : + $"xAI STT request failed with status code {(int)response.StatusCode} ({response.ReasonPhrase}): {body}"; + + throw new HttpRequestException(message, null, response.StatusCode); + } + + static async ValueTask CreateWebSocketAsync(Uri uri, string? apiKey, CancellationToken cancellationToken) + { + var webSocket = new ClientWebSocket(); + + if (!string.IsNullOrEmpty(apiKey)) + webSocket.Options.SetRequestHeader("Authorization", $"Bearer {apiKey}"); + + await webSocket.ConnectAsync(uri, cancellationToken).ConfigureAwait(false); + return webSocket; + } + + static async Task SendAudioAsync(WebSocket webSocket, Stream audioSpeechStream, CancellationToken cancellationToken) + { + var buffer = ArrayPool.Shared.Rent(DefaultStreamingChunkSize); + try + { + int bytesRead; + while ((bytesRead = await audioSpeechStream.ReadAsync(buffer.AsMemory(0, buffer.Length), cancellationToken).ConfigureAwait(false)) > 0) + { + await webSocket.SendAsync(new ArraySegment(buffer, 0, bytesRead), WebSocketMessageType.Binary, true, cancellationToken).ConfigureAwait(false); + } + } + finally + { + ArrayPool.Shared.Return(buffer); + } + } + + static Task SendJsonAsync(WebSocket webSocket, T value, JsonTypeInfo typeInfo, CancellationToken cancellationToken) + => webSocket.SendAsync(JsonSerializer.SerializeToUtf8Bytes(value, typeInfo), WebSocketMessageType.Text, true, cancellationToken); + + static async Task ReceiveJsonAsync(WebSocket webSocket, CancellationToken cancellationToken) + { + var buffer = ArrayPool.Shared.Rent(8192); + try + { + using var stream = new MemoryStream(); + + while (true) + { + var result = await webSocket.ReceiveAsync(new ArraySegment(buffer), cancellationToken).ConfigureAwait(false); + + if (result.MessageType == WebSocketMessageType.Close) + throw new InvalidOperationException($"xAI STT streaming connection closed before transcript.done: {result.CloseStatusDescription ?? result.CloseStatus?.ToString()}"); + + if (result.MessageType != WebSocketMessageType.Text) + throw new InvalidOperationException($"xAI STT streaming returned unsupported message type: {result.MessageType}"); + + stream.Write(buffer, 0, result.Count); + + if (result.EndOfMessage) + break; + } + + stream.Position = 0; + return await JsonDocument.ParseAsync(stream, cancellationToken: cancellationToken).ConfigureAwait(false); + } + finally + { + ArrayPool.Shared.Return(buffer); + } + } + + static string GetRequiredString(JsonElement json, string propertyName) + { + if (!json.TryGetProperty(propertyName, out var property) || property.ValueKind != JsonValueKind.String) + throw new InvalidOperationException($"xAI STT streaming event is missing required string property '{propertyName}'."); + + return property.GetString()!; + } + + static string? TryGetString(JsonElement json, string propertyName) => + json.TryGetProperty(propertyName, out var property) && property.ValueKind == JsonValueKind.String ? property.GetString() : null; + + static bool? GetBoolean(JsonElement json, string propertyName) => + json.TryGetProperty(propertyName, out var property) && property.ValueKind is JsonValueKind.True or JsonValueKind.False ? property.GetBoolean() : null; + + static double? TryGetDouble(JsonElement json, string propertyName) => + json.TryGetProperty(propertyName, out var property) && property.ValueKind == JsonValueKind.Number ? property.GetDouble() : null; + + static int? TryGetInt(JsonElement json, string propertyName) => + json.TryGetProperty(propertyName, out var property) && property.ValueKind == JsonValueKind.Number ? property.GetInt32() : null; + + [JsonSourceGenerationOptions(JsonSerializerDefaults.Web, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + PropertyNamingPolicy = JsonKnownNamingPolicy.SnakeCaseLower)] + [JsonSerializable(typeof(GrokSpeechToTextResponse))] + [JsonSerializable(typeof(AudioDoneMessage))] + partial class SpeechToTextJsonContext : JsonSerializerContext { } + + sealed record GrokSpeechToTextResponse( + string? Text, + string? Language, + double? Duration, + IReadOnlyList? Words, + IReadOnlyList? Channels); + + sealed record GrokSpeechToTextWord(string Text, double Start, double End, int? Speaker); + + sealed record GrokSpeechToTextChannel(int Index, string Text, IReadOnlyList? Words); + + sealed record AudioDoneMessage + { + public static readonly AudioDoneMessage Instance = new(); + + public string Type => "audio.done"; + } +} diff --git a/src/xAI/GrokSpeechToTextOptions.cs b/src/xAI/GrokSpeechToTextOptions.cs new file mode 100644 index 0000000..649091b --- /dev/null +++ b/src/xAI/GrokSpeechToTextOptions.cs @@ -0,0 +1,55 @@ +using Microsoft.Extensions.AI; + +namespace xAI; + +/// Grok-specific speech to text options that extend the base . +/// +/// These options map to xAI's /v1/stt REST and WebSocket parameters. +/// +public class GrokSpeechToTextOptions : SpeechToTextOptions +{ + /// Initializes a new instance of the class. + public GrokSpeechToTextOptions() + { + } + + /// Initializes a new instance of the class by cloning another instance. + protected GrokSpeechToTextOptions(GrokSpeechToTextOptions? other) + : base(other) + { + if (other is null) + return; + + Format = other.Format; + AudioFormat = other.AudioFormat; + Multichannel = other.Multichannel; + Channels = other.Channels; + Diarize = other.Diarize; + InterimResults = other.InterimResults; + Endpointing = other.Endpointing; + } + + /// Gets or sets a value indicating whether xAI should apply inverse text normalization to the transcript. + public bool? Format { get; set; } + + /// Gets or sets the raw input audio format hint or streaming encoding, such as pcm, mulaw, or alaw. + public string? AudioFormat { get; set; } + + /// Gets or sets a value indicating whether xAI should transcribe each channel independently. + public bool? Multichannel { get; set; } + + /// Gets or sets the number of audio channels. + public int? Channels { get; set; } + + /// Gets or sets a value indicating whether xAI should include speaker diarization data. + public bool? Diarize { get; set; } + + /// Gets or sets a value indicating whether xAI streaming should emit interim partial transcripts. + public bool? InterimResults { get; set; } + + /// Gets or sets the silence duration in milliseconds before xAI emits an utterance-final event. + public int? Endpointing { get; set; } + + /// + public override SpeechToTextOptions Clone() => new GrokSpeechToTextOptions(this); +} diff --git a/src/xAI/GrokTextToSpeechClient.cs b/src/xAI/GrokTextToSpeechClient.cs index faa8702..f2070bd 100644 --- a/src/xAI/GrokTextToSpeechClient.cs +++ b/src/xAI/GrokTextToSpeechClient.cs @@ -72,7 +72,6 @@ public async Task GetAudioAsync( return new TextToSpeechResponse([new DataContent(audio, mediaType)]) { - ModelId = options?.ModelId, RawRepresentation = raw, }; } @@ -104,7 +103,6 @@ public async IAsyncEnumerable GetStreamingAudioAsync { Kind = TextToSpeechResponseUpdateKind.AudioUpdating, Contents = [new DataContent(audio, GetMediaType(request.OutputFormat?.Codec))], - ModelId = options?.ModelId, RawRepresentation = rawRepresentation, }; break; @@ -113,7 +111,6 @@ public async IAsyncEnumerable GetStreamingAudioAsync var update = new TextToSpeechResponseUpdate { Kind = TextToSpeechResponseUpdateKind.SessionClose, - ModelId = options?.ModelId, RawRepresentation = rawRepresentation, };