Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 30 additions & 18 deletions src/xAI.Tests/ChatClientTests.cs
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
using System.Text.Json;
using System.Text.Json.Nodes;
using Azure;
using Devlooped.Extensions.AI;
using Google.Protobuf;
using Grpc.Core;
using Grpc.Core.Interceptors;
using Grpc.Net.Client;
using Microsoft.Extensions.AI;
using Moq;
using OpenAI;
using Tests.Client.Helpers;
using xAI;
using xAI.Protocol;
using static ConfigurationExtensions;
using Chat = Devlooped.Extensions.AI.Chat;
using OpenAIClientOptions = OpenAI.OpenAIClientOptions;

namespace xAI.Tests;

Expand Down Expand Up @@ -219,6 +212,23 @@ public async Task GrokInvokesSpecificSearchUrl()
Assert.Contains("catedralaltapatagonia.com", citations);
}

[SecretsFact("XAI_API_KEY")]
public async Task GrokPerformsInlineFileSearch()
{
var grok = new GrokClient(Configuration["XAI_API_KEY"]!).AsIChatClient("grok-4-1-fast-non-reasoning");

var message = new ChatMessage(ChatRole.User,
[
new DataContent(File.ReadAllBytes("preferences.pdf"), "application/pdf"),
new TextContent("what's my favorite company?")
]);

var response = await grok.GetResponseAsync(message);
var text = response.Text;

Assert.Contains("SpaceX", text);
}

[SecretsFact("XAI_API_KEY")]
public async Task GrokInvokesHostedSearchTool()
{
Expand Down Expand Up @@ -742,7 +752,7 @@ public async Task GrokDoesNotAddEmptyContentToToolCallOnlyMessages()
}

[Fact]
public async Task GrokSendsDataContentAsBase64ImageUrl()
public async Task GrokSendsUriContentAsImageUrl()
{
GetCompletionsRequest? capturedRequest = null;
var client = new Mock<xAI.Protocol.Chat.ChatClient>(MockBehavior.Strict);
Expand All @@ -759,11 +769,11 @@ public async Task GrokSendsDataContentAsBase64ImageUrl()
}
}));

var imageBytes = new byte[] { 1, 2, 3, 4, 5 };
var imageUri = new Uri("https://example.com/photo.jpg");
var grok = new GrokChatClient(client.Object, "grok-4-1-fast-non-reasoning");
var messages = new List<ChatMessage>
{
new(ChatRole.User, [new TextContent("What do you see?"), new DataContent(imageBytes, "image/png")]),
new(ChatRole.User, [new TextContent("What do you see?"), new UriContent(imageUri, "image/jpeg")]),
};

await grok.GetResponseAsync(messages);
Expand All @@ -775,11 +785,11 @@ public async Task GrokSendsDataContentAsBase64ImageUrl()
Assert.Equal("What do you see?", userMessage.Content[0].Text);
var imageContent = userMessage.Content[1].ImageUrl;
Assert.NotNull(imageContent);
Assert.Equal($"data:image/png;base64,{Convert.ToBase64String(imageBytes)}", imageContent.ImageUrl);
Assert.Equal(imageUri.ToString(), imageContent.ImageUrl);
}

[Fact]
public async Task GrokSendsUriContentAsImageUrl()
public async Task GrokSendsDataUriContentAsData()
{
GetCompletionsRequest? capturedRequest = null;
var client = new Mock<xAI.Protocol.Chat.ChatClient>(MockBehavior.Strict);
Expand All @@ -791,16 +801,16 @@ public async Task GrokSendsUriContentAsImageUrl()
{
new CompletionOutput
{
Message = new CompletionMessage { Content = "I see an image." }
Message = new CompletionMessage { Content = "I see a PDF." }
}
}
}));

var imageUri = new Uri("https://example.com/photo.jpg");
var pdfUri = new Uri("https://example.com/data.pdf");
var grok = new GrokChatClient(client.Object, "grok-4-1-fast-non-reasoning");
var messages = new List<ChatMessage>
{
new(ChatRole.User, [new TextContent("What do you see?"), new UriContent(imageUri, "image/jpeg")]),
new(ChatRole.User, [new TextContent("What do you see?"), new UriContent(pdfUri, "application/pdf")]),
};

