Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX check symbols to better identify supported libraries #151

Merged
merged 2 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
3.2.1 (TBD)
===========

- Fixed a bug where an unsupported library would be detected because it shares a common
prefix with one of the supported libraries. Now the symbols are also checked to
identify the supported libraries.
https://github.com/joblib/threadpoolctl/pull/151

3.2.0 (2023-07-13)
==================

Expand Down
1 change: 1 addition & 0 deletions continuous_integration/install_with_blis.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ popd
# build & install numpy
git clone https://github.com/numpy/numpy.git
pushd numpy
git checkout v1.26.0 # pin numpy < 2 for now
git submodule update --init
echo "[blis]
libraries = blis
Expand Down
8 changes: 8 additions & 0 deletions tests/_pyMylib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ class MyThreadedLibController(LibController):
# instance.
filename_prefixes = ("my_threaded_lib",)

# (Optional) Symbols that the linked library is expected to expose. It is used along
# with the `filename_prefixes` to make sure that the correct library is identified.
check_symbols = (
"mylib_get_num_threads",
"mylib_set_num_threads",
"mylib_get_version",
)

def get_num_threads(self):
# This function should return the current maximum number of threads,
# which is reported as "num_threads" by `ThreadpoolController.info`.
Expand Down
50 changes: 42 additions & 8 deletions threadpoolctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,18 @@ class OpenBLASController(LibController):
user_api = "blas"
internal_api = "openblas"
filename_prefixes = ("libopenblas", "libblas")
check_symbols = ("openblas_get_num_threads", "openblas_get_num_threads64_")
check_symbols = (
"openblas_get_num_threads",
"openblas_get_num_threads64_",
"openblas_set_num_threads",
"openblas_set_num_threads64_",
"openblas_get_config",
"openblas_get_config64_",
"openblas_get_parallel",
"openblas_get_parallel64_",
"openblas_get_corename",
"openblas_get_corename64_",
)

def set_additional_attributes(self):
self.threading_layer = self._get_threading_layer()
Expand Down Expand Up @@ -237,7 +248,15 @@ class BLISController(LibController):
user_api = "blas"
internal_api = "blis"
filename_prefixes = ("libblis", "libblas")
check_symbols = ("bli_thread_get_num_threads",)
check_symbols = (
"bli_thread_get_num_threads",
"bli_thread_set_num_threads",
"bli_info_get_version_str",
"bli_info_get_enable_openmp",
"bli_info_get_enable_pthreads",
"bli_arch_query_id",
"bli_arch_string",
)

def set_additional_attributes(self):
self.threading_layer = self._get_threading_layer()
Expand Down Expand Up @@ -266,9 +285,9 @@ def get_version(self):

def _get_threading_layer(self):
"""Return the threading layer of BLIS"""
if self.dynlib.bli_info_get_enable_openmp():
if getattr(self.dynlib, "bli_info_get_enable_openmp", lambda: False)():
return "openmp"
elif self.dynlib.bli_info_get_enable_pthreads():
elif getattr(self.dynlib, "bli_info_get_enable_pthreads", lambda: False)():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why False instead of None?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because I want to return a bool since it's passed to elif

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, I misread the line.

return "pthreads"
return "disabled"

Expand All @@ -292,7 +311,12 @@ class MKLController(LibController):
user_api = "blas"
internal_api = "mkl"
filename_prefixes = ("libmkl_rt", "mkl_rt", "libblas")
check_symbols = ("MKL_Get_Max_Threads",)
check_symbols = (
"MKL_Get_Max_Threads",
"MKL_Set_Num_Threads",
"MKL_Get_Version_String",
"MKL_Set_Threading_Layer",
)

def set_additional_attributes(self):
self.threading_layer = self._get_threading_layer()
Expand Down Expand Up @@ -343,6 +367,10 @@ class OpenMPController(LibController):
user_api = "openmp"
internal_api = "openmp"
filename_prefixes = ("libiomp", "libgomp", "libomp", "vcomp")
check_symbols = (
"omp_get_max_threads",
"omp_get_num_threads",
)

def get_num_threads(self):
get_func = getattr(self.dynlib, "omp_get_max_threads", lambda: None)
Expand Down Expand Up @@ -978,11 +1006,17 @@ def _make_controller_from_path(self, filepath):
# duplicate entry in threadpool_info.
continue

# filename matches a prefix. Create and store the library
# filename matches a prefix. Now we check if the library has the symbols we
# are looking for. If none of the symbols exists, it's very likely not the
# expected library (e.g. a library having a common prefix with one of the
# our supported libraries). Otherwise, create and store the library
# controller.

lib_controller = controller_class(filepath=filepath, prefix=prefix)
self.lib_controllers.append(lib_controller)
if not hasattr(controller_class, "check_symbols") or any(
hasattr(lib_controller.dynlib, func)
for func in controller_class.check_symbols
):
self.lib_controllers.append(lib_controller)

def _check_prefix(self, library_basename, filename_prefixes):
"""Return the prefix library_basename starts with
Expand Down