Skip to content

Commit

Permalink
Add more tests and improve executor shutdown documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
yurloc authored and ge0ffrey committed Oct 16, 2017
1 parent 041b344 commit d725e8b
Show file tree
Hide file tree
Showing 3 changed files with 219 additions and 69 deletions.
Expand Up @@ -144,17 +144,27 @@ public void solve(DefaultSolverScope<Solution_> solverScope) {
phaseScope.setLastCompletedStepScope(stepScope);
}
} finally {
// If a partition thread throws an Exception, it is propagated here
// but the other partition threads won't finish any time soon, so we need to ask them to terminate
if (childThreadPlumbingTermination.terminateChildren()) {
logger.info("Shutting down thread pool.");
if (Thread.interrupted()) {
// 1a. If current thread is interrupted, propagate interrupt signal to children by initiating
// abrupt shutdown.
executor.shutdownNow();
} else {
// 1b. Otherwise, initiate graceful shutdown of the executor. This allows partition solvers to finish
// solving upon detecting the termination issued in the next step (2). Shutting down the executor
// service is important because the JVM cannot exit until all nondaemon threads have terminated.
executor.shutdown();
}

// 2. In case on of the partition threads threw an Exception, it is propagated here
// but the other partition threads are not aware of the failure and may continue solving for a long time,
// so we need to ask them to terminate. In case no exception was thrown, this does nothing.
if (!childThreadPlumbingTermination.terminateChildren()) {
logger.info("Termination of children wasn't sucessful.");
}

// 3. Finally, wait until the executor finishes shutting down
try {
// First wait for solvers to terminate voluntarily (because we have just issued children termination)
final int awaitingSeconds = 10; // TODO revert back to 1 second
final int awaitingSeconds = 1;
if (!executor.awaitTermination(awaitingSeconds, TimeUnit.SECONDS)) {
// Some solvers refused to complete. Busy threads will be interrupted in the finally block
logger.warn("{}Partitioned Search threadPoolExecutor didn't terminate within timeout ({} second).",
Expand All @@ -167,6 +177,7 @@ public void solve(DefaultSolverScope<Solution_> solverScope) {
Thread.currentThread().interrupt();
throw new IllegalStateException("Thread pool shutdown was interrupted.", e);
} finally {
// Initiate an abrupt shutdown
executor.shutdownNow();
}
}
Expand Down
Expand Up @@ -15,26 +15,30 @@
*/
package org.optaplanner.core.impl.partitionedsearch;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import org.junit.Test;
import org.junit.*;
import org.optaplanner.core.api.solver.Solver;
import org.optaplanner.core.api.solver.SolverFactory;
import org.optaplanner.core.config.constructionheuristic.ConstructionHeuristicPhaseConfig;
import org.optaplanner.core.config.localsearch.LocalSearchPhaseConfig;
import org.optaplanner.core.config.partitionedsearch.PartitionedSearchPhaseConfig;
import org.optaplanner.core.config.solver.SolverConfig;
import org.optaplanner.core.config.solver.termination.TerminationConfig;
import org.optaplanner.core.impl.partitionedsearch.partitioner.SolutionPartitioner;
import org.optaplanner.core.impl.partitionedsearch.scope.PartitionedSearchPhaseScope;
import org.optaplanner.core.impl.phase.event.PhaseLifecycleListenerAdapter;
import org.optaplanner.core.impl.phase.scope.AbstractPhaseScope;
import org.optaplanner.core.impl.score.director.ScoreDirector;
import org.optaplanner.core.impl.solver.DefaultSolver;
import org.optaplanner.core.impl.testdata.domain.TestdataEntity;
import org.optaplanner.core.impl.testdata.domain.TestdataSolution;
Expand All @@ -45,15 +49,41 @@

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

public class DefaultPartitionedSearchPhaseTest {

private static final Logger logger = LoggerFactory.getLogger(DefaultPartitionedSearchPhaseTest.class);

private static void findCauseOrFail(Throwable ex, Class<? extends Throwable> cause) {
findCauseOrFail(ex, cause, "");
}

private static void findCauseOrFail(Throwable ex, Class<? extends Throwable> cause, String msgSubstring) {
Throwable t = ex.getCause();
while (t != null) {
if (cause.isAssignableFrom(t.getClass())) {
break;
} else {
t = t.getCause();
}
}
if (t == null) {
logger.error("Solver failure was caused by something unexpected:", ex);
fail("Solver failure should have been caused by " + cause.getCanonicalName()
+ " but was caused by something unexpected.");
} else {
assertTrue("Exception message (" + t.getMessage() + ") should contain substring: " + msgSubstring,
Objects.toString(t.getMessage(), "").contains(msgSubstring));
}
}

@Test
public void partCount() {
final int partSize = 3;
final int partCount = 7;
SolverFactory<TestdataSolution> solverFactory = createSolverFactory();
SolverFactory<TestdataSolution> solverFactory = createSolverFactory(false);
setPartSize(solverFactory.getSolverConfig(), partSize);
DefaultSolver<TestdataSolution> solver = (DefaultSolver<TestdataSolution>) solverFactory.buildSolver();
PartitionedSearchPhase<TestdataSolution> phase
Expand All @@ -67,7 +97,7 @@ public void phaseStarted(AbstractPhaseScope<TestdataSolution> phaseScope) {
solver.solve(createSolution(partCount * partSize, 2));
}

private static SolverFactory<TestdataSolution> createSolverFactory() {
private static SolverFactory<TestdataSolution> createSolverFactory(boolean infinite) {
SolverFactory<TestdataSolution> solverFactory = PlannerTestUtils
.buildSolverFactory(TestdataSolution.class, TestdataEntity.class);
SolverConfig solverConfig = solverFactory.getSolverConfig();
Expand All @@ -76,7 +106,9 @@ private static SolverFactory<TestdataSolution> createSolverFactory() {
solverConfig.setPhaseConfigList(Arrays.asList(partitionedSearchPhaseConfig));
ConstructionHeuristicPhaseConfig constructionHeuristicPhaseConfig = new ConstructionHeuristicPhaseConfig();
LocalSearchPhaseConfig localSearchPhaseConfig = new LocalSearchPhaseConfig();
localSearchPhaseConfig.setTerminationConfig(new TerminationConfig().withStepCountLimit(1));
if (!infinite) {
localSearchPhaseConfig.setTerminationConfig(new TerminationConfig().withStepCountLimit(1));
}
partitionedSearchPhaseConfig.setPhaseConfigList(
Arrays.asList(constructionHeuristicPhaseConfig, localSearchPhaseConfig));
return solverFactory;
Expand All @@ -103,70 +135,123 @@ private static void setPartSize(SolverConfig solverConfig, int partSize) {
phaseConfig.setSolutionPartitionerCustomProperties(map);
}

public static class TestdataSolutionPartitioner implements SolutionPartitioner<TestdataSolution> {
@Test
public void exceptionPropagation() {
final int partSize = 7;
final int partCount = 3;

/**
* {@link PartitionedSearchPhaseConfig#solutionPartitionerCustomProperties Custom property}.
*/
private int partSize = 1;
TestdataSolution solution = createSolution(partCount * partSize - 1, 100);
solution.getEntityList().add(new FaultyEntity("XYZ"));
assertEquals(partSize * partCount, solution.getEntityList().size());

public void setPartSize(int partSize) {
this.partSize = partSize;
SolverFactory<TestdataSolution> solverFactory = createSolverFactory(false);
setPartSize(solverFactory.getSolverConfig(), partSize);
Solver<TestdataSolution> solver = solverFactory.buildSolver();
try {
solver.solve(solution);
fail("The exception was not propagated.");
} catch (IllegalStateException ex) {
findCauseOrFail(ex, ArithmeticException.class);
}
}

@Override
public List<TestdataSolution> splitWorkingSolution(
ScoreDirector<TestdataSolution> scoreDirector, Integer runnablePartThreadLimit) {
TestdataSolution workingSolution = scoreDirector.getWorkingSolution();
List<TestdataEntity> allEntities = workingSolution.getEntityList();
if (allEntities.size() % partSize > 0) {
throw new IllegalStateException("This partitioner can only make equally sized partitions.");
}
List<TestdataSolution> partitions = new ArrayList<>();
for (int i = 0; i < allEntities.size() / partSize; i++) {
List<TestdataEntity> partitionEntitites = new ArrayList<>(
allEntities.subList(i * partSize, (i + 1) * partSize)
);
TestdataSolution partition = new TestdataSolution();
partition.setEntityList(partitionEntitites);
partition.setValueList(workingSolution.getValueList());
partitions.add(partition);
}
return partitions;
}
@Test(timeout = 5000)
public void terminateEarly() throws InterruptedException, ExecutionException {
final int partSize = 1;
final int partCount = 2;

}
TestdataSolution solution = createSolution(partCount * partSize, 10);

@Test
public void terminateEarly() {
// TODO?
SolverFactory<TestdataSolution> solverFactory = createSolverFactory(true);
setPartSize(solverFactory.getSolverConfig(), partSize);
Solver<TestdataSolution> solver = solverFactory.buildSolver();

ExecutorService executor = Executors.newSingleThreadExecutor();
Future<TestdataSolution> solvedSolution = executor.submit(() -> {
return solver.solve(solution);
});

while (!solver.isSolving()) {
// wait until solver starts solving before terminating early
}
assertTrue(solver.terminateEarly());
assertTrue(solver.isTerminateEarly());

executor.shutdown();
assertTrue(executor.awaitTermination(100, TimeUnit.MILLISECONDS));
assertNotNull(solvedSolution.get());
}

@Test
public void exceptionPropagation() {
final int partSize = 7;
@Test(timeout = 5000)
// FIXME rename, because this test interrupts the main thread in PQueue iterator!!
// TODO add another test that will interrupt one of the PartSolver threads (will require custom ThreadFactory that
// will provide access to created threads)
public void shutdownAbruptly() throws InterruptedException {
final int partSize = 5;
final int partCount = 3;

TestdataSolution solution = createSolution(partCount * partSize - 1, 100);
solution.getEntityList().add(new FaultyEntity("XYZ"));
assertEquals(partSize * partCount, solution.getEntityList().size());
TestdataSolution solution = createSolution(partCount * partSize - 1, 10);
CountDownLatch latch = new CountDownLatch(1);
solution.getEntityList().add(new BusyEntity("XYZ", latch));

SolverFactory<TestdataSolution> solverFactory = createSolverFactory();
SolverFactory<TestdataSolution> solverFactory = createSolverFactory(true);
setPartSize(solverFactory.getSolverConfig(), partSize);
DefaultSolver<TestdataSolution> solver = (DefaultSolver<TestdataSolution>) solverFactory.buildSolver();
Solver<TestdataSolution> solver = solverFactory.buildSolver();

ExecutorService executor = Executors.newSingleThreadExecutor();
Future<TestdataSolution> solvedSolution = executor.submit(() -> {
return solver.solve(solution);
});

latch.await();
// Now we know the busy entity is busy so we can attempt to shut down.
// This will initiate an abrupt shutdown that will interrupt all busy threads in the pool.
executor.shutdownNow();

// This verifies that solver checks Thread's interrupted flag and terminates solving when it detects the flag.
assertTrue("Executor must terminate successfully when it's shut down abruptly",
executor.awaitTermination(1000, TimeUnit.MILLISECONDS));
// This verifies that solver doesn't clear the interrupted flag
try {
solver.solve(solution);
fail("Test failed");
} catch (IllegalStateException e) {
Throwable t = e.getCause();
while (t != null) {
if (t instanceof PartionSolverInterruptedException) {
break;
} else {
t = t.getCause();
}
solvedSolution.get();
fail("InterruptedException should have been propagated to solver thread.");
} catch (ExecutionException ex) {
findCauseOrFail(ex, IllegalStateException.class, "Solver thread was interrupted in Partitioned Search");
findCauseOrFail(ex, InterruptedException.class);
}
}

public static class BusyEntity extends TestdataEntity {

private static final Logger logger = LoggerFactory.getLogger(BusyEntity.class);
private CountDownLatch latch;

public BusyEntity() {
// needed for cloning
}

public BusyEntity(String code, CountDownLatch cdl) {
super(code);
this.latch = cdl;
}

public CountDownLatch getLatch() {
return latch;
}

public void setLatch(CountDownLatch latch) {
this.latch = latch;
}

@Override
public void setValue(TestdataValue value) {
super.setValue(value);
latch.countDown();
logger.info("SETVALUE... STARTED");
while (!Thread.currentThread().isInterrupted()) {
// busy wait
}
assertNotNull(t);
logger.info("SETVALUE... INTERRUPTED!");
}
}

Expand All @@ -186,11 +271,8 @@ public FaultyEntity(String code) {
public void setValue(TestdataValue value) {
super.setValue(value);
logger.info("SOLVER FAULT");
throw new PartionSolverInterruptedException();
int zero = 0;
logger.info("{}", 1 / zero);
}
}

public static class PartionSolverInterruptedException extends RuntimeException {

}
}
@@ -0,0 +1,57 @@
/*
* Copyright 2017 Red Hat, Inc. and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.optaplanner.core.impl.partitionedsearch;

import java.util.ArrayList;
import java.util.List;

import org.optaplanner.core.impl.partitionedsearch.partitioner.SolutionPartitioner;
import org.optaplanner.core.impl.score.director.ScoreDirector;
import org.optaplanner.core.impl.testdata.domain.TestdataEntity;
import org.optaplanner.core.impl.testdata.domain.TestdataSolution;

public class TestdataSolutionPartitioner implements SolutionPartitioner<TestdataSolution> {

/**
* {@link PartitionedSearchPhaseConfig#solutionPartitionerCustomProperties Custom property}.
*/
private int partSize = 1;

public void setPartSize(int partSize) {
this.partSize = partSize;
}

@Override
public List<TestdataSolution> splitWorkingSolution(ScoreDirector<TestdataSolution> scoreDirector,
Integer runnablePartThreadLimit) {
TestdataSolution workingSolution = scoreDirector.getWorkingSolution();
List<TestdataEntity> allEntities = workingSolution.getEntityList();
if (allEntities.size() % partSize > 0) {
throw new IllegalStateException("This partitioner can only make equally sized partitions.");
}
List<TestdataSolution> partitions = new ArrayList<>();
for (int i = 0; i < allEntities.size() / partSize; i++) {
List<TestdataEntity> partitionEntitites
= new ArrayList<>(allEntities.subList(i * partSize, (i + 1) * partSize));
TestdataSolution partition = new TestdataSolution();
partition.setEntityList(partitionEntitites);
partition.setValueList(workingSolution.getValueList());
partitions.add(partition);
}
return partitions;
}

}

0 comments on commit d725e8b

Please sign in to comment.