From 160d53d9316a1db535e7353c4da008fa0acacffc Mon Sep 17 00:00:00 2001 From: czelabueno Date: Mon, 15 Jan 2024 17:25:17 -0500 Subject: [PATCH 01/24] Add Mistral AI model provider --- langchain4j-mistral-ai/pom.xml | 113 +++++++++++++++++++++++++++++++++ pom.xml | 1 + 2 files changed, 114 insertions(+) create mode 100644 langchain4j-mistral-ai/pom.xml diff --git a/langchain4j-mistral-ai/pom.xml b/langchain4j-mistral-ai/pom.xml new file mode 100644 index 0000000000..a651175b47 --- /dev/null +++ b/langchain4j-mistral-ai/pom.xml @@ -0,0 +1,113 @@ + + 4.0.0 + + + dev.langchain4j + langchain4j-parent + 0.26.0-SNAPSHOT + ../langchain4j-parent/pom.xml + + + langchain4j-mistral-ai + jar + + LangChain4j integration with MistralAI + + + + + dev.langchain4j + langchain4j-core + + + + com.squareup.retrofit2 + retrofit + + + + com.squareup.retrofit2 + converter-gson + + + + com.google.code.gson + gson + + + + + + com.squareup.okhttp3 + okhttp + + + + com.squareup.okhttp3 + okhttp-sse + 4.10.0 + + + + org.slf4j + slf4j-api + 2.0.7 + + + + + org.projectlombok + lombok + provided + + + + org.junit.jupiter + junit-jupiter-engine + test + + + + org.junit.jupiter + junit-jupiter-params + test + + + + org.assertj + assertj-core + test + + + + org.tinylog + tinylog-impl + test + + + org.tinylog + slf4j-tinylog + test + + + + + + + Apache-2.0 + https://www.apache.org/licenses/LICENSE-2.0.txt + repo + A business-friendly OSS license + + + + + + czelabueno + c.zelabueno@gmail.com + https://github.com/czelabueno + + + + diff --git a/pom.xml b/pom.xml index dc39e03175..3541cccf3e 100644 --- a/pom.xml +++ b/pom.xml @@ -28,6 +28,7 @@ langchain4j-open-ai langchain4j-vertex-ai langchain4j-vertex-ai-gemini + langchain4j-mistral-ai langchain4j-cassandra From 0cf6b07e0ef6bf1e2c31a31fc3e0b7109a00e8d7 Mon Sep 17 00:00:00 2001 From: czelabueno Date: Mon, 15 Jan 2024 17:30:07 -0500 Subject: [PATCH 02/24] MistralAI chat completions req/resp --- .../model/mistralai/ChatCompletionChoice.java | 20 +++++++++ .../model/mistralai/ChatCompletionModel.java | 42 +++++++++++++++++++ .../mistralai/ChatCompletionRequest.java | 26 ++++++++++++ .../mistralai/ChatCompletionResponse.java | 21 ++++++++++ .../model/mistralai/ChatMessage.java | 17 ++++++++ .../langchain4j/model/mistralai/Delta.java | 17 ++++++++ .../dev/langchain4j/model/mistralai/Role.java | 15 +++++++ 7 files changed, 158 insertions(+) create mode 100644 langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ChatCompletionChoice.java create mode 100644 langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ChatCompletionModel.java create mode 100644 langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ChatCompletionRequest.java create mode 100644 langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ChatCompletionResponse.java create mode 100644 langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ChatMessage.java create mode 100644 langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/Delta.java create mode 100644 langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/Role.java diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ChatCompletionChoice.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ChatCompletionChoice.java new file mode 100644 index 0000000000..bb0d6328bb --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ChatCompletionChoice.java @@ -0,0 +1,20 @@ +package dev.langchain4j.model.mistralai; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + + +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder +class ChatCompletionChoice { + + private Integer index; + private ChatMessage message; + private Delta delta; + private String finishReason; + private UsageInfo usage; //usageInfo is returned only when the prompt is finished in stream mode +} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ChatCompletionModel.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ChatCompletionModel.java new file mode 100644 index 0000000000..ab697dd46a --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ChatCompletionModel.java @@ -0,0 +1,42 @@ +package dev.langchain4j.model.mistralai; + +/** + * Represents the available chat completion models for Mistral AI. + * + *

+ * The chat completion models are used to generate responses for chat-based applications. + * Each model has a specific power and capability level. + *

+ * + *

+ * The available chat completion models are: + *

+ *

