Skip to content

Commit

Permalink
code clean
Browse files Browse the repository at this point in the history
  • Loading branch information
vga91 committed May 24, 2024
1 parent 15160c7 commit dd5ed21
Show file tree
Hide file tree
Showing 11 changed files with 154 additions and 188 deletions.
13 changes: 7 additions & 6 deletions extended-it/src/test/java/apoc/vectordb/ChromaDbTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
import static apoc.util.MapUtil.map;
import static apoc.util.TestUtil.testCall;
import static apoc.util.TestUtil.testResult;
import static apoc.vectordb.VectorDbHandler.Type.CHROMA;
import static apoc.vectordb.VectorDbTestUtil.assertBerlinResult;
import static apoc.vectordb.VectorDbTestUtil.assertLondonResult;
import static apoc.vectordb.VectorDbTestUtil.assertNodesCreated;
import static apoc.vectordb.VectorDbTestUtil.assertRelsAndIndexesCreated;
import static apoc.vectordb.VectorDbTestUtil.assertRelsCreated;
import static apoc.vectordb.VectorDbTestUtil.dropAndDeleteAll;
import static apoc.vectordb.VectorDbTestUtil.EntityType.*;
import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY;
Expand Down Expand Up @@ -203,7 +204,7 @@ public void queryVectorsWithLimit() {
}

