Skip to content

Commit

Permalink
Check for convergence and stop when reached required tolerance error.
Browse files Browse the repository at this point in the history
  • Loading branch information
castagna committed May 29, 2012
1 parent 9e9198b commit 11f07dd
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 71 deletions.
Expand Up @@ -32,78 +32,73 @@
public class PageRankVertex extends EdgeListVertex<Text, DoubleWritable, NullWritable, DoubleWritable> {

private static final Logger log = LoggerFactory.getLogger(PageRankVertex.class);
public static final int NUM_ITERATIONS = 30;

public static final int DEFAULT_NUM_ITERATIONS = 30;
public static final float DEFAULT_TOLERANCE = 10e-9f;

private int numIterations;
private double tolerance;
private Aggregator<?> danglingCurrentAggegator;
private Aggregator<?> pagerankSumAggegator;
private Aggregator<?> errorPreviousAggegator;

@Override
public void compute(Iterator<DoubleWritable> msgIterator) {
log.debug("{}#{} - compute(...) vertexValue={}", new Object[]{getVertexId(), getSuperstep(), getVertexValue()});
log.debug("{}#{} compute() vertexValue={}", new Object[]{getVertexId(), getSuperstep(), getVertexValue()});

@SuppressWarnings("unchecked")
Aggregator<DoubleWritable> danglingAggegator = (Aggregator<DoubleWritable>)getAggregator("dangling");
@SuppressWarnings("unchecked")
Aggregator<DoubleWritable> pagerankAggegator = (Aggregator<DoubleWritable>)getAggregator("pagerank");
@SuppressWarnings("unchecked")
Aggregator<DoubleWritable> errorCurrentAggegator = (Aggregator<DoubleWritable>)getAggregator("error-current");
@SuppressWarnings("unchecked")
Aggregator<DoubleWritable> errorPreviousAggegator = (Aggregator<DoubleWritable>)getAggregator("error-previous");
log.debug("{}#{} - compute(...) errorCurrentAggregator={}", new Object[]{getVertexId(), getSuperstep(), errorCurrentAggegator.getAggregatedValue() });
log.debug("{}#{} - compute(...) errorPreviousAggregator={}", new Object[]{getVertexId(), getSuperstep(), errorPreviousAggegator.getAggregatedValue() });
danglingCurrentAggegator = getAggregator("dangling-current");
@SuppressWarnings("unchecked") Aggregator<DoubleWritable> danglingPreviousAggegator = (Aggregator<DoubleWritable>)getAggregator("dangling-previous");
@SuppressWarnings("unchecked") Aggregator<DoubleWritable> errorCurrentAggegator = (Aggregator<DoubleWritable>)getAggregator("error-current");
errorPreviousAggegator = getAggregator("error-previous");
pagerankSumAggegator = getAggregator("pagerank-sum");
@SuppressWarnings("unchecked") Aggregator<LongWritable> verticesCountAggregator = (Aggregator<LongWritable>)getAggregator("vertices-count");

@SuppressWarnings("unchecked")
Aggregator<LongWritable> countVerticesAggegator = (Aggregator<LongWritable>)getAggregator("count");
long numVertices = countVerticesAggegator.getAggregatedValue().get();

double danglingNodesContribution = danglingAggegator.getAggregatedValue().get();
long numVertices = verticesCountAggregator.getAggregatedValue().get();
double danglingNodesContribution = danglingPreviousAggegator.getAggregatedValue().get();
numIterations = getConf().getInt("giraph.pagerank.iterations", DEFAULT_NUM_ITERATIONS);
tolerance = getConf().getFloat("giraph.pagerank.tolerance", DEFAULT_TOLERANCE);

if ( getSuperstep() == 0 ) {
log.debug("{}#{} - compute(...): {}", new Object[]{getVertexId(), getSuperstep(), "sending fake messages, just to count vertices including dangling ones"});
log.debug("{}#{} compute(): sending fake messages to count vertices, including 'implicit' dangling ones", getVertexId(), getSuperstep());
sendMsgToAllEdges ( new DoubleWritable() );
} else if ( getSuperstep() == 1 ) {
log.debug("{}#{} - compute(...): {}", new Object[]{getVertexId(), getSuperstep(), "counting vertices including dangling ones"});
countVerticesAggegator.aggregate(new LongWritable(1L));
log.debug("{}#{} compute(): counting vertices including 'implicit' dangling ones", getVertexId(), getSuperstep());
verticesCountAggregator.aggregate ( new LongWritable(1L) );
} else if ( getSuperstep() == 2 ) {
log.debug("{}#{} - compute(...): numVertices={}", new Object[]{getVertexId(), getSuperstep(), numVertices});
log.debug("{}#{} - compute(...): {}", new Object[]{getVertexId(), getSuperstep(), "initializing pagerank scores to 1/N"});
log.debug("{}#{} compute(): initializing pagerank scores to 1/N, N={}", new Object[]{getVertexId(), getSuperstep(), numVertices});
DoubleWritable vertexValue = new DoubleWritable ( 1.0 / numVertices );
setVertexValue(vertexValue);
log.debug("{}#{} - compute(...) vertexValue={}", new Object[]{getVertexId(), getSuperstep(), getVertexValue()});
send( danglingAggegator, pagerankAggegator );
log.debug("{}#{} compute() vertexValue <-- {}", new Object[]{getVertexId(), getSuperstep(), getVertexValue()});
sendMessages();
} else if ( getSuperstep() > 2 ) {
if ( getSuperstep() % 2 == 1 ) {
log.debug("{}#{} - compute(...): numVertices={}", new Object[]{getVertexId(), getSuperstep(), numVertices});
double sum = 0;
while (msgIterator.hasNext()) {
double msgValue = msgIterator.next().get();
log.debug("{}#{} - compute(...) <-- {}", new Object[]{getVertexId(), getSuperstep(), msgValue});
sum += msgValue;
}
log.debug("{}#{} - compute(...) danglingNodesContribution={}", new Object[]{getVertexId(), getSuperstep(), danglingNodesContribution });
DoubleWritable vertexValue = new DoubleWritable( ( 0.15f / numVertices ) + 0.85f * ( sum + danglingNodesContribution / numVertices ) );
errorCurrentAggegator.aggregate( new DoubleWritable(Math.abs(vertexValue.get() - getVertexValue().get())) );
setVertexValue(vertexValue);
log.debug("{}#{} - compute(...) vertexValue={}", new Object[]{getVertexId(), getSuperstep(), getVertexValue()});
double sum = 0;
while (msgIterator.hasNext()) {
double msgValue = msgIterator.next().get();
log.debug("{}#{} compute() <-- {}", new Object[]{getVertexId(), getSuperstep(), msgValue});
sum += msgValue;
}
send( danglingAggegator, pagerankAggegator );
DoubleWritable vertexValue = new DoubleWritable( ( 0.15f / numVertices ) + 0.85f * ( sum + danglingNodesContribution / numVertices ) );
errorCurrentAggegator.aggregate( new DoubleWritable(Math.abs(vertexValue.get() - getVertexValue().get())) );
setVertexValue(vertexValue);
log.debug("{}#{} compute() vertexValue <-- {}", new Object[]{getVertexId(), getSuperstep(), getVertexValue()});
sendMessages();
}
}

private void send( Aggregator<DoubleWritable> danglingAggegator, Aggregator<DoubleWritable> pagerankAggegator ) {
if ( getSuperstep() < NUM_ITERATIONS ) {

@SuppressWarnings("unchecked")
private void sendMessages() {
double error = ((Aggregator<DoubleWritable>)errorPreviousAggegator).getAggregatedValue().get();
if ( ( getSuperstep() - 3 < numIterations ) && ( error > tolerance ) ) {
long edges = getNumOutEdges();
if ( edges > 0 ) {
log.debug("{}#{} - send(...) numOutEdges={} propagating pagerank values...", new Object[]{getVertexId(), getSuperstep(), edges});
sendMsgToAllEdges ( new DoubleWritable(getVertexValue().get() / edges) );
}
if ( ( edges == 0 ) && ( getSuperstep() % 2 == 0) ) {
log.debug("{}#{} - send(...) numOutEdges={} updating dangling node contribution...", new Object[]{getVertexId(), getSuperstep(), edges});
log.debug("{}#{} - send(...) danglingAggregator={}", new Object[]{getVertexId(), getSuperstep(), danglingAggegator.getAggregatedValue().get()});
danglingAggegator.aggregate( getVertexValue() );
log.debug("{}#{} - send(...) danglingAggregator={}", new Object[]{getVertexId(), getSuperstep(), danglingAggegator.getAggregatedValue().get()});
} else {
((Aggregator<DoubleWritable>)danglingCurrentAggegator).aggregate( getVertexValue() );
}
} else {
pagerankAggegator.aggregate ( getVertexValue() );
((Aggregator<DoubleWritable>)pagerankSumAggegator).aggregate ( getVertexValue() );
voteToHalt();
log.debug("{}#{} - compute(...) --> halt", new Object[]{getVertexId(), getSuperstep()});
log.debug("{}#{} compute() --> halt", getVertexId(), getSuperstep());
}
}

Expand Down
Expand Up @@ -30,39 +30,42 @@ public class PageRankVertexWorkerContext extends WorkerContext {

private static final Logger log = LoggerFactory.getLogger(PageRankVertexWorkerContext.class);

@SuppressWarnings("unchecked")
@Override
public void preApplication() throws InstantiationException, IllegalAccessException {
log.debug("preApplication()");
registerAggregator("dangling", SumAggregator.class);
registerAggregator("pagerank", SumAggregator.class);
registerAggregator("dangling-current", SumAggregator.class);
registerAggregator("dangling-previous", SumAggregator.class);
registerAggregator("error-current", SumAggregator.class);
registerAggregator("error-previous", SumAggregator.class);
registerAggregator("count", LongSumAggregator.class);
registerAggregator("pagerank-sum", SumAggregator.class);
registerAggregator("vertices-count", LongSumAggregator.class);

((Aggregator<DoubleWritable>)getAggregator("error-previous")).setAggregatedValue( new DoubleWritable( Double.MAX_VALUE ) );
((Aggregator<DoubleWritable>)getAggregator("error-current")).setAggregatedValue( new DoubleWritable( Double.MAX_VALUE ) );
}

@Override
public void postApplication() {
log.debug("postApplication()");
log.debug("postApplication() pagerank={}", getAggregator("pagerank").getAggregatedValue());
log.debug("postApplication() pagerank-sum={}", getAggregator("pagerank-sum").getAggregatedValue());
}

@SuppressWarnings("unchecked")
@Override
public void preSuperstep() {
log.debug("preSuperstep()");
((Aggregator<DoubleWritable>)getAggregator("error-previous")).setAggregatedValue( new DoubleWritable(((Aggregator<DoubleWritable>)getAggregator("error-current")).getAggregatedValue().get()) );
((Aggregator<DoubleWritable>)getAggregator("error-current")).setAggregatedValue(new DoubleWritable(0L));
if ( getSuperstep() % 2 == 0 ) {
((Aggregator<DoubleWritable>)getAggregator("dangling")).setAggregatedValue(new DoubleWritable(0L));
log.debug("preSuperstep() danglingAggregators={}", getAggregator("dangling").getAggregatedValue());
if ( getSuperstep() > 2 ) {
((Aggregator<DoubleWritable>)getAggregator("error-previous")).setAggregatedValue( new DoubleWritable(((Aggregator<DoubleWritable>)getAggregator("error-current")).getAggregatedValue().get()) );
((Aggregator<DoubleWritable>)getAggregator("error-current")).setAggregatedValue( new DoubleWritable(0L) );
}
((Aggregator<DoubleWritable>)getAggregator("dangling-previous")).setAggregatedValue( new DoubleWritable(((Aggregator<DoubleWritable>)getAggregator("dangling-current")).getAggregatedValue().get()) );
((Aggregator<DoubleWritable>)getAggregator("dangling-current")).setAggregatedValue( new DoubleWritable(0L) );
}

@Override
public void postSuperstep() {
log.debug("postSuperstep()");
log.debug("postSuperstep() error-previous={}", getAggregator("error-previous").getAggregatedValue());
log.debug("postSuperstep() error-current={}", getAggregator("error-current").getAggregatedValue());
}

}
Expand Up @@ -16,7 +16,7 @@
* limitations under the License.
*/

package org.apache.jena.grande.giraph.pagerank.memory;
package org.apache.jena.grande.giraph.pagerank;

import java.io.BufferedReader;
import java.io.File;
Expand All @@ -38,27 +38,32 @@
import org.apache.commons.io.output.ByteArrayOutputStream;
import org.apache.jena.grande.giraph.pagerank.RunPageRankVertexLocally;
import org.apache.jena.grande.giraph.pagerank.RunSimplePageRankVertexLocally;
import org.apache.jena.grande.giraph.pagerank.memory.JungPageRank;
import org.apache.jena.grande.giraph.pagerank.memory.PlainPageRank;

public class TestPageRank extends TestCase {

private static final String filename = "src/test/resources/pagerank.txt";
private static final double dumping_factor = ( 0.85d );
private static final double alpha = ( 1.0d - dumping_factor );
private static final int iterations = 30;
private static final double tolerance = 0.0000001d;

public void testPlainPageRank() throws IOException {
File input = new File (filename);
BufferedReader in = new BufferedReader(new FileReader (input)) ;

PlainPageRank pagerank1 = new PlainPageRank (in, 0.85d, 30) ;
PlainPageRank pagerank1 = new PlainPageRank ( in, dumping_factor, iterations ) ;
Map<String, Double> result1 = pagerank1.compute() ;

JungPageRank pagerank2 = new JungPageRank(input, 30, 0.0000001d, 0.15d);
JungPageRank pagerank2 = new JungPageRank ( input, iterations, tolerance, alpha );
Map<String, Double> result2 = pagerank2.compute();

check ( result1, result2 );
}

public void testPageRankVertex() throws Exception {
File input = new File (filename);
JungPageRank pagerank1 = new JungPageRank(input, 30, 0.0000001d, 0.15d);
JungPageRank pagerank1 = new JungPageRank ( new File (filename), iterations, tolerance, alpha );
Map<String, Double> result1 = pagerank1.compute();

Map<String, Double> result2 = RunPageRankVertexLocally.run(filename);
Expand All @@ -67,8 +72,7 @@ public void testPageRankVertex() throws Exception {
}

public void testSimplePageRankVertex() throws Exception {
File input = new File (filename);
JungPageRank pagerank1 = new JungPageRank(input, 30, 0.0000001d, 0.15d);
JungPageRank pagerank1 = new JungPageRank ( new File (filename), iterations, tolerance, alpha );
Map<String, Double> result1 = pagerank1.compute();

Map<String, Double> result2 = RunSimplePageRankVertexLocally.run(filename);
Expand All @@ -87,13 +91,13 @@ private void check ( Map<String, Double> result1, Map<String, Double> result2 )
}

// pagerank values should be a probability distribution, right? Sum should be 1 then.
assertEquals ( dump(result1, result2), 1.0, sum(result1), 0.0001 );
assertEquals ( dump(result1, result2), 1.0, sum(result2), 0.0001 );
assertEquals ( dump(result1, result2), 1.0, sum(result1), 0.00001 );
assertEquals ( dump(result1, result2), 1.0, sum(result2), 0.00001 );

// check actual pagerank values
for ( String key : result1.keySet() ) {
assertTrue( result2.containsKey(key) );
assertEquals ( dump(result1, result2), result1.get(key), result2.get(key), 0.000001d );
assertEquals ( dump(result1, result2), result1.get(key), result2.get(key), 0.00001d );
}
}

Expand Down

0 comments on commit 11f07dd

Please sign in to comment.