Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #4058: Add support for mixedbread.ai Embedding API #4060

Merged
merged 3 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/asciidoc/modules/ROOT/nav.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ include::partial$generated-documentation/nav.adoc[]
** xref:ml/vertexai.adoc[]
** xref:ml/openai.adoc[]
** xref:ml/bedrock.adoc[]
** xref:ml/watsonai.adoc[]
** xref:ml/mixedbread.adoc[]

* xref:background-operations/index.adoc[]
** xref::background-operations/apoc-load-directory-async.adoc[]
Expand Down
1 change: 1 addition & 0 deletions docs/asciidoc/modules/ROOT/pages/ml/index.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ This section includes:
* xref::ml/openai.adoc[]
* xref::ml/bedrock.adoc[]
* xref::ml/watsonai.adoc[]
* xref::ml/mixedbread.adoc[]
182 changes: 182 additions & 0 deletions docs/asciidoc/modules/ROOT/pages/ml/mixedbread.adoc
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
[[Mixedbread-api]]
= Mixedbread API Access
:description: This section describes procedures that can be used to access the Mixedbread API.


Here is a list of all available Mixedbread API procedures:


[opts=header, cols="1, 4", separator="|"]
|===
|name| description
|apoc.ml.mixedbread.custom(body, $config)| To create a customizable Mixedbread API call
|apoc.ml.mixedbread.embedding(texts, $config)| To create a Mixedbread API call to generate embeddings
|===

The `$config` parameter coincides with the payload to be passed to the http request,
and additionally the following configuration keys.


.Common configuration parameter

|===
| key | description
| apiType | analogous to `apoc.ml.openai.type` APOC config
| endpoint | analogous to `apoc.ml.openai.url` APOC config
| apiVersion | analogous to `apoc.ml.azure.api.version` APOC config
| path | To customize the url portion added to the base url (defined by the `endpoint` config).
By default, is `/embeddings`, `/completions` and `/chat/completions` for respectively the `apoc.ml.openai.embedding`, `apoc.ml.openai.completion` and `apoc.ml.openai.chat` procedures.
| jsonPath | To customize https://github.com/json-path/JsonPath[JSONPath] of the response.
The default is `$` for the `apoc.ml.openai.chat` and `apoc.ml.openai.completion` procedures, and `$.data` for the `apoc.ml.openai.embedding` procedure.
|===

Since embeddings are a super set of the Openai ones,
under-the-hood they leverage the apoc.ml.openai.* procedures,
so we can also create an APOC config `apoc.ml.openai.url` instead of the `endpoint` config.


== Generate Embeddings API

This procedure `apoc.ml.mixedbread.embedding` can take a list of text strings, and will return one row per string, with the embedding data as a 1536 element vector.
It uses the `/embeddings/create` API which is https://www.mixedbread.ai/api-reference/endpoints/embeddings#create-embeddings[documented here^].

Additional configuration is passed to the API, the default model used is `mxbai-embed-large-v1`.


.Parameters
[%autowidth, opts=header]
|===
|name | description
| texts | List of text strings
| apiKey | OpenAI API key
| configuration | optional map. See `Configuration` table above
|===

.Results
[%autowidth, opts=header]
|===
|name | description
| index | index entry in original list
| text | line of text from original list
| embedding | embedding list of floatings/binary,
or map of embedding lists of floatings / binaries, in case of multiple encoding_format
|===


.Generate Embeddings Call
[source,cypher]
----
CALL apoc.ml.mixedbread.embedding(['Some Text'], $apiKey, {}) yield index, text, embedding;
----

.Generate Embeddings Response
[%autowidth, opts=header]
|===
|index | text | embedding
|0 | "Some Text" | [-0.0065358975, -7.9563365E-4, .... -0.010693862, -0.005087272]
|===


.Generate Embeddings Call with custom embedding dimension and model
[source,cypher]
----
CALL apoc.ml.mixedbread.embedding(['Some Text', 'Other Text'],
$apiKey,
{model: 'mxbai-embed-2d-large-v1', dimensions: 4}
)
----

