Skip to content

Commit

Permalink
Fixes #4057: Add Watson Embedding API
Browse files Browse the repository at this point in the history
  • Loading branch information
vga91 committed May 2, 2024
1 parent ce01da6 commit 35498bf
Show file tree
Hide file tree
Showing 23 changed files with 473 additions and 221 deletions.
41 changes: 37 additions & 4 deletions docs/asciidoc/modules/ROOT/pages/ml/watsonai.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,40 @@ With a result like this:

== Generate Embeddings API

Watson does not have built-in embedding APIs,
so if we want to use embeddings, we have to use one of the other Apoc ML integrations,
such as xref::ml/openai.adoc[apoc.ml.openai.embeddings] or xref::ml/vertexai.adoc[apoc.ml.vertexai.embeddings],
or https://huggingface.co/docs/hub/sentence-transformers[HuggingFace embeddings as sentence transformers].
This procedure `apoc.ml.watson.embedding` can take a list of text strings, and will return one row per string, with the embedding data as a 512 element vector.


Additional configuration is passed to the API, the default model used is `ibm/slate-30m-english-rtrvr`.
See https://www.ibm.com/products/watsonx-ai/foundation-models[here] to check the list of the current models.

.Generate Embeddings Call
[source,cypher]
----
CALL apoc.ml.watson.embedding(['Some Text'], $accessToken, $project, {}) yield index, text, embedding;
----

.Generate Embeddings Response
[%autowidth, opts=header]
|===
|index | text | embedding
|0 | "Some Text" | [-0.0065358975, -7.9563365E-4, .... -0.010693862, -0.005087272]
|===

.Parameters
[%autowidth, opts=header]
|===
|name | description
| texts | List of text strings
| accessToken | Watson access token
| configuration | optional config map, equivalent to other procedures
|===


.Results
[%autowidth, opts=header]
|===
|name | description
| index | index entry in original list
| text | line of text from original list
| embedding | floating point embedding vector for ada-002 model
|===
5 changes: 5 additions & 0 deletions extended/src/main/java/apoc/ml/MLUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,9 @@

public class MLUtil {
public static final String ERROR_NULL_INPUT = "The input provided is null. Please specify a valid input";

public static final String ENDPOINT_CONF_KEY = "endpoint";
public static final String API_VERSION_CONF_KEY = "apiVersion";
public static final String REGION_CONF_KEY = "region";
public static final String MODEL_CONF_KEY = "model";
}
5 changes: 1 addition & 4 deletions extended/src/main/java/apoc/ml/OpenAI.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,15 @@

import static apoc.ExtendedApocConfig.APOC_ML_OPENAI_TYPE;
import static apoc.ExtendedApocConfig.APOC_OPENAI_KEY;
import static apoc.ml.MLUtil.ERROR_NULL_INPUT;
import static apoc.ml.MLUtil.*;


