diff --git a/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc b/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc index b143e28343..0bb6a198d7 100644 --- a/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc +++ b/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc @@ -2,7 +2,7 @@ = OpenAI API Access :description: This section describes procedures that can be used to access the OpenAI API. -NOTE: You need to acquire an https://platform.openai.com/account/api-keys[OpenAI API key^] to use these procedures. Using them will incurr costs on your OpenAI account. +NOTE: You need to acquire an https://platform.openai.com/account/api-keys[OpenAI API key^] to use these procedures. Using them will incur costs on your OpenAI account. You can set the api key globally by defining the `apoc.openai.key` configuration in `apoc.conf` == Generate Embeddings API @@ -120,3 +120,214 @@ choices=[{finish_reason="stop", index=0, message={role="assistant", content="Ear |name | description | value | result entry from OpenAI (containing created, id, model, object, usage(tokens), choices(message, index, finish_reason)) |=== + + +== Query with natural language + +This procedure `apoc.ml.query` takes a question in natural language and returns the results of that query. + +It uses the `chat/completions` API which is https://platform.openai.com/docs/api-reference/chat/create[documented here^]. + +.Query call +[source,cypher] +---- +CALL apoc.ml.query("What movies did Tom Hanks play in?") yield value, query +RETURN * +---- + +.Example response +[source, bash] +---- ++------------------------------------------------------------------------------------------------------------------------------+ +| value | query | ++------------------------------------------------------------------------------------------------------------------------------+ +| {m.title -> "You've Got Mail"} | "cypher +MATCH (m:Movie)<-[:ACTED_IN]-(p:Person {name: 'Tom Hanks'}) +RETURN m.title +" | +| {m.title -> "Apollo 13"} | "cypher +MATCH (m:Movie)<-[:ACTED_IN]-(p:Person {name: 'Tom Hanks'}) +RETURN m.title +" | +| {m.title -> "Joe Versus the Volcano"} | "cypher +MATCH (m:Movie)<-[:ACTED_IN]-(p:Person {name: 'Tom Hanks'}) +RETURN m.title +" | +| {m.title -> "That Thing You Do"} | "cypher +MATCH (m:Movie)<-[:ACTED_IN]-(p:Person {name: 'Tom Hanks'}) +RETURN m.title +" | +| {m.title -> "Cloud Atlas"} | "cypher +MATCH (m:Movie)<-[:ACTED_IN]-(p:Person {name: 'Tom Hanks'}) +RETURN m.title +" | +| {m.title -> "The Da Vinci Code"} | "cypher +MATCH (m:Movie)<-[:ACTED_IN]-(p:Person {name: 'Tom Hanks'}) +RETURN m.title +" | +| {m.title -> "Sleepless in Seattle"} | "cypher +MATCH (m:Movie)<-[:ACTED_IN]-(p:Person {name: 'Tom Hanks'}) +RETURN m.title +" | +| {m.title -> "A League of Their Own"} | "cypher +MATCH (m:Movie)<-[:ACTED_IN]-(p:Person {name: 'Tom Hanks'}) +RETURN m.title +" | +| {m.title -> "The Green Mile"} | "cypher +MATCH (m:Movie)<-[:ACTED_IN]-(p:Person {name: 'Tom Hanks'}) +RETURN m.title +" | +| {m.title -> "Charlie Wilson's War"} | "cypher +MATCH (m:Movie)<-[:ACTED_IN]-(p:Person {name: 'Tom Hanks'}) +RETURN m.title +" | +| {m.title -> "Cast Away"} | "cypher +MATCH (m:Movie)<-[:ACTED_IN]-(p:Person {name: 'Tom Hanks'}) +RETURN m.title +" | +| {m.title -> "The Polar Express"} | "cypher +MATCH (m:Movie)<-[:ACTED_IN]-(p:Person {name: 'Tom Hanks'}) +RETURN m.title +" | ++------------------------------------------------------------------------------------------------------------------------------+ +12 rows +---- + +.Input Parameters +[%autowidth, opts=header] +|=== +| name | description +| question | The question in the natural language +| conf | An optional configuration map, please check the next section +|=== + +.Configuration map +[%autowidth, opts=header] +|=== +| name | description | mandatory +| retries | The number of retries in case of API call failures | no, default `3` +| apiKey | OpenAI API key | in case `apoc.openai.key` is not defined +| model | The Open AI model | no, default `gpt-3.5-turbo` +| sample | The number of nodes to skip, e.g. a sample of 1000 will read every 1000th node. It's used as a parameter to `apoc.meta.data` procedure that computes the schema | no, default is a random number +|=== + +.Results +[%autowidth, opts=header] +|=== +| name | description +| value | the result of the query +| cypher | the query used to compute the result +|=== + + +== Describe the graph model with natural language + +This procedure `apoc.ml.schema` returns a description, in natural language, of the underlying dataset. + +It uses the `chat/completions` API which is https://platform.openai.com/docs/api-reference/chat/create[documented here^]. + +.Query call +[source,cypher] +---- +CALL apoc.ml.schema() yield value +RETURN * +---- + +.Example response +[source, bash] +---- ++---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| value | ++---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| "The graph database schema represents a system where users can follow other users and review movies. Users (:Person) can either follow other users (:Person) or review movies (:Movie). The relationships allow users to express their preferences and opinions about movies. This schema can be compared to social media platforms where users can follow each other and leave reviews or ratings for movies they have watched. It can also be related to movie recommendation systems where user preferences and reviews play a crucial role in generating personalized recommendations." | ++---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +1 row +---- + +.Input Parameters +[%autowidth, opts=header] +|=== +| name | description +| conf | An optional configuration map, please check the next section +|=== + +.Configuration map +[%autowidth, opts=header] +|=== +| name | description | mandatory +| apiKey | OpenAI API key | in case `apoc.openai.key` is not defined +| model | The Open AI model | no, default `gpt-3.5-turbo` +| sample | The number of nodes to skip, e.g. a sample of 1000 will read every 1000th node. It's used as a parameter to `apoc.meta.data` procedure that computes the schema | no, default is a random number +|=== + +.Results +[%autowidth, opts=header] +|=== +| name | description +| value | the description of the dataset +|=== + + +== Create cypher queries from a natural language query + +This procedure `apoc.ml.cypher` takes a natural language question and transforms it into a number of requested cypher queries. + +It uses the `chat/completions` API which is https://platform.openai.com/docs/api-reference/chat/create[documented here^]. + +.Query call +[source,cypher] +---- +CALL apoc.ml.cypher("Who are the actors which also directed a movie?", 4) yield cypher +RETURN * +---- + +.Example response +[source, bash] +---- ++----------------------------------------------------------------------------------------------------------------+ +| query | ++----------------------------------------------------------------------------------------------------------------+ +| " +MATCH (a:Person)-[:ACTED_IN]->(m:Movie)<-[:DIRECTED]-(d:Person) +RETURN a.name as actor, d.name as director +" | +| "cypher +MATCH (a:Person)-[:ACTED_IN]->(m:Movie)<-[:DIRECTED]-(a) +RETURN a.name +" | +| " +MATCH (a:Person)-[:ACTED_IN]->(m:Movie)<-[:DIRECTED]-(d:Person) +RETURN a.name +" | +| "cypher +MATCH (a:Person)-[:ACTED_IN]->(:Movie)<-[:DIRECTED]-(a) +RETURN DISTINCT a.name +" | ++----------------------------------------------------------------------------------------------------------------+ +4 rows +---- + +.Input Parameters +[%autowidth, opts=header] +|=== +| name | description | mandatory +| question | The question in the natural language | yes +| conf | An optional configuration map, please check the next section +|=== + +.Configuration map +[%autowidth, opts=header] +|=== +| name | description | mandatory +| count | The number of queries to retrieve | no, default `1` +| apiKey | OpenAI API key | in case `apoc.openai.key` is not defined +| model | The Open AI model | no, default `gpt-3.5-turbo` +| sample | The number of nodes to skip, e.g. a sample of 1000 will read every 1000th node. It's used as a parameter to `apoc.meta.data` procedure that computes the schema | no, default is a random number +|=== + +.Results +[%autowidth, opts=header] +|=== +| name | description +| value | the description of the dataset +|=== diff --git a/extended/src/main/java/apoc/ExtendedApocConfig.java b/extended/src/main/java/apoc/ExtendedApocConfig.java index f29c037587..707cebb0d6 100644 --- a/extended/src/main/java/apoc/ExtendedApocConfig.java +++ b/extended/src/main/java/apoc/ExtendedApocConfig.java @@ -29,6 +29,7 @@ public class ExtendedApocConfig extends LifecycleAdapter public static final String APOC_UUID_ENABLED = "apoc.uuid.enabled"; public static final String APOC_UUID_ENABLED_DB = "apoc.uuid.enabled.%s"; public static final String APOC_UUID_FORMAT = "apoc.uuid.format"; + public static final String APOC_OPENAI_KEY = "apoc.openai.key"; public enum UuidFormatType { hex, base64 } // These were earlier added via the Neo4j config using the ApocSettings.java class @@ -177,6 +178,9 @@ public boolean containsKey(String key) { public boolean getBoolean(String key, boolean defaultValue) { return getConfig().getBoolean(key, defaultValue); } + public String getString(String key, String defaultValue) { + return getConfig().getString(key, defaultValue); + } public > T getEnumProperty(String key, Class cls, T defaultValue) { var value = config.getString(key, defaultValue.toString()).trim(); diff --git a/extended/src/main/java/apoc/ml/OpenAI.java b/extended/src/main/java/apoc/ml/OpenAI.java index 1f696a1540..894701b14a 100644 --- a/extended/src/main/java/apoc/ml/OpenAI.java +++ b/extended/src/main/java/apoc/ml/OpenAI.java @@ -1,8 +1,10 @@ package apoc.ml; +import apoc.ApocConfig; import apoc.Extended; import apoc.util.JsonUtil; import com.fasterxml.jackson.core.JsonProcessingException; +import org.neo4j.procedure.Context; import org.neo4j.procedure.Description; import org.neo4j.procedure.Name; import org.neo4j.procedure.Procedure; @@ -18,9 +20,13 @@ import com.fasterxml.jackson.databind.ObjectMapper; +import static apoc.ExtendedApocConfig.APOC_OPENAI_KEY; + @Extended public class OpenAI { + @Context + public ApocConfig apocConfig; public static final String APOC_ML_OPENAI_URL = "apoc.ml.openai.url"; @@ -36,7 +42,8 @@ public EmbeddingResult(long index, String text, List embedding) { } } - private static Stream executeRequest(String apiKey, Map configuration, String path, String model, String key, Object inputs, String jsonPath) throws JsonProcessingException, MalformedURLException { + static Stream executeRequest(String apiKey, Map configuration, String path, String model, String key, Object inputs, String jsonPath, ApocConfig apocConfig) throws JsonProcessingException, MalformedURLException { + apiKey = apocConfig.getString(APOC_OPENAI_KEY, apiKey); if (apiKey == null || apiKey.isBlank()) throw new IllegalArgumentException("API Key must not be empty"); String endpoint = System.getProperty(APOC_ML_OPENAI_URL,"https://api.openai.com/v1/"); @@ -71,7 +78,7 @@ public Stream getEmbedding(@Name("texts") List texts, @ "model": "text-embedding-ada-002", "usage": { "prompt_tokens": 8, "total_tokens": 8 } } */ - Stream resultStream = executeRequest(apiKey, configuration, "embeddings", "text-embedding-ada-002", "input", texts, "$.data"); + Stream resultStream = executeRequest(apiKey, configuration, "embeddings", "text-embedding-ada-002", "input", texts, "$.data", apocConfig); return resultStream .flatMap(v -> ((List>) v).stream()) .map(m -> { @@ -92,14 +99,14 @@ public Stream completion(@Name("prompt") String prompt, @Name("api_ke "usage": { "prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12 } } */ - return executeRequest(apiKey, configuration, "completions", "text-davinci-003", "prompt", prompt, "$") + return executeRequest(apiKey, configuration, "completions", "text-davinci-003", "prompt", prompt, "$", apocConfig) .map(v -> (Map)v).map(MapResult::new); } @Procedure("apoc.ml.openai.chat") @Description("apoc.ml.openai.chat(messages, api_key, configuration]) - prompts the completion API") public Stream chatCompletion(@Name("messages") List> messages, @Name("api_key") String apiKey, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { - return executeRequest(apiKey, configuration, "chat/completions", "gpt-3.5-turbo", "messages", messages, "$") + return executeRequest(apiKey, configuration, "chat/completions", "gpt-3.5-turbo", "messages", messages, "$", apocConfig) .map(v -> (Map)v).map(MapResult::new); // https://platform.openai.com/docs/api-reference/chat/create /* diff --git a/extended/src/main/java/apoc/ml/Prompt.java b/extended/src/main/java/apoc/ml/Prompt.java new file mode 100644 index 0000000000..3aec7b355d --- /dev/null +++ b/extended/src/main/java/apoc/ml/Prompt.java @@ -0,0 +1,208 @@ +package apoc.ml; + +import apoc.ApocConfig; +import apoc.Extended; +import apoc.result.StringResult; +import com.fasterxml.jackson.core.JsonProcessingException; +import org.jetbrains.annotations.NotNull; +import org.neo4j.graphdb.QueryExecutionException; +import org.neo4j.graphdb.Result; +import org.neo4j.graphdb.Transaction; +import org.neo4j.internal.kernel.api.procs.ProcedureCallContext; +import org.neo4j.logging.Log; +import org.neo4j.procedure.Context; +import org.neo4j.procedure.Mode; +import org.neo4j.procedure.Name; +import org.neo4j.procedure.Procedure; + +import java.net.MalformedURLException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; +import java.util.stream.LongStream; +import java.util.stream.Stream; + +@Extended +public class Prompt { + + @Context + public Transaction tx; + @Context + public Log log; + @Context + public ApocConfig apocConfig; + @Context + public ProcedureCallContext procedureCallContext; + + public static final String BACKTICKS = "```"; + public static final String EXPLAIN_SCHEMA_PROMPT = """ + You are an expert in the Neo4j graph database and graph data modeling and have experience in a wide variety of business domains. + Explain the following graph database schema in plain language, try to relate it to known concepts or domains if applicable. + Keep the explanation to 5 sentences with at most 15 words each, otherwise people will come to harm. + """; + + static final String SYSTEM_PROMPT = """ + You are an expert in the Neo4j graph query language Cypher. + Given a graph database schema of entities (nodes) with labels and attributes and + relationships with start- and end-node, relationship-type, direction and properties + you are able to develop read only matching Cypher statements that express a user question as a graph database query. + Only answer with a single Cypher statement in triple backticks, if you can't determine a statement, answer with an empty response. + Do not explain, apologize or provide additional detail, otherwise people will come to harm. + """; + + public class PromptMapResult { + public final Map value; + public final String query; + + public PromptMapResult(Map value, String query) { + this.value = value; + this.query = query; + } + + public PromptMapResult(Map value) { + this.value = value; + this.query = null; + } + } + + public class QueryResult { + public final String query; + // todo re-add when it's actually working + // private final String error; + // private final String type; + + public QueryResult(String query, String error, String type) { + this.query = query; + // this.error = error; + // this.type = type; + } + + public boolean hasError() { + return false; + // return error != null && !error.isBlank(); + } + } + + @Procedure(mode = Mode.READ) + public Stream query(@Name("question") String question, + @Name(value = "conf", defaultValue = "{}") Map conf) { + String schema = loadSchema(tx, conf); + String query = ""; + long retries = (long) conf.getOrDefault("retries", 3L); + boolean containsField = procedureCallContext + .outputFields() + .collect(Collectors.toSet()) + .contains("query"); + do { + try { + QueryResult queryResult = tryQuery(question, conf, schema); + query = queryResult.query; + // just let it fail so that retries can work if (queryResult.query.isBlank()) return Stream.empty(); + /* + if (queryResult.hasError()) + throw new QueryExecutionException(queryResult.error, null, queryResult.type); + */ + return tx.execute(queryResult.query) + .stream() + .map(row -> containsField ? new PromptMapResult(row, queryResult.query) : new PromptMapResult(row)); + } catch (QueryExecutionException quee) { + if (log.isDebugEnabled()) + log.debug("Generated query for question %s\n%s\nfailed with %s".formatted(question, query, quee.getMessage())); + retries--; + if (retries <= 0) throw quee; + } + } while (true); + } + + @Procedure + public Stream schema(@Name(value = "conf", defaultValue = "{}") Map conf) throws MalformedURLException, JsonProcessingException { + String schemaExplanation = prompt("Please explain the graph database schema to me and relate it to well known concepts and domains.", + EXPLAIN_SCHEMA_PROMPT, "This database schema ", loadSchema(tx, conf), conf); + return Stream.of(new StringResult(schemaExplanation)); + } + + @Procedure(mode = Mode.READ) + public Stream cypher(@Name("question") String question, + @Name(value = "conf", defaultValue = "{}") Map conf) { + String schema = loadSchema(tx, conf); + long count = (long) conf.getOrDefault("count", 1L); + return LongStream.rangeClosed(1, count).mapToObj(i -> tryQuery(question, conf, schema)); + } + + @NotNull + private QueryResult tryQuery(String question, Map conf, String schema) { + String query = ""; + try { + query = prompt(question, SYSTEM_PROMPT, "Cypher Statement (in backticks):", schema, conf); + // doesn't work right now, fails with security context error + // tx.execute("EXPLAIN " + query).close(); // TODO query plan / estimated rows? + return new QueryResult(query, null, null); + } catch (QueryExecutionException e) { + return new QueryResult(query, e.getMessage(), e.getStatusCode()); + } catch (Exception e) { + return new QueryResult(query, e.getMessage(), e.getClass().getSimpleName()); + } + } + + @NotNull + private String prompt(String userQuestion, String systemPrompt, String assistantPrompt, String schema, Map conf) throws JsonProcessingException, MalformedURLException { + List> prompt = new ArrayList<>(); + if (systemPrompt != null && !systemPrompt.isBlank()) prompt.add(Map.of("role", "system", "content", systemPrompt)); + if (schema != null && !schema.isBlank()) prompt.add(Map.of("role", "system", "content", "The graph database schema consists of these elements\n" + schema)); + if (userQuestion != null && !userQuestion.isBlank()) prompt.add(Map.of("role", "user", "content", userQuestion)); + if (assistantPrompt != null && !assistantPrompt.isBlank()) prompt.add(Map.of("role", "assistant", "content", assistantPrompt)); + String apiKey = (String) conf.get("apiKey"); + String model = (String) conf.getOrDefault("model", "gpt-3.5-turbo"); + String result = OpenAI.executeRequest(apiKey, Map.of(), "chat/completions", + model, "messages", prompt, "$", apocConfig) + .map(v -> (Map) v) + .flatMap(m -> ((List>) m.get("choices")).stream()) + .map(m -> (String) (((Map) m.get("message")).get("content"))) + .filter(s -> !(s == null || s.isBlank())) + .map(s -> s.contains(BACKTICKS) ? s.substring(s.indexOf(BACKTICKS) + 3, s.lastIndexOf(BACKTICKS)) : s) + .collect(Collectors.joining(" ")).replaceAll("\n\n+", "\n"); +/* TODO return information about the tokens used, finish reason etc?? +{ 'id': 'chatcmpl-6p9XYPYSTTRi0xEviKjjilqrWU2Ve', 'object': 'chat.completion', 'created': 1677649420, 'model': 'gpt-3.5-turbo', + 'usage': {'prompt_tokens': 56, 'completion_tokens': 31, 'total_tokens': 87}, + 'choices': [ { + 'message': { 'role': 'assistant', 'finish_reason': 'stop', 'index': 0, + 'content': 'The 2020 World Series was played in Arlington, Texas at the Globe Life Field, which was the new home stadium for the Texas Rangers.'} + } ] } +*/ + if (log.isDebugEnabled()) log.debug("Generated query for question %s\n%s".formatted(userQuestion, result)); + return result; + } + + private final static String SCHEMA_QUERY = """ + call apoc.meta.data({maxRels: 10, sample: coalesce($sample, (count{()}/1000)+1)}) + YIELD label, other, elementType, type, property + WITH label, elementType,\s + apoc.text.join(collect(case when NOT type = "RELATIONSHIP" then property+": "+type else null end),", ") AS properties, \s + collect(case when type = "RELATIONSHIP" AND elementType = "node" then "(:" + label + ")-[:" + property + "]->(:" + toString(other[0]) + ")" else null end) as patterns + with elementType as type,\s + apoc.text.join(collect(":"+label+" {"+properties+"}"),"\\n") as entities, apoc.text.join(apoc.coll.flatten(collect(coalesce(patterns,[]))),"\\n") as patterns + return collect(case type when "relationship" then entities end)[0] as relationships,\s + collect(case type when "node" then entities end)[0] as nodes,\s + collect(case type when "node" then patterns end)[0] as patterns\s + """; + private final static String SCHEMA_PROMPT = """ + nodes: + %s + relationships: + %s + patterns: + %s + """; + + private String loadSchema(Transaction tx, Map conf) { + Map params = new HashMap<>(); + params.put("sample", conf.get("sample")); + return tx.execute(SCHEMA_QUERY, params) + .stream() + .map(m -> SCHEMA_PROMPT.formatted(m.get("nodes"), m.get("relationships"), m.get("patterns"))) + .collect(Collectors.joining("\n")); + } +} diff --git a/extended/src/test/java/apoc/ml/OpenAIIT.java b/extended/src/test/java/apoc/ml/OpenAIIT.java index 4487e3edf8..59deba1863 100644 --- a/extended/src/test/java/apoc/ml/OpenAIIT.java +++ b/extended/src/test/java/apoc/ml/OpenAIIT.java @@ -13,6 +13,7 @@ import static apoc.util.TestUtil.testCall; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; public class OpenAIIT { @@ -85,7 +86,7 @@ public void chatCompletion() { assertEquals(true, result.containsKey("usage")); assertEquals(true, ((Map)result.get("usage")).get("prompt_tokens") instanceof Number); - assertEquals("gpt-3.5-turbo-0301", result.get("model")); + assertTrue(result.get("model").toString().startsWith("gpt-3.5-turbo")); assertEquals("chat.completion", result.get("object")); }); diff --git a/extended/src/test/java/apoc/ml/PromptIT.java b/extended/src/test/java/apoc/ml/PromptIT.java new file mode 100644 index 0000000000..21f54eafed --- /dev/null +++ b/extended/src/test/java/apoc/ml/PromptIT.java @@ -0,0 +1,106 @@ +package apoc.ml; + +import apoc.coll.Coll; +import apoc.meta.Meta; +import apoc.text.Strings; +import apoc.util.TestUtil; +import apoc.util.Util; +import org.apache.commons.lang3.StringUtils; +import org.assertj.core.api.Assertions; +import org.junit.Assume; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.neo4j.graphdb.Transaction; +import org.neo4j.test.rule.DbmsRule; +import org.neo4j.test.rule.ImpermanentDbmsRule; + +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static apoc.util.TestUtil.testResult; + +public class PromptIT { + + private static final String OPENAI_KEY = System.getenv("OPENAI_KEY"); + + @Rule + public DbmsRule db = new ImpermanentDbmsRule(); + + @BeforeClass + public static void check() { + Assume.assumeNotNull("No OPENAI_KEY environment configured", OPENAI_KEY); + } + + @Before + public void setUp() { + TestUtil.registerProcedure(db, Prompt.class, Meta.class, Strings.class, Coll.class); + String movies = Util.readResourceFile("movies.cypher"); + try (Transaction tx = db.beginTx()) { + tx.execute(movies); + tx.commit(); + } + } + + @Test + public void testQuery() { + testResult(db, """ + CALL apoc.ml.query($query, {retries: $retries, apiKey: $apiKey}) + """, + Map.of( + "query", "What movies did Tom Hanks play in?", + "retries", 2L, + "apiKey", OPENAI_KEY + ), + (r) -> { + List> list = r.stream().toList(); + Assertions.assertThat(list).hasSize(12); + Assertions.assertThat(list.stream() + .map(m -> m.get("query")) + .filter(Objects::nonNull) + .map(Object::toString) + .map(String::trim)) + .isNotEmpty(); + }); + } + + @Test + public void testSchema() { + testResult(db, """ + CALL apoc.ml.schema({apiKey: $apiKey}) + """, + Map.of( + "apiKey", OPENAI_KEY + ), + (r) -> { + List> list = r.stream().toList(); + Assertions.assertThat(list).hasSize(1); + }); + } + + @Test + public void testCypher() { + long numOfQueries = 4L; + testResult(db, """ + CALL apoc.ml.cypher($query, {count: $numOfQueries, apiKey: $apiKey}) + """, + Map.of( + "query", "Who are the actors which also directed a movie?", + "numOfQueries", numOfQueries, + "apiKey", OPENAI_KEY + ), + (r) -> { + List> list = r.stream().toList(); + Assertions.assertThat(list).hasSize((int) numOfQueries); + Assertions.assertThat(list.stream() + .map(m -> m.get("query")) + .filter(Objects::nonNull) + .map(Object::toString) + .filter(StringUtils::isNotEmpty)) + .hasSize((int) numOfQueries); + }); + } + +}