From 11f07dd897562f7a4bf8d6e4845128d7f2cdd2ff Mon Sep 17 00:00:00 2001 From: Paolo Castagna Date: Tue, 29 May 2012 12:00:17 +0100 Subject: [PATCH] Check for convergence and stop when reached required tolerance error. --- .../giraph/pagerank/PageRankVertex.java | 95 +++++++++---------- .../pagerank/PageRankVertexWorkerContext.java | 25 ++--- .../pagerank/{memory => }/TestPageRank.java | 24 +++-- 3 files changed, 73 insertions(+), 71 deletions(-) rename src/test/java/org/apache/jena/grande/giraph/pagerank/{memory => }/TestPageRank.java (85%) diff --git a/src/main/java/org/apache/jena/grande/giraph/pagerank/PageRankVertex.java b/src/main/java/org/apache/jena/grande/giraph/pagerank/PageRankVertex.java index 0a7de4e..868805b 100644 --- a/src/main/java/org/apache/jena/grande/giraph/pagerank/PageRankVertex.java +++ b/src/main/java/org/apache/jena/grande/giraph/pagerank/PageRankVertex.java @@ -32,78 +32,73 @@ public class PageRankVertex extends EdgeListVertex { private static final Logger log = LoggerFactory.getLogger(PageRankVertex.class); - public static final int NUM_ITERATIONS = 30; + + public static final int DEFAULT_NUM_ITERATIONS = 30; + public static final float DEFAULT_TOLERANCE = 10e-9f; + + private int numIterations; + private double tolerance; + private Aggregator danglingCurrentAggegator; + private Aggregator pagerankSumAggegator; + private Aggregator errorPreviousAggegator; @Override public void compute(Iterator msgIterator) { - log.debug("{}#{} - compute(...) vertexValue={}", new Object[]{getVertexId(), getSuperstep(), getVertexValue()}); + log.debug("{}#{} compute() vertexValue={}", new Object[]{getVertexId(), getSuperstep(), getVertexValue()}); - @SuppressWarnings("unchecked") - Aggregator danglingAggegator = (Aggregator)getAggregator("dangling"); - @SuppressWarnings("unchecked") - Aggregator pagerankAggegator = (Aggregator)getAggregator("pagerank"); - @SuppressWarnings("unchecked") - Aggregator errorCurrentAggegator = (Aggregator)getAggregator("error-current"); - @SuppressWarnings("unchecked") - Aggregator errorPreviousAggegator = (Aggregator)getAggregator("error-previous"); - log.debug("{}#{} - compute(...) errorCurrentAggregator={}", new Object[]{getVertexId(), getSuperstep(), errorCurrentAggegator.getAggregatedValue() }); - log.debug("{}#{} - compute(...) errorPreviousAggregator={}", new Object[]{getVertexId(), getSuperstep(), errorPreviousAggegator.getAggregatedValue() }); + danglingCurrentAggegator = getAggregator("dangling-current"); + @SuppressWarnings("unchecked") Aggregator danglingPreviousAggegator = (Aggregator)getAggregator("dangling-previous"); + @SuppressWarnings("unchecked") Aggregator errorCurrentAggegator = (Aggregator)getAggregator("error-current"); + errorPreviousAggegator = getAggregator("error-previous"); + pagerankSumAggegator = getAggregator("pagerank-sum"); + @SuppressWarnings("unchecked") Aggregator verticesCountAggregator = (Aggregator)getAggregator("vertices-count"); - @SuppressWarnings("unchecked") - Aggregator countVerticesAggegator = (Aggregator)getAggregator("count"); - long numVertices = countVerticesAggegator.getAggregatedValue().get(); - - double danglingNodesContribution = danglingAggegator.getAggregatedValue().get(); + long numVertices = verticesCountAggregator.getAggregatedValue().get(); + double danglingNodesContribution = danglingPreviousAggegator.getAggregatedValue().get(); + numIterations = getConf().getInt("giraph.pagerank.iterations", DEFAULT_NUM_ITERATIONS); + tolerance = getConf().getFloat("giraph.pagerank.tolerance", DEFAULT_TOLERANCE); if ( getSuperstep() == 0 ) { - log.debug("{}#{} - compute(...): {}", new Object[]{getVertexId(), getSuperstep(), "sending fake messages, just to count vertices including dangling ones"}); + log.debug("{}#{} compute(): sending fake messages to count vertices, including 'implicit' dangling ones", getVertexId(), getSuperstep()); sendMsgToAllEdges ( new DoubleWritable() ); } else if ( getSuperstep() == 1 ) { - log.debug("{}#{} - compute(...): {}", new Object[]{getVertexId(), getSuperstep(), "counting vertices including dangling ones"}); - countVerticesAggegator.aggregate(new LongWritable(1L)); + log.debug("{}#{} compute(): counting vertices including 'implicit' dangling ones", getVertexId(), getSuperstep()); + verticesCountAggregator.aggregate ( new LongWritable(1L) ); } else if ( getSuperstep() == 2 ) { - log.debug("{}#{} - compute(...): numVertices={}", new Object[]{getVertexId(), getSuperstep(), numVertices}); - log.debug("{}#{} - compute(...): {}", new Object[]{getVertexId(), getSuperstep(), "initializing pagerank scores to 1/N"}); + log.debug("{}#{} compute(): initializing pagerank scores to 1/N, N={}", new Object[]{getVertexId(), getSuperstep(), numVertices}); DoubleWritable vertexValue = new DoubleWritable ( 1.0 / numVertices ); setVertexValue(vertexValue); - log.debug("{}#{} - compute(...) vertexValue={}", new Object[]{getVertexId(), getSuperstep(), getVertexValue()}); - send( danglingAggegator, pagerankAggegator ); + log.debug("{}#{} compute() vertexValue <-- {}", new Object[]{getVertexId(), getSuperstep(), getVertexValue()}); + sendMessages(); } else if ( getSuperstep() > 2 ) { - if ( getSuperstep() % 2 == 1 ) { - log.debug("{}#{} - compute(...): numVertices={}", new Object[]{getVertexId(), getSuperstep(), numVertices}); - double sum = 0; - while (msgIterator.hasNext()) { - double msgValue = msgIterator.next().get(); - log.debug("{}#{} - compute(...) <-- {}", new Object[]{getVertexId(), getSuperstep(), msgValue}); - sum += msgValue; - } - log.debug("{}#{} - compute(...) danglingNodesContribution={}", new Object[]{getVertexId(), getSuperstep(), danglingNodesContribution }); - DoubleWritable vertexValue = new DoubleWritable( ( 0.15f / numVertices ) + 0.85f * ( sum + danglingNodesContribution / numVertices ) ); - errorCurrentAggegator.aggregate( new DoubleWritable(Math.abs(vertexValue.get() - getVertexValue().get())) ); - setVertexValue(vertexValue); - log.debug("{}#{} - compute(...) vertexValue={}", new Object[]{getVertexId(), getSuperstep(), getVertexValue()}); + double sum = 0; + while (msgIterator.hasNext()) { + double msgValue = msgIterator.next().get(); + log.debug("{}#{} compute() <-- {}", new Object[]{getVertexId(), getSuperstep(), msgValue}); + sum += msgValue; } - send( danglingAggegator, pagerankAggegator ); + DoubleWritable vertexValue = new DoubleWritable( ( 0.15f / numVertices ) + 0.85f * ( sum + danglingNodesContribution / numVertices ) ); + errorCurrentAggegator.aggregate( new DoubleWritable(Math.abs(vertexValue.get() - getVertexValue().get())) ); + setVertexValue(vertexValue); + log.debug("{}#{} compute() vertexValue <-- {}", new Object[]{getVertexId(), getSuperstep(), getVertexValue()}); + sendMessages(); } } - - private void send( Aggregator danglingAggegator, Aggregator pagerankAggegator ) { - if ( getSuperstep() < NUM_ITERATIONS ) { + + @SuppressWarnings("unchecked") + private void sendMessages() { + double error = ((Aggregator)errorPreviousAggegator).getAggregatedValue().get(); + if ( ( getSuperstep() - 3 < numIterations ) && ( error > tolerance ) ) { long edges = getNumOutEdges(); if ( edges > 0 ) { - log.debug("{}#{} - send(...) numOutEdges={} propagating pagerank values...", new Object[]{getVertexId(), getSuperstep(), edges}); sendMsgToAllEdges ( new DoubleWritable(getVertexValue().get() / edges) ); - } - if ( ( edges == 0 ) && ( getSuperstep() % 2 == 0) ) { - log.debug("{}#{} - send(...) numOutEdges={} updating dangling node contribution...", new Object[]{getVertexId(), getSuperstep(), edges}); - log.debug("{}#{} - send(...) danglingAggregator={}", new Object[]{getVertexId(), getSuperstep(), danglingAggegator.getAggregatedValue().get()}); - danglingAggegator.aggregate( getVertexValue() ); - log.debug("{}#{} - send(...) danglingAggregator={}", new Object[]{getVertexId(), getSuperstep(), danglingAggegator.getAggregatedValue().get()}); + } else { + ((Aggregator)danglingCurrentAggegator).aggregate( getVertexValue() ); } } else { - pagerankAggegator.aggregate ( getVertexValue() ); + ((Aggregator)pagerankSumAggegator).aggregate ( getVertexValue() ); voteToHalt(); - log.debug("{}#{} - compute(...) --> halt", new Object[]{getVertexId(), getSuperstep()}); + log.debug("{}#{} compute() --> halt", getVertexId(), getSuperstep()); } } diff --git a/src/main/java/org/apache/jena/grande/giraph/pagerank/PageRankVertexWorkerContext.java b/src/main/java/org/apache/jena/grande/giraph/pagerank/PageRankVertexWorkerContext.java index 4ea4f08..bd8dab6 100644 --- a/src/main/java/org/apache/jena/grande/giraph/pagerank/PageRankVertexWorkerContext.java +++ b/src/main/java/org/apache/jena/grande/giraph/pagerank/PageRankVertexWorkerContext.java @@ -30,39 +30,42 @@ public class PageRankVertexWorkerContext extends WorkerContext { private static final Logger log = LoggerFactory.getLogger(PageRankVertexWorkerContext.class); + @SuppressWarnings("unchecked") @Override public void preApplication() throws InstantiationException, IllegalAccessException { log.debug("preApplication()"); - registerAggregator("dangling", SumAggregator.class); - registerAggregator("pagerank", SumAggregator.class); + registerAggregator("dangling-current", SumAggregator.class); + registerAggregator("dangling-previous", SumAggregator.class); registerAggregator("error-current", SumAggregator.class); registerAggregator("error-previous", SumAggregator.class); - registerAggregator("count", LongSumAggregator.class); + registerAggregator("pagerank-sum", SumAggregator.class); + registerAggregator("vertices-count", LongSumAggregator.class); + + ((Aggregator)getAggregator("error-previous")).setAggregatedValue( new DoubleWritable( Double.MAX_VALUE ) ); + ((Aggregator)getAggregator("error-current")).setAggregatedValue( new DoubleWritable( Double.MAX_VALUE ) ); } @Override public void postApplication() { log.debug("postApplication()"); - log.debug("postApplication() pagerank={}", getAggregator("pagerank").getAggregatedValue()); + log.debug("postApplication() pagerank-sum={}", getAggregator("pagerank-sum").getAggregatedValue()); } @SuppressWarnings("unchecked") @Override public void preSuperstep() { log.debug("preSuperstep()"); - ((Aggregator)getAggregator("error-previous")).setAggregatedValue( new DoubleWritable(((Aggregator)getAggregator("error-current")).getAggregatedValue().get()) ); - ((Aggregator)getAggregator("error-current")).setAggregatedValue(new DoubleWritable(0L)); - if ( getSuperstep() % 2 == 0 ) { - ((Aggregator)getAggregator("dangling")).setAggregatedValue(new DoubleWritable(0L)); - log.debug("preSuperstep() danglingAggregators={}", getAggregator("dangling").getAggregatedValue()); + if ( getSuperstep() > 2 ) { + ((Aggregator)getAggregator("error-previous")).setAggregatedValue( new DoubleWritable(((Aggregator)getAggregator("error-current")).getAggregatedValue().get()) ); + ((Aggregator)getAggregator("error-current")).setAggregatedValue( new DoubleWritable(0L) ); } + ((Aggregator)getAggregator("dangling-previous")).setAggregatedValue( new DoubleWritable(((Aggregator)getAggregator("dangling-current")).getAggregatedValue().get()) ); + ((Aggregator)getAggregator("dangling-current")).setAggregatedValue( new DoubleWritable(0L) ); } @Override public void postSuperstep() { log.debug("postSuperstep()"); - log.debug("postSuperstep() error-previous={}", getAggregator("error-previous").getAggregatedValue()); - log.debug("postSuperstep() error-current={}", getAggregator("error-current").getAggregatedValue()); } } \ No newline at end of file diff --git a/src/test/java/org/apache/jena/grande/giraph/pagerank/memory/TestPageRank.java b/src/test/java/org/apache/jena/grande/giraph/pagerank/TestPageRank.java similarity index 85% rename from src/test/java/org/apache/jena/grande/giraph/pagerank/memory/TestPageRank.java rename to src/test/java/org/apache/jena/grande/giraph/pagerank/TestPageRank.java index df46b38..1eade09 100644 --- a/src/test/java/org/apache/jena/grande/giraph/pagerank/memory/TestPageRank.java +++ b/src/test/java/org/apache/jena/grande/giraph/pagerank/TestPageRank.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.jena.grande.giraph.pagerank.memory; +package org.apache.jena.grande.giraph.pagerank; import java.io.BufferedReader; import java.io.File; @@ -38,27 +38,32 @@ import org.apache.commons.io.output.ByteArrayOutputStream; import org.apache.jena.grande.giraph.pagerank.RunPageRankVertexLocally; import org.apache.jena.grande.giraph.pagerank.RunSimplePageRankVertexLocally; +import org.apache.jena.grande.giraph.pagerank.memory.JungPageRank; +import org.apache.jena.grande.giraph.pagerank.memory.PlainPageRank; public class TestPageRank extends TestCase { private static final String filename = "src/test/resources/pagerank.txt"; + private static final double dumping_factor = ( 0.85d ); + private static final double alpha = ( 1.0d - dumping_factor ); + private static final int iterations = 30; + private static final double tolerance = 0.0000001d; public void testPlainPageRank() throws IOException { File input = new File (filename); BufferedReader in = new BufferedReader(new FileReader (input)) ; - PlainPageRank pagerank1 = new PlainPageRank (in, 0.85d, 30) ; + PlainPageRank pagerank1 = new PlainPageRank ( in, dumping_factor, iterations ) ; Map result1 = pagerank1.compute() ; - JungPageRank pagerank2 = new JungPageRank(input, 30, 0.0000001d, 0.15d); + JungPageRank pagerank2 = new JungPageRank ( input, iterations, tolerance, alpha ); Map result2 = pagerank2.compute(); check ( result1, result2 ); } public void testPageRankVertex() throws Exception { - File input = new File (filename); - JungPageRank pagerank1 = new JungPageRank(input, 30, 0.0000001d, 0.15d); + JungPageRank pagerank1 = new JungPageRank ( new File (filename), iterations, tolerance, alpha ); Map result1 = pagerank1.compute(); Map result2 = RunPageRankVertexLocally.run(filename); @@ -67,8 +72,7 @@ public void testPageRankVertex() throws Exception { } public void testSimplePageRankVertex() throws Exception { - File input = new File (filename); - JungPageRank pagerank1 = new JungPageRank(input, 30, 0.0000001d, 0.15d); + JungPageRank pagerank1 = new JungPageRank ( new File (filename), iterations, tolerance, alpha ); Map result1 = pagerank1.compute(); Map result2 = RunSimplePageRankVertexLocally.run(filename); @@ -87,13 +91,13 @@ private void check ( Map result1, Map result2 ) } // pagerank values should be a probability distribution, right? Sum should be 1 then. - assertEquals ( dump(result1, result2), 1.0, sum(result1), 0.0001 ); - assertEquals ( dump(result1, result2), 1.0, sum(result2), 0.0001 ); + assertEquals ( dump(result1, result2), 1.0, sum(result1), 0.00001 ); + assertEquals ( dump(result1, result2), 1.0, sum(result2), 0.00001 ); // check actual pagerank values for ( String key : result1.keySet() ) { assertTrue( result2.containsKey(key) ); - assertEquals ( dump(result1, result2), result1.get(key), result2.get(key), 0.000001d ); + assertEquals ( dump(result1, result2), result1.get(key), result2.get(key), 0.00001d ); } }