await grok.GetResponseAsync(messages);
Expand All @@ -810,9 +820,11 @@ public async Task GrokSendsUriContentAsImageUrl()
Assert.NotNull(userMessage);
Assert.Equal(2, userMessage.Content.Count);
Assert.Equal("What do you see?", userMessage.Content[0].Text);
var imageContent = userMessage.Content[1].ImageUrl;
Assert.NotNull(imageContent);
Assert.Equal(imageUri.ToString(), imageContent.ImageUrl);
var pdfFile = userMessage.Content[1].File;
Assert.NotNull(pdfFile);

Assert.Equal(pdfUri.ToString(), pdfFile.Url);
Assert.Equal("application/pdf", pdfFile.MimeType);
}

[Fact]
Expand Down
Binary file added src/xAI.Tests/preferences.pdf
Binary file not shown.
1 change: 1 addition & 0 deletions src/xAI.Tests/xAI.Tests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

<ItemGroup>
<None Update=".env" CopyToOutputDirectory="PreserveNewest" />
<None Update="preferences.pdf" CopyToOutputDirectory="PreserveNewest" />
<None Update="Content\**\*.*;*.md" CopyToOutputDirectory="PreserveNewest" />
<Content Update="Content\**\*.*" CopyToOutputDirectory="PreserveNewest" />
<Content Include="*.json;*.ini;*.toml" Exclude="@(Content)" CopyToOutputDirectory="PreserveNewest" />
Expand Down
174 changes: 12 additions & 162 deletions src/xAI/GrokChatClient.cs
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
using System.Text.Json;
using Google.Protobuf;
using Grpc.Core;
using Grpc.Net.Client;
using Microsoft.Extensions.AI;
using xAI.Protocol;
using static xAI.Protocol.Chat;

namespace xAI;

class GrokChatClient : IChatClient
interface IGrokChatClient : IChatClient
{
string DefaultModelId { get; }
string? EndUserId { get; }
}

