diff --git a/engines/python/src/main/java/ai/djl/python/engine/PyEnv.java b/engines/python/src/main/java/ai/djl/python/engine/PyEnv.java index 2d16b5949..01d63677a 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/PyEnv.java +++ b/engines/python/src/main/java/ai/djl/python/engine/PyEnv.java @@ -50,7 +50,6 @@ public class PyEnv { private int predictTimeout; private int modelLoadingTimeout; private int tensorParallelDegree; - private int mpiProcesses; private Map envs; private Map initParameters; private boolean initialized; @@ -219,6 +218,12 @@ public void setPythonExecutable(String pythonExecutable) { * @return the tensor parallel degree */ public int getTensorParallelDegree() { + if (tensorParallelDegree == 0) { + String value = Utils.getenv("TENSOR_PARALLEL_DEGREE"); + if (value != null) { + tensorParallelDegree = Integer.parseInt(value); + } + } return tensorParallelDegree; } @@ -229,12 +234,11 @@ public int getTensorParallelDegree() { */ public void setTensorParallelDegree(int tensorParallelDegree) { this.tensorParallelDegree = tensorParallelDegree; - int gpuCount = CudaUtils.getGpuCount(); - mpiProcesses = gpuCount / tensorParallelDegree; } int getMpiWorkers() { - return mpiProcesses; + int gpuCount = CudaUtils.getGpuCount(); + return gpuCount / getTensorParallelDegree(); } /** diff --git a/engines/python/src/main/java/ai/djl/python/engine/PyModel.java b/engines/python/src/main/java/ai/djl/python/engine/PyModel.java index 82393b408..c839119a5 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/PyModel.java +++ b/engines/python/src/main/java/ai/djl/python/engine/PyModel.java @@ -74,9 +74,11 @@ public void load(Path modelPath, String prefix, Map options) throws I } String entryPoint = null; if (options != null) { + logger.debug("options in serving.properties for model: {}", modelName); for (Map.Entry entry : options.entrySet()) { String key = entry.getKey(); String value = (String) entry.getValue(); + logger.debug("{}={}", key, value); switch (key) { case "pythonExecutable": pyEnv.setPythonExecutable(value); @@ -112,6 +114,9 @@ public void load(Path modelPath, String prefix, Map options) throws I case "parallel_loading": parallelLoading = Boolean.parseBoolean(value); break; + case "tensor_parallel_degree": + pyEnv.setTensorParallelDegree(Integer.parseInt(value)); + break; case "handler": pyEnv.setHandler(value); break; @@ -134,10 +139,8 @@ public void load(Path modelPath, String prefix, Map options) throws I } pyEnv.setEntryPoint(entryPoint); if (pyEnv.isMpiMode()) { - int partitions; - if (System.getenv("TENSOR_PARALLEL_DEGREE") != null) { - partitions = Integer.parseInt(System.getenv("TENSOR_PARALLEL_DEGREE")); - } else { + int partitions = pyEnv.getTensorParallelDegree(); + if (partitions == 0) { // TODO: avoid use hardcoded "partitioned_model_" name try (Stream stream = Files.list(modelPath)) { partitions = @@ -154,10 +157,10 @@ public void load(Path modelPath, String prefix, Map options) throws I throw new FileNotFoundException( "partitioned_model_ file not found in: " + modelPath); } + pyEnv.setTensorParallelDegree(partitions); } logger.info("Loading model in MPI model with TP: {}.", partitions); - pyEnv.setTensorParallelDegree(partitions); int mpiWorkers = pyEnv.getMpiWorkers(); if (mpiWorkers <= 0) { throw new EngineException( @@ -235,6 +238,7 @@ private Path findModelFile(String prefix) { } private void createAllPyProcesses(int mpiWorkers) { + long begin = System.currentTimeMillis(); ExecutorService pool = null; List> futures = new ArrayList<>(); if (parallelLoading) { @@ -245,6 +249,7 @@ private void createAllPyProcesses(int mpiWorkers) { PyProcess worker = new PyProcess(this, pyEnv, i); workerQueue.offer(worker); if (pool != null) { + logger.debug("Submitting to pool: {}", i); futures.add(pool.submit(worker::startPythonProcess)); } else { worker.startPythonProcess(); @@ -264,6 +269,8 @@ private void createAllPyProcesses(int mpiWorkers) { } } } + long duration = System.currentTimeMillis() - begin; + logger.info("{} model loaded in {} ms.", modelName, duration); } private void shutdown() {