Permalink
Browse files

TaskThreadPool: await until current tasks have finished, refactored t…

…ests for better reproducability
  • Loading branch information...
orfjackal committed Nov 27, 2008
1 parent 4e377dc commit a01e06166d49995fd703206a5d023cf370705ed2
@@ -34,8 +34,8 @@
import net.orfjackal.dimdwarf.tasks.TaskExecutor;
import org.slf4j.*;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.atomic.AtomicInteger;
+import java.util.*;
+import java.util.concurrent.*;
/**
* @author Esko Luontola
@@ -49,7 +49,7 @@
private final TaskProducer producer;
private final Thread consumer;
private final ExecutorService workers;
- private final AtomicInteger runningTasks = new AtomicInteger(0);
+ private final Set<CountDownLatch> runningTasks = Collections.synchronizedSet(new HashSet<CountDownLatch>());
public TaskThreadPool(TaskExecutor taskContext, TaskProducer producer, ExecutorService threadPool) {
this(taskContext, producer, threadPool, DEFAULT_LOGGER);
@@ -58,7 +58,7 @@ public TaskThreadPool(TaskExecutor taskContext, TaskProducer producer, ExecutorS
public TaskThreadPool(TaskExecutor taskContext, TaskProducer producer, ExecutorService threadPool, Logger logger) {
this.taskContext = taskContext;
this.producer = producer;
- this.consumer = new Thread(new TaskConsumer());
+ this.consumer = new Thread(new TaskConsumer(), "TaskConsumer");
this.workers = threadPool;
this.logger = logger;
}
@@ -68,7 +68,7 @@ public void start() {
}
public int getRunningTasks() {
- return runningTasks.get();
+ return runningTasks.size();
}
public void shutdown() {
@@ -83,6 +83,19 @@ public void shutdown() {
logger.info("Shutdown finished");
}
+ @SuppressWarnings({"ToArrayCallWithZeroLengthArrayArgument"})
+ public void awaitForCurrentTasksToFinish() throws InterruptedException {
+ // It would be dangerous to pass an array larger than 0 to the toArray() method,
+ // because there is a small chance that between the calls to size() and toArray()
+ // an entry is removed from the collection, and the returned array would be too
+ // big and would contain a null entry (toArray() does not shrink the array parameter
+ // if it's too big).
+ CountDownLatch[] snapshotOfRunningTasks = runningTasks.toArray(new CountDownLatch[0]);
+ for (CountDownLatch taskHasFinished : snapshotOfRunningTasks) {
+ taskHasFinished.await();
+ }
+ }
+
private class TaskConsumer implements Runnable {
public void run() {
@@ -106,13 +119,15 @@ public TaskContextSetup(Runnable task) {
}
public void run() {
+ CountDownLatch taskHasFinished = new CountDownLatch(1);
try {
- runningTasks.incrementAndGet();
+ runningTasks.add(taskHasFinished);
taskContext.execute(task);
} catch (Throwable t) {
logger.error("Task threw an exception", t);
} finally {
- runningTasks.decrementAndGet();
+ runningTasks.remove(taskHasFinished);
+ taskHasFinished.countDown();
}
}
}
@@ -80,6 +80,33 @@ public TaskBootstrap takeNextTask() throws InterruptedException {
pool.start();
}
+ @SuppressWarnings("ThrowableResultOfMethodCallIgnored")
+ public void destroy() throws Exception {
+ checking(new Expectations() {{
+ allowing(logger).info(with(any(String.class)));
+ allowing(logger).info(with(any(String.class)), with(any(Throwable.class)));
+ }});
+ pool.shutdown();
+ }
+
+ private static void executeAfterCurrentThreadIsNotRunning(final Runnable command) {
+ final Thread currentThread = Thread.currentThread();
+ Thread t = new Thread(new Runnable() {
+ public void run() {
+ Thread.State state = currentThread.getState();
+ if (state.equals(Thread.State.RUNNABLE)) {
+ try {
+ Thread.sleep(100);
+ } catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ command.run();
+ }
+ });
+ t.start();
+ }
+
public class WhenTasksAreAddedToTheQueue {
@@ -124,7 +151,6 @@ public void theyAreExecutedInsideTaskContext() throws InterruptedException {
private CountDownLatch step1 = new CountDownLatch(1);
private CountDownLatch step2 = new CountDownLatch(1);
private CountDownLatch step3 = new CountDownLatch(1);
- private CountDownLatch stepEnd = new CountDownLatch(1);
private volatile Integer runningTasks0 = null;
private volatile Integer runningTasks1 = null;
@@ -151,7 +177,6 @@ public void run() {
runningTasks2 = pool.getRunningTasks();
step2.countDown();
step3.await();
- stepEnd.await();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
@@ -164,9 +189,8 @@ public void run() {
taskQueue.add(new SimpleTaskBootstrap(task2));
step3.await(TEST_TIMEOUT, TimeUnit.MILLISECONDS);
- Thread.yield();
+ pool.awaitForCurrentTasksToFinish();
runningTasksEnd = pool.getRunningTasks();
- stepEnd.countDown();
}
public void theyAreExecutedInParallel() throws InterruptedException {
@@ -179,7 +203,71 @@ public void thePoolKnowsTheNumberOfRunningTasks() {
specify(runningTasks0, should.equal(0));
specify(runningTasks1, should.equal(1));
specify(runningTasks2, should.equal(2));
- specify(runningTasksEnd, should.equal(1));
+ specify(runningTasksEnd, should.equal(0));
+ }
+ }
+
+ public class WhenAClientWaitsForTheCurrentlyExecutingTasksToFinish {
+
+ private CountDownLatch firstTaskIsExecuting = new CountDownLatch(1);
+ private CountDownLatch clientIsWaitingForTasksToFinish = new CountDownLatch(1);
+ private CountDownLatch secondTaskIsExecuting = new CountDownLatch(1);
+ private CountDownLatch testHasEnded = new CountDownLatch(1);
+
+ private volatile boolean aNewTaskIsRunning = false;
+
+ public void create() throws InterruptedException {
+ final Runnable task2 = new Runnable() {
+ public void run() {
+ try {
+ aNewTaskIsRunning = true;
+ secondTaskIsExecuting.countDown();
+ testHasEnded.await();
+ aNewTaskIsRunning = false;
+ } catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ };
+ Runnable task1 = new Runnable() {
+ public void run() {
+ try {
+ firstTaskIsExecuting.countDown();
+ clientIsWaitingForTasksToFinish.await();
+ taskQueue.add(new SimpleTaskBootstrap(task2));
+ secondTaskIsExecuting.await();
+ } catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ };
+ taskQueue.add(new SimpleTaskBootstrap(task1));
+ firstTaskIsExecuting.await(TEST_TIMEOUT, TimeUnit.MILLISECONDS);
+
+ executeAfterCurrentThreadIsNotRunning(new Runnable() {
+ public void run() {
+ // Let's hope that this gets executed *after* the client begins waiting.
+ // There is no guarantee that this thread won't be executed first...
+ clientIsWaitingForTasksToFinish.countDown();
+ }
+ });
+ pool.awaitForCurrentTasksToFinish();
+ }
+
+ public void destroy() {
+ testHasEnded.countDown();
+ }
+
+ public void afterWaitingAllThePreviouslyExecutingTasksHaveFinished() {
+ if (aNewTaskIsRunning) {
+ specify(pool.getRunningTasks(), should.equal(1));
+ } else {
+ specify(pool.getRunningTasks(), should.equal(0));
+ }
+ }
+
+ public void afterWaitingOtherNewTasksMayBeExecuting() {
+ specify(aNewTaskIsRunning);
}
}
@@ -198,7 +286,7 @@ public void run() {
};
taskQueue.add(new SimpleTaskBootstrap(task));
end.await(TEST_TIMEOUT, TimeUnit.MILLISECONDS);
- Thread.yield();
+ pool.awaitForCurrentTasksToFinish();
}
public Expectations theExceptionIsLogged() {
@@ -214,7 +302,7 @@ public void theNumberOfRunningTasksIsDecrementedCorrectly() {
public class WhenThePoolIsShutDown {
- public void create() {
+ public void create() throws InterruptedException {
checking(theShutdownIsLogged());
pool.shutdown();
taskQueue.add(new SimpleTaskBootstrap(new Runnable() {
@@ -229,7 +317,7 @@ public void run() {
private Expectations theShutdownIsLogged() {
return new Expectations() {{
one(logger).info("Shutting down...");
- one(logger).info(with(equal("Task consumer was interrupted")), with(aNonNull(InterruptedException.class)));
+ allowing(logger).info(with(equal("Task consumer was interrupted")), with(aNonNull(InterruptedException.class)));
one(logger).info("Shutdown finished");
}};
}
@@ -252,6 +340,5 @@ public Runnable getTaskInsideTransaction() {
}
}
- // TODO: knows which tasks are executing and can tell when all currently executing tasks have finished (needed for GC)
// TODO: give access to the current task's ScheduledFuture?
}

0 comments on commit a01e061

Please sign in to comment.