diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..075a469 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,5 @@ +# xAI SDK implementation notes + +- `GrokClient` is primarily backed by generated gRPC protocol clients, but text to speech uses xAI's documented REST/WebSocket voice endpoints because there are no generated TTS protocol types in `src\xAI.Protocol`. +- `AsITextToSpeechClient` returns an `ITextToSpeechClient` implementation that uses `POST /v1/tts` for unary audio and `wss://.../v1/tts` for streaming audio. +- TTS defaults follow xAI docs: voice `eve`, language `en` when omitted by `TextToSpeechOptions`, and MP3 output when no codec is specified. diff --git a/readme.md b/readme.md index 35e6eb6..8405c89 100644 --- a/readme.md +++ b/readme.md @@ -45,6 +45,12 @@ var chat = new GrokClient(Environment.GetEnvironmentVariable("XAI_API_KEY")!) var images = new GrokClient(Environment.GetEnvironmentVariable("XAI_API_KEY")!) .AsIImageGenerator("grok-imagine-image"); + +var speech = new GrokClient(Environment.GetEnvironmentVariable("XAI_API_KEY")!) + .AsITextToSpeechClient(); + +var audio = await speech.GetAudioAsync("Hello! Welcome to xAI text to speech.", + new TextToSpeechOptions { VoiceId = "eve", Language = "en" }); ``` ## File Attachments @@ -393,6 +399,72 @@ var editedImage = (UriContent)result.Contents.First(); Console.WriteLine($"Edited image URL: {editedImage.Uri}"); ``` +## Text to Speech + +Grok supports text to speech via the `ITextToSpeechClient` abstraction from Microsoft.Extensions.AI. +Use `AsITextToSpeechClient` to get a TTS client: + +```csharp +var speech = new GrokClient(Environment.GetEnvironmentVariable("XAI_API_KEY")!) + .AsITextToSpeechClient(); +``` + +### Unary (single response) + +Call `GetAudioAsync` to synthesize speech in a single request. The result contains a `DataContent` +with the audio bytes and media type: + +```csharp +var response = await speech.GetAudioAsync("Hello! Welcome to xAI text to speech.", + new TextToSpeechOptions { VoiceId = "eve", Language = "en" }); + +var audio = (DataContent)response.Contents.First(); +// audio.MediaType == "audio/mpeg" (MP3 by default) +await File.WriteAllBytesAsync("output.mp3", audio.Data.ToArray()); +``` + +Available voices include `ara`, `eve`, `leo`, `rex`, and `sal`. Defaults to `eve` and English when +`VoiceId`/`Language` are not specified. + +### Streaming + +Call `GetStreamingAudioAsync` to receive audio chunks as they are generated, enabling low-latency +playback or progressive file writes: + +```csharp +await using var fileStream = File.Create("output.mp3"); + +await foreach (var update in speech.GetStreamingAudioAsync("Hello from streaming TTS!", + new TextToSpeechOptions { VoiceId = "eve", AudioFormat = "mp3" })) +{ + if (update.Kind == TextToSpeechResponseUpdateKind.AudioUpdating) + { + foreach (var content in update.Contents.OfType()) + await fileStream.WriteAsync(content.Data); + } +} +``` + +### Grok-Specific Options + +Use `GrokTextToSpeechOptions` to control audio quality and streaming behavior beyond the base +`TextToSpeechOptions`: + +```csharp +var options = new GrokTextToSpeechOptions +{ + VoiceId = "rex", + Language = "en", + AudioFormat = "mp3", // mp3 | wav | pcm | mulaw | alaw + SampleRate = 24000, // Hz + BitRate = 128000, // bits per second (MP3 only) + OptimizeStreamingLatency = 1, // 0–4; higher trades quality for lower latency + TextNormalization = true, // expand abbreviations and numbers before synthesis +}; + +var response = await speech.GetAudioAsync("Streaming at 24 kHz, 128 kbps.", options); +``` + # xAI.Protocol diff --git a/src/xAI.Tests/ChatClientTests.cs b/src/xAI.Tests/ChatClientTests.cs index e20f0d3..7feb5d7 100644 --- a/src/xAI.Tests/ChatClientTests.cs +++ b/src/xAI.Tests/ChatClientTests.cs @@ -21,7 +21,9 @@ public async Task OpenAIInvokesTools() { "user", "What day is today?" }, }; - var chat = new OpenAIClient(Configuration["OPENAI_API_KEY"]!).GetChatClient("gpt-5.4").AsIChatClient() + var chat = new OpenAIClient(Configuration["OPENAI_API_KEY"]!) + .GetChatClient("gpt-5.4") + .AsIChatClient() .AsBuilder() .UseFunctionInvocation(configure: client => client.MaximumIterationsPerRequest = 3) .UseLogging(output.AsLoggerFactory()) @@ -96,10 +98,10 @@ public async Task GrokInvokesTools() [SecretsFact("XAI_API_KEY")] public async Task GrokReasoningModelOutputsBothContentAndEncryptedReasoning() { - var grok = new GrokClient(Configuration["XAI_API_KEY"]!).AsIChatClient("grok-4-1-fast"); + var grok = new GrokClient(Configuration["XAI_API_KEY"]!).AsIChatClient("grok-4-1-fast-reasoning"); var response = await grok.GetResponseAsync( - "What is 3 + 4? Respond with just the number.", + "What is 3 + 4? Respond with just the number, think about it really well.", new GrokChatOptions { UseEncryptedContent = true diff --git a/src/xAI.Tests/TextToSpeechClientTests.cs b/src/xAI.Tests/TextToSpeechClientTests.cs new file mode 100644 index 0000000..266f451 --- /dev/null +++ b/src/xAI.Tests/TextToSpeechClientTests.cs @@ -0,0 +1,370 @@ +using System.Net; +using System.Net.Http.Headers; +using System.Net.WebSockets; +using System.Text; +using System.Text.Json; +using Grpc.Net.Client; +using Microsoft.Extensions.AI; + +namespace xAI.Tests; + +public class TextToSpeechClientTests +{ + [Fact] + public void AsITextToSpeechClient_ReturnsMetadata() + { + using var client = new GrokClient("test-api-key", CreateOptions(new CaptureHandler())); + using var tts = client.AsITextToSpeechClient(); + + var metadata = tts.GetService(); + + Assert.NotNull(metadata); + Assert.Equal("xai", metadata.ProviderName); + Assert.Equal(client.Options.Endpoint, metadata.ProviderUri); + Assert.Null(metadata.DefaultModelId); + } + + [Fact] + public async Task GetAudioAsync_MapsRequestAndResponse() + { + var audio = new byte[] { 1, 2, 3 }; + var handler = new CaptureHandler(_ => new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new ByteArrayContent(audio) + { + Headers = + { + ContentType = new MediaTypeHeaderValue("audio/wav"), + } + } + }); + + using var client = new GrokClient("test-api-key", CreateOptions(handler)); + using var tts = client.AsITextToSpeechClient(); + + var response = await tts.GetAudioAsync("Hello from Grok.", + new GrokTextToSpeechOptions + { + VoiceId = "rex", + Language = "pt-BR", + AudioFormat = "audio/wav", + SampleRate = 44100, + BitRate = 192000, + OptimizeStreamingLatency = 1, + TextNormalization = true, + ModelId = "test-model", + }); + + Assert.Equal(HttpMethod.Post, handler.Request!.Method); + Assert.Equal(new Uri($"{client.Options.Endpoint}v1/tts"), handler.Request.RequestUri); + Assert.Equal("Bearer", handler.Request.Headers.Authorization?.Scheme); + Assert.Equal("test-api-key", handler.Request.Headers.Authorization?.Parameter); + + using var json = JsonDocument.Parse(handler.RequestBody!); + var root = json.RootElement; + Assert.Equal("Hello from Grok.", root.GetProperty("text").GetString()); + Assert.Equal("rex", root.GetProperty("voice_id").GetString()); + Assert.Equal("pt-BR", root.GetProperty("language").GetString()); + Assert.Equal(1, root.GetProperty("optimize_streaming_latency").GetInt32()); + Assert.True(root.GetProperty("text_normalization").GetBoolean()); + + var outputFormat = root.GetProperty("output_format"); + Assert.Equal("wav", outputFormat.GetProperty("codec").GetString()); + Assert.Equal(44100, outputFormat.GetProperty("sample_rate").GetInt32()); + Assert.Equal(192000, outputFormat.GetProperty("bit_rate").GetInt32()); + + var content = Assert.Single(response.Contents); + var data = Assert.IsType(content); + Assert.Equal("audio/wav", data.MediaType); + Assert.Equal(audio, data.Data.ToArray()); + Assert.Equal("test-model", response.ModelId); + } + + [Theory] + [InlineData(null, "audio/mpeg")] + [InlineData("mp3", "audio/mpeg")] + [InlineData("wav", "audio/wav")] + [InlineData("pcm", "audio/pcm")] + [InlineData("mulaw", "audio/basic")] + [InlineData("alaw", "audio/alaw")] + public async Task GetAudioAsync_MapsCodecToMediaType(string? audioFormat, string expectedMediaType) + { + var handler = new CaptureHandler(_ => new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new ByteArrayContent([1]), + }); + + using var client = new GrokClient("test-api-key", CreateOptions(handler)); + using var tts = client.AsITextToSpeechClient(); + + var response = await tts.GetAudioAsync("Hello.", new TextToSpeechOptions { AudioFormat = audioFormat }); + + var data = Assert.IsType(Assert.Single(response.Contents)); + Assert.Equal(expectedMediaType, data.MediaType); + } + + [Fact] + public async Task GetAudioAsync_WithDefaults_SendsRequiredFieldsOnly() + { + var handler = new CaptureHandler(_ => new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new ByteArrayContent([1]), + }); + + using var client = new GrokClient("test-api-key", CreateOptions(handler)); + using var tts = client.AsITextToSpeechClient(); + + await tts.GetAudioAsync("Hello."); + + using var json = JsonDocument.Parse(handler.RequestBody!); + var root = json.RootElement; + Assert.Equal("Hello.", root.GetProperty("text").GetString()); + Assert.Equal("eve", root.GetProperty("voice_id").GetString()); + Assert.Equal("en", root.GetProperty("language").GetString()); + Assert.False(root.TryGetProperty("output_format", out _)); + } + + [Fact] + public async Task GetAudioAsync_WithError_ThrowsHttpRequestException() + { + var handler = new CaptureHandler(_ => new HttpResponseMessage(HttpStatusCode.BadRequest) + { + ReasonPhrase = "Bad Request", + Content = new StringContent("""{"error":"invalid language"}"""), + }); + + using var client = new GrokClient("test-api-key", CreateOptions(handler)); + using var tts = client.AsITextToSpeechClient(); + + var exception = await Assert.ThrowsAsync(() => tts.GetAudioAsync("Hello.")); + + Assert.Equal(HttpStatusCode.BadRequest, exception.StatusCode); + Assert.Contains("invalid language", exception.Message); + } + + [Fact] + public async Task GetAudioAsync_WithNullText_ThrowsArgumentNullException() + { + using var client = new GrokClient("test-api-key", CreateOptions(new CaptureHandler())); + using var tts = client.AsITextToSpeechClient(); + + await Assert.ThrowsAsync(() => tts.GetAudioAsync(null!)); + } + + [Fact] + public async Task GetStreamingAudioAsync_MapsWebSocketEvents() + { + var webSocket = new FakeWebSocket( + """{"type":"audio.delta","delta":"AQID"}""", + """{"type":"audio.done","trace_id":"trace-123"}"""); + + Uri? capturedUri = null; + string? capturedApiKey = null; + using var tts = new GrokTextToSpeechClient( + new HttpClient(new CaptureHandler()), + new Uri("https://streaming.test/base/"), + "test-api-key", + (uri, apiKey, _) => + { + capturedUri = uri; + capturedApiKey = apiKey; + return ValueTask.FromResult(webSocket); + }); + + var updates = new List(); + await foreach (var update in tts.GetStreamingAudioAsync("Hello.", + new GrokTextToSpeechOptions + { + VoiceId = "ara", + Language = "auto", + AudioFormat = "mulaw", + SampleRate = 8000, + OptimizeStreamingLatency = 1, + TextNormalization = true, + })) + { + updates.Add(update); + } + + Assert.Equal("test-api-key", capturedApiKey); + Assert.Equal("wss://streaming.test/base/v1/tts?voice=ara&language=auto&codec=mulaw&sample_rate=8000&optimize_streaming_latency=1&text_normalization=true", capturedUri!.ToString()); + + Assert.Collection(webSocket.SentMessages, + message => + { + using var json = JsonDocument.Parse(message); + Assert.Equal("text.delta", json.RootElement.GetProperty("type").GetString()); + Assert.Equal("Hello.", json.RootElement.GetProperty("delta").GetString()); + }, + message => + { + using var json = JsonDocument.Parse(message); + Assert.Equal("text.done", json.RootElement.GetProperty("type").GetString()); + }); + + Assert.Collection(updates, + update => + { + Assert.Equal(TextToSpeechResponseUpdateKind.AudioUpdating, update.Kind); + var data = Assert.IsType(Assert.Single(update.Contents)); + Assert.Equal(new byte[] { 1, 2, 3 }, data.Data.ToArray()); + Assert.Equal("audio/basic", data.MediaType); + }, + update => + { + Assert.Equal(TextToSpeechResponseUpdateKind.SessionClose, update.Kind); + Assert.Equal("trace-123", update.AdditionalProperties?["trace_id"]); + }); + } + + [Fact] + public async Task GetStreamingAudioAsync_WithErrorEvent_ThrowsInvalidOperationException() + { + var webSocket = new FakeWebSocket("""{"type":"error","message":"voice rejected"}"""); + using var tts = new GrokTextToSpeechClient( + new HttpClient(new CaptureHandler()), + new Uri("https://streaming.test/"), + "test-api-key", + (_, _, _) => ValueTask.FromResult(webSocket)); + + var exception = await Assert.ThrowsAsync(async () => + { + await foreach (var _ in tts.GetStreamingAudioAsync("Hello.")) + { + } + }); + + Assert.Contains("voice rejected", exception.Message); + } + + [SecretsTheory("XAI_API_KEY")] + //[InlineData("ara")] + //[InlineData("eve")] + [InlineData("rex")] // 👈 el mejor para Jesus + //[InlineData("sal")] + //[InlineData("leo")] + public async Task GetStreamingAudioAsync_IntegrationTest_SavesAndPlaysAudio(string voiceId) + { + var apiKey = Environment.GetEnvironmentVariable("XAI_API_KEY"); + using var client = new GrokClient(apiKey!); + using var tts = client.AsITextToSpeechClient(); + + var tempFile = System.IO.Path.Combine(System.IO.Path.GetTempPath(), $"xai-tts-{Guid.NewGuid():N}.mp3"); + + await using (var fileStream = System.IO.File.Create(tempFile)) + { + await foreach (var update in tts.GetStreamingAudioAsync( + """ + El que cree en mí, en realidad no cree en mí, sino en aquel que me envió. + Y el que me ve, ve al que me envió. + Yo soy la luz, y he venido al mundo para que todo el que crea en mí no permanezca en las tinieblas. + """, + new GrokTextToSpeechOptions + { + VoiceId = voiceId, + AudioFormat = "mp3", + + })) + { + if (update.Kind == TextToSpeechResponseUpdateKind.AudioUpdating) + { + foreach (var content in update.Contents) + { + if (content is DataContent data) + { + await fileStream.WriteAsync(data.Data); + } + } + } + } + } + + Assert.True(System.IO.File.Exists(tempFile)); + Assert.True(new System.IO.FileInfo(tempFile).Length > 0); + + System.Diagnostics.Process.Start(new System.Diagnostics.ProcessStartInfo + { + FileName = tempFile, + UseShellExecute = true + }); + } + + static GrokClientOptions CreateOptions(HttpMessageHandler handler) => new() + { + Endpoint = new Uri($"https://unit-{Guid.NewGuid():N}.test/"), + ChannelOptions = new GrpcChannelOptions + { + HttpHandler = handler, + }, + }; + + sealed class CaptureHandler(Func? responder = null) : HttpMessageHandler + { + readonly Func responder = responder ?? (_ => new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new ByteArrayContent([1]), + }); + + public HttpRequestMessage? Request { get; private set; } + public string? RequestBody { get; private set; } + + protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + Request = request; + RequestBody = request.Content is null ? null : await request.Content.ReadAsStringAsync(cancellationToken); + return responder(request); + } + } + + sealed class FakeWebSocket(params string[] messages) : WebSocket + { + readonly Queue messages = new(messages.Select(Encoding.UTF8.GetBytes)); + WebSocketState state = WebSocketState.Open; + WebSocketCloseStatus? closeStatus; + string? closeStatusDescription; + + public List SentMessages { get; } = []; + + public override WebSocketCloseStatus? CloseStatus => closeStatus; + + public override string? CloseStatusDescription => closeStatusDescription; + + public override WebSocketState State => state; + + public override string? SubProtocol => null; + + public override void Abort() => state = WebSocketState.Aborted; + + public override Task CloseAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken) + { + this.closeStatus = closeStatus; + closeStatusDescription = statusDescription; + state = WebSocketState.Closed; + return Task.CompletedTask; + } + + public override Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken) + => CloseAsync(closeStatus, statusDescription, cancellationToken); + + public override void Dispose() => state = WebSocketState.Closed; + + public override Task ReceiveAsync(ArraySegment buffer, CancellationToken cancellationToken) + { + if (messages.Count == 0) + { + state = WebSocketState.CloseReceived; + return Task.FromResult(new WebSocketReceiveResult(0, WebSocketMessageType.Close, true, WebSocketCloseStatus.NormalClosure, "closed")); + } + + var message = messages.Dequeue(); + message.CopyTo(buffer.Array!, buffer.Offset); + return Task.FromResult(new WebSocketReceiveResult(message.Length, WebSocketMessageType.Text, true)); + } + + public override Task SendAsync(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) + { + SentMessages.Add(Encoding.UTF8.GetString(buffer.Array!, buffer.Offset, buffer.Count)); + return Task.CompletedTask; + } + } +} diff --git a/src/xAI/GrokClient.cs b/src/xAI/GrokClient.cs index 2274895..43871ec 100644 --- a/src/xAI/GrokClient.cs +++ b/src/xAI/GrokClient.cs @@ -15,13 +15,14 @@ namespace xAI; /// The options used to configure the client. public sealed class GrokClient(string apiKey, GrokClientOptions options) : IDisposable { - static readonly ConcurrentDictionary<(Uri, string), ChannelBase> channels = []; + static readonly ConcurrentDictionary<(Uri, string), (ChannelBase, HttpMessageHandler)> channels = []; /// Initializes a new instance of the class with default options. public GrokClient(string apiKey) : this(apiKey, new GrokClientOptions()) { } - internal GrokClient(ChannelBase channel, GrokClientOptions options) : this("", options) - => channels[(options.Endpoint, "")] = channel; + /// Testing ctor. + internal GrokClient(ChannelBase channel, GrokClientOptions options, string? apiKey = default) : this(apiKey ?? "", options) + => channels[(options.Endpoint, apiKey ?? "")] = (channel, GetHttpHandler(options.ChannelOptions, apiKey ?? "")); /// Gets the API key used for authentication. public string ApiKey { get; } = apiKey; @@ -33,29 +34,47 @@ internal GrokClient(ChannelBase channel, GrokClientOptions options) : this("", o public GrokClientOptions Options { get; } = options; /// Gets a new instance of that reuses the client configuration details provided to the instance. - public Auth.AuthClient GetAuthClient() => new(Channel); + public Auth.AuthClient GetAuthClient() => new(ChannelHandler.Channel); /// Gets a new instance of that reuses the client configuration details provided to the instance. - public Chat.ChatClient GetChatClient() => new(Channel, Options); + public Chat.ChatClient GetChatClient() => new(ChannelHandler.Channel, Options); /// Gets a new instance of that reuses the client configuration details provided to the instance. - public Documents.DocumentsClient GetDocumentsClient() => new(Channel); + public Documents.DocumentsClient GetDocumentsClient() => new(ChannelHandler.Channel); /// Gets a new instance of that reuses the client configuration details provided to the instance. - public Embedder.EmbedderClient GetEmbedderClient() => new(Channel); + public Embedder.EmbedderClient GetEmbedderClient() => new(ChannelHandler.Channel); /// Gets a new instance of that reuses the client configuration details provided to the instance. - public Image.ImageClient GetImageClient() => new(Channel, Options); + public Image.ImageClient GetImageClient() => new(ChannelHandler.Channel, Options); /// Gets a new instance of that reuses the client configuration details provided to the instance. - public Models.ModelsClient GetModelsClient() => new(Channel); + public Models.ModelsClient GetModelsClient() => new(ChannelHandler.Channel); /// Gets a new instance of that reuses the client configuration details provided to the instance. - public Tokenize.TokenizeClient GetTokenizeClient() => new(Channel); + public Tokenize.TokenizeClient GetTokenizeClient() => new(ChannelHandler.Channel); - internal ChannelBase Channel => channels.GetOrAdd((Endpoint, ApiKey), key => + internal (ChannelBase Channel, HttpMessageHandler Handler) ChannelHandler => channels.GetOrAdd((Endpoint, ApiKey), key => { - var inner = Options.ChannelOptions?.HttpHandler; + var handler = GetHttpHandler(Options.ChannelOptions, key.Item2); + + // Provide some sensible defaults for gRPC channel options, while allowing users to + // override them via GrokClientOptions.ChannelOptions if needed. + var options = Options.ChannelOptions ?? new GrpcChannelOptions + { + DisposeHttpClient = true, + MaxReceiveMessageSize = 128 * 1024 * 1024, // large enough for tool output + MaxSendMessageSize = 16 * 1024 * 1024, + }; + + options.HttpHandler = handler; + + return (GrpcChannel.ForAddress(key.Item1, options), handler); + }); + + static HttpMessageHandler GetHttpHandler(GrpcChannelOptions? options, string apiKey) + { + var inner = options?.HttpHandler; if (inner == null) { // If no custom HttpHandler is provided, we create one with Polly retry @@ -94,24 +113,13 @@ StatusCode.DeadlineExceeded or }; } - var handler = new AuthenticationHeaderHandler(ApiKey) + var handler = string.IsNullOrEmpty(apiKey) ? inner : new AuthenticationHeaderHandler(apiKey) { InnerHandler = inner }; - // Provide some sensible defaults for gRPC channel options, while allowing users to - // override them via GrokClientOptions.ChannelOptions if needed. - var options = Options.ChannelOptions ?? new GrpcChannelOptions - { - DisposeHttpClient = true, - MaxReceiveMessageSize = 128 * 1024 * 1024, // large enough for tool output - MaxSendMessageSize = 16 * 1024 * 1024, - }; - - options.HttpHandler = handler; - - return GrpcChannel.ForAddress(Endpoint, options); - }); + return handler; + } /// Clears the cached list of gRPC channels in the client. public void Dispose() => channels.Clear(); diff --git a/src/xAI/GrokClientExtensions.cs b/src/xAI/GrokClientExtensions.cs index 9bf53fb..f81da7c 100644 --- a/src/xAI/GrokClientExtensions.cs +++ b/src/xAI/GrokClientExtensions.cs @@ -10,7 +10,7 @@ public static class GrokClientExtensions { /// Creates a new from the specified using the given model as the default. public static IChatClient AsIChatClient(this GrokClient client, string defaultModelId) - => new GrokChatClient(client.Channel, client.Options, defaultModelId); + => new GrokChatClient(client.ChannelHandler.Channel, client.Options, defaultModelId); /// Creates a new from the specified using the given model as the default. public static IChatClient AsIChatClient(this Chat.ChatClient client, string defaultModelId) @@ -18,9 +18,13 @@ public static IChatClient AsIChatClient(this Chat.ChatClient client, string defa /// Creates a new from the specified using the given model as the default. public static IImageGenerator AsIImageGenerator(this GrokClient client, string defaultModelId) - => new GrokImageGenerator(client.Channel, client.Options, defaultModelId); + => new GrokImageGenerator(client.ChannelHandler.Channel, client.Options, defaultModelId); /// Creates a new from the specified using the given model as the default. public static IImageGenerator AsIImageGenerator(this Image.ImageClient client, string defaultModelId) => new GrokImageGenerator(client, defaultModelId); -} \ No newline at end of file + + /// Creates a new from the specified . + public static ITextToSpeechClient AsITextToSpeechClient(this GrokClient client) + => new GrokTextToSpeechClient(client.ChannelHandler.Handler, client.Options, client.ApiKey); +} diff --git a/src/xAI/GrokTextToSpeechClient.cs b/src/xAI/GrokTextToSpeechClient.cs new file mode 100644 index 0000000..faa8702 --- /dev/null +++ b/src/xAI/GrokTextToSpeechClient.cs @@ -0,0 +1,343 @@ +using System.Buffers; +using System.Collections.Specialized; +using System.Net.Http.Json; +using System.Net.WebSockets; +using System.Runtime.CompilerServices; +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization; +using System.Text.Json.Serialization.Metadata; +using Microsoft.Extensions.AI; + +namespace xAI; + +/// Represents an for xAI's Grok text to speech service. +partial class GrokTextToSpeechClient : ITextToSpeechClient +{ + const string DefaultVoice = "eve"; + const string DefaultLanguage = "en"; + const string DefaultCodec = "mp3"; + + + readonly TextToSpeechClientMetadata metadata; + readonly HttpClient httpClient; + readonly Uri endpoint; + readonly string? apiKey; + readonly Func> webSocketFactory; + + internal GrokTextToSpeechClient(HttpMessageHandler handler, GrokClientOptions options, string? apiKey) + : this(new HttpClient(handler, disposeHandler: false), options.Endpoint, apiKey, CreateWebSocketAsync) + { + } + + internal GrokTextToSpeechClient( + HttpClient httpClient, + Uri endpoint, + string? apiKey, + Func> webSocketFactory) + { + this.httpClient = Throw.IfNull(httpClient); + this.endpoint = Throw.IfNull(endpoint); + this.apiKey = apiKey; + this.webSocketFactory = Throw.IfNull(webSocketFactory); + + metadata = new("xai", endpoint); + } + + /// + public async Task GetAudioAsync( + string text, + TextToSpeechOptions? options = null, + CancellationToken cancellationToken = default) + { + var request = CreateRequest(Throw.IfNull(text), options); + using var message = new HttpRequestMessage(HttpMethod.Post, GetHttpEndpoint()) + { + Content = JsonContent.Create(request, JsonContext.Default.GrokTextToSpeechRequest), + }; + + using var response = await httpClient.SendAsync(message, cancellationToken).ConfigureAwait(false); + + if (!response.IsSuccessStatusCode) + await ThrowHttpExceptionAsync(response, cancellationToken).ConfigureAwait(false); + + var audio = await response.Content.ReadAsByteArrayAsync(cancellationToken).ConfigureAwait(false); + var mediaType = response.Content.Headers.ContentType?.MediaType ?? GetMediaType(request.OutputFormat?.Codec); + + var raw = new HttpResponseMessage(response.StatusCode); + foreach (var header in response.Headers) + raw.Headers.TryAddWithoutValidation(header.Key, header.Value); + foreach (var header in response.Content.Headers) + raw.Content.Headers.TryAddWithoutValidation(header.Key, header.Value); + + return new TextToSpeechResponse([new DataContent(audio, mediaType)]) + { + ModelId = options?.ModelId, + RawRepresentation = raw, + }; + } + + /// + public async IAsyncEnumerable GetStreamingAudioAsync( + string text, + TextToSpeechOptions? options = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var request = CreateRequest(Throw.IfNull(text), options); + using var webSocket = await webSocketFactory(GetStreamingEndpoint(request), apiKey, cancellationToken).ConfigureAwait(false); + + await SendJsonAsync(webSocket, new TextDeltaMessage(text), JsonContext.Default.TextDeltaMessage, cancellationToken).ConfigureAwait(false); + await SendJsonAsync(webSocket, TextDoneMessage.Instance, JsonContext.Default.TextDoneMessage, cancellationToken).ConfigureAwait(false); + + while (true) + { + using var json = await ReceiveJsonAsync(webSocket, cancellationToken).ConfigureAwait(false); + var root = json.RootElement; + var rawRepresentation = root.Clone(); + var type = GetRequiredString(root, "type"); + + switch (type) + { + case "audio.delta": + var audio = Convert.FromBase64String(GetRequiredString(root, "delta")); + yield return new TextToSpeechResponseUpdate + { + Kind = TextToSpeechResponseUpdateKind.AudioUpdating, + Contents = [new DataContent(audio, GetMediaType(request.OutputFormat?.Codec))], + ModelId = options?.ModelId, + RawRepresentation = rawRepresentation, + }; + break; + + case "audio.done": + var update = new TextToSpeechResponseUpdate + { + Kind = TextToSpeechResponseUpdateKind.SessionClose, + ModelId = options?.ModelId, + RawRepresentation = rawRepresentation, + }; + + if (root.TryGetProperty("trace_id", out var traceId) && traceId.ValueKind == JsonValueKind.String) + { + update.AdditionalProperties = new() + { + ["trace_id"] = traceId.GetString(), + }; + } + + yield return update; + yield break; + + case "error": + throw new InvalidOperationException($"xAI TTS streaming error: {GetRequiredString(root, "message")}"); + + default: + throw new InvalidOperationException($"Unsupported xAI TTS streaming event type: {type}"); + } + } + } + + /// + public object? GetService(Type serviceType, object? serviceKey = null) => serviceKey is not null ? null : serviceType switch + { + Type t when t == typeof(TextToSpeechClientMetadata) => metadata, + Type t when t == typeof(GrokTextToSpeechClient) => this, + Type t when t == typeof(HttpClient) => httpClient, + Type t when t.IsInstanceOfType(this) => this, + _ => null + }; + + /// + public void Dispose() => httpClient.Dispose(); + + static GrokTextToSpeechRequest CreateRequest(string text, TextToSpeechOptions? options) + { + var codec = GetCodec(options?.AudioFormat); + var grokOptions = options as GrokTextToSpeechOptions; + var outputFormat = + codec != DefaultCodec || + grokOptions?.SampleRate is not null || + grokOptions?.BitRate is not null + ? new GrokTextToSpeechOutputFormat(codec, grokOptions?.SampleRate, grokOptions?.BitRate) + : null; + + return new( + text, + options?.VoiceId ?? DefaultVoice, + options?.Language ?? DefaultLanguage, + outputFormat, + grokOptions?.OptimizeStreamingLatency, + grokOptions?.TextNormalization); + } + + Uri GetHttpEndpoint() => GetEndpoint(endpoint, "https", "v1/tts", null); + + Uri GetStreamingEndpoint(GrokTextToSpeechRequest request) + { + var query = new NameValueCollection + { + ["voice"] = request.VoiceId, + ["language"] = request.Language, + ["codec"] = request.OutputFormat?.Codec ?? DefaultCodec, + }; + + if (request.OutputFormat?.SampleRate is int sampleRate) + query["sample_rate"] = sampleRate.ToString(System.Globalization.CultureInfo.InvariantCulture); + + if (request.OutputFormat?.BitRate is int bitRate) + query["bit_rate"] = bitRate.ToString(System.Globalization.CultureInfo.InvariantCulture); + + if (request.OptimizeStreamingLatency is int optimizeStreamingLatency) + query["optimize_streaming_latency"] = optimizeStreamingLatency.ToString(System.Globalization.CultureInfo.InvariantCulture); + + if (request.TextNormalization is bool textNormalization) + query["text_normalization"] = textNormalization ? "true" : "false"; + + return GetEndpoint(endpoint, endpoint.Scheme == Uri.UriSchemeHttp ? "ws" : "wss", "v1/tts", query); + } + + static Uri GetEndpoint(Uri endpoint, string scheme, string relativePath, NameValueCollection? query) => new UriBuilder(endpoint) + { + Scheme = scheme, + Path = CombinePath(endpoint.AbsolutePath, relativePath), + Query = query is null ? "" : ToQueryString(query), + }.Uri; + + static string CombinePath(string basePath, string relativePath) + { + var path = basePath == "/" ? "" : basePath.TrimEnd('/'); + return $"{path}/{relativePath.TrimStart('/')}"; + } + + static string ToQueryString(NameValueCollection query) + { + var builder = new StringBuilder(); + + foreach (string key in query) + { + if (query[key] is not { } value) + continue; + + if (builder.Length > 0) + builder.Append('&'); + + builder + .Append(Uri.EscapeDataString(key)) + .Append('=') + .Append(Uri.EscapeDataString(value)); + } + + return builder.ToString(); + } + + static string GetCodec(string? format) => format?.ToUpperInvariant() switch + { + null or "" => DefaultCodec, + "MP3" or "AUDIO/MPEG" => "mp3", + "WAV" or "AUDIO/WAV" => "wav", + "PCM" or "AUDIO/PCM" or "AUDIO/L16" => "pcm", + "MULAW" or "ULAW" or "AUDIO/BASIC" => "mulaw", + "ALAW" or "AUDIO/ALAW" => "alaw", + _ => format.ToLowerInvariant(), + }; + + static string GetMediaType(string? codec) => codec switch + { + null or "" or "mp3" => "audio/mpeg", + "wav" => "audio/wav", + "pcm" => "audio/pcm", + "mulaw" or "ulaw" => "audio/basic", + "alaw" => "audio/alaw", + _ => "application/octet-stream", + }; + + static async Task ThrowHttpExceptionAsync(HttpResponseMessage response, CancellationToken cancellationToken) + { + var body = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false); + var message = string.IsNullOrWhiteSpace(body) ? + $"xAI TTS request failed with status code {(int)response.StatusCode} ({response.ReasonPhrase})." : + $"xAI TTS request failed with status code {(int)response.StatusCode} ({response.ReasonPhrase}): {body}"; + + throw new HttpRequestException(message, null, response.StatusCode); + } + + static async ValueTask CreateWebSocketAsync(Uri uri, string? apiKey, CancellationToken cancellationToken) + { + var webSocket = new ClientWebSocket(); + + if (!string.IsNullOrEmpty(apiKey)) + webSocket.Options.SetRequestHeader("Authorization", $"Bearer {apiKey}"); + + await webSocket.ConnectAsync(uri, cancellationToken).ConfigureAwait(false); + return webSocket; + } + + static Task SendJsonAsync(WebSocket webSocket, T value, JsonTypeInfo typeInfo, CancellationToken cancellationToken) + => webSocket.SendAsync(JsonSerializer.SerializeToUtf8Bytes(value, typeInfo), WebSocketMessageType.Text, true, cancellationToken); + + static async Task ReceiveJsonAsync(WebSocket webSocket, CancellationToken cancellationToken) + { + var buffer = ArrayPool.Shared.Rent(8192); + try + { + using var stream = new MemoryStream(); + + while (true) + { + var result = await webSocket.ReceiveAsync(new ArraySegment(buffer), cancellationToken).ConfigureAwait(false); + + if (result.MessageType == WebSocketMessageType.Close) + throw new InvalidOperationException($"xAI TTS streaming connection closed before audio.done: {result.CloseStatusDescription ?? result.CloseStatus?.ToString()}"); + + if (result.MessageType != WebSocketMessageType.Text) + throw new InvalidOperationException($"xAI TTS streaming returned unsupported message type: {result.MessageType}"); + + stream.Write(buffer, 0, result.Count); + + if (result.EndOfMessage) + break; + } + + stream.Position = 0; + return await JsonDocument.ParseAsync(stream, cancellationToken: cancellationToken).ConfigureAwait(false); + } + finally + { + ArrayPool.Shared.Return(buffer); + + } + } + + static string GetRequiredString(JsonElement json, string propertyName) + { + if (!json.TryGetProperty(propertyName, out var property) || property.ValueKind != JsonValueKind.String) + throw new InvalidOperationException($"xAI TTS streaming event is missing required string property '{propertyName}'."); + + return property.GetString()!; + } + + [JsonSourceGenerationOptions(JsonSerializerDefaults.Web, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + PropertyNamingPolicy = JsonKnownNamingPolicy.SnakeCaseLower)] + [JsonSerializable(typeof(GrokTextToSpeechRequest))] + [JsonSerializable(typeof(TextDeltaMessage))] + [JsonSerializable(typeof(TextDoneMessage))] + partial class JsonContext : JsonSerializerContext { } + + sealed record GrokTextToSpeechRequest(string Text, string VoiceId, string Language, + GrokTextToSpeechOutputFormat? OutputFormat, int? OptimizeStreamingLatency, bool? TextNormalization); + + sealed record GrokTextToSpeechOutputFormat(string Codec, int? SampleRate, int? BitRate); + + sealed record TextDeltaMessage(string Delta) + { + public string Type => "text.delta"; + } + + sealed record TextDoneMessage + { + public static readonly TextDoneMessage Instance = new(); + + public string Type => "text.done"; + } +} diff --git a/src/xAI/GrokTextToSpeechOptions.cs b/src/xAI/GrokTextToSpeechOptions.cs new file mode 100644 index 0000000..3d72a5a --- /dev/null +++ b/src/xAI/GrokTextToSpeechOptions.cs @@ -0,0 +1,44 @@ +using Microsoft.Extensions.AI; + +namespace xAI; + +/// Grok-specific text to speech options that extend the base . +/// +/// These options map to xAI's /v1/tts REST and WebSocket parameters. +/// If not specified, the API defaults to MP3 at 24 kHz / 128 kbps. +/// +public class GrokTextToSpeechOptions : TextToSpeechOptions +{ + /// Initializes a new instance of the class. + public GrokTextToSpeechOptions() + { + } + + /// Initializes a new instance of the class by cloning another instance. + protected GrokTextToSpeechOptions(GrokTextToSpeechOptions? other) + : base(other) + { + if (other is null) + return; + + SampleRate = other.SampleRate; + BitRate = other.BitRate; + OptimizeStreamingLatency = other.OptimizeStreamingLatency; + TextNormalization = other.TextNormalization; + } + + /// Gets or sets the output sample rate in Hz. + public int? SampleRate { get; set; } + + /// Gets or sets the MP3 bit rate in bits per second. + public int? BitRate { get; set; } + + /// Gets or sets the xAI streaming latency optimization level. + public int? OptimizeStreamingLatency { get; set; } + + /// Gets or sets a value indicating whether xAI should normalize written-form text before synthesis. + public bool? TextNormalization { get; set; } + + /// + public override TextToSpeechOptions Clone() => new GrokTextToSpeechOptions(this); +}