From 7f94c1ad2be815549cd3d6168ac7e3c2e17558db Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Mon, 4 Dec 2023 20:25:54 -0800 Subject: [PATCH] [api] Use folk java process to avoid jvm consume GPU memory (#2882) --- .../main/java/ai/djl/util/cuda/CudaUtils.java | 108 +++++++++++++++++- .../java/ai/djl/util/SecurityManagerTest.java | 9 +- .../java/ai/djl/util/cuda/CudaUtilsTest.java | 21 +++- 3 files changed, 126 insertions(+), 12 deletions(-) diff --git a/api/src/main/java/ai/djl/util/cuda/CudaUtils.java b/api/src/main/java/ai/djl/util/cuda/CudaUtils.java index b0b8e3e4247..edcbf40eef3 100644 --- a/api/src/main/java/ai/djl/util/cuda/CudaUtils.java +++ b/api/src/main/java/ai/djl/util/cuda/CudaUtils.java @@ -22,7 +22,11 @@ import org.slf4j.LoggerFactory; import java.io.File; +import java.io.IOException; +import java.io.InputStream; import java.lang.management.MemoryUsage; +import java.util.ArrayList; +import java.util.List; import java.util.Locale; import java.util.regex.Pattern; @@ -33,6 +37,8 @@ public final class CudaUtils { private static final CudaLibrary LIB = loadLibrary(); + private static String[] gpuInfo; + private CudaUtils() {} /** @@ -49,7 +55,15 @@ public static boolean hasCuda() { * * @return the number of GPUs available in the system */ + @SuppressWarnings("PMD.NonThreadSafeSingleton") public static int getGpuCount() { + if (Boolean.getBoolean("ai.djl.util.cuda.folk")) { + if (gpuInfo == null) { + gpuInfo = execute(-1); // NOPMD + } + return Integer.parseInt(gpuInfo[0]); + } + if (LIB == null) { return 0; } @@ -79,7 +93,19 @@ public static int getGpuCount() { * * @return the version of CUDA runtime */ + @SuppressWarnings("PMD.NonThreadSafeSingleton") public static int getCudaVersion() { + if (Boolean.getBoolean("ai.djl.util.cuda.folk")) { + if (gpuInfo == null) { + gpuInfo = execute(-1); + } + int version = Integer.parseInt(gpuInfo[1]); + if (version == -1) { + throw new IllegalArgumentException("No cuda device found."); + } + return version; + } + if (LIB == null) { throw new IllegalStateException("No cuda library is loaded."); } @@ -95,9 +121,6 @@ public static int getCudaVersion() { * @return the version string of CUDA runtime */ public static String getCudaVersionString() { - if (LIB == null) { - throw new IllegalStateException("No cuda library is loaded."); - } int version = getCudaVersion(); int major = version / 1000; int minor = (version / 10) % 10; @@ -111,6 +134,14 @@ public static String getCudaVersionString() { * @return the CUDA compute capability */ public static String getComputeCapability(int device) { + if (Boolean.getBoolean("ai.djl.util.cuda.folk")) { + String[] ret = execute(device); + if (ret.length != 3) { + throw new IllegalArgumentException(ret[0]); + } + return ret[0]; + } + if (LIB == null) { throw new IllegalStateException("No cuda library is loaded."); } @@ -137,6 +168,16 @@ public static MemoryUsage getGpuMemory(Device device) { throw new IllegalArgumentException("Only GPU device is allowed."); } + if (Boolean.getBoolean("ai.djl.util.cuda.folk")) { + String[] ret = execute(device.getDeviceId()); + if (ret.length != 3) { + throw new IllegalArgumentException(ret[0]); + } + long total = Long.parseLong(ret[1]); + long used = Long.parseLong(ret[2]); + return new MemoryUsage(-1, used, used, total); + } + if (LIB == null) { throw new IllegalStateException("No GPU device detected."); } @@ -155,8 +196,42 @@ public static MemoryUsage getGpuMemory(Device device) { return new MemoryUsage(-1, committed, committed, total[0]); } + /** + * The main entrypoint to get CUDA information with command line. + * + * @param args the command line arguments. + */ + @SuppressWarnings("PMD.SystemPrintln") + public static void main(String[] args) { + int gpuCount = getGpuCount(); + if (args.length == 0) { + if (gpuCount <= 0) { + System.out.println("0,-1"); + return; + } + int cudaVersion = getCudaVersion(); + System.out.println(gpuCount + "," + cudaVersion); + return; + } + try { + int deviceId = Integer.parseInt(args[0]); + if (deviceId < 0 || deviceId >= gpuCount) { + System.out.println("Invalid device: " + deviceId); + return; + } + MemoryUsage mem = getGpuMemory(Device.gpu(deviceId)); + String cc = getComputeCapability(deviceId); + System.out.println(cc + ',' + mem.getMax() + ',' + mem.getUsed()); + } catch (NumberFormatException e) { + System.out.println("Invalid device: " + args[0]); + } + } + private static CudaLibrary loadLibrary() { try { + if (Boolean.getBoolean("ai.djl.util.cuda.folk")) { + return null; + } if (System.getProperty("os.name").startsWith("Win")) { String path = Utils.getenv("PATH"); if (path == null) { @@ -199,6 +274,33 @@ private static CudaLibrary loadLibrary() { } } + private static String[] execute(int deviceId) { + try { + String javaHome = System.getProperty("java.home"); + String classPath = System.getProperty("java.class.path"); + String os = System.getProperty("os.name"); + List cmd = new ArrayList<>(4); + if (os.startsWith("Win")) { + cmd.add(javaHome + "\\bin\\java.exe"); + } else { + cmd.add(javaHome + "/bin/java"); + } + cmd.add("-cp"); + cmd.add(classPath); + cmd.add("ai.djl.util.cuda.CudaUtils"); + if (deviceId >= 0) { + cmd.add(String.valueOf(deviceId)); + } + Process ps = new ProcessBuilder(cmd).redirectErrorStream(true).start(); + try (InputStream is = ps.getInputStream()) { + String line = Utils.toString(is).trim(); + return line.split(","); + } + } catch (IOException e) { + throw new IllegalArgumentException("Failed get GPU information", e); + } + } + private static void checkCall(int ret) { if (LIB == null) { throw new IllegalStateException("No cuda library is loaded."); diff --git a/api/src/test/java/ai/djl/util/SecurityManagerTest.java b/api/src/test/java/ai/djl/util/SecurityManagerTest.java index fd9b5db72bc..1e9eb17f63c 100644 --- a/api/src/test/java/ai/djl/util/SecurityManagerTest.java +++ b/api/src/test/java/ai/djl/util/SecurityManagerTest.java @@ -74,8 +74,11 @@ public void checkPermission(Permission perm) { } }; System.setSecurityManager(sm); - - Assert.assertFalse(CudaUtils.hasCuda()); - Assert.assertEquals(CudaUtils.getGpuCount(), 0); + try { + Assert.assertFalse(CudaUtils.hasCuda()); + Assert.assertEquals(CudaUtils.getGpuCount(), 0); + } finally { + System.setSecurityManager(null); + } } } diff --git a/api/src/test/java/ai/djl/util/cuda/CudaUtilsTest.java b/api/src/test/java/ai/djl/util/cuda/CudaUtilsTest.java index de1c5cb4a20..a598d8482e6 100644 --- a/api/src/test/java/ai/djl/util/cuda/CudaUtilsTest.java +++ b/api/src/test/java/ai/djl/util/cuda/CudaUtilsTest.java @@ -20,8 +20,6 @@ import org.testng.annotations.Test; import java.lang.management.MemoryUsage; -import java.util.Arrays; -import java.util.List; public class CudaUtilsTest { @@ -30,6 +28,9 @@ public class CudaUtilsTest { @Test public void testCudaUtils() { if (!CudaUtils.hasCuda()) { + Assert.assertThrows(CudaUtils::getCudaVersionString); + Assert.assertThrows(() -> CudaUtils.getComputeCapability(0)); + Assert.assertThrows(() -> CudaUtils.getGpuMemory(Device.gpu())); return; } // Possible to have CUDA and not have a GPU. @@ -37,16 +38,24 @@ public void testCudaUtils() { return; } - int cudaVersion = CudaUtils.getCudaVersion(); + String cudaVersion = CudaUtils.getCudaVersionString(); String smVersion = CudaUtils.getComputeCapability(0); MemoryUsage memoryUsage = CudaUtils.getGpuMemory(Device.gpu()); logger.info("CUDA runtime version: {}, sm: {}", cudaVersion, smVersion); logger.info("Memory usage: {}", memoryUsage); - Assert.assertTrue(cudaVersion >= 9020, "cuda 9.2+ required."); + Assert.assertNotNull(cudaVersion); + Assert.assertNotNull(smVersion); + } - List supportedSm = Arrays.asList("37", "52", "60", "61", "70", "75"); - Assert.assertTrue(supportedSm.contains(smVersion), "Unsupported cuda sm: " + smVersion); + @Test + public void testCudaUtilsWithFolk() { + System.setProperty("ai.djl.util.cuda.folk", "true"); + try { + testCudaUtils(); + } finally { + System.clearProperty("ai.djl.util.cuda.folk"); + } } }