diff --git a/PACKAGES.md b/PACKAGES.md index a6e2dc7a..eab2cdf0 100644 --- a/PACKAGES.md +++ b/PACKAGES.md @@ -39,6 +39,9 @@ A BOM is provided that can be used to define the versions of all Semantic Kernel `semantickernel-aiservices-openai` : Provides a connector that can be used to interact with the OpenAI API. +`semantickernel-aiservices-voyageai` +: Provides connectors for VoyageAI's embedding and reranking services, including text embeddings, contextualized embeddings, multimodal embeddings, and document reranking. + ## Example Configurations ### Example: OpenAI + SQLite @@ -72,5 +75,36 @@ POM XML for a simple project that uses OpenAI. ``` +### Example: VoyageAI Embeddings and Reranking + +POM XML for a project that uses VoyageAI for embeddings and reranking. + +```xml + + + + + + com.microsoft.semantic-kernel + semantickernel-bom + ${semantickernel.version} + import + pom + + + + + + com.microsoft.semantic-kernel + semantickernel-api + + + com.microsoft.semantic-kernel + semantickernel-aiservices-voyageai + + + +``` + diff --git a/aiservices/voyageai/pom.xml b/aiservices/voyageai/pom.xml new file mode 100644 index 00000000..db0d06a4 --- /dev/null +++ b/aiservices/voyageai/pom.xml @@ -0,0 +1,89 @@ + + + 4.0.0 + + com.microsoft.semantic-kernel + semantickernel-parent + 1.4.4-RC3-SNAPSHOT + ../../pom.xml + + + semantickernel-aiservices-voyageai + Semantic Kernel VoyageAI Services + VoyageAI services for Semantic Kernel + + + + com.microsoft.semantic-kernel + semantickernel-api + + + com.microsoft.semantic-kernel + semantickernel-api-builders + + + com.microsoft.semantic-kernel + semantickernel-api-ai-services + + + com.microsoft.semantic-kernel + semantickernel-api-textembedding-services + + + com.microsoft.semantic-kernel + semantickernel-api-exceptions + + + com.microsoft.semantic-kernel + semantickernel-api-localization + + + + com.fasterxml.jackson.core + jackson-databind + compile + + + com.fasterxml.jackson.core + jackson-core + compile + + + com.fasterxml.jackson.core + jackson-annotations + compile + + + + + com.squareup.okhttp3 + okhttp + 4.12.0 + + + + + io.projectreactor + reactor-core + + + + + org.slf4j + slf4j-api + + + + + org.junit.jupiter + junit-jupiter + test + + + org.mockito + mockito-core + test + + + + diff --git a/aiservices/voyageai/src/main/java/com/microsoft/semantickernel/aiservices/voyageai/contextualizedembedding/VoyageAIContextualizedEmbeddingGenerationService.java b/aiservices/voyageai/src/main/java/com/microsoft/semantickernel/aiservices/voyageai/contextualizedembedding/VoyageAIContextualizedEmbeddingGenerationService.java new file mode 100644 index 00000000..e04f6794 --- /dev/null +++ b/aiservices/voyageai/src/main/java/com/microsoft/semantickernel/aiservices/voyageai/contextualizedembedding/VoyageAIContextualizedEmbeddingGenerationService.java @@ -0,0 +1,204 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.aiservices.voyageai.contextualizedembedding; + +import com.microsoft.semantickernel.aiservices.voyageai.core.VoyageAIClient; +import com.microsoft.semantickernel.aiservices.voyageai.core.VoyageAIModels; +import com.microsoft.semantickernel.orchestration.PromptExecutionSettings; +import com.microsoft.semantickernel.services.textembedding.Embedding; +import com.microsoft.semantickernel.services.textembedding.TextEmbeddingGenerationService; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; + +import javax.annotation.Nullable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.stream.Collectors; + +/** + * VoyageAI contextualized embedding generation service. + * Generates embeddings that capture both local chunk details and global document-level metadata. + * Supports models like voyage-3. + */ +public class VoyageAIContextualizedEmbeddingGenerationService implements TextEmbeddingGenerationService { + + private static final Logger LOGGER = LoggerFactory.getLogger(VoyageAIContextualizedEmbeddingGenerationService.class); + + private final VoyageAIClient client; + private final String modelId; + private final String serviceId; + + /** + * Creates a new instance of VoyageAI contextualized embedding generation service. + * + * @param client VoyageAI client + * @param modelId Model ID (e.g., "voyage-3") + * @param serviceId Optional service ID + */ + public VoyageAIContextualizedEmbeddingGenerationService( + VoyageAIClient client, + String modelId, + @Nullable String serviceId) { + + if (client == null) { + throw new IllegalArgumentException("Client cannot be null"); + } + if (modelId == null || modelId.trim().isEmpty()) { + throw new IllegalArgumentException("Model ID cannot be null or empty"); + } + + this.client = client; + this.modelId = modelId; + this.serviceId = serviceId != null ? serviceId : PromptExecutionSettings.DEFAULT_SERVICE_ID; + } + + @Override + public String getServiceId() { + return serviceId; + } + + @Override + public String getModelId() { + return modelId; + } + + /** + * Generates contextualized embeddings for document chunks. + * + * @param inputs List of lists where each inner list contains document chunks + * @return A Mono containing a list of embeddings for all chunks across all documents + */ + public Mono> generateContextualizedEmbeddingsAsync(List> inputs) { + if (inputs == null || inputs.isEmpty()) { + return Mono.just(Collections.emptyList()); + } + + LOGGER.debug("Generating contextualized embeddings for {} document groups using model {}", + inputs.size(), modelId); + + VoyageAIModels.ContextualizedEmbeddingRequest request = + new VoyageAIModels.ContextualizedEmbeddingRequest(); + request.setInputs(inputs); + request.setModel(modelId); + + return client.sendRequestAsync( + "contextualizedembeddings", + request, + VoyageAIModels.ContextualizedEmbeddingResponse.class) + .map(response -> { + List embeddings = new ArrayList<>(); + // Parse nested data structure: {"data":[{"data":[{"embedding":[...]}]}]} + for (VoyageAIModels.ContextualizedEmbeddingDataList dataList : response.getData()) { + for (VoyageAIModels.EmbeddingDataItem item : dataList.getData()) { + embeddings.add(new Embedding(item.getEmbedding())); + } + } + + LOGGER.debug("Received {} contextualized embeddings from VoyageAI", embeddings.size()); + return embeddings; + }); + } + + /** + * Generates embeddings for the given text. + * For standard text embedding, wraps the data as a single input. + * + * @param data The text to generate embeddings for + * @return A Mono that completes with the embedding + */ + @Override + public Mono generateEmbeddingAsync(String data) { + return generateEmbeddingsAsync(Arrays.asList(data)) + .flatMap(embeddings -> { + if (embeddings.isEmpty()) { + return Mono.empty(); + } + return Mono.just(embeddings.get(0)); + }); + } + + /** + * Generates embeddings for the given texts. + * Each text is treated as a separate document for contextualized embeddings. + * + * @param data The texts to generate embeddings for + * @return A Mono that completes with the list of embeddings + */ + @Override + public Mono> generateEmbeddingsAsync(List data) { + if (data == null || data.isEmpty()) { + return Mono.just(Collections.emptyList()); + } + + // Convert each string to a single-element list for contextualized embeddings + List> inputs = new ArrayList<>(); + for (String text : data) { + inputs.add(Arrays.asList(text)); + } + + return generateContextualizedEmbeddingsAsync(inputs); + } + + /** + * Creates a builder for VoyageAI contextualized embedding generation service. + * + * @return A new builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for {@link VoyageAIContextualizedEmbeddingGenerationService}. + */ + public static class Builder { + private VoyageAIClient client; + private String modelId; + private String serviceId; + + /** + * Sets the VoyageAI client. + * + * @param client VoyageAI client + * @return This builder + */ + public Builder withClient(VoyageAIClient client) { + this.client = client; + return this; + } + + /** + * Sets the model ID. + * + * @param modelId Model ID (e.g., "voyage-3") + * @return This builder + */ + public Builder withModelId(String modelId) { + this.modelId = modelId; + return this; + } + + /** + * Sets the service ID. + * + * @param serviceId Service ID + * @return This builder + */ + public Builder withServiceId(String serviceId) { + this.serviceId = serviceId; + return this; + } + + /** + * Builds the VoyageAI contextualized embedding generation service. + * + * @return A new instance of VoyageAIContextualizedEmbeddingGenerationService + */ + public VoyageAIContextualizedEmbeddingGenerationService build() { + return new VoyageAIContextualizedEmbeddingGenerationService(client, modelId, serviceId); + } + } +} diff --git a/aiservices/voyageai/src/main/java/com/microsoft/semantickernel/aiservices/voyageai/core/VoyageAIClient.java b/aiservices/voyageai/src/main/java/com/microsoft/semantickernel/aiservices/voyageai/core/VoyageAIClient.java new file mode 100644 index 00000000..95188ea5 --- /dev/null +++ b/aiservices/voyageai/src/main/java/com/microsoft/semantickernel/aiservices/voyageai/core/VoyageAIClient.java @@ -0,0 +1,132 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.aiservices.voyageai.core; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.microsoft.semantickernel.exceptions.AIException; +import okhttp3.MediaType; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.RequestBody; +import okhttp3.Response; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; + +import javax.annotation.Nullable; +import java.io.IOException; +import java.util.concurrent.TimeUnit; + +/** + * HTTP client for VoyageAI API. + */ +public class VoyageAIClient { + private static final Logger LOGGER = LoggerFactory.getLogger(VoyageAIClient.class); + private static final MediaType JSON = MediaType.get("application/json; charset=utf-8"); + private static final String DEFAULT_ENDPOINT = "https://api.voyageai.com/v1"; + + private final OkHttpClient httpClient; + private final String apiKey; + private final String endpoint; + private final ObjectMapper objectMapper; + + /** + * Creates a new VoyageAI client. + * + * @param apiKey VoyageAI API key + * @param endpoint Optional API endpoint (defaults to https://api.voyageai.com/v1) + * @param httpClient Optional HTTP client + */ + public VoyageAIClient( + String apiKey, + @Nullable String endpoint, + @Nullable OkHttpClient httpClient) { + + if (apiKey == null || apiKey.trim().isEmpty()) { + throw new IllegalArgumentException("API key cannot be null or empty"); + } + + this.apiKey = apiKey; + this.endpoint = endpoint != null ? endpoint : DEFAULT_ENDPOINT; + this.httpClient = httpClient != null ? httpClient : createDefaultHttpClient(); + this.objectMapper = createObjectMapper(); + } + + /** + * Creates a new VoyageAI client with default HTTP client and endpoint. + * + * @param apiKey VoyageAI API key + */ + public VoyageAIClient(String apiKey) { + this(apiKey, null, null); + } + + private static OkHttpClient createDefaultHttpClient() { + return new OkHttpClient.Builder() + .connectTimeout(30, TimeUnit.SECONDS) + .readTimeout(60, TimeUnit.SECONDS) + .writeTimeout(30, TimeUnit.SECONDS) + .build(); + } + + private static ObjectMapper createObjectMapper() { + ObjectMapper mapper = new ObjectMapper(); + mapper.setPropertyNamingStrategy(PropertyNamingStrategies.SNAKE_CASE); + return mapper; + } + + /** + * Sends a request to the VoyageAI API. + * + * @param path API path (e.g., "embeddings", "rerank") + * @param requestBody Request body object + * @param responseType Response type class + * @param Response type + * @return Mono containing the response + */ + public Mono sendRequestAsync( + String path, + Object requestBody, + Class responseType) { + + return Mono.fromCallable(() -> { + String requestUri = endpoint + "/" + path; + + LOGGER.debug("Sending VoyageAI request to {}", requestUri); + + String json = objectMapper.writeValueAsString(requestBody); + LOGGER.trace("Request body: {}", json); + + RequestBody body = RequestBody.create(json, JSON); + + Request request = new Request.Builder() + .url(requestUri) + .addHeader("Authorization", "Bearer " + apiKey) + .addHeader("Accept", "application/json") + .post(body) + .build(); + + try (Response response = httpClient.newCall(request).execute()) { + String responseBody = response.body() != null ? response.body().string() : ""; + + if (!response.isSuccessful()) { + LOGGER.error("VoyageAI API request failed with status {}: {}", + response.code(), responseBody); + throw new AIException(AIException.ErrorCodes.SERVICE_ERROR, + String.format("VoyageAI API request failed with status %d: %s", + response.code(), responseBody)); + } + + LOGGER.trace("Response body: {}", responseBody); + + T result = objectMapper.readValue(responseBody, responseType); + if (result == null) { + throw new AIException(AIException.ErrorCodes.SERVICE_ERROR, + "Failed to deserialize VoyageAI response: " + responseBody); + } + + return result; + } + }); + } +} diff --git a/aiservices/voyageai/src/main/java/com/microsoft/semantickernel/aiservices/voyageai/core/VoyageAIModels.java b/aiservices/voyageai/src/main/java/com/microsoft/semantickernel/aiservices/voyageai/core/VoyageAIModels.java new file mode 100644 index 00000000..b196d9e0 --- /dev/null +++ b/aiservices/voyageai/src/main/java/com/microsoft/semantickernel/aiservices/voyageai/core/VoyageAIModels.java @@ -0,0 +1,606 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.aiservices.voyageai.core; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.List; + +/** + * VoyageAI API request and response models. + */ +public class VoyageAIModels { + + // Embedding Models + + /** + * Request model for text embeddings. + */ + public static class EmbeddingRequest { + @JsonProperty("input") + private List input; + + @JsonProperty("model") + private String model; + + @JsonProperty("input_type") + private String inputType; + + @JsonProperty("truncation") + private Boolean truncation; + + @JsonProperty("output_dimension") + private Integer outputDimension; + + @JsonProperty("output_dtype") + private String outputDtype; + + public List getInput() { + return input; + } + + public void setInput(List input) { + this.input = input; + } + + public String getModel() { + return model; + } + + public void setModel(String model) { + this.model = model; + } + + public String getInputType() { + return inputType; + } + + public void setInputType(String inputType) { + this.inputType = inputType; + } + + public Boolean getTruncation() { + return truncation; + } + + public void setTruncation(Boolean truncation) { + this.truncation = truncation; + } + + public Integer getOutputDimension() { + return outputDimension; + } + + public void setOutputDimension(Integer outputDimension) { + this.outputDimension = outputDimension; + } + + public String getOutputDtype() { + return outputDtype; + } + + public void setOutputDtype(String outputDtype) { + this.outputDtype = outputDtype; + } + } + + /** + * Response model for embeddings. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public static class EmbeddingResponse { + @JsonProperty("data") + private List data; + + @JsonProperty("usage") + private EmbeddingUsage usage; + + public List getData() { + return data; + } + + public void setData(List data) { + this.data = data; + } + + public EmbeddingUsage getUsage() { + return usage; + } + + public void setUsage(EmbeddingUsage usage) { + this.usage = usage; + } + } + + /** + * Embedding data item. + */ + public static class EmbeddingDataItem { + @JsonProperty("object") + private String object; + + @JsonProperty("embedding") + private float[] embedding; + + @JsonProperty("index") + private int index; + + public String getObject() { + return object; + } + + public void setObject(String object) { + this.object = object; + } + + public float[] getEmbedding() { + return embedding; + } + + public void setEmbedding(float[] embedding) { + this.embedding = embedding; + } + + public int getIndex() { + return index; + } + + public void setIndex(int index) { + this.index = index; + } + } + + /** + * Usage information. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public static class EmbeddingUsage { + @JsonProperty("total_tokens") + private int totalTokens; + + public int getTotalTokens() { + return totalTokens; + } + + public void setTotalTokens(int totalTokens) { + this.totalTokens = totalTokens; + } + } + + // Reranking Models + + /** + * Request model for reranking. + */ + public static class RerankRequest { + @JsonProperty("query") + private String query; + + @JsonProperty("documents") + private List documents; + + @JsonProperty("model") + private String model; + + @JsonProperty("top_k") + private Integer topK; + + @JsonProperty("truncation") + private Boolean truncation; + + public String getQuery() { + return query; + } + + public void setQuery(String query) { + this.query = query; + } + + public List getDocuments() { + return documents; + } + + public void setDocuments(List documents) { + this.documents = documents; + } + + public String getModel() { + return model; + } + + public void setModel(String model) { + this.model = model; + } + + public Integer getTopK() { + return topK; + } + + public void setTopK(Integer topK) { + this.topK = topK; + } + + public Boolean getTruncation() { + return truncation; + } + + public void setTruncation(Boolean truncation) { + this.truncation = truncation; + } + } + + /** + * Response model for reranking. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public static class RerankResponse { + @JsonProperty("data") + private List data; + + @JsonProperty("usage") + private EmbeddingUsage usage; + + public List getData() { + return data; + } + + public void setData(List data) { + this.data = data; + } + + public EmbeddingUsage getUsage() { + return usage; + } + + public void setUsage(EmbeddingUsage usage) { + this.usage = usage; + } + } + + /** + * Rerank data item. + */ + public static class RerankDataItem { + @JsonProperty("index") + private int index; + + @JsonProperty("relevance_score") + private double relevanceScore; + + public int getIndex() { + return index; + } + + public void setIndex(int index) { + this.index = index; + } + + public double getRelevanceScore() { + return relevanceScore; + } + + public void setRelevanceScore(double relevanceScore) { + this.relevanceScore = relevanceScore; + } + } + + // Contextualized Embedding Models + + /** + * Request model for contextualized embeddings. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public static class ContextualizedEmbeddingRequest { + @JsonProperty("inputs") + private List> inputs; + + @JsonProperty("model") + private String model; + + @JsonProperty("input_type") + private String inputType; + + @JsonProperty("truncation") + private Boolean truncation; + + @JsonProperty("output_dimension") + private Integer outputDimension; + + @JsonProperty("output_dtype") + private String outputDtype; + + public List> getInputs() { + return inputs; + } + + public void setInputs(List> inputs) { + this.inputs = inputs; + } + + public String getModel() { + return model; + } + + public void setModel(String model) { + this.model = model; + } + + public String getInputType() { + return inputType; + } + + public void setInputType(String inputType) { + this.inputType = inputType; + } + + public Boolean getTruncation() { + return truncation; + } + + public void setTruncation(Boolean truncation) { + this.truncation = truncation; + } + + public Integer getOutputDimension() { + return outputDimension; + } + + public void setOutputDimension(Integer outputDimension) { + this.outputDimension = outputDimension; + } + + public String getOutputDtype() { + return outputDtype; + } + + public void setOutputDtype(String outputDtype) { + this.outputDtype = outputDtype; + } + } + + /** + * Response model for contextualized embeddings. + * VoyageAI returns a nested list structure: {"object":"list","data":[{"object":"list","data":[...]}]} + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public static class ContextualizedEmbeddingResponse { + @JsonProperty("data") + private List data; + + public List getData() { + return data; + } + + public void setData(List data) { + this.data = data; + } + } + + /** + * Nested data list for contextualized embeddings. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public static class ContextualizedEmbeddingDataList { + @JsonProperty("data") + private List data; + + public List getData() { + return data; + } + + public void setData(List data) { + this.data = data; + } + } + + /** + * Contextualized embedding result. + */ + public static class ContextualizedEmbeddingResult { + @JsonProperty("embeddings") + private List embeddings; + + public List getEmbeddings() { + return embeddings; + } + + public void setEmbeddings(List embeddings) { + this.embeddings = embeddings; + } + } + + /** + * Embedding item with chunk information. + */ + public static class EmbeddingItem { + @JsonProperty("embedding") + private float[] embedding; + + @JsonProperty("chunk") + private String chunk; + + @JsonProperty("index") + private int index; + + public float[] getEmbedding() { + return embedding; + } + + public void setEmbedding(float[] embedding) { + this.embedding = embedding; + } + + public String getChunk() { + return chunk; + } + + public void setChunk(String chunk) { + this.chunk = chunk; + } + + public int getIndex() { + return index; + } + + public void setIndex(int index) { + this.index = index; + } + } + + // Multimodal Embedding Models + + /** + * Content item for multimodal input (text or image). + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public static class MultimodalContentItem { + @JsonProperty("type") + private String type; // "text" or "image_url" + + @JsonProperty("text") + private String text; + + @JsonProperty("image_url") + private String imageUrl; + + public MultimodalContentItem() { + // Default constructor for Jackson + } + + public MultimodalContentItem(String type, String value) { + this.type = type; + if ("text".equals(type)) { + this.text = value; + } else if ("image_url".equals(type)) { + this.imageUrl = value; + } + } + + public String getType() { + return type; + } + + public void setType(String type) { + this.type = type; + } + + public String getText() { + return text; + } + + public void setText(String text) { + this.text = text; + } + + public String getImageUrl() { + return imageUrl; + } + + public void setImageUrl(String imageUrl) { + this.imageUrl = imageUrl; + } + } + + /** + * Input for multimodal embedding (contains a list of content items). + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public static class MultimodalInput { + @JsonProperty("content") + private List content; + + public MultimodalInput() { + // Default constructor for Jackson + } + + public MultimodalInput(List content) { + this.content = content; + } + + public List getContent() { + return content; + } + + public void setContent(List content) { + this.content = content; + } + } + + /** + * Request model for multimodal embeddings. + */ + @JsonInclude(JsonInclude.Include.NON_NULL) + public static class MultimodalEmbeddingRequest { + @JsonProperty("inputs") + private List inputs; + + @JsonProperty("model") + private String model; + + @JsonProperty("input_type") + private String inputType; + + @JsonProperty("truncation") + private Boolean truncation; + + public List getInputs() { + return inputs; + } + + public void setInputs(List inputs) { + this.inputs = inputs; + } + + public String getModel() { + return model; + } + + public void setModel(String model) { + this.model = model; + } + + public String getInputType() { + return inputType; + } + + public void setInputType(String inputType) { + this.inputType = inputType; + } + + public Boolean getTruncation() { + return truncation; + } + + public void setTruncation(Boolean truncation) { + this.truncation = truncation; + } + } + + /** + * Response model for multimodal embeddings. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public static class MultimodalEmbeddingResponse { + @JsonProperty("data") + private List data; + + @JsonProperty("usage") + private EmbeddingUsage usage; + + public List getData() { + return data; + } + + public void setData(List data) { + this.data = data; + } + + public EmbeddingUsage getUsage() { + return usage; + } + + public void setUsage(EmbeddingUsage usage) { + this.usage = usage; + } + } +} diff --git a/aiservices/voyageai/src/main/java/com/microsoft/semantickernel/aiservices/voyageai/multimodalembedding/VoyageAIMultimodalEmbeddingGenerationService.java b/aiservices/voyageai/src/main/java/com/microsoft/semantickernel/aiservices/voyageai/multimodalembedding/VoyageAIMultimodalEmbeddingGenerationService.java new file mode 100644 index 00000000..239d29b5 --- /dev/null +++ b/aiservices/voyageai/src/main/java/com/microsoft/semantickernel/aiservices/voyageai/multimodalembedding/VoyageAIMultimodalEmbeddingGenerationService.java @@ -0,0 +1,227 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.aiservices.voyageai.multimodalembedding; + +import com.microsoft.semantickernel.aiservices.voyageai.core.VoyageAIClient; +import com.microsoft.semantickernel.aiservices.voyageai.core.VoyageAIModels; +import com.microsoft.semantickernel.orchestration.PromptExecutionSettings; +import com.microsoft.semantickernel.services.textembedding.Embedding; +import com.microsoft.semantickernel.services.textembedding.TextEmbeddingGenerationService; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; + +import javax.annotation.Nullable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.stream.Collectors; + +/** + * VoyageAI multimodal embedding generation service. + * Generates embeddings for text, images, or interleaved text and images. + * Supports the voyage-multimodal-3 model. + *

