diff --git a/engines/python/src/main/java/ai/djl/python/engine/PyEnv.java b/engines/python/src/main/java/ai/djl/python/engine/PyEnv.java index eb793967b..3142ecdae 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/PyEnv.java +++ b/engines/python/src/main/java/ai/djl/python/engine/PyEnv.java @@ -261,6 +261,14 @@ public void setTensorParallelDegree(int tensorParallelDegree) { int getMpiWorkers() { int gpuCount = CudaUtils.getGpuCount(); + String visibleDevices = Utils.getenv("CUDA_VISIBLE_DEVICES"); + if (gpuCount > 0 && visibleDevices != null) { + int visibleCount = visibleDevices.split(",").length; + if (visibleCount > gpuCount || visibleCount < 1) { + throw new AssertionError("Invalid CUDA_VISIBLE_DEVICES: " + visibleDevices); + } + gpuCount = visibleCount; + } return gpuCount / getTensorParallelDegree(); }