org.assertj
assertj-core
diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/DefaultMistralAiHelper.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/DefaultMistralAiHelper.java
index 2e00d9fc40..cec7f89d8a 100644
--- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/DefaultMistralAiHelper.java
+++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/DefaultMistralAiHelper.java
@@ -4,12 +4,13 @@
import dev.langchain4j.data.message.ChatMessageType;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.TokenUsage;
-import okhttp3.Response;
+import okhttp3.Headers;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import java.io.IOException;
import java.util.List;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
@@ -17,11 +18,12 @@
import static dev.langchain4j.model.output.FinishReason.*;
import static java.util.stream.Collectors.toList;
-public class DefaultMistralAiHelper{
+class DefaultMistralAiHelper{
private static final Logger LOGGER = LoggerFactory.getLogger(DefaultMistralAiHelper.class);
static final String MISTRALAI_API_URL = "https://api.mistral.ai/v1";
static final String MISTRALAI_API_CREATE_EMBEDDINGS_ENCODING_FORMAT = "float";
+ private static final Pattern MISTRAI_API_KEY_BEARER_PATTERN = Pattern.compile("^(Bearer\\s*) ([A-Za-z0-9]{1,32})$");
public static String ensureNotBlankApiKey(String value) {
if (isNullOrBlank(value)) {
@@ -47,21 +49,21 @@ public static MistralChatMessage toMistralAiMessage(ChatMessage message) {
.build();
}
- private static MistralRoleType toMistralAiRole(ChatMessageType chatMessageType) {
+ private static MistralRoleName toMistralAiRole(ChatMessageType chatMessageType) {
switch (chatMessageType) {
case SYSTEM:
- return MistralRoleType.SYSTEM;
+ return MistralRoleName.SYSTEM;
case AI:
- return MistralRoleType.ASSISTANT;
+ return MistralRoleName.ASSISTANT;
case USER:
- return MistralRoleType.USER;
+ return MistralRoleName.USER;
default:
throw new IllegalArgumentException("Unknown chat message type: " + chatMessageType);
}
}
- public static TokenUsage tokenUsageFrom(MistralUsageInfo mistralAiUsage) {
+ static TokenUsage tokenUsageFrom(MistralUsageInfo mistralAiUsage) {
if (mistralAiUsage == null) {
return null;
}
@@ -72,7 +74,7 @@ public static TokenUsage tokenUsageFrom(MistralUsageInfo mistralAiUsage) {
);
}
- public static FinishReason finishReasonFrom(String mistralAiFinishReason) {
+ static FinishReason finishReasonFrom(String mistralAiFinishReason) {
if (mistralAiFinishReason == null) {
return null;
}
@@ -87,35 +89,32 @@ public static FinishReason finishReasonFrom(String mistralAiFinishReason) {
}
}
- public static void logResponse(Response response){
- try {
- LOGGER.debug("Response code: {}", response.code());
- LOGGER.debug("Response body: {}", getResponseBody(response));
- LOGGER.debug("Response headers: {}", getResponseHeaders(response));
- } catch (IOException e) {
- LOGGER.warn("Error while logging response", e);
- }
- }
-
- private static String getResponseBody(Response response) throws IOException {
- return isEventStream(response) ? "" : response.peekBody(Long.MAX_VALUE).string();
- }
-
- private static String getResponseHeaders(Response response){
- return (String) StreamSupport.stream(response.headers().spliterator(),false).map(header -> {
+ static String getHeaders(Headers headers){
+ return StreamSupport.stream(headers.spliterator(),false).map(header -> {
String headerKey = header.component1();
String headerValue = header.component2();
if (headerKey.equals("Authorization")) {
- headerValue = "Bearer " + headerValue.substring(0, 5) + "..." + headerValue.substring(headerValue.length() - 5);
- } else if (headerKey.equals("api-key")) {
- headerValue = headerValue.substring(0, 2) + "..." + headerValue.substring(headerValue.length() - 2);
+ headerValue = maskAuthorizationHeaderValue(headerValue);
}
return String.format("[%s: %s]", headerKey, headerValue);
}).collect(Collectors.joining(", "));
}
- private static boolean isEventStream(Response response){
- String contentType = response.header("Content-Type");
- return contentType != null && contentType.contains("event-stream");
+
+ private static String maskAuthorizationHeaderValue(String authorizationHeaderValue) {
+ try {
+ Matcher matcher = MISTRAI_API_KEY_BEARER_PATTERN.matcher(authorizationHeaderValue);
+ StringBuffer sb = new StringBuffer();
+
+ while (matcher.find()) {
+ String bearer = matcher.group(1);
+ String token = matcher.group(2);
+ matcher.appendReplacement(sb, bearer + " " + token.substring(0, 7) + "..." + token.substring(token.length() - 7));
+ }
+ matcher.appendTail(sb);
+ return sb.toString();
+ } catch (Exception e) {
+ return "Error while masking Authorization header value";
+ }
}
}
diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatModel.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatModel.java
index 5c6f5f8055..552608ed23 100644
--- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatModel.java
+++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatModel.java
@@ -44,7 +44,11 @@ public class MistralAiChatModel implements ChatLanguageModel {
* @param safePrompt a flag indicating whether to use a safe prompt for generating chat responses
* @param randomSeed the random seed for generating chat responses
* @param timeout the timeout duration for API requests
- * @param maxRetries the maximum number of retries for API requests
+ *
+ * The default value is 60 seconds
+ * @param logRequests a flag indicating whether to log API requests
+ * @param logResponses a flag indicating whether to log API responses
+ * @param maxRetries the maximum number of retries for API requests. It uses the default value 3 if not specified
*/
@Builder
public MistralAiChatModel(String baseUrl,
@@ -56,14 +60,18 @@ public MistralAiChatModel(String baseUrl,
Boolean safePrompt,
Integer randomSeed,
Duration timeout,
+ Boolean logRequests,
+ Boolean logResponses,
Integer maxRetries) {
this.client = MistralAiClient.builder()
.baseUrl(formattedURLForRetrofit(getOrDefault(baseUrl, MISTRALAI_API_URL)))
.apiKey(ensureNotBlankApiKey(apiKey))
.timeout(getOrDefault(timeout, Duration.ofSeconds(60)))
+ .logRequests(getOrDefault(logRequests, false))
+ .logResponses(getOrDefault(logResponses, false))
.build();
- this.modelName = getOrDefault(modelName, MistralChatCompletionModel.MISTRAL_TINY.toString());
+ this.modelName = getOrDefault(modelName, MistralChatCompletionModelName.MISTRAL_TINY.toString());
this.temperature = temperature;
this.topP = topP;
this.maxNewTokens = maxNewTokens;
diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiClient.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiClient.java
index 6383dc8ce2..083aa1578c 100644
--- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiClient.java
+++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiClient.java
@@ -7,6 +7,7 @@
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
+import dev.langchain4j.model.output.TokenUsage;
import lombok.Builder;
import okhttp3.OkHttpClient;
import okhttp3.ResponseBody;
@@ -21,8 +22,8 @@
import java.io.IOException;
import java.time.Duration;
+import static dev.langchain4j.internal.Utils.isNullOrBlank;
import static dev.langchain4j.model.mistralai.DefaultMistralAiHelper.*;
-import static dev.langchain4j.model.output.FinishReason.*;
class MistralAiClient {
@@ -35,14 +36,32 @@ class MistralAiClient {
private final OkHttpClient okHttpClient;
@Builder
- public MistralAiClient(String baseUrl, String apiKey, Duration timeout) {
- okHttpClient = new OkHttpClient.Builder()
- .addInterceptor(new MistralAiApiKeyInterceptor(apiKey))
+ public MistralAiClient(String baseUrl,
+ String apiKey,
+ Duration timeout,
+ Boolean logRequests,
+ Boolean logResponses) {
+ OkHttpClient.Builder okHttpClientBuilder = new OkHttpClient.Builder()
.callTimeout(timeout)
.connectTimeout(timeout)
.readTimeout(timeout)
- .writeTimeout(timeout)
- .build();
+ .writeTimeout(timeout);
+ if (isNullOrBlank(apiKey)) {
+ throw new IllegalArgumentException("MistralAI API Key must be defined. It can be generated here: https://console.mistral.ai/user/api-keys/");
+ }else {
+ okHttpClientBuilder.addInterceptor(new MistralAiApiKeyInterceptor(apiKey));
+ // Log raw HTTP requests
+ if (logRequests) {
+ okHttpClientBuilder.addInterceptor(new MistralAiRequestLoggingInterceptor());
+ }
+
+ // Log raw HTTP responses
+ if (logResponses) {
+ okHttpClientBuilder.addInterceptor(new MistralAiResponseLoggingInterceptor());
+ }
+ }
+
+ this.okHttpClient = okHttpClientBuilder.build();
Retrofit retrofit = new Retrofit.Builder()
.baseUrl(baseUrl)
@@ -57,7 +76,6 @@ public MistralChatCompletionResponse chatCompletion(MistralChatCompletionRequest
try {
retrofit2.Response retrofitResponse
= mistralAiApi.chatCompletion(request).execute();
- LOGGER.debug("MistralChatCompletionResponse: {}", retrofitResponse);
if (retrofitResponse.isSuccessful()) {
LOGGER.debug("ChatCompletionResponseBody: {}", retrofitResponse.body());
return retrofitResponse.body();
@@ -71,14 +89,13 @@ public MistralChatCompletionResponse chatCompletion(MistralChatCompletionRequest
public void streamingChatCompletion(MistralChatCompletionRequest request, StreamingResponseHandler handler) {
EventSourceListener eventSourceListener = new EventSourceListener() {
- StringBuilder contentBuilder = new StringBuilder();
- MistralUsageInfo tokenUsage = new MistralUsageInfo();
- FinishReason lastFinishReason = null;
+ final StringBuffer contentBuilder = new StringBuffer();
+ TokenUsage tokenUsage;
+ FinishReason finishReason;
@Override
public void onOpen(EventSource eventSource, okhttp3.Response response) {
LOGGER.debug("onOpen()");
- logResponse(response);
}
@Override
@@ -88,8 +105,8 @@ public void onEvent(EventSource eventSource, String id, String type, String data
if ("[DONE]".equals(data)) {
Response response = Response.from(
AiMessage.from(contentBuilder.toString()),
- tokenUsageFrom(tokenUsage),
- lastFinishReason
+ tokenUsage,
+ finishReason
);
handler.onComplete(response);
} else {
@@ -100,21 +117,14 @@ public void onEvent(EventSource eventSource, String id, String type, String data
contentBuilder.append(chunk);
handler.onNext(chunk);
- //Retrieving token usage of the last choice
- if(choice.getFinishReason() != null){
- FinishReason finishReason = finishReasonFrom(choice.getFinishReason());
- switch (finishReason){
- case STOP:
- lastFinishReason = STOP;
- tokenUsage = choice.getUsage();
- break;
- case LENGTH:
- lastFinishReason = LENGTH;
- tokenUsage = choice.getUsage();
- break;
- default:
- break;
- }
+ MistralUsageInfo usageInfo = choice.getUsage();
+ if(usageInfo != null){
+ this.tokenUsage = tokenUsageFrom(usageInfo);
+ }
+
+ String finishReasonString = choice.getFinishReason();
+ if(finishReasonString != null){
+ this.finishReason = finishReasonFrom(finishReasonString);
}
} catch (Exception e) {
handler.onError(e);
@@ -127,7 +137,6 @@ public void onEvent(EventSource eventSource, String id, String type, String data
@Override
public void onFailure(EventSource eventSource, Throwable t, okhttp3.Response response) {
LOGGER.debug("onFailure()", t);
- logResponse(response);
if (t != null){
handler.onError(t);
@@ -153,7 +162,6 @@ public MistralEmbeddingResponse embedding(MistralEmbeddingRequest request) {
try {
retrofit2.Response retrofitResponse
= mistralAiApi.embedding(request).execute();
- LOGGER.debug("MistralEmbeddingResponse: {}", retrofitResponse);
if (retrofitResponse.isSuccessful()) {
LOGGER.debug("EmbeddingResponseBody: {}", retrofitResponse.body());
return retrofitResponse.body();
@@ -169,7 +177,6 @@ public MistralModelResponse listModels() {
try {
retrofit2.Response retrofitResponse
= mistralAiApi.models().execute();
- LOGGER.debug("MistralModelResponse: {}", retrofitResponse);
if (retrofitResponse.isSuccessful()) {
LOGGER.debug("ModelResponseBody: {}", retrofitResponse.body());
return retrofitResponse.body();
diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModel.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModel.java
index 404df16dbc..3294af4d5d 100644
--- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModel.java
+++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModel.java
@@ -31,6 +31,10 @@ public class MistralAiEmbeddingModel implements EmbeddingModel {
* @param apiKey the API key for authentication
* @param modelName the name of the embedding model. It uses a default value if not specified
* @param timeout the timeout duration for API requests. It uses a default value of 60 seconds if not specified
+ *
+ * The default value is 60 seconds
+ * @param logRequests a flag indicating whether to log API requests
+ * @param logResponses a flag indicating whether to log API responses
* @param maxRetries the maximum number of retries for API requests. It uses a default value of 3 if not specified
*/
@Builder
@@ -38,13 +42,17 @@ public MistralAiEmbeddingModel(String baseUrl,
String apiKey,
String modelName,
Duration timeout,
+ Boolean logRequests,
+ Boolean logResponses,
Integer maxRetries) {
this.client = MistralAiClient.builder()
.baseUrl(formattedURLForRetrofit(getOrDefault(baseUrl, MISTRALAI_API_URL)))
.apiKey(ensureNotBlankApiKey(apiKey))
.timeout(getOrDefault(timeout, Duration.ofSeconds(60)))
+ .logRequests(getOrDefault(logRequests, false))
+ .logResponses(getOrDefault(logResponses,false))
.build();
- this.modelName = getOrDefault(modelName, MistralEmbeddingModelType.MISTRAL_EMBED.toString());
+ this.modelName = getOrDefault(modelName, MistralEmbeddingModelName.MISTRAL_EMBED.toString());
this.maxRetries = getOrDefault(maxRetries, 3);
}
diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralModelCard.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModelCard.java
similarity index 81%
rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralModelCard.java
rename to langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModelCard.java
index 77fb5eb3e9..c2bd3f8c9e 100644
--- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralModelCard.java
+++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModelCard.java
@@ -11,7 +11,7 @@
@NoArgsConstructor
@AllArgsConstructor
@Builder
-public class MistralModelCard {
+public class MistralAiModelCard {
private String id;
private String object;
@@ -19,6 +19,6 @@ public class MistralModelCard {
private String ownerBy;
private String root;
private String parent;
- private List permission;
+ private List permission;
}
diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralModelPermission.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModelPermission.java
similarity index 93%
rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralModelPermission.java
rename to langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModelPermission.java
index 812a22bdda..ea60d9b592 100644
--- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralModelPermission.java
+++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModelPermission.java
@@ -10,7 +10,7 @@
@NoArgsConstructor
@AllArgsConstructor
@Builder
-class MistralModelPermission {
+public class MistralAiModelPermission {
private String id;
private String object;
diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModels.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModels.java
index e741f76bef..ccf1e7f33b 100644
--- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModels.java
+++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModels.java
@@ -9,7 +9,6 @@
import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.model.mistralai.DefaultMistralAiHelper.*;
-import static java.util.stream.Collectors.toList;
/**
* Represents a collection of Mistral AI models.
@@ -26,17 +25,23 @@ public class MistralAiModels {
* @param baseUrl the base URL of the Mistral AI API. It uses the default value if not specified
* @param apiKey the API key for authentication
* @param timeout the timeout duration for API requests. It uses the default value of 60 seconds if not specified
+ * @param logRequests a flag whether to log raw HTTP requests
+ * @param logResponses a flag whether to log raw HTTP responses
* @param maxRetries the maximum number of retries for API requests. It uses the default value of 3 if not specified
*/
@Builder
public MistralAiModels(String baseUrl,
String apiKey,
Duration timeout,
+ Boolean logRequests,
+ Boolean logResponses,
Integer maxRetries) {
this.client = MistralAiClient.builder()
.baseUrl(formattedURLForRetrofit(getOrDefault(baseUrl, MISTRALAI_API_URL)))
.apiKey(ensureNotBlankApiKey(apiKey))
.timeout(getOrDefault(timeout, Duration.ofSeconds(60)))
+ .logRequests(getOrDefault(logRequests, false))
+ .logResponses(getOrDefault(logResponses, false))
.build();
this.maxRetries = getOrDefault(maxRetries, 3);
}
@@ -51,35 +56,12 @@ public static MistralAiModels withApiKey(String apiKey) {
return builder().apiKey(apiKey).build();
}
- /**
- * Retrieves the details of a specific model.
- *
- * @param modelId the ID of the model
- * @return the response containing the model details
- */
- public Response getModelDetails(String modelId){
- return Response.from(
- this.getModels().content().stream().filter(modelCard -> modelCard.getId().equals(modelId)).findFirst().orElse(null)
- );
- }
-
- /**
- * Retrieves the IDs of all available models.
- *
- * @return the response containing the list of model IDs
- */
- public Response> get(){
- return Response.from(
- this.getModels().content().stream().map(MistralModelCard::getId).collect(toList())
- );
- }
-
/**
* Retrieves the list of all available models.
*
* @return the response containing the list of models
*/
- public Response> getModels(){
+ public Response> availableModels(){
MistralModelResponse response = withRetry(client::listModels, maxRetries);
return Response.from(
response.getData()
diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiRequestLoggingInterceptor.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiRequestLoggingInterceptor.java
new file mode 100644
index 0000000000..57ec952a0b
--- /dev/null
+++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiRequestLoggingInterceptor.java
@@ -0,0 +1,57 @@
+package dev.langchain4j.model.mistralai;
+
+import okhttp3.Interceptor;
+import okhttp3.Request;
+import okhttp3.Response;
+import okio.Buffer;
+import org.jetbrains.annotations.NotNull;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+
+import static dev.langchain4j.model.mistralai.DefaultMistralAiHelper.getHeaders;
+
+class MistralAiRequestLoggingInterceptor implements Interceptor {
+
+ private static final Logger LOGGER = LoggerFactory.getLogger(MistralAiRequestLoggingInterceptor.class);
+
+ @NotNull
+ @Override
+ public Response intercept(@NotNull Chain chain) throws IOException {
+ Request request = chain.request();
+ this.log(request);
+ return chain.proceed(request);
+ }
+
+ private void log(Request request) {
+ String message = "Request:\n- method: {}\n- url: {}\n- headers: {}\n- body: {}";
+ try {
+ if (LOGGER.isInfoEnabled()) {
+ LOGGER.info(message, new Object[]{request.method(), request.url(), getHeaders(request.headers()), getBody(request)});
+ } else if (LOGGER.isWarnEnabled()) {
+ LOGGER.warn(message, new Object[]{request.method(), request.url(), getHeaders(request.headers()), getBody(request)});
+ } else if (LOGGER.isErrorEnabled()) {
+ LOGGER.error(message, new Object[]{request.method(), request.url(), getHeaders(request.headers()), getBody(request)});
+ } else {
+ LOGGER.debug(message, new Object[]{request.method(), request.url(), getHeaders(request.headers()), getBody(request)});
+ }
+ } catch (Exception e) {
+ LOGGER.warn("Error while logging request: {}", e.getMessage());
+ }
+ }
+
+ private static String getBody(Request request){
+ try {
+ Buffer buffer = new Buffer();
+ if (request.body() == null) {
+ return "";
+ }
+ request.body().writeTo(buffer);
+ return buffer.readUtf8();
+ } catch (Exception e) {
+ LOGGER.warn("Exception while getting body", e);
+ return "Exception while getting body: " + e.getMessage();
+ }
+ }
+}
diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiResponseLoggingInterceptor.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiResponseLoggingInterceptor.java
new file mode 100644
index 0000000000..fafe65cfc4
--- /dev/null
+++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiResponseLoggingInterceptor.java
@@ -0,0 +1,52 @@
+package dev.langchain4j.model.mistralai;
+
+import okhttp3.Interceptor;
+import okhttp3.Request;
+import okhttp3.Response;
+import org.jetbrains.annotations.NotNull;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+
+import static dev.langchain4j.model.mistralai.DefaultMistralAiHelper.getHeaders;
+
+class MistralAiResponseLoggingInterceptor implements Interceptor {
+
+ private static final Logger LOGGER = LoggerFactory.getLogger(MistralAiResponseLoggingInterceptor.class);
+
+ @NotNull
+ @Override
+ public Response intercept(@NotNull Chain chain) throws IOException {
+ Request request = chain.request();
+ Response response = chain.proceed(request);
+ this.log(response);
+ return response;
+ }
+
+ private void log(Response response) {
+ String message = "Response:\n- status code: {}\n- headers: {}\n- body: {}";
+ try {
+ if (LOGGER.isInfoEnabled()) {
+ LOGGER.info(message, new Object[]{response.code(), getHeaders(response.headers()), this.getBody(response)});
+ } else if (LOGGER.isWarnEnabled()) {
+ LOGGER.warn(message, new Object[]{response.code(), getHeaders(response.headers()), this.getBody(response)});
+ } else if (LOGGER.isErrorEnabled()) {
+ LOGGER.error(message, new Object[]{response.code(), getHeaders(response.headers()), this.getBody(response)});
+ } else {
+ LOGGER.debug(message, new Object[]{response.code(), getHeaders(response.headers()), this.getBody(response)});
+ }
+ } catch (Exception e) {
+ LOGGER.warn("Error while logging response: {}", e.getMessage());
+ }
+ }
+
+ private String getBody(Response response) throws IOException {
+ return isEventStream(response) ? "[skipping response body due to streaming]" : response.peekBody(Long.MAX_VALUE).string();
+ }
+
+ private static boolean isEventStream(Response response){
+ String contentType = response.header("Content-Type");
+ return contentType != null && contentType.contains("event-stream");
+ }
+}
diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModel.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModel.java
index f2dc5ec477..88cea69b66 100644
--- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModel.java
+++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModel.java
@@ -43,6 +43,9 @@ public class MistralAiStreamingChatModel implements StreamingChatLanguageModel {
* @param maxNewTokens the maximum number of new tokens to generate in a chat response
* @param safePrompt a flag indicating whether to use a safe prompt for generating chat responses
* @param randomSeed the random seed for generating chat responses
+ * (if not specified, a random number is used)
+ * @param logRequests a flag indicating whether to log raw HTTP requests
+ * @param logResponses a flag indicating whether to log raw HTTP responses
* @param timeout the timeout duration for API requests
*/
@Builder
@@ -54,14 +57,18 @@ public MistralAiStreamingChatModel(String baseUrl,
Integer maxNewTokens,
Boolean safePrompt,
Integer randomSeed,
+ Boolean logRequests,
+ Boolean logResponses,
Duration timeout) {
this.client = MistralAiClient.builder()
.baseUrl(formattedURLForRetrofit(getOrDefault(baseUrl, MISTRALAI_API_URL)))
.apiKey(ensureNotBlankApiKey(apiKey))
.timeout(getOrDefault(timeout, Duration.ofSeconds(60)))
+ .logRequests(getOrDefault(logRequests, false))
+ .logResponses(getOrDefault(logResponses, false))
.build();
- this.modelName = getOrDefault(modelName, MistralChatCompletionModel.MISTRAL_TINY.toString());
+ this.modelName = getOrDefault(modelName, MistralChatCompletionModelName.MISTRAL_TINY.toString());
this.temperature = temperature;
this.topP = topP;
this.maxNewTokens = maxNewTokens;
diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatCompletionModel.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatCompletionModelName.java
similarity index 90%
rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatCompletionModel.java
rename to langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatCompletionModelName.java
index d31762e1b9..e245f5f34c 100644
--- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatCompletionModel.java
+++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatCompletionModelName.java
@@ -19,7 +19,7 @@
*
* @see Mistral AI Endpoints
*/
-enum MistralChatCompletionModel {
+public enum MistralChatCompletionModelName {
// powered by Mistral-7B-v0.2
MISTRAL_TINY("mistral-tiny"),
@@ -30,7 +30,7 @@ enum MistralChatCompletionModel {
private final String value;
- private MistralChatCompletionModel(String value) {
+ private MistralChatCompletionModelName(String value) {
this.value = value;
}
diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatMessage.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatMessage.java
index ebc1980c58..99e65024a8 100644
--- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatMessage.java
+++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatMessage.java
@@ -12,6 +12,6 @@
@Builder
class MistralChatMessage {
- private MistralRoleType role;
+ private MistralRoleName role;
private String content;
}
diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralDeltaMessage.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralDeltaMessage.java
index 9dae9f54cd..76aa1b402c 100644
--- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralDeltaMessage.java
+++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralDeltaMessage.java
@@ -11,7 +11,7 @@
@Builder
class MistralDeltaMessage {
- private MistralRoleType role;
+ private MistralRoleName role;
private String content;
}
diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralEmbeddingModelType.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralEmbeddingModelName.java
similarity index 74%
rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralEmbeddingModelType.java
rename to langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralEmbeddingModelName.java
index 31f2abbb15..bf6c58142b 100644
--- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralEmbeddingModelType.java
+++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralEmbeddingModelName.java
@@ -1,9 +1,9 @@
package dev.langchain4j.model.mistralai;
/**
- * The MistralEmbeddingModelType enum represents the available embedding models in the Mistral AI module.
+ * The MistralEmbeddingModelName enum represents the available embedding models in the Mistral AI module.
*/
-enum MistralEmbeddingModelType {
+public enum MistralEmbeddingModelName {
/**
* The MISTRAL_EMBED model.
@@ -12,7 +12,7 @@ enum MistralEmbeddingModelType {
private final String value;
- private MistralEmbeddingModelType(String value) {
+ private MistralEmbeddingModelName(String value) {
this.value = value;
}
diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralModelResponse.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralModelResponse.java
index f9e7afff05..59f9b8654a 100644
--- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralModelResponse.java
+++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralModelResponse.java
@@ -14,6 +14,6 @@
class MistralModelResponse {
private String object;
- private List data;
+ private List data;
}
diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralRoleType.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralRoleName.java
similarity index 79%
rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralRoleType.java
rename to langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralRoleName.java
index cded6c49c7..ef54b3ad97 100644
--- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralRoleType.java
+++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralRoleName.java
@@ -4,12 +4,12 @@
import lombok.Getter;
@Getter
-public enum MistralRoleType {
+public enum MistralRoleName {
@SerializedName("system") SYSTEM,
@SerializedName("user") USER,
@SerializedName("assistant") ASSISTANT;
- private MistralRoleType() {}
+ private MistralRoleName() {}
}
diff --git a/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiChatModelIT.java b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiChatModelIT.java
index d65f64a8a0..f52a49aae1 100644
--- a/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiChatModelIT.java
+++ b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiChatModelIT.java
@@ -16,6 +16,8 @@ class MistralAiChatModelIT {
ChatLanguageModel model = MistralAiChatModel.builder()
.apiKey(System.getenv("MISTRAL_AI_API_KEY"))
.temperature(0.1)
+ .logRequests(true)
+ .logResponses(true)
.build();
@Test
@@ -127,7 +129,7 @@ void should_generate_answer_in_french_using_model_small_and_return_token_usage_a
// given - Mistral Small = Mistral-8X7B
ChatLanguageModel model = MistralAiChatModel.builder()
.apiKey(System.getenv("MISTRAL_AI_API_KEY"))
- .modelName(MistralChatCompletionModel.MISTRAL_SMALL.toString())
+ .modelName(MistralChatCompletionModelName.MISTRAL_SMALL.toString())
.temperature(0.1)
.build();
@@ -154,7 +156,7 @@ void should_generate_answer_in_spanish_using_model_small_and_return_token_usage_
// given - Mistral Small = Mistral-8X7B
ChatLanguageModel model = MistralAiChatModel.builder()
.apiKey(System.getenv("MISTRAL_AI_API_KEY"))
- .modelName(MistralChatCompletionModel.MISTRAL_SMALL.toString())
+ .modelName(MistralChatCompletionModelName.MISTRAL_SMALL.toString())
.temperature(0.1)
.build();
@@ -181,7 +183,7 @@ void should_generate_answer_using_model_medium_and_return_token_usage_and_finish
// given - Mistral Medium = currently relies on an internal prototype model.
ChatLanguageModel model = MistralAiChatModel.builder()
.apiKey(System.getenv("MISTRAL_AI_API_KEY"))
- .modelName(MistralChatCompletionModel.MISTRAL_MEDIUM.toString())
+ .modelName(MistralChatCompletionModelName.MISTRAL_MEDIUM.toString())
.maxNewTokens(10)
.build();
diff --git a/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModelIT.java b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModelIT.java
index 627e381a8c..83fc16d0eb 100644
--- a/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModelIT.java
+++ b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModelIT.java
@@ -16,6 +16,8 @@ class MistralAiEmbeddingModelIT {
EmbeddingModel model = MistralAiEmbeddingModel.builder()
.apiKey(System.getenv("MISTRAL_AI_API_KEY"))
+ .logRequests(true)
+ .logResponses(true)
.build();
@Test
diff --git a/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiModelsIT.java b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiModelsIT.java
index a3199c6d50..0e9a54edc4 100644
--- a/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiModelsIT.java
+++ b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiModelsIT.java
@@ -12,39 +12,15 @@ class MistralAiModelsIT {
MistralAiModels models = MistralAiModels.withApiKey(System.getenv("MISTRAL_AI_API_KEY"));
//https://docs.mistral.ai/models/
- @Test
- void should_return_all_models() {
- // when
- Response> response = models.get();
-
- // then
- assertThat(response.content()).isNotEmpty();
- assertThat(response.content().size()).isEqualTo(4);
- assertThat(response.content()).contains(MistralChatCompletionModel.MISTRAL_TINY.toString());
-
- }
-
- @Test
- void should_return_one_model_card(){
- // when
- Response response = models.getModelDetails(MistralChatCompletionModel.MISTRAL_TINY.toString());
-
- // then
- assertThat(response.content()).isNotNull();
- assertThat(response.content()).extracting("id").isEqualTo(MistralChatCompletionModel.MISTRAL_TINY.toString());
- assertThat(response.content()).extracting("object").isEqualTo("model");
- assertThat(response.content()).extracting("permission").isNotNull();
- }
-
@Test
void should_return_all_model_cards(){
// when
- Response> response = models.getModels();
+ Response> response = models.availableModels();
// then
assertThat(response.content()).isNotEmpty();
assertThat(response.content().size()).isEqualTo(4);
- assertThat(response.content()).extracting("id").contains(MistralChatCompletionModel.MISTRAL_TINY.toString());
+ assertThat(response.content()).extracting("id").contains(MistralChatCompletionModelName.MISTRAL_TINY.toString());
assertThat(response.content()).extracting("object").contains("model");
assertThat(response.content()).extracting("permission").isNotNull();
}
diff --git a/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModelIT.java b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModelIT.java
index c18d6af960..63c04b9f3b 100644
--- a/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModelIT.java
+++ b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModelIT.java
@@ -2,21 +2,17 @@
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.UserMessage;
-import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
+import dev.langchain4j.model.chat.TestStreamingResponseHandler;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import org.junit.jupiter.api.Test;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ExecutionException;
-import java.util.concurrent.TimeoutException;
-
import static dev.langchain4j.data.message.UserMessage.userMessage;
import static dev.langchain4j.model.output.FinishReason.LENGTH;
import static dev.langchain4j.model.output.FinishReason.STOP;
import static java.util.Arrays.asList;
-import static java.util.concurrent.TimeUnit.SECONDS;
+import static java.util.Collections.singletonList;
import static org.assertj.core.api.Assertions.assertThat;
class MistralAiStreamingChatModelIT {
@@ -27,46 +23,19 @@ class MistralAiStreamingChatModelIT {
.build();
@Test
- void should_stream_answer_and_return_token_usage_and_finish_reason_stop() throws ExecutionException, InterruptedException, TimeoutException {
-
- CompletableFuture futureAnswer = new CompletableFuture<>();
- CompletableFuture> futureResponse = new CompletableFuture<>();
+ void should_stream_answer_and_return_token_usage_and_finish_reason_stop() {
// given
UserMessage userMessage = userMessage("What is the capital of Peru?");
// when
- model.generate(userMessage.text(), new StreamingResponseHandler() {
-
- private final StringBuilder answerBuilder = new StringBuilder();
-
- @Override
- public void onNext(String token) {
- System.out.println("onNext: '" + token + "'");
- answerBuilder.append(token);
-
- }
-
- @Override
- public void onComplete(Response response) {
- System.out.println("onComplete: '" + response + "'");
- futureAnswer.complete(answerBuilder.toString());
- futureResponse.complete(response);
- }
-
- @Override
- public void onError(Throwable error) {
- futureAnswer.completeExceptionally(error);
- futureResponse.completeExceptionally(error);
- }
- });
+ TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>();
+ model.generate(singletonList(userMessage), handler);
- String chunk = futureAnswer.get(10, SECONDS);
- Response response = futureResponse.get(10, SECONDS);
+ Response response = handler.get();
// then
- assertThat(chunk).contains("Lima");
- assertThat(response.content().text()).isEqualTo(chunk);
+ assertThat(response.content().text()).containsIgnoringCase("Lima");
TokenUsage tokenUsage = response.tokenUsage();
assertThat(tokenUsage.inputTokenCount()).isEqualTo(15);
@@ -78,10 +47,7 @@ public void onError(Throwable error) {
}
@Test
- void should_stream_answer_and_return_token_usage_and_finish_reason_length() throws ExecutionException, InterruptedException, TimeoutException {
-
- CompletableFuture futureAnswer = new CompletableFuture<>();
- CompletableFuture> futureResponse = new CompletableFuture<>();
+ void should_stream_answer_and_return_token_usage_and_finish_reason_length() {
// given
StreamingChatLanguageModel model = MistralAiStreamingChatModel.builder()
@@ -93,37 +59,12 @@ void should_stream_answer_and_return_token_usage_and_finish_reason_length() thro
UserMessage userMessage = userMessage("What is the capital of Peru?");
// when
- model.generate(userMessage.text(), new StreamingResponseHandler() {
-
- private final StringBuilder answerBuilder = new StringBuilder();
-
- @Override
- public void onNext(String token) {
- System.out.println("onNext: '" + token + "'");
- answerBuilder.append(token);
-
- }
-
- @Override
- public void onComplete(Response response) {
- System.out.println("onComplete: '" + response + "'");
- futureAnswer.complete(answerBuilder.toString());
- futureResponse.complete(response);
- }
-
- @Override
- public void onError(Throwable error) {
- futureAnswer.completeExceptionally(error);
- futureResponse.completeExceptionally(error);
- }
- });
-
- String chunk = futureAnswer.get(10, SECONDS);
- Response response = futureResponse.get(10, SECONDS);
+ TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>();
+ model.generate(singletonList(userMessage), handler);
+ Response response = handler.get();
// then
- assertThat(chunk).contains("Lima");
- assertThat(response.content().text()).isEqualTo(chunk);
+ assertThat(response.content().text()).containsIgnoringCase("Lima");
TokenUsage tokenUsage = response.tokenUsage();
assertThat(tokenUsage.inputTokenCount()).isEqualTo(15);
@@ -136,10 +77,7 @@ public void onError(Throwable error) {
//https://docs.mistral.ai/platform/guardrailing/
@Test
- void should_stream_answer_and_system_prompt_to_enforce_guardrails() throws ExecutionException, InterruptedException, TimeoutException {
-
- CompletableFuture futureAnswer = new CompletableFuture<>();
- CompletableFuture> futureResponse = new CompletableFuture<>();
+ void should_stream_answer_and_system_prompt_to_enforce_guardrails() {
// given
StreamingChatLanguageModel model = MistralAiStreamingChatModel.builder()
@@ -151,37 +89,13 @@ void should_stream_answer_and_system_prompt_to_enforce_guardrails() throws Execu
UserMessage userMessage = userMessage("Hello, my name is Carlos");
// then
- model.generate(userMessage.text(), new StreamingResponseHandler() {
- private final StringBuilder answerBuilder = new StringBuilder();
-
- @Override
- public void onNext(String token) {
- System.out.println("onNext: '" + token + "'");
- answerBuilder.append(token);
-
- }
-
- @Override
- public void onComplete(Response response) {
- System.out.println("onComplete: '" + response + "'");
- futureAnswer.complete(answerBuilder.toString());
- futureResponse.complete(response);
- }
-
- @Override
- public void onError(Throwable error) {
- futureAnswer.completeExceptionally(error);
- futureResponse.completeExceptionally(error);
- }
- });
-
- String chunk = futureAnswer.get(10, SECONDS);
- Response response = futureResponse.get(10, SECONDS);
+ TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>();
+ model.generate(singletonList(userMessage), handler);
+ Response response = handler.get();
// then
- assertThat(chunk).contains("respect");
- assertThat(chunk).contains("truth");
- assertThat(response.content().text()).isEqualTo(chunk);
+ assertThat(response.content().text()).containsIgnoringCase("respect");
+ assertThat(response.content().text()).containsIgnoringCase("truth");
TokenUsage tokenUsage = response.tokenUsage();
assertThat(tokenUsage.inputTokenCount()).isGreaterThan(50);
@@ -190,52 +104,25 @@ public void onError(Throwable error) {
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
assertThat(response.finishReason()).isEqualTo(STOP);
-
}
@Test
- void should_stream_answer_and_return_token_usage_and_finish_reason_stop_with_multiple_messages() throws ExecutionException, InterruptedException, TimeoutException {
-
- CompletableFuture futureAnswer = new CompletableFuture<>();
- CompletableFuture> futureResponse = new CompletableFuture<>();
+ void should_stream_answer_and_return_token_usage_and_finish_reason_stop_with_multiple_messages() {
// given
UserMessage userMessage1 = userMessage("What is the capital of Peru?");
UserMessage userMessage2 = userMessage("What is the capital of France?");
UserMessage userMessage3 = userMessage("What is the capital of Canada?");
- model.generate(asList(userMessage1,userMessage2,userMessage3), new StreamingResponseHandler(){
- private final StringBuilder answerBuilder = new StringBuilder();
-
- @Override
- public void onNext(String token) {
- System.out.println("onNext: '" + token + "'");
- answerBuilder.append(token);
-
- }
-
- @Override
- public void onComplete(Response response) {
- System.out.println("onComplete: '" + response + "'");
- futureAnswer.complete(answerBuilder.toString());
- futureResponse.complete(response);
- }
-
- @Override
- public void onError(Throwable error) {
- futureAnswer.completeExceptionally(error);
- futureResponse.completeExceptionally(error);
- }
- });
-
- String chunk = futureAnswer.get(10, SECONDS);
- Response response = futureResponse.get(10, SECONDS);
+ // when
+ TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>();
+ model.generate(asList(userMessage1, userMessage2, userMessage3), handler);
+ Response response = handler.get();
// then
- assertThat(chunk).contains("Lima");
- assertThat(chunk).contains("Paris");
- assertThat(chunk).contains("Ottawa");
- assertThat(response.content().text()).isEqualTo(chunk);
+ assertThat(response.content().text()).containsIgnoringCase("lima");
+ assertThat(response.content().text()).containsIgnoringCase("paris");
+ assertThat(response.content().text()).containsIgnoringCase("ottawa");
TokenUsage tokenUsage = response.tokenUsage();
assertThat(tokenUsage.inputTokenCount()).isEqualTo(11 + 11 + 11);
@@ -244,55 +131,27 @@ public void onError(Throwable error) {
.isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount());
assertThat(response.finishReason()).isEqualTo(STOP);
-
-
}
@Test
- void should_stream_answer_in_french_using_model_small_and_return_token_usage_and_finish_reason_stop() throws ExecutionException, InterruptedException, TimeoutException {
-
- CompletableFuture futureAnswer = new CompletableFuture<>();
- CompletableFuture> futureResponse = new CompletableFuture<>();
+ void should_stream_answer_in_french_using_model_small_and_return_token_usage_and_finish_reason_stop() {
// given - Mistral Small = Mistral-8X7B
StreamingChatLanguageModel model = MistralAiStreamingChatModel.builder()
.apiKey(System.getenv("MISTRAL_AI_API_KEY"))
- .modelName(MistralChatCompletionModel.MISTRAL_SMALL.toString())
+ .modelName(MistralChatCompletionModelName.MISTRAL_SMALL.toString())
.temperature(0.1)
.build();
UserMessage userMessage = userMessage("Quelle est la capitale du Pérou?");
- model.generate(userMessage.text(), new StreamingResponseHandler() {
- private final StringBuilder answerBuilder = new StringBuilder();
-
- @Override
- public void onNext(String token) {
- System.out.println("onNext: '" + token + "'");
- answerBuilder.append(token);
-
- }
-
- @Override
- public void onComplete(Response response) {
- System.out.println("onComplete: '" + response + "'");
- futureAnswer.complete(answerBuilder.toString());
- futureResponse.complete(response);
- }
-
- @Override
- public void onError(Throwable error) {
- futureAnswer.completeExceptionally(error);
- futureResponse.completeExceptionally(error);
- }
- });
-
- String chunk = futureAnswer.get(10, SECONDS);
- Response response = futureResponse.get(10, SECONDS);
+ // when
+ TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>();
+ model.generate(singletonList(userMessage), handler);
+ Response response = handler.get();
// then
- assertThat(chunk).contains("Lima");
- assertThat(response.content().text()).isEqualTo(chunk);
+ assertThat(response.content().text()).containsIgnoringCase("lima");
TokenUsage tokenUsage = response.tokenUsage();
assertThat(tokenUsage.inputTokenCount()).isEqualTo(18);
@@ -304,50 +163,24 @@ public void onError(Throwable error) {
}
@Test
- void should_stream_answer_in_spanish_using_model_small_and_return_token_usage_and_finish_reason_stop() throws ExecutionException, InterruptedException, TimeoutException {
-
- CompletableFuture futureAnswer = new CompletableFuture<>();
- CompletableFuture> futureResponse = new CompletableFuture<>();
+ void should_stream_answer_in_spanish_using_model_small_and_return_token_usage_and_finish_reason_stop() {
// given - Mistral Small = Mistral-8X7B
StreamingChatLanguageModel model = MistralAiStreamingChatModel.builder()
.apiKey(System.getenv("MISTRAL_AI_API_KEY"))
- .modelName(MistralChatCompletionModel.MISTRAL_SMALL.toString())
+ .modelName(MistralChatCompletionModelName.MISTRAL_SMALL.toString())
.temperature(0.1)
.build();
UserMessage userMessage = userMessage("¿Cuál es la capital de Perú?");
- model.generate(userMessage.text(), new StreamingResponseHandler() {
- private final StringBuilder answerBuilder = new StringBuilder();
-
- @Override
- public void onNext(String token) {
- System.out.println("onNext: '" + token + "'");
- answerBuilder.append(token);
-
- }
-
- @Override
- public void onComplete(Response response) {
- System.out.println("onComplete: '" + response + "'");
- futureAnswer.complete(answerBuilder.toString());
- futureResponse.complete(response);
- }
-
- @Override
- public void onError(Throwable error) {
- futureAnswer.completeExceptionally(error);
- futureResponse.completeExceptionally(error);
- }
- });
-
- String chunk = futureAnswer.get(10, SECONDS);
- Response response = futureResponse.get(10, SECONDS);
+ // when
+ TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>();
+ model.generate(singletonList(userMessage), handler);
+ Response response = handler.get();
// then
- assertThat(chunk).contains("Lima");
- assertThat(response.content().text()).isEqualTo(chunk);
+ assertThat(response.content().text()).containsIgnoringCase("lima");
TokenUsage tokenUsage = response.tokenUsage();
assertThat(tokenUsage.inputTokenCount()).isEqualTo(19);
@@ -359,50 +192,24 @@ public void onError(Throwable error) {
}
@Test
- void should_stream_answer_using_model_medium_and_return_token_usage_and_finish_reason_length() throws ExecutionException, InterruptedException, TimeoutException {
-
- CompletableFuture futureAnswer = new CompletableFuture<>();
- CompletableFuture> futureResponse = new CompletableFuture<>();
+ void should_stream_answer_using_model_medium_and_return_token_usage_and_finish_reason_length() {
// given - Mistral Medium = currently relies on an internal prototype model.
StreamingChatLanguageModel model = MistralAiStreamingChatModel.builder()
.apiKey(System.getenv("MISTRAL_AI_API_KEY"))
- .modelName(MistralChatCompletionModel.MISTRAL_MEDIUM.toString())
+ .modelName(MistralChatCompletionModelName.MISTRAL_MEDIUM.toString())
.maxNewTokens(10)
.build();
UserMessage userMessage = userMessage("What is the capital of Peru?");
- model.generate(userMessage.text(), new StreamingResponseHandler() {
- private final StringBuilder answerBuilder = new StringBuilder();
-
- @Override
- public void onNext(String token) {
- System.out.println("onNext: '" + token + "'");
- answerBuilder.append(token);
-
- }
-
- @Override
- public void onComplete(Response response) {
- System.out.println("onComplete: '" + response + "'");
- futureAnswer.complete(answerBuilder.toString());
- futureResponse.complete(response);
- }
-
- @Override
- public void onError(Throwable error) {
- futureAnswer.completeExceptionally(error);
- futureResponse.completeExceptionally(error);
- }
- });
-
- String chunk = futureAnswer.get(10, SECONDS);
- Response response = futureResponse.get(10, SECONDS);
+ // when
+ TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>();
+ model.generate(singletonList(userMessage), handler);
+ Response response = handler.get();
// then
- assertThat(chunk).contains("Lima");
- assertThat(response.content().text()).isEqualTo(chunk);
+ assertThat(response.content().text()).containsIgnoringCase("lima");
TokenUsage tokenUsage = response.tokenUsage();
assertThat(tokenUsage.inputTokenCount()).isEqualTo(15);
diff --git a/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiChatModelIT.java b/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiChatModelIT.java
index 3c78e651bb..6403682677 100644
--- a/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiChatModelIT.java
+++ b/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiChatModelIT.java
@@ -376,4 +376,4 @@ void should_accept_text_and_multiple_images_from_different_sources() {
assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(189);
}
-}
\ No newline at end of file
+}
From 0d87349d8b4d4bb83a74f2390328a0fa5e3ba03b Mon Sep 17 00:00:00 2001
From: Carlos Zela Bueno <1715122+czelabueno@users.noreply.github.com>
Date: Wed, 24 Jan 2024 11:21:44 -0500
Subject: [PATCH 17/24] Mistral AI token masking until 4 symbols
Co-authored-by: LangChain4j
---
.../dev/langchain4j/model/mistralai/DefaultMistralAiHelper.java | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/DefaultMistralAiHelper.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/DefaultMistralAiHelper.java
index cec7f89d8a..c61bb32fd3 100644
--- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/DefaultMistralAiHelper.java
+++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/DefaultMistralAiHelper.java
@@ -108,7 +108,7 @@ private static String maskAuthorizationHeaderValue(String authorizationHeaderVal
while (matcher.find()) {
String bearer = matcher.group(1);
String token = matcher.group(2);
- matcher.appendReplacement(sb, bearer + " " + token.substring(0, 7) + "..." + token.substring(token.length() - 7));
+ matcher.appendReplacement(sb, bearer + " " + token.substring(0, 2) + "..." + token.substring(token.length() - 2));
}
matcher.appendTail(sb);
return sb.toString();
From 7da8928e716290fa810bfb0729fb09c4c7ab8a83 Mon Sep 17 00:00:00 2001
From: Carlos Zela Bueno <1715122+czelabueno@users.noreply.github.com>
Date: Wed, 24 Jan 2024 11:41:49 -0500
Subject: [PATCH 18/24] MistralAI update chat model enum
Co-authored-by: LangChain4j
---
.../model/mistralai/MistralChatCompletionModelName.java | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatCompletionModelName.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatCompletionModelName.java
index e245f5f34c..79f7250727 100644
--- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatCompletionModelName.java
+++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatCompletionModelName.java
@@ -19,7 +19,7 @@
*
* @see Mistral AI Endpoints
*/
-public enum MistralChatCompletionModelName {
+public enum MistralAiChatModelName {
// powered by Mistral-7B-v0.2
MISTRAL_TINY("mistral-tiny"),
From 79237cc0af590c918d76d435d26a02e6c304565f Mon Sep 17 00:00:00 2001
From: Carlos Zela Bueno <1715122+czelabueno@users.noreply.github.com>
Date: Wed, 24 Jan 2024 14:15:04 -0500
Subject: [PATCH 19/24] MistralAI update embedding model enum
Co-authored-by: LangChain4j
---
.../langchain4j/model/mistralai/MistralEmbeddingModelName.java | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralEmbeddingModelName.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralEmbeddingModelName.java
index bf6c58142b..e3f497b9df 100644
--- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralEmbeddingModelName.java
+++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralEmbeddingModelName.java
@@ -3,7 +3,7 @@
/**
* The MistralEmbeddingModelName enum represents the available embedding models in the Mistral AI module.
*/
-public enum MistralEmbeddingModelName {
+public enum MistralAiEmbeddingModelName {
/**
* The MISTRAL_EMBED model.
From b51d0a5955f74e5290394898ec08690b973309ef Mon Sep 17 00:00:00 2001
From: Carlos Zela Bueno <1715122+czelabueno@users.noreply.github.com>
Date: Wed, 24 Jan 2024 14:17:04 -0500
Subject: [PATCH 20/24] Mistral AI fix get usageInfo from last chat completion
response
Co-authored-by: LangChain4j
---
.../java/dev/langchain4j/model/mistralai/MistralAiClient.java | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiClient.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiClient.java
index 083aa1578c..9f7f6e91ed 100644
--- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiClient.java
+++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiClient.java
@@ -117,7 +117,7 @@ public void onEvent(EventSource eventSource, String id, String type, String data
contentBuilder.append(chunk);
handler.onNext(chunk);
- MistralUsageInfo usageInfo = choice.getUsage();
+ MistralUsageInfo usageInfo = chatCompletionResponse.getUsage();
if(usageInfo != null){
this.tokenUsage = tokenUsageFrom(usageInfo);
}
From 2078d32eef3da43622e6eb96f33b1669f0972444 Mon Sep 17 00:00:00 2001
From: czelabueno
Date: Wed, 24 Jan 2024 14:31:11 -0500
Subject: [PATCH 21/24] Mistral AI fix logging streaming and rename enums
---
langchain4j-mistral-ai/pom.xml | 2 +-
.../mistralai/DefaultMistralAiHelper.java | 28 ++++++++++++++-----
.../model/mistralai/MistralAiChatModel.java | 2 +-
...lName.java => MistralAiChatModelName.java} | 4 +--
.../model/mistralai/MistralAiClient.java | 25 +++++++++++------
.../mistralai/MistralAiEmbeddingModel.java | 2 +-
....java => MistralAiEmbeddingModelName.java} | 4 +--
.../MistralAiRequestLoggingInterceptor.java | 13 ++-------
.../MistralAiResponseLoggingInterceptor.java | 14 +++-------
.../MistralAiStreamingChatModel.java | 2 +-
.../model/mistralai/MistralAiChatModelIT.java | 6 ++--
.../model/mistralai/MistralAiModelsIT.java | 5 ++--
.../MistralAiStreamingChatModelIT.java | 9 +++---
13 files changed, 62 insertions(+), 54 deletions(-)
rename langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/{MistralChatCompletionModelName.java => MistralAiChatModelName.java} (94%)
rename langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/{MistralEmbeddingModelName.java => MistralAiEmbeddingModelName.java} (73%)
diff --git a/langchain4j-mistral-ai/pom.xml b/langchain4j-mistral-ai/pom.xml
index cc02452f3e..5a9b83da99 100644
--- a/langchain4j-mistral-ai/pom.xml
+++ b/langchain4j-mistral-ai/pom.xml
@@ -55,7 +55,6 @@
2.0.7
-