-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #113 from lsst/tickets/DM-33622
DM-33622: Add utility for forcing thread environment variables to set value
- Loading branch information
Showing
8 changed files
with
147 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
Add ``lsst.utils.threads`` for control of threads. | ||
Use `lsst.utils.threads.disable_implicit_threading()` to disable implicit threading. | ||
This function should be used in place of ``lsst.base.disableImplicitThreading()`` in all new code. | ||
This package now depends on the `threadpoolctl` package. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# This file is part of utils. | ||
# | ||
# Developed for the LSST Data Management System. | ||
# This product includes software developed by the LSST Project | ||
# (https://www.lsst.org). | ||
# See the COPYRIGHT file at the top-level directory of this distribution | ||
# for details of code ownership. | ||
# | ||
# Use of this source code is governed by a 3-clause BSD-style | ||
# license that can be found in the LICENSE file. | ||
# | ||
from __future__ import annotations | ||
|
||
"""Support for threading and multi-processing.""" | ||
|
||
__all__ = ["set_thread_envvars", "disable_implicit_threading"] | ||
|
||
import os | ||
|
||
try: | ||
from threadpoolctl import threadpool_limits | ||
except ImportError: | ||
threadpool_limits = None | ||
|
||
|
||
def set_thread_envvars(num_threads: int = 1, override: bool = False) -> None: | ||
"""Set common threading environment variables to the given value. | ||
Parameters | ||
---------- | ||
num_threads : `int`, optional | ||
Number of threads to use when setting the environment variable values. | ||
Default to 1 (disable threading). | ||
override : `bool`, optional | ||
Controls whether a previously set value should be over-ridden. Defaults | ||
to `False`. | ||
""" | ||
envvars = ( | ||
"OPENBLAS_NUM_THREADS", | ||
"GOTO_NUM_THREADS", | ||
"OMP_NUM_THREADS", | ||
"MKL_NUM_THREADS", | ||
"MKL_DOMAIN_NUM_THREADS", | ||
"MPI_NUM_THREADS", | ||
"NUMEXPR_NUM_THREADS", | ||
"NUMEXPR_MAX_THREADS", | ||
) | ||
|
||
for var in envvars: | ||
if override or var not in os.environ: | ||
os.environ[var] = str(num_threads) | ||
|
||
|
||
def disable_implicit_threading() -> None: | ||
"""Do whatever is necessary to try to prevent implicit threading. | ||
Notes | ||
----- | ||
Explicitly limits the number of threads allowed to be used by ``numexpr`` | ||
and attempts to limit the number of threads in all APIs supported by | ||
the ``threadpoolctl`` package. | ||
""" | ||
# Force one thread and force override. | ||
set_thread_envvars(1, True) | ||
|
||
try: | ||
# This must be a deferred import since importing it immediately | ||
# triggers the environment variable examination. | ||
# Catch this in case numexpr is not installed. | ||
import numexpr.utils | ||
except ImportError: | ||
pass | ||
else: | ||
numexpr.utils.set_num_threads(1) | ||
|
||
# Try to set threads for openblas and openmp | ||
if threadpool_limits is not None: | ||
threadpool_limits(limits=1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,3 +3,4 @@ psutil >= 5.7 | |
deprecated >= 1.2 | ||
pyyaml >5.1 | ||
astropy >= 5.0 | ||
threadpoolctl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# This file is part of utils. | ||
# | ||
# Developed for the LSST Data Management System. | ||
# This product includes software developed by the LSST Project | ||
# (https://www.lsst.org). | ||
# See the COPYRIGHT file at the top-level directory of this distribution | ||
# for details of code ownership. | ||
# | ||
# Use of this source code is governed by a 3-clause BSD-style | ||
# license that can be found in the LICENSE file. | ||
# | ||
|
||
import os | ||
import unittest | ||
|
||
from lsst.utils.threads import disable_implicit_threading, set_thread_envvars | ||
|
||
try: | ||
import numexpr | ||
except ImportError: | ||
numexpr = None | ||
try: | ||
import threadpoolctl | ||
except ImportError: | ||
threadpoolctl = None | ||
|
||
|
||
class ThreadsTestCase(unittest.TestCase): | ||
"""Tests for threads.""" | ||
|
||
def testDisable(self): | ||
set_thread_envvars(2, override=True) | ||
self.assertEqual(os.environ["OMP_NUM_THREADS"], "2") | ||
set_thread_envvars(3, override=False) | ||
self.assertEqual(os.environ["OMP_NUM_THREADS"], "2") | ||
|
||
disable_implicit_threading() | ||
self.assertEqual(os.environ["OMP_NUM_THREADS"], "1") | ||
|
||
# Check that we have only one thread. | ||
if numexpr: | ||
self.assertEqual(numexpr.utils.get_num_threads(), 1) | ||
if threadpoolctl: | ||
info = threadpoolctl.threadpool_info() | ||
for api in info: | ||
self.assertEqual(api["num_threads"], 1, f"API: {api}") | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |