Skip to content

Commit

Permalink
updating integration testing for function calling
Browse files Browse the repository at this point in the history
  • Loading branch information
czelabueno committed Mar 15, 2024
1 parent 91c8925 commit ba884b0
Show file tree
Hide file tree
Showing 3 changed files with 519 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -1,19 +1,43 @@
package dev.langchain4j.model.mistralai;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import org.junit.jupiter.api.Test;

import java.util.ArrayList;
import java.util.List;

import static dev.langchain4j.agent.tool.JsonSchemaProperty.STRING;
import static dev.langchain4j.data.message.UserMessage.userMessage;
import static dev.langchain4j.model.output.FinishReason.LENGTH;
import static dev.langchain4j.model.output.FinishReason.STOP;
import static dev.langchain4j.model.output.FinishReason.*;
import static dev.langchain4j.model.output.FinishReason.TOOL_EXECUTION;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static org.assertj.core.api.Assertions.assertThat;

class MistralAiChatModelIT {

ToolSpecification retrievePaymentStatus = ToolSpecification.builder()
.name("retrieve-payment-status")
.description("Retrieve Payment Status")
.addParameter("transactionId", STRING)
.build();

ChatLanguageModel mistralLargeModel = MistralAiChatModel.builder()
.apiKey(System.getenv("MISTRAL_AI_API_KEY"))
.modelName(MistralAiChatModelName.MISTRAL_LARGE_LATEST.toString())
.temperature(0.1)
.logRequests(true)
.logResponses(true)
.build();

ChatLanguageModel model = MistralAiChatModel.builder()
.apiKey(System.getenv("MISTRAL_AI_API_KEY"))
.temperature(0.1)
Expand Down Expand Up @@ -130,8 +154,10 @@ void should_generate_answer_in_french_using_model_small_and_return_token_usage_a
// given - Mistral Small = Mistral-8X7B
ChatLanguageModel model = MistralAiChatModel.builder()
.apiKey(System.getenv("MISTRAL_AI_API_KEY"))
.modelName(MistralAiChatModelName.MISTRAL_SMALL.toString())
.modelName(MistralAiChatModelName.OPEN_MIXTRAL_8x7B.toString())
.temperature(0.1)
.logRequests(true)
.logResponses(true)
.build();

UserMessage userMessage = userMessage("Quelle est la capitale du Pérou?");
Expand All @@ -157,8 +183,10 @@ void should_generate_answer_in_spanish_using_model_small_and_return_token_usage_
// given - Mistral Small = Mistral-8X7B
ChatLanguageModel model = MistralAiChatModel.builder()
.apiKey(System.getenv("MISTRAL_AI_API_KEY"))
.modelName(MistralAiChatModelName.MISTRAL_SMALL.toString())
.modelName(MistralAiChatModelName.OPEN_MIXTRAL_8x7B.toString())
.temperature(0.1)
.logRequests(true)
.logResponses(true)
.build();

UserMessage userMessage = userMessage("¿Cuál es la capital de Perú?");
Expand All @@ -181,11 +209,13 @@ void should_generate_answer_in_spanish_using_model_small_and_return_token_usage_
@Test
void should_generate_answer_using_model_medium_and_return_token_usage_and_finish_reason_length() {

// given - Mistral Medium = currently relies on an internal prototype model.
// given - Mistral Medium 2312.
ChatLanguageModel model = MistralAiChatModel.builder()
.apiKey(System.getenv("MISTRAL_AI_API_KEY"))
.modelName(MistralAiChatModelName.MISTRAL_MEDIUM.toString())
.modelName(MistralAiChatModelName.MISTRAL_MEDIUM_LATEST.toString())
.maxTokens(10)
.logRequests(true)
.logResponses(true)
.build();

UserMessage userMessage = userMessage("What is the capital of Peru?");
Expand All @@ -205,5 +235,224 @@ void should_generate_answer_using_model_medium_and_return_token_usage_and_finish
assertThat(response.finishReason()).isEqualTo(LENGTH);
}

@Test
void should_return_toolCalls_as_finishReason(){

// given
UserMessage userMessage = userMessage("What is the status of transaction T123?");
List<ToolSpecification> toolSpecifications = singletonList(retrievePaymentStatus);

// when
Response<AiMessage> response = mistralLargeModel.generate(singletonList(userMessage), toolSpecifications);

// then
AiMessage aiMessage = response.content();
assertThat(aiMessage.text()).isNullOrEmpty();
assertThat(aiMessage.toolExecutionRequests()).hasSize(1);

ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0);
assertThat(toolExecutionRequest.name()).isEqualTo("retrieve-payment-status");
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"transactionId\":\"T123\"}");

TokenUsage tokenUsage = response.tokenUsage();
assertThat(tokenUsage.inputTokenCount()).isEqualTo(78);
assertThat(tokenUsage.outputTokenCount()).isEqualTo(28);
assertThat(tokenUsage.totalTokenCount())
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());

assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION);
}

@Test
void should_execute_function_when_toolChoice_is_auto_and_answer(){
// given
ToolSpecification retrievePaymentDate = ToolSpecification.builder()
.name("retrieve-payment-date")
.description("Retrieve Payment Date")
.addParameter("transactionId", STRING)
.build();

List<ChatMessage> chatMessages = new ArrayList<>();
UserMessage userMessage = userMessage("What is the status of transaction T123?");

chatMessages.add(userMessage);
List<ToolSpecification> toolSpecifications = asList(retrievePaymentStatus,retrievePaymentDate);

// when
Response<AiMessage> response = mistralLargeModel.generate(chatMessages, toolSpecifications);

// then
AiMessage aiMessage = response.content();
assertThat(aiMessage.text()).isNullOrEmpty();
assertThat(aiMessage.toolExecutionRequests()).hasSize(1);

ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0);
assertThat(toolExecutionRequest.name()).isEqualTo("retrieve-payment-status");
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"transactionId\":\"T123\"}");

assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION);

chatMessages.add(aiMessage);

// given
ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.from(toolExecutionRequest, "{\"status\": \"PAID\"}");
chatMessages.add(toolExecutionResultMessage);

// when
Response<AiMessage> response2 = mistralLargeModel.generate(chatMessages);

// then
AiMessage aiMessage2 = response2.content();
assertThat(aiMessage2.text()).containsIgnoringCase("T123");
assertThat(aiMessage2.text()).containsIgnoringCase("paid");
assertThat(aiMessage2.toolExecutionRequests()).isNull();

TokenUsage tokenUsage2 = response2.tokenUsage();
assertThat(tokenUsage2.inputTokenCount()).isEqualTo(69);
assertThat(tokenUsage2.outputTokenCount()).isGreaterThan(0);
assertThat(tokenUsage2.totalTokenCount())
.isEqualTo(tokenUsage2.inputTokenCount() + tokenUsage2.outputTokenCount());

assertThat(response2.finishReason()).isEqualTo(STOP);
}

@Test
void should_execute_tool_forcefully_when_toolChoice_is_any_and_answer() {
// given
ToolSpecification retrievePaymentDate = ToolSpecification.builder()
.name("retrieve-payment-date")
.description("Retrieve Payment Date")
.addParameter("transactionId", STRING)
.build();

List<ChatMessage> chatMessages = new ArrayList<>();
UserMessage userMessage = userMessage("What is the payment date of transaction T123?");
chatMessages.add(userMessage);

// when
Response<AiMessage> response = mistralLargeModel.generate(singletonList(userMessage), retrievePaymentDate);

// then
AiMessage aiMessage = response.content();
assertThat(aiMessage.text()).isNull();
assertThat(aiMessage.toolExecutionRequests()).hasSize(1);

ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0);
assertThat(toolExecutionRequest.name()).isEqualTo("retrieve-payment-date");
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"transactionId\":\"T123\"}");

TokenUsage tokenUsage = response.tokenUsage();
assertThat(tokenUsage.inputTokenCount()).isEqualTo(79);
assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0);
assertThat(tokenUsage.totalTokenCount())
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());

assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION);
chatMessages.add(aiMessage);

// given
ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.from(toolExecutionRequest, "{\"date\": \"2024-03-11\"}");
chatMessages.add(toolExecutionResultMessage);

// when
Response<AiMessage> response2 = mistralLargeModel.generate(chatMessages);

// then
AiMessage aiMessage2 = response2.content();
assertThat(aiMessage2.text()).containsIgnoringCase("T123");
assertThat(aiMessage2.text()).containsIgnoringWhitespaces("March 11, 2024");
assertThat(aiMessage2.toolExecutionRequests()).isNull();

TokenUsage tokenUsage2 = response2.tokenUsage();
assertThat(tokenUsage2.inputTokenCount()).isEqualTo(78);
assertThat(tokenUsage2.outputTokenCount()).isGreaterThan(0);
assertThat(tokenUsage2.totalTokenCount())
.isEqualTo(tokenUsage2.inputTokenCount() + tokenUsage2.outputTokenCount());

assertThat(response2.finishReason()).isEqualTo(STOP);
}

@Test
void should_return_valid_json_object(){

// given
String userMessage = "Return JSON with two fields: transactionId and status with the values T123 and paid.";

String expectedJson = "{\"transactionId\":\"T123\",\"status\":\"paid\"}";

ChatLanguageModel mistralLargeModel = MistralAiChatModel.builder()
.apiKey(System.getenv("MISTRAL_AI_API_KEY"))
.modelName(MistralAiChatModelName.MISTRAL_LARGE_LATEST.toString())
.temperature(0.1)
.responseFormat(MistralAiResponseFormatType.JSON_OBJECT.toString())
.logRequests(true)
.logResponses(true)
.build();

// when
String json = mistralLargeModel.generate(userMessage);

// then
assertThat(json).isEqualToIgnoringWhitespace(expectedJson);
}

@Test
void should_multiple_tools_then_answer(){
// given
ToolSpecification retrievePaymentDate = ToolSpecification.builder()
.name("retrieve-payment-date")
.description("Retrieve Payment Date")
.addParameter("transactionId", STRING)
.build();

List<ChatMessage> chatMessages = new ArrayList<>();
UserMessage userMessage = userMessage("What is the status and the payment date of transaction T123?");

chatMessages.add(userMessage);
List<ToolSpecification> toolSpecifications = asList(retrievePaymentStatus,retrievePaymentDate);

// when
Response<AiMessage> response = mistralLargeModel.generate(chatMessages, toolSpecifications);

// then
AiMessage aiMessage = response.content();
assertThat(aiMessage.text()).isNullOrEmpty();
assertThat(aiMessage.toolExecutionRequests()).hasSize(2);

ToolExecutionRequest toolExecutionRequest1 = aiMessage.toolExecutionRequests().get(0);
assertThat(toolExecutionRequest1.name()).isEqualTo("retrieve-payment-status");
assertThat(toolExecutionRequest1.arguments()).isEqualToIgnoringWhitespace("{\"transactionId\":\"T123\"}");

ToolExecutionRequest toolExecutionRequest2 = aiMessage.toolExecutionRequests().get(1);
assertThat(toolExecutionRequest2.name()).isEqualTo("retrieve-payment-date");
assertThat(toolExecutionRequest2.arguments()).isEqualToIgnoringWhitespace("{\"transactionId\":\"T123\"}");

assertThat(response.finishReason()).isEqualTo(TOOL_EXECUTION);

chatMessages.add(aiMessage);

// given
ToolExecutionResultMessage toolExecutionResultMessage1 = ToolExecutionResultMessage.from(toolExecutionRequest1, "{\"status\": \"PAID\"}");
chatMessages.add(toolExecutionResultMessage1);
ToolExecutionResultMessage toolExecutionResultMessage2 = ToolExecutionResultMessage.from(toolExecutionRequest2, "{\"date\": \"2024-03-11\"}");
chatMessages.add(toolExecutionResultMessage2);

// when
Response<AiMessage> response2 = mistralLargeModel.generate(chatMessages);

// then
AiMessage aiMessage2 = response2.content();
assertThat(aiMessage2.text()).contains("T123");
assertThat(aiMessage2.text()).containsIgnoringCase("paid");
assertThat(aiMessage2.text()).containsIgnoringWhitespaces("March 11, 2024");
assertThat(aiMessage2.toolExecutionRequests()).isNull();

TokenUsage tokenUsage2 = response2.tokenUsage();
assertThat(tokenUsage2.inputTokenCount()).isEqualTo(128);
assertThat(tokenUsage2.outputTokenCount()).isGreaterThan(0);
assertThat(tokenUsage2.totalTokenCount())
.isEqualTo(tokenUsage2.inputTokenCount() + tokenUsage2.outputTokenCount());

assertThat(response2.finishReason()).isEqualTo(STOP);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ void should_return_all_model_cards() {

// then
assertThat(response.content().size()).isGreaterThan(0);
assertThat(response.content()).extracting("id").contains(MistralAiChatModelName.MISTRAL_TINY.toString());
assertThat(response.content()).extracting("id").contains(MistralAiChatModelName.OPEN_MISTRAL_7B.toString());
assertThat(response.content()).extracting("object").contains("model");
assertThat(response.content()).extracting("permission").isNotNull();
}
Expand Down
Loading

0 comments on commit ba884b0

Please sign in to comment.