Skip to content

Commit

Permalink
FEAT-#6914: Add a config for setting a number of threads per Dask wor…
Browse files Browse the repository at this point in the history
…ker (#6915)

Signed-off-by: Igoshev, Iaroslav <iaroslav.igoshev@intel.com>
  • Loading branch information
YarShev committed Feb 6, 2024
1 parent 91c8301 commit 2cdb534
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 1 deletion.
3 changes: 3 additions & 0 deletions modin/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
CIAWSAccessKeyID,
CIAWSSecretAccessKey,
CpuCount,
DaskThreadsPerWorker,
DoUseCalcite,
Engine,
EnvironmentVariable,
Expand Down Expand Up @@ -73,6 +74,8 @@
"RayRedisPassword",
"TestRayClient",
"LazyExecution",
# Dask specific
"DaskThreadsPerWorker",
# Partitioning
"NPartitions",
"MinPartitionSize",
Expand Down
7 changes: 7 additions & 0 deletions modin/config/envvars.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,13 @@ class LazyExecution(EnvironmentVariable, type=bool):
default = False


class DaskThreadsPerWorker(EnvironmentVariable, type=int):
"""Number of threads per Dask worker."""

varname = "MODIN_DASK_THREADS_PER_WORKER"
default = None


def _check_vars() -> None:
"""
Check validity of environment variables.
Expand Down
8 changes: 7 additions & 1 deletion modin/core/execution/dask/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
CIAWSAccessKeyID,
CIAWSSecretAccessKey,
CpuCount,
DaskThreadsPerWorker,
GithubCI,
Memory,
NPartitions,
Expand All @@ -44,12 +45,17 @@ def _disable_warnings():
from distributed import Client

num_cpus = CpuCount.get()
threads_per_worker = DaskThreadsPerWorker.get()
memory_limit = Memory.get()
worker_memory_limit = memory_limit // num_cpus if memory_limit else "auto"

# when the client is initialized, environment variables are inherited
with set_env(PYTHONWARNINGS="ignore::FutureWarning"):
client = Client(n_workers=num_cpus, memory_limit=worker_memory_limit)
client = Client(
n_workers=num_cpus,
threads_per_worker=threads_per_worker,
memory_limit=worker_memory_limit,
)

if GithubCI.get():
# set these keys to run tests that write to the mock s3 service. this seems
Expand Down

0 comments on commit 2cdb534

Please sign in to comment.