diff --git a/.azure_pipeline.yml b/.azure_pipeline.yml index df5133f3..955a222a 100644 --- a/.azure_pipeline.yml +++ b/.azure_pipeline.yml @@ -59,6 +59,7 @@ jobs: VERSION_PYTHON: '*' CC_OUTER_LOOP: 'clang-8' CC_INNER_LOOP: 'gcc' + MKL_THREADING_LAYER: 'INTEL' # Linux + Python 3.7 with numpy / scipy installed with pip from PyPI and # heterogeneous openmp runtimes. py37_pip_openblas_gcc_clang: diff --git a/CHANGES.md b/CHANGES.md index 4850acd5..35fe4605 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,11 @@ +1.2.0 (TBD) +=========== + +- Expose MKL threading layer in informations displayed by `threadpool_info`. + This information is referenced in the `threading_layer` field. + https://github.com/joblib/threadpoolctl/pull/48 + + 1.1.0 (2019-09-12) ================== diff --git a/tests/test_threadpoolctl.py b/tests/test_threadpoolctl.py index 201c5ab4..75d88895 100644 --- a/tests/test_threadpoolctl.py +++ b/tests/test_threadpoolctl.py @@ -1,3 +1,4 @@ +import os import re import ctypes import pytest @@ -302,3 +303,19 @@ def test_get_original_num_threads(limit): expected = min( [module['num_threads'] for module in original_infos]) assert original_num_threads['blas'] == expected + + +def test_mkl_threading_layer(): + # Check that threadpool_info correctly recovers the threading layer used + # by mkl + mkl_info = [module for module in threadpool_info() + if module['internal_api'] == 'mkl'] + + if not mkl_info: + pytest.skip("requires MKL") + + expected_layer = os.getenv("MKL_THREADING_LAYER") + actual_layer = mkl_info[0]['threading_layer'] + + if expected_layer: + assert actual_layer == expected_layer.lower() diff --git a/threadpoolctl.py b/threadpoolctl.py index 2c6dc515..b4004519 100644 --- a/threadpoolctl.py +++ b/threadpoolctl.py @@ -243,6 +243,9 @@ def threadpool_info(): # we map it to 1 for consistency with other libraries. if module['num_threads'] == -1 and module['internal_api'] == 'blis': module['num_threads'] = 1 + if module['internal_api'] == 'mkl': + layer = _get_mkl_threading_layer(module['dynlib']) + module['threading_layer'] = layer # Remove the wrapper for the module and its function del module['set_num_threads'], module['get_num_threads'] del module['dynlib'] @@ -251,6 +254,17 @@ def threadpool_info(): return infos +def _get_mkl_threading_layer(mkl_dynlib): + """Return the threading layer of MKL""" + # The function mkl_set_threading_layer returns the current threading layer + # Calling it with an invalid threading layer allows us to safely get + # the threading layer + set_threading_layer = getattr(mkl_dynlib, "MKL_Set_Threading_Layer") + layer_map = {0: "intel", 1: "sequential", 2: "pgi", + 3: "gnu", 4: "tbb", -1: "not specified"} + return layer_map[set_threading_layer(-1)] + + def _get_version(dynlib, internal_api): if internal_api == "mkl": return _get_mkl_version(dynlib)