Skip to content

Commit

Permalink
Merge pull request #7749 from stuartarchibald/fix/7458
Browse files Browse the repository at this point in the history
Refactor threading layer priority tests to not use stdout/stderr
  • Loading branch information
sklam authored and esc committed Jan 27, 2022
1 parent f09cd36 commit ac1d173
Showing 1 changed file with 34 additions and 42 deletions.
76 changes: 34 additions & 42 deletions numba/tests/test_parallel_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import random
import subprocess
import sys
import textwrap
import threading
import unittest

Expand Down Expand Up @@ -528,44 +529,47 @@ class TestThreadingLayerPriority(ThreadLayerTestHelper):

def each_env_var(self, env_var: str):
"""Test setting priority via env var NUMBA_THREADING_LAYER_PRIORITY.
:return: threading_layer_priority, stderr
(containing ``@threading_layer@``)
"""
env = os.environ.copy()
env['NUMBA_THREADING_LAYER'] = 'default'
env['NUMBA_THREADING_LAYER_PRIORITY'] = env_var

code = """import sys
import numba
# trigger threading layer decision
# hence catching invalid THREADING_LAYER_PRIORITY
@numba.jit(
'float64[::1](float64[::1], float64[::1])',
nopython=True,
parallel=True,
)
def plus(x, y):
return x + y
print(' '.join(numba.config.THREADING_LAYER_PRIORITY))
print("@%s@" % numba.threading_layer(), file=sys.stderr)
"""
code = f"""
import numba
# trigger threading layer decision
# hence catching invalid THREADING_LAYER_PRIORITY
@numba.jit(
'float64[::1](float64[::1], float64[::1])',
nopython=True,
parallel=True,
)
def plus(x, y):
return x + y
captured_envvar = list("{env_var}".split())
assert numba.config.THREADING_LAYER_PRIORITY == \
captured_envvar, "priority mismatch"
assert numba.threading_layer() == captured_envvar[0],\
"selected backend mismatch"
"""
cmd = [
sys.executable,
'-c',
code,
textwrap.dedent(code),
]
return self.run_cmd(cmd, env=env)
self.run_cmd(cmd, env=env)

@skip_no_omp
@skip_no_tbb
def test_valid_env_var(self):
default = ['tbb', 'omp', 'workqueue']
for p in itertools.permutations(default):
env_var = ' '.join(p)
threading_layer_priority, _ = self.each_env_var(env_var)
self.assertEqual(threading_layer_priority.strip(), env_var)
self.each_env_var(env_var)

@skip_no_omp
@skip_no_tbb
def test_invalid_env_var(self):
env_var = 'tbb omp workqueue notvalidhere'
with self.assertRaises(AssertionError) as raises:
Expand All @@ -579,22 +583,16 @@ def test_invalid_env_var(self):
@skip_no_omp
def test_omp(self):
for env_var in ("omp tbb workqueue", "omp workqueue tbb"):
threading_layer_priority, out = self.each_env_var(env_var)
self.assertEqual(threading_layer_priority.strip(), env_var)
self.assertIn("@omp@", out)
self.each_env_var(env_var)

@skip_no_tbb
def test_tbb(self):
for env_var in ("tbb omp workqueue", "tbb workqueue omp"):
threading_layer_priority, out = self.each_env_var(env_var)
self.assertEqual(threading_layer_priority.strip(), env_var)
self.assertIn("@tbb@", out)
self.each_env_var(env_var)

def test_workqueue(self):
for env_var in ("workqueue tbb omp", "workqueue omp tbb"):
threading_layer_priority, out = self.each_env_var(env_var)
self.assertEqual(threading_layer_priority.strip(), env_var)
self.assertIn("@workqueue@", out)
self.each_env_var(env_var)


@skip_parfors_unsupported
Expand All @@ -619,16 +617,13 @@ def foo(a, b, c, d, e, f, g, h):
x = np.ones(2**20, np.float32)
foo(*([x]*8))
print("@%s@" % threading_layer())
assert threading_layer() == "omp", "omp not found"
"""
cmdline = [sys.executable, '-c', runme]
env = os.environ.copy()
env['NUMBA_THREADING_LAYER'] = "omp"
env['OMP_STACKSIZE'] = "100K"
out, err = self.run_cmd(cmdline, env=env)
if self._DEBUG:
print(out, err)
self.assertIn("@omp@", out)
self.run_cmd(cmdline, env=env)

@skip_no_tbb
def test_single_thread_tbb(self):
Expand All @@ -647,16 +642,13 @@ def foo(n):
return acc
foo(100)
print("@%s@" % threading_layer())
assert threading_layer() == "tbb", "tbb not found"
"""
cmdline = [sys.executable, '-c', runme]
env = os.environ.copy()
env['NUMBA_THREADING_LAYER'] = "tbb"
env['NUMBA_NUM_THREADS'] = "1"
out, err = self.run_cmd(cmdline, env=env)
if self._DEBUG:
print(out, err)
self.assertIn("@tbb@", out)
self.run_cmd(cmdline, env=env)

def test_workqueue_aborts_on_nested_parallelism(self):
"""
Expand Down

0 comments on commit ac1d173

Please sign in to comment.