diff --git a/src/xAI.Tests/ChatClientTests.cs b/src/xAI.Tests/ChatClientTests.cs index 93daa6d..df0a1d9 100644 --- a/src/xAI.Tests/ChatClientTests.cs +++ b/src/xAI.Tests/ChatClientTests.cs @@ -1,20 +1,13 @@ using System.Text.Json; -using System.Text.Json.Nodes; -using Azure; using Devlooped.Extensions.AI; -using Google.Protobuf; using Grpc.Core; -using Grpc.Core.Interceptors; -using Grpc.Net.Client; using Microsoft.Extensions.AI; using Moq; using OpenAI; using Tests.Client.Helpers; -using xAI; using xAI.Protocol; using static ConfigurationExtensions; using Chat = Devlooped.Extensions.AI.Chat; -using OpenAIClientOptions = OpenAI.OpenAIClientOptions; namespace xAI.Tests; @@ -219,6 +212,23 @@ public async Task GrokInvokesSpecificSearchUrl() Assert.Contains("catedralaltapatagonia.com", citations); } + [SecretsFact("XAI_API_KEY")] + public async Task GrokPerformsInlineFileSearch() + { + var grok = new GrokClient(Configuration["XAI_API_KEY"]!).AsIChatClient("grok-4-1-fast-non-reasoning"); + + var message = new ChatMessage(ChatRole.User, + [ + new DataContent(File.ReadAllBytes("preferences.pdf"), "application/pdf"), + new TextContent("what's my favorite company?") + ]); + + var response = await grok.GetResponseAsync(message); + var text = response.Text; + + Assert.Contains("SpaceX", text); + } + [SecretsFact("XAI_API_KEY")] public async Task GrokInvokesHostedSearchTool() { @@ -742,7 +752,7 @@ public async Task GrokDoesNotAddEmptyContentToToolCallOnlyMessages() } [Fact] - public async Task GrokSendsDataContentAsBase64ImageUrl() + public async Task GrokSendsUriContentAsImageUrl() { GetCompletionsRequest? capturedRequest = null; var client = new Mock(MockBehavior.Strict); @@ -759,11 +769,11 @@ public async Task GrokSendsDataContentAsBase64ImageUrl() } })); - var imageBytes = new byte[] { 1, 2, 3, 4, 5 }; + var imageUri = new Uri("https://example.com/photo.jpg"); var grok = new GrokChatClient(client.Object, "grok-4-1-fast-non-reasoning"); var messages = new List { - new(ChatRole.User, [new TextContent("What do you see?"), new DataContent(imageBytes, "image/png")]), + new(ChatRole.User, [new TextContent("What do you see?"), new UriContent(imageUri, "image/jpeg")]), }; await grok.GetResponseAsync(messages); @@ -775,11 +785,11 @@ public async Task GrokSendsDataContentAsBase64ImageUrl() Assert.Equal("What do you see?", userMessage.Content[0].Text); var imageContent = userMessage.Content[1].ImageUrl; Assert.NotNull(imageContent); - Assert.Equal($"data:image/png;base64,{Convert.ToBase64String(imageBytes)}", imageContent.ImageUrl); + Assert.Equal(imageUri.ToString(), imageContent.ImageUrl); } [Fact] - public async Task GrokSendsUriContentAsImageUrl() + public async Task GrokSendsDataUriContentAsData() { GetCompletionsRequest? capturedRequest = null; var client = new Mock(MockBehavior.Strict); @@ -791,16 +801,16 @@ public async Task GrokSendsUriContentAsImageUrl() { new CompletionOutput { - Message = new CompletionMessage { Content = "I see an image." } + Message = new CompletionMessage { Content = "I see a PDF." } } } })); - var imageUri = new Uri("https://example.com/photo.jpg"); + var pdfUri = new Uri("https://example.com/data.pdf"); var grok = new GrokChatClient(client.Object, "grok-4-1-fast-non-reasoning"); var messages = new List { - new(ChatRole.User, [new TextContent("What do you see?"), new UriContent(imageUri, "image/jpeg")]), + new(ChatRole.User, [new TextContent("What do you see?"), new UriContent(pdfUri, "application/pdf")]), }; await grok.GetResponseAsync(messages); @@ -810,9 +820,11 @@ public async Task GrokSendsUriContentAsImageUrl() Assert.NotNull(userMessage); Assert.Equal(2, userMessage.Content.Count); Assert.Equal("What do you see?", userMessage.Content[0].Text); - var imageContent = userMessage.Content[1].ImageUrl; - Assert.NotNull(imageContent); - Assert.Equal(imageUri.ToString(), imageContent.ImageUrl); + var pdfFile = userMessage.Content[1].File; + Assert.NotNull(pdfFile); + + Assert.Equal(pdfUri.ToString(), pdfFile.Url); + Assert.Equal("application/pdf", pdfFile.MimeType); } [Fact] diff --git a/src/xAI.Tests/preferences.pdf b/src/xAI.Tests/preferences.pdf new file mode 100644 index 0000000..4c9bfee Binary files /dev/null and b/src/xAI.Tests/preferences.pdf differ diff --git a/src/xAI.Tests/xAI.Tests.csproj b/src/xAI.Tests/xAI.Tests.csproj index 0dc342c..7213863 100644 --- a/src/xAI.Tests/xAI.Tests.csproj +++ b/src/xAI.Tests/xAI.Tests.csproj @@ -44,6 +44,7 @@ + diff --git a/src/xAI/GrokChatClient.cs b/src/xAI/GrokChatClient.cs index eea6bdf..8763ae2 100644 --- a/src/xAI/GrokChatClient.cs +++ b/src/xAI/GrokChatClient.cs @@ -1,14 +1,17 @@ -using System.Text.Json; -using Google.Protobuf; using Grpc.Core; -using Grpc.Net.Client; using Microsoft.Extensions.AI; using xAI.Protocol; using static xAI.Protocol.Chat; namespace xAI; -class GrokChatClient : IChatClient +interface IGrokChatClient : IChatClient +{ + string DefaultModelId { get; } + string? EndUserId { get; } +} + +class GrokChatClient : IGrokChatClient { readonly ChatClientMetadata metadata; readonly ChatClient client; @@ -34,9 +37,12 @@ internal GrokChatClient(ChatClient client, GrokClientOptions clientOptions, stri metadata = new ChatClientMetadata("xai", clientOptions.Endpoint, defaultModelId); } + public string DefaultModelId => defaultModelId; + public string? EndUserId => clientOptions.EndUserId; + public async Task GetResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { - var request = MapToRequest(messages, options); + var request = this.AsCompletionsRequest(messages, options); var response = await client.GetCompletionAsync(request, cancellationToken: cancellationToken); var lastOutput = response.Outputs.OrderByDescending(x => x.Index).FirstOrDefault(); @@ -62,7 +68,7 @@ public IAsyncEnumerable GetStreamingResponseAsync(IEnumerabl async IAsyncEnumerable CompleteChatStreamingCore(IEnumerable messages, ChatOptions? options, [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken) { - var request = MapToRequest(messages, options); + var request = this.AsCompletionsRequest(messages, options); var call = client.GetCompletionChunk(request, cancellationToken: cancellationToken); await foreach (var chunk in call.ResponseStream.ReadAllAsync(cancellationToken)) @@ -130,162 +136,6 @@ static CitationAnnotation MapCitation(string citation) }; } - GetCompletionsRequest MapToRequest(IEnumerable messages, ChatOptions? options) - { - var request = options?.RawRepresentationFactory?.Invoke(this) as GetCompletionsRequest ?? new GetCompletionsRequest() - { - Model = options?.ModelId ?? defaultModelId, - }; - - if (string.IsNullOrEmpty(request.Model)) - request.Model = options?.ModelId ?? defaultModelId; - - if ((options?.EndUserId ?? clientOptions.EndUserId) is { } user) request.User = user; - if (options?.MaxOutputTokens is { } maxTokens) request.MaxTokens = maxTokens; - if (options?.Temperature is { } temperature) request.Temperature = temperature; - if (options?.TopP is { } topP) request.TopP = topP; - if (options?.FrequencyPenalty is { } frequencyPenalty) request.FrequencyPenalty = frequencyPenalty; - if (options?.PresencePenalty is { } presencePenalty) request.PresencePenalty = presencePenalty; - - foreach (var message in messages) - { - if (message.RawRepresentation is Message input) - { - request.Messages.Add(input); - continue; - } - else if (message.RawRepresentation is CompletionMessage completion) - { - request.Messages.Add(completion.AsMessage()); - continue; - } - - var gmsg = new Message { Role = message.Role.Convert() }; - - foreach (var content in message.Contents) - { - if (content.RawRepresentation is CompletionMessage completion) - { - request.Messages.Add(completion.AsMessage()); - continue; - } - - if (content is TextContent textContent && !string.IsNullOrEmpty(textContent.Text)) - { - gmsg.Content.Add(new Content { Text = textContent.Text }); - } - else if (content is TextReasoningContent reasoning) - { - gmsg.ReasoningContent = reasoning.Text; - gmsg.EncryptedContent = reasoning.ProtectedData; - } - else if (content is DataContent dataContent && dataContent.HasTopLevelMediaType("image")) - { - gmsg.Content.Add(new Content { ImageUrl = new ImageUrlContent { ImageUrl = $"data:{dataContent.MediaType};base64,{Convert.ToBase64String(dataContent.Data.Span)}" } }); - } - else if (content is UriContent uriContent && uriContent.HasTopLevelMediaType("image")) - { - gmsg.Content.Add(new Content { ImageUrl = new ImageUrlContent { ImageUrl = uriContent.Uri.ToString() } }); - } - else if (content.RawRepresentation is ToolCall toolCall) - { - gmsg.ToolCalls.Add(toolCall); - } - else if (content is FunctionCallContent functionCall) - { - gmsg.ToolCalls.Add(new ToolCall - { - Id = functionCall.CallId, - Type = ToolCallType.ClientSideTool, - Function = new FunctionCall - { - Name = functionCall.Name, - Arguments = JsonSerializer.Serialize(functionCall.Arguments) - } - }); - } - else if (content is FunctionResultContent resultContent) - { - var msg = new Message - { - Role = MessageRole.RoleTool, - Content = { new Content { Text = JsonSerializer.Serialize(resultContent.Result) ?? "null" } } - }; - - if (resultContent.CallId is { Length: > 0 } callId) - msg.ToolCallId = callId; - - request.Messages.Add(msg); - } - else if (content is McpServerToolResultContent mcpResult && - mcpResult.RawRepresentation is ToolCall mcpToolCall && - // TODO: what if there are multiple outputs? - mcpResult.Outputs is { Count: 1 } && - mcpResult.Outputs[0] is TextContent mcpText) - { - request.Messages.Add(new Message - { - Role = MessageRole.RoleTool, - ToolCalls = { mcpToolCall }, - Content = { new Content { Text = mcpText.Text } } - }); - } - else if (content is CodeInterpreterToolResultContent codeResult && - codeResult.RawRepresentation is ToolCall codeToolCall && - // TODO: what if there are multiple outputs? - codeResult.Outputs is { Count: 1 } && - codeResult.Outputs[0] is TextContent codeText) - { - request.Messages.Add(new Message - { - Role = MessageRole.RoleTool, - ToolCalls = { codeToolCall }, - Content = { new Content { Text = codeText.Text } } - }); - } - } - - if (gmsg.Content.Count == 0 && gmsg.ToolCalls.Count == 0) - continue; - - request.Messages.Add(gmsg); - } - - if (options is GrokChatOptions grokOptions) - { - request.Include.AddRange(grokOptions.Include); - - if (grokOptions.Search.HasFlag(GrokSearch.X)) - { - (options.Tools ??= []).Insert(0, new GrokXSearchTool()); - } - else if (grokOptions.Search.HasFlag(GrokSearch.Web)) - { - (options.Tools ??= []).Insert(0, new GrokSearchTool()); - } - - request.UseEncryptedContent = grokOptions.UseEncryptedContent; - } - - if (options?.Tools is not null) - { - foreach (var tool in options.Tools.Select(x => x.AsProtocolTool(options))) - if (tool is not null) request.Tools.Add(tool); - } - - if (options?.ResponseFormat is ChatResponseFormatJson jsonFormat) - { - request.ResponseFormat = new ResponseFormat { FormatType = FormatType.JsonObject }; - if (jsonFormat.Schema != null) - { - request.ResponseFormat.FormatType = FormatType.JsonSchema; - request.ResponseFormat.Schema = jsonFormat.Schema?.ToString(); - } - } - - return request; - } - /// public object? GetService(Type serviceType, object? serviceKey = null) => serviceType switch { diff --git a/src/xAI/GrokProtocolExtensions.cs b/src/xAI/GrokProtocolExtensions.cs index bc2780a..2be359c 100644 --- a/src/xAI/GrokProtocolExtensions.cs +++ b/src/xAI/GrokProtocolExtensions.cs @@ -221,6 +221,195 @@ static IEnumerable ToChatMessages(IEnumerable me } } + /// Converts messages and optional options to an xAI protocol completions request. + internal static GetCompletionsRequest AsCompletionsRequest(this IGrokChatClient client, IEnumerable messages, ChatOptions? options = null) + { + var request = options?.RawRepresentationFactory?.Invoke(client) as GetCompletionsRequest ?? new GetCompletionsRequest() + { + Model = options?.ModelId ?? client.DefaultModelId, + }; + + if (string.IsNullOrEmpty(request.Model)) + request.Model = options?.ModelId ?? client.DefaultModelId; + + if ((options?.EndUserId ?? client.EndUserId) is { } user) request.User = user; + if (options?.MaxOutputTokens is { } maxTokens) request.MaxTokens = maxTokens; + if (options?.Temperature is { } temperature) request.Temperature = temperature; + if (options?.TopP is { } topP) request.TopP = topP; + if (options?.FrequencyPenalty is { } frequencyPenalty) request.FrequencyPenalty = frequencyPenalty; + if (options?.PresencePenalty is { } presencePenalty) request.PresencePenalty = presencePenalty; + + foreach (var message in messages) + { + if (message.RawRepresentation is Message input) + { + request.Messages.Add(input); + continue; + } + else if (message.RawRepresentation is CompletionMessage completion) + { + request.Messages.Add(completion.AsMessage()); + continue; + } + + var gmsg = new Message { Role = message.Role.Convert() }; + + foreach (var content in message.Contents) + { + if (content.RawRepresentation is CompletionMessage completion) + { + request.Messages.Add(completion.AsMessage()); + continue; + } + if (content.RawRepresentation is Content protoContent) + { + gmsg.Content.Add(protoContent); + continue; + } + + if (content is TextContent textContent && !string.IsNullOrEmpty(textContent.Text)) + { + gmsg.Content.Add(new Content { Text = textContent.Text }); + } + else if (content is TextReasoningContent reasoning) + { + gmsg.ReasoningContent = reasoning.Text; + gmsg.EncryptedContent = reasoning.ProtectedData; + } + else if (content is DataContent dataContent) + { + gmsg.Content.Add(new Content + { + File = new FileContent + { + Data = Google.Protobuf.ByteString.CopyFrom(dataContent.Data.Span), + MimeType = dataContent.MediaType, + Filename = dataContent.Name ?? "", + } + }); + //gmsg.Content.Add(new Content { ImageUrl = new ImageUrlContent { ImageUrl = $"data:{dataContent.MediaType};base64,{System.Convert.ToBase64String(dataContent.Data.Span)}" } }); + } + else if (content is UriContent uriContent) + { + if (uriContent.HasTopLevelMediaType("image")) + { + gmsg.Content.Add(new Content + { + ImageUrl = new ImageUrlContent { ImageUrl = uriContent.Uri.ToString() }, + }); + } + else + { + gmsg.Content.Add(new Content + { + File = new FileContent + { + Url = uriContent.Uri.ToString(), + MimeType = uriContent.MediaType + } + }); + } + } + else if (content.RawRepresentation is ToolCall toolCall) + { + gmsg.ToolCalls.Add(toolCall); + } + else if (content is FunctionCallContent functionCall) + { + gmsg.ToolCalls.Add(new ToolCall + { + Id = functionCall.CallId, + Type = ToolCallType.ClientSideTool, + Function = new FunctionCall + { + Name = functionCall.Name, + Arguments = JsonSerializer.Serialize(functionCall.Arguments) + } + }); + } + else if (content is FunctionResultContent resultContent) + { + var msg = new Message + { + Role = MessageRole.RoleTool, + Content = { new Content { Text = JsonSerializer.Serialize(resultContent.Result) ?? "null" } } + }; + + if (resultContent.CallId is { Length: > 0 } callId) + msg.ToolCallId = callId; + + request.Messages.Add(msg); + } + else if (content is McpServerToolResultContent mcpResult && + mcpResult.RawRepresentation is ToolCall mcpToolCall && + // TODO: what if there are multiple outputs? + mcpResult.Outputs is { Count: 1 } && + mcpResult.Outputs[0] is TextContent mcpText) + { + request.Messages.Add(new Message + { + Role = MessageRole.RoleTool, + ToolCalls = { mcpToolCall }, + Content = { new Content { Text = mcpText.Text } } + }); + } + else if (content is CodeInterpreterToolResultContent codeResult && + codeResult.RawRepresentation is ToolCall codeToolCall && + // TODO: what if there are multiple outputs? + codeResult.Outputs is { Count: 1 } && + codeResult.Outputs[0] is TextContent codeText) + { + request.Messages.Add(new Message + { + Role = MessageRole.RoleTool, + ToolCalls = { codeToolCall }, + Content = { new Content { Text = codeText.Text } } + }); + } + } + + if (gmsg.Content.Count == 0 && gmsg.ToolCalls.Count == 0) + continue; + + request.Messages.Add(gmsg); + } + + if (options is GrokChatOptions grokOptions) + { + request.Include.AddRange(grokOptions.Include); + + if (grokOptions.Search.HasFlag(GrokSearch.X)) + { + (options.Tools ??= []).Insert(0, new GrokXSearchTool()); + } + else if (grokOptions.Search.HasFlag(GrokSearch.Web)) + { + (options.Tools ??= []).Insert(0, new GrokSearchTool()); + } + + request.UseEncryptedContent = grokOptions.UseEncryptedContent; + } + + if (options?.Tools is not null) + { + foreach (var tool in options.Tools.Select(x => x.AsProtocolTool(options))) + if (tool is not null) request.Tools.Add(tool); + } + + if (options?.ResponseFormat is ChatResponseFormatJson jsonFormat) + { + request.ResponseFormat = new ResponseFormat { FormatType = FormatType.JsonObject }; + if (jsonFormat.Schema != null) + { + request.ResponseFormat.FormatType = FormatType.JsonSchema; + request.ResponseFormat.Schema = jsonFormat.Schema?.ToString(); + } + } + + return request; + } + + internal static IEnumerable AsContents(this IEnumerable toolCalls, string? content = default, List? annotations = default) { foreach (var toolCall in toolCalls)