@Test
public void queryVectorsWithCreateIndex() {
public void queryVectorsWithCreateNode() {
Map<String, Object> conf = map(ALL_RESULTS_KEY, true,
MAPPING_KEY, map("embeddingProp", "vect",
"label", "Test",
Expand Down Expand Up @@ -246,7 +247,7 @@ MAPPING_KEY, map("embeddingProp", "vect",
}

@Test
public void queryVectorsWithCreateIndexUsingExistingNode() {
public void queryVectorsWithCreateNodeUsingExistingNode() {

db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})");

Expand All @@ -273,7 +274,7 @@ MAPPING_KEY, map("embeddingProp", "vect",
}

@Test
public void queryVectorsWithCreateRelIndex() {
public void queryVectorsWithCreateRel() {

db.executeTransactionally("CREATE (:Start)-[:TEST {myId: 'one'}]->(:End), (:Start)-[:TEST {myId: 'two'}]->(:End)");

Expand All @@ -297,13 +298,13 @@ MAPPING_KEY, map("embeddingProp", "vect",
assertNotNull(row.get("vector"));
});

assertRelsAndIndexesCreated(db);
assertRelsCreated(db);
}

@Test
public void queryVectorsWithSystemDbStorage() {
db.executeTransactionally("CALL apoc.vectordb.store($vectorName, $host, $credential, $mapping)",
map("vectorName", VectorDbUtil.VectorDbHandler.Type.CHROMA.toString(),
map("vectorName", CHROMA.toString(),
"host", "http://" + HOST,
"credential", null,
"mapping", map("embeddingProp", "vect",
Expand Down
13 changes: 7 additions & 6 deletions extended-it/src/test/java/apoc/vectordb/QdrantTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
import static apoc.util.MapUtil.map;
import static apoc.util.TestUtil.testCall;
import static apoc.util.TestUtil.testResult;
import static apoc.vectordb.VectorDbHandler.Type.QDRANT;
import static apoc.vectordb.VectorDbTestUtil.EntityType.NODE;
import static apoc.vectordb.VectorDbTestUtil.EntityType.FALSE;
import static apoc.vectordb.VectorDbTestUtil.EntityType.REL;
import static apoc.vectordb.VectorDbTestUtil.assertBerlinResult;
import static apoc.vectordb.VectorDbTestUtil.assertLondonResult;
import static apoc.vectordb.VectorDbTestUtil.assertNodesCreated;
import static apoc.vectordb.VectorDbTestUtil.assertRelsAndIndexesCreated;
import static apoc.vectordb.VectorDbTestUtil.assertRelsCreated;
import static apoc.vectordb.VectorDbTestUtil.dropAndDeleteAll;
import static apoc.vectordb.VectorDbTestUtil.getAuthHeader;
import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY;
Expand Down Expand Up @@ -235,7 +236,7 @@ public void queryVectorsWithLimit() {
}

@Test
public void queryVectorsWithCreateIndex() {
public void queryVectorsWithCreateNode() {

Map<String, Object> conf = map(ALL_RESULTS_KEY, true,
HEADERS_KEY, ADMIN_AUTHORIZATION,
Expand Down Expand Up @@ -281,7 +282,7 @@ MAPPING_KEY, map("embeddingProp", "vect",
}

@Test
public void queryVectorsWithCreateIndexUsingExistingNode() {
public void queryVectorsWithCreateNodeUsingExistingNode() {

db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})");

Expand Down Expand Up @@ -310,7 +311,7 @@ MAPPING_KEY, map("embeddingProp", "vect",
}

@Test
public void queryVectorsWithCreateRelIndex() {
public void queryVectorsWithCreateRel() {

db.executeTransactionally("CREATE (:Start)-[:TEST {myId: 'one'}]->(:End), (:Start)-[:TEST {myId: 'two'}]->(:End)");

Expand All @@ -334,13 +335,13 @@ MAPPING_KEY, map("embeddingProp", "vect",
assertNotNull(row.get("vector"));
});

assertRelsAndIndexesCreated(db);
assertRelsCreated(db);
}

@Test
public void queryVectorsWithSystemDbStorage() {
db.executeTransactionally("CALL apoc.vectordb.store($vectorName, $host, $credential, $mapping)",
map("vectorName", VectorDbUtil.VectorDbHandler.Type.QDRANT.toString(),
map("vectorName", QDRANT.toString(),
"host", "http://" + HOST,
"credential", ADMIN_KEY,
"mapping", map("embeddingProp", "vect",
Expand Down
13 changes: 7 additions & 6 deletions extended-it/src/test/java/apoc/vectordb/WeaviateTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import static apoc.util.TestUtil.testCallEmpty;
import static apoc.util.TestUtil.testResult;
import static apoc.util.Util.map;
import static apoc.vectordb.VectorDbHandler.Type.WEAVIATE;
import static apoc.vectordb.VectorDbTestUtil.*;
import static apoc.vectordb.VectorDbTestUtil.EntityType.*;
import static apoc.vectordb.VectorEmbeddingConfig.ALL_RESULTS_KEY;
Expand Down Expand Up @@ -225,7 +226,7 @@ public void queryVectorsWithLimit() {
}

@Test
public void queryVectorsWithCreateIndex() {
public void queryVectorsWithCreateNode() {

Map<String, Object> conf = map(ALL_RESULTS_KEY, true,
"fields", FIELDS,
Expand Down Expand Up @@ -274,7 +275,7 @@ MAPPING_KEY, map("embeddingProp", "vect",
}

@Test
public void queryVectorsWithCreateIndexUsingExistingNode() {
public void queryVectorsWithCreateNodeUsingExistingNode() {

db.executeTransactionally("CREATE (:Test {myId: 'one'}), (:Test {myId: 'two'})");

Expand Down Expand Up @@ -304,7 +305,7 @@ MAPPING_KEY, map("embeddingProp", "vect",
}

@Test
public void queryVectorsWithCreateRelIndex() {
public void queryVectorsWithCreateRel() {

db.executeTransactionally("CREATE (:Start)-[:TEST {myId: 'one'}]->(:End), (:Start)-[:TEST {myId: 'two'}]->(:End)");

Expand All @@ -330,11 +331,11 @@ MAPPING_KEY, map("embeddingProp", "vect",
assertNotNull(row.get("vector"));
});

assertRelsAndIndexesCreated(db);
assertRelsCreated(db);
}

@Test
public void queryVectorsWithCreateRelIndexWithoutVectorResult() {
public void queryVectorsWithCreateRelWithoutVectorResult() {

db.executeTransactionally("CREATE (:Start)-[:TEST {myId: 'one'}]->(:End), (:Start)-[:TEST {myId: 'two'}]->(:End)");

Expand Down Expand Up @@ -370,7 +371,7 @@ MAPPING_KEY, map("type", "TEST",
@Test
public void queryVectorsWithSystemDbStorage() {
db.executeTransactionally("CALL apoc.vectordb.store($vectorName, $host, $credential, $mapping)",
map("vectorName", VectorDbUtil.VectorDbHandler.Type.WEAVIATE.toString(),
map("vectorName", WEAVIATE.toString(),
"host", "http://" + HOST + "/v1",
"credential", ADMIN_KEY,
"mapping", map("embeddingProp", "vect",
Expand Down
5 changes: 2 additions & 3 deletions extended/src/main/java/apoc/util/ExtendedUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,8 @@ private static void mapEntryToList(Map<Object, List> map, Map<String, Object> ve
list.add(item);
return list;
}
List list = (List) v;
list.add(v);
return list;
v.add(item);
return v;
});
}

Expand Down
7 changes: 3 additions & 4 deletions extended/src/main/java/apoc/vectordb/ChromaDb.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import apoc.ml.RestAPIConfig;
import apoc.result.ListResult;
import apoc.result.MapResult;
import apoc.util.UrlResolver;
import org.apache.commons.collections4.CollectionUtils;
import org.neo4j.graphdb.GraphDatabaseService;
import org.neo4j.graphdb.Transaction;
Expand All @@ -26,8 +25,8 @@
import static apoc.util.MapUtil.map;
import static apoc.vectordb.VectorDb.executeRequest;
import static apoc.vectordb.VectorDb.getEmbeddingResultStream;
import static apoc.vectordb.VectorDbHandler.Type.CHROMA;
import static apoc.vectordb.VectorDbUtil.*;
import static apoc.vectordb.VectorDbUtil.VectorDbHandler.Type.CHROMA;
import static apoc.vectordb.VectorEmbeddingConfig.*;

@Extended
Expand Down Expand Up @@ -133,7 +132,7 @@ public Stream<EmbeddingResult> query(@Name("hostOrKey") String hostOrKey,
Map<String, Object> config = getVectorDbInfo(hostOrKey, collection, configuration, url);

VectorEmbeddingConfig apiConfig = CHROMA.get().getEmbedding().fromGet(config, procedureCallContext, ids);
return getEmbeddingResultStream(apiConfig, procedureCallContext, urlAccessChecker, db, tx,
return getEmbeddingResultStream(apiConfig, procedureCallContext, urlAccessChecker, tx,
v -> listToMap((Map) v).stream());
}

Expand All @@ -150,7 +149,7 @@ public Stream<EmbeddingResult> query(@Name("hostOrKey") String hostOrKey,
Map<String, Object> config = getVectorDbInfo(hostOrKey, collection, configuration, url);

VectorEmbeddingConfig apiConfig = CHROMA.get().getEmbedding().fromQuery(config, procedureCallContext, vector, filter, limit, collection);
return getEmbeddingResultStream(apiConfig, procedureCallContext, urlAccessChecker, db, tx,
return getEmbeddingResultStream(apiConfig, procedureCallContext, urlAccessChecker, tx,
v -> listOfListsToMap((Map) v).stream());
}

Expand Down
7 changes: 3 additions & 4 deletions extended/src/main/java/apoc/vectordb/Qdrant.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import apoc.Extended;
import apoc.ml.RestAPIConfig;
import apoc.result.MapResult;
import apoc.util.UrlResolver;
import org.neo4j.graphdb.GraphDatabaseService;
import org.neo4j.graphdb.Transaction;
import org.neo4j.graphdb.security.URLAccessChecker;
Expand All @@ -22,8 +21,8 @@
import static apoc.ml.RestAPIConfig.METHOD_KEY;
import static apoc.vectordb.VectorDb.executeRequest;
import static apoc.vectordb.VectorDb.getEmbeddingResultStream;
import static apoc.vectordb.VectorDbHandler.Type.QDRANT;
import static apoc.vectordb.VectorDbUtil.*;
import static apoc.vectordb.VectorDbUtil.VectorDbHandler.Type.QDRANT;

@Extended
public class Qdrant {
Expand Down Expand Up @@ -135,7 +134,7 @@ public Stream<EmbeddingResult> query(@Name("hostOrKey") String hostOrKey,
Map<String, Object> config = getVectorDbInfo(hostOrKey, collection, configuration, url);

VectorEmbeddingConfig apiConfig = QDRANT.get().getEmbedding().fromGet(config, procedureCallContext, ids);
return getEmbeddingResultStream(apiConfig, procedureCallContext, urlAccessChecker, db, tx);
return getEmbeddingResultStream(apiConfig, procedureCallContext, urlAccessChecker, tx);
}

@Procedure(value = "apoc.vectordb.qdrant.query", mode = Mode.SCHEMA)
Expand All @@ -151,7 +150,7 @@ public Stream<EmbeddingResult> query(@Name("hostOrKey") String hostOrKey,
Map<String, Object> config = getVectorDbInfo(hostOrKey, collection, configuration, url);

VectorEmbeddingConfig apiConfig = QDRANT.get().getEmbedding().fromQuery(config, procedureCallContext, vector, filter, limit, collection);
return getEmbeddingResultStream(apiConfig, procedureCallContext, urlAccessChecker, db, tx);
return getEmbeddingResultStream(apiConfig, procedureCallContext, urlAccessChecker, tx);
}

private Map<String, Object> getVectorDbInfo(
Expand Down
68 changes: 26 additions & 42 deletions extended/src/main/java/apoc/vectordb/VectorDb.java
Original file line number Diff line number Diff line change
Expand Up @@ -81,22 +81,20 @@ public Stream<EmbeddingResult> get(@Name("host") String host,
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration) throws Exception {

getEndpoint(configuration, host);
VectorEmbeddingConfig restAPIConfig = new VectorEmbeddingConfig(configuration);//, Map.of(), Map.of());
return getEmbeddingResultStream(restAPIConfig, procedureCallContext, urlAccessChecker, db, tx);
VectorEmbeddingConfig restAPIConfig = new VectorEmbeddingConfig(configuration);
return getEmbeddingResultStream(restAPIConfig, procedureCallContext, urlAccessChecker, tx);
}

public static Stream<EmbeddingResult> getEmbeddingResultStream(VectorEmbeddingConfig conf,
ProcedureCallContext procedureCallContext,
URLAccessChecker urlAccessChecker,
GraphDatabaseService db,
Transaction tx) throws Exception {
return getEmbeddingResultStream(conf, procedureCallContext, urlAccessChecker, db, tx, v -> ((List<Map>) v).stream());
return getEmbeddingResultStream(conf, procedureCallContext, urlAccessChecker, tx, v -> ((List<Map>) v).stream());
}

public static Stream<EmbeddingResult> getEmbeddingResultStream(VectorEmbeddingConfig conf,
ProcedureCallContext procedureCallContext,
URLAccessChecker urlAccessChecker,
GraphDatabaseService db,
Transaction tx,
Function<Object, Stream<Map>> objectMapper) throws Exception {
List<String> fields = procedureCallContext.outputFields().toList();
Expand All @@ -109,10 +107,10 @@ public static Stream<EmbeddingResult> getEmbeddingResultStream(VectorEmbeddingCo

return resultStream
.flatMap(objectMapper)
.map(m -> getEmbeddingResult(conf, db, tx, hasVector, hasMetadata, mapping, m));
.map(m -> getEmbeddingResult(conf, tx, hasVector, hasMetadata, mapping, m));
}

public static EmbeddingResult getEmbeddingResult(VectorEmbeddingConfig conf, GraphDatabaseService db, Transaction tx, boolean hasEmbedding, boolean hasMetadata, VectorMappingConfig mapping, Map m) {
public static EmbeddingResult getEmbeddingResult(VectorEmbeddingConfig conf, Transaction tx, boolean hasEmbedding, boolean hasMetadata, VectorMappingConfig mapping, Map m) {
Object id = conf.isAllResults() ? m.get(conf.getIdKey()) : null;
List<Double> embedding = hasEmbedding ? (List<Double>) m.get(conf.getVectorKey()) : null;
Map<String, Object> metadata = hasMetadata ? (Map<String, Object>) m.get(conf.getMetadataKey()) : null;
Expand All @@ -121,15 +119,15 @@ public static EmbeddingResult getEmbeddingResult(VectorEmbeddingConfig conf, Gra
Double score = Util.toDouble(m.get(conf.getScoreKey()));
String text = conf.isAllResults() ? (String) m.get(conf.getTextKey()) : null;

Entity entity = handleMapping(tx, db, mapping, metadata, embedding);
Entity entity = handleMapping(tx, mapping, metadata, embedding);
if (entity != null) entity = Util.rebind(tx, entity);
return new EmbeddingResult(id, score, embedding, metadata, text,
mapping.getLabel() == null ? null : (Node) entity,
mapping.getLabel() != null ? null : (Relationship) entity
);
}

private static Entity handleMapping(Transaction tx, GraphDatabaseService db, VectorMappingConfig mapping, Map<String, Object> metadata, List<Double> embedding) {
private static Entity handleMapping(Transaction tx, VectorMappingConfig mapping, Map<String, Object> metadata, List<Double> embedding) {
if (mapping.getProp() == null) {
return null;
}
Expand All @@ -138,33 +136,26 @@ private static Entity handleMapping(Transaction tx, GraphDatabaseService db, Vec
}
Map<String, Object> metaProps = new HashMap<>(metadata);
if (mapping.getLabel() != null) {
return handleMappingNode(tx, db, mapping, metaProps, embedding);
return handleMappingNode(tx, mapping, metaProps, embedding);
} else if (mapping.getType() != null) {
return handleMappingRel(tx, db, mapping, metaProps, embedding);
return handleMappingRel(tx, mapping, metaProps, embedding);
} else {
throw new RuntimeException("Mapping conf has to contain either label or type key");
}
}

private static Entity handleMappingNode(Transaction tx, GraphDatabaseService db, VectorMappingConfig mapping, Map<String, Object> metaProps, List<Double> embedding) {
String query = "CREATE CONSTRAINT IF NOT EXISTS FOR (n:%s) REQUIRE n.%s IS UNIQUE"
.formatted(mapping.getLabel(), mapping.getProp());
db.executeTransactionally(query);

private static Entity handleMappingNode(Transaction transaction, VectorMappingConfig mapping, Map<String, Object> metaProps, List<Double> embedding) {
try {
Node node;
try (Transaction transaction = db.beginTx()) {
Object propValue = metaProps.get(mapping.getId());
node = transaction.findNode(Label.label(mapping.getLabel()), mapping.getProp(), propValue);
if (node == null && mapping.isCreate()) {
node = transaction.createNode(Label.label(mapping.getLabel()));
node.setProperty(mapping.getProp(), propValue);
}
if (node != null) {
setProperties(node, metaProps);
setVectorProp(tx, db, mapping, embedding, node/*, setVectorQuery*/);
}
transaction.commit();
Object propValue = metaProps.get(mapping.getId());
node = transaction.findNode(Label.label(mapping.getLabel()), mapping.getProp(), propValue);
if (node == null && mapping.isCreate()) {
node = transaction.createNode(Label.label(mapping.getLabel()));
node.setProperty(mapping.getProp(), propValue);
}
if (node != null) {
setProperties(node, metaProps);
setVectorProp(mapping, embedding, node);
}

return node;
Expand All @@ -173,22 +164,15 @@ private static Entity handleMappingNode(Transaction tx, GraphDatabaseService db,
}
}

private static Entity handleMappingRel(Transaction tx, GraphDatabaseService db, VectorMappingConfig mapping, Map<String, Object> metaProps, List<Double> embedding) {
private static Entity handleMappingRel(Transaction transaction, VectorMappingConfig mapping, Map<String, Object> metaProps, List<Double> embedding) {
try {
String query = "CREATE CONSTRAINT IF NOT EXISTS FOR ()-[r:%s]-() REQUIRE (r.%s) IS UNIQUE"
.formatted(mapping.getType(), mapping.getProp());
db.executeTransactionally(query);

// in this case we cannot auto-create the rel, since we should have to define start and end node as well
Relationship rel;
try (Transaction transaction = db.beginTx()) {
Object propValue = metaProps.get(mapping.getId());
rel = transaction.findRelationship(RelationshipType.withName(mapping.getType()), mapping.getProp(), propValue);
if (rel != null) {
setProperties(rel, metaProps);
setVectorProp(tx, db, mapping, embedding, rel/*, setVectorQuery*/);
}
transaction.commit();
Object propValue = metaProps.get(mapping.getId());
rel = transaction.findRelationship(RelationshipType.withName(mapping.getType()), mapping.getProp(), propValue);
if (rel != null) {
setProperties(rel, metaProps);
setVectorProp(mapping, embedding, rel);
}

return rel;
Expand All @@ -197,7 +181,7 @@ private static Entity handleMappingRel(Transaction tx, GraphDatabaseService db,
}
}

private static <T extends Entity> void setVectorProp(Transaction tx, GraphDatabaseService db, VectorMappingConfig mapping, List<Double> embedding, T entity/*, String setVectorQuery*/) {
private static <T extends Entity> void setVectorProp(VectorMappingConfig mapping, List<Double> embedding, T entity) {
if (mapping.getEmbeddingProp() == null) {
return;
}
Expand Down
Loading

0 comments on commit dd5ed21

Please sign in to comment.