Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Mistral AI supports function calling #767

Merged
merged 50 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
160d53d
Add Mistral AI model provider
czelabueno Jan 15, 2024
0cf6b07
MistralAI chat completions req/resp
czelabueno Jan 15, 2024
d2a47dc
Mistral AI embeddings req/resp
czelabueno Jan 15, 2024
f656f61
Mistral AI Taken usage
czelabueno Jan 15, 2024
62be0b2
Mistral AI models req/resp
czelabueno Jan 15, 2024
9c95a28
Mistral AI Client code
czelabueno Jan 15, 2024
4e6f008
Mistral AI Chat model
czelabueno Jan 15, 2024
c858389
Mistral AI Chat Streaming model support
czelabueno Jan 15, 2024
a3c9fc8
Mistral AI embedding model support
czelabueno Jan 15, 2024
2293142
Mistral Ai get models from API
czelabueno Jan 15, 2024
06e7cce
Mistral AI chat model tests
czelabueno Jan 15, 2024
13f056b
Mistral AI embeddings model tests
czelabueno Jan 15, 2024
ba19306
Mistral AI chat streaming model tests
czelabueno Jan 15, 2024
f09458f
Mistral AI get models tests
czelabueno Jan 15, 2024
63d6924
Merge branch 'main' into main
langchain4j Jan 16, 2024
9117ee3
MistralAI - renamed classes to the project convention names to avoid …
czelabueno Jan 17, 2024
2c1a22f
Mistral AI logRequestResponse and commit suggestions
czelabueno Jan 19, 2024
6b2e2b1
Merge branch 'main' into main
langchain4j Jan 24, 2024
0d87349
Mistral AI token masking until 4 symbols
czelabueno Jan 24, 2024
7da8928
MistralAI update chat model enum
czelabueno Jan 24, 2024
79237cc
MistralAI update embedding model enum
czelabueno Jan 24, 2024
b51d0a5
Mistral AI fix get usageInfo from last chat completion response
czelabueno Jan 24, 2024
2078d32
Mistral AI fix logging streaming and rename enums
czelabueno Jan 24, 2024
49dd2e4
Merge branch 'main' into main
langchain4j Jan 25, 2024
4d44c3b
Merge conflict with upstream repo
czelabueno Jan 29, 2024
134a9dd
Merge remote-tracking branch 'upstream/main'
czelabueno Feb 5, 2024
d982a78
Merge remote-tracking branch 'upstream/main'
czelabueno Feb 10, 2024
004e942
Merge remote-tracking branch 'upstream/main'
czelabueno Feb 23, 2024
08938d7
Merge remote-tracking branch 'upstream/main'
czelabueno Mar 8, 2024
1d902d7
Merge remote-tracking branch 'upstream/main'
czelabueno Mar 11, 2024
9a60c5a
Adding function calling feature
czelabueno Mar 15, 2024
91c8925
adding factories for each model
czelabueno Mar 15, 2024
ba884b0
updating integration testing for function calling
czelabueno Mar 15, 2024
056a172
updating load factories strategy
czelabueno Mar 15, 2024
f7028e7
Merge remote-tracking branch 'upstream/main'
czelabueno Mar 15, 2024
9ddb2ce
update overview integration table
czelabueno Mar 16, 2024
53a6840
Update docs/docs/integrations/index.mdx
LizeRaes Mar 16, 2024
507e748
Update docs/docs/integrations/index.mdx
LizeRaes Mar 16, 2024
b07f87b
improve responseFormat javadocs
czelabueno Mar 24, 2024
faa959c
rename test method name properly
czelabueno Mar 24, 2024
f597119
rename test method name for tool choice
czelabueno Mar 24, 2024
69b5385
rename test method name for multiple tools
czelabueno Mar 24, 2024
253b0ae
rename test method name for finish reason
czelabueno Mar 24, 2024
7d97c54
uncomment assert for streaming chat
czelabueno Mar 24, 2024
8671041
rename test method when tool choice is auto
czelabueno Mar 24, 2024
58b1bc7
rename test method for multiple tools
czelabueno Mar 24, 2024
3cdc567
Fixing the comments of mistralai function calling support
czelabueno Mar 24, 2024
3316d52
Update AiServicesWithToolIT to parameterized tests to supply OpenAI a…
czelabueno Mar 25, 2024
1f21d10
Merge remote-tracking branch 'upstream/main'
czelabueno Mar 25, 2024
aaf4041
Merge with changes of BuildFactory and adding to TestStreamingRespons…
czelabueno Mar 25, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,13 @@ public void onComplete(Response<T> response) {

String expectedTextContent = textContentBuilder.toString();
if (response.content() instanceof AiMessage) {
assertThat(((AiMessage) response.content()).text()).isEqualTo(expectedTextContent);
AiMessage aiMessage = (AiMessage) response.content();
if (aiMessage.hasToolExecutionRequests()){
assertThat(aiMessage.toolExecutionRequests().size()).isGreaterThan(0);
assertThat(aiMessage.text()).isNull();
} else {
assertThat(aiMessage.text()).isEqualTo(expectedTextContent);
}
} else if (response.content() instanceof String) {
assertThat(response.content()).isEqualTo(expectedTextContent);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import com.google.gson.FieldNamingPolicy;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import lombok.Builder;
import okhttp3.OkHttpClient;
import okhttp3.ResponseBody;
import okhttp3.sse.EventSource;
Expand All @@ -20,11 +20,10 @@
import retrofit2.converter.gson.GsonConverterFactory;

import java.io.IOException;
import java.time.Duration;
import java.util.List;

import static dev.langchain4j.internal.Utils.isNullOrBlank;
import static dev.langchain4j.model.mistralai.DefaultMistralAiHelper.finishReasonFrom;
import static dev.langchain4j.model.mistralai.DefaultMistralAiHelper.tokenUsageFrom;
import static dev.langchain4j.internal.Utils.*;
import static dev.langchain4j.model.mistralai.DefaultMistralAiHelper.*;

class DefaultMistralAiClient extends MistralAiClient {

Expand Down Expand Up @@ -101,6 +100,7 @@ public MistralAiChatCompletionResponse chatCompletion(MistralAiChatCompletionReq
public void streamingChatCompletion(MistralAiChatCompletionRequest request, StreamingResponseHandler<AiMessage> handler) {
EventSourceListener eventSourceListener = new EventSourceListener() {
final StringBuffer contentBuilder = new StringBuffer();
List<ToolExecutionRequest> toolExecutionRequests;
TokenUsage tokenUsage;
FinishReason finishReason;

Expand All @@ -117,8 +117,15 @@ public void onEvent(EventSource eventSource, String id, String type, String data
LOGGER.debug("onEvent() {}", data);
}
if ("[DONE]".equals(data)) {
AiMessage aiMessage;
if (!isNullOrEmpty(toolExecutionRequests)){
aiMessage = AiMessage.from(toolExecutionRequests);
} else {
aiMessage = AiMessage.from(contentBuilder.toString());
}

Response<AiMessage> response = Response.from(
AiMessage.from(contentBuilder.toString()),
aiMessage,
tokenUsage,
finishReason
);
Expand All @@ -127,9 +134,17 @@ public void onEvent(EventSource eventSource, String id, String type, String data
try {
MistralAiChatCompletionResponse chatCompletionResponse = GSON.fromJson(data, MistralAiChatCompletionResponse.class);
MistralAiChatCompletionChoice choice = chatCompletionResponse.getChoices().get(0);

String chunk = choice.getDelta().getContent();
contentBuilder.append(chunk);
handler.onNext(chunk);
if (isNotNullOrBlank(chunk)) {
contentBuilder.append(chunk);
handler.onNext(chunk);
}

List<MistralAiToolCall> toolCalls = choice.getDelta().getToolCalls();
if (!isNullOrEmpty(toolCalls)) {
toolExecutionRequests = toToolExecutionRequests(toolCalls);
}

MistralAiUsage usageInfo = chatCompletionResponse.getUsage();
if (usageInfo != null) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
package dev.langchain4j.model.mistralai;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolParameters;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.*;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.TokenUsage;
Expand All @@ -11,8 +14,10 @@
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

import static dev.langchain4j.model.output.FinishReason.LENGTH;
import static dev.langchain4j.model.output.FinishReason.STOP;
import static dev.langchain4j.data.message.AiMessage.aiMessage;
import static dev.langchain4j.internal.Utils.isNullOrBlank;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.model.output.FinishReason.*;
import static java.util.stream.Collectors.toList;

public class DefaultMistralAiHelper {
Expand All @@ -29,41 +34,70 @@ static List<MistralAiChatMessage> toMistralAiMessages(List<ChatMessage> messages
}

static MistralAiChatMessage toMistralAiMessage(ChatMessage message) {
return MistralAiChatMessage.builder()
.role(toMistralAiRole(message.type()))
.content(toMistralChatMessageContent(message))
.build();
}

private static MistralAiRole toMistralAiRole(ChatMessageType chatMessageType) {
switch (chatMessageType) {
case SYSTEM:
return MistralAiRole.SYSTEM;
case AI:
return MistralAiRole.ASSISTANT;
case USER:
return MistralAiRole.USER;
default:
throw new IllegalArgumentException("Unknown chat message type: " + chatMessageType);
}
}

private static String toMistralChatMessageContent(ChatMessage message) {
if (message instanceof SystemMessage) {
return ((SystemMessage) message).text();
return MistralAiChatMessage.builder()
.role(MistralAiRole.SYSTEM)
.content(((SystemMessage) message).text())
.build();
}

if (message instanceof AiMessage) {
return ((AiMessage) message).text();
AiMessage aiMessage = (AiMessage) message;

if (!aiMessage.hasToolExecutionRequests()) {
return MistralAiChatMessage.builder()
.role(MistralAiRole.ASSISTANT)
.content(aiMessage.text())
.build();
}

List<MistralAiToolCall> toolCalls = aiMessage.toolExecutionRequests().stream()
.map(DefaultMistralAiHelper::toMistralAiToolCall)
.collect(toList());

if (isNullOrBlank(aiMessage.text())){
return MistralAiChatMessage.builder()
.role(MistralAiRole.ASSISTANT)
.content(null)
.toolCalls(toolCalls)
.build();
}

return MistralAiChatMessage.builder()
.role(MistralAiRole.ASSISTANT)
.content(aiMessage.text())
czelabueno marked this conversation as resolved.
Show resolved Hide resolved
.toolCalls(toolCalls)
.build();
}

if (message instanceof UserMessage) {
return message.text(); // MistralAI support Text Content only as String
return MistralAiChatMessage.builder()
.role(MistralAiRole.USER)
.content(message.text()) // MistralAI support Text Content only as String
.build();
}

if (message instanceof ToolExecutionResultMessage){
return MistralAiChatMessage.builder()
.role(MistralAiRole.TOOL)
.name(((ToolExecutionResultMessage) message).toolName())
.content(((ToolExecutionResultMessage) message).text())
.build();
}

throw new IllegalArgumentException("Unknown message type: " + message.type());
}

static MistralAiToolCall toMistralAiToolCall(ToolExecutionRequest toolExecutionRequest) {
return MistralAiToolCall.builder()
.id(toolExecutionRequest.id())
langchain4j marked this conversation as resolved.
Show resolved Hide resolved
.function(MistralAiFunctionCall.builder()
.name(toolExecutionRequest.name())
.arguments(toolExecutionRequest.arguments())
.build())
.build();
}

public static TokenUsage tokenUsageFrom(MistralAiUsage mistralAiUsage) {
if (mistralAiUsage == null) {
return null;
Expand All @@ -84,12 +118,76 @@ public static FinishReason finishReasonFrom(String mistralAiFinishReason) {
return STOP;
case "length":
return LENGTH;
case "tool_calls":
return TOOL_EXECUTION;
case "content_filter":
return CONTENT_FILTER;
case "model_length":
default:
return null;
}
}

public static AiMessage aiMessageFrom(MistralAiChatCompletionResponse response) {
MistralAiChatMessage aiMistralMessage = response.getChoices().get(0).getMessage();
List<MistralAiToolCall> toolCalls = aiMistralMessage.getToolCalls();
if (!isNullOrEmpty(toolCalls)){
return AiMessage.from(toToolExecutionRequests(toolCalls));
}
return AiMessage.from(aiMistralMessage.getContent());
}

public static List<ToolExecutionRequest> toToolExecutionRequests(List<MistralAiToolCall> mistralAiToolCalls) {
return mistralAiToolCalls.stream()
.filter(toolCall -> toolCall.getType() == MistralAiToolType.FUNCTION)
.map(DefaultMistralAiHelper::toToolExecutionRequest)
.collect(toList());
}

public static ToolExecutionRequest toToolExecutionRequest(MistralAiToolCall mistralAiToolCall) {
return ToolExecutionRequest.builder()
.id(mistralAiToolCall.getId())
.name(mistralAiToolCall.getFunction().getName())
.arguments(mistralAiToolCall.getFunction().getArguments())
.build();
}

static List<MistralAiTool> toMistralAiTools(List<ToolSpecification> toolSpecifications) {
return toolSpecifications.stream()
.map(DefaultMistralAiHelper::toMistralAiTool)
.collect(toList());
}

static MistralAiTool toMistralAiTool(ToolSpecification toolSpecification) {
MistralAiFunction function = MistralAiFunction.builder()
.name(toolSpecification.name())
.description(toolSpecification.description())
.parameters(toMistralAiParameters(toolSpecification.parameters()))
.build();
return MistralAiTool.from(function);
}

static MistralAiParameters toMistralAiParameters(ToolParameters parameters){
if (parameters == null) {
return MistralAiParameters.builder().build();
}
return MistralAiParameters.from(parameters);
}

static MistralAiResponseFormat toMistralAiResponseFormat(String responseFormat) {
if (responseFormat == null) {
return null;
}
switch (responseFormat) {
case "text":
return MistralAiResponseFormat.fromType(MistralAiResponseFormatType.TEXT);
case "json_object":
return MistralAiResponseFormat.fromType(MistralAiResponseFormatType.JSON_OBJECT);
czelabueno marked this conversation as resolved.
Show resolved Hide resolved
default:
throw new IllegalArgumentException("Unknown response format: " + responseFormat);
}
}

static String getHeaders(Headers headers) {
return StreamSupport.stream(headers.spliterator(), false).map(header -> {
String headerKey = header.component1();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,8 @@ public class MistralAiChatCompletionRequest {
private Boolean stream;
private Boolean safePrompt;
private Integer randomSeed;
private List<MistralAiTool> tools;
private MistralAiToolChoiceName toolChoice;
private MistralAiResponseFormat responseFormat;

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import lombok.Data;
import lombok.NoArgsConstructor;

import java.util.List;

@Data
@NoArgsConstructor
@AllArgsConstructor
Expand All @@ -13,4 +15,6 @@ public class MistralAiChatMessage {

private MistralAiRole role;
private String content;
private String name;
private List<MistralAiToolCall> toolCalls;
}
Loading
Loading