Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OpenAI: Support parallel tool calling #338

Merged
merged 5 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions langchain4j-azure-open-ai/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>

</dependencies>

<licenses>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.TokenCountEstimator;
import dev.langchain4j.model.openai.OpenAiTokenizer;
import dev.langchain4j.model.output.Response;

import java.time.Duration;
import java.util.List;

import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.model.azure.AzureOpenAiModelName.GPT_3_5_TURBO;
import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.setupOpenAIClient;
import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.toFunctions;
import static java.util.Collections.singletonList;

/**
Expand Down Expand Up @@ -135,7 +137,7 @@ private void generate(List<ChatMessage> messages,
Integer inputTokenCount = tokenizer == null ? null : tokenizer.estimateTokenCountInMessages(messages);

if (toolSpecifications != null && !toolSpecifications.isEmpty()) {
options.setFunctions(InternalAzureOpenAiHelper.toFunctions(toolSpecifications));
options.setFunctions(toFunctions(toolSpecifications));
if (tokenizer != null) {
inputTokenCount += tokenizer.estimateTokenCountInToolSpecifications(toolSpecifications);
}
Expand All @@ -156,7 +158,8 @@ private void generate(List<ChatMessage> messages,
responseBuilder.append(chatCompletions);
handle(chatCompletions, handler);
});
handler.onComplete(responseBuilder.build());
Response<AiMessage> response = responseBuilder.build(tokenizer, toolThatMustBeExecuted != null);
handler.onComplete(response);
} catch (Exception exception) {
handler.onError(exception);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ public void generate(String prompt, StreamingResponseHandler<String> handler) {
handle(completions, handler);
});

Response<AiMessage> response = responseBuilder.build();
Response<AiMessage> response = responseBuilder.build(tokenizer, false);
handler.onComplete(Response.from(
response.content().text(),
response.tokenUsage(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import com.azure.ai.openai.models.*;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.Tokenizer;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;

import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;

import static dev.langchain4j.model.openai.InternalOpenAiHelper.finishReasonFrom;
import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.finishReasonFrom;
import static java.util.Collections.singletonList;

/**
* This class needs to be thread safe because it is called when a streaming result comes back
Expand All @@ -21,9 +22,9 @@ class AzureOpenAiStreamingResponseBuilder {
private final StringBuffer contentBuilder = new StringBuffer();
private final StringBuffer toolNameBuilder = new StringBuffer();
private final StringBuffer toolArgumentsBuilder = new StringBuffer();
private volatile CompletionsFinishReason finishReason;

private final Integer inputTokenCount;
private final AtomicInteger outputTokenCount = new AtomicInteger();
private volatile String finishReason;

public AzureOpenAiStreamingResponseBuilder(Integer inputTokenCount) {
this.inputTokenCount = inputTokenCount;
Expand All @@ -46,7 +47,7 @@ public void append(ChatCompletions completions) {

CompletionsFinishReason finishReason = chatCompletionChoice.getFinishReason();
if (finishReason != null) {
this.finishReason = finishReason.toString();
this.finishReason = finishReason;
}

com.azure.ai.openai.models.ChatMessage delta = chatCompletionChoice.getDelta();
Expand All @@ -57,20 +58,17 @@ public void append(ChatCompletions completions) {
String content = delta.getContent();
if (content != null) {
contentBuilder.append(content);
outputTokenCount.incrementAndGet();
return;
}

FunctionCall functionCall = delta.getFunctionCall();
if (functionCall != null) {
if (functionCall.getName() != null) {
toolNameBuilder.append(functionCall.getName());
outputTokenCount.incrementAndGet();
}

if (functionCall.getArguments() != null) {
toolArgumentsBuilder.append(functionCall.getArguments());
outputTokenCount.incrementAndGet();
}
}
}
Expand All @@ -92,39 +90,63 @@ public void append(Completions completions) {

CompletionsFinishReason completionsFinishReason = completionChoice.getFinishReason();
if (completionsFinishReason != null) {
this.finishReason = completionsFinishReason.toString();
this.finishReason = completionsFinishReason;
}

String token = completionChoice.getText();
if (token != null) {
contentBuilder.append(token);
outputTokenCount.incrementAndGet();
}
}

public Response<AiMessage> build() {
public Response<AiMessage> build(Tokenizer tokenizer, boolean forcefulToolExecution) {

String content = contentBuilder.toString();
if (!content.isEmpty()) {
return Response.from(
AiMessage.from(content),
new TokenUsage(inputTokenCount, outputTokenCount.get()),
tokenUsage(content, tokenizer),
finishReasonFrom(finishReason)
);
}

String toolName = toolNameBuilder.toString();
if (!toolName.isEmpty()) {
ToolExecutionRequest toolExecutionRequest = ToolExecutionRequest.builder()
.name(toolName)
.arguments(toolArgumentsBuilder.toString())
.build();
return Response.from(
AiMessage.from(ToolExecutionRequest.builder()
.name(toolName)
.arguments(toolArgumentsBuilder.toString())
.build()),
new TokenUsage(inputTokenCount, outputTokenCount.get()),
AiMessage.from(toolExecutionRequest),
tokenUsage(toolExecutionRequest, tokenizer, forcefulToolExecution),
finishReasonFrom(finishReason)
);
}

return null;
}

private TokenUsage tokenUsage(String content, Tokenizer tokenizer) {
if (tokenizer == null) {
return null;
}
int outputTokenCount = tokenizer.estimateTokenCountInText(content);
return new TokenUsage(inputTokenCount, outputTokenCount);
}

private TokenUsage tokenUsage(ToolExecutionRequest toolExecutionRequest, Tokenizer tokenizer, boolean forcefulToolExecution) {
if (tokenizer == null) {
return null;
}

int outputTokenCount = 0;
if (forcefulToolExecution) {
// OpenAI calculates output tokens differently when tool is executed forcefully
outputTokenCount += tokenizer.estimateTokenCountInForcefulToolExecutionRequest(toolExecutionRequest);
} else {
outputTokenCount = tokenizer.estimateTokenCountInToolExecutionRequests(singletonList(toolExecutionRequest));
}

return new TokenUsage(inputTokenCount, outputTokenCount);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,10 @@ private static String nameFrom(ChatMessage message) {
private static FunctionCall functionCallFrom(ChatMessage message) {
if (message instanceof AiMessage) {
AiMessage aiMessage = (AiMessage) message;
if (aiMessage.toolExecutionRequest() != null) {
return new FunctionCall(aiMessage.toolExecutionRequest().name(),
aiMessage.toolExecutionRequest().arguments());
if (aiMessage.hasToolExecutionRequests()) {
// TODO switch to tools once supported
ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0);
return new FunctionCall(toolExecutionRequest.name(), toolExecutionRequest.arguments());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
import static dev.langchain4j.model.output.FinishReason.STOP;
import static org.assertj.core.api.Assertions.assertThat;

public class AzureOpenAIChatModelIT {
public class AzureOpenAiChatModelIT {

Logger logger = LoggerFactory.getLogger(AzureOpenAIChatModelIT.class);
Logger logger = LoggerFactory.getLogger(AzureOpenAiChatModelIT.class);

@Test
void should_generate_answer_and_return_token_usage_and_finish_reason_stop() {
Expand Down Expand Up @@ -102,14 +102,17 @@ void should_call_function_with_argument() {

Response<AiMessage> response = model.generate(Collections.singletonList(userMessage), toolSpecification);

assertThat(response.content().text()).isBlank();
assertThat(response.content().toolExecutionRequest().name()).isEqualTo(toolName);
AiMessage aiMessage = response.content();
assertThat(aiMessage.text()).isBlank();

assertThat(aiMessage.toolExecutionRequests()).hasSize(1);
ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0);
assertThat(toolExecutionRequest.name()).isEqualTo(toolName);

// We should get a response telling how to call the "getCurrentWeather" function, with the correct parameters in JSON format.
logger.info(response.toString());

// We can now call the function with the correct parameters.
ToolExecutionRequest toolExecutionRequest = response.content().toolExecutionRequest();
WeatherLocation weatherLocation = BinaryData.fromString(toolExecutionRequest.arguments()).toObject(WeatherLocation.class);
int currentWeather = 0;
currentWeather = getCurrentWeather(weatherLocation);
Expand All @@ -121,12 +124,13 @@ void should_call_function_with_argument() {
assertThat(weather).isEqualTo("The weather in Paris, France is 35 degrees celsius.");

// Now that we know the function's result, we can call the model again with the result as input.
ToolExecutionResultMessage toolExecutionResultMessage = toolExecutionResultMessage(toolName, weather);
ToolExecutionResultMessage toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest, weather);
SystemMessage systemMessage = SystemMessage.systemMessage("If the weather is above 30 degrees celsius, recommend the user wears a t-shirt and shorts.");

List<ChatMessage> chatMessages = new ArrayList<>();
chatMessages.add(systemMessage);
chatMessages.add(userMessage);
chatMessages.add(aiMessage);
chatMessages.add(toolExecutionResultMessage);

Response<AiMessage> response2 = model.generate(chatMessages);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ public void onError(Throwable error) {
AiMessage aiMessage = response.content();
assertThat(aiMessage.text()).isNull();

ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequest();
assertThat(aiMessage.toolExecutionRequests()).hasSize(1);
ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0);
assertThat(toolExecutionRequest.name()).isEqualTo("calculator");
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"first\": 2, \"second\": 2}");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import java.util.Objects;

import static dev.langchain4j.internal.Utils.quoted;
import static java.util.Collections.singletonMap;

public class JsonSchemaProperty {

Expand Down Expand Up @@ -95,4 +96,8 @@ public static JsonSchemaProperty enums(Class<?> enumClass) {

return from("enum", enumClass.getEnumConstants());
}

public static JsonSchemaProperty items(JsonSchemaProperty type) {
return from("items", singletonMap(type.key, type.value));
}
}
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
package dev.langchain4j.agent.tool;

import static dev.langchain4j.internal.Utils.quoted;
import java.util.Objects;

import static dev.langchain4j.internal.Utils.quoted;

public class ToolExecutionRequest {

private final String id;
private final String name;
private final String arguments;

private ToolExecutionRequest(Builder builder) {
this.id = builder.id;
this.name = builder.name;
this.arguments = builder.arguments;
}

public String id() {
return id;
}

public String name() {
return name;
}
Expand All @@ -29,13 +36,15 @@ public boolean equals(Object another) {
}

private boolean equalTo(ToolExecutionRequest another) {
return Objects.equals(name, another.name)
return Objects.equals(id, another.id)
&& Objects.equals(name, another.name)
&& Objects.equals(arguments, another.arguments);
}

@Override
public int hashCode() {
int h = 5381;
h += (h << 5) + Objects.hashCode(id);
h += (h << 5) + Objects.hashCode(name);
h += (h << 5) + Objects.hashCode(arguments);
return h;
Expand All @@ -44,7 +53,8 @@ public int hashCode() {
@Override
public String toString() {
return "ToolExecutionRequest {"
+ " name = " + quoted(name)
+ " id = " + quoted(id)
+ ", name = " + quoted(name)
+ ", arguments = " + quoted(arguments)
+ " }";
}
Expand All @@ -55,12 +65,18 @@ public static Builder builder() {

public static final class Builder {

private String id;
private String name;
private String arguments;

private Builder() {
}

public Builder id(String id) {
this.id = id;
return this;
}

public Builder name(String name) {
this.name = name;
return this;
Expand Down