diff --git a/algo/src/main/java/org/neo4j/graphalgo/PageRankProc.java b/algo/src/main/java/org/neo4j/graphalgo/PageRankProc.java index ecdfc340b..6ae8edd53 100644 --- a/algo/src/main/java/org/neo4j/graphalgo/PageRankProc.java +++ b/algo/src/main/java/org/neo4j/graphalgo/PageRankProc.java @@ -33,6 +33,7 @@ import org.neo4j.graphalgo.impl.PageRankAlgorithm; import org.neo4j.graphalgo.results.PageRankScore; import org.neo4j.graphdb.Direction; +import org.neo4j.graphdb.Node; import org.neo4j.kernel.api.KernelTransaction; import org.neo4j.kernel.internal.GraphDatabaseAPI; import org.neo4j.logging.Log; @@ -42,6 +43,8 @@ import org.neo4j.procedure.Name; import org.neo4j.procedure.Procedure; +import java.util.ArrayList; +import java.util.List; import java.util.Map; import java.util.stream.IntStream; import java.util.stream.LongStream; @@ -104,7 +107,7 @@ public Stream pageRankStream( @Name(value = "relationship", defaultValue = "") String relationship, @Name(value = "config", defaultValue = "{}") Map config) { - ProcedureConfiguration configuration = ProcedureConfiguration.create(config); + ProcedureConfiguration configuration = ProcedureConfiguration.create(config); PageRankScore.Stats.Builder statsBuilder = new PageRankScore.Stats.Builder(); AllocationTracker tracker = AllocationTracker.create(); @@ -150,13 +153,19 @@ private Graph load( AllocationTracker tracker, Class graphFactory, PageRankScore.Stats.Builder statsBuilder, ProcedureConfiguration configuration) { - GraphLoader graphLoader = new GraphLoader(api, Pools.DEFAULT) .init(log, label, relationship, configuration) .withAllocationTracker(tracker) - .withDirection(Direction.OUTGOING) .withoutRelationshipWeights(); + Direction direction = configuration.getDirection(Direction.OUTGOING); + if (direction == Direction.BOTH) { + graphLoader.asUndirected(true); + } else { + graphLoader.withDirection(direction); + } + + try (ProgressTimer timer = statsBuilder.timeLoad()) { Graph graph = graphLoader.load(graphFactory); statsBuilder.withNodes(graph.nodeCount()); @@ -177,10 +186,14 @@ private PageRankResult evaluate( final int concurrency = configuration.getConcurrency(Pools.getNoThreadsInDefaultPool()); log.debug("Computing page rank with damping of " + dampingFactor + " and " + iterations + " iterations."); + + List sourceNodes = configuration.get("sourceNodes", new ArrayList<>()); + LongStream sourceNodeIds = sourceNodes.stream().mapToLong(Node::getId); PageRankAlgorithm prAlgo = PageRankAlgorithm.of( tracker, graph, dampingFactor, + sourceNodeIds, Pools.DEFAULT, concurrency, batchSize); @@ -189,6 +202,7 @@ private PageRankResult evaluate( .withLog(log) .withTerminationFlag(terminationFlag); + statsBuilder.timeEval(() -> prAlgo.compute(iterations)); statsBuilder diff --git a/algo/src/main/java/org/neo4j/graphalgo/impl/HugePageRank.java b/algo/src/main/java/org/neo4j/graphalgo/impl/HugePageRank.java index 1c25c6522..f6e7f931b 100644 --- a/algo/src/main/java/org/neo4j/graphalgo/impl/HugePageRank.java +++ b/algo/src/main/java/org/neo4j/graphalgo/impl/HugePageRank.java @@ -21,11 +21,7 @@ import com.carrotsearch.hppc.IntArrayList; import com.carrotsearch.hppc.LongArrayList; import org.neo4j.collection.primitive.PrimitiveLongIterator; -import org.neo4j.graphalgo.api.HugeDegrees; -import org.neo4j.graphalgo.api.HugeIdMapping; -import org.neo4j.graphalgo.api.HugeNodeIterator; -import org.neo4j.graphalgo.api.HugeRelationshipConsumer; -import org.neo4j.graphalgo.api.HugeRelationshipIterator; +import org.neo4j.graphalgo.api.*; import org.neo4j.graphalgo.core.utils.ParallelUtil; import org.neo4j.graphalgo.core.utils.paged.AllocationTracker; import org.neo4j.graphalgo.core.write.Exporter; @@ -39,14 +35,11 @@ import java.util.Iterator; import java.util.List; import java.util.concurrent.ExecutorService; +import java.util.stream.LongStream; import static org.neo4j.graphalgo.core.utils.ArrayUtil.binaryLookup; import static org.neo4j.graphalgo.core.utils.paged.AllocationTracker.humanReadable; -import static org.neo4j.graphalgo.core.utils.paged.MemoryUsage.shallowSizeOfInstance; -import static org.neo4j.graphalgo.core.utils.paged.MemoryUsage.sizeOfDoubleArray; -import static org.neo4j.graphalgo.core.utils.paged.MemoryUsage.sizeOfIntArray; -import static org.neo4j.graphalgo.core.utils.paged.MemoryUsage.sizeOfLongArray; -import static org.neo4j.graphalgo.core.utils.paged.MemoryUsage.sizeOfObjectArray; +import static org.neo4j.graphalgo.core.utils.paged.MemoryUsage.*; /** @@ -111,6 +104,8 @@ public class HugePageRank extends Algorithm implements PageRankAlg private final HugeRelationshipIterator relationshipIterator; private final HugeDegrees degrees; private final double dampingFactor; + private final HugeGraph graph; + private LongStream sourceNodeIds; private Log log; private ComputeSteps computeSteps; @@ -121,21 +116,17 @@ public class HugePageRank extends Algorithm implements PageRankAlg */ HugePageRank( AllocationTracker tracker, - HugeIdMapping idMapping, - HugeNodeIterator nodeIterator, - HugeRelationshipIterator relationshipIterator, - HugeDegrees degrees, - double dampingFactor) { + HugeGraph graph, + double dampingFactor, + LongStream sourceNodeIds) { this( null, -1, ParallelUtil.DEFAULT_BATCH_SIZE, tracker, - idMapping, - nodeIterator, - relationshipIterator, - degrees, - dampingFactor); + graph, + dampingFactor, + sourceNodeIds); } /** @@ -148,20 +139,20 @@ public class HugePageRank extends Algorithm implements PageRankAlg int concurrency, int batchSize, AllocationTracker tracker, - HugeIdMapping idMapping, - HugeNodeIterator nodeIterator, - HugeRelationshipIterator relationshipIterator, - HugeDegrees degrees, - double dampingFactor) { + HugeGraph graph, + double dampingFactor, + LongStream sourceNodeIds) { this.executor = executor; this.concurrency = concurrency; this.batchSize = batchSize; this.tracker = tracker; - this.idMapping = idMapping; - this.nodeIterator = nodeIterator; - this.relationshipIterator = relationshipIterator; - this.degrees = degrees; + this.idMapping = graph; + this.nodeIterator = graph; + this.relationshipIterator = graph; + this.degrees = graph; + this.graph = graph; this.dampingFactor = dampingFactor; + this.sourceNodeIds = sourceNodeIds; } /** @@ -209,6 +200,7 @@ private void initializeSteps() { concurrency, idMapping.nodeCount(), dampingFactor, + sourceNodeIds.map(graph::toHugeMappedNodeId).filter(mappedId -> mappedId != -1L).toArray(), relationshipIterator, degrees, partitions, @@ -246,6 +238,7 @@ private ComputeSteps createComputeSteps( int concurrency, long nodeCount, double dampingFactor, + long[] sourceNodeIds, HugeRelationshipIterator relationshipIterator, HugeDegrees degrees, List partitions, @@ -281,6 +274,7 @@ private ComputeSteps createComputeSteps( computeSteps.add(new ComputeStep( dampingFactor, + sourceNodeIds, relationshipIterator, degrees, tracker, @@ -542,6 +536,7 @@ private static final class ComputeStep implements Runnable, HugeRelationshipCons private long[] starts; private int[] lengths; + private long[] sourceNodeIds; private final HugeRelationshipIterator relationshipIterator; private final HugeDegrees degrees; private final AllocationTracker tracker; @@ -562,6 +557,7 @@ private static final class ComputeStep implements Runnable, HugeRelationshipCons ComputeStep( double dampingFactor, + long[] sourceNodeIds, HugeRelationshipIterator relationshipIterator, HugeDegrees degrees, AllocationTracker tracker, @@ -569,6 +565,7 @@ private static final class ComputeStep implements Runnable, HugeRelationshipCons long startNode) { this.dampingFactor = dampingFactor; this.alpha = 1.0 - dampingFactor; + this.sourceNodeIds = sourceNodeIds; this.relationshipIterator = relationshipIterator.concurrentCopy(); this.degrees = degrees; this.tracker = tracker; @@ -606,8 +603,21 @@ private void initialize() { }); tracker.add(sizeOfDoubleArray(partitionSize) << 1); + double[] partitionRank = new double[partitionSize]; - Arrays.fill(partitionRank, alpha); + if(sourceNodeIds.length == 0) { + Arrays.fill(partitionRank, alpha); + } else { + Arrays.fill(partitionRank,0); + + long[] partitionSourceNodeIds = LongStream.of(sourceNodeIds) + .filter(sourceNodeId -> sourceNodeId >= startNode && sourceNodeId <= endNode) + .toArray(); + + for (long sourceNodeId : partitionSourceNodeIds) { + partitionRank[Math.toIntExact(sourceNodeId - this.startNode)] = alpha; + } + } this.pageRank = partitionRank; this.deltas = Arrays.copyOf(partitionRank, partitionSize); diff --git a/algo/src/main/java/org/neo4j/graphalgo/impl/PageRank.java b/algo/src/main/java/org/neo4j/graphalgo/impl/PageRank.java index 3120e0068..38d5d7172 100644 --- a/algo/src/main/java/org/neo4j/graphalgo/impl/PageRank.java +++ b/algo/src/main/java/org/neo4j/graphalgo/impl/PageRank.java @@ -1,18 +1,18 @@ /** * Copyright (c) 2017 "Neo4j, Inc." - * + *

* This file is part of Neo4j Graph Algorithms . - * + *

* Neo4j Graph Algorithms is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. - * + *

* This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. - * + *

* You should have received a copy of the GNU General Public License * along with this program. If not, see . */ @@ -20,11 +20,7 @@ import com.carrotsearch.hppc.IntArrayList; import org.neo4j.collection.primitive.PrimitiveIntIterator; -import org.neo4j.graphalgo.api.Degrees; -import org.neo4j.graphalgo.api.IdMapping; -import org.neo4j.graphalgo.api.NodeIterator; -import org.neo4j.graphalgo.api.RelationshipConsumer; -import org.neo4j.graphalgo.api.RelationshipIterator; +import org.neo4j.graphalgo.api.*; import org.neo4j.graphalgo.core.utils.ParallelUtil; import org.neo4j.graphalgo.core.utils.Pools; import org.neo4j.graphalgo.core.write.Exporter; @@ -32,12 +28,10 @@ import org.neo4j.graphalgo.core.write.Translators; import org.neo4j.graphdb.Direction; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.Iterator; -import java.util.List; +import java.util.*; import java.util.concurrent.ExecutorService; +import java.util.stream.IntStream; +import java.util.stream.LongStream; import static org.neo4j.graphalgo.core.utils.ArrayUtil.binaryLookup; @@ -101,21 +95,16 @@ public class PageRank extends Algorithm implements PageRankAlgorithm { * Forces sequential use. If you want parallelism, prefer * {@link #PageRank(ExecutorService, int, int, IdMapping, NodeIterator, RelationshipIterator, Degrees, double)} */ - PageRank( - IdMapping idMapping, - NodeIterator nodeIterator, - RelationshipIterator relationshipIterator, - Degrees degrees, - double dampingFactor) { + PageRank(Graph graph, + double dampingFactor, + LongStream sourceNodeIds) { this( null, -1, ParallelUtil.DEFAULT_BATCH_SIZE, - idMapping, - nodeIterator, - relationshipIterator, - degrees, - dampingFactor); + graph, + dampingFactor, + sourceNodeIds); } /** @@ -127,28 +116,27 @@ public class PageRank extends Algorithm implements PageRankAlgorithm { ExecutorService executor, int concurrency, int batchSize, - IdMapping idMapping, - NodeIterator nodeIterator, - RelationshipIterator relationshipIterator, - Degrees degrees, - double dampingFactor) { + Graph graph, + double dampingFactor, + LongStream sourceNodeIds) { List partitions; if (ParallelUtil.canRunInParallel(executor)) { partitions = partitionGraph( adjustBatchSize(batchSize), - idMapping, - nodeIterator, - degrees); + graph, + graph, + graph); } else { executor = null; - partitions = createSinglePartition(idMapping, degrees); + partitions = createSinglePartition(graph, graph); } computeSteps = createComputeSteps( concurrency, dampingFactor, - relationshipIterator, - degrees, + sourceNodeIds.mapToInt(graph::toMappedNodeId).filter(mappedId -> mappedId != -1L).toArray(), + graph, + graph, partitions, executor); } @@ -220,6 +208,7 @@ private List createSinglePartition( private ComputeSteps createComputeSteps( int concurrency, double dampingFactor, + int[] sourceNodeIds, RelationshipIterator relationshipIterator, Degrees degrees, List partitions, @@ -252,6 +241,7 @@ private ComputeSteps createComputeSteps( computeSteps.add(new ComputeStep( dampingFactor, + sourceNodeIds, relationshipIterator, degrees, partitionCount, @@ -389,6 +379,7 @@ private static final class ComputeStep implements Runnable, RelationshipConsumer private int[] starts; private int[] lengths; + private int[] sourceNodeIds; private final RelationshipIterator relationshipIterator; private final Degrees degrees; @@ -408,12 +399,14 @@ private static final class ComputeStep implements Runnable, RelationshipConsumer ComputeStep( double dampingFactor, + int[] sourceNodeIds, RelationshipIterator relationshipIterator, Degrees degrees, int partitionSize, int startNode) { this.dampingFactor = dampingFactor; this.alpha = 1.0 - dampingFactor; + this.sourceNodeIds = sourceNodeIds; this.relationshipIterator = relationshipIterator; this.degrees = degrees; this.partitionSize = partitionSize; @@ -446,7 +439,21 @@ private void initialize() { Arrays.setAll(nextScores, i -> new int[lengths[i]]); double[] partitionRank = new double[partitionSize]; - Arrays.fill(partitionRank, alpha); + + if(sourceNodeIds.length == 0) { + Arrays.fill(partitionRank, alpha); + } else { + Arrays.fill(partitionRank,0); + + int[] partitionSourceNodeIds = IntStream.of(sourceNodeIds) + .filter(sourceNodeId -> sourceNodeId >= startNode && sourceNodeId < endNode) + .toArray(); + + for (int sourceNodeId : partitionSourceNodeIds) { + partitionRank[sourceNodeId - this.startNode] = alpha; + } + } + this.pageRank = partitionRank; this.deltas = Arrays.copyOf(partitionRank, partitionSize); @@ -509,6 +516,7 @@ private void synchronizeScores(int[] allScores) { int length = allScores.length; for (int i = 0; i < length; i++) { int sum = allScores[i]; + double delta = dampingFactor * (sum / 100_000.0); pageRank[i] += delta; deltas[i] = delta; diff --git a/algo/src/main/java/org/neo4j/graphalgo/impl/PageRankAlgorithm.java b/algo/src/main/java/org/neo4j/graphalgo/impl/PageRankAlgorithm.java index 3aa223444..f3ecabe80 100644 --- a/algo/src/main/java/org/neo4j/graphalgo/impl/PageRankAlgorithm.java +++ b/algo/src/main/java/org/neo4j/graphalgo/impl/PageRankAlgorithm.java @@ -23,6 +23,8 @@ import org.neo4j.graphalgo.core.utils.paged.AllocationTracker; import java.util.concurrent.ExecutorService; +import java.util.stream.LongStream; +import java.util.stream.Stream; public interface PageRankAlgorithm { @@ -33,35 +35,39 @@ public interface PageRankAlgorithm { Algorithm algorithm(); static PageRankAlgorithm of( - Graph graph, - double dampingFactor) { - return of(AllocationTracker.EMPTY, graph, dampingFactor); + Graph graph, + double dampingFactor, + LongStream sourceNodeIds) { + return of(AllocationTracker.EMPTY, dampingFactor, sourceNodeIds, graph); } static PageRankAlgorithm of( AllocationTracker tracker, - Graph graph, - double dampingFactor) { + double dampingFactor, + LongStream sourceNodeIds, + Graph graph) { if (graph instanceof HugeGraph) { HugeGraph huge = (HugeGraph) graph; - return new HugePageRank(tracker, huge, huge, huge, huge, dampingFactor); + return new HugePageRank(tracker, huge, dampingFactor, sourceNodeIds); } - return new PageRank(graph, graph, graph, graph, dampingFactor); + return new PageRank(graph, dampingFactor, sourceNodeIds); } static PageRankAlgorithm of( Graph graph, double dampingFactor, + LongStream sourceNodeIds, ExecutorService pool, int concurrency, int batchSize) { - return of(AllocationTracker.EMPTY, graph, dampingFactor, pool, concurrency, batchSize); + return of(AllocationTracker.EMPTY, graph, dampingFactor, sourceNodeIds, pool, concurrency, batchSize); } static PageRankAlgorithm of( AllocationTracker tracker, Graph graph, double dampingFactor, + LongStream sourceNodeIds, ExecutorService pool, int concurrency, int batchSize) { @@ -73,19 +79,16 @@ static PageRankAlgorithm of( batchSize, tracker, huge, - huge, - huge, - huge, - dampingFactor); + dampingFactor, + sourceNodeIds + ); } return new PageRank( pool, concurrency, batchSize, graph, - graph, - graph, - graph, - dampingFactor); + dampingFactor, + sourceNodeIds); } } diff --git a/doc/asciidoc/pagerank.adoc b/doc/asciidoc/pagerank.adoc index 8c7685465..aa9c99d7f 100644 --- a/doc/asciidoc/pagerank.adoc +++ b/doc/asciidoc/pagerank.adoc @@ -124,6 +124,42 @@ As we might expect, the Home page has the highest PageRank because it has incomi We can also see that it's not only the number of incoming links that is important, but also the importance of the pages behind those links. // end::stream-sample-graph-explanation[] +=== Personalized PageRank + +Personalized PageRank is a variation of PageRank which is biased towards a set of `sourceNodes`. +This variant of PageRank is often used as part of https://www.r-bloggers.com/from-random-walks-to-personalized-pagerank/[recommender systems^]. + +The following examples show how to run PageRank centered around 'Site A'. + +.The following will run the algorithm and stream results: +[source,cypher] +---- +include::scripts/pagerank.cypher[tag=ppr-stream-sample-graph] +---- + +.The following will run the algorithm and write back results: +[source,cypher] +---- +include::scripts/pagerank.cypher[tag=ppr-write-sample-graph] +---- + +// tag::ppr-stream-graph-result[] +.Results +[opts="header",cols="1,1"] +|=== +| Name | PageRank +| Home | 0.399 +| Site A | 0.169 +| About | 0.112 +| Product | 0.112 +| Links | 0.112 +| Site B | 0.019 +| Site C | 0.019 +| Site D | 0.019 +|=== +// end::ppr-stream-graph-result[] + + [[algorithms-pagerank-example]] == Example usage diff --git a/doc/asciidoc/scripts/pagerank.cypher b/doc/asciidoc/scripts/pagerank.cypher index 813167f87..ee6b8326e 100644 --- a/doc/asciidoc/scripts/pagerank.cypher +++ b/doc/asciidoc/scripts/pagerank.cypher @@ -46,6 +46,28 @@ YIELD nodes, iterations, loadMillis, computeMillis, writeMillis, dampingFactor, // end::write-sample-graph[] +// tag::ppr-stream-sample-graph[] +MATCH (siteA:Page {name: "Site A"}) + +CALL algo.pageRank.stream('Page', 'LINKS', {iterations:20, dampingFactor:0.85, sourceNodes: [siteA]}) +YIELD nodeId, score + +MATCH (node) WHERE id(node) = nodeId + +RETURN node.name AS page,score +ORDER BY score DESC + +// end::ppr-stream-sample-graph[] + +// tag::ppr-write-sample-graph[] + +MATCH (siteA:Page {name: "Site A"}) +CALL algo.pageRank('Page', 'LINKS', +{iterations:20, dampingFactor:0.85, sourceNodes: [siteA], write: true, writeProperty:"ppr"}) +YIELD nodes, iterations, loadMillis, computeMillis, writeMillis, dampingFactor, write, writeProperty +RETURN * +// end::ppr-write-sample-graph[] + // tag::cypher-loading[] CALL algo.pageRank( diff --git a/tests/src/test/java/org/neo4j/graphalgo/algo/PageRankProcIntegrationTest.java b/tests/src/test/java/org/neo4j/graphalgo/algo/PageRankProcIntegrationTest.java index cb4000cd1..18f0ce50a 100644 --- a/tests/src/test/java/org/neo4j/graphalgo/algo/PageRankProcIntegrationTest.java +++ b/tests/src/test/java/org/neo4j/graphalgo/algo/PageRankProcIntegrationTest.java @@ -207,7 +207,14 @@ public void testPageRankParallelExecution() throws Exception { private static void runQuery( String query, Consumer check) { - try (Result result = db.execute(query)) { + runQuery(query, new HashMap<>(), check); + } + + private static void runQuery( + String query, + Map params, + Consumer check) { + try (Result result = db.execute(query, params)) { result.accept(row -> { check.accept(row); return true; diff --git a/tests/src/test/java/org/neo4j/graphalgo/algo/PersonalizedPageRankProcIntegrationTest.java b/tests/src/test/java/org/neo4j/graphalgo/algo/PersonalizedPageRankProcIntegrationTest.java new file mode 100644 index 000000000..82447144a --- /dev/null +++ b/tests/src/test/java/org/neo4j/graphalgo/algo/PersonalizedPageRankProcIntegrationTest.java @@ -0,0 +1,254 @@ +/** + * Copyright (c) 2017 "Neo4j, Inc." + * + * This file is part of Neo4j Graph Algorithms . + * + * Neo4j Graph Algorithms is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.graphalgo.algo; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.neo4j.graphalgo.PageRankProc; +import org.neo4j.graphalgo.TestDatabaseCreator; +import org.neo4j.graphdb.Label; +import org.neo4j.graphdb.Node; +import org.neo4j.graphdb.Result; +import org.neo4j.graphdb.Transaction; +import org.neo4j.internal.kernel.api.exceptions.KernelException; +import org.neo4j.kernel.impl.proc.Procedures; +import org.neo4j.kernel.internal.GraphDatabaseAPI; + +import java.util.*; +import java.util.function.Consumer; + +import static org.junit.Assert.*; + +@RunWith(Parameterized.class) +public class PersonalizedPageRankProcIntegrationTest { + + private static GraphDatabaseAPI db; + private static Map expected = new HashMap<>(); + + private static final String DB_CYPHER = "" + + "CREATE (a:Label1 {name:\"a\"})\n" + + "CREATE (b:Label1 {name:\"b\"})\n" + + "CREATE (c:Label1 {name:\"c\"})\n" + + "CREATE (d:Label1 {name:\"d\"})\n" + + "CREATE (e:Label1 {name:\"e\"})\n" + + "CREATE (f:Label1 {name:\"f\"})\n" + + "CREATE (g:Label1 {name:\"g\"})\n" + + "CREATE (h:Label1 {name:\"h\"})\n" + + "CREATE (i:Label1 {name:\"i\"})\n" + + "CREATE (j:Label1 {name:\"j\"})\n" + + "CREATE (k:Label2 {name:\"k\"})\n" + + "CREATE (l:Label2 {name:\"l\"})\n" + + "CREATE (m:Label2 {name:\"m\"})\n" + + "CREATE (n:Label2 {name:\"n\"})\n" + + "CREATE (o:Label2 {name:\"o\"})\n" + + "CREATE (p:Label2 {name:\"p\"})\n" + + "CREATE (q:Label2 {name:\"q\"})\n" + + "CREATE (r:Label2 {name:\"r\"})\n" + + "CREATE (s:Label2 {name:\"s\"})\n" + + "CREATE (t:Label2 {name:\"t\"})\n" + + "CREATE\n" + + " (b)-[:TYPE1{foo:1.0}]->(c),\n" + + " (c)-[:TYPE1{foo:1.2}]->(b),\n" + + " (d)-[:TYPE1{foo:1.3}]->(a),\n" + + " (d)-[:TYPE1{foo:1.7}]->(b),\n" + + " (e)-[:TYPE1{foo:1.1}]->(b),\n" + + " (e)-[:TYPE1{foo:2.2}]->(d),\n" + + " (e)-[:TYPE1{foo:1.5}]->(f),\n" + + " (f)-[:TYPE1{foo:3.5}]->(b),\n" + + " (f)-[:TYPE1{foo:2.9}]->(e),\n" + + " (g)-[:TYPE2{foo:3.2}]->(b),\n" + + " (g)-[:TYPE2{foo:5.3}]->(e),\n" + + " (h)-[:TYPE2{foo:9.5}]->(b),\n" + + " (h)-[:TYPE2{foo:0.3}]->(e),\n" + + " (i)-[:TYPE2{foo:5.4}]->(b),\n" + + " (i)-[:TYPE2{foo:3.2}]->(e),\n" + + " (j)-[:TYPE2{foo:9.5}]->(e),\n" + + " (k)-[:TYPE2{foo:4.2}]->(e)\n"; + + @AfterClass + public static void tearDown() throws Exception { + if (db != null) db.shutdown(); + } + + @BeforeClass + public static void setup() throws KernelException { + db = TestDatabaseCreator.createTestDatabase(); + try (Transaction tx = db.beginTx()) { + db.execute(DB_CYPHER).close(); + tx.success(); + } + + db.getDependencyResolver() + .resolveDependency(Procedures.class) + .registerProcedure(PageRankProc.class); + + + try (Transaction tx = db.beginTx()) { + final Label label = Label.label("Label1"); + expected.put(db.findNode(label, "name", "a").getId(), 0.243); + expected.put(db.findNode(label, "name", "b").getId(), 1.844); + expected.put(db.findNode(label, "name", "c").getId(), 1.777); + expected.put(db.findNode(label, "name", "d").getId(), 0.218); + expected.put(db.findNode(label, "name", "e").getId(), 0.243); + expected.put(db.findNode(label, "name", "f").getId(), 0.218); + expected.put(db.findNode(label, "name", "g").getId(), 0.150); + expected.put(db.findNode(label, "name", "h").getId(), 0.150); + expected.put(db.findNode(label, "name", "i").getId(), 0.150); + expected.put(db.findNode(label, "name", "j").getId(), 0.150); + tx.success(); + } + } + + @Parameterized.Parameters(name = "{0}") + public static Collection data() { + return Arrays.asList( + new Object[]{"Heavy"}, + new Object[]{"Light"}, + new Object[]{"Kernel"}, + new Object[]{"Huge"} + ); + } + + @Parameterized.Parameter + public String graphImpl; + + @Test + public void testPageRankStream() throws Exception { + final Map actual = new HashMap<>(); + runQuery( + "CALL algo.pageRank.stream('Label1', 'TYPE1', {graph:'"+graphImpl+"'}) YIELD node, score", + row -> actual.put( + row.getNode("node").getId(), + (Double) row.get("score"))); + + assertMapEquals(expected, actual); + } + + @Test + public void testPageRankWriteBack() throws Exception { + runQuery( + "CALL algo.pageRank('Label1', 'TYPE1', {graph:'"+graphImpl+"'}) YIELD writeMillis, write, writeProperty", + row -> { + assertTrue(row.getBoolean("write")); + assertEquals("pagerank", row.getString("writeProperty")); + assertTrue( + "write time not set", + row.getNumber("writeMillis").intValue() >= 0); + }); + + assertResult("pagerank"); + } + + @Test + public void testPageRankWriteBackUnderDifferentProperty() throws Exception { + runQuery( + "CALL algo.pageRank('Label1', 'TYPE1', {writeProperty:'foobar', graph:'"+graphImpl+"'}) YIELD writeMillis, write, writeProperty", + row -> { + assertTrue(row.getBoolean("write")); + assertEquals("foobar", row.getString("writeProperty")); + assertTrue( + "write time not set", + row.getNumber("writeMillis").intValue() >= 0); + }); + + assertResult("foobar"); + } + + @Test + public void testPageRankParallelWriteBack() throws Exception { + runQuery( + "CALL algo.pageRank('Label1', 'TYPE1', {batchSize:3, write:true, graph:'"+graphImpl+"'}) YIELD writeMillis, write, writeProperty", + row -> assertTrue( + "write time not set", + row.getNumber("writeMillis").intValue() >= 0)); + + assertResult("pagerank"); + } + + @Test + public void testPageRankParallelExecution() throws Exception { + final Map actual = new HashMap<>(); + runQuery( + "CALL algo.pageRank.stream('Label1', 'TYPE1', {batchSize:2, graph:'"+graphImpl+"'}) YIELD nodeId, node, score", + row -> { + final long nodeId = row.getNumber("nodeId").longValue(); + final Node node = row.getNode("node"); + assertEquals(node.getId(), nodeId); + actual.put(nodeId, (Double) row.get("score")); + }); + assertMapEquals(expected, actual); + } + + private static void runQuery( + String query, + Consumer check) { + runQuery(query, new HashMap<>(), check); + } + + private static void runQuery( + String query, + Map params, + Consumer check) { + try (Result result = db.execute(query, params)) { + result.accept(row -> { + check.accept(row); + return true; + }); + } + } + + private void assertResult(final String scoreProperty) { + try (Transaction tx = db.beginTx()) { + for (Map.Entry entry : expected.entrySet()) { + double score = ((Number) db + .getNodeById(entry.getKey()) + .getProperty(scoreProperty)).doubleValue(); + assertEquals( + "score for " + entry.getKey(), + entry.getValue(), + score, + 0.1); + } + tx.success(); + } + } + + private static void assertMapEquals( + Map expected, + Map actual) { + assertEquals("number of elements", expected.size(), actual.size()); + HashSet expectedKeys = new HashSet<>(expected.keySet()); + for (Map.Entry entry : actual.entrySet()) { + assertTrue( + "unknown key " + entry.getKey(), + expectedKeys.remove(entry.getKey())); + assertEquals( + "value for " + entry.getKey(), + expected.get(entry.getKey()), + entry.getValue(), + 0.1); + } + for (Long expectedKey : expectedKeys) { + fail("missing key " + expectedKey); + } + } +} diff --git a/tests/src/test/java/org/neo4j/graphalgo/impl/PageRankTest.java b/tests/src/test/java/org/neo4j/graphalgo/impl/PageRankTest.java index a9c0c96a4..9104b9d75 100644 --- a/tests/src/test/java/org/neo4j/graphalgo/impl/PageRankTest.java +++ b/tests/src/test/java/org/neo4j/graphalgo/impl/PageRankTest.java @@ -41,6 +41,8 @@ import java.util.HashMap; import java.util.Map; import java.util.stream.IntStream; +import java.util.stream.LongStream; +import java.util.stream.Stream; import static org.junit.Assert.assertEquals; @@ -157,7 +159,7 @@ public void test() throws Exception { } final PageRankResult rankResult = PageRankAlgorithm - .of(graph, 0.85) + .of(graph, 0.85, LongStream.empty()) .compute(40) .result(); diff --git a/tests/src/test/java/org/neo4j/graphalgo/impl/PageRankWikiTest.java b/tests/src/test/java/org/neo4j/graphalgo/impl/PageRankWikiTest.java index 236c97acc..f4a0f6b4b 100644 --- a/tests/src/test/java/org/neo4j/graphalgo/impl/PageRankWikiTest.java +++ b/tests/src/test/java/org/neo4j/graphalgo/impl/PageRankWikiTest.java @@ -39,6 +39,8 @@ import java.util.HashMap; import java.util.Map; import java.util.stream.IntStream; +import java.util.stream.LongStream; +import java.util.stream.Stream; import static org.junit.Assert.assertEquals; @@ -146,7 +148,7 @@ public void test() throws Exception { .load(graphImpl); final PageRankResult rankResult = PageRankAlgorithm - .of(graph, 0.85) + .of(graph, 0.85, LongStream.empty()) .compute(40) .result(); diff --git a/tests/src/test/java/org/neo4j/graphalgo/impl/PersonalizedPageRankTest.java b/tests/src/test/java/org/neo4j/graphalgo/impl/PersonalizedPageRankTest.java new file mode 100644 index 000000000..999a98205 --- /dev/null +++ b/tests/src/test/java/org/neo4j/graphalgo/impl/PersonalizedPageRankTest.java @@ -0,0 +1,168 @@ +/** + * Copyright (c) 2017 "Neo4j, Inc." + * + * This file is part of Neo4j Graph Algorithms . + * + * Neo4j Graph Algorithms is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.graphalgo.impl; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.neo4j.graphalgo.TestDatabaseCreator; +import org.neo4j.graphalgo.api.Graph; +import org.neo4j.graphalgo.api.GraphFactory; +import org.neo4j.graphalgo.core.GraphLoader; +import org.neo4j.graphalgo.core.heavyweight.HeavyCypherGraphFactory; +import org.neo4j.graphalgo.core.heavyweight.HeavyGraphFactory; +import org.neo4j.graphalgo.core.huge.HugeGraphFactory; +import org.neo4j.graphalgo.core.neo4jview.GraphViewFactory; +import org.neo4j.graphalgo.core.utils.Pools; +import org.neo4j.graphdb.Direction; +import org.neo4j.graphdb.Label; +import org.neo4j.graphdb.Node; +import org.neo4j.graphdb.Transaction; +import org.neo4j.kernel.internal.GraphDatabaseAPI; + +import java.util.*; +import java.util.stream.IntStream; +import java.util.stream.LongStream; +import java.util.stream.Stream; + +import static org.junit.Assert.assertEquals; + +@RunWith(Parameterized.class) +public final class PersonalizedPageRankTest { + + private Class graphImpl; + + @Parameterized.Parameters(name = "{1}") + public static Collection data() { + return Arrays.asList( + new Object[]{HeavyGraphFactory.class, "HeavyGraphFactory"}, + new Object[]{HeavyCypherGraphFactory.class, "HeavyCypherGraphFactory"}, + new Object[]{HugeGraphFactory.class, "HugeGraphFactory"} + ); + } + private static final String DB_CYPHER = "" + + "CREATE (iphone:Product {name:\"iPhone5\"})\n" + + "CREATE (kindle:Product {name:\"Kindle Fire\"})\n" + + "CREATE (fitbit:Product {name:\"Fitbit Flex Wireless\"})\n" + + "CREATE (potter:Product {name:\"Harry Potter\"})\n" + + "CREATE (hobbit:Product {name:\"Hobbit\"})\n" + + + "CREATE (todd:Person {name:\"Todd\"})\n" + + "CREATE (mary:Person {name:\"Mary\"})\n" + + "CREATE (jill:Person {name:\"Jill\"})\n" + + "CREATE (john:Person {name:\"John\"})\n" + + + "CREATE\n" + + " (john)-[:PURCHASED]->(iphone),\n" + + " (john)-[:PURCHASED]->(kindle),\n" + + " (mary)-[:PURCHASED]->(iphone),\n" + + " (mary)-[:PURCHASED]->(kindle),\n" + + " (mary)-[:PURCHASED]->(fitbit),\n" + + " (jill)-[:PURCHASED]->(iphone),\n" + + " (jill)-[:PURCHASED]->(kindle),\n" + + " (jill)-[:PURCHASED]->(fitbit),\n" + + " (todd)-[:PURCHASED]->(fitbit),\n" + + " (todd)-[:PURCHASED]->(potter),\n" + + " (todd)-[:PURCHASED]->(hobbit)"; + + private static GraphDatabaseAPI db; + + @BeforeClass + public static void setupGraph() { + db = TestDatabaseCreator.createTestDatabase(); + try (Transaction tx = db.beginTx()) { + db.execute(DB_CYPHER).close(); + tx.success(); + } + } + + @AfterClass + public static void shutdownGraph() throws Exception { + if (db!=null) db.shutdown(); + } + + public PersonalizedPageRankTest( + Class graphImpl, + String nameIgnoredOnlyForTestName) { + this.graphImpl = graphImpl; + } + + @Test + public void test() throws Exception { + Label personLabel = Label.label("Person"); + Label productLabel = Label.label("Product"); + final Map expected = new HashMap<>(); + + try (Transaction tx = db.beginTx()) { + + expected.put(db.findNode(personLabel, "name", "John").getId(), 0.24851499999999993); + expected.put(db.findNode(personLabel, "name", "Jill").getId(), 0.12135449999999998); + expected.put(db.findNode(personLabel, "name", "Mary").getId(), 0.12135449999999998); + expected.put(db.findNode(personLabel, "name", "Todd").getId(), 0.043511499999999995); + + expected.put(db.findNode(productLabel, "name", "Kindle Fire").getId(), 0.17415649999999996); + expected.put(db.findNode(productLabel, "name", "iPhone5").getId(), 0.17415649999999996); + expected.put(db.findNode(productLabel, "name", "Fitbit Flex Wireless").getId(), 0.08085200000000001); + expected.put(db.findNode(productLabel, "name", "Harry Potter").getId(), 0.01224); + expected.put(db.findNode(productLabel, "name", "Hobbit").getId(), 0.01224); + tx.close(); + } + + final Graph graph; + if (graphImpl.isAssignableFrom(HeavyCypherGraphFactory.class)) { + graph = new GraphLoader(db) + .withLabel("MATCH (n) RETURN id(n) as id") + .withRelationshipType("MATCH (n)-[:PURCHASED]-(m) RETURN id(n) as source,id(m) as target") + .load(graphImpl); + + } else { + graph = new GraphLoader(db) + .withDirection(Direction.BOTH) + .withRelationshipType("PURCHASED") + .asUndirected(true) + .load(graphImpl); + } + + LongStream sourceNodeIds; + try(Transaction tx = db.beginTx()) { + Node node = db.findNode(personLabel, "name", "John"); + sourceNodeIds = LongStream.of(node.getId()); + } + + final PageRankResult rankResult = PageRankAlgorithm + .of(graph,0.85, sourceNodeIds, Pools.DEFAULT, 2, 1) + .compute(40) + .result(); + + IntStream.range(0, expected.size()).forEach(i -> { + final long nodeId = graph.toOriginalNodeId(i); + assertEquals( + "Node#" + nodeId, + expected.get(nodeId), + rankResult.score(i), + 1e-2 + ); + }); + + } + + +}