From edb4e6de5a3d585fb3431b8d1a3da403c30d1d4c Mon Sep 17 00:00:00 2001 From: Michael Hunger Date: Wed, 17 May 2023 16:54:16 +0200 Subject: [PATCH] Add OpenAI (LLM) procedures (#3575) * WIP * Add completion API * add chatCompletion * prettificiation * WIP * prettify * Refactoring, todo docs & tests * Added Tests & Docs for OpenAI procs (WIP) * Update openai.adoc From @tomasonjo --------- Co-authored-by: Tomaz Bratanic --- docs/asciidoc/modules/ROOT/nav.adoc | 3 + .../asciidoc/modules/ROOT/pages/ml/index.adoc | 10 ++ .../modules/ROOT/pages/ml/openai.adoc | 122 ++++++++++++++++++ extended/src/main/java/apoc/ml/OpenAI.java | 114 ++++++++++++++++ extended/src/test/java/apoc/ml/OpenAIIT.java | 120 +++++++++++++++++ .../src/test/java/apoc/ml/OpenAITest.java | 114 ++++++++++++++++ extended/src/test/resources/chat/completions | 21 +++ extended/src/test/resources/completions | 19 +++ extended/src/test/resources/embeddings | 15 +++ 9 files changed, 538 insertions(+) create mode 100644 docs/asciidoc/modules/ROOT/pages/ml/index.adoc create mode 100644 docs/asciidoc/modules/ROOT/pages/ml/openai.adoc create mode 100644 extended/src/main/java/apoc/ml/OpenAI.java create mode 100644 extended/src/test/java/apoc/ml/OpenAIIT.java create mode 100644 extended/src/test/java/apoc/ml/OpenAITest.java create mode 100644 extended/src/test/resources/chat/completions create mode 100644 extended/src/test/resources/completions create mode 100644 extended/src/test/resources/embeddings diff --git a/docs/asciidoc/modules/ROOT/nav.adoc b/docs/asciidoc/modules/ROOT/nav.adoc index 5400be2b24..fb977f9594 100644 --- a/docs/asciidoc/modules/ROOT/nav.adoc +++ b/docs/asciidoc/modules/ROOT/nav.adoc @@ -56,6 +56,9 @@ include::partial$generated-documentation/nav.adoc[] ** xref:nlp/aws.adoc[] ** xref:nlp/azure.adoc[] +* xref:ml/index.adoc[] + ** xref:ml/openai.adoc[] + * xref:background-operations/index.adoc[] ** xref::background-operations/apoc-load-directory-async.adoc[] diff --git a/docs/asciidoc/modules/ROOT/pages/ml/index.adoc b/docs/asciidoc/modules/ROOT/pages/ml/index.adoc new file mode 100644 index 0000000000..73dd5e3169 --- /dev/null +++ b/docs/asciidoc/modules/ROOT/pages/ml/index.adoc @@ -0,0 +1,10 @@ +[[ml]] += Machine Learning (ML +:description: This chapter describes procedures that can be used for adding Machine Learning (ML) functionality to graph applications. + +The procedures described in this chapter act as wrappers around cloud based Machine Learning APIs. +These procedures generate embeddings, analyze text, complete text, complete chat conversations and more. + +This section includes: + +* xref::ml/openai.adoc[] diff --git a/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc b/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc new file mode 100644 index 0000000000..b143e28343 --- /dev/null +++ b/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc @@ -0,0 +1,122 @@ +[[openai-api]] += OpenAI API Access +:description: This section describes procedures that can be used to access the OpenAI API. + +NOTE: You need to acquire an https://platform.openai.com/account/api-keys[OpenAI API key^] to use these procedures. Using them will incurr costs on your OpenAI account. + +== Generate Embeddings API + +This procedure `apoc.ml.openai.embedding` can take a list of text strings, and will return one row per string, with the embedding data as a 1536 element vector. +It uses the `/embeddings/create` API which is https://platform.openai.com/docs/api-reference/embeddings/create[documented here^]. + +Additional configuration is passed to the API, the default model used is `text-embedding-ada-002`. + +.Generate Embeddings Call +[source,cypher] +---- +CALL apoc.ml.openai.embedding(['Some Text'], $apiKey, {}) yield index, text, embedding; +---- + +.Generate Embeddings Response +[%autowidth, opts=header] +|=== +|index | text | embedding +|0 | "Some Text" | [-0.0065358975, -7.9563365E-4, .... -0.010693862, -0.005087272] +|=== + +.Parameters +[%autowidth, opts=header] +|=== +|name | description +| texts | List of text strings +| apiKey | OpenAI API key +| configuration | optional map for entries like model and other request parameters +|=== + + +.Results +[%autowidth, opts=header] +|=== +|name | description +| index | index entry in original list +| text | line of text from original list +| embedding | 1536 element floating point embedding vector for ada-002 model +|=== + +== Text Completion API + +This procedure `apoc.ml.openai.completion` can continue/complete a given text. + +It uses the `/completions/create` API which is https://platform.openai.com/docs/api-reference/completions/create[documented here^]. + +Additional configuration is passed to the API, the default model used is `text-davinci-003`. + +.Text Completion Call +[source,cypher] +---- +CALL apoc.ml.openai.completion('What color is the sky? Answer in one word: ', $apiKey, {config}) yield value; +---- + +.Text Completion Response +---- +{ created=1684248202, model="text-davinci-003", id="cmpl-7GqBWwX49yMJljdmnLkWxYettZoOy", + usage={completion_tokens=2, prompt_tokens=12, total_tokens=14}, + choices=[{finish_reason="stop", index=0, text="Blue", logprobs=null}], object="text_completion"} +---- + +.Parameters +[%autowidth, opts=header] +|=== +|name | description +| prompt | Text to complete +| apiKey | OpenAI API key +| configuration | optional map for entries like model, temperature, and other request parameters +|=== + +.Results +[%autowidth, opts=header] +|=== +|name | description +| value | result entry from OpenAI (containing) +|=== + +== Chat Completion API + +This procedure `apoc.ml.openai.chat` takes a list of maps of chat exchanges between assistant and user (with optional system message), and will return the next message in the flow. + +It uses the `/chat/create` API which is https://platform.openai.com/docs/api-reference/chat/create[documented here^]. + +Additional configuration is passed to the API, the default model used is `gpt-3.5-turbo`. + +.Chat Completion Call +[source,cypher] +---- +CALL apoc.ml.openai.chat([ +{role:"system", content:"Only answer with a single word"}, +{role:"user", content:"What planet do humans live on?"} +], $apiKey) yield value +---- + +.Chat Completion Response +---- +{created=1684248203, id="chatcmpl-7GqBXZr94avd4fluYDi2fWEz7DIHL", +object="chat.completion", model="gpt-3.5-turbo-0301", +usage={completion_tokens=2, prompt_tokens=26, total_tokens=28}, +choices=[{finish_reason="stop", index=0, message={role="assistant", content="Earth."}}]} +---- + +.Parameters +[%autowidth, opts=header] +|=== +|name | description +| messages | List of maps of instructions with `{role:"assistant|user|system", content:"text}` +| apiKey | OpenAI API key +| configuration | optional map for entries like model, temperature, and other request parameters +|=== + +.Results +[%autowidth, opts=header] +|=== +|name | description +| value | result entry from OpenAI (containing created, id, model, object, usage(tokens), choices(message, index, finish_reason)) +|=== diff --git a/extended/src/main/java/apoc/ml/OpenAI.java b/extended/src/main/java/apoc/ml/OpenAI.java new file mode 100644 index 0000000000..1f696a1540 --- /dev/null +++ b/extended/src/main/java/apoc/ml/OpenAI.java @@ -0,0 +1,114 @@ +package apoc.ml; + +import apoc.Extended; +import apoc.util.JsonUtil; +import com.fasterxml.jackson.core.JsonProcessingException; +import org.neo4j.procedure.Description; +import org.neo4j.procedure.Name; +import org.neo4j.procedure.Procedure; + +import java.net.MalformedURLException; +import java.net.URL; +import java.util.HashMap; +import java.util.Map; +import java.util.List; +import java.util.stream.Stream; + +import apoc.result.MapResult; + +import com.fasterxml.jackson.databind.ObjectMapper; + + +@Extended +public class OpenAI { + + public static final String APOC_ML_OPENAI_URL = "apoc.ml.openai.url"; + + public static class EmbeddingResult { + public final long index; + public final String text; + public final List embedding; + + public EmbeddingResult(long index, String text, List embedding) { + this.index = index; + this.text = text; + this.embedding = embedding; + } + } + + private static Stream executeRequest(String apiKey, Map configuration, String path, String model, String key, Object inputs, String jsonPath) throws JsonProcessingException, MalformedURLException { + if (apiKey == null || apiKey.isBlank()) + throw new IllegalArgumentException("API Key must not be empty"); + String endpoint = System.getProperty(APOC_ML_OPENAI_URL,"https://api.openai.com/v1/"); + Map headers = Map.of( + "Content-Type", "application/json", + "Authorization", "Bearer " + apiKey + ); + + var config = new HashMap<>(configuration); + config.putIfAbsent("model", model); + config.put(key, inputs); + + String payload = new ObjectMapper().writeValueAsString(config); + + var url = new URL(new URL(endpoint), path).toString(); + return JsonUtil.loadJson(url, headers, payload, jsonPath, true, List.of()); + } + + @Procedure("apoc.ml.openai.embedding") + @Description("apoc.openai.embedding([texts], api_key, configuration) - returns the embeddings for a given text") + public Stream getEmbedding(@Name("texts") List texts, @Name("api_key") String apiKey, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { + // https://platform.openai.com/docs/api-reference/embeddings/create + /* + { "object": "list", + "data": [ + { + "object": "embedding", + "embedding": [ 0.0023064255, -0.009327292, .... (1536 floats total for ada-002) -0.0028842222 ], + "index": 0 + } + ], + "model": "text-embedding-ada-002", + "usage": { "prompt_tokens": 8, "total_tokens": 8 } } + */ + Stream resultStream = executeRequest(apiKey, configuration, "embeddings", "text-embedding-ada-002", "input", texts, "$.data"); + return resultStream + .flatMap(v -> ((List>) v).stream()) + .map(m -> { + Long index = (Long) m.get("index"); + return new EmbeddingResult(index, texts.get(index.intValue()), (List) m.get("embedding")); + }); + } + + + @Procedure("apoc.ml.openai.completion") + @Description("apoc.ml.openai.completion(prompt, api_key, configuration) - prompts the completion API") + public Stream completion(@Name("prompt") String prompt, @Name("api_key") String apiKey, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { + // https://platform.openai.com/docs/api-reference/completions/create + /* + { "id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7", + "object": "text_completion", "created": 1589478378, "model": "text-davinci-003", + "choices": [ { "text": "\n\nThis is indeed a test", "index": 0, "logprobs": null, "finish_reason": "length" } ], + "usage": { "prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12 } + } + */ + return executeRequest(apiKey, configuration, "completions", "text-davinci-003", "prompt", prompt, "$") + .map(v -> (Map)v).map(MapResult::new); + } + + @Procedure("apoc.ml.openai.chat") + @Description("apoc.ml.openai.chat(messages, api_key, configuration]) - prompts the completion API") + public Stream chatCompletion(@Name("messages") List> messages, @Name("api_key") String apiKey, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { + return executeRequest(apiKey, configuration, "chat/completions", "gpt-3.5-turbo", "messages", messages, "$") + .map(v -> (Map)v).map(MapResult::new); + // https://platform.openai.com/docs/api-reference/chat/create + /* + { 'id': 'chatcmpl-6p9XYPYSTTRi0xEviKjjilqrWU2Ve', 'object': 'chat.completion', 'created': 1677649420, 'model': 'gpt-3.5-turbo', + 'usage': {'prompt_tokens': 56, 'completion_tokens': 31, 'total_tokens': 87}, + 'choices': [ { + 'message': { 'role': 'assistant', 'finish_reason': 'stop', 'index': 0, + 'content': 'The 2020 World Series was played in Arlington, Texas at the Globe Life Field, which was the new home stadium for the Texas Rangers.'} + } ] } + */ + } +} \ No newline at end of file diff --git a/extended/src/test/java/apoc/ml/OpenAIIT.java b/extended/src/test/java/apoc/ml/OpenAIIT.java new file mode 100644 index 0000000000..d66caeaf67 --- /dev/null +++ b/extended/src/test/java/apoc/ml/OpenAIIT.java @@ -0,0 +1,120 @@ +package apoc.ml; + +import apoc.util.TestUtil; +import org.junit.Assume; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.neo4j.test.rule.DbmsRule; +import org.neo4j.test.rule.ImpermanentDbmsRule; + +import java.nio.file.Paths; +import java.util.List; +import java.util.Map; + +import static apoc.ApocConfig.APOC_IMPORT_FILE_ENABLED; +import static apoc.ApocConfig.apocConfig; +import static apoc.util.TestUtil.getUrlFileName; +import static apoc.util.TestUtil.testCall; +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class OpenAIIT { + + private String openaiKey; + + @Rule + public DbmsRule db = new ImpermanentDbmsRule(); + + public OpenAIIT() { + } + + @Before + public void setUp() throws Exception { + openaiKey = System.getenv("OPENAI_KEY"); + Assume.assumeNotNull("No OPENAI_KEY environment configured", openaiKey); + TestUtil.registerProcedure(db, OpenAI.class); + } + + @Test + public void getEmbedding() { + testCall(db, "CALL apoc.ml.openai.embedding(['Some Text'], $apiKey)", Map.of("apiKey",openaiKey),(row) -> { + System.out.println("row = " + row); + assertEquals(0L, row.get("index")); + assertEquals("Some Text", row.get("text")); + var embedding = (List) row.get("embedding"); + assertEquals(1536, embedding.size()); + assertEquals(true, embedding.stream().allMatch(d -> d instanceof Double)); + }); + } + + @Test + public void completion() { + testCall(db, "CALL apoc.ml.openai.completion('What color is the sky? Answer in one word: ', $apiKey)", + Map.of("apiKey",openaiKey),(row) -> { + System.out.println("row = " + row); + var result = (Map)row.get("value"); + assertEquals(true, result.get("created") instanceof Number); + assertEquals(true, result.containsKey("choices")); + var finishReason = (String)((List) result.get("choices")).get(0).get("finish_reason"); + assertEquals(true, finishReason.matches("stop|length")); + String text = (String) ((List) result.get("choices")).get(0).get("text"); + assertEquals(true, text != null && !text.isBlank()); + assertEquals(true, text.toLowerCase().contains("blue")); + assertEquals(true, result.containsKey("usage")); + assertEquals(true, ((Map)result.get("usage")).get("prompt_tokens") instanceof Number); + assertEquals("text-davinci-003", result.get("model")); + assertEquals("text_completion", result.get("object")); + }); + } + + @Test + public void chatCompletion() { + testCall(db, """ +CALL apoc.ml.openai.chat([ +{role:"system", content:"Only answer with a single word"}, +{role:"user", content:"What planet do humans live on?"} +], $apiKey) +""", Map.of("apiKey",openaiKey), (row) -> { + System.out.println("row = " + row); + var result = (Map)row.get("value"); + assertEquals(true, result.get("created") instanceof Number); + assertEquals(true, result.containsKey("choices")); + + Map message = ((List>) result.get("choices")).get(0).get("message"); + assertEquals("assistant", message.get("role")); + // assertEquals("stop", message.get("finish_reason")); + String text = (String) message.get("content"); + assertEquals(true, text != null && !text.isBlank()); + + + assertEquals(true, result.containsKey("usage")); + assertEquals(true, ((Map)result.get("usage")).get("prompt_tokens") instanceof Number); + assertEquals("gpt-3.5-turbo-0301", result.get("model")); + assertEquals("chat.completion", result.get("object")); + }); + + /* + { + "id": "chatcmpl-6p9XYPYSTTRi0xEviKjjilqrWU2Ve", + "object": "chat.completion", + "created": 1677649420, + "model": "gpt-3.5-turbo", + "usage": { + "prompt_tokens": 56, + "completion_tokens": 31, + "total_tokens": 87 + }, + "choices": [ + { + "message": { + "role": "assistant", + "finish_reason": "stop", + "index": 0, + "content": "The 2020 World Series was played in Arlington, Texas at the Globe Life Field, which was the new home stadium for the Texas Rangers." + } + } + ] +} + */ + } +} \ No newline at end of file diff --git a/extended/src/test/java/apoc/ml/OpenAITest.java b/extended/src/test/java/apoc/ml/OpenAITest.java new file mode 100644 index 0000000000..87dc84b100 --- /dev/null +++ b/extended/src/test/java/apoc/ml/OpenAITest.java @@ -0,0 +1,114 @@ +package apoc.ml; + +import apoc.util.TestUtil; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.neo4j.test.rule.DbmsRule; +import org.neo4j.test.rule.ImpermanentDbmsRule; + +import java.nio.file.Paths; +import java.util.List; +import java.util.Map; + +import static apoc.ApocConfig.APOC_IMPORT_FILE_ENABLED; +import static apoc.ApocConfig.apocConfig; +import static apoc.util.TestUtil.getUrlFileName; +import static apoc.util.TestUtil.testCall; +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class OpenAITest { + + private String openaiKey; + + @Rule + public DbmsRule db = new ImpermanentDbmsRule(); + + public OpenAITest() { + } + + @Before + public void setUp() throws Exception { + // openaiKey = System.getenv("OPENAI_KEY"); + // Assume.assumeNotNull("No OPENAI_KEY environment configured", openaiKey); + var path = Paths.get(getUrlFileName("embeddings").toURI()).getParent().toUri(); + System.setProperty(OpenAI.APOC_ML_OPENAI_URL, path.toString()); + apocConfig().setProperty(APOC_IMPORT_FILE_ENABLED, true); + TestUtil.registerProcedure(db, OpenAI.class); + } + + @Test + public void getEmbedding() { + testCall(db, "CALL apoc.ml.openai.embedding(['Some Text'], 'fake-api-key')", (row) -> { + assertEquals(0L, row.get("index")); + assertEquals("Some Text", row.get("text")); + assertEquals(List.of(0.0023064255, -0.009327292, -0.0028842222), row.get("embedding")); + }); + } + + @Test + public void completion() { + testCall(db, "CALL apoc.ml.openai.completion('What color is the sky? Answer: ', 'fake-api-key')", (row) -> { + var result = (Map)row.get("value"); + assertEquals(true, result.get("created") instanceof Number); + assertEquals(true, result.containsKey("choices")); + assertEquals("stop", ((List)result.get("choices")).get(0).get("finish_reason")); + String text = (String) ((List) result.get("choices")).get(0).get("text"); + assertEquals(true, text != null && !text.isBlank()); + assertEquals(true, result.containsKey("usage")); + assertEquals(true, ((Map)result.get("usage")).get("prompt_tokens") instanceof Number); + assertEquals("text-davinci-003", result.get("model")); + assertEquals("text_completion", result.get("object")); + }); + } + + @Test + public void chatCompletion() { + testCall(db, """ +CALL apoc.ml.openai.chat([ +{role:"system", content:"Only answer with a single word"}, +{role:"user", content:"What planet do humans live on?"} +], 'fake-api-key') +""", (row) -> { + var result = (Map)row.get("value"); + assertEquals(true, result.get("created") instanceof Number); + assertEquals(true, result.containsKey("choices")); + + Map message = ((List>) result.get("choices")).get(0).get("message"); + assertEquals("assistant", message.get("role")); + assertEquals("stop", message.get("finish_reason")); + String text = (String) message.get("content"); + assertEquals(true, text != null && !text.isBlank()); + + + assertEquals(true, result.containsKey("usage")); + assertEquals(true, ((Map)result.get("usage")).get("prompt_tokens") instanceof Number); + assertEquals("gpt-3.5-turbo-0301", result.get("model")); + assertEquals("chat.completion", result.get("object")); + }); + + /* + { + "id": "chatcmpl-6p9XYPYSTTRi0xEviKjjilqrWU2Ve", + "object": "chat.completion", + "created": 1677649420, + "model": "gpt-3.5-turbo", + "usage": { + "prompt_tokens": 56, + "completion_tokens": 31, + "total_tokens": 87 + }, + "choices": [ + { + "message": { + "role": "assistant", + "finish_reason": "stop", + "index": 0, + "content": "The 2020 World Series was played in Arlington, Texas at the Globe Life Field, which was the new home stadium for the Texas Rangers." + } + } + ] +} + */ + } +} \ No newline at end of file diff --git a/extended/src/test/resources/chat/completions b/extended/src/test/resources/chat/completions new file mode 100644 index 0000000000..eeb37e6a46 --- /dev/null +++ b/extended/src/test/resources/chat/completions @@ -0,0 +1,21 @@ +{ + "id": "chatcmpl-6p9XYPYSTTRi0xEviKjjilqrWU2Ve", + "object": "chat.completion", + "created": 1677649420, + "model": "gpt-3.5-turbo-0301", + "usage": { + "prompt_tokens": 56, + "completion_tokens": 31, + "total_tokens": 87 + }, + "choices": [ + { + "message": { + "role": "assistant", + "finish_reason": "stop", + "index": 0, + "content": "The 2020 World Series was played in Arlington, Texas at the Globe Life Field, which was the new home stadium for the Texas Rangers." + } + } + ] +} \ No newline at end of file diff --git a/extended/src/test/resources/completions b/extended/src/test/resources/completions new file mode 100644 index 0000000000..7fb6cc274f --- /dev/null +++ b/extended/src/test/resources/completions @@ -0,0 +1,19 @@ +{ + "id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7", + "object": "text_completion", + "created": 1589478378, + "model": "text-davinci-003", + "choices": [ + { + "text": "\n\nThis is indeed a test", + "index": 0, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 5, + "completion_tokens": 7, + "total_tokens": 12 + } +} diff --git a/extended/src/test/resources/embeddings b/extended/src/test/resources/embeddings new file mode 100644 index 0000000000..f1450b25de --- /dev/null +++ b/extended/src/test/resources/embeddings @@ -0,0 +1,15 @@ +{ + "object": "list", + "data": [ + { + "object": "embedding", + "embedding": [0.0023064255, -0.009327292, -0.0028842222], + "index": 0 + } + ], + "model": "text-embedding-ada-002", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } +} \ No newline at end of file