diff --git a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java index 726ecf75769..ad20f8d99e3 100644 --- a/api/src/main/java/ai/djl/ndarray/BaseNDManager.java +++ b/api/src/main/java/ai/djl/ndarray/BaseNDManager.java @@ -14,6 +14,7 @@ import ai.djl.Device; import ai.djl.engine.Engine; +import ai.djl.engine.StandardCapabilities; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.util.PairList; @@ -62,8 +63,13 @@ protected BaseNDManager(NDManager parent, Device device) { uid = UUID.randomUUID().toString(); Engine engine = getEngine().getAlternativeEngine(); if (engine != null) { - // Use the default device - alternativeManager = engine.newBaseManager(); + // Use the same device if possible for efficiency + if (this.device.isGpu() && engine.hasCapability(StandardCapabilities.CUDA)) { + alternativeManager = engine.newBaseManager(this.device); + } else { + // Use the default device + alternativeManager = engine.newBaseManager(); + } } }