Skip to content

Commit

Permalink
Enhance parallel testing junit_tests
Browse files Browse the repository at this point in the history
Surfaces EXPERIMENTAL support for parallel testing of methods within a test class.
 - Adds a 'concurrency' parameter to set test concurrency to 'serial' 'parallel', or 'parallel_method'
 - Adds a 'threads' parameter to junit_tests to control concurrency.
 - Adds a --test-junit-parallel-methods to allow methods to run in parallel for the entire test run
 - Added a number of unit test cases and integration tests to exercise these features

Note that there is a bug in that the @TestSerial annotation is not respected when the -parallel-methods flag is in use.  See pantsbuild#3209

Followup work for this change is to add an `@TestParallelMethods` annotation and come up with a more rational way to
pass concurrency options to the junit-runner backend.

Testing Done:
Integration test and additional unit tests added.
CI is green at https://travis-ci.org/pantsbuild/pants/builds/123712609

Bugs closed: 3191, 3210

Reviewed at https://rbcommons.com/s/twitter/r/3707/
  • Loading branch information
ericzundel committed Apr 20, 2016
1 parent cdca712 commit 96c7ecf
Show file tree
Hide file tree
Showing 33 changed files with 941 additions and 10 deletions.
4 changes: 4 additions & 0 deletions src/java/org/pantsbuild/junit/annotations/TestSerial.java
Expand Up @@ -13,6 +13,10 @@
* Annotate that a test class must be run in serial. See usage note in
* {@code org.pantsbuild.tools.junit.impl.ConsoleRunnerImpl}. This annotation takes precedence
* over a {@link TestParallel} annotation if a class has both (including via inheritance).
* <P>
* Note that this annotation is not currently compatible with the PARALLEL_METHODS default
* concurrency setting. See
* <a href="https://github.com/pantsbuild/pants/issues/3209">issue 3209</a>
*/
@Retention(RetentionPolicy.RUNTIME)
@Inherited
Expand Down
Expand Up @@ -657,6 +657,9 @@ class Options {
usage = "Show a description of each test and timer for each test class.")
private boolean perTestTimer;

// TODO(zundel): Combine -default-parallel and -paralel-methods together into a
// single argument: -default-concurrency {serial, parallel, parallel_methods}
// TODO(zundel): Also add a @TestParallelMethods annotation
@Option(name = "-default-parallel",
usage = "Whether to run test classes without @TestParallel or @TestSerial in parallel.")
private boolean defaultParallel;
Expand Down
43 changes: 40 additions & 3 deletions src/python/pants/backend/jvm/targets/java_tests.py
Expand Up @@ -7,6 +7,7 @@

from pants.backend.jvm.subsystems.jvm_platform import JvmPlatform
from pants.backend.jvm.targets.jvm_target import JvmTarget
from pants.base.exceptions import TargetDefinitionException
from pants.base.payload import Payload
from pants.base.payload_field import PrimitiveField

