Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #4004: Given (a set of) queries return the schema + explanation of the subgraph #4065

Merged
merged 2 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions docs/asciidoc/modules/ROOT/pages/ml/openai.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,84 @@ overall, this graph database schema provides a simple yet powerful representatio
|===


.Results
[%autowidth, opts=header]
|===
| name | description
| value | the description of the dataset
|===


== Create explanation of the subgraph from a set of queries

This procedure `apoc.ml.fromQueries` returns an explanation, in natural language, of the given set of queries.

It uses the `chat/completions` API which is https://platform.openai.com/docs/api-reference/chat/create[documented here^].

.Query call
[source,cypher]
----
CALL apoc.ml.fromQueries(['MATCH (n:Movie) RETURN n', 'MATCH (n:Person) RETURN n'],
{apiKey: <apiKey>})
YIELD value
RETURN *
----

.Example response
[source, bash]
----
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| value |
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| "The database represents movies and people, like in a movie database or social network.
There are no defined relationships between nodes, allowing flexibility for future connections.
The Movie node includes properties like title, tagline, and release year." |
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
1 row
----

.Query call with path
[source,cypher]
----
CALL apoc.ml.fromQueries(['MATCH (n:Movie) RETURN n', 'MATCH p=(n:Movie)--() RETURN p'],
{apiKey: <apiKey>})
YIELD value
RETURN *
----

.Example response
[source, bash]
----
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| value |
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| "models relationships in the movie industry, connecting :Person nodes to :Movie nodes.
It represents actors, directors, writers, producers, and reviewers connected to movies they are involved with.
Similar to a social network graph but specialized for the entertainment industry.
Each relationship type corresponds to common roles in movie production and reviewing.
Allows for querying and analyzing connections and collaborations within the movie business." |
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
1 row
----


.Input Parameters
[%autowidth, opts=header]
|===
| name | description
| queries | The list of queries
| conf | An optional configuration map, please check the next section
|===

.Configuration map
[%autowidth, opts=header]
|===
| name | description | mandatory
| apiKey | OpenAI API key | in case `apoc.openai.key` is not defined
| model | The Open AI model | no, default `gpt-3.5-turbo`
| sample | The number of nodes to skip, e.g. a sample of 1000 will read every 1000th node. It's used as a parameter to `apoc.meta.data` procedure that computes the schema | no, default is a random number
|===

.Results
[%autowidth, opts=header]
|===
Expand Down
50 changes: 41 additions & 9 deletions extended/src/main/java/apoc/ml/Prompt.java
Original file line number Diff line number Diff line change
Expand Up @@ -232,18 +232,29 @@ private String prompt(String userQuestion, String systemPrompt, String assistant
return result;
}

