Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Concurrency changes, primary around thread interrupted status and pooling #4794

Merged
merged 3 commits into from Mar 22, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -13,10 +13,12 @@
* This is simplified ConcurrentHashSet implementation
*
* PLEASE NOTE: This class does NOT implement real equals & hashCode
*
* @deprecated Please use {@code Collections.newSetFromMap(new ConcurrentHashMap<>())}
*
* @author raver119@gmail.com
*/
// TODO: add equals/hashcode if needed
@Deprecated
public class ConcurrentHashSet<E> implements Set<E>, Serializable {
private static final long serialVersionUID = 123456789L;

Expand Down
7 changes: 7 additions & 0 deletions deeplearning4j-graph/pom.xml
Expand Up @@ -17,9 +17,16 @@
<version>${project.version}</version>
</dependency>

<dependency>
<groupId>org.threadly</groupId>
<artifactId>threadly</artifactId>
<version>${threadly.version}</version>
</dependency>

<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<scope>test</scope>
</dependency>

</dependencies>
Expand Down
Expand Up @@ -12,6 +12,8 @@
import org.deeplearning4j.graph.models.embeddings.InMemoryGraphLookupTable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.threadly.concurrent.PriorityScheduler;
import org.threadly.concurrent.future.FutureUtils;

import java.util.ArrayList;
import java.util.List;
Expand All @@ -37,7 +39,6 @@ public class DeepWalk<V, E> extends GraphVectorsImpl<V, E> {
private double learningRate;
private boolean initCalled = false;
private long seed;
private ExecutorService executorService;
private int nThreads = Runtime.getRuntime().availableProcessors();
private transient AtomicLong walkCounter = new AtomicLong(0);

Expand Down Expand Up @@ -116,37 +117,25 @@ public void fit(GraphWalkIteratorProvider<V> iteratorProvider) {
throw new UnsupportedOperationException("DeepWalk not initialized (call initialize before fit)");
List<GraphWalkIterator<V>> iteratorList = iteratorProvider.getGraphWalkIterators(nThreads);

executorService = Executors.newFixedThreadPool(nThreads, new ThreadFactory() {
@Override
public Thread newThread(Runnable r) {
Thread t = new Thread(r);
t.setDaemon(true);
return t;
}
});
PriorityScheduler scheduler = new PriorityScheduler(nThreads);

List<Future<Void>> list = new ArrayList<>(iteratorList.size());
//log.info("Fitting Graph with {} threads", Math.max(nThreads,iteratorList.size()));
for (GraphWalkIterator<V> iter : iteratorList) {
LearningCallable c = new LearningCallable(iter);
list.add(executorService.submit(c));
list.add(scheduler.submit(c));
}

executorService.shutdown();
scheduler.shutdown(); // wont shutdown till complete

try {
executorService.awaitTermination(999, TimeUnit.DAYS);
FutureUtils.blockTillAllCompleteOrFirstError(list);
} catch (InterruptedException e) {
// should not be possible with blocking till scheduler terminates
Thread.currentThread().interrupt();
throw new RuntimeException("ExecutorService interrupted", e);
}

//Don't need to block on futures to get a value out, but we want to re-throw any exceptions encountered
for (Future<Void> f : list) {
try {
f.get();
} catch (Exception e) {
throw new RuntimeException(e);
}
throw new RuntimeException(e);
} catch (ExecutionException e) {
throw new RuntimeException(e);
}
}

Expand Down
Expand Up @@ -23,9 +23,15 @@
<artifactId>nd4j-api</artifactId>
<version>${nd4j.version}</version>
</dependency>
<dependency>
<groupId>org.threadly</groupId>
<artifactId>threadly</artifactId>
<version>${threadly.version}</version>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
Expand Down
Expand Up @@ -40,7 +40,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executor;

/**
*
Expand All @@ -61,7 +61,7 @@ public class BaseClusteringAlgorithm implements ClusteringAlgorithm, Serializabl
private int currentIteration = 0;
private ClusterSet clusterSet;
private List<Point> initialPoints;
private transient ExecutorService exec;
private transient Executor exec;



Expand Down
Expand Up @@ -31,6 +31,7 @@
import org.nd4j.linalg.factory.Nd4j;

import java.util.*;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;

/**
Expand All @@ -43,7 +44,7 @@ public class ClusterUtils {

/** Classify the set of points base on cluster centers. This also adds each point to the ClusterSet */
public static ClusterSetInfo classifyPoints(final ClusterSet clusterSet, List<Point> points,
ExecutorService executorService) {
Executor executorService) {
final ClusterSetInfo clusterSetInfo = ClusterSetInfo.initialize(clusterSet, true);

List<Runnable> tasks = new ArrayList<>();
Expand Down Expand Up @@ -71,7 +72,7 @@ public static PointClassification classifyPoint(ClusterSet clusterSet, Point poi
}

public static void refreshClustersCenters(final ClusterSet clusterSet, final ClusterSetInfo clusterSetInfo,
ExecutorService executorService) {
Executor executorService) {
List<Runnable> tasks = new ArrayList<>();
int nClusters = clusterSet.getClusterCount();
for (int i = 0; i < nClusters; i++) {
Expand Down Expand Up @@ -136,7 +137,7 @@ public static void deriveClusterInfoDistanceStatistics(ClusterInfo info) {
* @return
*/
public static INDArray computeSquareDistancesFromNearestCluster(final ClusterSet clusterSet,
final List<Point> points, INDArray previousDxs, ExecutorService executorService) {
final List<Point> points, INDArray previousDxs, Executor executorService) {
final int pointsCount = points.size();
final INDArray dxs = Nd4j.create(pointsCount);
final Cluster newCluster = clusterSet.getClusters().get(clusterSet.getClusters().size() - 1);
Expand Down Expand Up @@ -181,7 +182,7 @@ public static ClusterSetInfo computeClusterSetInfo(ClusterSet clusterSet) {
return info;
}

public static ClusterSetInfo computeClusterSetInfo(final ClusterSet clusterSet, ExecutorService executorService) {
public static ClusterSetInfo computeClusterSetInfo(final ClusterSet clusterSet, Executor executorService) {
final ClusterSetInfo info = new ClusterSetInfo(clusterSet.isInverse(), true);
int clusterCount = clusterSet.getClusterCount();

Expand Down Expand Up @@ -266,7 +267,7 @@ public static ClusterInfo computeClusterInfos(Cluster cluster, String distanceFu
* @return
*/
public static boolean applyOptimization(OptimisationStrategy optimization, ClusterSet clusterSet,
ClusterSetInfo clusterSetInfo, ExecutorService executor) {
ClusterSetInfo clusterSetInfo, Executor executor) {

if (optimization.isClusteringOptimizationType(
ClusteringOptimizationType.MINIMIZE_AVERAGE_POINT_TO_CENTER_DISTANCE)) {
Expand Down Expand Up @@ -368,7 +369,7 @@ public static List<Cluster> getClustersWhereMaximumDistanceFromCenterGreaterThan
* @return
*/
public static int splitMostSpreadOutClusters(ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, int count,
ExecutorService executorService) {
Executor executorService) {
List<Cluster> clustersToSplit = getMostSpreadOutClusters(clusterSet, clusterSetInfo, count);
splitClusters(clusterSet, clusterSetInfo, clustersToSplit, executorService);
return clustersToSplit.size();
Expand All @@ -383,7 +384,7 @@ public static int splitMostSpreadOutClusters(ClusterSet clusterSet, ClusterSetIn
* @return
*/
public static int splitClustersWhereAverageDistanceFromCenterGreaterThan(ClusterSet clusterSet,
ClusterSetInfo clusterSetInfo, double maxWithinClusterDistance, ExecutorService executorService) {
ClusterSetInfo clusterSetInfo, double maxWithinClusterDistance, Executor executorService) {
List<Cluster> clustersToSplit = getClustersWhereAverageDistanceFromCenterGreaterThan(clusterSet, clusterSetInfo,
maxWithinClusterDistance);
splitClusters(clusterSet, clusterSetInfo, clustersToSplit, maxWithinClusterDistance, executorService);
Expand All @@ -399,7 +400,7 @@ public static int splitClustersWhereAverageDistanceFromCenterGreaterThan(Cluster
* @return
*/
public static int splitClustersWhereMaximumDistanceFromCenterGreaterThan(ClusterSet clusterSet,
ClusterSetInfo clusterSetInfo, double maxWithinClusterDistance, ExecutorService executorService) {
ClusterSetInfo clusterSetInfo, double maxWithinClusterDistance, Executor executorService) {
List<Cluster> clustersToSplit = getClustersWhereMaximumDistanceFromCenterGreaterThan(clusterSet, clusterSetInfo,
maxWithinClusterDistance);
splitClusters(clusterSet, clusterSetInfo, clustersToSplit, maxWithinClusterDistance, executorService);
Expand All @@ -414,7 +415,7 @@ public static int splitClustersWhereMaximumDistanceFromCenterGreaterThan(Cluster
* @param executorService
*/
public static void splitMostPopulatedClusters(ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, int count,
ExecutorService executorService) {
Executor executorService) {
List<Cluster> clustersToSplit = clusterSet.getMostPopulatedClusters(count);
splitClusters(clusterSet, clusterSetInfo, clustersToSplit, executorService);
}
Expand All @@ -428,7 +429,7 @@ public static void splitMostPopulatedClusters(ClusterSet clusterSet, ClusterSetI
* @param executorService
*/
public static void splitClusters(final ClusterSet clusterSet, final ClusterSetInfo clusterSetInfo,
List<Cluster> clusters, final double maxDistance, ExecutorService executorService) {
List<Cluster> clusters, final double maxDistance, Executor executorService) {
final Random random = new Random();
List<Runnable> tasks = new ArrayList<>();
for (final Cluster cluster : clusters) {
Expand Down Expand Up @@ -459,7 +460,7 @@ public void run() {
* @param executorService
*/
public static void splitClusters(final ClusterSet clusterSet, final ClusterSetInfo clusterSetInfo,
List<Cluster> clusters, ExecutorService executorService) {
List<Cluster> clusters, Executor executorService) {
final Random random = new Random();
List<Runnable> tasks = new ArrayList<>();
for (final Cluster cluster : clusters) {
Expand Down
Expand Up @@ -20,6 +20,8 @@

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.threadly.concurrent.PriorityScheduler;
import org.threadly.concurrent.wrapper.compatibility.PrioritySchedulerServiceWrapper;

import java.util.List;
import java.util.concurrent.*;
Expand All @@ -28,24 +30,14 @@ public class MultiThreadUtils {

private static Logger log = LoggerFactory.getLogger(MultiThreadUtils.class);

private static ExecutorService instance;

private MultiThreadUtils() {}

public static synchronized ExecutorService newExecutorService() {
public static ExecutorService newExecutorService() {
int nThreads = Runtime.getRuntime().availableProcessors();
return new ThreadPoolExecutor(nThreads, nThreads, 60L, TimeUnit.SECONDS, new LinkedTransferQueue<Runnable>(),
new ThreadFactory() {
@Override
public Thread newThread(Runnable r) {
Thread t = Executors.defaultThreadFactory().newThread(r);
t.setDaemon(true);
return t;
}
});
return new PrioritySchedulerServiceWrapper(new PriorityScheduler(nThreads));
}

public static void parallelTasks(final List<Runnable> tasks, ExecutorService executorService) {
public static void parallelTasks(final List<Runnable> tasks, Executor executorService) {
int tasksCount = tasks.size();
final CountDownLatch latch = new CountDownLatch(tasksCount);
for (int i = 0; i < tasksCount; i++) {
Expand All @@ -62,10 +54,10 @@ public void run() {
}
});
}

try {
latch.await();
} catch (Exception e) {
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException(e);
}
}
Expand Down
Expand Up @@ -577,17 +577,10 @@ public Node(int index, float threshold) {
public void fetchFutures() {
try {
if (futureLeft != null) {
while (!futureLeft.isDone())
Thread.sleep(100);


left = futureLeft.get();
}

if (futureRight != null) {
while (!futureRight.isDone())
Thread.sleep(100);

right = futureRight.get();
}

Expand All @@ -597,7 +590,10 @@ public void fetchFutures() {

if (right != null)
right.fetchFutures();
} catch (Exception e) {
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException(e);
} catch (ExecutionException e) {
throw new RuntimeException(e);
}

Expand Down
7 changes: 7 additions & 0 deletions deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml
Expand Up @@ -53,9 +53,16 @@
<version>${project.version}</version>
</dependency>

<dependency>
<groupId>org.threadly</groupId>
<artifactId>threadly</artifactId>
<version>${threadly.version}</version>
</dependency>

<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<scope>test</scope>
</dependency>

<dependency>
Expand Down