Expand All @@ -17,8 +18,11 @@ class JavaTests(JvmTarget):
:API: public
"""

_VALID_CONCURRENCY_OPTS = ['serial', 'parallel', 'parallel_methods']

def __init__(self, cwd=None, test_platform=None, payload=None, timeout=None,
extra_jvm_options=None, extra_env_vars=None, **kwargs):
extra_jvm_options=None, extra_env_vars=None, concurrency=None,
threads=None, **kwargs):
"""
:param str cwd: working directory (relative to the build root) for the tests under this
target. If unspecified (None), the working directory will be controlled by junit_run's --cwd.
Expand All @@ -31,8 +35,12 @@ def __init__(self, cwd=None, test_platform=None, payload=None, timeout=None,
tests. Example: ['-Dexample.property=1'] If unspecified, no extra jvm options will be added.
:param dict extra_env_vars: A map of environment variables to set when running the tests, e.g.
{ 'FOOBAR': 12 }. Using `None` as the value will cause the variable to be unset.
:param string concurrency: One of 'serial', 'parallel', or 'parallel_methods'. Overrides
the setting of --test-junit-default-parallel or --test-junit-parallel-methods options.
:param int threads: Use the specified number of threads when running the test. Overrides
the setting of --test-junit-parallel-threads.
"""
self.cwd = cwd

payload = payload or Payload()

if extra_env_vars is None:
Expand All @@ -43,12 +51,29 @@ def __init__(self, cwd=None, test_platform=None, payload=None, timeout=None,

payload.add_fields({
'test_platform': PrimitiveField(test_platform),
# TODO(zundel): Do extra_jvm_options and extra_env_vars really need to be fingerprinted?
'extra_jvm_options': PrimitiveField(tuple(extra_jvm_options or ())),
'extra_env_vars': PrimitiveField(tuple(extra_env_vars.items())),
})
self._timeout = timeout
super(JavaTests, self).__init__(payload=payload, **kwargs)

# These parameters don't need to go into the fingerprint:
self._concurrency = concurrency
self._cwd = cwd
self._threads = None
self._timeout = timeout

try:
if threads is not None:
self._threads = int(threads)
except ValueError:
raise TargetDefinitionException(self,
"The value for 'threads' must be an integer, got " + threads)
if concurrency and concurrency not in self._VALID_CONCURRENCY_OPTS:
raise TargetDefinitionException(self,
"The value for 'concurrency' must be one of "
+ repr(self._VALID_CONCURRENCY_OPTS) + " got: " + concurrency)

# TODO(John Sirois): These could be scala, clojure, etc. 'jvm' and 'tests' are the only truly
# applicable labels - fixup the 'java' misnomer.
self.add_labels('java', 'tests')
Expand All @@ -59,6 +84,18 @@ def test_platform(self):
return JvmPlatform.global_instance().get_platform_by_name(self.payload.test_platform)
return self.platform

@property
def concurrency(self):
return self._concurrency

@property
def cwd(self):
return self._cwd

@property
def threads(self):
return self._threads

@property
def timeout(self):
return self._timeout
1 change: 1 addition & 0 deletions src/python/pants/backend/jvm/tasks/BUILD
Expand Up @@ -326,6 +326,7 @@ python_library(
'src/python/pants/build_graph',
'src/python/pants/java:util',
'src/python/pants/task',
'src/python/pants/util:argutil',
'src/python/pants/util:contextutil',
'src/python/pants/util:dirutil',
'src/python/pants/util:process_handler',
Expand Down
53 changes: 49 additions & 4 deletions src/python/pants/backend/jvm/tasks/junit_run.py
Expand Up @@ -31,6 +31,7 @@
from pants.java.distribution.distribution import DistributionLocator
from pants.java.executor import SubprocessExecutor
from pants.task.testrunner_task_mixin import TestRunnerTaskMixin
from pants.util.argutil import ensure_arg, remove_arg
from pants.util.contextutil import environment_as
from pants.util.strutil import pluralize
from pants.util.xml_parser import XmlParser
Expand Down Expand Up @@ -62,6 +63,14 @@ class JUnitRun(TestRunnerTaskMixin, JvmToolTaskMixin, JvmTask):
"""
:API: public
"""

_CONCURRENCY_PARALLEL = 'PARALLEL'
_CONCURRENCY_PARALLEL_METHODS = 'PARALLEL_METHODS'
_CONCURRENCY_SERIAL = 'SERIAL'
_CONCURRENCY_CHOICES = [
_CONCURRENCY_PARALLEL, _CONCURRENCY_PARALLEL_METHODS, _CONCURRENCY_SERIAL
]

_MAIN = 'org.pantsbuild.tools.junit.ConsoleRunner'

@classmethod
Expand All @@ -73,7 +82,12 @@ def register_options(cls, register):
help='Force running of just these tests. Tests can be specified using any of: '
'[classname], [classname]#[methodname], [filename] or [filename]#[methodname]')
register('--per-test-timer', type=bool, help='Show progress and timer for each test.')
register('--default-concurrency', advanced=True,
choices=cls._CONCURRENCY_CHOICES, default=cls._CONCURRENCY_SERIAL,
help='Set the default concurrency mode for running tests not annotated with'
+ ' @TestParallel or @TestSerial.')
register('--default-parallel', advanced=True, type=bool,
deprecated_hint='Use --concurrency instead.', deprecated_version='0.0.86',
help='Run classes without @TestParallel or @TestSerial annotations in parallel.')
register('--parallel-threads', advanced=True, type=int, default=0,
help='Number of threads to run tests in parallel. 0 for autoset.')
Expand Down Expand Up @@ -101,7 +115,7 @@ def register_options(cls, register):
cls.register_jvm_tool(register,
'junit',
classpath=[
JarDependency(org='org.pantsbuild', name='junit-runner', rev='1.0.4'),
JarDependency(org='org.pantsbuild', name='junit-runner', rev='1.0.5'),
],
main=JUnitRun._MAIN,
# TODO(John Sirois): Investigate how much less we can get away with.
Expand Down Expand Up @@ -174,11 +188,23 @@ def __init__(self, *args, **kwargs):
self._args.append('-fail-fast')
self._args.append('-outdir')
self._args.append(self.workdir)

if options.per_test_timer:
self._args.append('-per-test-timer')

# TODO(zundel): Simply remove when --default_parallel finishes deprecation
if options.default_parallel:
self._args.append('-default-parallel')

if options.default_concurrency == self._CONCURRENCY_PARALLEL_METHODS:
self._args.append('-default-parallel')
self._args.append('-parallel-methods')
elif options.default_concurrency == self._CONCURRENCY_PARALLEL:
self._args.append('-default-parallel')
elif options.default_concurrency == self._CONCURRENCY_SERIAL:
# TODO(zundel): we can't do anything here yet while the --default-parallel
# option is in deprecation mode.
pass

self._args.append('-parallel-threads')
self._args.append(str(options.parallel_threads))

Expand Down Expand Up @@ -328,13 +354,16 @@ def _run_tests(self, tests_to_targets):
lambda target: target.test_platform,
lambda target: target.payload.extra_jvm_options,
lambda target: target.payload.extra_env_vars,
lambda target: target.concurrency,
lambda target: target.threads
)

# the below will be None if not set, and we'll default back to runtime_classpath
classpath_product = self.context.products.get_data('instrument_classpath')

result = 0
for (workdir, platform, target_jvm_options, target_env_vars), tests in tests_by_properties.items():
for properties, tests in tests_by_properties.items():
(workdir, platform, target_jvm_options, target_env_vars, concurrency, threads) = properties
for batch in self._partition(tests):
# Batches of test classes will likely exist within the same targets: dedupe them.
relevant_targets = set(map(tests_to_targets.get, batch))
Expand All @@ -345,6 +374,22 @@ def _run_tests(self, tests_to_targets):
classpath_product=classpath_product))
complete_classpath.update(classpath_append)
distribution = JvmPlatform.preferred_jvm_distribution([platform], self._strict_jvm_version)

# Override cmdline args with values from junit_test() target that specify concurrency:
args = self._args + [u'-xmlreport']

# TODO(zundel): Combine these together into a single -concurrency choices style argument
if concurrency == 'serial':
args = remove_arg(args, '-default-parallel')
if concurrency == 'parallel':
args = ensure_arg(args, '-default-parallel')
if concurrency == 'parallel_methods':
args = ensure_arg(args, '-default-parallel')
args = ensure_arg(args, '-parallel-methods')
if threads is not None:
args = remove_arg(args, '-parallel-threads', has_param=True)
args += ['-parallel-threads', str(threads)]

with binary_util.safe_args(batch, self.get_options()) as batch_tests:
self.context.log.debug('CWD = {}'.format(workdir))
self.context.log.debug('platform = {}'.format(platform))
Expand All @@ -355,7 +400,7 @@ def _run_tests(self, tests_to_targets):
classpath=complete_classpath,
main=JUnitRun._MAIN,
jvm_options=self.jvm_options + extra_jvm_options + list(target_jvm_options),
args=self._args + batch_tests + [u'-xmlreport'],
args=args + batch_tests,
workunit_factory=self.context.new_workunit,
workunit_name='run',
workunit_labels=[WorkUnitLabel.TEST],
Expand Down
5 changes: 5 additions & 0 deletions src/python/pants/util/BUILD
@@ -1,6 +1,11 @@
# Copyright 2014 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).

python_library(
name = 'argutil',
sources = ['argutil.py'],
)

python_library(
name = 'contextutil',
sources = ['contextutil.py'],
Expand Down
52 changes: 52 additions & 0 deletions src/python/pants/util/argutil.py
@@ -0,0 +1,52 @@
# coding=utf-8
# Copyright 2016 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).

from __future__ import (absolute_import, division, generators, nested_scopes, print_function,
unicode_literals, with_statement)


def ensure_arg(args, arg, param=None):
"""Make sure the arg is present in the list of args.
If arg is not present, adds the arg and the optional param.
If present and param != None, sets the parameter following the arg to param.
:param list args: strings representing an argument list.
:param string arg: argument to make sure is present in the list.
:param string param: parameter to add or update after arg in the list.
:return: possibly modified list of args.
"""
found = False
for idx, found_arg in enumerate(args):
if found_arg == arg:
if param is not None:
args[idx + 1] = param
return args

if not found:
args += [arg]
if param is not None:
args += [param]
return args


def remove_arg(args, arg, has_param=False):
"""Removes the first instance of the specified arg from the list of args.
If the arg is present and has_param is set, also removes the parameter that follows
the arg.
:param list args: strings representing an argument list.
:param staring arg: argument to remove from the list.
:param bool has_param: if true, also remove the parameter that follows arg in the list.
:return: possibly modified list of args.
"""
for idx, found_arg in enumerate(args):
if found_arg == arg:
if has_param:
slice_idx = idx + 2
else:
slice_idx = idx + 1
args = args[:idx] + args[slice_idx:]
break
return args
@@ -0,0 +1,43 @@
// Copyright 2016 Pants project contributors (see CONTRIBUTORS.md).
// Licensed under the Apache License, Version 2.0 (see LICENSE).
package org.pantsbuild.testproject.parallel;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import org.junit.Test;
import org.pantsbuild.junit.annotations.TestParallel;

import static org.junit.Assert.assertTrue;

/**
* This test is designed to exercise the TestParallel annotation.
* A similar test runs in tests/java/... to exercise junit-runner standalone.
* <p>
* For all methods in AnnotatedParallelTest1 and AnnotatedParallelTest2
* to succeed, both test classes must be running at the same time with the flag:
* <pre>
* --test-junit-parallel-threads 2
* <pre>
* when running with just these two classes as specs.
* <p>
* Runs in on the order of 10 milliseconds locally, but it may take longer on a CI machine to spin
* up 2 threads, so it has a generous timeout set.
*/
@TestParallel
public class AnnotatedParallelTest1 {
private static final int NUM_CONCURRENT_TESTS = 2;
private static final int RETRY_TIMEOUT_MS = 3000;
private static CountDownLatch latch = new CountDownLatch(NUM_CONCURRENT_TESTS);

@Test
public void aptest1() throws Exception {
awaitLatch("aptest1");
}

static void awaitLatch(String methodName) throws Exception {
System.out.println("start " + methodName);
latch.countDown();
assertTrue(latch.await(RETRY_TIMEOUT_MS, TimeUnit.MILLISECONDS));
System.out.println("end " + methodName);
}
}
@@ -0,0 +1,18 @@
// Copyright 2016 Pants project contributors (see CONTRIBUTORS.md).
// Licensed under the Apache License, Version 2.0 (see LICENSE).
package org.pantsbuild.testproject.parallel;

import org.junit.Test;
import org.pantsbuild.junit.annotations.TestParallel;

/**
* See {@link AnnotatedParallelTest1}
*/
@TestParallel
public class AnnotatedParallelTest2 {

@Test
public void aptest2() throws Exception {
AnnotatedParallelTest1.awaitLatch("aptest2");
}
}

0 comments on commit 96c7ecf

Please sign in to comment.