+ * Constraints: + * - Maximum 1,000 inputs per request + * - Images: ≤16 million pixels, ≤20 MB + * - Total tokens per input: ≤32,000 (560 pixels = 1 token) + * - Aggregate tokens across inputs: ≤320,000 + */ +public class VoyageAIMultimodalEmbeddingGenerationService implements TextEmbeddingGenerationService { + + private static final Logger LOGGER = LoggerFactory.getLogger(VoyageAIMultimodalEmbeddingGenerationService.class); + + private final VoyageAIClient client; + private final String modelId; + private final String serviceId; + + /** + * Creates a new instance of VoyageAI multimodal embedding generation service. + * + * @param client VoyageAI client + * @param modelId Model ID (e.g., "voyage-multimodal-3") + * @param serviceId Optional service ID + */ + public VoyageAIMultimodalEmbeddingGenerationService( + VoyageAIClient client, + String modelId, + @Nullable String serviceId) { + + if (client == null) { + throw new IllegalArgumentException("Client cannot be null"); + } + if (modelId == null || modelId.trim().isEmpty()) { + throw new IllegalArgumentException("Model ID cannot be null or empty"); + } + + this.client = client; + this.modelId = modelId; + this.serviceId = serviceId != null ? serviceId : PromptExecutionSettings.DEFAULT_SERVICE_ID; + } + + @Override + public String getServiceId() { + return serviceId; + } + + @Override + public String getModelId() { + return modelId; + } + + /** + * Generates multimodal embeddings for text and/or images. + * + * @param inputs List of multimodal inputs + * @return A Mono containing a list of multimodal embeddings + */ + public Mono> generateMultimodalEmbeddingsAsync(List inputs) { + if (inputs == null || inputs.isEmpty()) { + return Mono.just(Collections.emptyList()); + } + + LOGGER.debug("Generating multimodal embeddings for {} inputs using model {}", inputs.size(), modelId); + + VoyageAIModels.MultimodalEmbeddingRequest request = new VoyageAIModels.MultimodalEmbeddingRequest(); + request.setInputs(inputs); + request.setModel(modelId); + + return client.sendRequestAsync("multimodalembeddings", request, VoyageAIModels.MultimodalEmbeddingResponse.class) + .map(response -> { + LOGGER.debug("Received {} multimodal embeddings from VoyageAI", response.getData().size()); + + List embeddings = response.getData().stream() + .sorted(Comparator.comparingInt(VoyageAIModels.EmbeddingDataItem::getIndex)) + .map(item -> new Embedding(item.getEmbedding())) + .collect(Collectors.toList()); + + return embeddings; + }); + } + + /** + * Generates embeddings for the given text. + * For text-only input, converts to multimodal format. + * + * @param data The text to generate embeddings for + * @return A Mono that completes with the embedding + */ + @Override + public Mono generateEmbeddingAsync(String data) { + return generateEmbeddingsAsync(Arrays.asList(data)) + .flatMap(embeddings -> { + if (embeddings.isEmpty()) { + return Mono.empty(); + } + return Mono.just(embeddings.get(0)); + }); + } + + /** + * Generates embeddings for the given texts. + * Converts text-only inputs to multimodal format. + * + * @param data The texts to generate embeddings for + * @return A Mono that completes with the list of embeddings + */ + @Override + public Mono> generateEmbeddingsAsync(List data) { + if (data == null || data.isEmpty()) { + return Mono.just(Collections.emptyList()); + } + + // Convert each text to multimodal input format + List inputs = new ArrayList<>(); + for (String text : data) { + VoyageAIModels.MultimodalContentItem contentItem = + new VoyageAIModels.MultimodalContentItem("text", text); + VoyageAIModels.MultimodalInput input = + new VoyageAIModels.MultimodalInput(Arrays.asList(contentItem)); + inputs.add(input); + } + + if (inputs.isEmpty()) { + return Mono.just(Collections.emptyList()); + } + + LOGGER.debug("Generating multimodal embeddings for {} inputs using model {}", inputs.size(), modelId); + + VoyageAIModels.MultimodalEmbeddingRequest request = new VoyageAIModels.MultimodalEmbeddingRequest(); + request.setInputs(inputs); + request.setModel(modelId); + + return client.sendRequestAsync("multimodalembeddings", request, VoyageAIModels.MultimodalEmbeddingResponse.class) + .map(response -> { + LOGGER.debug("Received {} multimodal embeddings from VoyageAI", response.getData().size()); + + List embeddings = response.getData().stream() + .sorted(Comparator.comparingInt(VoyageAIModels.EmbeddingDataItem::getIndex)) + .map(item -> new Embedding(item.getEmbedding())) + .collect(Collectors.toList()); + + return embeddings; + }); + } + + /** + * Creates a builder for VoyageAI multimodal embedding generation service. + * + * @return A new builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for {@link VoyageAIMultimodalEmbeddingGenerationService}. + */ + public static class Builder { + private VoyageAIClient client; + private String modelId; + private String serviceId; + + /** + * Sets the VoyageAI client. + * + * @param client VoyageAI client + * @return This builder + */ + public Builder withClient(VoyageAIClient client) { + this.client = client; + return this; + } + + /** + * Sets the model ID. + * + * @param modelId Model ID (e.g., "voyage-multimodal-3") + * @return This builder + */ + public Builder withModelId(String modelId) { + this.modelId = modelId; + return this; + } + + /** + * Sets the service ID. + * + * @param serviceId Service ID + * @return This builder + */ + public Builder withServiceId(String serviceId) { + this.serviceId = serviceId; + return this; + } + + /** + * Builds the VoyageAI multimodal embedding generation service. + * + * @return A new instance of VoyageAIMultimodalEmbeddingGenerationService + */ + public VoyageAIMultimodalEmbeddingGenerationService build() { + return new VoyageAIMultimodalEmbeddingGenerationService(client, modelId, serviceId); + } + } +} diff --git a/aiservices/voyageai/src/main/java/com/microsoft/semantickernel/aiservices/voyageai/reranking/VoyageAITextRerankingService.java b/aiservices/voyageai/src/main/java/com/microsoft/semantickernel/aiservices/voyageai/reranking/VoyageAITextRerankingService.java new file mode 100644 index 00000000..6b0f59fa --- /dev/null +++ b/aiservices/voyageai/src/main/java/com/microsoft/semantickernel/aiservices/voyageai/reranking/VoyageAITextRerankingService.java @@ -0,0 +1,182 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.aiservices.voyageai.reranking; + +import com.microsoft.semantickernel.aiservices.voyageai.core.VoyageAIClient; +import com.microsoft.semantickernel.aiservices.voyageai.core.VoyageAIModels; +import com.microsoft.semantickernel.orchestration.PromptExecutionSettings; +import com.microsoft.semantickernel.services.reranking.RerankResult; +import com.microsoft.semantickernel.services.reranking.TextRerankingService; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; + +import javax.annotation.Nullable; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.stream.Collectors; + +/** + * VoyageAI implementation of {@link TextRerankingService}. + * Supports models like rerank-2, rerank-2-lite. + */ +public class VoyageAITextRerankingService implements TextRerankingService { + + private static final Logger LOGGER = LoggerFactory.getLogger(VoyageAITextRerankingService.class); + + private final VoyageAIClient client; + private final String modelId; + private final String serviceId; + private final Integer topK; + + /** + * Creates a new instance of VoyageAI text reranking service. + * + * @param client VoyageAI client + * @param modelId Model ID (e.g., "rerank-2") + * @param serviceId Optional service ID + * @param topK Optional top K results to return + */ + public VoyageAITextRerankingService( + VoyageAIClient client, + String modelId, + @Nullable String serviceId, + @Nullable Integer topK) { + + if (client == null) { + throw new IllegalArgumentException("Client cannot be null"); + } + if (modelId == null || modelId.trim().isEmpty()) { + throw new IllegalArgumentException("Model ID cannot be null or empty"); + } + + this.client = client; + this.modelId = modelId; + this.serviceId = serviceId != null ? serviceId : PromptExecutionSettings.DEFAULT_SERVICE_ID; + this.topK = topK; + } + + @Override + public String getServiceId() { + return serviceId; + } + + @Override + public String getModelId() { + return modelId; + } + + /** + * Reranks documents based on their relevance to the query. + * + * @param query The query to rank documents against + * @param documents The list of documents to rerank + * @return A Mono containing a list of {@link RerankResult} sorted by relevance score in descending order + */ + @Override + public Mono> rerankAsync(String query, List documents) { + if (query == null || query.trim().isEmpty()) { + throw new IllegalArgumentException("Query cannot be null or empty"); + } + if (documents == null || documents.isEmpty()) { + return Mono.just(Collections.emptyList()); + } + + LOGGER.debug("Reranking {} documents using model {}", documents.size(), modelId); + + VoyageAIModels.RerankRequest request = new VoyageAIModels.RerankRequest(); + request.setQuery(query); + request.setDocuments(documents); + request.setModel(modelId); + request.setTopK(topK); + request.setTruncation(true); + + return client.sendRequestAsync("rerank", request, VoyageAIModels.RerankResponse.class) + .map(response -> { + LOGGER.debug("Received {} reranked results from VoyageAI", response.getData().size()); + + List results = response.getData().stream() + .sorted(Comparator.comparingDouble(VoyageAIModels.RerankDataItem::getRelevanceScore).reversed()) + .map(item -> new RerankResult( + item.getIndex(), + documents.get(item.getIndex()), + item.getRelevanceScore() + )) + .collect(Collectors.toList()); + + return results; + }); + } + + /** + * Creates a builder for VoyageAI text reranking service. + * + * @return A new builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for {@link VoyageAITextRerankingService}. + */ + public static class Builder { + private VoyageAIClient client; + private String modelId; + private String serviceId; + private Integer topK; + + /** + * Sets the VoyageAI client. + * + * @param client VoyageAI client + * @return This builder + */ + public Builder withClient(VoyageAIClient client) { + this.client = client; + return this; + } + + /** + * Sets the model ID. + * + * @param modelId Model ID (e.g., "rerank-2") + * @return This builder + */ + public Builder withModelId(String modelId) { + this.modelId = modelId; + return this; + } + + /** + * Sets the service ID. + * + * @param serviceId Service ID + * @return This builder + */ + public Builder withServiceId(String serviceId) { + this.serviceId = serviceId; + return this; + } + + /** + * Sets the top K results to return. + * + * @param topK Top K results + * @return This builder + */ + public Builder withTopK(Integer topK) { + this.topK = topK; + return this; + } + + /** + * Builds the VoyageAI text reranking service. + * + * @return A new instance of VoyageAITextRerankingService + */ + public VoyageAITextRerankingService build() { + return new VoyageAITextRerankingService(client, modelId, serviceId, topK); + } + } +} diff --git a/aiservices/voyageai/src/main/java/com/microsoft/semantickernel/aiservices/voyageai/textembedding/VoyageAITextEmbeddingGenerationService.java b/aiservices/voyageai/src/main/java/com/microsoft/semantickernel/aiservices/voyageai/textembedding/VoyageAITextEmbeddingGenerationService.java new file mode 100644 index 00000000..22a39353 --- /dev/null +++ b/aiservices/voyageai/src/main/java/com/microsoft/semantickernel/aiservices/voyageai/textembedding/VoyageAITextEmbeddingGenerationService.java @@ -0,0 +1,174 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.aiservices.voyageai.textembedding; + +import com.microsoft.semantickernel.aiservices.voyageai.core.VoyageAIClient; +import com.microsoft.semantickernel.aiservices.voyageai.core.VoyageAIModels; +import com.microsoft.semantickernel.orchestration.PromptExecutionSettings; +import com.microsoft.semantickernel.services.textembedding.Embedding; +import com.microsoft.semantickernel.services.textembedding.TextEmbeddingGenerationService; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; + +import javax.annotation.Nullable; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.stream.Collectors; + +/** + * VoyageAI implementation of {@link TextEmbeddingGenerationService}. + * Supports models like voyage-3-large, voyage-3.5, voyage-code-3, voyage-finance-2, voyage-law-2. + */ +public class VoyageAITextEmbeddingGenerationService implements TextEmbeddingGenerationService { + + private static final Logger LOGGER = LoggerFactory.getLogger(VoyageAITextEmbeddingGenerationService.class); + + private final VoyageAIClient client; + private final String modelId; + private final String serviceId; + + /** + * Creates a new instance of VoyageAI text embedding generation service. + * + * @param client VoyageAI client + * @param modelId Model ID (e.g., "voyage-3-large") + * @param serviceId Optional service ID + */ + public VoyageAITextEmbeddingGenerationService( + VoyageAIClient client, + String modelId, + @Nullable String serviceId) { + + if (client == null) { + throw new IllegalArgumentException("Client cannot be null"); + } + if (modelId == null || modelId.trim().isEmpty()) { + throw new IllegalArgumentException("Model ID cannot be null or empty"); + } + + this.client = client; + this.modelId = modelId; + this.serviceId = serviceId != null ? serviceId : PromptExecutionSettings.DEFAULT_SERVICE_ID; + } + + @Override + public String getServiceId() { + return serviceId; + } + + @Override + public String getModelId() { + return modelId; + } + + /** + * Generates embeddings for the given text. + * + * @param data The text to generate embeddings for + * @return A Mono that completes with the embedding + */ + @Override + public Mono generateEmbeddingAsync(String data) { + return generateEmbeddingsAsync(Arrays.asList(data)) + .flatMap(embeddings -> { + if (embeddings.isEmpty()) { + return Mono.empty(); + } + return Mono.just(embeddings.get(0)); + }); + } + + /** + * Generates embeddings for the given texts. + * + * @param data The texts to generate embeddings for + * @return A Mono that completes with the list of embeddings + */ + @Override + public Mono> generateEmbeddingsAsync(List data) { + if (data == null || data.isEmpty()) { + return Mono.just(Collections.emptyList()); + } + + LOGGER.debug("Generating embeddings for {} texts using model {}", data.size(), modelId); + + VoyageAIModels.EmbeddingRequest request = new VoyageAIModels.EmbeddingRequest(); + request.setInput(data); + request.setModel(modelId); + request.setTruncation(true); + + return client.sendRequestAsync("embeddings", request, VoyageAIModels.EmbeddingResponse.class) + .map(response -> { + LOGGER.debug("Received {} embeddings from VoyageAI", response.getData().size()); + + List embeddings = response.getData().stream() + .sorted(Comparator.comparingInt(VoyageAIModels.EmbeddingDataItem::getIndex)) + .map(item -> new Embedding(item.getEmbedding())) + .collect(Collectors.toList()); + + return embeddings; + }); + } + + /** + * Creates a builder for VoyageAI text embedding generation service. + * + * @return A new builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for {@link VoyageAITextEmbeddingGenerationService}. + */ + public static class Builder { + private VoyageAIClient client; + private String modelId; + private String serviceId; + + /** + * Sets the VoyageAI client. + * + * @param client VoyageAI client + * @return This builder + */ + public Builder withClient(VoyageAIClient client) { + this.client = client; + return this; + } + + /** + * Sets the model ID. + * + * @param modelId Model ID (e.g., "voyage-3-large") + * @return This builder + */ + public Builder withModelId(String modelId) { + this.modelId = modelId; + return this; + } + + /** + * Sets the service ID. + * + * @param serviceId Service ID + * @return This builder + */ + public Builder withServiceId(String serviceId) { + this.serviceId = serviceId; + return this; + } + + /** + * Builds the VoyageAI text embedding generation service. + * + * @return A new instance of VoyageAITextEmbeddingGenerationService + */ + public VoyageAITextEmbeddingGenerationService build() { + return new VoyageAITextEmbeddingGenerationService(client, modelId, serviceId); + } + } +} diff --git a/aiservices/voyageai/src/test/java/com/microsoft/semantickernel/aiservices/voyageai/VoyageAIContextualizedEmbeddingGenerationServiceTest.java b/aiservices/voyageai/src/test/java/com/microsoft/semantickernel/aiservices/voyageai/VoyageAIContextualizedEmbeddingGenerationServiceTest.java new file mode 100644 index 00000000..8b4d2260 --- /dev/null +++ b/aiservices/voyageai/src/test/java/com/microsoft/semantickernel/aiservices/voyageai/VoyageAIContextualizedEmbeddingGenerationServiceTest.java @@ -0,0 +1,142 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.aiservices.voyageai; + +import com.microsoft.semantickernel.aiservices.voyageai.contextualizedembedding.VoyageAIContextualizedEmbeddingGenerationService; +import com.microsoft.semantickernel.aiservices.voyageai.core.VoyageAIClient; +import com.microsoft.semantickernel.aiservices.voyageai.core.VoyageAIModels; +import com.microsoft.semantickernel.services.textembedding.Embedding; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import reactor.core.publisher.Mono; + +import java.util.Arrays; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.when; + +public class VoyageAIContextualizedEmbeddingGenerationServiceTest { + + @Test + public void testGenerateContextualizedEmbeddings() { + VoyageAIClient mockClient = Mockito.mock(VoyageAIClient.class); + + VoyageAIModels.ContextualizedEmbeddingResponse mockResponse = + new VoyageAIModels.ContextualizedEmbeddingResponse(); + + VoyageAIModels.EmbeddingDataItem item1 = new VoyageAIModels.EmbeddingDataItem(); + item1.setEmbedding(new float[]{0.1f, 0.2f}); + item1.setIndex(0); + + VoyageAIModels.EmbeddingDataItem item2 = new VoyageAIModels.EmbeddingDataItem(); + item2.setEmbedding(new float[]{0.3f, 0.4f}); + item2.setIndex(0); + + VoyageAIModels.ContextualizedEmbeddingDataList dataList1 = + new VoyageAIModels.ContextualizedEmbeddingDataList(); + dataList1.setData(Arrays.asList(item1)); + + VoyageAIModels.ContextualizedEmbeddingDataList dataList2 = + new VoyageAIModels.ContextualizedEmbeddingDataList(); + dataList2.setData(Arrays.asList(item2)); + + mockResponse.setData(Arrays.asList(dataList1, dataList2)); + + when(mockClient.sendRequestAsync( + eq("contextualizedembeddings"), + any(), + eq(VoyageAIModels.ContextualizedEmbeddingResponse.class))) + .thenReturn(Mono.just(mockResponse)); + + VoyageAIContextualizedEmbeddingGenerationService service = + new VoyageAIContextualizedEmbeddingGenerationService(mockClient, "voyage-3", null); + + List> inputs = Arrays.asList( + Arrays.asList("chunk1"), + Arrays.asList("chunk2") + ); + + List results = service.generateContextualizedEmbeddingsAsync(inputs).block(); + + assertNotNull(results); + assertEquals(2, results.size()); + List expected1 = Arrays.asList(0.1f, 0.2f); + List expected2 = Arrays.asList(0.3f, 0.4f); + assertEquals(expected1, results.get(0).getVector()); + assertEquals(expected2, results.get(1).getVector()); + } + + @Test + public void testGenerateEmbedding() { + VoyageAIClient mockClient = Mockito.mock(VoyageAIClient.class); + + VoyageAIModels.ContextualizedEmbeddingResponse mockResponse = + new VoyageAIModels.ContextualizedEmbeddingResponse(); + + VoyageAIModels.EmbeddingDataItem item = new VoyageAIModels.EmbeddingDataItem(); + item.setEmbedding(new float[]{0.1f, 0.2f, 0.3f}); + item.setIndex(0); + + VoyageAIModels.ContextualizedEmbeddingDataList dataList = + new VoyageAIModels.ContextualizedEmbeddingDataList(); + dataList.setData(Arrays.asList(item)); + + mockResponse.setData(Arrays.asList(dataList)); + + when(mockClient.sendRequestAsync( + eq("contextualizedembeddings"), + any(), + eq(VoyageAIModels.ContextualizedEmbeddingResponse.class))) + .thenReturn(Mono.just(mockResponse)); + + VoyageAIContextualizedEmbeddingGenerationService service = + new VoyageAIContextualizedEmbeddingGenerationService(mockClient, "voyage-3", null); + + Embedding result2 = service.generateEmbeddingAsync("test text").block(); + + assertNotNull(result2); + List expected = Arrays.asList(0.1f, 0.2f, 0.3f); + assertEquals(expected, result2.getVector()); + } + + @Test + public void testServiceIdAndModelId() { + VoyageAIClient mockClient = Mockito.mock(VoyageAIClient.class); + + VoyageAIContextualizedEmbeddingGenerationService service = + new VoyageAIContextualizedEmbeddingGenerationService(mockClient, "voyage-3", "test-service"); + + assertEquals("test-service", service.getServiceId()); + assertEquals("voyage-3", service.getModelId()); + } + + @Test + public void testBuilderPattern() { + VoyageAIClient mockClient = Mockito.mock(VoyageAIClient.class); + + VoyageAIContextualizedEmbeddingGenerationService service = + VoyageAIContextualizedEmbeddingGenerationService.builder() + .withClient(mockClient) + .withModelId("voyage-3") + .withServiceId("test-service") + .build(); + + assertNotNull(service); + assertEquals("test-service", service.getServiceId()); + assertEquals("voyage-3", service.getModelId()); + } + + @Test + public void testNullClientThrowsException() { + assertThrows(IllegalArgumentException.class, () -> + new VoyageAIContextualizedEmbeddingGenerationService(null, "voyage-3", null)); + } + + @Test + public void testNullModelIdThrowsException() { + VoyageAIClient mockClient = Mockito.mock(VoyageAIClient.class); + assertThrows(IllegalArgumentException.class, () -> + new VoyageAIContextualizedEmbeddingGenerationService(mockClient, null, null)); + } +} diff --git a/aiservices/voyageai/src/test/java/com/microsoft/semantickernel/aiservices/voyageai/VoyageAIIntegrationTest.java b/aiservices/voyageai/src/test/java/com/microsoft/semantickernel/aiservices/voyageai/VoyageAIIntegrationTest.java new file mode 100644 index 00000000..b9631d64 --- /dev/null +++ b/aiservices/voyageai/src/test/java/com/microsoft/semantickernel/aiservices/voyageai/VoyageAIIntegrationTest.java @@ -0,0 +1,216 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.aiservices.voyageai; + +import com.microsoft.semantickernel.aiservices.voyageai.contextualizedembedding.VoyageAIContextualizedEmbeddingGenerationService; +import com.microsoft.semantickernel.aiservices.voyageai.core.VoyageAIClient; +import com.microsoft.semantickernel.aiservices.voyageai.multimodalembedding.VoyageAIMultimodalEmbeddingGenerationService; +import com.microsoft.semantickernel.aiservices.voyageai.reranking.VoyageAITextRerankingService; +import com.microsoft.semantickernel.aiservices.voyageai.textembedding.VoyageAITextEmbeddingGenerationService; +import com.microsoft.semantickernel.services.reranking.RerankResult; +import com.microsoft.semantickernel.services.textembedding.Embedding; +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Integration tests for VoyageAI services. + * Requires VOYAGE_API_KEY environment variable to be set. + */ +public class VoyageAIIntegrationTest { + + private static final String API_KEY_ENV_VAR = "VOYAGE_API_KEY"; + private static final String DEFAULT_EMBEDDING_MODEL = "voyage-3-large"; + private static final String DEFAULT_CONTEXTUALIZED_MODEL = "voyage-context-3"; + private static final String DEFAULT_MULTIMODAL_MODEL = "voyage-multimodal-3"; + private static final String DEFAULT_RERANK_MODEL = "rerank-2"; + + private String apiKey; + + @BeforeEach + public void setUp() { + apiKey = System.getenv(API_KEY_ENV_VAR); + Assumptions.assumeTrue( + apiKey != null && !apiKey.isEmpty(), + "Skipping integration test: " + API_KEY_ENV_VAR + " environment variable not set" + ); + } + + @Test + public void testTextEmbeddingGeneration() { + VoyageAIClient client = new VoyageAIClient(apiKey); + VoyageAITextEmbeddingGenerationService service = + VoyageAITextEmbeddingGenerationService.builder() + .withClient(client) + .withModelId(DEFAULT_EMBEDDING_MODEL) + .build(); + + Embedding embedding = service.generateEmbeddingAsync("Hello, world!").block(); + + assertNotNull(embedding, "Embedding should not be null"); + assertNotNull(embedding.getVector(), "Embedding vector should not be null"); + assertTrue(embedding.getVector().size() > 0, "Embedding vector should not be empty"); + + System.out.println("Generated embedding with dimension: " + embedding.getVector().size()); + } + + @Test + public void testMultipleTextEmbeddings() { + VoyageAIClient client = new VoyageAIClient(apiKey); + VoyageAITextEmbeddingGenerationService service = + VoyageAITextEmbeddingGenerationService.builder() + .withClient(client) + .withModelId(DEFAULT_EMBEDDING_MODEL) + .build(); + + List texts = Arrays.asList( + "Hello, world!", + "Semantic Kernel is awesome", + "VoyageAI provides great embeddings" + ); + + List embeddings = service.generateEmbeddingsAsync(texts).block(); + + assertNotNull(embeddings, "Embeddings should not be null"); + assertEquals(3, embeddings.size(), "Should generate 3 embeddings"); + + for (Embedding embedding : embeddings) { + assertNotNull(embedding.getVector(), "Each embedding vector should not be null"); + assertTrue(embedding.getVector().size() > 0, "Each embedding vector should not be empty"); + } + + System.out.println("Generated " + embeddings.size() + " embeddings"); + } + + @Test + public void testTextReranking() { + VoyageAIClient client = new VoyageAIClient(apiKey); + VoyageAITextRerankingService service = + VoyageAITextRerankingService.builder() + .withClient(client) + .withModelId(DEFAULT_RERANK_MODEL) + .build(); + + String query = "What is the capital of France?"; + List documents = Arrays.asList( + "Paris is the capital and most populous city of France.", + "Berlin is the capital of Germany.", + "The Eiffel Tower is located in Paris.", + "London is the capital of the United Kingdom." + ); + + List results = service.rerankAsync(query, documents).block(); + + assertNotNull(results, "Rerank results should not be null"); + assertEquals(4, results.size(), "Should have 4 reranked results"); + + // The first result should have the highest relevance score + assertTrue(results.get(0).getRelevanceScore() >= results.get(1).getRelevanceScore(), + "Results should be sorted by relevance score descending"); + + System.out.println("Reranking results:"); + for (int i = 0; i < results.size(); i++) { + RerankResult result = results.get(i); + System.out.printf("%d. [Index: %d, Score: %.4f] %s%n", + i + 1, result.getIndex(), result.getRelevanceScore(), result.getText()); + } + + // The most relevant document should be about Paris being the capital + assertEquals(0, results.get(0).getIndex(), + "Most relevant document should be the one about Paris being the capital"); + } + + @Test + public void testRerankingWithTopK() { + VoyageAIClient client = new VoyageAIClient(apiKey); + VoyageAITextRerankingService service = + VoyageAITextRerankingService.builder() + .withClient(client) + .withModelId(DEFAULT_RERANK_MODEL) + .withTopK(2) + .build(); + + String query = "Machine learning"; + List documents = Arrays.asList( + "Machine learning is a subset of artificial intelligence.", + "Cooking is an art form.", + "Deep learning uses neural networks.", + "The weather is nice today." + ); + + List results = service.rerankAsync(query, documents).block(); + + assertNotNull(results, "Rerank results should not be null"); + // VoyageAI might return all results sorted, or just top K + assertTrue(results.size() >= 2, "Should have at least 2 results"); + + System.out.println("Top K reranking results:"); + for (RerankResult result : results) { + System.out.printf("[Index: %d, Score: %.4f] %s%n", + result.getIndex(), result.getRelevanceScore(), result.getText()); + } + } + + @Test + public void testContextualizedEmbeddings() { + VoyageAIClient client = new VoyageAIClient(apiKey); + VoyageAIContextualizedEmbeddingGenerationService service = + VoyageAIContextualizedEmbeddingGenerationService.builder() + .withClient(client) + .withModelId(DEFAULT_CONTEXTUALIZED_MODEL) + .build(); + + // Create document chunks with context + List> inputs = Arrays.asList( + Arrays.asList("Introduction to semantic kernel", "Semantic kernel is a framework"), + Arrays.asList("VoyageAI provides embeddings", "VoyageAI is an AI company") + ); + + List embeddings = service.generateContextualizedEmbeddingsAsync(inputs).block(); + + assertNotNull(embeddings, "Contextualized embeddings should not be null"); + // Each input document has 2 chunks, so we expect 4 embeddings total (2 documents * 2 chunks each) + assertEquals(4, embeddings.size(), "Should generate 4 embeddings (2 per document)"); + + for (Embedding embedding : embeddings) { + assertNotNull(embedding.getVector(), "Each embedding vector should not be null"); + assertTrue(embedding.getVector().size() > 0, "Each embedding vector should not be empty"); + } + + System.out.println("Generated " + embeddings.size() + " contextualized embeddings"); + } + + @Test + public void testMultimodalEmbeddings() { + VoyageAIClient client = new VoyageAIClient(apiKey); + VoyageAIMultimodalEmbeddingGenerationService service = + VoyageAIMultimodalEmbeddingGenerationService.builder() + .withClient(client) + .withModelId(DEFAULT_MULTIMODAL_MODEL) + .build(); + + // Test using generateEmbeddingsAsync which handles text conversion + List texts = Arrays.asList( + "This is a text description", + "Another text example" + ); + + List embeddings = service.generateEmbeddingsAsync(texts).block(); + + assertNotNull(embeddings, "Multimodal embeddings should not be null"); + assertEquals(2, embeddings.size(), "Should generate 2 embeddings"); + + for (Embedding embedding : embeddings) { + assertNotNull(embedding.getVector(), "Each embedding vector should not be null"); + assertTrue(embedding.getVector().size() > 0, "Each embedding vector should not be empty"); + } + + System.out.println("Generated " + embeddings.size() + " multimodal embeddings"); + System.out.println("Embedding dimension: " + embeddings.get(0).getVector().size()); + } +} diff --git a/aiservices/voyageai/src/test/java/com/microsoft/semantickernel/aiservices/voyageai/VoyageAIMultimodalEmbeddingGenerationServiceTest.java b/aiservices/voyageai/src/test/java/com/microsoft/semantickernel/aiservices/voyageai/VoyageAIMultimodalEmbeddingGenerationServiceTest.java new file mode 100644 index 00000000..537958d2 --- /dev/null +++ b/aiservices/voyageai/src/test/java/com/microsoft/semantickernel/aiservices/voyageai/VoyageAIMultimodalEmbeddingGenerationServiceTest.java @@ -0,0 +1,181 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.aiservices.voyageai; + +import com.microsoft.semantickernel.aiservices.voyageai.core.VoyageAIClient; +import com.microsoft.semantickernel.aiservices.voyageai.core.VoyageAIModels; +import com.microsoft.semantickernel.aiservices.voyageai.multimodalembedding.VoyageAIMultimodalEmbeddingGenerationService; +import com.microsoft.semantickernel.services.textembedding.Embedding; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import reactor.core.publisher.Mono; + +import java.util.Arrays; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.when; + +public class VoyageAIMultimodalEmbeddingGenerationServiceTest { + + @Test + public void testGenerateMultimodalEmbeddings() { + VoyageAIClient mockClient = Mockito.mock(VoyageAIClient.class); + + VoyageAIModels.MultimodalEmbeddingResponse mockResponse = + new VoyageAIModels.MultimodalEmbeddingResponse(); + + VoyageAIModels.EmbeddingDataItem item1 = new VoyageAIModels.EmbeddingDataItem(); + item1.setEmbedding(new float[]{0.1f, 0.2f}); + item1.setIndex(0); + + VoyageAIModels.EmbeddingDataItem item2 = new VoyageAIModels.EmbeddingDataItem(); + item2.setEmbedding(new float[]{0.3f, 0.4f}); + item2.setIndex(1); + + mockResponse.setData(Arrays.asList(item1, item2)); + + VoyageAIModels.EmbeddingUsage usage = new VoyageAIModels.EmbeddingUsage(); + usage.setTotalTokens(20); + mockResponse.setUsage(usage); + + when(mockClient.sendRequestAsync( + eq("multimodalembeddings"), + any(), + eq(VoyageAIModels.MultimodalEmbeddingResponse.class))) + .thenReturn(Mono.just(mockResponse)); + + VoyageAIMultimodalEmbeddingGenerationService service = + new VoyageAIMultimodalEmbeddingGenerationService(mockClient, "voyage-multimodal-3", null); + + // Create properly structured multimodal inputs + VoyageAIModels.MultimodalContentItem content1 = new VoyageAIModels.MultimodalContentItem("text", "text1"); + VoyageAIModels.MultimodalContentItem content2 = new VoyageAIModels.MultimodalContentItem("text", "text2"); + VoyageAIModels.MultimodalInput input1 = new VoyageAIModels.MultimodalInput(Arrays.asList(content1)); + VoyageAIModels.MultimodalInput input2 = new VoyageAIModels.MultimodalInput(Arrays.asList(content2)); + List inputs = Arrays.asList(input1, input2); + + List results = service.generateMultimodalEmbeddingsAsync(inputs).block(); + + assertNotNull(results); + assertEquals(2, results.size()); + List expected1 = Arrays.asList(0.1f, 0.2f); + List expected2 = Arrays.asList(0.3f, 0.4f); + assertEquals(expected1, results.get(0).getVector()); + assertEquals(expected2, results.get(1).getVector()); + } + + @Test + public void testGenerateEmbedding() { + VoyageAIClient mockClient = Mockito.mock(VoyageAIClient.class); + + VoyageAIModels.MultimodalEmbeddingResponse mockResponse = + new VoyageAIModels.MultimodalEmbeddingResponse(); + + VoyageAIModels.EmbeddingDataItem item = new VoyageAIModels.EmbeddingDataItem(); + item.setEmbedding(new float[]{0.1f, 0.2f, 0.3f}); + item.setIndex(0); + + mockResponse.setData(Arrays.asList(item)); + + VoyageAIModels.EmbeddingUsage usage = new VoyageAIModels.EmbeddingUsage(); + usage.setTotalTokens(10); + mockResponse.setUsage(usage); + + when(mockClient.sendRequestAsync( + eq("multimodalembeddings"), + any(), + eq(VoyageAIModels.MultimodalEmbeddingResponse.class))) + .thenReturn(Mono.just(mockResponse)); + + VoyageAIMultimodalEmbeddingGenerationService service = + new VoyageAIMultimodalEmbeddingGenerationService(mockClient, "voyage-multimodal-3", null); + + Embedding result = service.generateEmbeddingAsync("test text").block(); + + assertNotNull(result); + List expected = Arrays.asList(0.1f, 0.2f, 0.3f); + assertEquals(expected, result.getVector()); + } + + @Test + public void testGenerateEmbeddingsFromTextList() { + VoyageAIClient mockClient = Mockito.mock(VoyageAIClient.class); + + VoyageAIModels.MultimodalEmbeddingResponse mockResponse = + new VoyageAIModels.MultimodalEmbeddingResponse(); + + VoyageAIModels.EmbeddingDataItem item1 = new VoyageAIModels.EmbeddingDataItem(); + item1.setEmbedding(new float[]{0.1f, 0.2f}); + item1.setIndex(0); + + VoyageAIModels.EmbeddingDataItem item2 = new VoyageAIModels.EmbeddingDataItem(); + item2.setEmbedding(new float[]{0.3f, 0.4f}); + item2.setIndex(1); + + mockResponse.setData(Arrays.asList(item1, item2)); + + VoyageAIModels.EmbeddingUsage usage = new VoyageAIModels.EmbeddingUsage(); + usage.setTotalTokens(20); + mockResponse.setUsage(usage); + + when(mockClient.sendRequestAsync( + eq("multimodalembeddings"), + any(), + eq(VoyageAIModels.MultimodalEmbeddingResponse.class))) + .thenReturn(Mono.just(mockResponse)); + + VoyageAIMultimodalEmbeddingGenerationService service = + new VoyageAIMultimodalEmbeddingGenerationService(mockClient, "voyage-multimodal-3", null); + + List results = service.generateEmbeddingsAsync( + Arrays.asList("text1", "text2")).block(); + + assertNotNull(results); + assertEquals(2, results.size()); + List expected1 = Arrays.asList(0.1f, 0.2f); + List expected2 = Arrays.asList(0.3f, 0.4f); + assertEquals(expected1, results.get(0).getVector()); + assertEquals(expected2, results.get(1).getVector()); + } + + @Test + public void testServiceIdAndModelId() { + VoyageAIClient mockClient = Mockito.mock(VoyageAIClient.class); + + VoyageAIMultimodalEmbeddingGenerationService service = + new VoyageAIMultimodalEmbeddingGenerationService(mockClient, "voyage-multimodal-3", "test-service"); + + assertEquals("test-service", service.getServiceId()); + assertEquals("voyage-multimodal-3", service.getModelId()); + } + + @Test + public void testBuilderPattern() { + VoyageAIClient mockClient = Mockito.mock(VoyageAIClient.class); + + VoyageAIMultimodalEmbeddingGenerationService service = + VoyageAIMultimodalEmbeddingGenerationService.builder() + .withClient(mockClient) + .withModelId("voyage-multimodal-3") + .withServiceId("test-service") + .build(); + + assertNotNull(service); + assertEquals("test-service", service.getServiceId()); + assertEquals("voyage-multimodal-3", service.getModelId()); + } + + @Test + public void testNullClientThrowsException() { + assertThrows(IllegalArgumentException.class, () -> + new VoyageAIMultimodalEmbeddingGenerationService(null, "voyage-multimodal-3", null)); + } + + @Test + public void testNullModelIdThrowsException() { + VoyageAIClient mockClient = Mockito.mock(VoyageAIClient.class); + assertThrows(IllegalArgumentException.class, () -> + new VoyageAIMultimodalEmbeddingGenerationService(mockClient, null, null)); + } +} diff --git a/aiservices/voyageai/src/test/java/com/microsoft/semantickernel/aiservices/voyageai/VoyageAITextEmbeddingGenerationServiceTest.java b/aiservices/voyageai/src/test/java/com/microsoft/semantickernel/aiservices/voyageai/VoyageAITextEmbeddingGenerationServiceTest.java new file mode 100644 index 00000000..962a340b --- /dev/null +++ b/aiservices/voyageai/src/test/java/com/microsoft/semantickernel/aiservices/voyageai/VoyageAITextEmbeddingGenerationServiceTest.java @@ -0,0 +1,131 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.aiservices.voyageai; + +import com.microsoft.semantickernel.aiservices.voyageai.core.VoyageAIClient; +import com.microsoft.semantickernel.aiservices.voyageai.core.VoyageAIModels; +import com.microsoft.semantickernel.aiservices.voyageai.textembedding.VoyageAITextEmbeddingGenerationService; +import com.microsoft.semantickernel.services.textembedding.Embedding; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import reactor.core.publisher.Mono; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.when; + +public class VoyageAITextEmbeddingGenerationServiceTest { + + @Test + public void testGenerateEmbedding() { + VoyageAIClient mockClient = Mockito.mock(VoyageAIClient.class); + + VoyageAIModels.EmbeddingResponse mockResponse = new VoyageAIModels.EmbeddingResponse(); + VoyageAIModels.EmbeddingDataItem item = new VoyageAIModels.EmbeddingDataItem(); + item.setEmbedding(new float[]{0.1f, 0.2f, 0.3f}); + item.setIndex(0); + mockResponse.setData(Arrays.asList(item)); + + VoyageAIModels.EmbeddingUsage usage = new VoyageAIModels.EmbeddingUsage(); + usage.setTotalTokens(10); + mockResponse.setUsage(usage); + + when(mockClient.sendRequestAsync( + eq("embeddings"), + any(), + eq(VoyageAIModels.EmbeddingResponse.class))) + .thenReturn(Mono.just(mockResponse)); + + VoyageAITextEmbeddingGenerationService service = + new VoyageAITextEmbeddingGenerationService(mockClient, "voyage-3-large", null); + + Embedding result = service.generateEmbeddingAsync("test text").block(); + + assertNotNull(result); + List expected = Arrays.asList(0.1f, 0.2f, 0.3f); + assertEquals(expected, result.getVector()); + } + + @Test + public void testGenerateMultipleEmbeddings() { + VoyageAIClient mockClient = Mockito.mock(VoyageAIClient.class); + + VoyageAIModels.EmbeddingResponse mockResponse = new VoyageAIModels.EmbeddingResponse(); + + VoyageAIModels.EmbeddingDataItem item1 = new VoyageAIModels.EmbeddingDataItem(); + item1.setEmbedding(new float[]{0.1f, 0.2f}); + item1.setIndex(0); + + VoyageAIModels.EmbeddingDataItem item2 = new VoyageAIModels.EmbeddingDataItem(); + item2.setEmbedding(new float[]{0.3f, 0.4f}); + item2.setIndex(1); + + mockResponse.setData(Arrays.asList(item1, item2)); + + VoyageAIModels.EmbeddingUsage usage = new VoyageAIModels.EmbeddingUsage(); + usage.setTotalTokens(20); + mockResponse.setUsage(usage); + + when(mockClient.sendRequestAsync( + eq("embeddings"), + any(), + eq(VoyageAIModels.EmbeddingResponse.class))) + .thenReturn(Mono.just(mockResponse)); + + VoyageAITextEmbeddingGenerationService service = + new VoyageAITextEmbeddingGenerationService(mockClient, "voyage-3-large", null); + + List results = service.generateEmbeddingsAsync( + Arrays.asList("text1", "text2")).block(); + + assertNotNull(results); + assertEquals(2, results.size()); + List expected1 = Arrays.asList(0.1f, 0.2f); + List expected2 = Arrays.asList(0.3f, 0.4f); + assertEquals(expected1, results.get(0).getVector()); + assertEquals(expected2, results.get(1).getVector()); + } + + @Test + public void testServiceIdAndModelId() { + VoyageAIClient mockClient = Mockito.mock(VoyageAIClient.class); + + VoyageAITextEmbeddingGenerationService service = + new VoyageAITextEmbeddingGenerationService(mockClient, "voyage-3-large", "test-service"); + + assertEquals("test-service", service.getServiceId()); + assertEquals("voyage-3-large", service.getModelId()); + } + + @Test + public void testBuilderPattern() { + VoyageAIClient mockClient = Mockito.mock(VoyageAIClient.class); + + VoyageAITextEmbeddingGenerationService service = + VoyageAITextEmbeddingGenerationService.builder() + .withClient(mockClient) + .withModelId("voyage-3-large") + .withServiceId("test-service") + .build(); + + assertNotNull(service); + assertEquals("test-service", service.getServiceId()); + assertEquals("voyage-3-large", service.getModelId()); + } + + @Test + public void testNullClientThrowsException() { + assertThrows(IllegalArgumentException.class, () -> + new VoyageAITextEmbeddingGenerationService(null, "voyage-3-large", null)); + } + + @Test + public void testNullModelIdThrowsException() { + VoyageAIClient mockClient = Mockito.mock(VoyageAIClient.class); + assertThrows(IllegalArgumentException.class, () -> + new VoyageAITextEmbeddingGenerationService(mockClient, null, null)); + } +} diff --git a/aiservices/voyageai/src/test/java/com/microsoft/semantickernel/aiservices/voyageai/VoyageAITextRerankingServiceTest.java b/aiservices/voyageai/src/test/java/com/microsoft/semantickernel/aiservices/voyageai/VoyageAITextRerankingServiceTest.java new file mode 100644 index 00000000..9feddd27 --- /dev/null +++ b/aiservices/voyageai/src/test/java/com/microsoft/semantickernel/aiservices/voyageai/VoyageAITextRerankingServiceTest.java @@ -0,0 +1,116 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.aiservices.voyageai; + +import com.microsoft.semantickernel.aiservices.voyageai.core.VoyageAIClient; +import com.microsoft.semantickernel.aiservices.voyageai.core.VoyageAIModels; +import com.microsoft.semantickernel.aiservices.voyageai.reranking.VoyageAITextRerankingService; +import com.microsoft.semantickernel.services.reranking.RerankResult; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import reactor.core.publisher.Mono; + +import java.util.Arrays; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.when; + +public class VoyageAITextRerankingServiceTest { + + @Test + public void testRerank() { + VoyageAIClient mockClient = Mockito.mock(VoyageAIClient.class); + + VoyageAIModels.RerankResponse mockResponse = new VoyageAIModels.RerankResponse(); + + VoyageAIModels.RerankDataItem item1 = new VoyageAIModels.RerankDataItem(); + item1.setIndex(1); + item1.setRelevanceScore(0.9); + + VoyageAIModels.RerankDataItem item2 = new VoyageAIModels.RerankDataItem(); + item2.setIndex(0); + item2.setRelevanceScore(0.5); + + mockResponse.setData(Arrays.asList(item1, item2)); + + VoyageAIModels.EmbeddingUsage usage = new VoyageAIModels.EmbeddingUsage(); + usage.setTotalTokens(20); + mockResponse.setUsage(usage); + + when(mockClient.sendRequestAsync( + eq("rerank"), + any(), + eq(VoyageAIModels.RerankResponse.class))) + .thenReturn(Mono.just(mockResponse)); + + VoyageAITextRerankingService service = + new VoyageAITextRerankingService(mockClient, "rerank-2", null, null); + + List documents = Arrays.asList("Document A", "Document B"); + List results = service.rerankAsync("test query", documents).block(); + + assertNotNull(results); + assertEquals(2, results.size()); + + // Results should be sorted by relevance score descending + assertEquals(1, results.get(0).getIndex()); + assertEquals("Document B", results.get(0).getText()); + assertEquals(0.9, results.get(0).getRelevanceScore(), 0.001); + + assertEquals(0, results.get(1).getIndex()); + assertEquals("Document A", results.get(1).getText()); + assertEquals(0.5, results.get(1).getRelevanceScore(), 0.001); + } + + @Test + public void testServiceIdAndModelId() { + VoyageAIClient mockClient = Mockito.mock(VoyageAIClient.class); + + VoyageAITextRerankingService service = + new VoyageAITextRerankingService(mockClient, "rerank-2", "test-service", null); + + assertEquals("test-service", service.getServiceId()); + assertEquals("rerank-2", service.getModelId()); + } + + @Test + public void testBuilderPattern() { + VoyageAIClient mockClient = Mockito.mock(VoyageAIClient.class); + + VoyageAITextRerankingService service = + VoyageAITextRerankingService.builder() + .withClient(mockClient) + .withModelId("rerank-2") + .withServiceId("test-service") + .withTopK(5) + .build(); + + assertNotNull(service); + assertEquals("test-service", service.getServiceId()); + assertEquals("rerank-2", service.getModelId()); + } + + @Test + public void testNullClientThrowsException() { + assertThrows(IllegalArgumentException.class, () -> + new VoyageAITextRerankingService(null, "rerank-2", null, null)); + } + + @Test + public void testNullModelIdThrowsException() { + VoyageAIClient mockClient = Mockito.mock(VoyageAIClient.class); + assertThrows(IllegalArgumentException.class, () -> + new VoyageAITextRerankingService(mockClient, null, null, null)); + } + + @Test + public void testNullQueryThrowsException() { + VoyageAIClient mockClient = Mockito.mock(VoyageAIClient.class); + VoyageAITextRerankingService service = + new VoyageAITextRerankingService(mockClient, "rerank-2", null, null); + + assertThrows(IllegalArgumentException.class, () -> + service.rerankAsync(null, Arrays.asList("doc")).block()); + } +} diff --git a/pom.xml b/pom.xml index 8399f1c7..b54d9e0e 100644 --- a/pom.xml +++ b/pom.xml @@ -74,6 +74,7 @@ aiservices/openai aiservices/google aiservices/huggingface + aiservices/voyageai data/semantickernel-data-azureaisearch data/semantickernel-data-jdbc data/semantickernel-data-redis diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/reranking/RerankResult.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/reranking/RerankResult.java new file mode 100644 index 00000000..9d365b76 --- /dev/null +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/reranking/RerankResult.java @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.services.reranking; + +/** + * Represents a single reranking result containing a document and its relevance score. + */ +public class RerankResult { + private final int index; + private final String text; + private final double relevanceScore; + + /** + * Initializes a new instance of the {@link RerankResult} class. + * + * @param index The index of the document in the original input list + * @param text The document text + * @param relevanceScore The relevance score (higher scores indicate greater relevance) + */ + public RerankResult(int index, String text, double relevanceScore) { + if (text == null) { + throw new IllegalArgumentException("Text cannot be null"); + } + this.index = index; + this.text = text; + this.relevanceScore = relevanceScore; + } + + /** + * Gets the index of the document in the original input list. + * + * @return The index + */ + public int getIndex() { + return index; + } + + /** + * Gets the document text. + * + * @return The text + */ + public String getText() { + return text; + } + + /** + * Gets the relevance score assigned by the reranker. + * Higher scores indicate greater relevance to the query. + * + * @return The relevance score + */ + public double getRelevanceScore() { + return relevanceScore; + } +} diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/reranking/TextRerankingService.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/reranking/TextRerankingService.java new file mode 100644 index 00000000..84fe66c0 --- /dev/null +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/reranking/TextRerankingService.java @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.services.reranking; + +import com.microsoft.semantickernel.services.AIService; +import reactor.core.publisher.Mono; + +import java.util.List; + +/** + * Interface for text reranking services that can reorder documents based on relevance to a query. + */ +public interface TextRerankingService extends AIService { + + /** + * Reranks a list of documents based on their relevance to a query. + * + * @param query The query to rank documents against + * @param documents The list of documents to rerank + * @return A Mono containing a list of {@link RerankResult} sorted by relevance score in descending order + */ + Mono> rerankAsync(String query, List documents); +}