+ * + * @see Mistral AI Endpoints + */ +enum ChatCompletionModel { + + // powered by Mistral-7B-v0.2 + MISTRAL_TINY("mistral-tiny"), + // powered Mixtral-8X7B-v0.1 + MISTRAL_SMALL("mistral-small"), + // currently relies on an internal prototype model + MISTRAL_MEDIUM("mistral-medium"); + + private final String value; + + private ChatCompletionModel(String value) { + this.value = value; + } + + public String toString() { + return this.value; + } + + +} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ChatCompletionRequest.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ChatCompletionRequest.java new file mode 100644 index 0000000000..ead10e7697 --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ChatCompletionRequest.java @@ -0,0 +1,26 @@ +package dev.langchain4j.model.mistralai; + + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.List; + +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder +class ChatCompletionRequest { + + private String model; + private List messages; + private Double temperature; + private Double topP; + private Integer maxTokens; + private Boolean stream; + private Boolean safePrompt; + private Integer randomSeed; + +} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ChatCompletionResponse.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ChatCompletionResponse.java new file mode 100644 index 0000000000..66138a3526 --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ChatCompletionResponse.java @@ -0,0 +1,21 @@ +package dev.langchain4j.model.mistralai; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.List; + +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder +class ChatCompletionResponse { + private String id; + private String object; + private Integer created; + private String model; + private List choices; + private UsageInfo usage; +} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ChatMessage.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ChatMessage.java new file mode 100644 index 0000000000..739127edb8 --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ChatMessage.java @@ -0,0 +1,17 @@ +package dev.langchain4j.model.mistralai; + + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder +class ChatMessage { + + private Role role; + private String content; +} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/Delta.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/Delta.java new file mode 100644 index 0000000000..c7a81f20a9 --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/Delta.java @@ -0,0 +1,17 @@ +package dev.langchain4j.model.mistralai; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder +class Delta { + + private Role role; + private String content; + +} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/Role.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/Role.java new file mode 100644 index 0000000000..f2802a3804 --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/Role.java @@ -0,0 +1,15 @@ +package dev.langchain4j.model.mistralai; + +import com.google.gson.annotations.SerializedName; +import lombok.Getter; + +@Getter +public enum Role { + + @SerializedName("system") SYSTEM, + @SerializedName("user") USER, + @SerializedName("assistant") ASSISTANT; + + private Role() {} + +} From d2a47dcc13c8aab6bea461345d911b90af69e9f7 Mon Sep 17 00:00:00 2001 From: czelabueno Date: Mon, 15 Jan 2024 17:31:22 -0500 Subject: [PATCH 03/24] Mistral AI embeddings req/resp --- .../model/mistralai/EmbeddingModel.java | 27 +++++++++++++++++++ .../model/mistralai/EmbeddingObject.java | 20 ++++++++++++++ .../model/mistralai/EmbeddingRequest.java | 20 ++++++++++++++ .../model/mistralai/EmbeddingResponse.java | 22 +++++++++++++++ 4 files changed, 89 insertions(+) create mode 100644 langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/EmbeddingModel.java create mode 100644 langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/EmbeddingObject.java create mode 100644 langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/EmbeddingRequest.java create mode 100644 langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/EmbeddingResponse.java diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/EmbeddingModel.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/EmbeddingModel.java new file mode 100644 index 0000000000..4f0e8a6d4b --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/EmbeddingModel.java @@ -0,0 +1,27 @@ +package dev.langchain4j.model.mistralai; + +/** + * The EmbeddingModel enum represents the available embedding models in the Mistral AI module. + */ +enum EmbeddingModel { + + /** + * The MISTRAL_EMBED model. + */ + MISTRAL_EMBED("mistral-embed"); + + private final String value; + + private EmbeddingModel(String value) { + this.value = value; + } + + /** + * Returns the string representation of the embedding model. + * + * @return the string representation of the embedding model + */ + public String toString() { + return this.value; + } +} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/EmbeddingObject.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/EmbeddingObject.java new file mode 100644 index 0000000000..f13f779b3f --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/EmbeddingObject.java @@ -0,0 +1,20 @@ +package dev.langchain4j.model.mistralai; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.List; + +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder +class EmbeddingObject { + + private String object; + private List embedding; + private Integer index; + +} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/EmbeddingRequest.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/EmbeddingRequest.java new file mode 100644 index 0000000000..21b8e922ee --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/EmbeddingRequest.java @@ -0,0 +1,20 @@ +package dev.langchain4j.model.mistralai; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.List; + +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder +class EmbeddingRequest { + + private String model; + private List input; + private String encodingFormat; + +} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/EmbeddingResponse.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/EmbeddingResponse.java new file mode 100644 index 0000000000..5dce95c35e --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/EmbeddingResponse.java @@ -0,0 +1,22 @@ +package dev.langchain4j.model.mistralai; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.List; + +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder +class EmbeddingResponse { + + private String id; + private String object; + private String model; + private List data; + private UsageInfo usage; + +} From f656f615ed8b04dc76112eed4d5cc9f08e041dd5 Mon Sep 17 00:00:00 2001 From: czelabueno Date: Mon, 15 Jan 2024 17:31:57 -0500 Subject: [PATCH 04/24] Mistral AI Taken usage --- .../langchain4j/model/mistralai/UsageInfo.java | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/UsageInfo.java diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/UsageInfo.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/UsageInfo.java new file mode 100644 index 0000000000..b14fbb6939 --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/UsageInfo.java @@ -0,0 +1,18 @@ +package dev.langchain4j.model.mistralai; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder +class UsageInfo { + + private Integer promptTokens; + private Integer totalTokens; + private Integer completionTokens; + +} From 62be0b265b0173a5f1020cb697ba48aa99bc32ed Mon Sep 17 00:00:00 2001 From: czelabueno Date: Mon, 15 Jan 2024 17:32:46 -0500 Subject: [PATCH 05/24] Mistral AI models req/resp --- .../model/mistralai/ModelCard.java | 24 ++++++++++++++++ .../model/mistralai/ModelPermission.java | 28 +++++++++++++++++++ .../model/mistralai/ModelResponse.java | 19 +++++++++++++ 3 files changed, 71 insertions(+) create mode 100644 langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ModelCard.java create mode 100644 langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ModelPermission.java create mode 100644 langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ModelResponse.java diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ModelCard.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ModelCard.java new file mode 100644 index 0000000000..2bdd049540 --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ModelCard.java @@ -0,0 +1,24 @@ +package dev.langchain4j.model.mistralai; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.List; + +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder +public class ModelCard{ + + private String id; + private String object; + private Integer created; + private String ownerBy; + private String root; + private String parent; + private List permission; + +} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ModelPermission.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ModelPermission.java new file mode 100644 index 0000000000..39550c7812 --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ModelPermission.java @@ -0,0 +1,28 @@ +package dev.langchain4j.model.mistralai; + + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder +class ModelPermission{ + + private String id; + private String object; + private Integer created; + private Boolean allowCreateEngine; + private Boolean allowSampling; + private Boolean allowLogprobs; + private Boolean allowSearchIndices; + private Boolean allowView; + private Boolean allowFineTuning; + private String organization; + private String group; + private Boolean isBlocking; + +} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ModelResponse.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ModelResponse.java new file mode 100644 index 0000000000..c070f87e5b --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ModelResponse.java @@ -0,0 +1,19 @@ +package dev.langchain4j.model.mistralai; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.List; + +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder +class ModelResponse { + + private String object; + private List data; + +} From 9c95a28deb562904694749db7fe20cfbe38f145d Mon Sep 17 00:00:00 2001 From: czelabueno Date: Mon, 15 Jan 2024 17:34:44 -0500 Subject: [PATCH 06/24] Mistral AI Client code --- .../mistralai/DefaultMistralAiHelper.java | 121 +++++++++++ .../model/mistralai/MistralAiApi.java | 26 +++ .../mistralai/MistralAiApiKeyInterceptor.java | 28 +++ .../model/mistralai/MistralAiClient.java | 199 ++++++++++++++++++ 4 files changed, 374 insertions(+) create mode 100644 langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/DefaultMistralAiHelper.java create mode 100644 langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiApi.java create mode 100644 langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiApiKeyInterceptor.java create mode 100644 langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiClient.java diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/DefaultMistralAiHelper.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/DefaultMistralAiHelper.java new file mode 100644 index 0000000000..bfb444e13a --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/DefaultMistralAiHelper.java @@ -0,0 +1,121 @@ +package dev.langchain4j.model.mistralai; + +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.ChatMessageType; +import dev.langchain4j.model.output.FinishReason; +import dev.langchain4j.model.output.TokenUsage; +import okhttp3.Response; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; + +import static dev.langchain4j.internal.Utils.isNullOrBlank; +import static dev.langchain4j.model.output.FinishReason.*; +import static java.util.stream.Collectors.toList; + +public class DefaultMistralAiHelper{ + + private static final Logger LOGGER = LoggerFactory.getLogger(DefaultMistralAiHelper.class); + static final String MISTRALAI_API_URL = "https://api.mistral.ai/v1"; + static final String MISTRALAI_API_CREATE_EMBEDDINGS_ENCODING_FORMAT = "float"; + + public static String ensureNotBlankApiKey(String value) { + if (isNullOrBlank(value)) { + throw new IllegalArgumentException("MistralAI API Key must be defined. It can be generated here: https://console.mistral.ai/user/api-keys/"); + } + return value; + } + + public static String formattedURLForRetrofit(String baseUrl) { + return baseUrl.endsWith("/") ? baseUrl : baseUrl + "/"; + } + + public static List toMistralAiMessages(List messages) { + return messages.stream() + .map(DefaultMistralAiHelper::toMistralAiMessage) + .collect(toList()); + } + + public static dev.langchain4j.model.mistralai.ChatMessage toMistralAiMessage(ChatMessage message) { + return dev.langchain4j.model.mistralai.ChatMessage.builder() + .role(toMistralAiRole(message.type())) + .content(message.text()) + .build(); + } + + private static Role toMistralAiRole(ChatMessageType chatMessageType) { + switch (chatMessageType) { + case SYSTEM: + return Role.SYSTEM; + case AI: + return Role.ASSISTANT; + case USER: + return Role.USER; + default: + throw new IllegalArgumentException("Unknown chat message type: " + chatMessageType); + } + + } + + public static TokenUsage tokenUsageFrom(UsageInfo mistralAiUsage) { + if (mistralAiUsage == null) { + return null; + } + return new TokenUsage( + mistralAiUsage.getPromptTokens(), + mistralAiUsage.getCompletionTokens(), + mistralAiUsage.getTotalTokens() + ); + } + + public static FinishReason finishReasonFrom(String mistralAiFinishReason) { + if (mistralAiFinishReason == null) { + return null; + } + switch (mistralAiFinishReason) { + case "stop": + return STOP; + case "length": + return LENGTH; + case "model_length": + default: + return null; + } + } + + public static void logResponse(Response response){ + try { + LOGGER.debug("Response code: {}", response.code()); + LOGGER.debug("Response body: {}", getResponseBody(response)); + LOGGER.debug("Response headers: {}", getResponseHeaders(response)); + } catch (IOException e) { + LOGGER.warn("Error while logging response", e); + } + } + + private static String getResponseBody(Response response) throws IOException { + return isEventStream(response) ? "" : response.peekBody(Long.MAX_VALUE).string(); + } + + private static String getResponseHeaders(Response response){ + return (String) StreamSupport.stream(response.headers().spliterator(),false).map(header -> { + String headerKey = header.component1(); + String headerValue = header.component2(); + if (headerKey.equals("Authorization")) { + headerValue = "Bearer " + headerValue.substring(0, 5) + "..." + headerValue.substring(headerValue.length() - 5); + } else if (headerKey.equals("api-key")) { + headerValue = headerValue.substring(0, 2) + "..." + headerValue.substring(headerValue.length() - 2); + } + return String.format("[%s: %s]", headerKey, headerValue); + }).collect(Collectors.joining(", ")); + } + private static boolean isEventStream(Response response){ + String contentType = response.header("Content-Type"); + return contentType != null && contentType.contains("event-stream"); + } + +} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiApi.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiApi.java new file mode 100644 index 0000000000..8657f29b95 --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiApi.java @@ -0,0 +1,26 @@ +package dev.langchain4j.model.mistralai; + +import okhttp3.ResponseBody; +import retrofit2.Call; +import retrofit2.http.*; + +interface MistralAiApi { + + @POST("chat/completions") + @Headers({"Content-Type: application/json"}) + Call chatCompletion(@Body ChatCompletionRequest request); + + @POST("chat/completions") + @Headers({"Content-Type: application/json"}) + @Streaming + Call streamingChatCompletion(@Body ChatCompletionRequest request); + + @POST("embeddings") + @Headers({"Content-Type: application/json"}) + Call embedding(@Body EmbeddingRequest request); + + @GET("models") + @Headers({"Content-Type: application/json"}) + Call models(); + +} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiApiKeyInterceptor.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiApiKeyInterceptor.java new file mode 100644 index 0000000000..411431381e --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiApiKeyInterceptor.java @@ -0,0 +1,28 @@ +package dev.langchain4j.model.mistralai; + +import okhttp3.Interceptor; +import okhttp3.Request; +import okhttp3.Response; +import org.jetbrains.annotations.NotNull; + +import java.io.IOException; + +class MistralAiApiKeyInterceptor implements Interceptor { + + private final String apiKey; + + MistralAiApiKeyInterceptor(String apiKey) { + this.apiKey = apiKey; + } + + + @NotNull + @Override + public Response intercept(@NotNull Chain chain) throws IOException { + Request request = chain.request() + .newBuilder() + .addHeader("Authorization", "Bearer " + apiKey) + .build(); + return chain.proceed(request); + } +} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiClient.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiClient.java new file mode 100644 index 0000000000..3dc13dcf1e --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiClient.java @@ -0,0 +1,199 @@ +package dev.langchain4j.model.mistralai; + +import com.google.gson.FieldNamingPolicy; +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.output.FinishReason; +import dev.langchain4j.model.output.Response; +import lombok.Builder; +import okhttp3.OkHttpClient; +import okhttp3.ResponseBody; +import okhttp3.sse.EventSource; +import okhttp3.sse.EventSourceListener; +import okhttp3.sse.EventSources; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import retrofit2.Retrofit; +import retrofit2.converter.gson.GsonConverterFactory; + +import java.io.IOException; +import java.time.Duration; + +import static dev.langchain4j.model.mistralai.DefaultMistralAiHelper.*; +import static dev.langchain4j.model.output.FinishReason.*; + +class MistralAiClient { + + private static final Logger LOGGER = LoggerFactory.getLogger(MistralAiClient.class); + private static final Gson GSON = new GsonBuilder() + .setFieldNamingPolicy(FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES) + .setPrettyPrinting() + .create(); + private final MistralAiApi mistralAiApi; + private final OkHttpClient okHttpClient; + + @Builder + public MistralAiClient(String baseUrl, String apiKey, Duration timeout) { + okHttpClient = new OkHttpClient.Builder() + .addInterceptor(new MistralAiApiKeyInterceptor(apiKey)) + .callTimeout(timeout) + .connectTimeout(timeout) + .readTimeout(timeout) + .writeTimeout(timeout) + .build(); + + Retrofit retrofit = new Retrofit.Builder() + .baseUrl(baseUrl) + .client(okHttpClient) + .addConverterFactory(GsonConverterFactory.create(GSON)) + .build(); + + mistralAiApi = retrofit.create(MistralAiApi.class); + } + + public ChatCompletionResponse chatCompletion(ChatCompletionRequest request) { + try { + retrofit2.Response retrofitResponse + = mistralAiApi.chatCompletion(request).execute(); + LOGGER.debug("ChatCompletionResponse: {}", retrofitResponse); + if (retrofitResponse.isSuccessful()) { + LOGGER.error("ChatCompletionResponseBody: {}", retrofitResponse.body()); + return retrofitResponse.body(); + } else { + throw toException(retrofitResponse); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public void streamingChatCompletion(ChatCompletionRequest request, StreamingResponseHandler handler) { + EventSourceListener eventSourceListener = new EventSourceListener() { + StringBuilder contentBuilder = new StringBuilder(); + UsageInfo tokenUsage = new UsageInfo(); + FinishReason lastFinishReason = null; + + @Override + public void onOpen(EventSource eventSource, okhttp3.Response response) { + LOGGER.debug("onOpen()"); + logResponse(response); + } + + @Override + public void onEvent(EventSource eventSource, String id, String type, String data) { + + LOGGER.debug("onEvent() {}", data); + if ("[DONE]".equals(data)) { + Response response = Response.from( + AiMessage.from(contentBuilder.toString()), + tokenUsageFrom(tokenUsage), + lastFinishReason + ); + handler.onComplete(response); + } else { + try { + ChatCompletionResponse chatCompletionResponse = GSON.fromJson(data, ChatCompletionResponse.class); + ChatCompletionChoice choice = chatCompletionResponse.getChoices().get(0); + String chunk = choice.getDelta().getContent(); + contentBuilder.append(chunk); + handler.onNext(chunk); + + //Retrieving token usage of the last choice + if(choice.getFinishReason() != null){ + FinishReason finishReason = finishReasonFrom(choice.getFinishReason()); + switch (finishReason){ + case STOP: + lastFinishReason = STOP; + tokenUsage = choice.getUsage(); + break; + case LENGTH: + lastFinishReason = LENGTH; + tokenUsage = choice.getUsage(); + break; + default: + break; + } + } + } catch (Exception e) { + handler.onError(e); + throw new RuntimeException(e); + } + + } + } + + @Override + public void onFailure(EventSource eventSource, Throwable t, okhttp3.Response response) { + LOGGER.debug("onFailure()", t); + logResponse(response); + + if (t != null){ + handler.onError(t); + } else { + handler.onError(new RuntimeException(String.format("status code: %s; body: %s", response.code(), response.body()))); + } + } + + @Override + public void onClosed(EventSource eventSource) { + LOGGER.debug("onClosed()"); + } + + }; + + EventSources.createFactory(this.okHttpClient) + .newEventSource( + mistralAiApi.streamingChatCompletion(request).request(), + eventSourceListener); + } + + public EmbeddingResponse embedding(EmbeddingRequest request) { + try { + retrofit2.Response retrofitResponse + = mistralAiApi.embedding(request).execute(); + LOGGER.debug("EmbeddingResponse: {}", retrofitResponse); + if (retrofitResponse.isSuccessful()) { + LOGGER.debug("EmbeddingResponseBody: {}", retrofitResponse.body()); + return retrofitResponse.body(); + } else { + throw toException(retrofitResponse); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public ModelResponse listModels() { + try { + retrofit2.Response retrofitResponse + = mistralAiApi.models().execute(); + LOGGER.debug("ModelResponse: {}", retrofitResponse); + if (retrofitResponse.isSuccessful()) { + LOGGER.debug("ModelResponseBody: {}", retrofitResponse.body()); + return retrofitResponse.body(); + } else { + throw toException(retrofitResponse); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private RuntimeException toException(retrofit2.Response retrofitResponse) throws IOException { + int code = retrofitResponse.code(); + if (code >= 400) { + ResponseBody errorBody = retrofitResponse.errorBody(); + if (errorBody != null) { + String errorBodyString = errorBody.string(); + String errorMessage = String.format("status code: %s; body: %s", code, errorBodyString); + LOGGER.error("Error response: {}", errorMessage); + return new RuntimeException(errorMessage); + } + } + return new RuntimeException(retrofitResponse.message()); + } + + +} From 4e6f0084559daa3c7bc5e3f2a65e6530cad18f44 Mon Sep 17 00:00:00 2001 From: czelabueno Date: Mon, 15 Jan 2024 17:35:16 -0500 Subject: [PATCH 07/24] Mistral AI Chat model --- .../model/mistralai/MistralAiChatModel.java | 113 ++++++++++++++++++ 1 file changed, 113 insertions(+) create mode 100644 langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatModel.java diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatModel.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatModel.java new file mode 100644 index 0000000000..321fff9c84 --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatModel.java @@ -0,0 +1,113 @@ +package dev.langchain4j.model.mistralai; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.output.Response; +import lombok.Builder; + +import java.time.Duration; +import java.util.List; + +import static dev.langchain4j.data.message.AiMessage.aiMessage; +import static dev.langchain4j.internal.RetryUtils.withRetry; +import static dev.langchain4j.internal.Utils.getOrDefault; +import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; +import static dev.langchain4j.model.mistralai.DefaultMistralAiHelper.*; + +/** + * Represents a Mistral AI Chat Model with a chat completion interface, such as mistral-tiny and mistral-small. + * This model allows generating chat completion of a sync way based on a list of chat messages. + * You can find description of parameters here. + */ +public class MistralAiChatModel implements ChatLanguageModel { + + private final MistralAiClient client; + private final String modelName; + private final Double temperature; + private final Double topP; + private final Integer maxNewTokens; + private final Boolean safePrompt; + private final Integer randomSeed; + + private final Integer maxRetries; + + /** + * Constructs a MistralAiChatModel with the specified parameters. + * + * @param baseUrl the base URL of the Mistral AI API. It uses the default value if not specified + * @param apiKey the API key for authentication + * @param modelName the name of the Mistral AI model to use + * @param temperature the temperature parameter for generating chat responses + * @param topP the top-p parameter for generating chat responses + * @param maxNewTokens the maximum number of new tokens to generate in a chat response + * @param safePrompt a flag indicating whether to use a safe prompt for generating chat responses + * @param randomSeed the random seed for generating chat responses + * @param timeout the timeout duration for API requests + * @param maxRetries the maximum number of retries for API requests + */ + @Builder + public MistralAiChatModel(String baseUrl, + String apiKey, + String modelName, + Double temperature, + Double topP, + Integer maxNewTokens, + Boolean safePrompt, + Integer randomSeed, + Duration timeout, + Integer maxRetries) { + + this.client = MistralAiClient.builder() + .baseUrl(formattedURLForRetrofit(getOrDefault(baseUrl, MISTRALAI_API_URL))) + .apiKey(ensureNotBlankApiKey(apiKey)) + .timeout(getOrDefault(timeout, Duration.ofSeconds(60))) + .build(); + this.modelName = getOrDefault(modelName, ChatCompletionModel.MISTRAL_TINY.toString()); + this.temperature = temperature; + this.topP = topP; + this.maxNewTokens = maxNewTokens; + this.safePrompt = safePrompt; + this.randomSeed = randomSeed; + this.maxRetries = getOrDefault(maxRetries, 3); + + } + + /** + * Creates a MistralAiChatModel with the specified API key. + * + * @param apiKey the API key for authentication + * @return a MistralAiChatModel instance + */ + public static MistralAiChatModel withApiKey(String apiKey) { + return builder().apiKey(apiKey).build(); + } + + /** + * Generates chat response based on the given list of messages. + * + * @param messages the list of chat messages + */ + @Override + public Response generate(List messages) { + ensureNotEmpty(messages, "messages"); + + ChatCompletionRequest request = ChatCompletionRequest.builder() + .model(this.modelName) + .messages(toMistralAiMessages(messages)) + .temperature(this.temperature) + .maxTokens(this.maxNewTokens) + .topP(this.topP) + .randomSeed(this.randomSeed) + .safePrompt(this.safePrompt) + .stream(false) + .build(); + + ChatCompletionResponse response = withRetry(() -> client.chatCompletion(request), maxRetries); + return Response.from( + aiMessage(response.getChoices().get(0).getMessage().getContent()), + tokenUsageFrom(response.getUsage()), + finishReasonFrom(response.getChoices().get(0).getFinishReason()) + ); + } +} From c8583890e3cfebff2e501d43d526f5528c299bc0 Mon Sep 17 00:00:00 2001 From: czelabueno Date: Mon, 15 Jan 2024 17:35:28 -0500 Subject: [PATCH 08/24] Mistral AI Chat Streaming model support --- .../MistralAiStreamingChatModel.java | 105 ++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModel.java diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModel.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModel.java new file mode 100644 index 0000000000..2af5483174 --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModel.java @@ -0,0 +1,105 @@ +package dev.langchain4j.model.mistralai; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import lombok.Builder; + +import java.time.Duration; +import java.util.List; + +import static dev.langchain4j.internal.Utils.getOrDefault; +import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; +import static dev.langchain4j.model.mistralai.DefaultMistralAiHelper.*; + + +/** + * Represents a Mistral AI Chat Model with a chat completion interface, such as mistral-tiny and mistral-small. + * The model's response is streamed token by token and should be handled with {@link StreamingResponseHandler}. + * You can find description of parameters here. + */ +/** + * Represents a streaming chat model for Mistral AI. + */ +public class MistralAiStreamingChatModel implements StreamingChatLanguageModel { + + private final MistralAiClient client; + private final String modelName; + private final Double temperature; + private final Double topP; + private final Integer maxNewTokens; + private final Boolean safePrompt; + private final Integer randomSeed; + + /** + * Constructs a MistralAiStreamingChatModel with the specified parameters. + * + * @param baseUrl the base URL of the Mistral AI API. It uses the default value if not specified + * @param apiKey the API key for authentication + * @param modelName the name of the Mistral AI model to use + * @param temperature the temperature parameter for generating chat responses + * @param topP the top-p parameter for generating chat responses + * @param maxNewTokens the maximum number of new tokens to generate in a chat response + * @param safePrompt a flag indicating whether to use a safe prompt for generating chat responses + * @param randomSeed the random seed for generating chat responses + * @param timeout the timeout duration for API requests + */ + @Builder + public MistralAiStreamingChatModel(String baseUrl, + String apiKey, + String modelName, + Double temperature, + Double topP, + Integer maxNewTokens, + Boolean safePrompt, + Integer randomSeed, + Duration timeout) { + + this.client = MistralAiClient.builder() + .baseUrl(formattedURLForRetrofit(getOrDefault(baseUrl, MISTRALAI_API_URL))) + .apiKey(ensureNotBlankApiKey(apiKey)) + .timeout(getOrDefault(timeout, Duration.ofSeconds(60))) + .build(); + this.modelName = getOrDefault(modelName, ChatCompletionModel.MISTRAL_TINY.toString()); + this.temperature = temperature; + this.topP = topP; + this.maxNewTokens = maxNewTokens; + this.safePrompt = safePrompt; + this.randomSeed = randomSeed; + } + + /** + * Creates a MistralAiStreamingChatModel with the specified API key. + * + * @param apiKey the API key for authentication + * @return a MistralAiStreamingChatModel instance + */ + public static MistralAiStreamingChatModel withApiKey(String apiKey) { + return builder().apiKey(apiKey).build(); + } + + /** + * Generates streamed token response based on the given list of messages. + * + * @param messages the list of chat messages + * @param handler the response handler for processing the generated chat chunk responses + */ + @Override + public void generate(List messages, StreamingResponseHandler handler) { + ensureNotEmpty(messages, "messages"); + + ChatCompletionRequest request = ChatCompletionRequest.builder() + .model(this.modelName) + .messages(toMistralAiMessages(messages)) + .temperature(this.temperature) + .maxTokens(this.maxNewTokens) + .topP(this.topP) + .randomSeed(this.randomSeed) + .safePrompt(this.safePrompt) + .stream(true) + .build(); + + client.streamingChatCompletion(request, handler); + } +} From a3c9fc8992e6df685b6ed0484604b9d55cb4d49c Mon Sep 17 00:00:00 2001 From: czelabueno Date: Mon, 15 Jan 2024 17:35:41 -0500 Subject: [PATCH 09/24] Mistral AI embedding model support --- .../mistralai/MistralAiEmbeddingModel.java | 87 +++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModel.java diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModel.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModel.java new file mode 100644 index 0000000000..b31834e861 --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModel.java @@ -0,0 +1,87 @@ +package dev.langchain4j.model.mistralai; + +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.output.Response; +import lombok.Builder; + +import java.time.Duration; +import java.util.List; + +import static dev.langchain4j.internal.RetryUtils.withRetry; +import static dev.langchain4j.internal.Utils.getOrDefault; +import static dev.langchain4j.model.mistralai.DefaultMistralAiHelper.*; +import static java.util.stream.Collectors.toList; + +/** + * Represents a Mistral AI embedding model, such as mistral-embed. + * You can find description of parameters here. + */ +public class MistralAiEmbeddingModel implements EmbeddingModel { + + private final MistralAiClient client; + private final String modelName; + private final Integer maxRetries; + + /** + * Constructs a new MistralAiEmbeddingModel instance. + * + * @param baseUrl the base URL of the Mistral AI API. It use a default value if not specified + * @param apiKey the API key for authentication + * @param modelName the name of the embedding model. It uses a default value if not specified + * @param timeout the timeout duration for API requests. It uses a default value of 60 seconds if not specified + * @param maxRetries the maximum number of retries for API requests. It uses a default value of 3 if not specified + */ + @Builder + public MistralAiEmbeddingModel(String baseUrl, + String apiKey, + String modelName, + Duration timeout, + Integer maxRetries) { + this.client = MistralAiClient.builder() + .baseUrl(formattedURLForRetrofit(getOrDefault(baseUrl, MISTRALAI_API_URL))) + .apiKey(ensureNotBlankApiKey(apiKey)) + .timeout(getOrDefault(timeout, Duration.ofSeconds(60))) + .build(); + this.modelName = getOrDefault(modelName, dev.langchain4j.model.mistralai.EmbeddingModel.MISTRAL_EMBED.toString()); + this.maxRetries = getOrDefault(maxRetries, 3); + } + + /** + * Creates a new MistralAiEmbeddingModel instance with the specified API key. + * + * @param apiKey the Mistral AI API key for authentication + * @return a new MistralAiEmbeddingModel instance + */ + public static MistralAiEmbeddingModel withApiKey(String apiKey) { + return builder().apiKey(apiKey).build(); + } + + /** + * Embeds a list of text segments using the Mistral AI embedding model. + * + * @param textSegments the list of text segments to embed + * @return a Response object containing the embeddings and token usage information + */ + @Override + public Response> embedAll(List textSegments) { + + EmbeddingRequest request = EmbeddingRequest.builder() + .model(modelName) + .input(textSegments.stream().map(TextSegment::text).collect(toList())) + .encodingFormat(MISTRALAI_API_CREATE_EMBEDDINGS_ENCODING_FORMAT) + .build(); + + EmbeddingResponse response = withRetry(() -> client.embedding(request), maxRetries); + + List embeddings = response.getData().stream() + .map(mistralAiEmbedding -> Embedding.from(mistralAiEmbedding.getEmbedding())) + .collect(toList()); + + return Response.from( + embeddings, + tokenUsageFrom(response.getUsage()) + ); + } +} From 2293142ff7596d972dc7f775f35cf54dc6a70b23 Mon Sep 17 00:00:00 2001 From: czelabueno Date: Mon, 15 Jan 2024 17:35:58 -0500 Subject: [PATCH 10/24] Mistral Ai get models from API --- .../model/mistralai/MistralAiModels.java | 88 +++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModels.java diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModels.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModels.java new file mode 100644 index 0000000000..80f75defee --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModels.java @@ -0,0 +1,88 @@ +package dev.langchain4j.model.mistralai; + +import dev.langchain4j.model.output.Response; +import lombok.Builder; + +import java.time.Duration; +import java.util.List; + +import static dev.langchain4j.internal.RetryUtils.withRetry; +import static dev.langchain4j.internal.Utils.getOrDefault; +import static dev.langchain4j.model.mistralai.DefaultMistralAiHelper.*; +import static java.util.stream.Collectors.toList; + +/** + * Represents a collection of Mistral AI models. + * You can find description of parameters here. + */ +public class MistralAiModels { + + private final MistralAiClient client; + private final Integer maxRetries; + + /** + * Constructs a new instance of MistralAiModels. + * + * @param baseUrl the base URL of the Mistral AI API. It uses the default value if not specified + * @param apiKey the API key for authentication + * @param timeout the timeout duration for API requests. It uses the default value of 60 seconds if not specified + * @param maxRetries the maximum number of retries for API requests. It uses the default value of 3 if not specified + */ + @Builder + public MistralAiModels(String baseUrl, + String apiKey, + Duration timeout, + Integer maxRetries) { + this.client = MistralAiClient.builder() + .baseUrl(formattedURLForRetrofit(getOrDefault(baseUrl, MISTRALAI_API_URL))) + .apiKey(ensureNotBlankApiKey(apiKey)) + .timeout(getOrDefault(timeout, Duration.ofSeconds(60))) + .build(); + this.maxRetries = getOrDefault(maxRetries, 3); + } + + /** + * Creates a new instance of MistralAiModels with the specified API key. + * + * @param apiKey the API key for authentication + * @return a new instance of MistralAiModels + */ + public static MistralAiModels withApiKey(String apiKey) { + return builder().apiKey(apiKey).build(); + } + + /** + * Retrieves the details of a specific model. + * + * @param modelId the ID of the model + * @return the response containing the model details + */ + public Response getModelDetails(String modelId){ + return Response.from( + this.getModels().content().stream().filter(modelCard -> modelCard.getId().equals(modelId)).findFirst().orElse(null) + ); + } + + /** + * Retrieves the IDs of all available models. + * + * @return the response containing the list of model IDs + */ + public Response> get(){ + return Response.from( + this.getModels().content().stream().map(ModelCard::getId).collect(toList()) + ); + } + + /** + * Retrieves the list of all available models. + * + * @return the response containing the list of models + */ + public Response> getModels(){ + ModelResponse response = withRetry(client::listModels, maxRetries); + return Response.from( + response.getData() + ); + } +} From 06e7ccee13b06a1e8f704922f91f903802f8ff47 Mon Sep 17 00:00:00 2001 From: czelabueno Date: Mon, 15 Jan 2024 17:36:33 -0500 Subject: [PATCH 11/24] Mistral AI chat model tests --- .../model/mistralai/MistralAiChatModelIT.java | 206 ++++++++++++++++++ 1 file changed, 206 insertions(+) create mode 100644 langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiChatModelIT.java diff --git a/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiChatModelIT.java b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiChatModelIT.java new file mode 100644 index 0000000000..a217143e42 --- /dev/null +++ b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiChatModelIT.java @@ -0,0 +1,206 @@ +package dev.langchain4j.model.mistralai; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; +import org.junit.jupiter.api.Test; + +import static dev.langchain4j.data.message.UserMessage.userMessage; +import static org.assertj.core.api.Assertions.assertThat; +import static dev.langchain4j.model.output.FinishReason.*; + +class MistralAiChatModelIT { + + ChatLanguageModel model = MistralAiChatModel.builder() + .apiKey(System.getenv("MISTRAL_AI_API_KEY")) + .temperature(0.1) + .build(); + + @Test + void should_generate_answer_and_return_token_usage_and_finish_reason_stop(){ + + // given + UserMessage userMessage = userMessage("What is the capital of Peru?"); + + // when + Response response = model.generate(userMessage); + + // then + assertThat(response.content().text()).contains("Lima"); + + TokenUsage tokenUsage = response.tokenUsage(); + assertThat(tokenUsage.inputTokenCount()).isEqualTo(15); + assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage.totalTokenCount()) + .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); + + assertThat(response.finishReason()).isEqualTo(STOP); + } + + @Test + void should_generate_answer_and_return_token_usage_and_finish_reason_length(){ + + // given + ChatLanguageModel model = MistralAiChatModel.builder() + .apiKey(System.getenv("MISTRAL_AI_API_KEY")) + .maxNewTokens(4) + .build(); + + // given + UserMessage userMessage = userMessage("What is the capital of Peru?"); + + // when + Response response = model.generate(userMessage); + + // then + assertThat(response.content().text()).isNotBlank(); + + TokenUsage tokenUsage = response.tokenUsage(); + assertThat(tokenUsage.inputTokenCount()).isEqualTo(15); + assertThat(tokenUsage.outputTokenCount()).isEqualTo(4); + assertThat(tokenUsage.totalTokenCount()) + .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); + + assertThat(response.finishReason()).isEqualTo(LENGTH); + } + + //https://docs.mistral.ai/platform/guardrailing/ + @Test + void should_generate_system_prompt_to_enforce_guardrails(){ + // given + ChatLanguageModel model = MistralAiChatModel.builder() + .apiKey(System.getenv("MISTRAL_AI_API_KEY")) + .safePrompt(true) + .build(); + + // given + UserMessage userMessage = userMessage("Hello, my name is Carlos"); + + // then + Response response = model.generate(userMessage); + + // then + AiMessage aiMessage = response.content(); + assertThat(aiMessage.text()).contains("respect"); + assertThat(aiMessage.text()).contains("truth"); + + TokenUsage tokenUsage = response.tokenUsage(); + assertThat(tokenUsage.inputTokenCount()).isGreaterThan(50); + assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage.totalTokenCount()) + .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); + + assertThat(response.finishReason()).isEqualTo(STOP); + + } + + @Test + void should_generate_answer_and_return_token_usage_and_finish_reason_stop_with_multiple_messages(){ + + // given + UserMessage userMessage1 = userMessage("What is the capital of Peru?"); + UserMessage userMessage2 = userMessage("What is the capital of France?"); + UserMessage userMessage3 = userMessage("What is the capital of Canada?"); + + // when + Response response = model.generate(userMessage1, userMessage2, userMessage3); + + // then + assertThat(response.content().text()).contains("Lima"); + assertThat(response.content().text()).contains("Paris"); + assertThat(response.content().text()).contains("Ottawa"); + + TokenUsage tokenUsage = response.tokenUsage(); + assertThat(tokenUsage.inputTokenCount()).isEqualTo(11 + 11 + 11); + assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage.totalTokenCount()) + .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); + + assertThat(response.finishReason()).isEqualTo(STOP); + } + + @Test + void should_generate_answer_in_french_using_model_small_and_return_token_usage_and_finish_reason_stop(){ + + // given - Mistral Small = Mistral-8X7B + ChatLanguageModel model = MistralAiChatModel.builder() + .apiKey(System.getenv("MISTRAL_AI_API_KEY")) + .modelName(ChatCompletionModel.MISTRAL_SMALL.toString()) + .temperature(0.1) + .build(); + + UserMessage userMessage = userMessage("Quelle est la capitale du Pérou?"); + + // when + Response response = model.generate(userMessage); + + // then + assertThat(response.content().text()).contains("Lima"); + + TokenUsage tokenUsage = response.tokenUsage(); + assertThat(tokenUsage.inputTokenCount()).isEqualTo(18); + assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage.totalTokenCount()) + .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); + + assertThat(response.finishReason()).isEqualTo(STOP); + } + + @Test + void should_generate_answer_in_spanish_using_model_small_and_return_token_usage_and_finish_reason_stop(){ + + // given - Mistral Small = Mistral-8X7B + ChatLanguageModel model = MistralAiChatModel.builder() + .apiKey(System.getenv("MISTRAL_AI_API_KEY")) + .modelName(ChatCompletionModel.MISTRAL_SMALL.toString()) + .temperature(0.1) + .build(); + + UserMessage userMessage = userMessage("¿Cuál es la capital de Perú?"); + + // when + Response response = model.generate(userMessage); + + // then + assertThat(response.content().text()).contains("Lima"); + + TokenUsage tokenUsage = response.tokenUsage(); + assertThat(tokenUsage.inputTokenCount()).isEqualTo(19); + assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage.totalTokenCount()) + .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); + + assertThat(response.finishReason()).isEqualTo(STOP); + } + + @Test + void should_generate_answer_using_model_medium_and_return_token_usage_and_finish_reason_length(){ + + // given - Mistral Medium = currently relies on an internal prototype model. + ChatLanguageModel model = MistralAiChatModel.builder() + .apiKey(System.getenv("MISTRAL_AI_API_KEY")) + .modelName(ChatCompletionModel.MISTRAL_MEDIUM.toString()) + .maxNewTokens(10) + .build(); + + UserMessage userMessage = userMessage("What is the capital of Peru?"); + + // when + Response response = model.generate(userMessage); + + // then + assertThat(response.content().text()).contains("Lima"); + + TokenUsage tokenUsage = response.tokenUsage(); + assertThat(tokenUsage.inputTokenCount()).isEqualTo(15); + assertThat(tokenUsage.outputTokenCount()).isEqualTo(10); + assertThat(tokenUsage.totalTokenCount()) + .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); + + assertThat(response.finishReason()).isEqualTo(LENGTH); + } + + +} From 13f056bd613f93b46b171534f74a35a9c690a2d1 Mon Sep 17 00:00:00 2001 From: czelabueno Date: Mon, 15 Jan 2024 17:36:50 -0500 Subject: [PATCH 12/24] Mistral AI embeddings model tests --- .../mistralai/MistralAiEmbeddingModelIT.java | 64 +++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModelIT.java diff --git a/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModelIT.java b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModelIT.java new file mode 100644 index 0000000000..627e381a8c --- /dev/null +++ b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModelIT.java @@ -0,0 +1,64 @@ +package dev.langchain4j.model.mistralai; + +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static java.util.Arrays.asList; +import static org.assertj.core.api.Assertions.assertThat; + +class MistralAiEmbeddingModelIT { + + EmbeddingModel model = MistralAiEmbeddingModel.builder() + .apiKey(System.getenv("MISTRAL_AI_API_KEY")) + .build(); + + @Test + void should_embed_and_return_token_usage() { + + // given + TextSegment textSegment = TextSegment.from("Embed this sentence."); + + // when + Response response = model.embed(textSegment); + + // then + assertThat(response.content().vector()).hasSize(1024); + + TokenUsage tokenUsage = response.tokenUsage(); + assertThat(tokenUsage.inputTokenCount()).isEqualTo(7); + assertThat(tokenUsage.outputTokenCount()).isEqualTo(0); + assertThat(tokenUsage.totalTokenCount()).isEqualTo(7); + + assertThat(response.finishReason()).isNull(); + } + + @Test + void should_embed_and_return_token_usage_with_multiple_inputs(){ + + // given + TextSegment textSegment1 = TextSegment.from("Embed this sentence."); + TextSegment textSegment2 = TextSegment.from("As well as this one."); + + // when + Response> response = model.embedAll(asList(textSegment1, textSegment2)); + + // then + assertThat(response.content().size()).isEqualTo(2); + assertThat(response.content().get(0).vector()).hasSize(1024); + assertThat(response.content().get(1).vector()).hasSize(1024); + + TokenUsage tokenUsage = response.tokenUsage(); + assertThat(tokenUsage.inputTokenCount()).isEqualTo(7 + 8); + assertThat(tokenUsage.outputTokenCount()).isEqualTo(0); + assertThat(tokenUsage.totalTokenCount()).isEqualTo(7 + 8); + + assertThat(response.finishReason()).isNull(); + + } +} From ba1930611cf9af1df46473de96666ba9a611fc8d Mon Sep 17 00:00:00 2001 From: czelabueno Date: Mon, 15 Jan 2024 17:37:04 -0500 Subject: [PATCH 13/24] Mistral AI chat streaming model tests --- .../MistralAiStreamingChatModelIT.java | 415 ++++++++++++++++++ 1 file changed, 415 insertions(+) create mode 100644 langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModelIT.java diff --git a/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModelIT.java b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModelIT.java new file mode 100644 index 0000000000..f3f1e35a93 --- /dev/null +++ b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModelIT.java @@ -0,0 +1,415 @@ +package dev.langchain4j.model.mistralai; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; +import org.junit.jupiter.api.Test; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; + +import static dev.langchain4j.data.message.UserMessage.userMessage; +import static dev.langchain4j.model.output.FinishReason.LENGTH; +import static dev.langchain4j.model.output.FinishReason.STOP; +import static java.util.Arrays.asList; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.assertj.core.api.Assertions.assertThat; + +class MistralAiStreamingChatModelIT { + + StreamingChatLanguageModel model = MistralAiStreamingChatModel.builder() + .apiKey(System.getenv("MISTRAL_AI_API_KEY")) + .temperature(0.1) + .build(); + + @Test + void should_stream_answer_and_return_token_usage_and_finish_reason_stop() throws ExecutionException, InterruptedException, TimeoutException { + + CompletableFuture futureAnswer = new CompletableFuture<>(); + CompletableFuture> futureResponse = new CompletableFuture<>(); + + // given + UserMessage userMessage = userMessage("What is the capital of Peru?"); + + // when + model.generate(userMessage.text(), new StreamingResponseHandler() { + + private final StringBuilder answerBuilder = new StringBuilder(); + + @Override + public void onNext(String token) { + System.out.println("onNext: '" + token + "'"); + answerBuilder.append(token); + + } + + @Override + public void onComplete(Response response) { + System.out.println("onComplete: '" + response + "'"); + futureAnswer.complete(answerBuilder.toString()); + futureResponse.complete(response); + } + + @Override + public void onError(Throwable error) { + futureAnswer.completeExceptionally(error); + futureResponse.completeExceptionally(error); + } + }); + + String chunk = futureAnswer.get(10, SECONDS); + Response response = futureResponse.get(10, SECONDS); + + // then + assertThat(chunk).contains("Lima"); + assertThat(response.content().text()).isEqualTo(chunk); + + TokenUsage tokenUsage = response.tokenUsage(); + assertThat(tokenUsage.inputTokenCount()).isEqualTo(15); + assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage.totalTokenCount()) + .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); + + assertThat(response.finishReason()).isEqualTo(STOP); + } + + @Test + void should_stream_answer_and_return_token_usage_and_finish_reason_length() throws ExecutionException, InterruptedException, TimeoutException { + + CompletableFuture futureAnswer = new CompletableFuture<>(); + CompletableFuture> futureResponse = new CompletableFuture<>(); + + // given + StreamingChatLanguageModel model = MistralAiStreamingChatModel.builder() + .apiKey(System.getenv("MISTRAL_AI_API_KEY")) + .maxNewTokens(10) + .build(); + + // given + UserMessage userMessage = userMessage("What is the capital of Peru?"); + + // when + model.generate(userMessage.text(), new StreamingResponseHandler() { + + private final StringBuilder answerBuilder = new StringBuilder(); + + @Override + public void onNext(String token) { + System.out.println("onNext: '" + token + "'"); + answerBuilder.append(token); + + } + + @Override + public void onComplete(Response response) { + System.out.println("onComplete: '" + response + "'"); + futureAnswer.complete(answerBuilder.toString()); + futureResponse.complete(response); + } + + @Override + public void onError(Throwable error) { + futureAnswer.completeExceptionally(error); + futureResponse.completeExceptionally(error); + } + }); + + String chunk = futureAnswer.get(10, SECONDS); + Response response = futureResponse.get(10, SECONDS); + + // then + assertThat(chunk).contains("Lima"); + assertThat(response.content().text()).isEqualTo(chunk); + + TokenUsage tokenUsage = response.tokenUsage(); + assertThat(tokenUsage.inputTokenCount()).isEqualTo(15); + assertThat(tokenUsage.outputTokenCount()).isEqualTo(10); + assertThat(tokenUsage.totalTokenCount()) + .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); + + assertThat(response.finishReason()).isEqualTo(LENGTH); + } + + //https://docs.mistral.ai/platform/guardrailing/ + @Test + void should_stream_answer_and_system_prompt_to_enforce_guardrails() throws ExecutionException, InterruptedException, TimeoutException { + + CompletableFuture futureAnswer = new CompletableFuture<>(); + CompletableFuture> futureResponse = new CompletableFuture<>(); + + // given + StreamingChatLanguageModel model = MistralAiStreamingChatModel.builder() + .apiKey(System.getenv("MISTRAL_AI_API_KEY")) + .safePrompt(true) + .build(); + + // given + UserMessage userMessage = userMessage("Hello, my name is Carlos"); + + // then + model.generate(userMessage.text(), new StreamingResponseHandler() { + private final StringBuilder answerBuilder = new StringBuilder(); + + @Override + public void onNext(String token) { + System.out.println("onNext: '" + token + "'"); + answerBuilder.append(token); + + } + + @Override + public void onComplete(Response response) { + System.out.println("onComplete: '" + response + "'"); + futureAnswer.complete(answerBuilder.toString()); + futureResponse.complete(response); + } + + @Override + public void onError(Throwable error) { + futureAnswer.completeExceptionally(error); + futureResponse.completeExceptionally(error); + } + }); + + String chunk = futureAnswer.get(10, SECONDS); + Response response = futureResponse.get(10, SECONDS); + + // then + assertThat(chunk).contains("respect"); + assertThat(chunk).contains("truth"); + assertThat(response.content().text()).isEqualTo(chunk); + + TokenUsage tokenUsage = response.tokenUsage(); + assertThat(tokenUsage.inputTokenCount()).isGreaterThan(50); + assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage.totalTokenCount()) + .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); + + assertThat(response.finishReason()).isEqualTo(STOP); + + } + + @Test + void should_stream_answer_and_return_token_usage_and_finish_reason_stop_with_multiple_messages() throws ExecutionException, InterruptedException, TimeoutException { + + CompletableFuture futureAnswer = new CompletableFuture<>(); + CompletableFuture> futureResponse = new CompletableFuture<>(); + + // given + UserMessage userMessage1 = userMessage("What is the capital of Peru?"); + UserMessage userMessage2 = userMessage("What is the capital of France?"); + UserMessage userMessage3 = userMessage("What is the capital of Canada?"); + + model.generate(asList(userMessage1,userMessage2,userMessage3), new StreamingResponseHandler(){ + private final StringBuilder answerBuilder = new StringBuilder(); + + @Override + public void onNext(String token) { + System.out.println("onNext: '" + token + "'"); + answerBuilder.append(token); + + } + + @Override + public void onComplete(Response response) { + System.out.println("onComplete: '" + response + "'"); + futureAnswer.complete(answerBuilder.toString()); + futureResponse.complete(response); + } + + @Override + public void onError(Throwable error) { + futureAnswer.completeExceptionally(error); + futureResponse.completeExceptionally(error); + } + }); + + String chunk = futureAnswer.get(10, SECONDS); + Response response = futureResponse.get(10, SECONDS); + + // then + assertThat(chunk).contains("Lima"); + assertThat(chunk).contains("Paris"); + assertThat(chunk).contains("Ottawa"); + assertThat(response.content().text()).isEqualTo(chunk); + + TokenUsage tokenUsage = response.tokenUsage(); + assertThat(tokenUsage.inputTokenCount()).isEqualTo(11 + 11 + 11); + assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage.totalTokenCount()) + .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); + + assertThat(response.finishReason()).isEqualTo(STOP); + + + } + + @Test + void should_stream_answer_in_french_using_model_small_and_return_token_usage_and_finish_reason_stop() throws ExecutionException, InterruptedException, TimeoutException { + + CompletableFuture futureAnswer = new CompletableFuture<>(); + CompletableFuture> futureResponse = new CompletableFuture<>(); + + // given - Mistral Small = Mistral-8X7B + StreamingChatLanguageModel model = MistralAiStreamingChatModel.builder() + .apiKey(System.getenv("MISTRAL_AI_API_KEY")) + .modelName(ChatCompletionModel.MISTRAL_SMALL.toString()) + .temperature(0.1) + .build(); + + UserMessage userMessage = userMessage("Quelle est la capitale du Pérou?"); + + model.generate(userMessage.text(), new StreamingResponseHandler() { + private final StringBuilder answerBuilder = new StringBuilder(); + + @Override + public void onNext(String token) { + System.out.println("onNext: '" + token + "'"); + answerBuilder.append(token); + + } + + @Override + public void onComplete(Response response) { + System.out.println("onComplete: '" + response + "'"); + futureAnswer.complete(answerBuilder.toString()); + futureResponse.complete(response); + } + + @Override + public void onError(Throwable error) { + futureAnswer.completeExceptionally(error); + futureResponse.completeExceptionally(error); + } + }); + + String chunk = futureAnswer.get(10, SECONDS); + Response response = futureResponse.get(10, SECONDS); + + // then + assertThat(chunk).contains("Lima"); + assertThat(response.content().text()).isEqualTo(chunk); + + TokenUsage tokenUsage = response.tokenUsage(); + assertThat(tokenUsage.inputTokenCount()).isEqualTo(18); + assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage.totalTokenCount()) + .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); + + assertThat(response.finishReason()).isEqualTo(STOP); + } + + @Test + void should_stream_answer_in_spanish_using_model_small_and_return_token_usage_and_finish_reason_stop() throws ExecutionException, InterruptedException, TimeoutException { + + CompletableFuture futureAnswer = new CompletableFuture<>(); + CompletableFuture> futureResponse = new CompletableFuture<>(); + + // given - Mistral Small = Mistral-8X7B + StreamingChatLanguageModel model = MistralAiStreamingChatModel.builder() + .apiKey(System.getenv("MISTRAL_AI_API_KEY")) + .modelName(ChatCompletionModel.MISTRAL_SMALL.toString()) + .temperature(0.1) + .build(); + + UserMessage userMessage = userMessage("¿Cuál es la capital de Perú?"); + + model.generate(userMessage.text(), new StreamingResponseHandler() { + private final StringBuilder answerBuilder = new StringBuilder(); + + @Override + public void onNext(String token) { + System.out.println("onNext: '" + token + "'"); + answerBuilder.append(token); + + } + + @Override + public void onComplete(Response response) { + System.out.println("onComplete: '" + response + "'"); + futureAnswer.complete(answerBuilder.toString()); + futureResponse.complete(response); + } + + @Override + public void onError(Throwable error) { + futureAnswer.completeExceptionally(error); + futureResponse.completeExceptionally(error); + } + }); + + String chunk = futureAnswer.get(10, SECONDS); + Response response = futureResponse.get(10, SECONDS); + + // then + assertThat(chunk).contains("Lima"); + assertThat(response.content().text()).isEqualTo(chunk); + + TokenUsage tokenUsage = response.tokenUsage(); + assertThat(tokenUsage.inputTokenCount()).isEqualTo(19); + assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage.totalTokenCount()) + .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); + + assertThat(response.finishReason()).isEqualTo(STOP); + } + + @Test + void should_stream_answer_using_model_medium_and_return_token_usage_and_finish_reason_length() throws ExecutionException, InterruptedException, TimeoutException { + + CompletableFuture futureAnswer = new CompletableFuture<>(); + CompletableFuture> futureResponse = new CompletableFuture<>(); + + // given - Mistral Medium = currently relies on an internal prototype model. + StreamingChatLanguageModel model = MistralAiStreamingChatModel.builder() + .apiKey(System.getenv("MISTRAL_AI_API_KEY")) + .modelName(ChatCompletionModel.MISTRAL_MEDIUM.toString()) + .maxNewTokens(10) + .build(); + + UserMessage userMessage = userMessage("What is the capital of Peru?"); + + model.generate(userMessage.text(), new StreamingResponseHandler() { + private final StringBuilder answerBuilder = new StringBuilder(); + + @Override + public void onNext(String token) { + System.out.println("onNext: '" + token + "'"); + answerBuilder.append(token); + + } + + @Override + public void onComplete(Response response) { + System.out.println("onComplete: '" + response + "'"); + futureAnswer.complete(answerBuilder.toString()); + futureResponse.complete(response); + } + + @Override + public void onError(Throwable error) { + futureAnswer.completeExceptionally(error); + futureResponse.completeExceptionally(error); + } + }); + + String chunk = futureAnswer.get(10, SECONDS); + Response response = futureResponse.get(10, SECONDS); + + // then + assertThat(chunk).contains("Lima"); + assertThat(response.content().text()).isEqualTo(chunk); + + TokenUsage tokenUsage = response.tokenUsage(); + assertThat(tokenUsage.inputTokenCount()).isEqualTo(15); + assertThat(tokenUsage.outputTokenCount()).isEqualTo(10); + assertThat(tokenUsage.totalTokenCount()) + .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); + + assertThat(response.finishReason()).isEqualTo(LENGTH); + } +} From f09458f03844e5f7796ab079f1e8ff9ec6428044 Mon Sep 17 00:00:00 2001 From: czelabueno Date: Mon, 15 Jan 2024 17:37:16 -0500 Subject: [PATCH 14/24] Mistral AI get models tests --- .../model/mistralai/MistralAiModelsIT.java | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiModelsIT.java diff --git a/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiModelsIT.java b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiModelsIT.java new file mode 100644 index 0000000000..4a059e2cd2 --- /dev/null +++ b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiModelsIT.java @@ -0,0 +1,51 @@ +package dev.langchain4j.model.mistralai; + +import dev.langchain4j.model.output.Response; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +class MistralAiModelsIT { + + MistralAiModels models = MistralAiModels.withApiKey(System.getenv("MISTRAL_AI_API_KEY")); + + //https://docs.mistral.ai/models/ + @Test + void should_return_all_models() { + // when + Response> response = models.get(); + + // then + assertThat(response.content()).isNotEmpty(); + assertThat(response.content().size()).isEqualTo(4); + assertThat(response.content()).contains(ChatCompletionModel.MISTRAL_TINY.toString()); + + } + + @Test + void should_return_one_model_card(){ + // when + Response response = models.getModelDetails(ChatCompletionModel.MISTRAL_TINY.toString()); + + // then + assertThat(response.content()).isNotNull(); + assertThat(response.content()).extracting("id").isEqualTo(ChatCompletionModel.MISTRAL_TINY.toString()); + assertThat(response.content()).extracting("object").isEqualTo("model"); + assertThat(response.content()).extracting("permission").isNotNull(); + } + + @Test + void should_return_all_model_cards(){ + // when + Response> response = models.getModels(); + + // then + assertThat(response.content()).isNotEmpty(); + assertThat(response.content().size()).isEqualTo(4); + assertThat(response.content()).extracting("id").contains(ChatCompletionModel.MISTRAL_TINY.toString()); + assertThat(response.content()).extracting("object").contains("model"); + assertThat(response.content()).extracting("permission").isNotNull(); + } +} From 9117ee338d0c0bc93b43554b93c46cb2fac8087a Mon Sep 17 00:00:00 2001 From: czelabueno Date: Tue, 16 Jan 2024 23:55:48 -0500 Subject: [PATCH 15/24] MistralAI - renamed classes to the project convention names to avoid collisions with core libs --- .../mistralai/DefaultMistralAiHelper.java | 16 +++++------ .../model/mistralai/MistralAiApi.java | 8 +++--- .../model/mistralai/MistralAiChatModel.java | 6 ++-- .../model/mistralai/MistralAiClient.java | 28 +++++++++---------- .../mistralai/MistralAiEmbeddingModel.java | 6 ++-- .../model/mistralai/MistralAiModels.java | 8 +++--- .../MistralAiStreamingChatModel.java | 4 +-- ....java => MistralChatCompletionChoice.java} | 8 +++--- ...l.java => MistralChatCompletionModel.java} | 4 +-- ...java => MistralChatCompletionRequest.java} | 4 +-- ...ava => MistralChatCompletionResponse.java} | 6 ++-- ...atMessage.java => MistralChatMessage.java} | 4 +-- .../{Delta.java => MistralDeltaMessage.java} | 4 +-- ...el.java => MistralEmbeddingModelType.java} | 6 ++-- ...bject.java => MistralEmbeddingObject.java} | 2 +- ...uest.java => MistralEmbeddingRequest.java} | 2 +- ...nse.java => MistralEmbeddingResponse.java} | 6 ++-- .../{ModelCard.java => MistralModelCard.java} | 4 +-- ...ssion.java => MistralModelPermission.java} | 2 +- ...esponse.java => MistralModelResponse.java} | 4 +-- .../{Role.java => MistralRoleType.java} | 4 +-- .../{UsageInfo.java => MistralUsageInfo.java} | 2 +- .../model/mistralai/MistralAiChatModelIT.java | 6 ++-- .../model/mistralai/MistralAiModelsIT.java | 10 +++---- .../MistralAiStreamingChatModelIT.java | 6 ++-- 25 files changed, 80 insertions(+), 80 deletions(-) rename langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/{ChatCompletionChoice.java => MistralChatCompletionChoice.java} (54%) rename langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/{ChatCompletionModel.java => MistralChatCompletionModel.java} (92%) rename langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/{ChatCompletionRequest.java => MistralChatCompletionRequest.java} (84%) rename langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/{ChatCompletionResponse.java => MistralChatCompletionResponse.java} (72%) rename langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/{ChatMessage.java => MistralChatMessage.java} (79%) rename langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/{Delta.java => MistralDeltaMessage.java} (79%) rename langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/{EmbeddingModel.java => MistralEmbeddingModelType.java} (68%) rename langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/{EmbeddingObject.java => MistralEmbeddingObject.java} (91%) rename langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/{EmbeddingRequest.java => MistralEmbeddingRequest.java} (91%) rename langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/{EmbeddingResponse.java => MistralEmbeddingResponse.java} (72%) rename langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/{ModelCard.java => MistralModelCard.java} (82%) rename langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/{ModelPermission.java => MistralModelPermission.java} (95%) rename langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/{ModelResponse.java => MistralModelResponse.java} (78%) rename langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/{Role.java => MistralRoleType.java} (79%) rename langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/{UsageInfo.java => MistralUsageInfo.java} (92%) diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/DefaultMistralAiHelper.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/DefaultMistralAiHelper.java index bfb444e13a..2e00d9fc40 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/DefaultMistralAiHelper.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/DefaultMistralAiHelper.java @@ -34,34 +34,34 @@ public static String formattedURLForRetrofit(String baseUrl) { return baseUrl.endsWith("/") ? baseUrl : baseUrl + "/"; } - public static List toMistralAiMessages(List messages) { + public static List toMistralAiMessages(List messages) { return messages.stream() .map(DefaultMistralAiHelper::toMistralAiMessage) .collect(toList()); } - public static dev.langchain4j.model.mistralai.ChatMessage toMistralAiMessage(ChatMessage message) { - return dev.langchain4j.model.mistralai.ChatMessage.builder() + public static MistralChatMessage toMistralAiMessage(ChatMessage message) { + return MistralChatMessage.builder() .role(toMistralAiRole(message.type())) .content(message.text()) .build(); } - private static Role toMistralAiRole(ChatMessageType chatMessageType) { + private static MistralRoleType toMistralAiRole(ChatMessageType chatMessageType) { switch (chatMessageType) { case SYSTEM: - return Role.SYSTEM; + return MistralRoleType.SYSTEM; case AI: - return Role.ASSISTANT; + return MistralRoleType.ASSISTANT; case USER: - return Role.USER; + return MistralRoleType.USER; default: throw new IllegalArgumentException("Unknown chat message type: " + chatMessageType); } } - public static TokenUsage tokenUsageFrom(UsageInfo mistralAiUsage) { + public static TokenUsage tokenUsageFrom(MistralUsageInfo mistralAiUsage) { if (mistralAiUsage == null) { return null; } diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiApi.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiApi.java index 8657f29b95..92bf7d954d 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiApi.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiApi.java @@ -8,19 +8,19 @@ interface MistralAiApi { @POST("chat/completions") @Headers({"Content-Type: application/json"}) - Call chatCompletion(@Body ChatCompletionRequest request); + Call chatCompletion(@Body MistralChatCompletionRequest request); @POST("chat/completions") @Headers({"Content-Type: application/json"}) @Streaming - Call streamingChatCompletion(@Body ChatCompletionRequest request); + Call streamingChatCompletion(@Body MistralChatCompletionRequest request); @POST("embeddings") @Headers({"Content-Type: application/json"}) - Call embedding(@Body EmbeddingRequest request); + Call embedding(@Body MistralEmbeddingRequest request); @GET("models") @Headers({"Content-Type: application/json"}) - Call models(); + Call models(); } diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatModel.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatModel.java index 321fff9c84..5c6f5f8055 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatModel.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatModel.java @@ -63,7 +63,7 @@ public MistralAiChatModel(String baseUrl, .apiKey(ensureNotBlankApiKey(apiKey)) .timeout(getOrDefault(timeout, Duration.ofSeconds(60))) .build(); - this.modelName = getOrDefault(modelName, ChatCompletionModel.MISTRAL_TINY.toString()); + this.modelName = getOrDefault(modelName, MistralChatCompletionModel.MISTRAL_TINY.toString()); this.temperature = temperature; this.topP = topP; this.maxNewTokens = maxNewTokens; @@ -92,7 +92,7 @@ public static MistralAiChatModel withApiKey(String apiKey) { public Response generate(List messages) { ensureNotEmpty(messages, "messages"); - ChatCompletionRequest request = ChatCompletionRequest.builder() + MistralChatCompletionRequest request = MistralChatCompletionRequest.builder() .model(this.modelName) .messages(toMistralAiMessages(messages)) .temperature(this.temperature) @@ -103,7 +103,7 @@ public Response generate(List messages) { .stream(false) .build(); - ChatCompletionResponse response = withRetry(() -> client.chatCompletion(request), maxRetries); + MistralChatCompletionResponse response = withRetry(() -> client.chatCompletion(request), maxRetries); return Response.from( aiMessage(response.getChoices().get(0).getMessage().getContent()), tokenUsageFrom(response.getUsage()), diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiClient.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiClient.java index 3dc13dcf1e..6383dc8ce2 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiClient.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiClient.java @@ -53,13 +53,13 @@ public MistralAiClient(String baseUrl, String apiKey, Duration timeout) { mistralAiApi = retrofit.create(MistralAiApi.class); } - public ChatCompletionResponse chatCompletion(ChatCompletionRequest request) { + public MistralChatCompletionResponse chatCompletion(MistralChatCompletionRequest request) { try { - retrofit2.Response retrofitResponse + retrofit2.Response retrofitResponse = mistralAiApi.chatCompletion(request).execute(); - LOGGER.debug("ChatCompletionResponse: {}", retrofitResponse); + LOGGER.debug("MistralChatCompletionResponse: {}", retrofitResponse); if (retrofitResponse.isSuccessful()) { - LOGGER.error("ChatCompletionResponseBody: {}", retrofitResponse.body()); + LOGGER.debug("ChatCompletionResponseBody: {}", retrofitResponse.body()); return retrofitResponse.body(); } else { throw toException(retrofitResponse); @@ -69,10 +69,10 @@ public ChatCompletionResponse chatCompletion(ChatCompletionRequest request) { } } - public void streamingChatCompletion(ChatCompletionRequest request, StreamingResponseHandler handler) { + public void streamingChatCompletion(MistralChatCompletionRequest request, StreamingResponseHandler handler) { EventSourceListener eventSourceListener = new EventSourceListener() { StringBuilder contentBuilder = new StringBuilder(); - UsageInfo tokenUsage = new UsageInfo(); + MistralUsageInfo tokenUsage = new MistralUsageInfo(); FinishReason lastFinishReason = null; @Override @@ -94,8 +94,8 @@ public void onEvent(EventSource eventSource, String id, String type, String data handler.onComplete(response); } else { try { - ChatCompletionResponse chatCompletionResponse = GSON.fromJson(data, ChatCompletionResponse.class); - ChatCompletionChoice choice = chatCompletionResponse.getChoices().get(0); + MistralChatCompletionResponse chatCompletionResponse = GSON.fromJson(data, MistralChatCompletionResponse.class); + MistralChatCompletionChoice choice = chatCompletionResponse.getChoices().get(0); String chunk = choice.getDelta().getContent(); contentBuilder.append(chunk); handler.onNext(chunk); @@ -149,11 +149,11 @@ public void onClosed(EventSource eventSource) { eventSourceListener); } - public EmbeddingResponse embedding(EmbeddingRequest request) { + public MistralEmbeddingResponse embedding(MistralEmbeddingRequest request) { try { - retrofit2.Response retrofitResponse + retrofit2.Response retrofitResponse = mistralAiApi.embedding(request).execute(); - LOGGER.debug("EmbeddingResponse: {}", retrofitResponse); + LOGGER.debug("MistralEmbeddingResponse: {}", retrofitResponse); if (retrofitResponse.isSuccessful()) { LOGGER.debug("EmbeddingResponseBody: {}", retrofitResponse.body()); return retrofitResponse.body(); @@ -165,11 +165,11 @@ public EmbeddingResponse embedding(EmbeddingRequest request) { } } - public ModelResponse listModels() { + public MistralModelResponse listModels() { try { - retrofit2.Response retrofitResponse + retrofit2.Response retrofitResponse = mistralAiApi.models().execute(); - LOGGER.debug("ModelResponse: {}", retrofitResponse); + LOGGER.debug("MistralModelResponse: {}", retrofitResponse); if (retrofitResponse.isSuccessful()) { LOGGER.debug("ModelResponseBody: {}", retrofitResponse.body()); return retrofitResponse.body(); diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModel.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModel.java index b31834e861..404df16dbc 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModel.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModel.java @@ -44,7 +44,7 @@ public MistralAiEmbeddingModel(String baseUrl, .apiKey(ensureNotBlankApiKey(apiKey)) .timeout(getOrDefault(timeout, Duration.ofSeconds(60))) .build(); - this.modelName = getOrDefault(modelName, dev.langchain4j.model.mistralai.EmbeddingModel.MISTRAL_EMBED.toString()); + this.modelName = getOrDefault(modelName, MistralEmbeddingModelType.MISTRAL_EMBED.toString()); this.maxRetries = getOrDefault(maxRetries, 3); } @@ -67,13 +67,13 @@ public static MistralAiEmbeddingModel withApiKey(String apiKey) { @Override public Response> embedAll(List textSegments) { - EmbeddingRequest request = EmbeddingRequest.builder() + MistralEmbeddingRequest request = MistralEmbeddingRequest.builder() .model(modelName) .input(textSegments.stream().map(TextSegment::text).collect(toList())) .encodingFormat(MISTRALAI_API_CREATE_EMBEDDINGS_ENCODING_FORMAT) .build(); - EmbeddingResponse response = withRetry(() -> client.embedding(request), maxRetries); + MistralEmbeddingResponse response = withRetry(() -> client.embedding(request), maxRetries); List embeddings = response.getData().stream() .map(mistralAiEmbedding -> Embedding.from(mistralAiEmbedding.getEmbedding())) diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModels.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModels.java index 80f75defee..e741f76bef 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModels.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModels.java @@ -57,7 +57,7 @@ public static MistralAiModels withApiKey(String apiKey) { * @param modelId the ID of the model * @return the response containing the model details */ - public Response getModelDetails(String modelId){ + public Response getModelDetails(String modelId){ return Response.from( this.getModels().content().stream().filter(modelCard -> modelCard.getId().equals(modelId)).findFirst().orElse(null) ); @@ -70,7 +70,7 @@ public Response getModelDetails(String modelId){ */ public Response> get(){ return Response.from( - this.getModels().content().stream().map(ModelCard::getId).collect(toList()) + this.getModels().content().stream().map(MistralModelCard::getId).collect(toList()) ); } @@ -79,8 +79,8 @@ public Response> get(){ * * @return the response containing the list of models */ - public Response> getModels(){ - ModelResponse response = withRetry(client::listModels, maxRetries); + public Response> getModels(){ + MistralModelResponse response = withRetry(client::listModels, maxRetries); return Response.from( response.getData() ); diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModel.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModel.java index 2af5483174..f2dc5ec477 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModel.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModel.java @@ -61,7 +61,7 @@ public MistralAiStreamingChatModel(String baseUrl, .apiKey(ensureNotBlankApiKey(apiKey)) .timeout(getOrDefault(timeout, Duration.ofSeconds(60))) .build(); - this.modelName = getOrDefault(modelName, ChatCompletionModel.MISTRAL_TINY.toString()); + this.modelName = getOrDefault(modelName, MistralChatCompletionModel.MISTRAL_TINY.toString()); this.temperature = temperature; this.topP = topP; this.maxNewTokens = maxNewTokens; @@ -89,7 +89,7 @@ public static MistralAiStreamingChatModel withApiKey(String apiKey) { public void generate(List messages, StreamingResponseHandler handler) { ensureNotEmpty(messages, "messages"); - ChatCompletionRequest request = ChatCompletionRequest.builder() + MistralChatCompletionRequest request = MistralChatCompletionRequest.builder() .model(this.modelName) .messages(toMistralAiMessages(messages)) .temperature(this.temperature) diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ChatCompletionChoice.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatCompletionChoice.java similarity index 54% rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ChatCompletionChoice.java rename to langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatCompletionChoice.java index bb0d6328bb..6b26d8ea2a 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ChatCompletionChoice.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatCompletionChoice.java @@ -10,11 +10,11 @@ @NoArgsConstructor @AllArgsConstructor @Builder -class ChatCompletionChoice { +class MistralChatCompletionChoice { private Integer index; - private ChatMessage message; - private Delta delta; + private MistralChatMessage message; + private MistralDeltaMessage delta; private String finishReason; - private UsageInfo usage; //usageInfo is returned only when the prompt is finished in stream mode + private MistralUsageInfo usage; //usageInfo is returned only when the prompt is finished in stream mode } diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ChatCompletionModel.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatCompletionModel.java similarity index 92% rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ChatCompletionModel.java rename to langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatCompletionModel.java index ab697dd46a..d31762e1b9 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ChatCompletionModel.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatCompletionModel.java @@ -19,7 +19,7 @@ * * @see Mistral AI Endpoints */ -enum ChatCompletionModel { +enum MistralChatCompletionModel { // powered by Mistral-7B-v0.2 MISTRAL_TINY("mistral-tiny"), @@ -30,7 +30,7 @@ enum ChatCompletionModel { private final String value; - private ChatCompletionModel(String value) { + private MistralChatCompletionModel(String value) { this.value = value; } diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ChatCompletionRequest.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatCompletionRequest.java similarity index 84% rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ChatCompletionRequest.java rename to langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatCompletionRequest.java index ead10e7697..f08955ebb6 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ChatCompletionRequest.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatCompletionRequest.java @@ -12,10 +12,10 @@ @NoArgsConstructor @AllArgsConstructor @Builder -class ChatCompletionRequest { +class MistralChatCompletionRequest { private String model; - private List messages; + private List messages; private Double temperature; private Double topP; private Integer maxTokens; diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ChatCompletionResponse.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatCompletionResponse.java similarity index 72% rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ChatCompletionResponse.java rename to langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatCompletionResponse.java index 66138a3526..7106e80309 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ChatCompletionResponse.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatCompletionResponse.java @@ -11,11 +11,11 @@ @NoArgsConstructor @AllArgsConstructor @Builder -class ChatCompletionResponse { +class MistralChatCompletionResponse { private String id; private String object; private Integer created; private String model; - private List choices; - private UsageInfo usage; + private List choices; + private MistralUsageInfo usage; } diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ChatMessage.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatMessage.java similarity index 79% rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ChatMessage.java rename to langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatMessage.java index 739127edb8..ebc1980c58 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ChatMessage.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatMessage.java @@ -10,8 +10,8 @@ @NoArgsConstructor @AllArgsConstructor @Builder -class ChatMessage { +class MistralChatMessage { - private Role role; + private MistralRoleType role; private String content; } diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/Delta.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralDeltaMessage.java similarity index 79% rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/Delta.java rename to langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralDeltaMessage.java index c7a81f20a9..9dae9f54cd 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/Delta.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralDeltaMessage.java @@ -9,9 +9,9 @@ @NoArgsConstructor @AllArgsConstructor @Builder -class Delta { +class MistralDeltaMessage { - private Role role; + private MistralRoleType role; private String content; } diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/EmbeddingModel.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralEmbeddingModelType.java similarity index 68% rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/EmbeddingModel.java rename to langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralEmbeddingModelType.java index 4f0e8a6d4b..31f2abbb15 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/EmbeddingModel.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralEmbeddingModelType.java @@ -1,9 +1,9 @@ package dev.langchain4j.model.mistralai; /** - * The EmbeddingModel enum represents the available embedding models in the Mistral AI module. + * The MistralEmbeddingModelType enum represents the available embedding models in the Mistral AI module. */ -enum EmbeddingModel { +enum MistralEmbeddingModelType { /** * The MISTRAL_EMBED model. @@ -12,7 +12,7 @@ enum EmbeddingModel { private final String value; - private EmbeddingModel(String value) { + private MistralEmbeddingModelType(String value) { this.value = value; } diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/EmbeddingObject.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralEmbeddingObject.java similarity index 91% rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/EmbeddingObject.java rename to langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralEmbeddingObject.java index f13f779b3f..b4568d4ed0 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/EmbeddingObject.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralEmbeddingObject.java @@ -11,7 +11,7 @@ @NoArgsConstructor @AllArgsConstructor @Builder -class EmbeddingObject { +class MistralEmbeddingObject { private String object; private List embedding; diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/EmbeddingRequest.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralEmbeddingRequest.java similarity index 91% rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/EmbeddingRequest.java rename to langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralEmbeddingRequest.java index 21b8e922ee..3b4f3b5bac 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/EmbeddingRequest.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralEmbeddingRequest.java @@ -11,7 +11,7 @@ @NoArgsConstructor @AllArgsConstructor @Builder -class EmbeddingRequest { +class MistralEmbeddingRequest { private String model; private List input; diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/EmbeddingResponse.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralEmbeddingResponse.java similarity index 72% rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/EmbeddingResponse.java rename to langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralEmbeddingResponse.java index 5dce95c35e..bf2c2fcbfc 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/EmbeddingResponse.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralEmbeddingResponse.java @@ -11,12 +11,12 @@ @NoArgsConstructor @AllArgsConstructor @Builder -class EmbeddingResponse { +class MistralEmbeddingResponse { private String id; private String object; private String model; - private List data; - private UsageInfo usage; + private List data; + private MistralUsageInfo usage; } diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ModelCard.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralModelCard.java similarity index 82% rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ModelCard.java rename to langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralModelCard.java index 2bdd049540..77fb5eb3e9 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ModelCard.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralModelCard.java @@ -11,7 +11,7 @@ @NoArgsConstructor @AllArgsConstructor @Builder -public class ModelCard{ +public class MistralModelCard { private String id; private String object; @@ -19,6 +19,6 @@ public class ModelCard{ private String ownerBy; private String root; private String parent; - private List permission; + private List permission; } diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ModelPermission.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralModelPermission.java similarity index 95% rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ModelPermission.java rename to langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralModelPermission.java index 39550c7812..812a22bdda 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ModelPermission.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralModelPermission.java @@ -10,7 +10,7 @@ @NoArgsConstructor @AllArgsConstructor @Builder -class ModelPermission{ +class MistralModelPermission { private String id; private String object; diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ModelResponse.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralModelResponse.java similarity index 78% rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ModelResponse.java rename to langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralModelResponse.java index c070f87e5b..f9e7afff05 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/ModelResponse.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralModelResponse.java @@ -11,9 +11,9 @@ @NoArgsConstructor @AllArgsConstructor @Builder -class ModelResponse { +class MistralModelResponse { private String object; - private List data; + private List data; } diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/Role.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralRoleType.java similarity index 79% rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/Role.java rename to langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralRoleType.java index f2802a3804..cded6c49c7 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/Role.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralRoleType.java @@ -4,12 +4,12 @@ import lombok.Getter; @Getter -public enum Role { +public enum MistralRoleType { @SerializedName("system") SYSTEM, @SerializedName("user") USER, @SerializedName("assistant") ASSISTANT; - private Role() {} + private MistralRoleType() {} } diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/UsageInfo.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralUsageInfo.java similarity index 92% rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/UsageInfo.java rename to langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralUsageInfo.java index b14fbb6939..4b132f700f 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/UsageInfo.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralUsageInfo.java @@ -9,7 +9,7 @@ @NoArgsConstructor @AllArgsConstructor @Builder -class UsageInfo { +class MistralUsageInfo { private Integer promptTokens; private Integer totalTokens; diff --git a/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiChatModelIT.java b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiChatModelIT.java index a217143e42..d65f64a8a0 100644 --- a/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiChatModelIT.java +++ b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiChatModelIT.java @@ -127,7 +127,7 @@ void should_generate_answer_in_french_using_model_small_and_return_token_usage_a // given - Mistral Small = Mistral-8X7B ChatLanguageModel model = MistralAiChatModel.builder() .apiKey(System.getenv("MISTRAL_AI_API_KEY")) - .modelName(ChatCompletionModel.MISTRAL_SMALL.toString()) + .modelName(MistralChatCompletionModel.MISTRAL_SMALL.toString()) .temperature(0.1) .build(); @@ -154,7 +154,7 @@ void should_generate_answer_in_spanish_using_model_small_and_return_token_usage_ // given - Mistral Small = Mistral-8X7B ChatLanguageModel model = MistralAiChatModel.builder() .apiKey(System.getenv("MISTRAL_AI_API_KEY")) - .modelName(ChatCompletionModel.MISTRAL_SMALL.toString()) + .modelName(MistralChatCompletionModel.MISTRAL_SMALL.toString()) .temperature(0.1) .build(); @@ -181,7 +181,7 @@ void should_generate_answer_using_model_medium_and_return_token_usage_and_finish // given - Mistral Medium = currently relies on an internal prototype model. ChatLanguageModel model = MistralAiChatModel.builder() .apiKey(System.getenv("MISTRAL_AI_API_KEY")) - .modelName(ChatCompletionModel.MISTRAL_MEDIUM.toString()) + .modelName(MistralChatCompletionModel.MISTRAL_MEDIUM.toString()) .maxNewTokens(10) .build(); diff --git a/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiModelsIT.java b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiModelsIT.java index 4a059e2cd2..a3199c6d50 100644 --- a/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiModelsIT.java +++ b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiModelsIT.java @@ -20,18 +20,18 @@ void should_return_all_models() { // then assertThat(response.content()).isNotEmpty(); assertThat(response.content().size()).isEqualTo(4); - assertThat(response.content()).contains(ChatCompletionModel.MISTRAL_TINY.toString()); + assertThat(response.content()).contains(MistralChatCompletionModel.MISTRAL_TINY.toString()); } @Test void should_return_one_model_card(){ // when - Response response = models.getModelDetails(ChatCompletionModel.MISTRAL_TINY.toString()); + Response response = models.getModelDetails(MistralChatCompletionModel.MISTRAL_TINY.toString()); // then assertThat(response.content()).isNotNull(); - assertThat(response.content()).extracting("id").isEqualTo(ChatCompletionModel.MISTRAL_TINY.toString()); + assertThat(response.content()).extracting("id").isEqualTo(MistralChatCompletionModel.MISTRAL_TINY.toString()); assertThat(response.content()).extracting("object").isEqualTo("model"); assertThat(response.content()).extracting("permission").isNotNull(); } @@ -39,12 +39,12 @@ void should_return_one_model_card(){ @Test void should_return_all_model_cards(){ // when - Response> response = models.getModels(); + Response> response = models.getModels(); // then assertThat(response.content()).isNotEmpty(); assertThat(response.content().size()).isEqualTo(4); - assertThat(response.content()).extracting("id").contains(ChatCompletionModel.MISTRAL_TINY.toString()); + assertThat(response.content()).extracting("id").contains(MistralChatCompletionModel.MISTRAL_TINY.toString()); assertThat(response.content()).extracting("object").contains("model"); assertThat(response.content()).extracting("permission").isNotNull(); } diff --git a/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModelIT.java b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModelIT.java index f3f1e35a93..c18d6af960 100644 --- a/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModelIT.java +++ b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModelIT.java @@ -257,7 +257,7 @@ void should_stream_answer_in_french_using_model_small_and_return_token_usage_and // given - Mistral Small = Mistral-8X7B StreamingChatLanguageModel model = MistralAiStreamingChatModel.builder() .apiKey(System.getenv("MISTRAL_AI_API_KEY")) - .modelName(ChatCompletionModel.MISTRAL_SMALL.toString()) + .modelName(MistralChatCompletionModel.MISTRAL_SMALL.toString()) .temperature(0.1) .build(); @@ -312,7 +312,7 @@ void should_stream_answer_in_spanish_using_model_small_and_return_token_usage_an // given - Mistral Small = Mistral-8X7B StreamingChatLanguageModel model = MistralAiStreamingChatModel.builder() .apiKey(System.getenv("MISTRAL_AI_API_KEY")) - .modelName(ChatCompletionModel.MISTRAL_SMALL.toString()) + .modelName(MistralChatCompletionModel.MISTRAL_SMALL.toString()) .temperature(0.1) .build(); @@ -367,7 +367,7 @@ void should_stream_answer_using_model_medium_and_return_token_usage_and_finish_r // given - Mistral Medium = currently relies on an internal prototype model. StreamingChatLanguageModel model = MistralAiStreamingChatModel.builder() .apiKey(System.getenv("MISTRAL_AI_API_KEY")) - .modelName(ChatCompletionModel.MISTRAL_MEDIUM.toString()) + .modelName(MistralChatCompletionModel.MISTRAL_MEDIUM.toString()) .maxNewTokens(10) .build(); From 2c1a22f9d033fe6ecddb4e303a5b731578bd42fd Mon Sep 17 00:00:00 2001 From: czelabueno Date: Fri, 19 Jan 2024 00:29:58 -0500 Subject: [PATCH 16/24] Mistral AI logRequestResponse and commit suggestions --- langchain4j-bom/pom.xml | 8 +- langchain4j-mistral-ai/pom.xml | 8 + .../mistralai/DefaultMistralAiHelper.java | 61 ++-- .../model/mistralai/MistralAiChatModel.java | 12 +- .../model/mistralai/MistralAiClient.java | 69 +++-- .../mistralai/MistralAiEmbeddingModel.java | 10 +- ...ModelCard.java => MistralAiModelCard.java} | 4 +- ...ion.java => MistralAiModelPermission.java} | 2 +- .../model/mistralai/MistralAiModels.java | 32 +- .../MistralAiRequestLoggingInterceptor.java | 57 ++++ .../MistralAiResponseLoggingInterceptor.java | 52 ++++ .../MistralAiStreamingChatModel.java | 9 +- ...va => MistralChatCompletionModelName.java} | 4 +- .../model/mistralai/MistralChatMessage.java | 2 +- .../model/mistralai/MistralDeltaMessage.java | 2 +- ...pe.java => MistralEmbeddingModelName.java} | 6 +- .../model/mistralai/MistralModelResponse.java | 2 +- ...tralRoleType.java => MistralRoleName.java} | 4 +- .../model/mistralai/MistralAiChatModelIT.java | 8 +- .../mistralai/MistralAiEmbeddingModelIT.java | 2 + .../model/mistralai/MistralAiModelsIT.java | 28 +- .../MistralAiStreamingChatModelIT.java | 287 +++--------------- .../model/openai/OpenAiChatModelIT.java | 2 +- 23 files changed, 296 insertions(+), 375 deletions(-) rename langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/{MistralModelCard.java => MistralAiModelCard.java} (81%) rename langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/{MistralModelPermission.java => MistralAiModelPermission.java} (93%) create mode 100644 langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiRequestLoggingInterceptor.java create mode 100644 langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiResponseLoggingInterceptor.java rename langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/{MistralChatCompletionModel.java => MistralChatCompletionModelName.java} (90%) rename langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/{MistralEmbeddingModelType.java => MistralEmbeddingModelName.java} (74%) rename langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/{MistralRoleType.java => MistralRoleName.java} (79%) diff --git a/langchain4j-bom/pom.xml b/langchain4j-bom/pom.xml index 0143b72d65..3d79e74a5d 100644 --- a/langchain4j-bom/pom.xml +++ b/langchain4j-bom/pom.xml @@ -93,6 +93,12 @@ ${project.version} + + dev.langchain4j + langchain4j-mistral-ai + ${project.version} + + @@ -200,4 +206,4 @@ - \ No newline at end of file + diff --git a/langchain4j-mistral-ai/pom.xml b/langchain4j-mistral-ai/pom.xml index a651175b47..cc02452f3e 100644 --- a/langchain4j-mistral-ai/pom.xml +++ b/langchain4j-mistral-ai/pom.xml @@ -74,6 +74,14 @@ test + + dev.langchain4j + langchain4j-core + tests + test-jar + test + + org.assertj assertj-core diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/DefaultMistralAiHelper.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/DefaultMistralAiHelper.java index 2e00d9fc40..cec7f89d8a 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/DefaultMistralAiHelper.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/DefaultMistralAiHelper.java @@ -4,12 +4,13 @@ import dev.langchain4j.data.message.ChatMessageType; import dev.langchain4j.model.output.FinishReason; import dev.langchain4j.model.output.TokenUsage; -import okhttp3.Response; +import okhttp3.Headers; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.IOException; import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import java.util.stream.Collectors; import java.util.stream.StreamSupport; @@ -17,11 +18,12 @@ import static dev.langchain4j.model.output.FinishReason.*; import static java.util.stream.Collectors.toList; -public class DefaultMistralAiHelper{ +class DefaultMistralAiHelper{ private static final Logger LOGGER = LoggerFactory.getLogger(DefaultMistralAiHelper.class); static final String MISTRALAI_API_URL = "https://api.mistral.ai/v1"; static final String MISTRALAI_API_CREATE_EMBEDDINGS_ENCODING_FORMAT = "float"; + private static final Pattern MISTRAI_API_KEY_BEARER_PATTERN = Pattern.compile("^(Bearer\\s*) ([A-Za-z0-9]{1,32})$"); public static String ensureNotBlankApiKey(String value) { if (isNullOrBlank(value)) { @@ -47,21 +49,21 @@ public static MistralChatMessage toMistralAiMessage(ChatMessage message) { .build(); } - private static MistralRoleType toMistralAiRole(ChatMessageType chatMessageType) { + private static MistralRoleName toMistralAiRole(ChatMessageType chatMessageType) { switch (chatMessageType) { case SYSTEM: - return MistralRoleType.SYSTEM; + return MistralRoleName.SYSTEM; case AI: - return MistralRoleType.ASSISTANT; + return MistralRoleName.ASSISTANT; case USER: - return MistralRoleType.USER; + return MistralRoleName.USER; default: throw new IllegalArgumentException("Unknown chat message type: " + chatMessageType); } } - public static TokenUsage tokenUsageFrom(MistralUsageInfo mistralAiUsage) { + static TokenUsage tokenUsageFrom(MistralUsageInfo mistralAiUsage) { if (mistralAiUsage == null) { return null; } @@ -72,7 +74,7 @@ public static TokenUsage tokenUsageFrom(MistralUsageInfo mistralAiUsage) { ); } - public static FinishReason finishReasonFrom(String mistralAiFinishReason) { + static FinishReason finishReasonFrom(String mistralAiFinishReason) { if (mistralAiFinishReason == null) { return null; } @@ -87,35 +89,32 @@ public static FinishReason finishReasonFrom(String mistralAiFinishReason) { } } - public static void logResponse(Response response){ - try { - LOGGER.debug("Response code: {}", response.code()); - LOGGER.debug("Response body: {}", getResponseBody(response)); - LOGGER.debug("Response headers: {}", getResponseHeaders(response)); - } catch (IOException e) { - LOGGER.warn("Error while logging response", e); - } - } - - private static String getResponseBody(Response response) throws IOException { - return isEventStream(response) ? "" : response.peekBody(Long.MAX_VALUE).string(); - } - - private static String getResponseHeaders(Response response){ - return (String) StreamSupport.stream(response.headers().spliterator(),false).map(header -> { + static String getHeaders(Headers headers){ + return StreamSupport.stream(headers.spliterator(),false).map(header -> { String headerKey = header.component1(); String headerValue = header.component2(); if (headerKey.equals("Authorization")) { - headerValue = "Bearer " + headerValue.substring(0, 5) + "..." + headerValue.substring(headerValue.length() - 5); - } else if (headerKey.equals("api-key")) { - headerValue = headerValue.substring(0, 2) + "..." + headerValue.substring(headerValue.length() - 2); + headerValue = maskAuthorizationHeaderValue(headerValue); } return String.format("[%s: %s]", headerKey, headerValue); }).collect(Collectors.joining(", ")); } - private static boolean isEventStream(Response response){ - String contentType = response.header("Content-Type"); - return contentType != null && contentType.contains("event-stream"); + + private static String maskAuthorizationHeaderValue(String authorizationHeaderValue) { + try { + Matcher matcher = MISTRAI_API_KEY_BEARER_PATTERN.matcher(authorizationHeaderValue); + StringBuffer sb = new StringBuffer(); + + while (matcher.find()) { + String bearer = matcher.group(1); + String token = matcher.group(2); + matcher.appendReplacement(sb, bearer + " " + token.substring(0, 7) + "..." + token.substring(token.length() - 7)); + } + matcher.appendTail(sb); + return sb.toString(); + } catch (Exception e) { + return "Error while masking Authorization header value"; + } } } diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatModel.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatModel.java index 5c6f5f8055..552608ed23 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatModel.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatModel.java @@ -44,7 +44,11 @@ public class MistralAiChatModel implements ChatLanguageModel { * @param safePrompt a flag indicating whether to use a safe prompt for generating chat responses * @param randomSeed the random seed for generating chat responses * @param timeout the timeout duration for API requests - * @param maxRetries the maximum number of retries for API requests + *

+ * The default value is 60 seconds + * @param logRequests a flag indicating whether to log API requests + * @param logResponses a flag indicating whether to log API responses + * @param maxRetries the maximum number of retries for API requests. It uses the default value 3 if not specified */ @Builder public MistralAiChatModel(String baseUrl, @@ -56,14 +60,18 @@ public MistralAiChatModel(String baseUrl, Boolean safePrompt, Integer randomSeed, Duration timeout, + Boolean logRequests, + Boolean logResponses, Integer maxRetries) { this.client = MistralAiClient.builder() .baseUrl(formattedURLForRetrofit(getOrDefault(baseUrl, MISTRALAI_API_URL))) .apiKey(ensureNotBlankApiKey(apiKey)) .timeout(getOrDefault(timeout, Duration.ofSeconds(60))) + .logRequests(getOrDefault(logRequests, false)) + .logResponses(getOrDefault(logResponses, false)) .build(); - this.modelName = getOrDefault(modelName, MistralChatCompletionModel.MISTRAL_TINY.toString()); + this.modelName = getOrDefault(modelName, MistralChatCompletionModelName.MISTRAL_TINY.toString()); this.temperature = temperature; this.topP = topP; this.maxNewTokens = maxNewTokens; diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiClient.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiClient.java index 6383dc8ce2..083aa1578c 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiClient.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiClient.java @@ -7,6 +7,7 @@ import dev.langchain4j.model.StreamingResponseHandler; import dev.langchain4j.model.output.FinishReason; import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; import lombok.Builder; import okhttp3.OkHttpClient; import okhttp3.ResponseBody; @@ -21,8 +22,8 @@ import java.io.IOException; import java.time.Duration; +import static dev.langchain4j.internal.Utils.isNullOrBlank; import static dev.langchain4j.model.mistralai.DefaultMistralAiHelper.*; -import static dev.langchain4j.model.output.FinishReason.*; class MistralAiClient { @@ -35,14 +36,32 @@ class MistralAiClient { private final OkHttpClient okHttpClient; @Builder - public MistralAiClient(String baseUrl, String apiKey, Duration timeout) { - okHttpClient = new OkHttpClient.Builder() - .addInterceptor(new MistralAiApiKeyInterceptor(apiKey)) + public MistralAiClient(String baseUrl, + String apiKey, + Duration timeout, + Boolean logRequests, + Boolean logResponses) { + OkHttpClient.Builder okHttpClientBuilder = new OkHttpClient.Builder() .callTimeout(timeout) .connectTimeout(timeout) .readTimeout(timeout) - .writeTimeout(timeout) - .build(); + .writeTimeout(timeout); + if (isNullOrBlank(apiKey)) { + throw new IllegalArgumentException("MistralAI API Key must be defined. It can be generated here: https://console.mistral.ai/user/api-keys/"); + }else { + okHttpClientBuilder.addInterceptor(new MistralAiApiKeyInterceptor(apiKey)); + // Log raw HTTP requests + if (logRequests) { + okHttpClientBuilder.addInterceptor(new MistralAiRequestLoggingInterceptor()); + } + + // Log raw HTTP responses + if (logResponses) { + okHttpClientBuilder.addInterceptor(new MistralAiResponseLoggingInterceptor()); + } + } + + this.okHttpClient = okHttpClientBuilder.build(); Retrofit retrofit = new Retrofit.Builder() .baseUrl(baseUrl) @@ -57,7 +76,6 @@ public MistralChatCompletionResponse chatCompletion(MistralChatCompletionRequest try { retrofit2.Response retrofitResponse = mistralAiApi.chatCompletion(request).execute(); - LOGGER.debug("MistralChatCompletionResponse: {}", retrofitResponse); if (retrofitResponse.isSuccessful()) { LOGGER.debug("ChatCompletionResponseBody: {}", retrofitResponse.body()); return retrofitResponse.body(); @@ -71,14 +89,13 @@ public MistralChatCompletionResponse chatCompletion(MistralChatCompletionRequest public void streamingChatCompletion(MistralChatCompletionRequest request, StreamingResponseHandler handler) { EventSourceListener eventSourceListener = new EventSourceListener() { - StringBuilder contentBuilder = new StringBuilder(); - MistralUsageInfo tokenUsage = new MistralUsageInfo(); - FinishReason lastFinishReason = null; + final StringBuffer contentBuilder = new StringBuffer(); + TokenUsage tokenUsage; + FinishReason finishReason; @Override public void onOpen(EventSource eventSource, okhttp3.Response response) { LOGGER.debug("onOpen()"); - logResponse(response); } @Override @@ -88,8 +105,8 @@ public void onEvent(EventSource eventSource, String id, String type, String data if ("[DONE]".equals(data)) { Response response = Response.from( AiMessage.from(contentBuilder.toString()), - tokenUsageFrom(tokenUsage), - lastFinishReason + tokenUsage, + finishReason ); handler.onComplete(response); } else { @@ -100,21 +117,14 @@ public void onEvent(EventSource eventSource, String id, String type, String data contentBuilder.append(chunk); handler.onNext(chunk); - //Retrieving token usage of the last choice - if(choice.getFinishReason() != null){ - FinishReason finishReason = finishReasonFrom(choice.getFinishReason()); - switch (finishReason){ - case STOP: - lastFinishReason = STOP; - tokenUsage = choice.getUsage(); - break; - case LENGTH: - lastFinishReason = LENGTH; - tokenUsage = choice.getUsage(); - break; - default: - break; - } + MistralUsageInfo usageInfo = choice.getUsage(); + if(usageInfo != null){ + this.tokenUsage = tokenUsageFrom(usageInfo); + } + + String finishReasonString = choice.getFinishReason(); + if(finishReasonString != null){ + this.finishReason = finishReasonFrom(finishReasonString); } } catch (Exception e) { handler.onError(e); @@ -127,7 +137,6 @@ public void onEvent(EventSource eventSource, String id, String type, String data @Override public void onFailure(EventSource eventSource, Throwable t, okhttp3.Response response) { LOGGER.debug("onFailure()", t); - logResponse(response); if (t != null){ handler.onError(t); @@ -153,7 +162,6 @@ public MistralEmbeddingResponse embedding(MistralEmbeddingRequest request) { try { retrofit2.Response retrofitResponse = mistralAiApi.embedding(request).execute(); - LOGGER.debug("MistralEmbeddingResponse: {}", retrofitResponse); if (retrofitResponse.isSuccessful()) { LOGGER.debug("EmbeddingResponseBody: {}", retrofitResponse.body()); return retrofitResponse.body(); @@ -169,7 +177,6 @@ public MistralModelResponse listModels() { try { retrofit2.Response retrofitResponse = mistralAiApi.models().execute(); - LOGGER.debug("MistralModelResponse: {}", retrofitResponse); if (retrofitResponse.isSuccessful()) { LOGGER.debug("ModelResponseBody: {}", retrofitResponse.body()); return retrofitResponse.body(); diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModel.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModel.java index 404df16dbc..3294af4d5d 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModel.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModel.java @@ -31,6 +31,10 @@ public class MistralAiEmbeddingModel implements EmbeddingModel { * @param apiKey the API key for authentication * @param modelName the name of the embedding model. It uses a default value if not specified * @param timeout the timeout duration for API requests. It uses a default value of 60 seconds if not specified + *

+ * The default value is 60 seconds + * @param logRequests a flag indicating whether to log API requests + * @param logResponses a flag indicating whether to log API responses * @param maxRetries the maximum number of retries for API requests. It uses a default value of 3 if not specified */ @Builder @@ -38,13 +42,17 @@ public MistralAiEmbeddingModel(String baseUrl, String apiKey, String modelName, Duration timeout, + Boolean logRequests, + Boolean logResponses, Integer maxRetries) { this.client = MistralAiClient.builder() .baseUrl(formattedURLForRetrofit(getOrDefault(baseUrl, MISTRALAI_API_URL))) .apiKey(ensureNotBlankApiKey(apiKey)) .timeout(getOrDefault(timeout, Duration.ofSeconds(60))) + .logRequests(getOrDefault(logRequests, false)) + .logResponses(getOrDefault(logResponses,false)) .build(); - this.modelName = getOrDefault(modelName, MistralEmbeddingModelType.MISTRAL_EMBED.toString()); + this.modelName = getOrDefault(modelName, MistralEmbeddingModelName.MISTRAL_EMBED.toString()); this.maxRetries = getOrDefault(maxRetries, 3); } diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralModelCard.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModelCard.java similarity index 81% rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralModelCard.java rename to langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModelCard.java index 77fb5eb3e9..c2bd3f8c9e 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralModelCard.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModelCard.java @@ -11,7 +11,7 @@ @NoArgsConstructor @AllArgsConstructor @Builder -public class MistralModelCard { +public class MistralAiModelCard { private String id; private String object; @@ -19,6 +19,6 @@ public class MistralModelCard { private String ownerBy; private String root; private String parent; - private List permission; + private List permission; } diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralModelPermission.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModelPermission.java similarity index 93% rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralModelPermission.java rename to langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModelPermission.java index 812a22bdda..ea60d9b592 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralModelPermission.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModelPermission.java @@ -10,7 +10,7 @@ @NoArgsConstructor @AllArgsConstructor @Builder -class MistralModelPermission { +public class MistralAiModelPermission { private String id; private String object; diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModels.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModels.java index e741f76bef..ccf1e7f33b 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModels.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiModels.java @@ -9,7 +9,6 @@ import static dev.langchain4j.internal.RetryUtils.withRetry; import static dev.langchain4j.internal.Utils.getOrDefault; import static dev.langchain4j.model.mistralai.DefaultMistralAiHelper.*; -import static java.util.stream.Collectors.toList; /** * Represents a collection of Mistral AI models. @@ -26,17 +25,23 @@ public class MistralAiModels { * @param baseUrl the base URL of the Mistral AI API. It uses the default value if not specified * @param apiKey the API key for authentication * @param timeout the timeout duration for API requests. It uses the default value of 60 seconds if not specified + * @param logRequests a flag whether to log raw HTTP requests + * @param logResponses a flag whether to log raw HTTP responses * @param maxRetries the maximum number of retries for API requests. It uses the default value of 3 if not specified */ @Builder public MistralAiModels(String baseUrl, String apiKey, Duration timeout, + Boolean logRequests, + Boolean logResponses, Integer maxRetries) { this.client = MistralAiClient.builder() .baseUrl(formattedURLForRetrofit(getOrDefault(baseUrl, MISTRALAI_API_URL))) .apiKey(ensureNotBlankApiKey(apiKey)) .timeout(getOrDefault(timeout, Duration.ofSeconds(60))) + .logRequests(getOrDefault(logRequests, false)) + .logResponses(getOrDefault(logResponses, false)) .build(); this.maxRetries = getOrDefault(maxRetries, 3); } @@ -51,35 +56,12 @@ public static MistralAiModels withApiKey(String apiKey) { return builder().apiKey(apiKey).build(); } - /** - * Retrieves the details of a specific model. - * - * @param modelId the ID of the model - * @return the response containing the model details - */ - public Response getModelDetails(String modelId){ - return Response.from( - this.getModels().content().stream().filter(modelCard -> modelCard.getId().equals(modelId)).findFirst().orElse(null) - ); - } - - /** - * Retrieves the IDs of all available models. - * - * @return the response containing the list of model IDs - */ - public Response> get(){ - return Response.from( - this.getModels().content().stream().map(MistralModelCard::getId).collect(toList()) - ); - } - /** * Retrieves the list of all available models. * * @return the response containing the list of models */ - public Response> getModels(){ + public Response> availableModels(){ MistralModelResponse response = withRetry(client::listModels, maxRetries); return Response.from( response.getData() diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiRequestLoggingInterceptor.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiRequestLoggingInterceptor.java new file mode 100644 index 0000000000..57ec952a0b --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiRequestLoggingInterceptor.java @@ -0,0 +1,57 @@ +package dev.langchain4j.model.mistralai; + +import okhttp3.Interceptor; +import okhttp3.Request; +import okhttp3.Response; +import okio.Buffer; +import org.jetbrains.annotations.NotNull; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; + +import static dev.langchain4j.model.mistralai.DefaultMistralAiHelper.getHeaders; + +class MistralAiRequestLoggingInterceptor implements Interceptor { + + private static final Logger LOGGER = LoggerFactory.getLogger(MistralAiRequestLoggingInterceptor.class); + + @NotNull + @Override + public Response intercept(@NotNull Chain chain) throws IOException { + Request request = chain.request(); + this.log(request); + return chain.proceed(request); + } + + private void log(Request request) { + String message = "Request:\n- method: {}\n- url: {}\n- headers: {}\n- body: {}"; + try { + if (LOGGER.isInfoEnabled()) { + LOGGER.info(message, new Object[]{request.method(), request.url(), getHeaders(request.headers()), getBody(request)}); + } else if (LOGGER.isWarnEnabled()) { + LOGGER.warn(message, new Object[]{request.method(), request.url(), getHeaders(request.headers()), getBody(request)}); + } else if (LOGGER.isErrorEnabled()) { + LOGGER.error(message, new Object[]{request.method(), request.url(), getHeaders(request.headers()), getBody(request)}); + } else { + LOGGER.debug(message, new Object[]{request.method(), request.url(), getHeaders(request.headers()), getBody(request)}); + } + } catch (Exception e) { + LOGGER.warn("Error while logging request: {}", e.getMessage()); + } + } + + private static String getBody(Request request){ + try { + Buffer buffer = new Buffer(); + if (request.body() == null) { + return ""; + } + request.body().writeTo(buffer); + return buffer.readUtf8(); + } catch (Exception e) { + LOGGER.warn("Exception while getting body", e); + return "Exception while getting body: " + e.getMessage(); + } + } +} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiResponseLoggingInterceptor.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiResponseLoggingInterceptor.java new file mode 100644 index 0000000000..fafe65cfc4 --- /dev/null +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiResponseLoggingInterceptor.java @@ -0,0 +1,52 @@ +package dev.langchain4j.model.mistralai; + +import okhttp3.Interceptor; +import okhttp3.Request; +import okhttp3.Response; +import org.jetbrains.annotations.NotNull; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; + +import static dev.langchain4j.model.mistralai.DefaultMistralAiHelper.getHeaders; + +class MistralAiResponseLoggingInterceptor implements Interceptor { + + private static final Logger LOGGER = LoggerFactory.getLogger(MistralAiResponseLoggingInterceptor.class); + + @NotNull + @Override + public Response intercept(@NotNull Chain chain) throws IOException { + Request request = chain.request(); + Response response = chain.proceed(request); + this.log(response); + return response; + } + + private void log(Response response) { + String message = "Response:\n- status code: {}\n- headers: {}\n- body: {}"; + try { + if (LOGGER.isInfoEnabled()) { + LOGGER.info(message, new Object[]{response.code(), getHeaders(response.headers()), this.getBody(response)}); + } else if (LOGGER.isWarnEnabled()) { + LOGGER.warn(message, new Object[]{response.code(), getHeaders(response.headers()), this.getBody(response)}); + } else if (LOGGER.isErrorEnabled()) { + LOGGER.error(message, new Object[]{response.code(), getHeaders(response.headers()), this.getBody(response)}); + } else { + LOGGER.debug(message, new Object[]{response.code(), getHeaders(response.headers()), this.getBody(response)}); + } + } catch (Exception e) { + LOGGER.warn("Error while logging response: {}", e.getMessage()); + } + } + + private String getBody(Response response) throws IOException { + return isEventStream(response) ? "[skipping response body due to streaming]" : response.peekBody(Long.MAX_VALUE).string(); + } + + private static boolean isEventStream(Response response){ + String contentType = response.header("Content-Type"); + return contentType != null && contentType.contains("event-stream"); + } +} diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModel.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModel.java index f2dc5ec477..88cea69b66 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModel.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModel.java @@ -43,6 +43,9 @@ public class MistralAiStreamingChatModel implements StreamingChatLanguageModel { * @param maxNewTokens the maximum number of new tokens to generate in a chat response * @param safePrompt a flag indicating whether to use a safe prompt for generating chat responses * @param randomSeed the random seed for generating chat responses + * (if not specified, a random number is used) + * @param logRequests a flag indicating whether to log raw HTTP requests + * @param logResponses a flag indicating whether to log raw HTTP responses * @param timeout the timeout duration for API requests */ @Builder @@ -54,14 +57,18 @@ public MistralAiStreamingChatModel(String baseUrl, Integer maxNewTokens, Boolean safePrompt, Integer randomSeed, + Boolean logRequests, + Boolean logResponses, Duration timeout) { this.client = MistralAiClient.builder() .baseUrl(formattedURLForRetrofit(getOrDefault(baseUrl, MISTRALAI_API_URL))) .apiKey(ensureNotBlankApiKey(apiKey)) .timeout(getOrDefault(timeout, Duration.ofSeconds(60))) + .logRequests(getOrDefault(logRequests, false)) + .logResponses(getOrDefault(logResponses, false)) .build(); - this.modelName = getOrDefault(modelName, MistralChatCompletionModel.MISTRAL_TINY.toString()); + this.modelName = getOrDefault(modelName, MistralChatCompletionModelName.MISTRAL_TINY.toString()); this.temperature = temperature; this.topP = topP; this.maxNewTokens = maxNewTokens; diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatCompletionModel.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatCompletionModelName.java similarity index 90% rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatCompletionModel.java rename to langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatCompletionModelName.java index d31762e1b9..e245f5f34c 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatCompletionModel.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatCompletionModelName.java @@ -19,7 +19,7 @@ * * @see Mistral AI Endpoints */ -enum MistralChatCompletionModel { +public enum MistralChatCompletionModelName { // powered by Mistral-7B-v0.2 MISTRAL_TINY("mistral-tiny"), @@ -30,7 +30,7 @@ enum MistralChatCompletionModel { private final String value; - private MistralChatCompletionModel(String value) { + private MistralChatCompletionModelName(String value) { this.value = value; } diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatMessage.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatMessage.java index ebc1980c58..99e65024a8 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatMessage.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatMessage.java @@ -12,6 +12,6 @@ @Builder class MistralChatMessage { - private MistralRoleType role; + private MistralRoleName role; private String content; } diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralDeltaMessage.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralDeltaMessage.java index 9dae9f54cd..76aa1b402c 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralDeltaMessage.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralDeltaMessage.java @@ -11,7 +11,7 @@ @Builder class MistralDeltaMessage { - private MistralRoleType role; + private MistralRoleName role; private String content; } diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralEmbeddingModelType.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralEmbeddingModelName.java similarity index 74% rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralEmbeddingModelType.java rename to langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralEmbeddingModelName.java index 31f2abbb15..bf6c58142b 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralEmbeddingModelType.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralEmbeddingModelName.java @@ -1,9 +1,9 @@ package dev.langchain4j.model.mistralai; /** - * The MistralEmbeddingModelType enum represents the available embedding models in the Mistral AI module. + * The MistralEmbeddingModelName enum represents the available embedding models in the Mistral AI module. */ -enum MistralEmbeddingModelType { +public enum MistralEmbeddingModelName { /** * The MISTRAL_EMBED model. @@ -12,7 +12,7 @@ enum MistralEmbeddingModelType { private final String value; - private MistralEmbeddingModelType(String value) { + private MistralEmbeddingModelName(String value) { this.value = value; } diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralModelResponse.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralModelResponse.java index f9e7afff05..59f9b8654a 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralModelResponse.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralModelResponse.java @@ -14,6 +14,6 @@ class MistralModelResponse { private String object; - private List data; + private List data; } diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralRoleType.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralRoleName.java similarity index 79% rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralRoleType.java rename to langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralRoleName.java index cded6c49c7..ef54b3ad97 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralRoleType.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralRoleName.java @@ -4,12 +4,12 @@ import lombok.Getter; @Getter -public enum MistralRoleType { +public enum MistralRoleName { @SerializedName("system") SYSTEM, @SerializedName("user") USER, @SerializedName("assistant") ASSISTANT; - private MistralRoleType() {} + private MistralRoleName() {} } diff --git a/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiChatModelIT.java b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiChatModelIT.java index d65f64a8a0..f52a49aae1 100644 --- a/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiChatModelIT.java +++ b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiChatModelIT.java @@ -16,6 +16,8 @@ class MistralAiChatModelIT { ChatLanguageModel model = MistralAiChatModel.builder() .apiKey(System.getenv("MISTRAL_AI_API_KEY")) .temperature(0.1) + .logRequests(true) + .logResponses(true) .build(); @Test @@ -127,7 +129,7 @@ void should_generate_answer_in_french_using_model_small_and_return_token_usage_a // given - Mistral Small = Mistral-8X7B ChatLanguageModel model = MistralAiChatModel.builder() .apiKey(System.getenv("MISTRAL_AI_API_KEY")) - .modelName(MistralChatCompletionModel.MISTRAL_SMALL.toString()) + .modelName(MistralChatCompletionModelName.MISTRAL_SMALL.toString()) .temperature(0.1) .build(); @@ -154,7 +156,7 @@ void should_generate_answer_in_spanish_using_model_small_and_return_token_usage_ // given - Mistral Small = Mistral-8X7B ChatLanguageModel model = MistralAiChatModel.builder() .apiKey(System.getenv("MISTRAL_AI_API_KEY")) - .modelName(MistralChatCompletionModel.MISTRAL_SMALL.toString()) + .modelName(MistralChatCompletionModelName.MISTRAL_SMALL.toString()) .temperature(0.1) .build(); @@ -181,7 +183,7 @@ void should_generate_answer_using_model_medium_and_return_token_usage_and_finish // given - Mistral Medium = currently relies on an internal prototype model. ChatLanguageModel model = MistralAiChatModel.builder() .apiKey(System.getenv("MISTRAL_AI_API_KEY")) - .modelName(MistralChatCompletionModel.MISTRAL_MEDIUM.toString()) + .modelName(MistralChatCompletionModelName.MISTRAL_MEDIUM.toString()) .maxNewTokens(10) .build(); diff --git a/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModelIT.java b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModelIT.java index 627e381a8c..83fc16d0eb 100644 --- a/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModelIT.java +++ b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModelIT.java @@ -16,6 +16,8 @@ class MistralAiEmbeddingModelIT { EmbeddingModel model = MistralAiEmbeddingModel.builder() .apiKey(System.getenv("MISTRAL_AI_API_KEY")) + .logRequests(true) + .logResponses(true) .build(); @Test diff --git a/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiModelsIT.java b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiModelsIT.java index a3199c6d50..0e9a54edc4 100644 --- a/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiModelsIT.java +++ b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiModelsIT.java @@ -12,39 +12,15 @@ class MistralAiModelsIT { MistralAiModels models = MistralAiModels.withApiKey(System.getenv("MISTRAL_AI_API_KEY")); //https://docs.mistral.ai/models/ - @Test - void should_return_all_models() { - // when - Response> response = models.get(); - - // then - assertThat(response.content()).isNotEmpty(); - assertThat(response.content().size()).isEqualTo(4); - assertThat(response.content()).contains(MistralChatCompletionModel.MISTRAL_TINY.toString()); - - } - - @Test - void should_return_one_model_card(){ - // when - Response response = models.getModelDetails(MistralChatCompletionModel.MISTRAL_TINY.toString()); - - // then - assertThat(response.content()).isNotNull(); - assertThat(response.content()).extracting("id").isEqualTo(MistralChatCompletionModel.MISTRAL_TINY.toString()); - assertThat(response.content()).extracting("object").isEqualTo("model"); - assertThat(response.content()).extracting("permission").isNotNull(); - } - @Test void should_return_all_model_cards(){ // when - Response> response = models.getModels(); + Response> response = models.availableModels(); // then assertThat(response.content()).isNotEmpty(); assertThat(response.content().size()).isEqualTo(4); - assertThat(response.content()).extracting("id").contains(MistralChatCompletionModel.MISTRAL_TINY.toString()); + assertThat(response.content()).extracting("id").contains(MistralChatCompletionModelName.MISTRAL_TINY.toString()); assertThat(response.content()).extracting("object").contains("model"); assertThat(response.content()).extracting("permission").isNotNull(); } diff --git a/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModelIT.java b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModelIT.java index c18d6af960..63c04b9f3b 100644 --- a/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModelIT.java +++ b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModelIT.java @@ -2,21 +2,17 @@ import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.UserMessage; -import dev.langchain4j.model.StreamingResponseHandler; import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.chat.TestStreamingResponseHandler; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; import org.junit.jupiter.api.Test; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.TimeoutException; - import static dev.langchain4j.data.message.UserMessage.userMessage; import static dev.langchain4j.model.output.FinishReason.LENGTH; import static dev.langchain4j.model.output.FinishReason.STOP; import static java.util.Arrays.asList; -import static java.util.concurrent.TimeUnit.SECONDS; +import static java.util.Collections.singletonList; import static org.assertj.core.api.Assertions.assertThat; class MistralAiStreamingChatModelIT { @@ -27,46 +23,19 @@ class MistralAiStreamingChatModelIT { .build(); @Test - void should_stream_answer_and_return_token_usage_and_finish_reason_stop() throws ExecutionException, InterruptedException, TimeoutException { - - CompletableFuture futureAnswer = new CompletableFuture<>(); - CompletableFuture> futureResponse = new CompletableFuture<>(); + void should_stream_answer_and_return_token_usage_and_finish_reason_stop() { // given UserMessage userMessage = userMessage("What is the capital of Peru?"); // when - model.generate(userMessage.text(), new StreamingResponseHandler() { - - private final StringBuilder answerBuilder = new StringBuilder(); - - @Override - public void onNext(String token) { - System.out.println("onNext: '" + token + "'"); - answerBuilder.append(token); - - } - - @Override - public void onComplete(Response response) { - System.out.println("onComplete: '" + response + "'"); - futureAnswer.complete(answerBuilder.toString()); - futureResponse.complete(response); - } - - @Override - public void onError(Throwable error) { - futureAnswer.completeExceptionally(error); - futureResponse.completeExceptionally(error); - } - }); + TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); + model.generate(singletonList(userMessage), handler); - String chunk = futureAnswer.get(10, SECONDS); - Response response = futureResponse.get(10, SECONDS); + Response response = handler.get(); // then - assertThat(chunk).contains("Lima"); - assertThat(response.content().text()).isEqualTo(chunk); + assertThat(response.content().text()).containsIgnoringCase("Lima"); TokenUsage tokenUsage = response.tokenUsage(); assertThat(tokenUsage.inputTokenCount()).isEqualTo(15); @@ -78,10 +47,7 @@ public void onError(Throwable error) { } @Test - void should_stream_answer_and_return_token_usage_and_finish_reason_length() throws ExecutionException, InterruptedException, TimeoutException { - - CompletableFuture futureAnswer = new CompletableFuture<>(); - CompletableFuture> futureResponse = new CompletableFuture<>(); + void should_stream_answer_and_return_token_usage_and_finish_reason_length() { // given StreamingChatLanguageModel model = MistralAiStreamingChatModel.builder() @@ -93,37 +59,12 @@ void should_stream_answer_and_return_token_usage_and_finish_reason_length() thro UserMessage userMessage = userMessage("What is the capital of Peru?"); // when - model.generate(userMessage.text(), new StreamingResponseHandler() { - - private final StringBuilder answerBuilder = new StringBuilder(); - - @Override - public void onNext(String token) { - System.out.println("onNext: '" + token + "'"); - answerBuilder.append(token); - - } - - @Override - public void onComplete(Response response) { - System.out.println("onComplete: '" + response + "'"); - futureAnswer.complete(answerBuilder.toString()); - futureResponse.complete(response); - } - - @Override - public void onError(Throwable error) { - futureAnswer.completeExceptionally(error); - futureResponse.completeExceptionally(error); - } - }); - - String chunk = futureAnswer.get(10, SECONDS); - Response response = futureResponse.get(10, SECONDS); + TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); + model.generate(singletonList(userMessage), handler); + Response response = handler.get(); // then - assertThat(chunk).contains("Lima"); - assertThat(response.content().text()).isEqualTo(chunk); + assertThat(response.content().text()).containsIgnoringCase("Lima"); TokenUsage tokenUsage = response.tokenUsage(); assertThat(tokenUsage.inputTokenCount()).isEqualTo(15); @@ -136,10 +77,7 @@ public void onError(Throwable error) { //https://docs.mistral.ai/platform/guardrailing/ @Test - void should_stream_answer_and_system_prompt_to_enforce_guardrails() throws ExecutionException, InterruptedException, TimeoutException { - - CompletableFuture futureAnswer = new CompletableFuture<>(); - CompletableFuture> futureResponse = new CompletableFuture<>(); + void should_stream_answer_and_system_prompt_to_enforce_guardrails() { // given StreamingChatLanguageModel model = MistralAiStreamingChatModel.builder() @@ -151,37 +89,13 @@ void should_stream_answer_and_system_prompt_to_enforce_guardrails() throws Execu UserMessage userMessage = userMessage("Hello, my name is Carlos"); // then - model.generate(userMessage.text(), new StreamingResponseHandler() { - private final StringBuilder answerBuilder = new StringBuilder(); - - @Override - public void onNext(String token) { - System.out.println("onNext: '" + token + "'"); - answerBuilder.append(token); - - } - - @Override - public void onComplete(Response response) { - System.out.println("onComplete: '" + response + "'"); - futureAnswer.complete(answerBuilder.toString()); - futureResponse.complete(response); - } - - @Override - public void onError(Throwable error) { - futureAnswer.completeExceptionally(error); - futureResponse.completeExceptionally(error); - } - }); - - String chunk = futureAnswer.get(10, SECONDS); - Response response = futureResponse.get(10, SECONDS); + TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); + model.generate(singletonList(userMessage), handler); + Response response = handler.get(); // then - assertThat(chunk).contains("respect"); - assertThat(chunk).contains("truth"); - assertThat(response.content().text()).isEqualTo(chunk); + assertThat(response.content().text()).containsIgnoringCase("respect"); + assertThat(response.content().text()).containsIgnoringCase("truth"); TokenUsage tokenUsage = response.tokenUsage(); assertThat(tokenUsage.inputTokenCount()).isGreaterThan(50); @@ -190,52 +104,25 @@ public void onError(Throwable error) { .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); assertThat(response.finishReason()).isEqualTo(STOP); - } @Test - void should_stream_answer_and_return_token_usage_and_finish_reason_stop_with_multiple_messages() throws ExecutionException, InterruptedException, TimeoutException { - - CompletableFuture futureAnswer = new CompletableFuture<>(); - CompletableFuture> futureResponse = new CompletableFuture<>(); + void should_stream_answer_and_return_token_usage_and_finish_reason_stop_with_multiple_messages() { // given UserMessage userMessage1 = userMessage("What is the capital of Peru?"); UserMessage userMessage2 = userMessage("What is the capital of France?"); UserMessage userMessage3 = userMessage("What is the capital of Canada?"); - model.generate(asList(userMessage1,userMessage2,userMessage3), new StreamingResponseHandler(){ - private final StringBuilder answerBuilder = new StringBuilder(); - - @Override - public void onNext(String token) { - System.out.println("onNext: '" + token + "'"); - answerBuilder.append(token); - - } - - @Override - public void onComplete(Response response) { - System.out.println("onComplete: '" + response + "'"); - futureAnswer.complete(answerBuilder.toString()); - futureResponse.complete(response); - } - - @Override - public void onError(Throwable error) { - futureAnswer.completeExceptionally(error); - futureResponse.completeExceptionally(error); - } - }); - - String chunk = futureAnswer.get(10, SECONDS); - Response response = futureResponse.get(10, SECONDS); + // when + TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); + model.generate(asList(userMessage1, userMessage2, userMessage3), handler); + Response response = handler.get(); // then - assertThat(chunk).contains("Lima"); - assertThat(chunk).contains("Paris"); - assertThat(chunk).contains("Ottawa"); - assertThat(response.content().text()).isEqualTo(chunk); + assertThat(response.content().text()).containsIgnoringCase("lima"); + assertThat(response.content().text()).containsIgnoringCase("paris"); + assertThat(response.content().text()).containsIgnoringCase("ottawa"); TokenUsage tokenUsage = response.tokenUsage(); assertThat(tokenUsage.inputTokenCount()).isEqualTo(11 + 11 + 11); @@ -244,55 +131,27 @@ public void onError(Throwable error) { .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); assertThat(response.finishReason()).isEqualTo(STOP); - - } @Test - void should_stream_answer_in_french_using_model_small_and_return_token_usage_and_finish_reason_stop() throws ExecutionException, InterruptedException, TimeoutException { - - CompletableFuture futureAnswer = new CompletableFuture<>(); - CompletableFuture> futureResponse = new CompletableFuture<>(); + void should_stream_answer_in_french_using_model_small_and_return_token_usage_and_finish_reason_stop() { // given - Mistral Small = Mistral-8X7B StreamingChatLanguageModel model = MistralAiStreamingChatModel.builder() .apiKey(System.getenv("MISTRAL_AI_API_KEY")) - .modelName(MistralChatCompletionModel.MISTRAL_SMALL.toString()) + .modelName(MistralChatCompletionModelName.MISTRAL_SMALL.toString()) .temperature(0.1) .build(); UserMessage userMessage = userMessage("Quelle est la capitale du Pérou?"); - model.generate(userMessage.text(), new StreamingResponseHandler() { - private final StringBuilder answerBuilder = new StringBuilder(); - - @Override - public void onNext(String token) { - System.out.println("onNext: '" + token + "'"); - answerBuilder.append(token); - - } - - @Override - public void onComplete(Response response) { - System.out.println("onComplete: '" + response + "'"); - futureAnswer.complete(answerBuilder.toString()); - futureResponse.complete(response); - } - - @Override - public void onError(Throwable error) { - futureAnswer.completeExceptionally(error); - futureResponse.completeExceptionally(error); - } - }); - - String chunk = futureAnswer.get(10, SECONDS); - Response response = futureResponse.get(10, SECONDS); + // when + TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); + model.generate(singletonList(userMessage), handler); + Response response = handler.get(); // then - assertThat(chunk).contains("Lima"); - assertThat(response.content().text()).isEqualTo(chunk); + assertThat(response.content().text()).containsIgnoringCase("lima"); TokenUsage tokenUsage = response.tokenUsage(); assertThat(tokenUsage.inputTokenCount()).isEqualTo(18); @@ -304,50 +163,24 @@ public void onError(Throwable error) { } @Test - void should_stream_answer_in_spanish_using_model_small_and_return_token_usage_and_finish_reason_stop() throws ExecutionException, InterruptedException, TimeoutException { - - CompletableFuture futureAnswer = new CompletableFuture<>(); - CompletableFuture> futureResponse = new CompletableFuture<>(); + void should_stream_answer_in_spanish_using_model_small_and_return_token_usage_and_finish_reason_stop() { // given - Mistral Small = Mistral-8X7B StreamingChatLanguageModel model = MistralAiStreamingChatModel.builder() .apiKey(System.getenv("MISTRAL_AI_API_KEY")) - .modelName(MistralChatCompletionModel.MISTRAL_SMALL.toString()) + .modelName(MistralChatCompletionModelName.MISTRAL_SMALL.toString()) .temperature(0.1) .build(); UserMessage userMessage = userMessage("¿Cuál es la capital de Perú?"); - model.generate(userMessage.text(), new StreamingResponseHandler() { - private final StringBuilder answerBuilder = new StringBuilder(); - - @Override - public void onNext(String token) { - System.out.println("onNext: '" + token + "'"); - answerBuilder.append(token); - - } - - @Override - public void onComplete(Response response) { - System.out.println("onComplete: '" + response + "'"); - futureAnswer.complete(answerBuilder.toString()); - futureResponse.complete(response); - } - - @Override - public void onError(Throwable error) { - futureAnswer.completeExceptionally(error); - futureResponse.completeExceptionally(error); - } - }); - - String chunk = futureAnswer.get(10, SECONDS); - Response response = futureResponse.get(10, SECONDS); + // when + TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); + model.generate(singletonList(userMessage), handler); + Response response = handler.get(); // then - assertThat(chunk).contains("Lima"); - assertThat(response.content().text()).isEqualTo(chunk); + assertThat(response.content().text()).containsIgnoringCase("lima"); TokenUsage tokenUsage = response.tokenUsage(); assertThat(tokenUsage.inputTokenCount()).isEqualTo(19); @@ -359,50 +192,24 @@ public void onError(Throwable error) { } @Test - void should_stream_answer_using_model_medium_and_return_token_usage_and_finish_reason_length() throws ExecutionException, InterruptedException, TimeoutException { - - CompletableFuture futureAnswer = new CompletableFuture<>(); - CompletableFuture> futureResponse = new CompletableFuture<>(); + void should_stream_answer_using_model_medium_and_return_token_usage_and_finish_reason_length() { // given - Mistral Medium = currently relies on an internal prototype model. StreamingChatLanguageModel model = MistralAiStreamingChatModel.builder() .apiKey(System.getenv("MISTRAL_AI_API_KEY")) - .modelName(MistralChatCompletionModel.MISTRAL_MEDIUM.toString()) + .modelName(MistralChatCompletionModelName.MISTRAL_MEDIUM.toString()) .maxNewTokens(10) .build(); UserMessage userMessage = userMessage("What is the capital of Peru?"); - model.generate(userMessage.text(), new StreamingResponseHandler() { - private final StringBuilder answerBuilder = new StringBuilder(); - - @Override - public void onNext(String token) { - System.out.println("onNext: '" + token + "'"); - answerBuilder.append(token); - - } - - @Override - public void onComplete(Response response) { - System.out.println("onComplete: '" + response + "'"); - futureAnswer.complete(answerBuilder.toString()); - futureResponse.complete(response); - } - - @Override - public void onError(Throwable error) { - futureAnswer.completeExceptionally(error); - futureResponse.completeExceptionally(error); - } - }); - - String chunk = futureAnswer.get(10, SECONDS); - Response response = futureResponse.get(10, SECONDS); + // when + TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); + model.generate(singletonList(userMessage), handler); + Response response = handler.get(); // then - assertThat(chunk).contains("Lima"); - assertThat(response.content().text()).isEqualTo(chunk); + assertThat(response.content().text()).containsIgnoringCase("lima"); TokenUsage tokenUsage = response.tokenUsage(); assertThat(tokenUsage.inputTokenCount()).isEqualTo(15); diff --git a/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiChatModelIT.java b/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiChatModelIT.java index 3c78e651bb..6403682677 100644 --- a/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiChatModelIT.java +++ b/langchain4j-open-ai/src/test/java/dev/langchain4j/model/openai/OpenAiChatModelIT.java @@ -376,4 +376,4 @@ void should_accept_text_and_multiple_images_from_different_sources() { assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(189); } -} \ No newline at end of file +} From 0d87349d8b4d4bb83a74f2390328a0fa5e3ba03b Mon Sep 17 00:00:00 2001 From: Carlos Zela Bueno <1715122+czelabueno@users.noreply.github.com> Date: Wed, 24 Jan 2024 11:21:44 -0500 Subject: [PATCH 17/24] Mistral AI token masking until 4 symbols Co-authored-by: LangChain4j --- .../dev/langchain4j/model/mistralai/DefaultMistralAiHelper.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/DefaultMistralAiHelper.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/DefaultMistralAiHelper.java index cec7f89d8a..c61bb32fd3 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/DefaultMistralAiHelper.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/DefaultMistralAiHelper.java @@ -108,7 +108,7 @@ private static String maskAuthorizationHeaderValue(String authorizationHeaderVal while (matcher.find()) { String bearer = matcher.group(1); String token = matcher.group(2); - matcher.appendReplacement(sb, bearer + " " + token.substring(0, 7) + "..." + token.substring(token.length() - 7)); + matcher.appendReplacement(sb, bearer + " " + token.substring(0, 2) + "..." + token.substring(token.length() - 2)); } matcher.appendTail(sb); return sb.toString(); From 7da8928e716290fa810bfb0729fb09c4c7ab8a83 Mon Sep 17 00:00:00 2001 From: Carlos Zela Bueno <1715122+czelabueno@users.noreply.github.com> Date: Wed, 24 Jan 2024 11:41:49 -0500 Subject: [PATCH 18/24] MistralAI update chat model enum Co-authored-by: LangChain4j --- .../model/mistralai/MistralChatCompletionModelName.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatCompletionModelName.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatCompletionModelName.java index e245f5f34c..79f7250727 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatCompletionModelName.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatCompletionModelName.java @@ -19,7 +19,7 @@ * * @see Mistral AI Endpoints */ -public enum MistralChatCompletionModelName { +public enum MistralAiChatModelName { // powered by Mistral-7B-v0.2 MISTRAL_TINY("mistral-tiny"), From 79237cc0af590c918d76d435d26a02e6c304565f Mon Sep 17 00:00:00 2001 From: Carlos Zela Bueno <1715122+czelabueno@users.noreply.github.com> Date: Wed, 24 Jan 2024 14:15:04 -0500 Subject: [PATCH 19/24] MistralAI update embedding model enum Co-authored-by: LangChain4j --- .../langchain4j/model/mistralai/MistralEmbeddingModelName.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralEmbeddingModelName.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralEmbeddingModelName.java index bf6c58142b..e3f497b9df 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralEmbeddingModelName.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralEmbeddingModelName.java @@ -3,7 +3,7 @@ /** * The MistralEmbeddingModelName enum represents the available embedding models in the Mistral AI module. */ -public enum MistralEmbeddingModelName { +public enum MistralAiEmbeddingModelName { /** * The MISTRAL_EMBED model. From b51d0a5955f74e5290394898ec08690b973309ef Mon Sep 17 00:00:00 2001 From: Carlos Zela Bueno <1715122+czelabueno@users.noreply.github.com> Date: Wed, 24 Jan 2024 14:17:04 -0500 Subject: [PATCH 20/24] Mistral AI fix get usageInfo from last chat completion response Co-authored-by: LangChain4j --- .../java/dev/langchain4j/model/mistralai/MistralAiClient.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiClient.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiClient.java index 083aa1578c..9f7f6e91ed 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiClient.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiClient.java @@ -117,7 +117,7 @@ public void onEvent(EventSource eventSource, String id, String type, String data contentBuilder.append(chunk); handler.onNext(chunk); - MistralUsageInfo usageInfo = choice.getUsage(); + MistralUsageInfo usageInfo = chatCompletionResponse.getUsage(); if(usageInfo != null){ this.tokenUsage = tokenUsageFrom(usageInfo); } From 2078d32eef3da43622e6eb96f33b1669f0972444 Mon Sep 17 00:00:00 2001 From: czelabueno Date: Wed, 24 Jan 2024 14:31:11 -0500 Subject: [PATCH 21/24] Mistral AI fix logging streaming and rename enums --- langchain4j-mistral-ai/pom.xml | 2 +- .../mistralai/DefaultMistralAiHelper.java | 28 ++++++++++++++----- .../model/mistralai/MistralAiChatModel.java | 2 +- ...lName.java => MistralAiChatModelName.java} | 4 +-- .../model/mistralai/MistralAiClient.java | 25 +++++++++++------ .../mistralai/MistralAiEmbeddingModel.java | 2 +- ....java => MistralAiEmbeddingModelName.java} | 4 +-- .../MistralAiRequestLoggingInterceptor.java | 13 ++------- .../MistralAiResponseLoggingInterceptor.java | 14 +++------- .../MistralAiStreamingChatModel.java | 2 +- .../model/mistralai/MistralAiChatModelIT.java | 6 ++-- .../model/mistralai/MistralAiModelsIT.java | 5 ++-- .../MistralAiStreamingChatModelIT.java | 9 +++--- 13 files changed, 62 insertions(+), 54 deletions(-) rename langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/{MistralChatCompletionModelName.java => MistralAiChatModelName.java} (94%) rename langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/{MistralEmbeddingModelName.java => MistralAiEmbeddingModelName.java} (73%) diff --git a/langchain4j-mistral-ai/pom.xml b/langchain4j-mistral-ai/pom.xml index cc02452f3e..5a9b83da99 100644 --- a/langchain4j-mistral-ai/pom.xml +++ b/langchain4j-mistral-ai/pom.xml @@ -55,7 +55,6 @@ 2.0.7 - org.projectlombok lombok @@ -93,6 +92,7 @@ tinylog-impl test + org.tinylog slf4j-tinylog diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/DefaultMistralAiHelper.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/DefaultMistralAiHelper.java index c61bb32fd3..7afc0ab8e2 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/DefaultMistralAiHelper.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/DefaultMistralAiHelper.java @@ -1,12 +1,9 @@ package dev.langchain4j.model.mistralai; -import dev.langchain4j.data.message.ChatMessage; -import dev.langchain4j.data.message.ChatMessageType; +import dev.langchain4j.data.message.*; import dev.langchain4j.model.output.FinishReason; import dev.langchain4j.model.output.TokenUsage; import okhttp3.Headers; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import java.util.List; import java.util.regex.Matcher; @@ -15,12 +12,12 @@ import java.util.stream.StreamSupport; import static dev.langchain4j.internal.Utils.isNullOrBlank; -import static dev.langchain4j.model.output.FinishReason.*; +import static dev.langchain4j.model.output.FinishReason.LENGTH; +import static dev.langchain4j.model.output.FinishReason.STOP; import static java.util.stream.Collectors.toList; class DefaultMistralAiHelper{ - private static final Logger LOGGER = LoggerFactory.getLogger(DefaultMistralAiHelper.class); static final String MISTRALAI_API_URL = "https://api.mistral.ai/v1"; static final String MISTRALAI_API_CREATE_EMBEDDINGS_ENCODING_FORMAT = "float"; private static final Pattern MISTRAI_API_KEY_BEARER_PATTERN = Pattern.compile("^(Bearer\\s*) ([A-Za-z0-9]{1,32})$"); @@ -45,7 +42,7 @@ public static List toMistralAiMessages(List mes public static MistralChatMessage toMistralAiMessage(ChatMessage message) { return MistralChatMessage.builder() .role(toMistralAiRole(message.type())) - .content(message.text()) + .content(toMistralChatMessageContent(message)) .build(); } @@ -63,6 +60,23 @@ private static MistralRoleName toMistralAiRole(ChatMessageType chatMessageType) } + private static String toMistralChatMessageContent(ChatMessage message) { + if (message instanceof SystemMessage) { + return ((SystemMessage) message).text(); + } + + if(message instanceof AiMessage){ + return ((AiMessage) message).text(); + } + + if(message instanceof UserMessage){ + return ((UserMessage) message).text(); // MistralAI support Text Content only as String + } + + throw new IllegalArgumentException("Unknown message type: " + message.type()); + + } + static TokenUsage tokenUsageFrom(MistralUsageInfo mistralAiUsage) { if (mistralAiUsage == null) { return null; diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatModel.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatModel.java index 552608ed23..f757ab7331 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatModel.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatModel.java @@ -71,7 +71,7 @@ public MistralAiChatModel(String baseUrl, .logRequests(getOrDefault(logRequests, false)) .logResponses(getOrDefault(logResponses, false)) .build(); - this.modelName = getOrDefault(modelName, MistralChatCompletionModelName.MISTRAL_TINY.toString()); + this.modelName = getOrDefault(modelName, MistralAiChatModelName.MISTRAL_TINY.toString()); this.temperature = temperature; this.topP = topP; this.maxNewTokens = maxNewTokens; diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatCompletionModelName.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatModelName.java similarity index 94% rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatCompletionModelName.java rename to langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatModelName.java index 79f7250727..831274f7fb 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralChatCompletionModelName.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiChatModelName.java @@ -10,12 +10,12 @@ * *

* The available chat completion models are: + *

*
    *
  • {@link #MISTRAL_TINY} - powered by Mistral-7B-v0.2
  • *
  • {@link #MISTRAL_SMALL} - powered by Mixtral-8X7B-v0.1
  • *
  • {@link #MISTRAL_MEDIUM} - currently relies on an internal prototype model
  • *
- *

* * @see Mistral AI Endpoints */ @@ -30,7 +30,7 @@ public enum MistralAiChatModelName { private final String value; - private MistralChatCompletionModelName(String value) { + private MistralAiChatModelName(String value) { this.value = value; } diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiClient.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiClient.java index 9f7f6e91ed..0e5eab56a5 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiClient.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiClient.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.time.Duration; +import static dev.langchain4j.internal.Utils.getOrDefault; import static dev.langchain4j.internal.Utils.isNullOrBlank; import static dev.langchain4j.model.mistralai.DefaultMistralAiHelper.*; @@ -34,6 +35,7 @@ class MistralAiClient { .create(); private final MistralAiApi mistralAiApi; private final OkHttpClient okHttpClient; + private final boolean logStreamingResponses; @Builder public MistralAiClient(String baseUrl, @@ -48,7 +50,7 @@ public MistralAiClient(String baseUrl, .writeTimeout(timeout); if (isNullOrBlank(apiKey)) { throw new IllegalArgumentException("MistralAI API Key must be defined. It can be generated here: https://console.mistral.ai/user/api-keys/"); - }else { + } else { okHttpClientBuilder.addInterceptor(new MistralAiApiKeyInterceptor(apiKey)); // Log raw HTTP requests if (logRequests) { @@ -61,6 +63,7 @@ public MistralAiClient(String baseUrl, } } + this.logStreamingResponses = logResponses; this.okHttpClient = okHttpClientBuilder.build(); Retrofit retrofit = new Retrofit.Builder() @@ -77,7 +80,6 @@ public MistralChatCompletionResponse chatCompletion(MistralChatCompletionRequest retrofit2.Response retrofitResponse = mistralAiApi.chatCompletion(request).execute(); if (retrofitResponse.isSuccessful()) { - LOGGER.debug("ChatCompletionResponseBody: {}", retrofitResponse.body()); return retrofitResponse.body(); } else { throw toException(retrofitResponse); @@ -95,13 +97,16 @@ public void streamingChatCompletion(MistralChatCompletionRequest request, Stream @Override public void onOpen(EventSource eventSource, okhttp3.Response response) { - LOGGER.debug("onOpen()"); + if (logStreamingResponses) { + LOGGER.debug("onOpen()"); + } } @Override public void onEvent(EventSource eventSource, String id, String type, String data) { - - LOGGER.debug("onEvent() {}", data); + if (logStreamingResponses){ + LOGGER.debug("onEvent() {}", data); + } if ("[DONE]".equals(data)) { Response response = Response.from( AiMessage.from(contentBuilder.toString()), @@ -136,7 +141,9 @@ public void onEvent(EventSource eventSource, String id, String type, String data @Override public void onFailure(EventSource eventSource, Throwable t, okhttp3.Response response) { - LOGGER.debug("onFailure()", t); + if (logStreamingResponses){ + LOGGER.debug("onFailure()", t); + } if (t != null){ handler.onError(t); @@ -147,7 +154,9 @@ public void onFailure(EventSource eventSource, Throwable t, okhttp3.Response res @Override public void onClosed(EventSource eventSource) { - LOGGER.debug("onClosed()"); + if (logStreamingResponses){ + LOGGER.debug("onClosed()"); + } } }; @@ -163,7 +172,6 @@ public MistralEmbeddingResponse embedding(MistralEmbeddingRequest request) { retrofit2.Response retrofitResponse = mistralAiApi.embedding(request).execute(); if (retrofitResponse.isSuccessful()) { - LOGGER.debug("EmbeddingResponseBody: {}", retrofitResponse.body()); return retrofitResponse.body(); } else { throw toException(retrofitResponse); @@ -178,7 +186,6 @@ public MistralModelResponse listModels() { retrofit2.Response retrofitResponse = mistralAiApi.models().execute(); if (retrofitResponse.isSuccessful()) { - LOGGER.debug("ModelResponseBody: {}", retrofitResponse.body()); return retrofitResponse.body(); } else { throw toException(retrofitResponse); diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModel.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModel.java index 3294af4d5d..d7d71b0a22 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModel.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModel.java @@ -52,7 +52,7 @@ public MistralAiEmbeddingModel(String baseUrl, .logRequests(getOrDefault(logRequests, false)) .logResponses(getOrDefault(logResponses,false)) .build(); - this.modelName = getOrDefault(modelName, MistralEmbeddingModelName.MISTRAL_EMBED.toString()); + this.modelName = getOrDefault(modelName, MistralAiEmbeddingModelName.MISTRAL_EMBED.toString()); this.maxRetries = getOrDefault(maxRetries, 3); } diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralEmbeddingModelName.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModelName.java similarity index 73% rename from langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralEmbeddingModelName.java rename to langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModelName.java index e3f497b9df..b0c7a43f97 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralEmbeddingModelName.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiEmbeddingModelName.java @@ -1,7 +1,7 @@ package dev.langchain4j.model.mistralai; /** - * The MistralEmbeddingModelName enum represents the available embedding models in the Mistral AI module. + * The MistralAiEmbeddingModelName enum represents the available embedding models in the Mistral AI module. */ public enum MistralAiEmbeddingModelName { @@ -12,7 +12,7 @@ public enum MistralAiEmbeddingModelName { private final String value; - private MistralEmbeddingModelName(String value) { + private MistralAiEmbeddingModelName(String value) { this.value = value; } diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiRequestLoggingInterceptor.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiRequestLoggingInterceptor.java index 57ec952a0b..ec8c65d6f3 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiRequestLoggingInterceptor.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiRequestLoggingInterceptor.java @@ -4,7 +4,6 @@ import okhttp3.Request; import okhttp3.Response; import okio.Buffer; -import org.jetbrains.annotations.NotNull; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -16,9 +15,9 @@ class MistralAiRequestLoggingInterceptor implements Interceptor { private static final Logger LOGGER = LoggerFactory.getLogger(MistralAiRequestLoggingInterceptor.class); - @NotNull + @Override - public Response intercept(@NotNull Chain chain) throws IOException { + public Response intercept(Chain chain) throws IOException { Request request = chain.request(); this.log(request); return chain.proceed(request); @@ -27,13 +26,7 @@ public Response intercept(@NotNull Chain chain) throws IOException { private void log(Request request) { String message = "Request:\n- method: {}\n- url: {}\n- headers: {}\n- body: {}"; try { - if (LOGGER.isInfoEnabled()) { - LOGGER.info(message, new Object[]{request.method(), request.url(), getHeaders(request.headers()), getBody(request)}); - } else if (LOGGER.isWarnEnabled()) { - LOGGER.warn(message, new Object[]{request.method(), request.url(), getHeaders(request.headers()), getBody(request)}); - } else if (LOGGER.isErrorEnabled()) { - LOGGER.error(message, new Object[]{request.method(), request.url(), getHeaders(request.headers()), getBody(request)}); - } else { + if (LOGGER.isDebugEnabled()){ LOGGER.debug(message, new Object[]{request.method(), request.url(), getHeaders(request.headers()), getBody(request)}); } } catch (Exception e) { diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiResponseLoggingInterceptor.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiResponseLoggingInterceptor.java index fafe65cfc4..a31f138517 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiResponseLoggingInterceptor.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiResponseLoggingInterceptor.java @@ -3,7 +3,8 @@ import okhttp3.Interceptor; import okhttp3.Request; import okhttp3.Response; -import org.jetbrains.annotations.NotNull; +import okio.Buffer; +import okio.BufferedSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -15,9 +16,8 @@ class MistralAiResponseLoggingInterceptor implements Interceptor { private static final Logger LOGGER = LoggerFactory.getLogger(MistralAiResponseLoggingInterceptor.class); - @NotNull @Override - public Response intercept(@NotNull Chain chain) throws IOException { + public Response intercept(Chain chain) throws IOException { Request request = chain.request(); Response response = chain.proceed(request); this.log(response); @@ -27,13 +27,7 @@ public Response intercept(@NotNull Chain chain) throws IOException { private void log(Response response) { String message = "Response:\n- status code: {}\n- headers: {}\n- body: {}"; try { - if (LOGGER.isInfoEnabled()) { - LOGGER.info(message, new Object[]{response.code(), getHeaders(response.headers()), this.getBody(response)}); - } else if (LOGGER.isWarnEnabled()) { - LOGGER.warn(message, new Object[]{response.code(), getHeaders(response.headers()), this.getBody(response)}); - } else if (LOGGER.isErrorEnabled()) { - LOGGER.error(message, new Object[]{response.code(), getHeaders(response.headers()), this.getBody(response)}); - } else { + if (LOGGER.isDebugEnabled()) { LOGGER.debug(message, new Object[]{response.code(), getHeaders(response.headers()), this.getBody(response)}); } } catch (Exception e) { diff --git a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModel.java b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModel.java index 88cea69b66..f2e11438b4 100644 --- a/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModel.java +++ b/langchain4j-mistral-ai/src/main/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModel.java @@ -68,7 +68,7 @@ public MistralAiStreamingChatModel(String baseUrl, .logRequests(getOrDefault(logRequests, false)) .logResponses(getOrDefault(logResponses, false)) .build(); - this.modelName = getOrDefault(modelName, MistralChatCompletionModelName.MISTRAL_TINY.toString()); + this.modelName = getOrDefault(modelName, MistralAiChatModelName.MISTRAL_TINY.toString()); this.temperature = temperature; this.topP = topP; this.maxNewTokens = maxNewTokens; diff --git a/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiChatModelIT.java b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiChatModelIT.java index f52a49aae1..78695442bd 100644 --- a/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiChatModelIT.java +++ b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiChatModelIT.java @@ -129,7 +129,7 @@ void should_generate_answer_in_french_using_model_small_and_return_token_usage_a // given - Mistral Small = Mistral-8X7B ChatLanguageModel model = MistralAiChatModel.builder() .apiKey(System.getenv("MISTRAL_AI_API_KEY")) - .modelName(MistralChatCompletionModelName.MISTRAL_SMALL.toString()) + .modelName(MistralAiChatModelName.MISTRAL_SMALL.toString()) .temperature(0.1) .build(); @@ -156,7 +156,7 @@ void should_generate_answer_in_spanish_using_model_small_and_return_token_usage_ // given - Mistral Small = Mistral-8X7B ChatLanguageModel model = MistralAiChatModel.builder() .apiKey(System.getenv("MISTRAL_AI_API_KEY")) - .modelName(MistralChatCompletionModelName.MISTRAL_SMALL.toString()) + .modelName(MistralAiChatModelName.MISTRAL_SMALL.toString()) .temperature(0.1) .build(); @@ -183,7 +183,7 @@ void should_generate_answer_using_model_medium_and_return_token_usage_and_finish // given - Mistral Medium = currently relies on an internal prototype model. ChatLanguageModel model = MistralAiChatModel.builder() .apiKey(System.getenv("MISTRAL_AI_API_KEY")) - .modelName(MistralChatCompletionModelName.MISTRAL_MEDIUM.toString()) + .modelName(MistralAiChatModelName.MISTRAL_MEDIUM.toString()) .maxNewTokens(10) .build(); diff --git a/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiModelsIT.java b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiModelsIT.java index 0e9a54edc4..4d6546744a 100644 --- a/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiModelsIT.java +++ b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiModelsIT.java @@ -18,9 +18,8 @@ void should_return_all_model_cards(){ Response> response = models.availableModels(); // then - assertThat(response.content()).isNotEmpty(); - assertThat(response.content().size()).isEqualTo(4); - assertThat(response.content()).extracting("id").contains(MistralChatCompletionModelName.MISTRAL_TINY.toString()); + assertThat(response.content().size()).isGreaterThan(0); + assertThat(response.content()).extracting("id").contains(MistralAiChatModelName.MISTRAL_TINY.toString()); assertThat(response.content()).extracting("object").contains("model"); assertThat(response.content()).extracting("permission").isNotNull(); } diff --git a/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModelIT.java b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModelIT.java index 63c04b9f3b..fb16f1f419 100644 --- a/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModelIT.java +++ b/langchain4j-mistral-ai/src/test/java/dev/langchain4j/model/mistralai/MistralAiStreamingChatModelIT.java @@ -20,6 +20,8 @@ class MistralAiStreamingChatModelIT { StreamingChatLanguageModel model = MistralAiStreamingChatModel.builder() .apiKey(System.getenv("MISTRAL_AI_API_KEY")) .temperature(0.1) + .logResponses(true) + .logRequests(true) .build(); @Test @@ -95,7 +97,6 @@ void should_stream_answer_and_system_prompt_to_enforce_guardrails() { // then assertThat(response.content().text()).containsIgnoringCase("respect"); - assertThat(response.content().text()).containsIgnoringCase("truth"); TokenUsage tokenUsage = response.tokenUsage(); assertThat(tokenUsage.inputTokenCount()).isGreaterThan(50); @@ -139,7 +140,7 @@ void should_stream_answer_in_french_using_model_small_and_return_token_usage_and // given - Mistral Small = Mistral-8X7B StreamingChatLanguageModel model = MistralAiStreamingChatModel.builder() .apiKey(System.getenv("MISTRAL_AI_API_KEY")) - .modelName(MistralChatCompletionModelName.MISTRAL_SMALL.toString()) + .modelName(MistralAiChatModelName.MISTRAL_SMALL.toString()) .temperature(0.1) .build(); @@ -168,7 +169,7 @@ void should_stream_answer_in_spanish_using_model_small_and_return_token_usage_an // given - Mistral Small = Mistral-8X7B StreamingChatLanguageModel model = MistralAiStreamingChatModel.builder() .apiKey(System.getenv("MISTRAL_AI_API_KEY")) - .modelName(MistralChatCompletionModelName.MISTRAL_SMALL.toString()) + .modelName(MistralAiChatModelName.MISTRAL_SMALL.toString()) .temperature(0.1) .build(); @@ -197,7 +198,7 @@ void should_stream_answer_using_model_medium_and_return_token_usage_and_finish_r // given - Mistral Medium = currently relies on an internal prototype model. StreamingChatLanguageModel model = MistralAiStreamingChatModel.builder() .apiKey(System.getenv("MISTRAL_AI_API_KEY")) - .modelName(MistralChatCompletionModelName.MISTRAL_MEDIUM.toString()) + .modelName(MistralAiChatModelName.MISTRAL_MEDIUM.toString()) .maxNewTokens(10) .build(); From 9ddb2ce7e5256301fd3d605593cd6dc8db962b8a Mon Sep 17 00:00:00 2001 From: czelabueno Date: Fri, 15 Mar 2024 20:08:46 -0400 Subject: [PATCH 22/24] update overview integration table --- README.md | 34 +++++++++++----------- docs/docs/integrations/index.mdx | 48 +++++++++++++++++--------------- 2 files changed, 43 insertions(+), 39 deletions(-) diff --git a/README.md b/README.md index cb3c851cac..9961a7fb7b 100644 --- a/README.md +++ b/README.md @@ -295,22 +295,24 @@ See example [here](https://github.com/langchain4j/langchain4j-examples/blob/main System.out.println(answer); // Hello! How can I assist you today? ``` ## Supported LLM Integrations ([Docs](https://docs.langchain4j.dev/category/integrations)) -| Provider | Native Image | [Completion](https://docs.langchain4j.dev/category/language-models) | [Streaming](https://docs.langchain4j.dev/integrations/language-models/response-streaming) | [Async Completion](https://docs.langchain4j.dev/category/language-models) | [Async Streaming](https://docs.langchain4j.dev/integrations/language-models/response-streaming) | [Embedding](https://docs.langchain4j.dev/category/embedding-models) | [Image Generation](https://docs.langchain4j.dev/category/image-models) | [ReRanking](https://docs.langchain4j.dev/category/reranking-models) -|---------------------------------------------------------------------------------------------------------| ------------- | ----------- | ------------- | --------- |--------------------------------| ------------ |---------------------------------------------------------------------------------------------|---------------| -| [OpenAI](https://docs.langchain4j.dev/integrations/language-models/openai) | | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [Azure OpenAI](https://docs.langchain4j.dev/integrations/language-models/azure-openai) | | ✅ | ✅ | | | ✅ | ✅ | -| [Hugging Face](https://docs.langchain4j.dev/integrations/language-models/huggingface) | | ✅ | | ✅ | | ✅ | | | -| [Amazon Bedrock](https://docs.langchain4j.dev/integrations/language-models/amazon-bedrock) | | ✅ | | | | ✅ | -| [Google Vertex AI Gemini](https://docs.langchain4j.dev/integrations/language-models/google-gemini) | | ✅ | ✅ | ✅ | ✅ | | | -| [Google Vertex AI](https://docs.langchain4j.dev/integrations/language-models/google-palm) | ✅ | ✅ | | ✅ | | ✅ | ✅ | -| [Mistral AI](https://docs.langchain4j.dev/integrations/language-models/mistralai) | | ✅ | ✅ | ✅ | ✅ | ✅ | -| [DashScope](https://docs.langchain4j.dev/integrations/language-models/dashscope) | | ✅ | ✅ | | ✅ | ✅ | -| [LocalAI](https://docs.langchain4j.dev/integrations/language-models/localai) | | ✅ | ✅ | ✅ | | ✅ | | -| [Ollama](https://docs.langchain4j.dev/integrations/language-models/ollama) | | ✅ | ✅ | ✅ | ✅ | ✅ | | -| [Cohere](https://docs.langchain4j.dev/integrations/reranking-models/cohere) | | | | | | | | ✅ | -| [Qianfan](https://docs.langchain4j.dev/integrations/language-models/qianfan) | | ✅ | ✅ | ✅ | ✅ | ✅ | | -| [ChatGLM](https://docs.langchain4j.dev/integrations/language-models/chatglm) | | ✅ | | | | | -| [Nomic](https://docs.langchain4j.dev/integrations/language-models/nomic) | | | | | | ✅ | | +| Provider | Native Image | [Sync Completion](https://docs.langchain4j.dev/category/language-models) | [Streaming Completion](https://docs.langchain4j.dev/integrations/language-models/response-streaming) | [Embedding](https://docs.langchain4j.dev/category/embedding-models) | [Image Generation](https://docs.langchain4j.dev/category/image-models) | [Scoring](https://docs.langchain4j.dev/category/scoring-models) | [Function Calling](https://docs.langchain4j.dev/tutorials/tools) +|----------------------------------------------------------------------------------------------------| ------------- |---------------------------------------------------------------------| ----------- | ------ |-------------------------------| ------ |--------------------------------------------------------------------------------------------| +| [OpenAI](https://docs.langchain4j.dev/integrations/language-models/openai) | | ✅ | ✅ | ✅ | ✅ | | ✅ | +| [Azure OpenAI](https://docs.langchain4j.dev/integrations/language-models/azure-openai) | | ✅ | ✅ | ✅ | ✅ | | ✅ | +| [Hugging Face](https://docs.langchain4j.dev/integrations/language-models/huggingface) | | ✅ | | ✅ | | | | | +| [Amazon Bedrock](https://docs.langchain4j.dev/integrations/language-models/amazon-bedrock) | | ✅ | |✅ | ✅ | | | +| [Google Vertex AI Gemini](https://docs.langchain4j.dev/integrations/language-models/google-gemini) | | ✅ | ✅ | | ✅ | | ✅ | +| [Google Vertex AI](https://docs.langchain4j.dev/integrations/language-models/google-palm) | ✅ | ✅ | | ✅ | ✅ | | | +| [Mistral AI](https://docs.langchain4j.dev/integrations/language-models/mistralai) | | ✅ | ✅ | ✅ | | |✅ | +| [DashScope](https://docs.langchain4j.dev/integrations/language-models/dashscope) | | ✅ | ✅ |✅ | | | | +| [LocalAI](https://docs.langchain4j.dev/integrations/language-models/localai) | | ✅ | ✅ | ✅ | | | ✅ | +| [Ollama](https://docs.langchain4j.dev/integrations/language-models/ollama) | | ✅ | ✅ | ✅ | | | | +| [Cohere](https://docs.langchain4j.dev/integrations/reranking-models/cohere) | | | | | | ✅| | +| [Qianfan](https://docs.langchain4j.dev/integrations/language-models/qianfan) | | ✅ | ✅ | ✅ | | |✅ | +| [ChatGLM](https://docs.langchain4j.dev/integrations/language-models/chatglm) | | ✅ | | | | | +| [Nomic](https://docs.langchain4j.dev/integrations/language-models/nomic) | | | |✅ | | | | +| [Anthropic](https://docs.langchain4j.dev/integrations/language-models/anthropic) | |✅ | | | | | | +| [Zhipu AI](https://docs.langchain4j.dev/integrations/language-models/zhipuai) | |✅| ✅| ✅| | |✅ | ## Disclaimer diff --git a/docs/docs/integrations/index.mdx b/docs/docs/integrations/index.mdx index 8baa50df84..275d85a238 100644 --- a/docs/docs/integrations/index.mdx +++ b/docs/docs/integrations/index.mdx @@ -10,33 +10,35 @@ We are making a great effort to have most of the functions enabled according to ## Capabilities 1. **Native image:** You can use this LLM integration for AOT compilation using GraalVM CE or GraalVM Oracle for [native image](https://www.graalvm.org/latest/reference-manual/native-image/) generation. -2. **Completion:** Supports the implementation of `text-completion` and `chat-completion` models in a synchronous way. This is most common usage. View examples [here](/tutorials/connect-to-llm) -3. **Streaming:** Supports `streaming` the model response back for `text-completion` or `chat-completion` models handling each event in `StreamingResponseHandler` class. View examples [here](/tutorials/response-streaming) -4. **Async Completion:** Provide an [asynchronous](https://docs.oracle.com/javase/8/docs/api/java/util/concurrent/CompletableFuture.html) version of the completion feature. -5. **Async Streaming:** Provide an [asynchronous](https://docs.oracle.com/javase/8/docs/api/java/util/concurrent/CompletableFuture.html) version of the streaming feature. -6. **Embeddings:** Supports the implementation of `text-embedding` models. Embeddings make it easy to add custom data without fine-tuning. Generally used with Retrieval-Augmented Generation (`RAG`) tools, `Vector Stores` and `Search`. View examples [here](/tutorials/embedding-store) -7. **Image Generation:** Supports the implementation of `text-to-image` models to create realistic and coherent images from scratch. View examples [here](/tutorials/image-models) -8. **ReRanking:** Understands the implementation of re-ranking models to improve created models by re-organizing their results based on certain parameters. View examples [here](/tutorials/reranking-models) +2. **Sync Completion:** Supports the implementation of `text-completion` and `chat-completion` models in a synchronous way. This is most common usage. View examples [here](/tutorials/connect-to-llm) +3. **Streaming Completion:** Supports `streaming` the model response back for `text-completion` or `chat-completion` models handling each event in `StreamingResponseHandler` class. View examples [here](/tutorials/response-streaming) +4. **Embeddings:** Supports the implementation of `text-embedding` models. Embeddings make it easy to add custom data without fine-tuning. Generally used with Retrieval-Augmented Generation (`RAG`) tools, `Vector Stores` and `Search`. View examples [here](/tutorials/embedding-store) +5. **Image Generation:** Supports the implementation of `text-to-image` models to create realistic and coherent images from scratch. View examples [here](/tutorials/image-models) +6. **Scoring:** Understands the implementation of scoring models to improve created models by re-organizing their results based on certain parameters. View examples [here](/tutorials/reranking-models) +7. **Function Calling:** Supports the implementation of `function-calling` models to call a function as a `Tool`. View examples [here](/tutorials/tools) :::note of course some LLM providers offer large multimodal model (accepting text or image inputs) and it would cover more than one capability. ::: ## Supported LLM Integrations -| Provider | [Native Image](/category/code-execution-engines) | [Completion](/tutorials/connect-to-llm) | [Streaming](/docs/tutorials/response-streaming) | [Async Completion](/docs/tutorials/connect-to-llm) | [Async Streaming](/docs/tutorials/response-streaming) | [Embeddings](/category/embedding-models) | [Image Generation](/docs/category/image-models) | [ReRanking](/docs/category/reranking-models) -|---------------------------------------------------------------------------| ------------- | ----------- | ------------- | --------- |--------------------------------| ------------ |----------------------------------------------------------------------------------------------|---------------| -| [OpenAI](/integrations/language-models/openai) | | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [Azure OpenAI](/integrations/language-models/azure-openai) | | ✅ | ✅ | | | ✅ | ✅ | -| [Hugging Face](/integrations/language-models/huggingface) | | ✅ | | ✅ | | ✅ | | | -| [Amazon Bedrock](/integrations/language-models/amazon-bedrock) | | ✅ | | | | ✅ | -| [Google Vertex AI Gemini](/integrations/language-models/google-gemini) | | ✅ | ✅ | ✅ | ✅ | | | -| [Google Vertex AI](/integrations/language-models/google-palm) | ✅ | ✅ | | ✅ | | ✅ | ✅ | -| [Mistral AI](/integrations/language-models/mistralai) | | ✅ | ✅ | ✅ | ✅ | ✅ | -| [DashScope](/integrations/language-models/dashscope) | | ✅ | ✅ | | ✅ | ✅ | -| [LocalAI](/integrations/language-models/localai) | | ✅ | ✅ | ✅ | | ✅ | | -| [Ollama](/integrations/language-models/ollama) | | ✅ | ✅ | ✅ | ✅ | ✅ | | -| [Cohere](/integrations/reranking-models/cohere) | | | | | | | | ✅ | -| [Qianfan](/integrations/language-models/qianfan) | | ✅ | ✅ | ✅ | ✅ | ✅ | | -| [ChatGLM](/integrations/language-models/chatglm) | | ✅ | | | | | -| [Nomic](/integrations/language-models/nomic) | | | | | | ✅ | | +| Provider | [Native Image](/category/code-execution-engines) | [Sync Completion](/category/language-models) | [Streaming Completion](/tutorials/response-streaming) | [Embedding](/category/embedding-models) | [Image Generation](/category/image-models) | [Scoring](https://docs.langchain4j.dev/category/scoring-models) | [Function Calling](/tutorials/tools) +|--------------------------------------------| ------------- |---------------------------------------------------------------------| ----------- | ------ |-------------------------------| ------ |--------------------------------------------------------------------------------------------| +| [OpenAI](/integrations/language-models/openai) | | ✅ | ✅ | ✅ | ✅ | | ✅ | +| [Azure OpenAI](/integrations/language-models/azure-openai) | | ✅ | ✅ | ✅ | ✅ | | ✅ | +| [Hugging Face](/integrations/language-models/huggingface) | | ✅ | | ✅ | | | | | +| [Amazon Bedrock](/integrations/language-models/amazon-bedrock) | | ✅ | |✅ | ✅ | | | +| [Google Vertex AI Gemini](/integrations/language-models/google-gemini) | | ✅ | ✅ | | ✅ | | ✅ | +| [Google Vertex AI](/integrations/language-models/google-palm) | ✅ | ✅ | | ✅ | ✅ | | | +| [Mistral AI](/integrations/language-models/mistralai) | | ✅ | ✅ | ✅ | | |✅ | +| [DashScope](/integrations/language-models/dashscope) | | ✅ | ✅ |✅ | | | | +| [LocalAI](/integrations/language-models/localai) | | ✅ | ✅ | ✅ | | | ✅ | +| [Ollama](/integrations/language-models/ollama) | | ✅ | ✅ | ✅ | | | | +| [Cohere](/integrations/reranking-models/cohere) | | | | | | ✅| | +| [Qianfan](/integrations/language-models/qianfan) | | ✅ | ✅ | ✅ | | |✅ | +| [ChatGLM](/integrations/language-models/chatglm) | | ✅ | | | | | +| [Nomic](/integrations/language-models/nomic) | | | |✅ | | | | +| [Anthropic](/integrations/language-models/anthropic) | |✅ | | | | | | +| [Zhipu AI](/integrations/language-models/zhipuai) | |✅| ✅| ✅| | |✅ | + From 53a6840a45c112998720f011b751e0752d093f08 Mon Sep 17 00:00:00 2001 From: Lize Raes <49833622+LizeRaes@users.noreply.github.com> Date: Sat, 16 Mar 2024 15:02:40 +0100 Subject: [PATCH 23/24] Update docs/docs/integrations/index.mdx Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- docs/docs/integrations/index.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/docs/integrations/index.mdx b/docs/docs/integrations/index.mdx index 275d85a238..42871f8cf7 100644 --- a/docs/docs/integrations/index.mdx +++ b/docs/docs/integrations/index.mdx @@ -10,7 +10,7 @@ We are making a great effort to have most of the functions enabled according to ## Capabilities 1. **Native image:** You can use this LLM integration for AOT compilation using GraalVM CE or GraalVM Oracle for [native image](https://www.graalvm.org/latest/reference-manual/native-image/) generation. -2. **Sync Completion:** Supports the implementation of `text-completion` and `chat-completion` models in a synchronous way. This is most common usage. View examples [here](/tutorials/connect-to-llm) +2. **Sync Completion:** Supports the implementation of `text-completion` and `chat-completion` models in a synchronous way. This is the most common usage. View examples [here](/tutorials/connect-to-llm) 3. **Streaming Completion:** Supports `streaming` the model response back for `text-completion` or `chat-completion` models handling each event in `StreamingResponseHandler` class. View examples [here](/tutorials/response-streaming) 4. **Embeddings:** Supports the implementation of `text-embedding` models. Embeddings make it easy to add custom data without fine-tuning. Generally used with Retrieval-Augmented Generation (`RAG`) tools, `Vector Stores` and `Search`. View examples [here](/tutorials/embedding-store) 5. **Image Generation:** Supports the implementation of `text-to-image` models to create realistic and coherent images from scratch. View examples [here](/tutorials/image-models) From 507e7481dbf5c643f4318911150a88f1bb9d683d Mon Sep 17 00:00:00 2001 From: Lize Raes <49833622+LizeRaes@users.noreply.github.com> Date: Sat, 16 Mar 2024 15:03:22 +0100 Subject: [PATCH 24/24] Update docs/docs/integrations/index.mdx Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- docs/docs/integrations/index.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/docs/integrations/index.mdx b/docs/docs/integrations/index.mdx index 42871f8cf7..19c7fdc240 100644 --- a/docs/docs/integrations/index.mdx +++ b/docs/docs/integrations/index.mdx @@ -11,7 +11,7 @@ We are making a great effort to have most of the functions enabled according to ## Capabilities 1. **Native image:** You can use this LLM integration for AOT compilation using GraalVM CE or GraalVM Oracle for [native image](https://www.graalvm.org/latest/reference-manual/native-image/) generation. 2. **Sync Completion:** Supports the implementation of `text-completion` and `chat-completion` models in a synchronous way. This is the most common usage. View examples [here](/tutorials/connect-to-llm) -3. **Streaming Completion:** Supports `streaming` the model response back for `text-completion` or `chat-completion` models handling each event in `StreamingResponseHandler` class. View examples [here](/tutorials/response-streaming) +3. **Streaming Completion:** Supports `streaming` the model response back for `text-completion` or `chat-completion` models, handling each event in `StreamingResponseHandler` class. View examples [here](/tutorials/response-streaming) 4. **Embeddings:** Supports the implementation of `text-embedding` models. Embeddings make it easy to add custom data without fine-tuning. Generally used with Retrieval-Augmented Generation (`RAG`) tools, `Vector Stores` and `Search`. View examples [here](/tutorials/embedding-store) 5. **Image Generation:** Supports the implementation of `text-to-image` models to create realistic and coherent images from scratch. View examples [here](/tutorials/image-models) 6. **Scoring:** Understands the implementation of scoring models to improve created models by re-organizing their results based on certain parameters. View examples [here](/tutorials/reranking-models)