.Generate Embeddings Example Response
[%autowidth, opts=header]
|===
|index | text | embedding
|0 | "Some Text" | [0.019943237, -0.08843994, 0.068603516, 0.034942627]
|1 | "Other Text" | [0.011482239, -0.09069824, 0.05331421, 0.034088135]
|===


.Generate Embeddings Call with custom embedding dimension, model and encoding_format
[source,cypher]
----
CALL apoc.ml.mixedbread.embedding(['Some Text', 'garpez'],
$apiKey,
{encoding_format: ["float", "binary", "ubinary", "int8", "uint8", "base64"]}
)
----

.Generate Embeddings Example Response
[%autowidth, opts=header]
|===
| index | text | embedding
| 0 | "Some Text" | {binary: <binaryResult>, ubinary: <ubinaryResult>, int8: <int8Result>, uint8: <uint8Result>, base64: <base64Result>, float: <floatResult>}
| 0 | "garpez" | {binary: <binaryResult>, ubinary: <ubinaryResult>, int8: <int8Result>, uint8: <uint8Result>, base64: <base64Result>, float: <floatResult>}
|===




== Custom API

Via the `apoc.ml.mixedbread.custom` we can create a customizable Mixedbread API Request,
returning a generic stream of objects.

For example, we can use the https://www.mixedbread.ai/api-reference/endpoints/reranking[Reranking API].


.Reranking API Call
[source,cypher]
----
CALL apoc.ml.mixedbread.custom($apiKey,
{
endpoint: "https://api.mixedbread.ai/v1/reranking",
model: "mixedbread-ai/mxbai-rerank-large-v1",
query: "Who is the author of To Kill a Mockingbird?",
top_k: 3,
input: [
"To Kill a Mockingbird is a novel by Harper Lee published in 1960. It was immediately successful, winning the Pulitzer Prize, and has become a classic of modern American literature.",
"The novel Moby-Dick was written by Herman Melville and first published in 1851. It is considered a masterpiece of American literature and deals with complex themes of obsession, revenge, and the conflict between good and evil.",
"Harper Lee, an American novelist widely known for her novel To Kill a Mockingbird, was born in 1926 in Monroeville, Alabama. She received the Pulitzer Prize for Fiction in 1961.",
"Jane Austen was an English novelist known primarily for her six major novels, which interpret, critique and comment upon the British landed gentry at the end of the 18th century.",
"The Harry Potter series, which consists of seven fantasy novels written by British author J.K. Rowling, is among the most popular and critically acclaimed books of the modern era.",
"The Great Gatsby, a novel written by American author F. Scott Fitzgerald, was published in 1925. The story is set in the Jazz Age and follows the life of millionaire Jay Gatsby and his pursuit of Daisy Buchanan."
]
}
)
----

.Generate Embeddings Example Response
[%autowidth, opts=header]
|===
| value
a|
[source,json]
----
{
"model": "mixedbread-ai/mxbai-rerank-large-v1",
"return_input": false,
"data": [
{
"index": 0,
"score": 0.9980469,
"object": "text_document"
},
{
"index": 2,
"score": 0.9980469,
"object": "text_document"
},
{
"index": 3,
"score": 0.06915283,
"object": "text_document"
}
],
"usage": {
"total_tokens": 302,
"prompt_tokens": 302
},
"object": "list",
"top_k": 3
}
----
|===
1 change: 1 addition & 0 deletions docs/asciidoc/modules/ROOT/pages/ml/openai.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ CALL apoc.ml.openai.chat([{"role": "user", "content": "Explain the importance of
----



== Query with natural language

This procedure `apoc.ml.query` takes a question in natural language and returns the results of that query.
Expand Down
80 changes: 80 additions & 0 deletions extended/src/main/java/apoc/ml/MixedbreadAI.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package apoc.ml;

import apoc.ApocConfig;
import apoc.result.ObjectResult;
import org.neo4j.graphdb.security.URLAccessChecker;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

import java.util.List;
import java.util.Map;
import java.util.stream.Stream;

import static apoc.ml.MLUtil.MODEL_CONF_KEY;
import static apoc.ml.OpenAI.API_TYPE_CONF_KEY;
import static apoc.ml.OpenAI.APOC_ML_OPENAI_URL;

public class MixedbreadAI {

public static final String ENDPOINT_CONF_KEY = "endpoint";
public static final String DEFAULT_MODEL_ID = "mxbai-embed-large-v1";
public static final String MIXEDBREAD_BASE_URL = "https://api.mixedbread.ai/v1";
public static final String ERROR_MSG_MISSING_ENDPOINT = "The endpoint must be defined via config `%s` or via apoc.conf `%s`"
.formatted(ENDPOINT_CONF_KEY, APOC_ML_OPENAI_URL);

public static final String ERROR_MSG_MISSING_MODELID = "The model must be defined via config `%s`"
.formatted(MODEL_CONF_KEY);


/**
* embedding is an Object instead of List<Double>, as with a Mixedbread request having `"encoding_format": [<multipleFormat>]`,
* the result can be e.g. {... "embedding": { "float": [<floatEmbedding>], "base": <base64Embedding>, } ...}
* instead of e.g. {... "embedding": [<floatEmbedding>] ...}
*/
public record EmbeddingResult(long index, String text, Object embedding) { }

@Context
public URLAccessChecker urlAccessChecker;

@Context
public ApocConfig apocConfig;


@Procedure("apoc.ml.mixedbread.custom")
@Description("apoc.mixedbread.custom(, configuration) - returns the embeddings for a given text")
public Stream<ObjectResult> custom(@Name("api_key") String apiKey, @Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration) throws Exception {
if (!configuration.containsKey(MODEL_CONF_KEY)) {
throw new RuntimeException(ERROR_MSG_MISSING_MODELID);
}

configuration.put(API_TYPE_CONF_KEY, OpenAIRequestHandler.Type.MIXEDBREAD_CUSTOM.name());

return OpenAI.executeRequest(apiKey, configuration,
null, null, null, null, null,
apocConfig,
urlAccessChecker)
.map(ObjectResult::new);
}


@Procedure("apoc.ml.mixedbread.embedding")
@Description("apoc.mixedbread.mixedbread([texts], api_key, configuration) - returns the embeddings for a given text")
public Stream<EmbeddingResult> getEmbedding(@Name("texts") List<String> texts,
@Name("api_key") String apiKey,
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration) throws Exception {
configuration.putIfAbsent(MODEL_CONF_KEY, DEFAULT_MODEL_ID);

configuration.put(API_TYPE_CONF_KEY, OpenAIRequestHandler.Type.MIXEDBREAD_EMBEDDING.name());
return OpenAI.getEmbeddingResult(texts, apiKey, configuration, apocConfig, urlAccessChecker,
(map, text) -> {
Long index = (Long) map.get("index");
return new EmbeddingResult(index, text, map.get("embedding"));
},
m -> new EmbeddingResult(-1, m, List.of())
);

}

}
28 changes: 21 additions & 7 deletions extended/src/main/java/apoc/ml/OpenAI.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

Expand Down Expand Up @@ -67,6 +69,9 @@ static Stream<Object> executeRequest(String apiKey, Map<String, Object> configur
Stream.of(ENDPOINT_CONF_KEY, API_TYPE_CONF_KEY, API_VERSION_CONF_KEY, APIKEY_CONF_KEY).forEach(config::remove);

switch (type) {
case MIXEDBREAD_CUSTOM -> {
// no payload manipulation, taken from the configuration as-is
}
case HUGGINGFACE -> {
config.putIfAbsent("inputs", inputs);
jsonPath = "$[0]";
Expand Down Expand Up @@ -111,6 +116,17 @@ public Stream<EmbeddingResult> getEmbedding(@Name("texts") List<String> texts, @
"model": "text-embedding-ada-002",
"usage": { "prompt_tokens": 8, "total_tokens": 8 } }
*/
return getEmbeddingResult(texts, apiKey, configuration, apocConfig, urlAccessChecker,
(map, text) -> {
Long index = (Long) map.get("index");
return new EmbeddingResult(index, text, (List<Double>) map.get("embedding"));
},
m -> new EmbeddingResult(-1, m, List.of())
);
}

static <T> Stream<T> getEmbeddingResult(List<String> texts, String apiKey, Map<String, Object> configuration, ApocConfig apocConfig, URLAccessChecker urlAccessChecker,
BiFunction<Map, String, T> embeddingMapping, Function<String, T> nullMapping) throws JsonProcessingException, MalformedURLException {
if (texts == null) {
throw new RuntimeException(ERROR_NULL_INPUT);
}
Expand All @@ -121,19 +137,17 @@ public Stream<EmbeddingResult> getEmbedding(@Name("texts") List<String> texts, @
List<String> nonNullTexts = collect.get(true);

Stream<Object> resultStream = executeRequest(apiKey, configuration, "embeddings", "text-embedding-ada-002", "input", nonNullTexts, "$.data", apocConfig, urlAccessChecker);
Stream<EmbeddingResult> embeddingResultStream = resultStream
Stream<T> embeddingResultStream = resultStream
.flatMap(v -> ((List<Map<String, Object>>) v).stream())
.map(m -> {
Long index = (Long) m.get("index");
return new EmbeddingResult(index, nonNullTexts.get(index.intValue()), (List<Double>) m.get("embedding"));
String text = nonNullTexts.get(index.intValue());
return embeddingMapping.apply(m, text);
});

List<String> nullTexts = collect.getOrDefault(false, List.of());
Stream<EmbeddingResult> nullResultStream = nullTexts.stream()
.map(i -> {
// null text return index -1 to indicate that are not coming from `/embeddings` RestAPI
return new EmbeddingResult(-1, i, List.of());
});
Stream<T> nullResultStream = nullTexts.stream()
.map(nullMapping);
return Stream.concat(embeddingResultStream, nullResultStream);
}

Expand Down
25 changes: 23 additions & 2 deletions extended/src/main/java/apoc/ml/OpenAIRequestHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

import static apoc.ExtendedApocConfig.APOC_ML_OPENAI_AZURE_VERSION;
import static apoc.ExtendedApocConfig.APOC_ML_OPENAI_URL;
import static apoc.ml.MixedbreadAI.ERROR_MSG_MISSING_ENDPOINT;
import static apoc.ml.MixedbreadAI.MIXEDBREAD_BASE_URL;
import static apoc.ml.MLUtil.*;

abstract class OpenAIRequestHandler {
Expand All @@ -27,8 +29,12 @@ public String getDefaultUrl() {
public abstract void addApiKey(Map<String, Object> headers, String apiKey);

public String getEndpoint(Map<String, Object> procConfig, ApocConfig apocConfig) {
return (String) procConfig.getOrDefault(ENDPOINT_CONF_KEY,
apocConfig.getString(APOC_ML_OPENAI_URL, System.getProperty(APOC_ML_OPENAI_URL, getDefaultUrl())));
String url = (String) procConfig.getOrDefault(ENDPOINT_CONF_KEY,
apocConfig.getString(APOC_ML_OPENAI_URL, System.getProperty(APOC_ML_OPENAI_URL)));
if (url == null) {
return getDefaultUrl();
}
return url;
}

public String getFullUrl(String method, Map<String, Object> procConfig, ApocConfig apocConfig) {
Expand All @@ -40,6 +46,8 @@ public String getFullUrl(String method, Map<String, Object> procConfig, ApocConf
enum Type {
AZURE(new Azure(null)),
HUGGINGFACE(new OpenAi(null)),
MIXEDBREAD_EMBEDDING(new OpenAi(MIXEDBREAD_BASE_URL)),
MIXEDBREAD_CUSTOM(new Custom()),
OPENAI(new OpenAi("https://api.openai.com/v1"));

private final OpenAIRequestHandler handler;
Expand Down Expand Up @@ -85,4 +93,17 @@ public void addApiKey(Map<String, Object> headers, String apiKey) {
headers.put("Authorization", "Bearer " + apiKey);
}
}

static class Custom extends OpenAi {

public Custom() {
super(null);
}

@Override
public String getDefaultUrl() {
throw new RuntimeException(ERROR_MSG_MISSING_ENDPOINT);

}
}
}