Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Anthropic: Support tools #897

Merged
merged 7 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
85 changes: 69 additions & 16 deletions docs/docs/tutorials/6-tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,29 +121,82 @@ Please note that tools/function calling is not the same as [JSON mode](/tutorial

## 2 levels of abstraction

LangChain4j provides two levels of abstraction for working with tools.
LangChain4j provides two levels of abstraction for using tools.

### Low level Tool API
At the low level, you can use the `generate(List<ChatMessage>, List<ToolSpecification>)`
and `generate(List<ChatMessage>, ToolSpecification)` methods
of `ChatLanguageModel` (and similar methods of `StreamingChatLanguageModel`).

You'll need to manually create `ToolSpecification` object(s) containing all information about the tool,
or use the `ToolSpecifications.toolSpecificationFrom(Method)` helper method
to convert any Java method into a `ToolSpecification`.
At the low level, you can use the `generate(List<ChatMessage>, List<ToolSpecification>)` method
of the `ChatLanguageModel`. A similar method is also present in the `StreamingChatLanguageModel`.

When the LLM decides to call the tool, the returned `AiMessage` will have data
in a `List<ToolExecutionRequest> toolExecutionRequests` field instead of a `String text` field.
Depending on the LLM, it can contain one or multiple `ToolExecutionRequest`s
`ToolSpecification` is an object that contains all the information about the tool:
- The `name` of the tool
- The `description` of the tool
- The `parameters` of the tool

It is recommended to provide as much information about the tool as possible:
a clear name, a complete description, and a description for each parameter, etc.

There are two ways to create a `ToolSpecification`:

1. Manually
```java
ToolSpecification toolSpecification = ToolSpecification.builder()
.name("getWeather")
.description("Returns the weather forecast for a given city")
.addParameter("city", type("string"), description("The city for which the weather forecast should be returned"))
.addParameter("temperatureUnit", enums(TemperatureUnit.class)) // enum TemperatureUnit { CELSIUS, FAHRENHEIT }
.build();
```

2. Using helper methods:
- `ToolSpecifications.toolSpecificationsFrom(Class)`
- `ToolSpecifications.toolSpecificationsFrom(Object)`
- `ToolSpecifications.toolSpecificationFrom(Method)`

```java
class WeatherTools {

@Tool("Returns the weather forecast for a given city")
String getWeather(
@P("The city for which the weather forecast should be returned") String city,
TemperatureUnit temperatureUnit
langchain4j marked this conversation as resolved.
Show resolved Hide resolved
) {
...
}
}

List<ToolSpecification> toolSpecifications = ToolSpecifications.toolSpecificationsFrom(WeatherTools.class);
```

Once you have a `List<ToolSpecification>`, you can call the model:
```java
UserMessage userMessage = UserMessage.from("What will the weather be like in London tomorrow?");
Response<AiMessage> response = model.generate(singletoneList(userMessage), toolSpecifications);
AiMessage aiMessage = response.content();
```

If the LLM decides to call the tool, the returned `AiMessage` will contain data
in the `toolExecutionRequests` field.
In this case, `AiMessage.hasToolExecutionRequests()` will return `true`.
Depending on the LLM, it can contain one or multiple `ToolExecutionRequest` objects
(some LLMs support calling multiple tools in parallel).

The `ToolExecutionRequest` will include the tool call's `id`, the `name` of the tool to be called,
and `arguments` (a valid JSON containing a value for each tool parameter).
You'll need to manually execute the tool(s) using information from the `ToolExecutionRequest`(s)
and then create a `ToolExecutionResultMessage` containing each tool's execution result.
Each `ToolExecutionRequest` should contain:
- The `id` of the tool call (some LLMs do not provide it)
- The `name` of the tool to be called, for example: `getWeather`
- The `arguments`, for example: `{ "city": "London", "temperatureUnit": "CELSIUS" }`

You'll need to manually execute the tool(s) using information from the `ToolExecutionRequest`(s).

Then, call the LLM with all messages (`UserMessage`, `AiMessage` containing `ToolExecutionRequest`,
`ToolExecutionResultMessage`) to get the final response from the LLM.
If you want to send the result of the tool execution back to the LLM,
you need to create a `ToolExecutionResultMessage` (one for each `ToolExecutionRequest`)
and send it along with all previous messages:
```java
String result = "It is expected to rain in London tomorrow.";
ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.from(toolExecutionRequest, result);
List<ChatMessage> messages = List.of(userMessage, aiMessage, toolExecutionResultMessage);
Response<AiMessage> response2 = model.generate(messages, toolSpecifications);
```

### High Level Tool API
At a high level, you can annotate any Java method with the `@Tool` annotation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ interface AnthropicApi {
@Headers({"content-type: application/json"})
Call<AnthropicCreateMessageResponse> createMessage(@Header(X_API_KEY) String apiKey,
@Header("anthropic-version") String version,
@Header("anthropic-beta") String beta,
@Body AnthropicCreateMessageRequest request);

@Streaming
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package dev.langchain4j.model.anthropic;

import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.image.Image;
import dev.langchain4j.data.message.*;
import dev.langchain4j.model.chat.ChatLanguageModel;
Expand All @@ -21,15 +22,15 @@
* More details are available <a href="https://docs.anthropic.com/claude/reference/messages_post">here</a>.
* <br>
* <br>
* It supports tools. See more information <a href="https://docs.anthropic.com/claude/docs/tool-use">here</a>.
* <br>
* <br>
* It supports {@link Image}s as inputs. {@link UserMessage}s can contain one or multiple {@link ImageContent}s.
* {@link Image}s must not be represented as URLs; they should be Base64-encoded strings and include a {@code mimeType}.
* <br>
* <br>
* The content of {@link SystemMessage}s is sent using the "system" parameter.
* If there are multiple {@link SystemMessage}s, they are concatenated with a double newline (\n\n).
* <br>
* <br>
* Does not support tools.
*/
public class AnthropicChatModel implements ChatLanguageModel {

Expand All @@ -48,6 +49,7 @@ public class AnthropicChatModel implements ChatLanguageModel {
* @param baseUrl The base URL of the Anthropic API. Default: "https://api.anthropic.com/v1/"
* @param apiKey The API key for authentication with the Anthropic API.
* @param version The version of the Anthropic API. Default: "2023-06-01"
* @param beta The value of the "anthropic-beta" HTTP header. It is used when tools are present in the request. Default: "tools-2024-04-04"
* @param modelName The name of the Anthropic model to use. Default: "claude-3-haiku-20240307"
* @param temperature The temperature
* @param topP The top-P
Expand All @@ -63,6 +65,7 @@ public class AnthropicChatModel implements ChatLanguageModel {
private AnthropicChatModel(String baseUrl,
String apiKey,
String version,
String beta,
String modelName,
Double temperature,
Double topP,
Expand All @@ -77,6 +80,7 @@ private AnthropicChatModel(String baseUrl,
.baseUrl(getOrDefault(baseUrl, "https://api.anthropic.com/v1/"))
.apiKey(apiKey)
.version(getOrDefault(version, "2023-06-01"))
.beta(getOrDefault(beta, "tools-2024-04-04"))
.timeout(getOrDefault(timeout, Duration.ofSeconds(60)))
.logRequests(getOrDefault(logRequests, false))
.logResponses(getOrDefault(logResponses, false))
Expand Down Expand Up @@ -115,6 +119,11 @@ public static AnthropicChatModel withApiKey(String apiKey) {

@Override
public Response<AiMessage> generate(List<ChatMessage> messages) {
return generate(messages, (List<ToolSpecification>) null);
}

@Override
public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
ensureNotEmpty(messages, "messages");

AnthropicCreateMessageRequest request = AnthropicCreateMessageRequest.builder()
Expand All @@ -127,6 +136,7 @@ public Response<AiMessage> generate(List<ChatMessage> messages) {
.temperature(temperature)
.topP(topP)
.topK(topK)
.tools(toAnthropicTools(toolSpecifications))
.build();

AnthropicCreateMessageResponse response = withRetry(() -> client.createMessage(request), maxRetries);
Expand All @@ -137,4 +147,6 @@ public Response<AiMessage> generate(List<ChatMessage> messages) {
toFinishReason(response.stopReason)
);
}

// TODO forcing tool use?
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import java.time.Duration;

public abstract class AnthropicClient {

public abstract AnthropicCreateMessageResponse createMessage(AnthropicCreateMessageRequest request);

public abstract void createMessage(AnthropicCreateMessageRequest request, StreamingResponseHandler<AiMessage> handler);

@SuppressWarnings("rawtypes")
Expand All @@ -20,9 +22,11 @@ public static AnthropicClient.Builder builder() {
}

public abstract static class Builder<T extends AnthropicClient, B extends Builder<T, B>> {

public String baseUrl;
public String apiKey;
public String version;
public String beta;
public Duration timeout;
public Boolean logRequests;
public Boolean logResponses;
Expand All @@ -39,7 +43,8 @@ public B baseUrl(String baseUrl) {

public B apiKey(String apiKey) {
if (apiKey == null || apiKey.trim().isEmpty()) {
throw new IllegalArgumentException("Anthropic API Key must be defined.");
throw new IllegalArgumentException("Anthropic API Key must be defined. " +
"It can be generated here: https://console.anthropic.com/settings/keys");
}
this.apiKey = apiKey;
return (B) this;
Expand All @@ -53,6 +58,14 @@ public B version(String version) {
return (B) this;
}

public B beta(String beta) {
if (beta == null) {
throw new IllegalArgumentException("beta cannot be null or empty");
}
this.beta = beta;
return (B) this;
}

public B timeout(Duration timeout) {
if (timeout == null) {
throw new IllegalArgumentException("timeout cannot be null");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
package dev.langchain4j.model.anthropic;

import lombok.Builder;

import java.util.Map;

@Builder
public class AnthropicContent {

public String type;

// when type = "text"
public String text;

// when type = "tool_use"
public String id;
public String name;
public Map<String, Object> input;
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@ public class AnthropicCreateMessageRequest {
Double temperature;
Double topP;
Integer topK;
List<AnthropicTool> tools;
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
package dev.langchain4j.model.anthropic;

public class AnthropicImageContent {
import lombok.EqualsAndHashCode;
import lombok.ToString;

@ToString
@EqualsAndHashCode(callSuper = true)
public class AnthropicImageContent extends AnthropicMessageContent {

public String type = "image";
public AnthropicImageContentSource source;

public AnthropicImageContent(String mediaType, String data) {
super("image");
this.source = new AnthropicImageContentSource("base64", mediaType, data);
}
}