From b10d9ca10f4db0d1fa5d7b2c96819a49db6c8482 Mon Sep 17 00:00:00 2001 From: Mattias Persson Date: Mon, 20 Apr 2015 09:42:08 +0200 Subject: [PATCH] Workers utility for managing worker threads in IdMapper --- .../idmapping/string/EncodingIdMapper.java | 37 ++--- .../cache/idmapping/string/ParallelSort.java | 94 ++++-------- .../cache/idmapping/string/Workers.java | 138 ++++++++++++++++++ 3 files changed, 186 insertions(+), 83 deletions(-) create mode 100644 community/kernel/src/main/java/org/neo4j/unsafe/impl/batchimport/cache/idmapping/string/Workers.java diff --git a/community/kernel/src/main/java/org/neo4j/unsafe/impl/batchimport/cache/idmapping/string/EncodingIdMapper.java b/community/kernel/src/main/java/org/neo4j/unsafe/impl/batchimport/cache/idmapping/string/EncodingIdMapper.java index 7c5b9bd03bf6c..778fdb34e10a1 100644 --- a/community/kernel/src/main/java/org/neo4j/unsafe/impl/batchimport/cache/idmapping/string/EncodingIdMapper.java +++ b/community/kernel/src/main/java/org/neo4j/unsafe/impl/batchimport/cache/idmapping/string/EncodingIdMapper.java @@ -241,33 +241,28 @@ public boolean needsPreparation() public void prepare( InputIterable ids, Collector collector, ProgressListener progress ) { endPreviousGroup(); - synchronized ( this ) + dataCache = dataCache.fixate(); + trackerCache = cacheFactory.newIntArray( dataCacheStats.highestIndex()+1, -1 ); + + try { - dataCache = dataCache.fixate(); - trackerCache = cacheFactory.newIntArray( dataCacheStats.highestIndex()+1, -1 ); + sortBuckets = new ParallelSort( radix, dataCache, dataCacheStats, trackerCache, trackerStats, + processorsForSorting, progress, DEFAULT ).run(); - // Synchronized since there's this concern that a couple of other threads are changing trackerCache - // and it's nice to go through a memory barrier afterwards to ensure this CPU see correct data. - try + if ( detectAndMarkCollisions( progress ) > 0 ) { - sortBuckets = new ParallelSort( radix, dataCache, dataCacheStats, trackerCache, trackerStats, - processorsForSorting, progress, DEFAULT ).run(); - - if ( detectAndMarkCollisions( progress ) > 0 ) + try ( InputIterator idIterator = ids.iterator() ) { - try ( InputIterator idIterator = ids.iterator() ) - { - buildCollisionInfo( idIterator, collector, progress ); - } + buildCollisionInfo( idIterator, collector, progress ); } } - catch ( InterruptedException e ) - { - Thread.interrupted(); - throw new RuntimeException( "Got interrupted while preparing the index. Throwing this exception " - + "onwards will cause a chain reaction which will cause a panic in the whole import, " - + "so mission accomplished" ); - } + } + catch ( InterruptedException e ) + { + Thread.interrupted(); + throw new RuntimeException( "Got interrupted while preparing the index. Throwing this exception " + + "onwards will cause a chain reaction which will cause a panic in the whole import, " + + "so mission accomplished" ); } readyForUse = true; } diff --git a/community/kernel/src/main/java/org/neo4j/unsafe/impl/batchimport/cache/idmapping/string/ParallelSort.java b/community/kernel/src/main/java/org/neo4j/unsafe/impl/batchimport/cache/idmapping/string/ParallelSort.java index 947b8f19246e9..eb21f558821eb 100644 --- a/community/kernel/src/main/java/org/neo4j/unsafe/impl/batchimport/cache/idmapping/string/ParallelSort.java +++ b/community/kernel/src/main/java/org/neo4j/unsafe/impl/batchimport/cache/idmapping/string/ParallelSort.java @@ -67,7 +67,7 @@ public ParallelSort( Radix radix, LongArray dataCache, NumberArrayStats dataStat this.threads = threads; } - public long[][] run() throws InterruptedException + public synchronized long[][] run() throws InterruptedException { int[][] sortParams = sortRadix(); int threadsNeeded = 0; @@ -80,8 +80,7 @@ public long[][] run() throws InterruptedException threadsNeeded++; } CountDownLatch waitSignal = new CountDownLatch( 1 ); - CountDownLatch doneSignal = new CountDownLatch( threadsNeeded ); - SortWorker[] sortWorker = new SortWorker[threadsNeeded]; + Workers sortWorkers = new Workers<>( "SortWorker" ); progress.started( "SORT" ); for ( int i = 0; i < threadsNeeded; i++ ) { @@ -89,13 +88,12 @@ public long[][] run() throws InterruptedException { break; } - sortWorker[i] = new SortWorker( i, sortParams[i][0], sortParams[i][1], waitSignal, doneSignal ); - sortWorker[i].start(); + sortWorkers.start( new SortWorker( sortParams[i][0], sortParams[i][1], waitSignal ) ); } waitSignal.countDown(); try { - doneSignal.await(); + sortWorkers.awaitAndThrowOnError(); } finally { @@ -108,7 +106,7 @@ private int[][] sortRadix() throws InterruptedException { int[][] rangeParams = new int[threads][2]; int[] bucketRange = new int[threads]; - TrackerInitializer[] initializers = new TrackerInitializer[threads]; + Workers initializers = new Workers<>( "TrackerInitializer" ); sortBuckets = new long[threads][2]; int bucketSize = safeCastLongToInt( dataStats.size() / threads ); int count = 0, fullCount = 0 + 0; @@ -134,9 +132,9 @@ private int[][] sortRadix() throws InterruptedException fullCount += radixIndexCount[i]; progress.add( radixIndexCount[i] ); } - initializers[threadIndex] = new TrackerInitializer( threadIndex, rangeParams[threadIndex], + initializers.start( new TrackerInitializer( threadIndex, rangeParams[threadIndex], threadIndex > 0 ? bucketRange[threadIndex-1] : -1, bucketRange[threadIndex], - sortBuckets[threadIndex] ); + sortBuckets[threadIndex] ) ); threadIndex++; } else @@ -148,9 +146,9 @@ private int[][] sortRadix() throws InterruptedException bucketRange[threadIndex] = radixIndexCount.length; rangeParams[threadIndex][0] = fullCount; rangeParams[threadIndex][1] = safeCastLongToInt( dataStats.size() - fullCount ); - initializers[threadIndex] = new TrackerInitializer( threadIndex, rangeParams[threadIndex], + initializers.start( new TrackerInitializer( threadIndex, rangeParams[threadIndex], threadIndex > 0 ? bucketRange[threadIndex-1] : -1, bucketRange[threadIndex], - sortBuckets[threadIndex] ); + sortBuckets[threadIndex] ) ); break; } } @@ -159,23 +157,15 @@ private int[][] sortRadix() throws InterruptedException // In the loop above where we split up radixes into buckets, we start one thread per bucket whose // job is to populate trackerCache and sortBuckets where each thread will not touch the same // data indexes as any other thread. Here we wait for them all to finish. + Throwable error = initializers.await(); int[] bucketIndex = new int[threads]; - Throwable error = null; long highestIndex = -1, size = 0; - for ( int i = 0; i < initializers.length; i++ ) + int i = 0; + for ( TrackerInitializer initializer : initializers ) { - TrackerInitializer initializer = initializers[i]; - if ( initializer != null ) - { - Throwable initializerError = initializer.await(); - if ( initializerError != null ) - { - error = initializerError; - } - bucketIndex[i] = initializer.bucketIndex; - highestIndex = max( highestIndex, initializer.highestIndex ); - size += initializer.size; - } + bucketIndex[i++] = initializer.bucketIndex; + highestIndex = max( highestIndex, initializer.highestIndex ); + size += initializer.size; } trackerStats.set( size, highestIndex ); if ( error != null ) @@ -301,20 +291,17 @@ public boolean ge( long right, long pivot ) * instead trackerCache is updated to point to the right indexes. Only touches a designated part of trackerCache * so that many can run in parallel on their own part without synchronization. */ - private class SortWorker extends Thread + private class SortWorker implements Runnable { private final int start, size; - private final CountDownLatch doneSignal, waitSignal; - private int workerId = -1; + private final CountDownLatch waitSignal; private int threadLocalProgress; - SortWorker( int workerId, int startRange, int size, CountDownLatch wait, CountDownLatch done ) + SortWorker( int startRange, int size, CountDownLatch wait ) { this.start = startRange; this.size = size; - this.doneSignal = done; this.waitSignal = wait; - this.workerId = workerId; } void incrementProgress( int diff ) @@ -336,7 +323,6 @@ private void reportProgress() public void run() { Random random = ThreadLocalRandom.current(); - this.setName( "SortWorker-" + workerId ); try { waitSignal.await(); @@ -347,7 +333,6 @@ public void run() } recursiveQsort( start, start + size, random, this ); reportProgress(); - doneSignal.countDown(); } } @@ -355,7 +340,7 @@ public void run() * Sets the initial tracker indexes pointing to data indexes. Only touches a designated part of trackerCache * so that many can run in parallel on their own part without synchronization. */ - private class TrackerInitializer extends Thread + private class TrackerInitializer implements Runnable { private final int[] rangeParams; private final int lowBucketRange; @@ -363,7 +348,6 @@ private class TrackerInitializer extends Thread private final int threadIndex; private int bucketIndex; private final long[] result; - private volatile Throwable error; private long highestIndex = -1; private long size; @@ -374,46 +358,32 @@ private class TrackerInitializer extends Thread this.lowBucketRange = lowBucketRange; this.highBucketRange = highBucketRange; this.result = result; - start(); } @Override public void run() { - try + long max = dataStats.highestIndex(); + for ( long i = 0; i <= max; i++ ) { - long dataSize = dataStats.size(); - for ( long i = 0; i < dataSize; i++ ) + int rIndex = radixCalculator.radixOf( dataCache.get( i ) ); + if ( rIndex > lowBucketRange && rIndex <= highBucketRange ) { - int rIndex = radixCalculator.radixOf( dataCache.get( i ) ); - if ( rIndex > lowBucketRange && rIndex <= highBucketRange ) + long temp = (rangeParams[0] + bucketIndex++); + assert tracker.get( temp ) == -1 : "Overlapping buckets i:" + i + ", k:" + threadIndex; + tracker.set( temp, (int) i ); + if ( bucketIndex == rangeParams[1] ) { - long temp = (rangeParams[0] + bucketIndex++); - assert tracker.get( temp ) == -1 : "Overlapping buckets i:" + i + ", k:" + threadIndex; - tracker.set( temp, (int) i ); - if ( bucketIndex == rangeParams[1] ) - { - result[0] = highBucketRange; - result[1] = rangeParams[0]; - } + result[0] = highBucketRange; + result[1] = rangeParams[0]; } } - if ( bucketIndex > 0 ) - { - highestIndex = rangeParams[0] + bucketIndex - 1; - } - size = bucketIndex; } - catch ( Throwable t ) + if ( bucketIndex > 0 ) { - error = t; + highestIndex = rangeParams[0] + bucketIndex - 1; } - } - - private synchronized Throwable await() throws InterruptedException - { - join(); - return error; + size = bucketIndex; } } } diff --git a/community/kernel/src/main/java/org/neo4j/unsafe/impl/batchimport/cache/idmapping/string/Workers.java b/community/kernel/src/main/java/org/neo4j/unsafe/impl/batchimport/cache/idmapping/string/Workers.java new file mode 100644 index 0000000000000..d493d7aeef746 --- /dev/null +++ b/community/kernel/src/main/java/org/neo4j/unsafe/impl/batchimport/cache/idmapping/string/Workers.java @@ -0,0 +1,138 @@ +/* + * Copyright (c) 2002-2015 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.unsafe.impl.batchimport.cache.idmapping.string; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.ExecutorService; + +import org.neo4j.function.Function; +import org.neo4j.helpers.Exceptions; +import org.neo4j.helpers.collection.Iterables; + +/** + * Utility for running a handful of {@link Runnable} in parallel, each in its own thread. + * {@link Runnable} instances are {@link #start(Runnable) added and started} and the caller can + * {@link #await()} them all to finish, returning a {@link Throwable error} if any thread encountered one so + * that the caller can decide how to handle that error. Or caller can use {@link #awaitAndThrowOnError()} + * where error from any worker would be thrown from that method. + * + * It's basically like using an {@link ExecutorService}, but without that "baggage" and an easier usage + * and less code in the scenario described above. + */ +public class Workers implements Iterable +{ + private final List workers = new ArrayList<>(); + private final String names; + + public Workers( String names ) + { + this.names = names; + } + + /** + * Starts a thread to run {@code toRun}. + */ + public void start( R toRun ) + { + Worker worker = new Worker( names + "-" + workers.size(), toRun ); + worker.start(); + workers.add( worker ); + } + + public Throwable await() throws InterruptedException + { + Throwable error = null; + for ( Worker worker : workers ) + { + Throwable anError = worker.await(); + if ( error == null ) + { + error = anError; + } + } + return error; + } + + public void awaitAndThrowOnError( Class launderingException ) + throws EXCEPTION, InterruptedException + { + Throwable error = await(); + if ( error != null ) + { + throw Exceptions.launderedException( launderingException, error ); + } + } + + public void awaitAndThrowOnError() throws InterruptedException + { + Throwable error = await(); + if ( error != null ) + { + throw Exceptions.launderedException( error ); + } + } + + @Override + public Iterator iterator() + { + return Iterables.map( new Function() + { + @Override + public R apply( Worker worker ) throws RuntimeException + { + return worker.toRun; + } + }, workers.iterator() ); + } + + private class Worker extends Thread + { + private volatile Throwable error; + private final R toRun; + + Worker( String name, R toRun ) + { + super( name ); + this.toRun = toRun; + } + + @Override + public void run() + { + try + { + toRun.run(); + } + catch ( Throwable t ) + { + error = t; + throw Exceptions.launderedException( t ); + } + } + + protected synchronized Throwable await() throws InterruptedException + { + join(); + return error; + } + } +}