class GrokChatClient : IGrokChatClient
{
readonly ChatClientMetadata metadata;
readonly ChatClient client;
Expand All @@ -34,9 +37,12 @@ internal GrokChatClient(ChatClient client, GrokClientOptions clientOptions, stri
metadata = new ChatClientMetadata("xai", clientOptions.Endpoint, defaultModelId);
}

public string DefaultModelId => defaultModelId;
public string? EndUserId => clientOptions.EndUserId;

public async Task<ChatResponse> GetResponseAsync(IEnumerable<ChatMessage> messages, ChatOptions? options = null, CancellationToken cancellationToken = default)
{
var request = MapToRequest(messages, options);
var request = this.AsCompletionsRequest(messages, options);
var response = await client.GetCompletionAsync(request, cancellationToken: cancellationToken);
var lastOutput = response.Outputs.OrderByDescending(x => x.Index).FirstOrDefault();

Expand All @@ -62,7 +68,7 @@ public IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(IEnumerabl

async IAsyncEnumerable<ChatResponseUpdate> CompleteChatStreamingCore(IEnumerable<ChatMessage> messages, ChatOptions? options, [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken)
{
var request = MapToRequest(messages, options);
var request = this.AsCompletionsRequest(messages, options);
var call = client.GetCompletionChunk(request, cancellationToken: cancellationToken);

await foreach (var chunk in call.ResponseStream.ReadAllAsync(cancellationToken))
Expand Down Expand Up @@ -130,162 +136,6 @@ static CitationAnnotation MapCitation(string citation)
};
}

GetCompletionsRequest MapToRequest(IEnumerable<ChatMessage> messages, ChatOptions? options)
{
var request = options?.RawRepresentationFactory?.Invoke(this) as GetCompletionsRequest ?? new GetCompletionsRequest()
{
Model = options?.ModelId ?? defaultModelId,
};

if (string.IsNullOrEmpty(request.Model))
request.Model = options?.ModelId ?? defaultModelId;

if ((options?.EndUserId ?? clientOptions.EndUserId) is { } user) request.User = user;
if (options?.MaxOutputTokens is { } maxTokens) request.MaxTokens = maxTokens;
if (options?.Temperature is { } temperature) request.Temperature = temperature;
if (options?.TopP is { } topP) request.TopP = topP;
if (options?.FrequencyPenalty is { } frequencyPenalty) request.FrequencyPenalty = frequencyPenalty;
if (options?.PresencePenalty is { } presencePenalty) request.PresencePenalty = presencePenalty;

foreach (var message in messages)
{
if (message.RawRepresentation is Message input)
{
request.Messages.Add(input);
continue;
}
else if (message.RawRepresentation is CompletionMessage completion)
{
request.Messages.Add(completion.AsMessage());
continue;
}

var gmsg = new Message { Role = message.Role.Convert() };

foreach (var content in message.Contents)
{
if (content.RawRepresentation is CompletionMessage completion)
{
request.Messages.Add(completion.AsMessage());
continue;
}

if (content is TextContent textContent && !string.IsNullOrEmpty(textContent.Text))
{
gmsg.Content.Add(new Content { Text = textContent.Text });
}
else if (content is TextReasoningContent reasoning)
{
gmsg.ReasoningContent = reasoning.Text;
gmsg.EncryptedContent = reasoning.ProtectedData;
}
else if (content is DataContent dataContent && dataContent.HasTopLevelMediaType("image"))
{
gmsg.Content.Add(new Content { ImageUrl = new ImageUrlContent { ImageUrl = $"data:{dataContent.MediaType};base64,{Convert.ToBase64String(dataContent.Data.Span)}" } });
}
else if (content is UriContent uriContent && uriContent.HasTopLevelMediaType("image"))
{
gmsg.Content.Add(new Content { ImageUrl = new ImageUrlContent { ImageUrl = uriContent.Uri.ToString() } });
}
else if (content.RawRepresentation is ToolCall toolCall)
{
gmsg.ToolCalls.Add(toolCall);
}
else if (content is FunctionCallContent functionCall)
{
gmsg.ToolCalls.Add(new ToolCall
{
Id = functionCall.CallId,
Type = ToolCallType.ClientSideTool,
Function = new FunctionCall
{
Name = functionCall.Name,
Arguments = JsonSerializer.Serialize(functionCall.Arguments)
}
});
}
else if (content is FunctionResultContent resultContent)
{
var msg = new Message
{
Role = MessageRole.RoleTool,
Content = { new Content { Text = JsonSerializer.Serialize(resultContent.Result) ?? "null" } }
};

if (resultContent.CallId is { Length: > 0 } callId)
msg.ToolCallId = callId;

request.Messages.Add(msg);
}
else if (content is McpServerToolResultContent mcpResult &&
mcpResult.RawRepresentation is ToolCall mcpToolCall &&
// TODO: what if there are multiple outputs?
mcpResult.Outputs is { Count: 1 } &&
mcpResult.Outputs[0] is TextContent mcpText)
{
request.Messages.Add(new Message
{
Role = MessageRole.RoleTool,
ToolCalls = { mcpToolCall },
Content = { new Content { Text = mcpText.Text } }
});
}
else if (content is CodeInterpreterToolResultContent codeResult &&
codeResult.RawRepresentation is ToolCall codeToolCall &&
// TODO: what if there are multiple outputs?
codeResult.Outputs is { Count: 1 } &&
codeResult.Outputs[0] is TextContent codeText)
{
request.Messages.Add(new Message
{
Role = MessageRole.RoleTool,
ToolCalls = { codeToolCall },
Content = { new Content { Text = codeText.Text } }
});
}
}

if (gmsg.Content.Count == 0 && gmsg.ToolCalls.Count == 0)
continue;

request.Messages.Add(gmsg);
}

if (options is GrokChatOptions grokOptions)
{
request.Include.AddRange(grokOptions.Include);

if (grokOptions.Search.HasFlag(GrokSearch.X))
{
(options.Tools ??= []).Insert(0, new GrokXSearchTool());
}
else if (grokOptions.Search.HasFlag(GrokSearch.Web))
{
(options.Tools ??= []).Insert(0, new GrokSearchTool());
}

request.UseEncryptedContent = grokOptions.UseEncryptedContent;
}

if (options?.Tools is not null)
{
foreach (var tool in options.Tools.Select(x => x.AsProtocolTool(options)))
if (tool is not null) request.Tools.Add(tool);
}

if (options?.ResponseFormat is ChatResponseFormatJson jsonFormat)
{
request.ResponseFormat = new ResponseFormat { FormatType = FormatType.JsonObject };
if (jsonFormat.Schema != null)
{
request.ResponseFormat.FormatType = FormatType.JsonSchema;
request.ResponseFormat.Schema = jsonFormat.Schema?.ToString();
}
}

return request;
}

/// <inheritdoc />
public object? GetService(Type serviceType, object? serviceKey = null) => serviceType switch
{
Expand Down
Loading
Loading