diff --git a/docs/docs/integrations/index.mdx b/docs/docs/integrations/index.mdx index 19c7fdc240..6be7384ac3 100644 --- a/docs/docs/integrations/index.mdx +++ b/docs/docs/integrations/index.mdx @@ -38,7 +38,7 @@ of course some LLM providers offer large multimodal model (accepting text or ima | [Qianfan](/integrations/language-models/qianfan) | | ✅ | ✅ | ✅ | | |✅ | | [ChatGLM](/integrations/language-models/chatglm) | | ✅ | | | | | | [Nomic](/integrations/language-models/nomic) | | | |✅ | | | | -| [Anthropic](/integrations/language-models/anthropic) | |✅ | | | | | | +| [Anthropic](/integrations/language-models/anthropic) | |✅ | | | | | ✅ | | [Zhipu AI](/integrations/language-models/zhipuai) | |✅| ✅| ✅| | |✅ | diff --git a/docs/docs/integrations/language-models/anthropic.md b/docs/docs/integrations/language-models/anthropic.md index 2ee3f3fc63..8847f14c75 100644 --- a/docs/docs/integrations/language-models/anthropic.md +++ b/docs/docs/integrations/language-models/anthropic.md @@ -4,7 +4,10 @@ sidebar_position: 2 # Anthropic -[Anthropic](https://www.anthropic.com/) +- [Anthropic Documentation](https://docs.anthropic.com/claude/docs) +- [Anthropic API Reference](https://docs.anthropic.com/claude/reference) + +## Maven Dependency ```xml @@ -14,12 +17,89 @@ sidebar_position: 2 ``` +## AnthropicChatModel + +```java +AnthropicChatModel model = AnthropicChatModel.withApiKey(System.getenv("ANTHROPIC_API_KEY")); +String answer = model.generate("Say 'Hello World'"); +System.out.println(answer); +``` + +### Customizing +```java +AnthropicChatModel model = AnthropicChatModel.builder() + .baseUrl(...) + .apiKey(...) + .version(...) + .beta(...) + .modelName(...) + .temperature(...) + .topP(...) + .topK(...) + .maxTokens(...) + .stopSequences(...) + .timeout(...) + .maxRetries(...) + .logRequests(...) + .logResponses(...) + .build(); +``` +See the description of some of the parameters above [here](https://docs.anthropic.com/claude/reference/messages_post). + +## AnthropicStreamingChatModel ```java -ChatLanguageModel model = AnthropicChatModel.withApiKey(System.getenv("ANTHROPIC_API_KEY")); +AnthropicStreamingChatModel model = AnthropicStreamingChatModel.withApiKey(System.getenv("ANTHROPIC_API_KEY")); -String answer = model.generate("What is the capital of Germany?"); +model.generate("Say 'Hello World'", new StreamingResponseHandler() { -System.out.println(answer); // Berlin + @Override + public void onNext(String token) { + // this method is called when a new token is available + } + + @Override + public void onComplete(Response response) { + // this method is called when the model has completed responding + } + + @Override + public void onError(Throwable error) { + // this method is called when an error occurs + } +}); ``` -More info is coming soon +### Customizing + +Identical to the `AnthropicChatModel`, see above. + +## Tools + +Anthropic supports [tools](/tutorials/tools), but only in a non-streaming mode. + +Anthropic documentation on tools can be found [here](https://docs.anthropic.com/claude/docs/tool-use). + +## Quarkus + +TODO + +## Spring Boot + +Import Spring Boot starter for Anthropic: +```xml + + dev.langchain4j + langchain4j-anthropic-spring-boot-starter + 0.29.1 + +``` + +Configure `AnthropicChatModel` bean: +``` +langchain4j.anthropic.chat-model.api-key = ${ANTHROPIC_API_KEY} +``` + +Configure `AnthropicStreamingChatModel` bean: +``` +langchain4j.anthropic.streaming-chat-model.api-key = ${ANTHROPIC_API_KEY} +``` \ No newline at end of file diff --git a/docs/docs/tutorials/6-tools.md b/docs/docs/tutorials/6-tools.md index e7b64de3f8..15011dbb7b 100644 --- a/docs/docs/tutorials/6-tools.md +++ b/docs/docs/tutorials/6-tools.md @@ -121,29 +121,82 @@ Please note that tools/function calling is not the same as [JSON mode](/tutorial ## 2 levels of abstraction -LangChain4j provides two levels of abstraction for working with tools. +LangChain4j provides two levels of abstraction for using tools. ### Low level Tool API -At the low level, you can use the `generate(List, List)` -and `generate(List, ToolSpecification)` methods -of `ChatLanguageModel` (and similar methods of `StreamingChatLanguageModel`). -You'll need to manually create `ToolSpecification` object(s) containing all information about the tool, -or use the `ToolSpecifications.toolSpecificationFrom(Method)` helper method -to convert any Java method into a `ToolSpecification`. +At the low level, you can use the `generate(List, List)` method +of the `ChatLanguageModel`. A similar method is also present in the `StreamingChatLanguageModel`. -When the LLM decides to call the tool, the returned `AiMessage` will have data -in a `List toolExecutionRequests` field instead of a `String text` field. -Depending on the LLM, it can contain one or multiple `ToolExecutionRequest`s +`ToolSpecification` is an object that contains all the information about the tool: +- The `name` of the tool +- The `description` of the tool +- The `parameters` (arguments) of the tool and their descriptions + +It is recommended to provide as much information about the tool as possible: +a clear name, a comprehensive description, and a description for each parameter, etc. + +There are two ways to create a `ToolSpecification`: + +1. Manually +```java +ToolSpecification toolSpecification = ToolSpecification.builder() + .name("getWeather") + .description("Returns the weather forecast for a given city") + .addParameter("city", type("string"), description("The city for which the weather forecast should be returned")) + .addParameter("temperatureUnit", enums(TemperatureUnit.class)) // enum TemperatureUnit { CELSIUS, FAHRENHEIT } + .build(); +``` + +2. Using helper methods: +- `ToolSpecifications.toolSpecificationsFrom(Class)` +- `ToolSpecifications.toolSpecificationsFrom(Object)` +- `ToolSpecifications.toolSpecificationFrom(Method)` + +```java +class WeatherTools { + + @Tool("Returns the weather forecast for a given city") + String getWeather( + @P("The city for which the weather forecast should be returned") String city, + TemperatureUnit temperatureUnit + ) { + ... + } +} + +List toolSpecifications = ToolSpecifications.toolSpecificationsFrom(WeatherTools.class); +``` + +Once you have a `List`, you can call the model: +```java +UserMessage userMessage = UserMessage.from("What will the weather be like in London tomorrow?"); +Response response = model.generate(singletoneList(userMessage), toolSpecifications); +AiMessage aiMessage = response.content(); +``` + +If the LLM decides to call the tool, the returned `AiMessage` will contain data +in the `toolExecutionRequests` field. +In this case, `AiMessage.hasToolExecutionRequests()` will return `true`. +Depending on the LLM, it can contain one or multiple `ToolExecutionRequest` objects (some LLMs support calling multiple tools in parallel). -The `ToolExecutionRequest` will include the tool call's `id`, the `name` of the tool to be called, -and `arguments` (a valid JSON containing a value for each tool parameter). -You'll need to manually execute the tool(s) using information from the `ToolExecutionRequest`(s) -and then create a `ToolExecutionResultMessage` containing each tool's execution result. +Each `ToolExecutionRequest` should contain: +- The `id` of the tool call (some LLMs do not provide it) +- The `name` of the tool to be called, for example: `getWeather` +- The `arguments`, for example: `{ "city": "London", "temperatureUnit": "CELSIUS" }` + +You'll need to manually execute the tool(s) using information from the `ToolExecutionRequest`(s). -Then, call the LLM with all messages (`UserMessage`, `AiMessage` containing `ToolExecutionRequest`, -`ToolExecutionResultMessage`) to get the final response from the LLM. +If you want to send the result of the tool execution back to the LLM, +you need to create a `ToolExecutionResultMessage` (one for each `ToolExecutionRequest`) +and send it along with all previous messages: +```java +String result = "It is expected to rain in London tomorrow."; +ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.from(toolExecutionRequest, result); +List messages = List.of(userMessage, aiMessage, toolExecutionResultMessage); +Response response2 = model.generate(messages, toolSpecifications); +``` ### High Level Tool API At a high level, you can annotate any Java method with the `@Tool` annotation diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicApi.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicApi.java index 4a54f28c7a..843128a45a 100644 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicApi.java +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicApi.java @@ -12,6 +12,7 @@ interface AnthropicApi { @Headers({"content-type: application/json"}) Call createMessage(@Header(X_API_KEY) String apiKey, @Header("anthropic-version") String version, + @Header("anthropic-beta") String beta, @Body AnthropicCreateMessageRequest request); @Streaming diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicChatModel.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicChatModel.java index 54b40f145c..193f2f5b04 100644 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicChatModel.java +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicChatModel.java @@ -1,5 +1,6 @@ package dev.langchain4j.model.anthropic; +import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.image.Image; import dev.langchain4j.data.message.*; import dev.langchain4j.model.chat.ChatLanguageModel; @@ -21,15 +22,15 @@ * More details are available here. *
*
+ * It supports tools. See more information here. + *
+ *
* It supports {@link Image}s as inputs. {@link UserMessage}s can contain one or multiple {@link ImageContent}s. * {@link Image}s must not be represented as URLs; they should be Base64-encoded strings and include a {@code mimeType}. *
*
* The content of {@link SystemMessage}s is sent using the "system" parameter. * If there are multiple {@link SystemMessage}s, they are concatenated with a double newline (\n\n). - *
- *
- * Does not support tools. */ public class AnthropicChatModel implements ChatLanguageModel { @@ -48,6 +49,7 @@ public class AnthropicChatModel implements ChatLanguageModel { * @param baseUrl The base URL of the Anthropic API. Default: "https://api.anthropic.com/v1/" * @param apiKey The API key for authentication with the Anthropic API. * @param version The version of the Anthropic API. Default: "2023-06-01" + * @param beta The value of the "anthropic-beta" HTTP header. It is used when tools are present in the request. Default: "tools-2024-04-04" * @param modelName The name of the Anthropic model to use. Default: "claude-3-haiku-20240307" * @param temperature The temperature * @param topP The top-P @@ -63,6 +65,7 @@ public class AnthropicChatModel implements ChatLanguageModel { private AnthropicChatModel(String baseUrl, String apiKey, String version, + String beta, String modelName, Double temperature, Double topP, @@ -77,6 +80,7 @@ private AnthropicChatModel(String baseUrl, .baseUrl(getOrDefault(baseUrl, "https://api.anthropic.com/v1/")) .apiKey(apiKey) .version(getOrDefault(version, "2023-06-01")) + .beta(getOrDefault(beta, "tools-2024-04-04")) .timeout(getOrDefault(timeout, Duration.ofSeconds(60))) .logRequests(getOrDefault(logRequests, false)) .logResponses(getOrDefault(logResponses, false)) @@ -115,6 +119,11 @@ public static AnthropicChatModel withApiKey(String apiKey) { @Override public Response generate(List messages) { + return generate(messages, (List) null); + } + + @Override + public Response generate(List messages, List toolSpecifications) { ensureNotEmpty(messages, "messages"); AnthropicCreateMessageRequest request = AnthropicCreateMessageRequest.builder() @@ -127,6 +136,7 @@ public Response generate(List messages) { .temperature(temperature) .topP(topP) .topK(topK) + .tools(toAnthropicTools(toolSpecifications)) .build(); AnthropicCreateMessageResponse response = withRetry(() -> client.createMessage(request), maxRetries); @@ -137,4 +147,6 @@ public Response generate(List messages) { toFinishReason(response.stopReason) ); } + + // TODO forcing tool use? } diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicClient.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicClient.java index 710f8d276b..83116a1752 100644 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicClient.java +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicClient.java @@ -7,7 +7,9 @@ import java.time.Duration; public abstract class AnthropicClient { + public abstract AnthropicCreateMessageResponse createMessage(AnthropicCreateMessageRequest request); + public abstract void createMessage(AnthropicCreateMessageRequest request, StreamingResponseHandler handler); @SuppressWarnings("rawtypes") @@ -20,9 +22,11 @@ public static AnthropicClient.Builder builder() { } public abstract static class Builder> { + public String baseUrl; public String apiKey; public String version; + public String beta; public Duration timeout; public Boolean logRequests; public Boolean logResponses; @@ -39,7 +43,8 @@ public B baseUrl(String baseUrl) { public B apiKey(String apiKey) { if (apiKey == null || apiKey.trim().isEmpty()) { - throw new IllegalArgumentException("Anthropic API Key must be defined."); + throw new IllegalArgumentException("Anthropic API key must be defined. " + + "It can be generated here: https://console.anthropic.com/settings/keys"); } this.apiKey = apiKey; return (B) this; @@ -53,6 +58,14 @@ public B version(String version) { return (B) this; } + public B beta(String beta) { + if (beta == null) { + throw new IllegalArgumentException("beta cannot be null or empty"); + } + this.beta = beta; + return (B) this; + } + public B timeout(Duration timeout) { if (timeout == null) { throw new IllegalArgumentException("timeout cannot be null"); diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicContent.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicContent.java index 3960a1bdbc..59dc693288 100644 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicContent.java +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicContent.java @@ -1,7 +1,19 @@ package dev.langchain4j.model.anthropic; +import lombok.Builder; + +import java.util.Map; + +@Builder public class AnthropicContent { public String type; + + // when type = "text" public String text; + + // when type = "tool_use" + public String id; + public String name; + public Map input; } \ No newline at end of file diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicCreateMessageRequest.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicCreateMessageRequest.java index 47420a00b1..3798183d87 100644 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicCreateMessageRequest.java +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicCreateMessageRequest.java @@ -22,4 +22,5 @@ public class AnthropicCreateMessageRequest { Double temperature; Double topP; Integer topK; + List tools; } diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicImageContent.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicImageContent.java index 2d4ef0548d..23c2c032b3 100644 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicImageContent.java +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicImageContent.java @@ -1,11 +1,16 @@ package dev.langchain4j.model.anthropic; -public class AnthropicImageContent { +import lombok.EqualsAndHashCode; +import lombok.ToString; + +@ToString +@EqualsAndHashCode(callSuper = true) +public class AnthropicImageContent extends AnthropicMessageContent { - public String type = "image"; public AnthropicImageContentSource source; public AnthropicImageContent(String mediaType, String data) { + super("image"); this.source = new AnthropicImageContentSource("base64", mediaType, data); } } diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicMapper.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicMapper.java index 482bf84e19..716930177f 100644 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicMapper.java +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicMapper.java @@ -1,16 +1,22 @@ package dev.langchain4j.model.anthropic; +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.image.Image; import dev.langchain4j.data.message.*; import dev.langchain4j.model.output.FinishReason; import dev.langchain4j.model.output.TokenUsage; +import java.util.ArrayList; import java.util.List; -import static dev.langchain4j.data.message.ChatMessageType.SYSTEM; import static dev.langchain4j.internal.Exceptions.illegalArgument; -import static dev.langchain4j.internal.Utils.isNullOrBlank; +import static dev.langchain4j.internal.Utils.*; import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; +import static dev.langchain4j.model.anthropic.AnthropicRole.ASSISTANT; +import static dev.langchain4j.model.anthropic.AnthropicRole.USER; +import static dev.langchain4j.model.anthropic.DefaultAnthropicClient.GSON; +import static dev.langchain4j.model.anthropic.DefaultAnthropicClient.MAP_TYPE; import static dev.langchain4j.model.output.FinishReason.*; import static java.util.stream.Collectors.joining; import static java.util.stream.Collectors.toList; @@ -18,17 +24,81 @@ public class AnthropicMapper { static List toAnthropicMessages(List messages) { - return messages.stream() - .filter(message -> message.type() != SYSTEM) - .map(AnthropicMapper::toAnthropicMessage) - .collect(toList()); + + List anthropicMessages = new ArrayList<>(); + List toolContents = new ArrayList<>(); + + for (ChatMessage message : messages) { + + if (message instanceof ToolExecutionResultMessage) { + toolContents.add(toAnthropicToolResultContent((ToolExecutionResultMessage) message)); + } else { + if (!toolContents.isEmpty()) { + anthropicMessages.add(new AnthropicMessage(USER, toolContents)); + toolContents = new ArrayList<>(); + } + + if (message instanceof UserMessage) { + List contents = toAnthropicMessageContents((UserMessage) message); + anthropicMessages.add(new AnthropicMessage(USER, contents)); + } else if (message instanceof AiMessage) { + List contents = toAnthropicMessageContents((AiMessage) message); + anthropicMessages.add(new AnthropicMessage(ASSISTANT, contents)); + } + } + } + + if (!toolContents.isEmpty()) { + anthropicMessages.add(new AnthropicMessage(USER, toolContents)); + } + + return anthropicMessages; } - static AnthropicMessage toAnthropicMessage(ChatMessage message) { - return AnthropicMessage.builder() - .role(toAnthropicRole(message.type())) - .content(toAnthropicContent(message)) - .build(); + private static AnthropicToolResultContent toAnthropicToolResultContent(ToolExecutionResultMessage message) { + return new AnthropicToolResultContent(message.id(), message.text(), null); // TODO propagate isError + } + + private static List toAnthropicMessageContents(UserMessage message) { + return message.contents().stream() + .map(content -> { + if (content instanceof TextContent) { + return new AnthropicTextContent(((TextContent) content).text()); + } else if (content instanceof ImageContent) { + Image image = ((ImageContent) content).image(); + if (image.url() != null) { + throw illegalArgument("Anthropic does not support images as URLs, " + + "only as Base64-encoded strings"); + } + return new AnthropicImageContent( + ensureNotBlank(image.mimeType(), "mimeType"), + ensureNotBlank(image.base64Data(), "base64Data") + ); + } else { + throw illegalArgument("Unknown content type: " + content); + } + }).collect(toList()); + } + + private static List toAnthropicMessageContents(AiMessage message) { + List contents = new ArrayList<>(); + + if (isNotNullOrBlank(message.text())) { + contents.add(new AnthropicTextContent(message.text())); + } + + if (message.hasToolExecutionRequests()) { + List toolUseContents = message.toolExecutionRequests().stream() + .map(toolExecutionRequest -> AnthropicToolUseContent.builder() + .id(toolExecutionRequest.id()) + .name(toolExecutionRequest.name()) + .input(GSON.fromJson(toolExecutionRequest.arguments(), MAP_TYPE)) + .build()) + .collect(toList()); + contents.addAll(toolUseContents); + } + + return contents; } static String toAnthropicSystemPrompt(List messages) { @@ -44,52 +114,29 @@ static String toAnthropicSystemPrompt(List messages) { } } - private static AnthropicRole toAnthropicRole(ChatMessageType chatMessageType) { - switch (chatMessageType) { - case AI: - return AnthropicRole.ASSISTANT; - case USER: - return AnthropicRole.USER; - default: - throw new IllegalArgumentException("Unknown chat message type: " + chatMessageType); - } - } - - private static Object toAnthropicContent(ChatMessage message) { - if (message instanceof AiMessage) { - AiMessage aiMessage = (AiMessage) message; - return aiMessage.text(); - } else if (message instanceof UserMessage) { - UserMessage userMessage = (UserMessage) message; - return userMessage.contents().stream() - .map(content -> { - if (content instanceof TextContent) { - return new AnthropicTextContent(((TextContent) content).text()); - } else if (content instanceof ImageContent) { - Image image = ((ImageContent) content).image(); - if (image.url() != null) { - throw illegalArgument("Anthropic does not support images as URLs, " + - "only as Base64-encoded strings"); - } - return new AnthropicImageContent( - ensureNotBlank(image.mimeType(), "mimeType"), - ensureNotBlank(image.base64Data(), "base64Data") - ); - } else { - throw illegalArgument("Unknown content type: " + content); - } - }).collect(toList()); - } else { - throw new IllegalArgumentException("Unknown message type: " + message.type()); - } - } - public static AiMessage toAiMessage(List contents) { + String text = contents.stream() .filter(content -> "text".equals(content.type)) .map(content -> content.text) .collect(joining("\n")); - return AiMessage.from(text); + + List toolExecutionRequests = contents.stream() + .filter(content -> "tool_use".equals(content.type)) + .map(content -> ToolExecutionRequest.builder() + .id(content.id) + .name(content.name) + .arguments(GSON.toJson(content.input)) + .build()) + .collect(toList()); + + if (isNotNullOrBlank(text) && !isNullOrEmpty(toolExecutionRequests)) { + return new AiMessage(text, toolExecutionRequests); + } else if (!isNullOrEmpty(toolExecutionRequests)) { + return AiMessage.from(toolExecutionRequests); + } else { + return AiMessage.from(text); + } } public static TokenUsage toTokenUsage(AnthropicUsage anthropicUsage) { @@ -110,8 +157,30 @@ public static FinishReason toFinishReason(String anthropicStopReason) { return LENGTH; case "stop_sequence": return OTHER; // TODO + case "tool_use": + return TOOL_EXECUTION; default: return null; // TODO } } + + static List toAnthropicTools(List toolSpecifications) { + if (toolSpecifications == null) { + return null; + } + return toolSpecifications.stream() + .map(AnthropicMapper::toAnthropicTool) + .collect(toList()); + } + + static AnthropicTool toAnthropicTool(ToolSpecification toolSpecification) { + return AnthropicTool.builder() + .name(toolSpecification.name()) + .description(toolSpecification.description()) + .inputSchema(AnthropicToolSchema.builder() + .properties(toolSpecification.parameters().properties()) + .required(toolSpecification.parameters().required()) + .build()) + .build(); + } } diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicMessage.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicMessage.java index 417c8dc19f..f130ca2ced 100644 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicMessage.java +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicMessage.java @@ -1,10 +1,18 @@ package dev.langchain4j.model.anthropic; +import lombok.AllArgsConstructor; import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.ToString; + +import java.util.List; @Builder +@ToString +@EqualsAndHashCode +@AllArgsConstructor public class AnthropicMessage { AnthropicRole role; - Object content; + List content; } diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicMessageContent.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicMessageContent.java new file mode 100644 index 0000000000..1cae26cc51 --- /dev/null +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicMessageContent.java @@ -0,0 +1,13 @@ +package dev.langchain4j.model.anthropic; + +import lombok.EqualsAndHashCode; + +@EqualsAndHashCode +public abstract class AnthropicMessageContent { + + public String type; + + public AnthropicMessageContent(String type) { + this.type = type; + } +} diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicTextContent.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicTextContent.java index af23f43631..5bd677d55d 100644 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicTextContent.java +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicTextContent.java @@ -1,11 +1,16 @@ package dev.langchain4j.model.anthropic; -public class AnthropicTextContent { +import lombok.EqualsAndHashCode; +import lombok.ToString; + +@ToString +@EqualsAndHashCode(callSuper = true) +public class AnthropicTextContent extends AnthropicMessageContent { - public String type = "text"; public String text; public AnthropicTextContent(String text) { + super("text"); this.text = text; } } diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicTool.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicTool.java new file mode 100644 index 0000000000..0173108974 --- /dev/null +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicTool.java @@ -0,0 +1,11 @@ +package dev.langchain4j.model.anthropic; + +import lombok.Builder; + +@Builder +public class AnthropicTool { + + public String name; + public String description; + public AnthropicToolSchema inputSchema; +} diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicToolResultContent.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicToolResultContent.java new file mode 100644 index 0000000000..055aa19dde --- /dev/null +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicToolResultContent.java @@ -0,0 +1,20 @@ +package dev.langchain4j.model.anthropic; + +import lombok.EqualsAndHashCode; +import lombok.ToString; + +@ToString +@EqualsAndHashCode(callSuper = true) +public class AnthropicToolResultContent extends AnthropicMessageContent { + + public String toolUseId; + public String content; + public Boolean isError; + + public AnthropicToolResultContent(String toolUseId, String content, Boolean isError) { + super("tool_result"); + this.toolUseId = toolUseId; + this.content = content; + this.isError = isError; + } +} diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicToolSchema.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicToolSchema.java new file mode 100644 index 0000000000..8b9cc330eb --- /dev/null +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicToolSchema.java @@ -0,0 +1,15 @@ +package dev.langchain4j.model.anthropic; + +import lombok.Builder; + +import java.util.List; +import java.util.Map; + +@Builder +public class AnthropicToolSchema { + + @Builder.Default + public String type = "object"; + public Map> properties; + public List required; +} diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicToolUseContent.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicToolUseContent.java new file mode 100644 index 0000000000..a6fa290379 --- /dev/null +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicToolUseContent.java @@ -0,0 +1,24 @@ +package dev.langchain4j.model.anthropic; + +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.ToString; + +import java.util.Map; + +@ToString +@EqualsAndHashCode(callSuper = true) +public class AnthropicToolUseContent extends AnthropicMessageContent { + + public String id; + public String name; + public Map input; + + @Builder + public AnthropicToolUseContent(String id, String name, Map input) { + super("tool_use"); + this.id = id; + this.name = name; + this.input = input; + } +} diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/DefaultAnthropicClient.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/DefaultAnthropicClient.java index 574e985a7b..d06ebd1e2f 100644 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/DefaultAnthropicClient.java +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/DefaultAnthropicClient.java @@ -1,8 +1,8 @@ package dev.langchain4j.model.anthropic; -import com.google.gson.FieldNamingPolicy; import com.google.gson.Gson; import com.google.gson.GsonBuilder; +import com.google.gson.reflect.TypeToken; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.model.StreamingResponseHandler; import dev.langchain4j.model.output.Response; @@ -19,28 +19,38 @@ import retrofit2.converter.gson.GsonConverterFactory; import java.io.IOException; +import java.lang.reflect.Type; import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; -import static dev.langchain4j.internal.Utils.isNotNullOrEmpty; -import static dev.langchain4j.internal.Utils.isNullOrBlank; +import static com.google.gson.FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES; +import static com.google.gson.ToNumberPolicy.LONG_OR_DOUBLE; +import static dev.langchain4j.internal.Utils.*; import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; import static dev.langchain4j.model.anthropic.AnthropicMapper.toFinishReason; import static java.util.Collections.synchronizedList; public class DefaultAnthropicClient extends AnthropicClient { + private static final Logger LOGGER = LoggerFactory.getLogger(DefaultAnthropicClient.class); - private static final Gson GSON = new GsonBuilder() - .setFieldNamingPolicy(FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES) + + static final Gson GSON = new GsonBuilder() + .setFieldNamingPolicy(LOWER_CASE_WITH_UNDERSCORES) + .setObjectToNumberStrategy(LONG_OR_DOUBLE) .setPrettyPrinting() .create(); + static final Type MAP_TYPE = new TypeToken>() { + }.getType(); + private final AnthropicApi anthropicApi; private final OkHttpClient okHttpClient; private final String apiKey; private final String version; + private final String beta; private final boolean logResponses; public static Builder builder() { @@ -48,6 +58,7 @@ public static Builder builder() { } public static class Builder extends AnthropicClient.Builder { + public DefaultAnthropicClient build() { return new DefaultAnthropicClient(this); } @@ -61,6 +72,7 @@ public DefaultAnthropicClient build() { this.apiKey = builder.apiKey; this.version = ensureNotBlank(builder.version, "version"); + this.beta = builder.beta; this.logResponses = builder.logResponses; OkHttpClient.Builder okHttpClientBuilder = new OkHttpClient.Builder() @@ -91,7 +103,7 @@ public DefaultAnthropicClient build() { public AnthropicCreateMessageResponse createMessage(AnthropicCreateMessageRequest request) { try { retrofit2.Response retrofitResponse - = anthropicApi.createMessage(apiKey, version, request).execute(); + = anthropicApi.createMessage(apiKey, version, toBeta(request), request).execute(); if (retrofitResponse.isSuccessful()) { return retrofitResponse.body(); } else { @@ -107,6 +119,17 @@ public AnthropicCreateMessageResponse createMessage(AnthropicCreateMessageReques } } + private String toBeta(AnthropicCreateMessageRequest request) { + return hasTools(request) ? beta : null; + } + + private static boolean hasTools(AnthropicCreateMessageRequest request) { + return !isNullOrEmpty(request.tools) || request.messages.stream() + .flatMap(message -> message.content.stream()) + .anyMatch(content -> + (content instanceof AnthropicToolUseContent) || (content instanceof AnthropicToolResultContent)); + } + @Override public void createMessage(AnthropicCreateMessageRequest request, StreamingResponseHandler handler) { diff --git a/langchain4j-anthropic/src/main/resources/META-INF/native-image/dev.langchain4j/langchain4j-anthropic/reflect-config.json b/langchain4j-anthropic/src/main/resources/META-INF/native-image/dev.langchain4j/langchain4j-anthropic/reflect-config.json index 52339f1422..f250ab6e83 100644 --- a/langchain4j-anthropic/src/main/resources/META-INF/native-image/dev.langchain4j/langchain4j-anthropic/reflect-config.json +++ b/langchain4j-anthropic/src/main/resources/META-INF/native-image/dev.langchain4j/langchain4j-anthropic/reflect-config.json @@ -62,6 +62,15 @@ "allDeclaredFields": true, "allPublicFields": true }, + { + "name": "dev.langchain4j.model.anthropic.AnthropicMessageContent", + "allDeclaredConstructors": true, + "allPublicConstructors": true, + "allDeclaredMethods": true, + "allPublicMethods": true, + "allDeclaredFields": true, + "allPublicFields": true + }, { "name": "dev.langchain4j.model.anthropic.AnthropicResponseMessage", "allDeclaredConstructors": true, @@ -98,6 +107,42 @@ "allDeclaredFields": true, "allPublicFields": true }, + { + "name": "dev.langchain4j.model.anthropic.AnthropicTool", + "allDeclaredConstructors": true, + "allPublicConstructors": true, + "allDeclaredMethods": true, + "allPublicMethods": true, + "allDeclaredFields": true, + "allPublicFields": true + }, + { + "name": "dev.langchain4j.model.anthropic.AnthropicToolResultContent", + "allDeclaredConstructors": true, + "allPublicConstructors": true, + "allDeclaredMethods": true, + "allPublicMethods": true, + "allDeclaredFields": true, + "allPublicFields": true + }, + { + "name": "dev.langchain4j.model.anthropic.AnthropicToolSchema", + "allDeclaredConstructors": true, + "allPublicConstructors": true, + "allDeclaredMethods": true, + "allPublicMethods": true, + "allDeclaredFields": true, + "allPublicFields": true + }, + { + "name": "dev.langchain4j.model.anthropic.AnthropicToolUseContent", + "allDeclaredConstructors": true, + "allPublicConstructors": true, + "allDeclaredMethods": true, + "allPublicMethods": true, + "allDeclaredFields": true, + "allPublicFields": true + }, { "name": "dev.langchain4j.model.anthropic.AnthropicUsage", "allDeclaredConstructors": true, diff --git a/langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/AnthropicChatModelIT.java b/langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/AnthropicChatModelIT.java index e5965a2b5b..90fb11ffa8 100644 --- a/langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/AnthropicChatModelIT.java +++ b/langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/AnthropicChatModelIT.java @@ -1,5 +1,7 @@ package dev.langchain4j.model.anthropic; +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.message.*; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.output.Response; @@ -7,26 +9,32 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.EnumSource; +import org.junit.jupiter.params.provider.MethodSource; import java.time.Duration; import java.util.Base64; import java.util.List; +import java.util.stream.Stream; +import static dev.langchain4j.agent.tool.JsonSchemaProperty.*; +import static dev.langchain4j.data.message.ToolExecutionResultMessage.from; import static dev.langchain4j.data.message.UserMessage.userMessage; import static dev.langchain4j.internal.Utils.readBytes; import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_SONNET_20240229; import static dev.langchain4j.model.output.FinishReason.*; import static java.lang.System.getenv; import static java.util.Arrays.asList; +import static java.util.Arrays.stream; import static java.util.Collections.singletonList; +import static java.util.Collections.singletonMap; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; class AnthropicChatModelIT { static final String CAT_IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/e/e9/Felis_silvestris_silvestris_small_gradual_decrease_of_quality.png"; - static final String DICE_IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png"; ChatLanguageModel model = AnthropicChatModel.builder() .apiKey(System.getenv("ANTHROPIC_API_KEY")) @@ -42,6 +50,20 @@ class AnthropicChatModelIT { .logResponses(true) .build(); + ToolSpecification calculator = ToolSpecification.builder() + .name("calculator") + .description("returns a sum of two numbers") + .addParameter("first", INTEGER) + .addParameter("second", INTEGER) + .build(); + + ToolSpecification weather = ToolSpecification.builder() + .name("weather") + .description("returns a weather forecast for a given location") + // TODO simplify defining nested properties + .addParameter("location", OBJECT, property("properties", singletonMap("city", singletonMap("type", "string")))) + .build(); + @AfterEach void afterEach() throws InterruptedException { Thread.sleep(10_000L); // to avoid hitting rate limits @@ -291,4 +313,190 @@ void should_fail_with_rate_limit_error() { .hasRootCauseExactlyInstanceOf(AnthropicHttpException.class) .hasMessageContaining("rate_limit_error"); } + + @ParameterizedTest + @MethodSource("models_supporting_tools") + void should_execute_a_tool_then_answer(AnthropicChatModelName modelName) { + + // given + ChatLanguageModel model = AnthropicChatModel.builder() + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .modelName(modelName) + .temperature(0.0) + .logRequests(true) + .logResponses(true) + .build(); + + List toolSpecifications = singletonList(calculator); + + UserMessage userMessage = userMessage("2+2=?"); + + // when + Response response = model.generate(singletonList(userMessage), toolSpecifications); + + // then + AiMessage aiMessage = response.content(); + assertThat(aiMessage.toolExecutionRequests()).hasSize(1); + + ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0); + assertThat(toolExecutionRequest.id()).isNotBlank(); + assertThat(toolExecutionRequest.name()).isEqualTo("calculator"); + assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"first\": 2, \"second\": 2}"); + + TokenUsage tokenUsage = response.tokenUsage(); + assertThat(tokenUsage.inputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage.totalTokenCount()) + .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); + + assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION); + + // given + ToolExecutionResultMessage toolExecutionResultMessage = from(toolExecutionRequest, "4"); + List messages = asList(userMessage, aiMessage, toolExecutionResultMessage); + + // when + Response secondResponse = model.generate(messages, toolSpecifications); + + // then + AiMessage secondAiMessage = secondResponse.content(); + assertThat(secondAiMessage.text()).contains("4"); + assertThat(secondAiMessage.toolExecutionRequests()).isNull(); + + TokenUsage secondTokenUsage = secondResponse.tokenUsage(); + assertThat(secondTokenUsage.inputTokenCount()).isGreaterThan(0); + assertThat(secondTokenUsage.outputTokenCount()).isGreaterThan(0); + assertThat(secondTokenUsage.totalTokenCount()) + .isEqualTo(secondTokenUsage.inputTokenCount() + secondTokenUsage.outputTokenCount()); + + assertThat(secondResponse.finishReason()).isEqualTo(STOP); + } + + @ParameterizedTest + @MethodSource("models_supporting_tools") + void should_execute_multiple_tools_in_parallel_then_answer(AnthropicChatModelName modelName) { + + // given + ChatLanguageModel model = AnthropicChatModel.builder() + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .modelName(modelName) + .temperature(0.0) + .logRequests(true) + .logResponses(true) + .build(); + + List toolSpecifications = singletonList(calculator); + + UserMessage userMessage = userMessage("How much is 2+2 and 3+3? Call tools in parallel!"); + + // when + Response response = model.generate(singletonList(userMessage), toolSpecifications); + + // then + AiMessage aiMessage = response.content(); + assertThat(aiMessage.toolExecutionRequests()).hasSize(2); + + ToolExecutionRequest toolExecutionRequest1 = aiMessage.toolExecutionRequests().get(0); + assertThat(toolExecutionRequest1.name()).isEqualTo("calculator"); + assertThat(toolExecutionRequest1.arguments()).isEqualToIgnoringWhitespace("{\"first\": 2, \"second\": 2}"); + + ToolExecutionRequest toolExecutionRequest2 = aiMessage.toolExecutionRequests().get(1); + assertThat(toolExecutionRequest2.name()).isEqualTo("calculator"); + assertThat(toolExecutionRequest2.arguments()).isEqualToIgnoringWhitespace("{\"first\": 3, \"second\": 3}"); + + TokenUsage tokenUsage = response.tokenUsage(); + assertThat(tokenUsage.inputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage.totalTokenCount()) + .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); + + assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION); + + // given + ToolExecutionResultMessage toolExecutionResultMessage1 = from(toolExecutionRequest1, "4"); + ToolExecutionResultMessage toolExecutionResultMessage2 = from(toolExecutionRequest2, "6"); + + List messages = asList(userMessage, aiMessage, toolExecutionResultMessage1, toolExecutionResultMessage2); + + // when + Response secondResponse = model.generate(messages, toolSpecifications); + + // then + AiMessage secondAiMessage = secondResponse.content(); + assertThat(secondAiMessage.text()).contains("4", "6"); + assertThat(secondAiMessage.toolExecutionRequests()).isNull(); + + TokenUsage secondTokenUsage = secondResponse.tokenUsage(); + assertThat(secondTokenUsage.inputTokenCount()).isGreaterThan(0); + assertThat(secondTokenUsage.outputTokenCount()).isGreaterThan(0); + assertThat(secondTokenUsage.totalTokenCount()) + .isEqualTo(secondTokenUsage.inputTokenCount() + secondTokenUsage.outputTokenCount()); + + assertThat(secondResponse.finishReason()).isEqualTo(STOP); + } + + @ParameterizedTest + @MethodSource("models_supporting_tools") + void should_execute_a_tool_with_nested_properties_then_answer(AnthropicChatModelName modelName) { + + // given + ChatLanguageModel model = AnthropicChatModel.builder() + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .modelName(modelName) + .temperature(0.0) + .logRequests(true) + .logResponses(true) + .build(); + + List toolSpecifications = singletonList(weather); + + UserMessage userMessage = userMessage("What is the weather in Munich?"); + + // when + Response response = model.generate(singletonList(userMessage), toolSpecifications); + + // then + AiMessage aiMessage = response.content(); + assertThat(aiMessage.toolExecutionRequests()).hasSize(1); + + ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0); + assertThat(toolExecutionRequest.id()).isNotBlank(); + assertThat(toolExecutionRequest.name()).isEqualTo("weather"); + assertThat(toolExecutionRequest.arguments()) + .isEqualToIgnoringWhitespace("{\"location\": {\"city\": \"Munich\"}}"); + + TokenUsage tokenUsage = response.tokenUsage(); + assertThat(tokenUsage.inputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage.totalTokenCount()) + .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); + + assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION); + + // given + ToolExecutionResultMessage toolExecutionResultMessage = from(toolExecutionRequest, "Super hot, 42 Celsius"); + List messages = asList(userMessage, aiMessage, toolExecutionResultMessage); + + // when + Response secondResponse = model.generate(messages, toolSpecifications); + + // then + AiMessage secondAiMessage = secondResponse.content(); + assertThat(secondAiMessage.text()).contains("42"); + assertThat(secondAiMessage.toolExecutionRequests()).isNull(); + + TokenUsage secondTokenUsage = secondResponse.tokenUsage(); + assertThat(secondTokenUsage.inputTokenCount()).isGreaterThan(0); + assertThat(secondTokenUsage.outputTokenCount()).isGreaterThan(0); + assertThat(secondTokenUsage.totalTokenCount()) + .isEqualTo(secondTokenUsage.inputTokenCount() + secondTokenUsage.outputTokenCount()); + + assertThat(secondResponse.finishReason()).isEqualTo(STOP); + } + + static Stream models_supporting_tools() { + return stream(AnthropicChatModelName.values()) + .filter(modelName -> modelName.toString().startsWith("claude-3")) + .map(Arguments::of); + } } \ No newline at end of file diff --git a/langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/AnthropicMapperTest.java b/langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/AnthropicMapperTest.java new file mode 100644 index 0000000000..decd3890a1 --- /dev/null +++ b/langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/AnthropicMapperTest.java @@ -0,0 +1,214 @@ +package dev.langchain4j.model.anthropic; + +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.data.message.*; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.AbstractMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; + +import static dev.langchain4j.model.anthropic.AnthropicMapper.toAnthropicMessages; +import static dev.langchain4j.model.anthropic.AnthropicRole.ASSISTANT; +import static dev.langchain4j.model.anthropic.AnthropicRole.USER; +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; +import static java.util.stream.Collectors.toMap; +import static org.assertj.core.api.Assertions.assertThat; + +class AnthropicMapperTest { + + @ParameterizedTest + @MethodSource + void test_toAnthropicMessages(List messages, List expectedAnthropicMessages) { + + // when + List anthropicMessages = toAnthropicMessages(messages); + + //then + assertThat(anthropicMessages).containsExactlyElementsOf(expectedAnthropicMessages); + } + + static Stream test_toAnthropicMessages() { + return Stream.of( + Arguments.of( + singletonList(UserMessage.from("Hello")), + singletonList(new AnthropicMessage(USER, singletonList(new AnthropicTextContent("Hello")))) + ), + Arguments.of( + asList( + SystemMessage.from("Ignored"), + UserMessage.from("Hello") + ), + singletonList(new AnthropicMessage(USER, singletonList(new AnthropicTextContent("Hello")))) + ), + Arguments.of( + asList( + UserMessage.from("Hello"), + SystemMessage.from("Ignored") + ), + singletonList(new AnthropicMessage(USER, singletonList(new AnthropicTextContent("Hello")))) + ), + Arguments.of( + asList( + UserMessage.from("Hello"), + AiMessage.from("Hi"), + UserMessage.from("How are you?") + ), + asList( + new AnthropicMessage(USER, singletonList(new AnthropicTextContent("Hello"))), + new AnthropicMessage(ASSISTANT, singletonList(new AnthropicTextContent("Hi"))), + new AnthropicMessage(USER, singletonList(new AnthropicTextContent("How are you?"))) + ) + ), + Arguments.of( + asList( + UserMessage.from("How much is 2+2?"), + AiMessage.from( + ToolExecutionRequest.builder() + .id("12345") + .name("calculator") + .arguments("{\"first\": 2, \"second\": 2}") + .build() + ), + ToolExecutionResultMessage.from("12345", "calculator", "4") + ), + asList( + new AnthropicMessage(USER, singletonList(new AnthropicTextContent("How much is 2+2?"))), + new AnthropicMessage(ASSISTANT, singletonList( + AnthropicToolUseContent.builder() + .id("12345") + .name("calculator") + .input(mapOf(entry("first", 2L), entry("second", 2L))) + .build() + )), + new AnthropicMessage(USER, singletonList( + new AnthropicToolResultContent("12345", "4", null) + )) + ) + ), + Arguments.of( + asList( + UserMessage.from("How much is 2+2?"), + new AiMessage( + "I need to use the calculator tool", + singletonList(ToolExecutionRequest.builder() + .id("12345") + .name("calculator") + .arguments("{\"first\": 2, \"second\": 2}") + .build()) + ), + ToolExecutionResultMessage.from("12345", "calculator", "4") + ), + asList( + new AnthropicMessage(USER, singletonList(new AnthropicTextContent("How much is 2+2?"))), + new AnthropicMessage(ASSISTANT, asList( + new AnthropicTextContent("I need to use the calculator tool"), + AnthropicToolUseContent.builder() + .id("12345") + .name("calculator") + .input(mapOf(entry("first", 2L), entry("second", 2L))) + .build() + )), + new AnthropicMessage(USER, singletonList( + new AnthropicToolResultContent("12345", "4", null) + )) + ) + ), + Arguments.of( + asList( + UserMessage.from("How much is 2+2 and 3+3?"), + AiMessage.from( + ToolExecutionRequest.builder() + .id("12345") + .name("calculator") + .arguments("{\"first\": 2, \"second\": 2}") + .build(), + ToolExecutionRequest.builder() + .id("67890") + .name("calculator") + .arguments("{\"first\": 3, \"second\": 3}") + .build() + ), + ToolExecutionResultMessage.from("12345", "calculator", "4"), + ToolExecutionResultMessage.from("67890", "calculator", "6") + ), + asList( + new AnthropicMessage(USER, singletonList(new AnthropicTextContent("How much is 2+2 and 3+3?"))), + new AnthropicMessage(ASSISTANT, asList( + AnthropicToolUseContent.builder() + .id("12345") + .name("calculator") + .input(mapOf(entry("first", 2L), entry("second", 2L))) + .build(), + AnthropicToolUseContent.builder() + .id("67890") + .name("calculator") + .input(mapOf(entry("first", 3L), entry("second", 3L))) + .build() + )), + new AnthropicMessage(USER, asList( + new AnthropicToolResultContent("12345", "4", null), + new AnthropicToolResultContent("67890", "6", null) + )) + ) + ), + Arguments.of( + asList( + UserMessage.from("How much is 2+2 and 3+3?"), + AiMessage.from( + ToolExecutionRequest.builder() + .id("12345") + .name("calculator") + .arguments("{\"first\": 2, \"second\": 2}") + .build() + ), + ToolExecutionResultMessage.from("12345", "calculator", "4"), + AiMessage.from( + ToolExecutionRequest.builder() + .id("67890") + .name("calculator") + .arguments("{\"first\": 3, \"second\": 3}") + .build() + ), + ToolExecutionResultMessage.from("67890", "calculator", "6") + ), + asList( + new AnthropicMessage(USER, singletonList(new AnthropicTextContent("How much is 2+2 and 3+3?"))), + new AnthropicMessage(ASSISTANT, singletonList( + AnthropicToolUseContent.builder() + .id("12345") + .name("calculator") + .input(mapOf(entry("first", 2L), entry("second", 2L))) + .build() + )), + new AnthropicMessage(USER, singletonList( + new AnthropicToolResultContent("12345", "4", null) + )), + new AnthropicMessage(ASSISTANT, singletonList( + AnthropicToolUseContent.builder() + .id("67890") + .name("calculator") + .input(mapOf(entry("first", 3L), entry("second", 3L))) + .build() + )), + new AnthropicMessage(USER, singletonList( + new AnthropicToolResultContent("67890", "6", null) + )) + ) + ) + ); + } + + @SafeVarargs + private static Map mapOf(Map.Entry... entries) { + return Stream.of(entries).collect(toMap(Map.Entry::getKey, Map.Entry::getValue)); + } + + private static Map.Entry entry(K key, V value) { + return new AbstractMap.SimpleEntry<>(key, value); + } +} \ No newline at end of file diff --git a/langchain4j-core/src/main/java/dev/langchain4j/agent/tool/ToolSpecifications.java b/langchain4j-core/src/main/java/dev/langchain4j/agent/tool/ToolSpecifications.java index ecf3a92c53..fadce38888 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/agent/tool/ToolSpecifications.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/agent/tool/ToolSpecifications.java @@ -24,20 +24,31 @@ private ToolSpecifications() { } /** - * Get the {@link ToolSpecification}s for each {@link Tool} method of the given object. + * Returns {@link ToolSpecification}s for all methods annotated with @{@link Tool} within the specified class. * - * @param objectWithTools the object. + * @param classWithTools the class. * @return the {@link ToolSpecification}s. */ - public static List toolSpecificationsFrom(Object objectWithTools) { - return stream(objectWithTools.getClass().getDeclaredMethods()) + public static List toolSpecificationsFrom(Class classWithTools) { + return stream(classWithTools.getDeclaredMethods()) .filter(method -> method.isAnnotationPresent(Tool.class)) .map(ToolSpecifications::toolSpecificationFrom) .collect(toList()); } /** - * Get the {@link ToolSpecification} for the given {@link Tool} method. + * Returns {@link ToolSpecification}s for all methods annotated with @{@link Tool} + * within the class of the specified object. + * + * @param objectWithTools the object. + * @return the {@link ToolSpecification}s. + */ + public static List toolSpecificationsFrom(Object objectWithTools) { + return toolSpecificationsFrom(objectWithTools.getClass()); + } + + /** + * Returns the {@link ToolSpecification} for the given method annotated with @{@link Tool}. * * @param method the method. * @return the {@link ToolSpecification}. diff --git a/langchain4j-core/src/main/java/dev/langchain4j/data/message/AiMessage.java b/langchain4j-core/src/main/java/dev/langchain4j/data/message/AiMessage.java index 4c5e1c742d..e029b5c261 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/data/message/AiMessage.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/data/message/AiMessage.java @@ -8,8 +8,7 @@ import static dev.langchain4j.data.message.ChatMessageType.AI; import static dev.langchain4j.internal.Utils.isNullOrEmpty; import static dev.langchain4j.internal.Utils.quoted; -import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; -import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; +import static dev.langchain4j.internal.ValidationUtils.*; import static java.util.Arrays.asList; /** @@ -42,6 +41,17 @@ public AiMessage(List toolExecutionRequests) { this.toolExecutionRequests = ensureNotEmpty(toolExecutionRequests, "toolExecutionRequests"); } + /** + * Create a new {@link AiMessage} with the given text and tool execution requests. + * + * @param text the text of the message. + * @param toolExecutionRequests the tool execution requests of the message. + */ + public AiMessage(String text, List toolExecutionRequests) { + this.text = ensureNotBlank(text, "text"); + this.toolExecutionRequests = ensureNotEmpty(toolExecutionRequests, "toolExecutionRequests"); + } + /** * Get the text of the message. *