@Extended
public class OpenAI {
public static final String API_TYPE_CONF_KEY = "apiType";
public static final String APIKEY_CONF_KEY = "apiKey";
public static final String ENDPOINT_CONF_KEY = "endpoint";
public static final String API_VERSION_CONF_KEY = "apiVersion";
public static final String JSON_PATH_CONF_KEY = "jsonPath";
public static final String PATH_CONF_KEY = "path";
public static final String MODEL_CONF_KEY = "model";

@Context
public ApocConfig apocConfig;
Expand Down
3 changes: 1 addition & 2 deletions extended/src/main/java/apoc/ml/OpenAIRequestHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@

import static apoc.ExtendedApocConfig.APOC_ML_OPENAI_AZURE_VERSION;
import static apoc.ExtendedApocConfig.APOC_ML_OPENAI_URL;
import static apoc.ml.OpenAI.API_VERSION_CONF_KEY;
import static apoc.ml.OpenAI.ENDPOINT_CONF_KEY;
import static apoc.ml.MLUtil.*;

abstract class OpenAIRequestHandler {

Expand Down
3 changes: 1 addition & 2 deletions extended/src/main/java/apoc/ml/VertexAIHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@
import java.util.Objects;

import static apoc.ExtendedApocConfig.APOC_ML_VERTEXAI_URL;
import static apoc.ml.MLUtil.*;
import static apoc.ml.VertexAI.getParameters;
import static org.apache.commons.lang3.StringUtils.isBlank;

public abstract class VertexAIHandler {
public static final String ENDPOINT_CONF_KEY = "endpoint";
public static final String MODEL_CONF_KEY = "model";
public static final String RESOURCE_CONF_KEY = "resource";

public static final String STREAM_RESOURCE = "streamGenerateContent";
Expand Down
9 changes: 4 additions & 5 deletions extended/src/main/java/apoc/ml/aws/AWSConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import static apoc.ApocConfig.apocConfig;
import static apoc.ExtendedApocConfig.APOC_AWS_KEY_ID;
import static apoc.ExtendedApocConfig.APOC_AWS_SECRET_KEY;
import static apoc.ml.MLUtil.*;

public abstract class AWSConfig {

Expand All @@ -17,8 +18,6 @@ public abstract class AWSConfig {
public static final String JSON_PATH = "jsonPath";
public static final String SECRET_KEY = "secretKey";
public static final String KEY_ID = "keyId";
public static final String REGION_KEY = "region";
public static final String ENDPOINT_KEY = "endpoint";
public static final String METHOD_KEY = "method";

private final String keyId;
Expand All @@ -37,7 +36,7 @@ protected AWSConfig(Map<String, Object> config) {
this.keyId = apocConfig().getString(APOC_AWS_KEY_ID, (String) config.get(KEY_ID));
this.secretKey = apocConfig().getString(APOC_AWS_SECRET_KEY, (String) config.get(SECRET_KEY));

this.region = (String) config.getOrDefault(REGION_KEY, "us-east-1");
this.region = (String) config.getOrDefault(REGION_CONF_KEY, "us-east-1");
this.endpoint = getEndpoint(config, getDefaultEndpoint(config));

this.method = (String) config.getOrDefault(METHOD_KEY, getDefaultMethod());
Expand All @@ -49,7 +48,7 @@ protected AWSConfig(Map<String, Object> config) {

private String getEndpoint(Map<String, Object> config, String defaultEndpoint) {

String endpointConfig = (String) config.get(ENDPOINT_KEY);
String endpointConfig = (String) config.get(ENDPOINT_CONF_KEY);
if (endpointConfig != null) {
return endpointConfig;
}
Expand All @@ -58,7 +57,7 @@ private String getEndpoint(Map<String, Object> config, String defaultEndpoint) {
}
String errMessage = String.format("An endpoint could not be retrieved.\n" +
"Explicit the %s config",
ENDPOINT_KEY);
ENDPOINT_CONF_KEY);
throw new RuntimeException(errMessage);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package apoc.ml;
package apoc.ml.watson;

import apoc.ApocConfig;
import apoc.Extended;
Expand All @@ -11,36 +11,65 @@
import org.neo4j.procedure.Procedure;

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static apoc.ExtendedApocConfig.APOC_ML_WATSON_URL;
import static apoc.ExtendedApocConfig.APOC_ML_WATSON_PROJECT_ID;
import static apoc.ml.MLUtil.ERROR_NULL_INPUT;

@Extended
public class Watson {
private static final String PROJECT_ID_KEY = "project_id";
private static final String SPACE_ID_KEY = "space_id";
private static final String WML_INSTANCE_CRN_KEY = "wml_instance_crn";
private static final String MODEL_ID_KEY = "model_id";
private static final String DEFAULT_MODEL_ID = "ibm/granite-13b-chat-v2";
static final String PROJECT_ID_KEY = "project_id";
static final String SPACE_ID_KEY = "space_id";
static final String MODEL_ID_KEY = "model_id";
static final String WML_INSTANCE_CRN_KEY = "wml_instance_crn";
static final String DEFAULT_COMPLETION_MODEL_ID = "ibm/granite-13b-chat-v2";
static final String DEFAULT_EMBEDDING_MODEL_ID = "ibm/slate-30m-english-rtrvr";


// The version date currently used in IBM Prompt Lab endpoints (apr 2024) 2024-04-04
static final String DEFAULT_VERSION_DATE = "2023-05-29";
static final String DEFAULT_REGION = "eu-de";

@Context
public ApocConfig apocConfig;

@Context
public URLAccessChecker urlAccessChecker;

public static final String ENDPOINT_CONF_KEY = "endpoint";

public record EmbeddingResult(long index, String text, List<Double> embedding) {}



@Procedure("apoc.ml.watson.embedding")
@Description("apoc.ml.watson.embedding([texts], $configuration) - returns the embeddings for a given text")
public Stream<EmbeddingResult> embedding(@Name(value = "texts") List<String> texts,
@Name("accessToken") String accessToken,
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration) {
if (texts == null) {
throw new RuntimeException(ERROR_NULL_INPUT);
}

AtomicInteger idx = new AtomicInteger();

return executeRequest(texts, accessToken, configuration, WatsonHandler.Type.EMBEDDING.get())
.flatMap(v -> ((List<Map>) v.get("results")).stream())
.map(i -> {
int index = idx.getAndIncrement();
List<Double> embedding = (List<Double>) i.get("embedding");
return new EmbeddingResult(index, texts.get(index), embedding);
});
}

@Procedure("apoc.ml.watson.chat")
@Description("apoc.ml.watson.chat(messages, accessToken, $configuration) - prompts the completion API")
public Stream<MapResult> chatCompletion(@Name("messages") List<Map<String, Object>> messages, @Name("accessToken") String accessToken, @Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration) throws Exception {
if (messages == null) {
return Stream.of(new MapResult(null));
throw new RuntimeException(ERROR_NULL_INPUT);
}
String prompt = messages.stream()
.map(message -> {
Expand All @@ -60,12 +89,13 @@ public Stream<MapResult> chatCompletion(@Name("messages") List<Map<String, Objec
@Description("apoc.ml.watson.completion(prompt, accessToken, $configuration) - prompts the completion API")
public Stream<MapResult> completion(@Name("prompt") String prompt, @Name("accessToken") String accessToken, @Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration) throws Exception {
if (prompt == null) {
return Stream.of(new MapResult(null));
throw new RuntimeException(ERROR_NULL_INPUT);
}
return executeRequest(prompt, accessToken, configuration);
return executeRequest(prompt, accessToken, configuration, WatsonHandler.Type.COMPLETION.get())
.map(MapResult::new);
}

private Stream<MapResult> executeRequest(Object input, String accessToken, Map<String, Object> configuration) {
private Stream<Map> executeRequest(Object input, String accessToken, Map<String, Object> configuration, WatsonHandler type) {
try {
// the body request has to contain space_id or project_id or wml_instance_crn,
// in case is missing we put the project_id from apoc.conf, otherwise we throw an exception
Expand All @@ -79,32 +109,20 @@ private Stream<MapResult> executeRequest(Object input, String accessToken, Map<S
configuration.put(PROJECT_ID_KEY, apocConfProjectId);
}

String endpoint = getEndpoint(configuration);

var config = new HashMap<>(configuration);
config.putIfAbsent(MODEL_ID_KEY, DEFAULT_MODEL_ID);
config.put("input", input);
String endpoint = type.getEndpoint(configuration);

Map<String, Object> headers = Map.of("Content-Type", "application/json",
"accept", "application/json",
"Authorization", "Bearer " + accessToken);

String payload = JsonUtil.OBJECT_MAPPER.writeValueAsString(config);

Map<String, Object> payloadMap = type.getPayload(configuration, input);
String payload = JsonUtil.OBJECT_MAPPER.writeValueAsString(payloadMap);

return JsonUtil.loadJson(endpoint, headers, payload, "$", true, List.of(), urlAccessChecker)
.map(v -> (Map<String, Object>) v)
.map(MapResult::new);
.map(v -> (Map<String, Object>) v);
} catch (IOException e) {
throw new RuntimeException(e);
}
}

public String getEndpoint(Map<String, Object> config) {
Object remove = config.remove(ENDPOINT_CONF_KEY);
if (remove != null) {
return (String) remove;
}
return apocConfig.getString(APOC_ML_WATSON_URL, "https://eu-de.ml.cloud.ibm.com/ml/v1-beta/generation/text?version=2023-05-29");
}

}
91 changes: 91 additions & 0 deletions extended/src/main/java/apoc/ml/watson/WatsonHandler.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package apoc.ml.watson;

import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

import static apoc.ApocConfig.apocConfig;
import static apoc.ExtendedApocConfig.APOC_ML_WATSON_URL;
import static apoc.ml.MLUtil.*;
import static apoc.ml.watson.Watson.*;

public interface WatsonHandler {

enum Type {
EMBEDDING(new EmbeddingHandler()),
COMPLETION(new CompletionHandler());

private final WatsonHandler handler;

Type(WatsonHandler handler) {
this.handler = handler;
}

public WatsonHandler get() {
return handler;
}
}

// -- interface methods

String getDefaultMethod();
Map<String, Object> getPayload(Map<String, Object> configuration, Object input);

default String getEndpoint(Map<String, Object> config) {
var endpoint = config.remove(ENDPOINT_CONF_KEY);
if (endpoint != null) {
return (String) endpoint;
}

var version = Objects.requireNonNullElse(
config.remove(API_VERSION_CONF_KEY),
DEFAULT_VERSION_DATE
);

var region = Objects.requireNonNullElse(
config.remove(REGION_CONF_KEY),
REGION_CONF_KEY
);

String url = "https://%s.ml.cloud.ibm.com/ml/v1/%s?version=%s".formatted(
region, getDefaultMethod(), version
);

return apocConfig().getString(APOC_ML_WATSON_URL, url);
}


// -- concrete implementations

class EmbeddingHandler implements WatsonHandler {
@Override
public String getDefaultMethod() {
return "text/embeddings";
}

@Override
public Map<String, Object> getPayload(Map<String, Object> configuration, Object input) {
var config = new HashMap<>(configuration);
config.putIfAbsent(MODEL_ID_KEY, DEFAULT_EMBEDDING_MODEL_ID);
config.put("inputs", input);
return config;
}

}

class CompletionHandler implements WatsonHandler {
@Override
public String getDefaultMethod() {
return "text/generation";
}

@Override
public Map<String, Object> getPayload(Map<String, Object> configuration, Object input) {
var config = new HashMap<>(configuration);
config.putIfAbsent(MODEL_ID_KEY, DEFAULT_COMPLETION_MODEL_ID);
config.put("input", input);
return config;
}
}

}
3 changes: 1 addition & 2 deletions extended/src/test/java/apoc/ml/OpenAIAnyScaleIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
import java.util.List;
import java.util.Map;

import static apoc.ml.OpenAI.ENDPOINT_CONF_KEY;
import static apoc.ml.OpenAI.MODEL_CONF_KEY;
import static apoc.ml.MLUtil.*;
import static apoc.ml.OpenAITestResultUtils.*;
import static apoc.util.TestUtil.testCall;
import static org.junit.jupiter.api.Assertions.assertEquals;
Expand Down
3 changes: 1 addition & 2 deletions extended/src/test/java/apoc/ml/OpenAIAzureIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
import java.util.stream.Stream;

import static apoc.ml.OpenAI.API_TYPE_CONF_KEY;
import static apoc.ml.OpenAI.API_VERSION_CONF_KEY;
import static apoc.ml.OpenAI.ENDPOINT_CONF_KEY;
import static apoc.ml.MLUtil.*;
import static apoc.ml.OpenAITestResultUtils.CHAT_COMPLETION_QUERY;
import static apoc.ml.OpenAITestResultUtils.COMPLETION_QUERY;
import static apoc.ml.OpenAITestResultUtils.EMBEDDING_QUERY;
Expand Down

0 comments on commit 35498bf

Please sign in to comment.