From 68c20d977c9192bab35efcca43725a5cd5fc7cbb Mon Sep 17 00:00:00 2001 From: vga91 Date: Mon, 29 Apr 2024 12:04:54 +0200 Subject: [PATCH 1/3] Fixes #4058: Add support for mixedbread.ai Embedding API --- docs/asciidoc/modules/ROOT/nav.adoc | 2 + .../asciidoc/modules/ROOT/pages/ml/index.adoc | 1 + .../modules/ROOT/pages/ml/mixedbread.adoc | 182 ++++++++++++ .../modules/ROOT/pages/ml/openai.adoc | 1 + .../src/main/java/apoc/ml/MixedbreadAI.java | 80 ++++++ extended/src/main/java/apoc/ml/OpenAI.java | 37 ++- .../java/apoc/ml/OpenAIRequestHandler.java | 25 +- .../src/test/java/apoc/ml/MixedbreadAIIT.java | 259 ++++++++++++++++++ 8 files changed, 578 insertions(+), 9 deletions(-) create mode 100644 docs/asciidoc/modules/ROOT/pages/ml/mixedbread.adoc create mode 100644 extended/src/main/java/apoc/ml/MixedbreadAI.java create mode 100644 extended/src/test/java/apoc/ml/MixedbreadAIIT.java diff --git a/docs/asciidoc/modules/ROOT/nav.adoc b/docs/asciidoc/modules/ROOT/nav.adoc index ae2f30addc..6ba15c480d 100644 --- a/docs/asciidoc/modules/ROOT/nav.adoc +++ b/docs/asciidoc/modules/ROOT/nav.adoc @@ -65,6 +65,8 @@ include::partial$generated-documentation/nav.adoc[] ** xref:ml/vertexai.adoc[] ** xref:ml/openai.adoc[] ** xref:ml/bedrock.adoc[] + ** xref:ml/watsonai.adoc[] + ** xref:ml/mixedbread.adoc[] * xref:background-operations/index.adoc[] ** xref::background-operations/apoc-load-directory-async.adoc[] diff --git a/docs/asciidoc/modules/ROOT/pages/ml/index.adoc b/docs/asciidoc/modules/ROOT/pages/ml/index.adoc index 5f82b2c4f7..dfb9c88570 100644 --- a/docs/asciidoc/modules/ROOT/pages/ml/index.adoc +++ b/docs/asciidoc/modules/ROOT/pages/ml/index.adoc @@ -11,3 +11,4 @@ This section includes: * xref::ml/openai.adoc[] * xref::ml/bedrock.adoc[] * xref::ml/watsonai.adoc[] +* xref::ml/mixedbread.adoc[] diff --git a/docs/asciidoc/modules/ROOT/pages/ml/mixedbread.adoc b/docs/asciidoc/modules/ROOT/pages/ml/mixedbread.adoc new file mode 100644 index 0000000000..cc829f28b1 --- /dev/null +++ b/docs/asciidoc/modules/ROOT/pages/ml/mixedbread.adoc @@ -0,0 +1,182 @@ +[[Mixedbread-api]] += Mixedbread API Access +:description: This section describes procedures that can be used to access the Mixedbread API. + + +Here is a list of all available Mixedbread API procedures: + + +[opts=header, cols="1, 4", separator="|"] +|=== +|name| description +|apoc.ml.mixedbread.custom(body, $config)| To create a customizable Mixedbread API call +|apoc.ml.mixedbread.embedding(texts, $config)| To create a Mixedbread API call to generate embeddings +|=== + +The `$config` parameter coincides with the payload to be passed to the http request, +and additionally the following configuration keys. + + +.Common configuration parameter + +|=== +| key | description +| apiType | analogous to `apoc.ml.openai.type` APOC config +| endpoint | analogous to `apoc.ml.openai.url` APOC config +| apiVersion | analogous to `apoc.ml.azure.api.version` APOC config +| path | To customize the url portion added to the base url (defined by the `endpoint` config). +By default, is `/embeddings`, `/completions` and `/chat/completions` for respectively the `apoc.ml.openai.embedding`, `apoc.ml.openai.completion` and `apoc.ml.openai.chat` procedures. +| jsonPath | To customize https://github.com/json-path/JsonPath[JSONPath] of the response. +The default is `$` for the `apoc.ml.openai.chat` and `apoc.ml.openai.completion` procedures, and `$.data` for the `apoc.ml.openai.embedding` procedure. +|=== + +Since embeddings are a super set of the Openai ones, +under-the-hood they leverage the apoc.ml.openai.* procedures, +so we can also create an APOC config `apoc.ml.openai.url` instead of the `endpoint` config. + + +== Generate Embeddings API + +This procedure `apoc.ml.mixedbread.embedding` can take a list of text strings, and will return one row per string, with the embedding data as a 1536 element vector. +It uses the `/embeddings/create` API which is https://www.mixedbread.ai/api-reference/endpoints/embeddings#create-embeddings[documented here^]. + +Additional configuration is passed to the API, the default model used is `mxbai-embed-large-v1`. + + +.Parameters +[%autowidth, opts=header] +|=== +|name | description +| texts | List of text strings +| apiKey | OpenAI API key +| configuration | optional map. See `Configuration` table above +|=== + +.Results +[%autowidth, opts=header] +|=== +|name | description +| index | index entry in original list +| text | line of text from original list +| embedding | embedding list of floatings/binary, + or map of embedding lists of floatings / binaries, in case of multiple encoding_format +|=== + + +.Generate Embeddings Call +[source,cypher] +---- +CALL apoc.ml.mixedbread.embedding(['Some Text'], $apiKey, {}) yield index, text, embedding; +---- + +.Generate Embeddings Response +[%autowidth, opts=header] +|=== +|index | text | embedding +|0 | "Some Text" | [-0.0065358975, -7.9563365E-4, .... -0.010693862, -0.005087272] +|=== + + +.Generate Embeddings Call with custom embedding dimension and model +[source,cypher] +---- +CALL apoc.ml.mixedbread.embedding(['Some Text', 'Other Text'], + $apiKey, + {model: 'mxbai-embed-2d-large-v1', dimensions: 4} +) +---- + +.Generate Embeddings Example Response +[%autowidth, opts=header] +|=== +|index | text | embedding +|0 | "Some Text" | [0.019943237, -0.08843994, 0.068603516, 0.034942627] +|1 | "Other Text" | [0.011482239, -0.09069824, 0.05331421, 0.034088135] +|=== + + +.Generate Embeddings Call with custom embedding dimension, model and encoding_format +[source,cypher] +---- +CALL apoc.ml.mixedbread.embedding(['Some Text', 'garpez'], + $apiKey, + {encoding_format: ["float", "binary", "ubinary", "int8", "uint8", "base64"]} +) +---- + +.Generate Embeddings Example Response +[%autowidth, opts=header] +|=== +| index | text | embedding +| 0 | "Some Text" | {binary: , ubinary: , int8: , uint8: , base64: , float: } +| 0 | "garpez" | {binary: , ubinary: , int8: , uint8: , base64: , float: } +|=== + + + + +== Custom API + +Via the `apoc.ml.mixedbread.custom` we can create a customizable Mixedbread API Request, +returning a generic stream of objects. + +For example, we can use the https://www.mixedbread.ai/api-reference/endpoints/reranking[Reranking API]. + + +.Reranking API Call +[source,cypher] +---- +CALL apoc.ml.mixedbread.custom($apiKey, + { + endpoint: "https://api.mixedbread.ai/v1/reranking", + model: "mixedbread-ai/mxbai-rerank-large-v1", + query: "Who is the author of To Kill a Mockingbird?", + top_k: 3, + input: [ + "To Kill a Mockingbird is a novel by Harper Lee published in 1960. It was immediately successful, winning the Pulitzer Prize, and has become a classic of modern American literature.", + "The novel Moby-Dick was written by Herman Melville and first published in 1851. It is considered a masterpiece of American literature and deals with complex themes of obsession, revenge, and the conflict between good and evil.", + "Harper Lee, an American novelist widely known for her novel To Kill a Mockingbird, was born in 1926 in Monroeville, Alabama. She received the Pulitzer Prize for Fiction in 1961.", + "Jane Austen was an English novelist known primarily for her six major novels, which interpret, critique and comment upon the British landed gentry at the end of the 18th century.", + "The Harry Potter series, which consists of seven fantasy novels written by British author J.K. Rowling, is among the most popular and critically acclaimed books of the modern era.", + "The Great Gatsby, a novel written by American author F. Scott Fitzgerald, was published in 1925. The story is set in the Jazz Age and follows the life of millionaire Jay Gatsby and his pursuit of Daisy Buchanan." + ] + } +) +---- + +.Generate Embeddings Example Response +[%autowidth, opts=header] +|=== +| value +a| +[source,json] +---- +{ + "model": "mixedbread-ai/mxbai-rerank-large-v1", + "return_input": false, + "data": [ + { + "index": 0, + "score": 0.9980469, + "object": "text_document" + }, + { + "index": 2, + "score": 0.9980469, + "object": "text_document" + }, + { + "index": 3, + "score": 0.06915283, + "object": "text_document" + } + ], + "usage": { + "total_tokens": 302, + "prompt_tokens": 302 + }, + "object": "list", + "top_k": 3 +} +---- +|=== diff --git a/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc b/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc index e2423f7726..135e90ec00 100644 --- a/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc +++ b/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc @@ -251,6 +251,7 @@ CALL apoc.ml.openai.chat([{"role": "user", "content": "Explain the importance of ---- + == Query with natural language This procedure `apoc.ml.query` takes a question in natural language and returns the results of that query. diff --git a/extended/src/main/java/apoc/ml/MixedbreadAI.java b/extended/src/main/java/apoc/ml/MixedbreadAI.java new file mode 100644 index 0000000000..31fa7163fb --- /dev/null +++ b/extended/src/main/java/apoc/ml/MixedbreadAI.java @@ -0,0 +1,80 @@ +package apoc.ml; + +import apoc.ApocConfig; +import apoc.result.ObjectResult; +import org.neo4j.graphdb.security.URLAccessChecker; +import org.neo4j.procedure.Context; +import org.neo4j.procedure.Description; +import org.neo4j.procedure.Name; +import org.neo4j.procedure.Procedure; + +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; + +import static apoc.ml.OpenAI.API_TYPE_CONF_KEY; +import static apoc.ml.OpenAI.APOC_ML_OPENAI_URL; +import static apoc.ml.OpenAI.MODEL_CONF_KEY; + +public class MixedbreadAI { + + public static final String ENDPOINT_CONF_KEY = "endpoint"; + public static final String DEFAULT_MODEL_ID = "mxbai-embed-large-v1"; + public static final String MIXEDBREAD_BASE_URL = "https://api.mixedbread.ai/v1"; + public static final String ERROR_MSG_MISSING_ENDPOINT = "The endpoint must be defined via config `%s` or via apoc.conf `%s`" + .formatted(ENDPOINT_CONF_KEY, APOC_ML_OPENAI_URL); + + public static final String ERROR_MSG_MISSING_MODELID = "The model must be defined via config `%s`" + .formatted(MODEL_CONF_KEY); + + + /** + * embedding is an Object instead of List, as with a Mixedbread request having `"encoding_format": []`, + * the result can be e.g. {... "embedding": { "float": [], "base": , } ...} + * instead of e.g. {... "embedding": [] ...} + */ + public record EmbeddingResult(long index, String text, Object embedding) { } + + @Context + public URLAccessChecker urlAccessChecker; + + @Context + public ApocConfig apocConfig; + + + @Procedure("apoc.ml.mixedbread.custom") + @Description("apoc.mixedbread.custom(, configuration) - returns the embeddings for a given text") + public Stream custom(@Name("api_key") String apiKey, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { + if (!configuration.containsKey(MODEL_CONF_KEY)) { + throw new RuntimeException(ERROR_MSG_MISSING_MODELID); + } + + configuration.put(API_TYPE_CONF_KEY, OpenAIRequestHandler.Type.MIXEDBREAD_CUSTOM.name()); + + return OpenAI.executeRequest(apiKey, configuration, + null, null, null, null, null, + apocConfig, + urlAccessChecker) + .map(ObjectResult::new); + } + + + @Procedure("apoc.ml.mixedbread.embedding") + @Description("apoc.mixedbread.mixedbread([texts], api_key, configuration) - returns the embeddings for a given text") + public Stream getEmbedding(@Name("texts") List texts, + @Name("api_key") String apiKey, + @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { + configuration.putIfAbsent(MODEL_CONF_KEY, DEFAULT_MODEL_ID); + + configuration.put(API_TYPE_CONF_KEY, OpenAIRequestHandler.Type.MIXEDBREAD_EMBEDDING.name()); + return OpenAI.getEmbeddingResult(texts, apiKey, configuration, apocConfig, urlAccessChecker, + (map, text) -> { + Long index = (Long) map.get("index"); + return new EmbeddingResult(index, text, map.get("embedding")); + }, + m -> new EmbeddingResult(-1, m, List.of()) + ); + + } + +} diff --git a/extended/src/main/java/apoc/ml/OpenAI.java b/extended/src/main/java/apoc/ml/OpenAI.java index 79772fcde1..461600696e 100644 --- a/extended/src/main/java/apoc/ml/OpenAI.java +++ b/extended/src/main/java/apoc/ml/OpenAI.java @@ -5,6 +5,7 @@ import apoc.result.MapResult; import apoc.util.JsonUtil; import com.fasterxml.jackson.core.JsonProcessingException; +import org.jetbrains.annotations.NotNull; import org.neo4j.graphdb.security.URLAccessChecker; import org.neo4j.procedure.Context; import org.neo4j.procedure.Description; @@ -17,6 +18,8 @@ import java.util.Locale; import java.util.Map; import java.util.Objects; +import java.util.function.BiFunction; +import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -67,6 +70,9 @@ static Stream executeRequest(String apiKey, Map configur Stream.of(ENDPOINT_CONF_KEY, API_TYPE_CONF_KEY, API_VERSION_CONF_KEY, APIKEY_CONF_KEY).forEach(config::remove); switch (type) { + case MIXEDBREAD_CUSTOM -> { + // no payload manipulation, taken from the configuration as-is + } case HUGGINGFACE -> { config.putIfAbsent("inputs", inputs); jsonPath = "$[0]"; @@ -111,6 +117,17 @@ public Stream getEmbedding(@Name("texts") List texts, @ "model": "text-embedding-ada-002", "usage": { "prompt_tokens": 8, "total_tokens": 8 } } */ + return getEmbeddingResult(texts, apiKey, configuration, apocConfig, urlAccessChecker, + (map, text) -> { + Long index = (Long) map.get("index"); + return new EmbeddingResult(index, text, (List) map.get("embedding")); + }, + m -> new EmbeddingResult(-1, m, List.of()) + ); + } + + static Stream getEmbeddingResult(List texts, String apiKey, Map configuration, ApocConfig apocConfig, URLAccessChecker urlAccessChecker, + BiFunction embeddingMapping, Function nullMapping) throws JsonProcessingException, MalformedURLException { if (texts == null) { throw new RuntimeException(ERROR_NULL_INPUT); } @@ -121,19 +138,25 @@ public Stream getEmbedding(@Name("texts") List texts, @ List nonNullTexts = collect.get(true); Stream resultStream = executeRequest(apiKey, configuration, "embeddings", "text-embedding-ada-002", "input", nonNullTexts, "$.data", apocConfig, urlAccessChecker); - Stream embeddingResultStream = resultStream +// Function, R> mapRFunction = m -> { +// Long index = (Long) m.get("index"); +// return new EmbeddingResult(index, nonNullTexts.get(index.intValue()), m.get("embedding")); +// }; + Stream embeddingResultStream = resultStream .flatMap(v -> ((List>) v).stream()) .map(m -> { Long index = (Long) m.get("index"); - return new EmbeddingResult(index, nonNullTexts.get(index.intValue()), (List) m.get("embedding")); + String text = nonNullTexts.get(index.intValue()); + return embeddingMapping.apply(m, text); }); List nullTexts = collect.getOrDefault(false, List.of()); - Stream nullResultStream = nullTexts.stream() - .map(i -> { - // null text return index -1 to indicate that are not coming from `/embeddings` RestAPI - return new EmbeddingResult(-1, i, List.of()); - }); +// Function stringRFunction = i -> { +// // null text return index -1 to indicate that are not coming from `/embeddings` RestAPI +// return new EmbeddingResult(-1, i, List.of()); +// }; + Stream nullResultStream = nullTexts.stream() + .map(nullMapping); return Stream.concat(embeddingResultStream, nullResultStream); } diff --git a/extended/src/main/java/apoc/ml/OpenAIRequestHandler.java b/extended/src/main/java/apoc/ml/OpenAIRequestHandler.java index 940b1737d4..6a8f7a686a 100644 --- a/extended/src/main/java/apoc/ml/OpenAIRequestHandler.java +++ b/extended/src/main/java/apoc/ml/OpenAIRequestHandler.java @@ -10,6 +10,8 @@ import static apoc.ExtendedApocConfig.APOC_ML_OPENAI_AZURE_VERSION; import static apoc.ExtendedApocConfig.APOC_ML_OPENAI_URL; +import static apoc.ml.MixedbreadAI.ERROR_MSG_MISSING_ENDPOINT; +import static apoc.ml.MixedbreadAI.MIXEDBREAD_BASE_URL; import static apoc.ml.MLUtil.*; abstract class OpenAIRequestHandler { @@ -27,8 +29,12 @@ public String getDefaultUrl() { public abstract void addApiKey(Map headers, String apiKey); public String getEndpoint(Map procConfig, ApocConfig apocConfig) { - return (String) procConfig.getOrDefault(ENDPOINT_CONF_KEY, - apocConfig.getString(APOC_ML_OPENAI_URL, System.getProperty(APOC_ML_OPENAI_URL, getDefaultUrl()))); + String url = (String) procConfig.getOrDefault(ENDPOINT_CONF_KEY, + apocConfig.getString(APOC_ML_OPENAI_URL, System.getProperty(APOC_ML_OPENAI_URL))); + if (url == null) { + return getDefaultUrl(); + } + return url; } public String getFullUrl(String method, Map procConfig, ApocConfig apocConfig) { @@ -40,6 +46,8 @@ public String getFullUrl(String method, Map procConfig, ApocConf enum Type { AZURE(new Azure(null)), HUGGINGFACE(new OpenAi(null)), + MIXEDBREAD_EMBEDDING(new OpenAi(MIXEDBREAD_BASE_URL)), + MIXEDBREAD_CUSTOM(new Custom()), OPENAI(new OpenAi("https://api.openai.com/v1")); private final OpenAIRequestHandler handler; @@ -85,4 +93,17 @@ public void addApiKey(Map headers, String apiKey) { headers.put("Authorization", "Bearer " + apiKey); } } + + static class Custom extends OpenAi { + + public Custom() { + super(null); + } + + @Override + public String getDefaultUrl() { + throw new RuntimeException(ERROR_MSG_MISSING_ENDPOINT); + + } + } } diff --git a/extended/src/test/java/apoc/ml/MixedbreadAIIT.java b/extended/src/test/java/apoc/ml/MixedbreadAIIT.java new file mode 100644 index 0000000000..fa3539c6e0 --- /dev/null +++ b/extended/src/test/java/apoc/ml/MixedbreadAIIT.java @@ -0,0 +1,259 @@ +package apoc.ml; + +import apoc.util.TestUtil; +import org.junit.Assume; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; +import org.neo4j.test.rule.DbmsRule; +import org.neo4j.test.rule.ImpermanentDbmsRule; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static apoc.ml.MixedbreadAI.*; +import static apoc.ml.OpenAI.MODEL_CONF_KEY; +import static apoc.util.TestUtil.testCall; +import static apoc.util.TestUtil.testResult; +import static apoc.util.Util.map; +import static java.util.Collections.emptyMap; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; + +public class MixedbreadAIIT { + + @ClassRule + public static DbmsRule db = new ImpermanentDbmsRule(); + + private static String apiKey; + + @BeforeClass + public static void setUp() throws Exception { + String keyIdEnv = "MIXEDBREAD_API_KEY"; + apiKey = System.getenv(keyIdEnv); + + Assume.assumeNotNull("No MIXEDBREAD_API_KEY environment configured", apiKey); + + TestUtil.registerProcedure(db, MixedbreadAI.class); + } + + @Test + public void getEmbedding() { + testResult(db, "CALL apoc.ml.mixedbread.embedding(['Some Text', 'Other Text'], $apiKey, $conf)", + map("apiKey", apiKey, "conf", emptyMap()), + r -> { + Map row = r.next(); + assertEmbedding(row, 0L, "Some Text", 1024); + + row = r.next(); + assertEmbedding(row, 1L, "Other Text", 1024); + + assertFalse(r.hasNext()); + }); + } + + @Test + public void getEmbeddingWithNulls() { + testResult(db, "CALL apoc.ml.mixedbread.embedding([null, 'Some Text', null, 'Another Text'], $apiKey, $conf)", + Map.of("apiKey", apiKey, "conf", emptyMap()), + (r) -> { + + Map row = r.next(); + assertEquals(1024, ((List) row.get("embedding")).size()); + assertEquals("Some Text", row.get("text")); + + row = r.next(); + assertEquals(1024, ((List) row.get("embedding")).size()); + assertEquals("Another Text", row.get("text")); + + row = r.next(); + assertNullEmbedding(row); + + row = r.next(); + assertNullEmbedding(row); + + assertFalse(r.hasNext()); + }); + } + + @Test + public void getEmbeddingWithCustomMultipleEncodingFormats() { + Set formats = Set.of("float", "binary", "ubinary", "int8", "uint8", "base64"); + Map conf = map("encoding_format", formats); + testResult(db, "CALL apoc.ml.mixedbread.embedding(['Some Text', 'Other Text'], $apiKey, $conf)", + map("apiKey", apiKey, "conf", conf), + r -> { + Map row = r.next(); + assertEquals(0L, row.get("index")); + assertEquals("Some Text", row.get("text")); + var embedding = (Map) row.get("embedding"); + assertEquals(formats, embedding.keySet()); + assertTrue(embedding.get("float") instanceof List); + assertTrue(embedding.get("binary") instanceof List); + assertTrue(embedding.get("ubinary") instanceof List); + assertTrue(embedding.get("int8") instanceof List); + assertTrue(embedding.get("uint8") instanceof List); + assertTrue(embedding.get("base64") instanceof String); + + row = r.next(); + assertEquals(1L, row.get("index")); + assertEquals("Other Text", row.get("text")); + embedding = (Map) row.get("embedding"); + assertEquals(formats, embedding.keySet()); + assertTrue(embedding.get("float") instanceof List); + assertTrue(embedding.get("binary") instanceof List); + assertTrue(embedding.get("ubinary") instanceof List); + assertTrue(embedding.get("int8") instanceof List); + assertTrue(embedding.get("uint8") instanceof List); + assertTrue(embedding.get("base64") instanceof String); + + assertFalse(r.hasNext()); + }); + } + + @Test + public void getEmbeddingWithCustomEmbeddingSize() { + testResult(db, "CALL apoc.ml.mixedbread.embedding(['Some Text', 'Other Text'], $apiKey, $conf)", + map("apiKey", apiKey, "conf", map("dimensions", 256)), + r -> { + Map row = r.next(); + assertEmbedding(row, 0L, "Some Text", 256); + + row = r.next(); + assertEmbedding(row, 1L, "Other Text", 256); + + assertFalse(r.hasNext()); + }); + } + + @Test + public void getEmbeddingWithOtherModel() { + testResult(db, "CALL apoc.ml.mixedbread.embedding(['Some Text', 'Other Text'], $apiKey, $conf)", + map("apiKey", apiKey, "conf", map(MODEL_CONF_KEY, "mxbai-embed-2d-large-v1")), + r -> { + Map row = r.next(); + assertEmbedding(row, 0L, "Some Text", 1024); + + row = r.next(); + assertEmbedding(row, 1L, "Other Text", 1024); + + assertFalse(r.hasNext()); + }); + } + + @Test + public void getEmbeddingWithWrongModel() { + try { + testCall(db, "CALL apoc.ml.mixedbread.embedding(['Some Text', 'Other Text'], $apiKey, $conf)", + map("apiKey", apiKey, + "conf", map(MODEL_CONF_KEY, "wrong-id") + ), + r -> fail("Should fail due to wrong model id")); + } catch (Exception e) { + String errMsg = e.getMessage(); + assertTrue("Actual error message is: " + errMsg, + errMsg.contains("Server returned HTTP response code: 422 for URL: https://api.mixedbread.ai/v1/embeddings") + ); + + } + } + + /** + * Example taken from here: https://www.mixedbread.ai/api-reference/endpoints/reranking + */ + @Test + public void customWithReranking() { + List input = List.of("To Kill a Mockingbird is a novel by Harper Lee published in 1960. It was immediately successful, winning the Pulitzer Prize, and has become a classic of modern American literature.", + "The novel Moby-Dick was written by Herman Melville and first published in 1851. It is considered a masterpiece of American literature and deals with complex themes of obsession, revenge, and the conflict between good and evil.", + "Harper Lee, an American novelist widely known for her novel To Kill a Mockingbird, was born in 1926 in Monroeville, Alabama. She received the Pulitzer Prize for Fiction in 1961.", + "Jane Austen was an English novelist known primarily for her six major novels, which interpret, critique and comment upon the British landed gentry at the end of the 18th century.", + "The Harry Potter series, which consists of seven fantasy novels written by British author J.K. Rowling, is among the most popular and critically acclaimed books of the modern era.", + "The Great Gatsby, a novel written by American author F. Scott Fitzgerald, was published in 1925. The story is set in the Jazz Age and follows the life of millionaire Jay Gatsby and his pursuit of Daisy Buchanan." + ); + Map conf = map(ENDPOINT_CONF_KEY, MIXEDBREAD_BASE_URL + "/reranking", + MODEL_CONF_KEY, "mixedbread-ai/mxbai-rerank-large-v1", + "query", "Who is the author of To Kill a Mockingbird?", + "top_k", 3, + "input", input + ); + testCall(db, "CALL apoc.ml.mixedbread.custom($apiKey, $conf)", + Map.of("apiKey", apiKey, "conf", conf), + row -> { + Map value = (Map) row.get("value"); + + List data = (List) value.get("data"); + assertEquals(3, data.size()); + + Map firstData = map("index", 0L, + "score", 0.9980469, + "object", "text_document"); + assertEquals(firstData, data.get(0)); + + + Map secondData = map( + "index", 2L, + "score", 0.9980469, + "object", "text_document"); + assertEquals(secondData, data.get(1)); + + Map thirdData = map( + "index", 3L, + "score", 0.06915283, + "object", "text_document"); + assertEquals(thirdData, data.get(2)); + + assertEquals("list", value.get("object")); + }); + } + + @Test + public void customWithMissingEndpoint() { + try { + testCall(db, "CALL apoc.ml.mixedbread.custom($apiKey, $conf)", + map("apiKey", apiKey, + "conf", map(MODEL_CONF_KEY, "aModelId") + ), + r -> fail("Should fail due to missing endpoint")); + } catch (Exception e) { + String errMsg = e.getMessage(); + assertTrue("Actual error message is: " + errMsg, + errMsg.contains(ERROR_MSG_MISSING_ENDPOINT) + ); + } + } + + @Test + public void customWithMissingModel() { + try { + testCall(db, "CALL apoc.ml.mixedbread.custom($apiKey, $conf)", + map("apiKey", apiKey, + "conf", map(ENDPOINT_CONF_KEY, MIXEDBREAD_BASE_URL + "/reranking", + "foo", "bar") + ), + r -> fail("Should fail due to missing model")); + } catch (Exception e) { + String errMsg = e.getMessage(); + assertTrue("Actual error message is: " + errMsg, + errMsg.contains(ERROR_MSG_MISSING_MODELID) + ); + } + } + + private static void assertEmbedding(Map row, + long expectedIdx, + String expectedText, + Integer expectedSize) { + assertEquals(expectedIdx, row.get("index")); + assertEquals(expectedText, row.get("text")); + var embedding = (List) row.get("embedding"); + assertEquals(expectedSize, embedding.size()); + } + + private static void assertNullEmbedding(Map row) { + assertEmbedding(row, -1, null, 0); + } + +} From 1c45783ed793c08fcd0e35721413c104a053ede5 Mon Sep 17 00:00:00 2001 From: vga91 Date: Tue, 14 May 2024 14:36:11 +0200 Subject: [PATCH 2/3] updated extended.txt --- extended/src/main/java/apoc/ml/OpenAI.java | 1 - extended/src/main/resources/extended.txt | 2 ++ 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/extended/src/main/java/apoc/ml/OpenAI.java b/extended/src/main/java/apoc/ml/OpenAI.java index 461600696e..652ffb6b6d 100644 --- a/extended/src/main/java/apoc/ml/OpenAI.java +++ b/extended/src/main/java/apoc/ml/OpenAI.java @@ -5,7 +5,6 @@ import apoc.result.MapResult; import apoc.util.JsonUtil; import com.fasterxml.jackson.core.JsonProcessingException; -import org.jetbrains.annotations.NotNull; import org.neo4j.graphdb.security.URLAccessChecker; import org.neo4j.procedure.Context; import org.neo4j.procedure.Description; diff --git a/extended/src/main/resources/extended.txt b/extended/src/main/resources/extended.txt index b11bcdf95d..f2c3e172c1 100644 --- a/extended/src/main/resources/extended.txt +++ b/extended/src/main/resources/extended.txt @@ -114,6 +114,8 @@ apoc.ml.fromCypher apoc.ml.fromQueries apoc.ml.query apoc.ml.schema +apoc.ml.mixedbread.custom +apoc.ml.mixedbread.embedding apoc.ml.openai.chat apoc.ml.openai.completion apoc.ml.openai.embedding From 7e0897079a9b08e8dbf84b0b7241608c14b92fe8 Mon Sep 17 00:00:00 2001 From: vga91 Date: Thu, 16 May 2024 12:09:26 +0200 Subject: [PATCH 3/3] Changes after rebase --- extended/src/main/java/apoc/ml/MixedbreadAI.java | 2 +- extended/src/main/java/apoc/ml/OpenAI.java | 8 -------- extended/src/test/java/apoc/ml/MixedbreadAIIT.java | 2 +- 3 files changed, 2 insertions(+), 10 deletions(-) diff --git a/extended/src/main/java/apoc/ml/MixedbreadAI.java b/extended/src/main/java/apoc/ml/MixedbreadAI.java index 31fa7163fb..99f4fd7316 100644 --- a/extended/src/main/java/apoc/ml/MixedbreadAI.java +++ b/extended/src/main/java/apoc/ml/MixedbreadAI.java @@ -12,9 +12,9 @@ import java.util.Map; import java.util.stream.Stream; +import static apoc.ml.MLUtil.MODEL_CONF_KEY; import static apoc.ml.OpenAI.API_TYPE_CONF_KEY; import static apoc.ml.OpenAI.APOC_ML_OPENAI_URL; -import static apoc.ml.OpenAI.MODEL_CONF_KEY; public class MixedbreadAI { diff --git a/extended/src/main/java/apoc/ml/OpenAI.java b/extended/src/main/java/apoc/ml/OpenAI.java index 652ffb6b6d..842ae4267f 100644 --- a/extended/src/main/java/apoc/ml/OpenAI.java +++ b/extended/src/main/java/apoc/ml/OpenAI.java @@ -137,10 +137,6 @@ static Stream getEmbeddingResult(List texts, String apiKey, Map nonNullTexts = collect.get(true); Stream resultStream = executeRequest(apiKey, configuration, "embeddings", "text-embedding-ada-002", "input", nonNullTexts, "$.data", apocConfig, urlAccessChecker); -// Function, R> mapRFunction = m -> { -// Long index = (Long) m.get("index"); -// return new EmbeddingResult(index, nonNullTexts.get(index.intValue()), m.get("embedding")); -// }; Stream embeddingResultStream = resultStream .flatMap(v -> ((List>) v).stream()) .map(m -> { @@ -150,10 +146,6 @@ static Stream getEmbeddingResult(List texts, String apiKey, Map nullTexts = collect.getOrDefault(false, List.of()); -// Function stringRFunction = i -> { -// // null text return index -1 to indicate that are not coming from `/embeddings` RestAPI -// return new EmbeddingResult(-1, i, List.of()); -// }; Stream nullResultStream = nullTexts.stream() .map(nullMapping); return Stream.concat(embeddingResultStream, nullResultStream); diff --git a/extended/src/test/java/apoc/ml/MixedbreadAIIT.java b/extended/src/test/java/apoc/ml/MixedbreadAIIT.java index fa3539c6e0..a9a456fc1c 100644 --- a/extended/src/test/java/apoc/ml/MixedbreadAIIT.java +++ b/extended/src/test/java/apoc/ml/MixedbreadAIIT.java @@ -12,8 +12,8 @@ import java.util.Map; import java.util.Set; +import static apoc.ml.MLUtil.MODEL_CONF_KEY; import static apoc.ml.MixedbreadAI.*; -import static apoc.ml.OpenAI.MODEL_CONF_KEY; import static apoc.util.TestUtil.testCall; import static apoc.util.TestUtil.testResult; import static apoc.util.Util.map;