private final static String SCHEMA_QUERY = """
call apoc.meta.data({maxRels: 10, sample: coalesce($sample, (count{()}/1000)+1)})
YIELD label, other, elementType, type, property
WITH label, elementType,\s
apoc.text.join(collect(case when NOT type = "RELATIONSHIP" then property+": "+type else null end),", ") AS properties, \s
private final static String SCHEMA_FROM_META_DATA = """
\nWITH label, elementType,
apoc.text.join(collect(case when NOT type = "RELATIONSHIP" then property+": "+type else null end),", ") AS properties,
collect(case when type = "RELATIONSHIP" AND elementType = "node" then "(:" + label + ")-[:" + property + "]->(:" + toString(other[0]) + ")" else null end) as patterns
with elementType as type,\s
with elementType as type,
apoc.text.join(collect(":"+label+" {"+properties+"}"),"\\n") as entities, apoc.text.join(apoc.coll.flatten(collect(coalesce(patterns,[]))),"\\n") as patterns
return collect(case type when "relationship" then entities end)[0] as relationships,\s
collect(case type when "node" then entities end)[0] as nodes,\s
collect(case type when "node" then patterns end)[0] as patterns\s
return collect(case type when "relationship" then entities end)[0] as relationships,
collect(case type when "node" then entities end)[0] as nodes,
collect(case type when "node" then patterns end)[0] as patterns
""";

private final static String SCHEMA_QUERY = """
call apoc.meta.data({maxRels: 10, sample: coalesce($sample, (count{()}/1000)+1)})
YIELD label, other, elementType, type, property
""" + SCHEMA_FROM_META_DATA;

private final static String GRAPH_QUERY = """
UNWIND $queries AS query
CALL apoc.meta.data.of(query, {maxRels: 10, sample: $sample})
YIELD label, other, elementType, type, property
WITH DISTINCT label, other, elementType, type, property
""" + SCHEMA_FROM_META_DATA;

private final static String SCHEMA_PROMPT = """
nodes:
%s
Expand All @@ -253,6 +264,27 @@ private String prompt(String userQuestion, String systemPrompt, String assistant
%s
""";


@Procedure
public Stream<StringResult> fromQueries(@Name(value = "queries") List<String> queries, @Name(value = "conf", defaultValue = "{}") Map<String, Object> conf) throws MalformedURLException, JsonProcessingException {
String schemaExplanation = prompt("Please explain the graph database schema to me and relate it to well known concepts and domains.",
EXPLAIN_SCHEMA_PROMPT, "This database schema ", loadSchemaQueries(tx, queries, conf), conf, List.of());
return Stream.of(new StringResult(schemaExplanation));
}

private String loadSchemaQueries(Transaction tx, List<String> queries, Map<String, Object> conf) {

Map<String, Object> params = Map.of(
"sample", conf.getOrDefault("sample", 1000L),
"queries", queries
);

return tx.execute(GRAPH_QUERY, params)
.stream()
.map(m -> SCHEMA_PROMPT.formatted(m.get("nodes"), m.get("relationships"), m.get("patterns")))
.collect(Collectors.joining("\n"));
}

private String loadSchema(Transaction tx, Map<String, Object> conf) {
Map<String, Object> params = new HashMap<>();
params.put("sample", conf.get("sample"));
Expand Down
1 change: 1 addition & 0 deletions extended/src/main/resources/extended.txt
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ apoc.ml.bedrock.image
apoc.ml.bedrock.list
apoc.ml.cypher
apoc.ml.fromCypher
apoc.ml.fromQueries
apoc.ml.query
apoc.ml.schema
apoc.ml.openai.chat
Expand Down
73 changes: 73 additions & 0 deletions extended/src/test/java/apoc/ml/PromptIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import static apoc.util.TestUtil.testResult;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.fail;

public class PromptIT {

Expand Down Expand Up @@ -144,4 +145,76 @@ public void testFromCypher() {
});
}

@Test
public void testSchemaFromQueries() {
List<String> queries = List.of("MATCH p=(n:Movie)--() RETURN p", "MATCH (n:Person) RETURN n", "MATCH (n:Movie) RETURN n", "MATCH p=(n)-[r]->() RETURN r");

testCall(db, """
CALL apoc.ml.fromQueries($queries, {apiKey: $apiKey})
""",
Map.of(
"queries", queries,
"apiKey", OPENAI_KEY
),
(r) -> {

String value = ((String) r.get("value")).toLowerCase();
Assertions.assertThat(value).containsIgnoringCase("movie");
Assertions.assertThat(value).containsAnyOf("person", "people");
});
}

@Test
public void testSchemaFromQueriesWithSingleQuery() {
List<String> queries = List.of("MATCH (n:Movie) RETURN n");

testCall(db, """
CALL apoc.ml.fromQueries($queries, {apiKey: $apiKey})
""",
Map.of(
"queries", queries,
"apiKey", OPENAI_KEY
),
(r) -> {
String value = ((String) r.get("value")).toLowerCase();
Assertions.assertThat(value).containsIgnoringCase("movie");
Assertions.assertThat(value).doesNotContainIgnoringCase("person", "people");
});
}

@Test
public void testSchemaFromQueriesWithWrongQuery() {
List<String> queries = List.of("MATCH (n:Movie) RETURN a");
try {
testCall(db, """
CALL apoc.ml.fromQueries($queries, {apiKey: $apiKey})
""",
Map.of(
"queries", queries,
"apiKey", OPENAI_KEY
),
(r) -> fail());
} catch (Exception e) {
Assertions.assertThat(e.getMessage()).contains(" Variable `a` not defined");
}

}

@Test
public void testSchemaFromEmptyQueries() {
List<String> queries = List.of("MATCH (n:Movie) RETURN 1");

testCall(db, """
CALL apoc.ml.fromQueries($queries, {apiKey: $apiKey})
""",
Map.of(
"queries", queries,
"apiKey", OPENAI_KEY
),
(r) -> {
String value = ((String) r.get("value")).toLowerCase();
Assertions.assertThat(value).containsAnyOf("does not contain", "empty", "undefined", "doesn't have");
});
}

}
Loading