From 19c3702f02debf0737739329ab0fd2003d416686 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Thu, 2 May 2024 20:17:23 -0700 Subject: [PATCH] [pytorch] Updates PyTorch to 2.3.0 --- .github/workflows/nightly_publish.yml | 2 ++ engines/pytorch/pytorch-engine/build.gradle | 1 + .../pytorch/integration/ALibUtilsTest.java | 3 +++ .../djl/pytorch/integration/IValueTest.java | 5 ++++ .../djl/pytorch/integration/MkldnnTest.java | 2 +- .../ai/djl/pytorch/integration/MpsTest.java | 26 +++++-------------- .../djl/pytorch/integration/ProfilerTest.java | 3 +++ .../djl/pytorch/integration/PtModelTest.java | 3 +++ .../pytorch/integration/PtNDArrayTest.java | 5 ++++ .../pytorch/integration/TorchScriptTest.java | 7 +++++ .../ai/djl/pytorch/jni/IValueUtilsTest.java | 9 +++++++ .../java/ai/djl/pytorch/jni/JniUtilsTest.java | 4 +++ engines/pytorch/pytorch-jni/build.gradle | 6 ++++- .../pytorch/pytorch-model-zoo/build.gradle | 1 + .../nlp/textgeneration/GptTranslatorTest.java | 3 +++ engines/pytorch/pytorch-native/build.gradle | 8 ------ .../ai/djl/integration/IntegrationTests.java | 21 +++++++++------ .../java/ai/djl/testing/TestRequirements.java | 16 ++++++++++++ 18 files changed, 88 insertions(+), 37 deletions(-) diff --git a/.github/workflows/nightly_publish.yml b/.github/workflows/nightly_publish.yml index 6d303bbf767..1de37a4854f 100644 --- a/.github/workflows/nightly_publish.yml +++ b/.github/workflows/nightly_publish.yml @@ -191,6 +191,7 @@ jobs: ./gradlew clean engines:pytorch:pytorch-jni:publish -Ppt_version=1.13.1 -Psnapshot ./gradlew clean engines:pytorch:pytorch-jni:publish -Ppt_version=2.1.2 -Psnapshot ./gradlew clean engines:pytorch:pytorch-jni:publish -Ppt_version=2.2.2 -Psnapshot + ./gradlew clean engines:pytorch:pytorch-jni:publish -Ppt_version=2.3.0 -Psnapshot ./gradlew clean engines:ml:xgboost:publish -Pgpu -Psnapshot ./gradlew clean publish -Psnapshot cd bom @@ -206,6 +207,7 @@ jobs: ./gradlew clean engines:pytorch:pytorch-jni:publish -Ppt_version=1.13.1 -P${{ github.event.inputs.mode }} ./gradlew clean engines:pytorch:pytorch-jni:publish -Ppt_version=2.1.2 -P${{ github.event.inputs.mode }} ./gradlew clean engines:pytorch:pytorch-jni:publish -Ppt_version=2.2.2 -P${{ github.event.inputs.mode }} + ./gradlew clean engines:pytorch:pytorch-jni:publish -Ppt_version=2.3.0 -P${{ github.event.inputs.mode }} ./gradlew clean engines:ml:xgboost:publish -Pgpu -P${{ github.event.inputs.mode }} ./gradlew clean publish -P${{ github.event.inputs.mode }} cd bom diff --git a/engines/pytorch/pytorch-engine/build.gradle b/engines/pytorch/pytorch-engine/build.gradle index 73107ab4572..85773862bc5 100644 --- a/engines/pytorch/pytorch-engine/build.gradle +++ b/engines/pytorch/pytorch-engine/build.gradle @@ -7,6 +7,7 @@ dependencies { exclude group: "junit", module: "junit" } testImplementation "org.slf4j:slf4j-simple:${slf4j_version}" + testImplementation(project(":testing")) testRuntimeOnly project(":engines:pytorch:pytorch-model-zoo") testRuntimeOnly project(":engines:pytorch:pytorch-jni") } diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/ALibUtilsTest.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/ALibUtilsTest.java index 569c83e9945..89f3ab19d99 100644 --- a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/ALibUtilsTest.java +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/ALibUtilsTest.java @@ -13,6 +13,7 @@ package ai.djl.pytorch.integration; import ai.djl.engine.Engine; +import ai.djl.testing.TestRequirements; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; @@ -23,6 +24,8 @@ public class ALibUtilsTest { @BeforeClass public void setup() { + TestRequirements.notMacX86(); + System.setProperty("ai.djl.pytorch.native_helper", ALibUtilsTest.class.getName()); System.setProperty("STDCXX_LIBRARY_PATH", "/usr/lib/non-exists"); } diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/IValueTest.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/IValueTest.java index 286ac02001c..33f9a70cb1f 100644 --- a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/IValueTest.java +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/IValueTest.java @@ -23,6 +23,7 @@ import ai.djl.pytorch.jni.IValue; import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ZooModel; +import ai.djl.testing.TestRequirements; import ai.djl.training.util.ProgressBar; import org.testng.Assert; @@ -37,6 +38,8 @@ public class IValueTest { @Test public void testIValue() { + TestRequirements.notMacX86(); + try (PtNDManager manager = (PtNDManager) NDManager.newBaseManager()) { PtNDArray array1 = (PtNDArray) manager.zeros(new Shape(1)); PtNDArray array2 = (PtNDArray) manager.ones(new Shape(1)); @@ -199,6 +202,8 @@ public void testIValue() { @Test public void testIValueModel() throws IOException, ModelException { + TestRequirements.notMacX86(); + Criteria criteria = Criteria.builder() .setTypes(NDList.class, NDList.class) diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/MkldnnTest.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/MkldnnTest.java index f855efb3e1c..f38ac446315 100644 --- a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/MkldnnTest.java +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/MkldnnTest.java @@ -25,7 +25,7 @@ /** The file is for testing PyTorch MKLDNN functionalities. */ public class MkldnnTest { - @Test + @Test(enabled = false) public void testMkldnn() { if (!"amd64".equals(System.getProperty("os.arch"))) { throw new SkipException("MKLDNN Test requires x86_64 arch."); diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/MpsTest.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/MpsTest.java index e8f6e5d405f..968fdcd4e0d 100644 --- a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/MpsTest.java +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/MpsTest.java @@ -17,9 +17,9 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; +import ai.djl.testing.TestRequirements; import org.testng.Assert; -import org.testng.SkipException; import org.testng.annotations.Test; import java.util.Arrays; @@ -27,12 +27,9 @@ public class MpsTest { - @Test + @Test(enabled = false) public void testMps() { - if (!"aarch64".equals(System.getProperty("os.arch")) - || !System.getProperty("os.name").startsWith("Mac")) { - throw new SkipException("MPS test requires M1 macOS."); - } + TestRequirements.macosM1(); Device device = Device.of("mps", -1); try (NDManager manager = NDManager.newBaseManager(device)) { @@ -41,16 +38,9 @@ public void testMps() { } } - private static boolean checkMpsCompatible() { - return "aarch64".equals(System.getProperty("os.arch")) - && System.getProperty("os.name").startsWith("Mac"); - } - - @Test + @Test(enabled = false) public void testToTensorMPS() { - if (!checkMpsCompatible()) { - throw new SkipException("MPS toTensor test requires Apple Silicon macOS."); - } + TestRequirements.macosM1(); // Test that toTensor does not fail on MPS (e.g. due to use of float64 for division) try (NDManager manager = NDManager.newBaseManager(Device.fromName("mps"))) { @@ -60,11 +50,9 @@ public void testToTensorMPS() { } } - @Test + @Test(enabled = false) public void testClassificationsMPS() { - if (!checkMpsCompatible()) { - throw new SkipException("MPS classification test requires Apple Silicon macOS."); - } + TestRequirements.macosM1(); // Test that classifications do not fail on MPS (e.g. due to conversion of probabilities to // float64) diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/ProfilerTest.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/ProfilerTest.java index 6c60bf15e8c..80d67c88cc2 100644 --- a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/ProfilerTest.java +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/ProfilerTest.java @@ -28,6 +28,7 @@ import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ModelNotFoundException; import ai.djl.repository.zoo.ZooModel; +import ai.djl.testing.TestRequirements; import ai.djl.training.util.ProgressBar; import ai.djl.translate.TranslateException; @@ -47,6 +48,8 @@ public void testProfiler() ModelNotFoundException, IOException, TranslateException { + TestRequirements.notMacX86(); + try (NDManager manager = NDManager.newBaseManager()) { ImageClassificationTranslator translator = ImageClassificationTranslator.builder().addTransform(new ToTensor()).build(); diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/PtModelTest.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/PtModelTest.java index 34868730c06..09713513b68 100644 --- a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/PtModelTest.java +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/PtModelTest.java @@ -21,6 +21,7 @@ import ai.djl.ndarray.types.Shape; import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ZooModel; +import ai.djl.testing.TestRequirements; import ai.djl.training.util.ProgressBar; import ai.djl.translate.NoopTranslator; import ai.djl.translate.TranslateException; @@ -36,6 +37,8 @@ public class PtModelTest { @Test public void testLoadFromStream() throws IOException, TranslateException, ModelException { + TestRequirements.notMacX86(); + Criteria criteria = Criteria.builder() .setTypes(NDList.class, NDList.class) diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/PtNDArrayTest.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/PtNDArrayTest.java index 83b01f7a1e8..9e15a736c69 100644 --- a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/PtNDArrayTest.java +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/PtNDArrayTest.java @@ -17,6 +17,7 @@ import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; +import ai.djl.testing.TestRequirements; import org.testng.Assert; import org.testng.annotations.Test; @@ -27,6 +28,8 @@ public class PtNDArrayTest { @Test public void testStringTensor() { + TestRequirements.notMacX86(); + try (NDManager manager = NDManager.newBaseManager()) { String[] str = {"a", "b", "c"}; NDArray arr = manager.create(str); @@ -40,6 +43,8 @@ public void testStringTensor() { @Test public void testLargeTensor() { + TestRequirements.notMacX86(); + try (NDManager manager = NDManager.newBaseManager()) { NDArray array = manager.zeros(new Shape(10 * 2850, 18944), DataType.FLOAT32); Assert.assertThrows(EngineException.class, array::toByteArray); diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/TorchScriptTest.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/TorchScriptTest.java index 54b19c7568b..10409eb0030 100644 --- a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/TorchScriptTest.java +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/TorchScriptTest.java @@ -24,6 +24,7 @@ import ai.djl.pytorch.jni.JniUtils; import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ZooModel; +import ai.djl.testing.TestRequirements; import ai.djl.training.util.ProgressBar; import ai.djl.translate.TranslateException; @@ -42,6 +43,8 @@ public class TorchScriptTest { @Test public void testDictInput() throws ModelException, IOException, TranslateException { + TestRequirements.notMacX86(); + try (NDManager manager = NDManager.newBaseManager()) { Criteria criteria = Criteria.builder() @@ -76,6 +79,8 @@ public void testDictInput() throws ModelException, IOException, TranslateExcepti @Test public void testInputOutput() throws IOException, ModelException { + TestRequirements.notMacX86(); + Criteria criteria = Criteria.builder() .setTypes(NDList.class, NDList.class) @@ -99,6 +104,8 @@ public void testInputOutput() throws IOException, ModelException { @Test public void testGetMethodNames() throws ModelException, IOException { + TestRequirements.notMacX86(); + Criteria criteria = Criteria.builder() .setTypes(NDList.class, NDList.class) diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/jni/IValueUtilsTest.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/jni/IValueUtilsTest.java index 273cae610fe..0cf25590b07 100644 --- a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/jni/IValueUtilsTest.java +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/jni/IValueUtilsTest.java @@ -17,6 +17,7 @@ import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; import ai.djl.pytorch.engine.PtNDManager; +import ai.djl.testing.TestRequirements; import ai.djl.util.Pair; import org.testng.Assert; @@ -28,6 +29,8 @@ public class IValueUtilsTest { @Test public void testTuple() { + TestRequirements.notMacX86(); + try (NDManager manager = NDManager.newBaseManager()) { NDArray array1 = manager.zeros(new Shape(1)); array1.setName("input1()"); @@ -50,6 +53,8 @@ public void testTuple() { @Test public void testTupleOfTuple() { + TestRequirements.notMacX86(); + try (PtNDManager manager = (PtNDManager) NDManager.newBaseManager()) { NDArray array1 = manager.zeros(new Shape(1)); array1.setName("input1(2,3)"); @@ -91,6 +96,8 @@ public void testTupleOfTuple() { @Test public void testMapOfTensor() { + TestRequirements.notMacX86(); + try (NDManager manager = NDManager.newBaseManager()) { NDArray array1 = manager.zeros(new Shape(1)); array1.setName("input1.key1"); @@ -119,6 +126,8 @@ public void testMapOfTensor() { @Test public void testListOfTensor() { + TestRequirements.notMacX86(); + try (NDManager manager = NDManager.newBaseManager()) { NDArray array1 = manager.zeros(new Shape(1)); array1.setName("input1[]"); diff --git a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/jni/JniUtilsTest.java b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/jni/JniUtilsTest.java index 0e7a6ae9868..0ea909d4f95 100644 --- a/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/jni/JniUtilsTest.java +++ b/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/jni/JniUtilsTest.java @@ -12,12 +12,16 @@ */ package ai.djl.pytorch.jni; +import ai.djl.testing.TestRequirements; + import org.testng.annotations.Test; public class JniUtilsTest { @Test public void testClearGpuCache() { + TestRequirements.notMacX86(); + JniUtils.emptyCudaCache(); } } diff --git a/engines/pytorch/pytorch-jni/build.gradle b/engines/pytorch/pytorch-jni/build.gradle index 4a7865288f4..593e051515f 100644 --- a/engines/pytorch/pytorch-jni/build.gradle +++ b/engines/pytorch/pytorch-jni/build.gradle @@ -25,7 +25,11 @@ processResources { "osx-aarch64/cpu/libdjl_torch.dylib", "win-x86_64/cpu/djl_torch.dll" ] - if (ptVersion.startsWith("2.1.") || ptVersion.startsWith("2.2.")) { + if (ptVersion.startsWith("2.3.")) { + files.add("linux-x86_64/cu121/libdjl_torch.so") + files.add("linux-x86_64/cu121-precxx11/libdjl_torch.so") + files.add("win-x86_64/cu121/djl_torch.dll") + } else if (ptVersion.startsWith("2.1.") || ptVersion.startsWith("2.2.")) { files.add("linux-x86_64/cu121/libdjl_torch.so") files.add("linux-x86_64/cu121-precxx11/libdjl_torch.so") files.add("win-x86_64/cu121/djl_torch.dll") diff --git a/engines/pytorch/pytorch-model-zoo/build.gradle b/engines/pytorch/pytorch-model-zoo/build.gradle index f2a465b5070..5614b7f3535 100644 --- a/engines/pytorch/pytorch-model-zoo/build.gradle +++ b/engines/pytorch/pytorch-model-zoo/build.gradle @@ -7,6 +7,7 @@ dependencies { exclude group: "junit", module: "junit" } testImplementation "org.slf4j:slf4j-simple:${slf4j_version}" + testImplementation(project(":testing")) } tasks.register('syncS3', Exec) { diff --git a/engines/pytorch/pytorch-model-zoo/src/test/java/ai/djl/pytorch/zoo/nlp/textgeneration/GptTranslatorTest.java b/engines/pytorch/pytorch-model-zoo/src/test/java/ai/djl/pytorch/zoo/nlp/textgeneration/GptTranslatorTest.java index 95d2614f339..6ead0d6d87a 100644 --- a/engines/pytorch/pytorch-model-zoo/src/test/java/ai/djl/pytorch/zoo/nlp/textgeneration/GptTranslatorTest.java +++ b/engines/pytorch/pytorch-model-zoo/src/test/java/ai/djl/pytorch/zoo/nlp/textgeneration/GptTranslatorTest.java @@ -24,6 +24,7 @@ import ai.djl.nn.LambdaBlock; import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ZooModel; +import ai.djl.testing.TestRequirements; import ai.djl.translate.TranslateException; import org.testng.Assert; @@ -38,6 +39,8 @@ public class GptTranslatorTest { @Test public void testGpt2() throws TranslateException, ModelException, IOException { + TestRequirements.notMacX86(); + // This is a fake model that simulates language models like GPT2: NDList(inputIds, posIds, // attnMask) -> NDList(logits(1), pastKv(12*2)[, hiddenStates(13)]) Block block = diff --git a/engines/pytorch/pytorch-native/build.gradle b/engines/pytorch/pytorch-native/build.gradle index 99e9a01a03d..f424f55418f 100644 --- a/engines/pytorch/pytorch-native/build.gradle +++ b/engines/pytorch/pytorch-native/build.gradle @@ -80,7 +80,6 @@ def prepareNativeLib(String binaryRoot, String ver) { def files = [ "cpu/libtorch-cxx11-abi-shared-with-deps-${ver}%2Bcpu.zip" : "cpu/linux-x86_64", - "cpu/libtorch-macos-x86_64-${ver}.zip" : "cpu/osx-x86_64", "cpu/libtorch-macos-arm64-${ver}.zip" : "cpu/osx-aarch64", "cpu/libtorch-win-shared-with-deps-${ver}%2Bcpu.zip" : "cpu/win-x86_64", "${cuda}/libtorch-cxx11-abi-shared-with-deps-${ver}%2B${cuda}.zip": "${cuda}/linux-x86_64", @@ -96,12 +95,6 @@ def prepareNativeLib(String binaryRoot, String ver) { copyNativeLibToOutputDir(files, binaryRoot, officialPytorchUrl) copyNativeLibToOutputDir(aarch64Files, binaryRoot, aarch64PytorchUrl) - exec { - commandLine 'install_name_tool', '-add_rpath', '@loader_path', "${binaryRoot}/cpu/osx-x86_64/native/lib/libtorch_cpu.dylib" - } - exec { - commandLine 'install_name_tool', '-add_rpath', '@loader_path', "${binaryRoot}/cpu/osx-x86_64/native/lib/libtorch.dylib" - } exec { commandLine 'install_name_tool', '-add_rpath', '@loader_path', "${binaryRoot}/cpu/osx-aarch64/native/lib/libtorch_cpu.dylib" } @@ -269,7 +262,6 @@ tasks.register('uploadS3') { def uploadDirs = [ "${BINARY_ROOT}/cpu/linux-x86_64/native/lib/", "${BINARY_ROOT}/cpu/osx-aarch64/native/lib/", - "${BINARY_ROOT}/cpu/osx-x86_64/native/lib/", "${BINARY_ROOT}/cpu/win-x86_64/native/lib/", "${BINARY_ROOT}/cpu-precxx11/linux-aarch64/native/lib/", "${BINARY_ROOT}/cpu-precxx11/linux-x86_64/native/lib/", diff --git a/integration/src/test/java/ai/djl/integration/IntegrationTests.java b/integration/src/test/java/ai/djl/integration/IntegrationTests.java index 7956726f54e..23267808fd7 100644 --- a/integration/src/test/java/ai/djl/integration/IntegrationTests.java +++ b/integration/src/test/java/ai/djl/integration/IntegrationTests.java @@ -20,6 +20,10 @@ import org.testng.Assert; import org.testng.annotations.Test; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + public class IntegrationTests { private static final Logger logger = LoggerFactory.getLogger(IntegrationTests.class); @@ -28,23 +32,24 @@ public class IntegrationTests { public void runIntegrationTests() { String[] args = {}; - String[] engines; + List engines = new ArrayList<>(); String defaultEngine = System.getProperty("ai.djl.default_engine"); if (defaultEngine == null) { // TODO: windows CPU build is having OOM issue if 3 engines are loaded and running tests // together if (System.getProperty("os.name").startsWith("Win")) { - engines = new String[] {"MXNet"}; + engines.add("MXNet"); } else if ("aarch64".equals(System.getProperty("os.arch"))) { - engines = new String[] {"PyTorch"}; + engines.add("PyTorch"); } else { - engines = - new String[] { - "MXNet", "PyTorch", "TensorFlow", "OnnxRuntime", "XGBoost", "LightGBM" - }; + engines.addAll( + Arrays.asList("MXNet", "TensorFlow", "OnnxRuntime", "XGBoost", "LightGBM")); + if (!System.getProperty("os.name").startsWith("Mac")) { + engines.add("PyTorch"); + } } } else { - engines = new String[] {defaultEngine}; + engines.add(defaultEngine); } for (String engine : engines) { diff --git a/testing/src/main/java/ai/djl/testing/TestRequirements.java b/testing/src/main/java/ai/djl/testing/TestRequirements.java index 32f242589b9..6efd07409d0 100644 --- a/testing/src/main/java/ai/djl/testing/TestRequirements.java +++ b/testing/src/main/java/ai/djl/testing/TestRequirements.java @@ -58,6 +58,22 @@ public static void gpu() { } } + /** Requires that the test runs on macOS M1. */ + public static void macosM1() { + if (!System.getProperty("os.name").toLowerCase().startsWith("mac") + || !"aarch64".equals(System.getProperty("os.arch"))) { + throw new SkipException("This test requires a macOS M1."); + } + } + + /** Requires that the test runs on macOS M1, Linux or Windows, not mac x86_64. */ + public static void notMacX86() { + if (System.getProperty("os.name").toLowerCase().startsWith("mac") + && "x86_64".equals(System.getProperty("os.arch"))) { + throw new SkipException("macOS x86_64 is not supported."); + } + } + /** Requires that the test runs on OSX or linux, not windows. */ public static void notWindows() { if (System.getProperty("os.name").toLowerCase().startsWith("win")) {