Skip to content

Commit

Permalink
Minor fixes to improve Apple Silicon MPS support (#2873)
Browse files Browse the repository at this point in the history
* Support 32-bit toTensor()

This fixes an issue where calling NDArrayEx.toTensor() fails on Apple Silicon due to a lack of support for float64.

* Avoid float64 conversion in Classifications constructor

Don't convert probabilities from float32 to float64, because this causes a failure on Apple Silicon.

Co-authored-by: Frank Liu <frankfliu2000@gmail.com>
  • Loading branch information
petebankhead and frankfliu committed Nov 28, 2023
1 parent ce441c1 commit 36d4aec
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 14 deletions.
16 changes: 12 additions & 4 deletions api/src/main/java/ai/djl/modality/Classifications.java
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,18 @@ public Classifications(List<String> classNames, NDArray probabilities) {
*/
public Classifications(List<String> classNames, NDArray probabilities, int topK) {
this.classNames = classNames;
NDArray array = probabilities.toType(DataType.FLOAT64, false);
this.probabilities =
Arrays.stream(array.toDoubleArray()).boxed().collect(Collectors.toList());
array.close();
if (probabilities.getDataType() == DataType.FLOAT32) {
// Avoid converting float32 to float64 as this is not supported on MPS device
this.probabilities = new ArrayList<>();
for (float prob : probabilities.toFloatArray()) {
this.probabilities.add((double) prob);
}
} else {
NDArray array = probabilities.toType(DataType.FLOAT64, false);
this.probabilities =
Arrays.stream(array.toDoubleArray()).boxed().collect(Collectors.toList());
array.close();
}
this.topK = topK;
}

Expand Down
7 changes: 6 additions & 1 deletion api/src/main/java/ai/djl/ndarray/internal/NDArrayEx.java
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,12 @@ default NDArray toTensor() {
if (dim == 3) {
result = result.expandDims(0);
}
result = result.div(255.0).transpose(0, 3, 1, 2);
// For Apple Silicon MPS it is important not to switch to 64-bit float here
if (result.getDataType() == DataType.FLOAT32) {
result = result.div(255.0f).transpose(0, 3, 1, 2);
} else {
result = result.div(255.0).transpose(0, 3, 1, 2);
}
if (dim == 3) {
result = result.squeeze(0);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.Map;

/**
Expand Down Expand Up @@ -118,8 +119,9 @@ private NDArray readData(Artifact.Item item, long length) throws IOException {
byte[] buf = Utils.toByteArray(is);
try (NDArray array =
manager.create(
new Shape(length, IMAGE_WIDTH, IMAGE_HEIGHT, 1), DataType.UINT8)) {
array.set(buf);
ByteBuffer.wrap(buf),
new Shape(length, IMAGE_WIDTH, IMAGE_HEIGHT, 1),
DataType.UINT8)) {
return array.toType(DataType.FLOAT32, false);
}
}
Expand All @@ -132,8 +134,8 @@ private NDArray readLabel(Artifact.Item item) throws IOException {
}

byte[] buf = Utils.toByteArray(is);
try (NDArray array = manager.create(new Shape(buf.length), DataType.UINT8)) {
array.set(buf);
try (NDArray array =
manager.create(ByteBuffer.wrap(buf), new Shape(buf.length), DataType.UINT8)) {
return array.toType(DataType.FLOAT32, false);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.Map;

/**
Expand Down Expand Up @@ -111,8 +112,9 @@ private NDArray readData(Artifact.Item item, long length) throws IOException {
}

byte[] buf = Utils.toByteArray(is);
try (NDArray array = manager.create(new Shape(length, 28, 28, 1), DataType.UINT8)) {
array.set(buf);
try (NDArray array =
manager.create(
ByteBuffer.wrap(buf), new Shape(length, 28, 28, 1), DataType.UINT8)) {
return array.toType(DataType.FLOAT32, false);
}
}
Expand All @@ -123,10 +125,9 @@ private NDArray readLabel(Artifact.Item item) throws IOException {
if (is.skip(8) != 8) {
throw new AssertionError("Failed skip data.");
}

byte[] buf = Utils.toByteArray(is);
try (NDArray array = manager.create(new Shape(buf.length), DataType.UINT8)) {
array.set(buf);
try (NDArray array =
manager.create(ByteBuffer.wrap(buf), new Shape(buf.length), DataType.UINT8)) {
return array.toType(DataType.FLOAT32, false);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
package ai.djl.pytorch.integration;

import ai.djl.Device;
import ai.djl.modality.Classifications;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
Expand All @@ -21,6 +22,10 @@
import org.testng.SkipException;
import org.testng.annotations.Test;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;

public class MpsTest {

@Test
Expand All @@ -36,4 +41,39 @@ public void testMps() {
Assert.assertEquals(array.getDevice().getDeviceType(), "mps");
}
}

private static boolean checkMpsCompatible() {
return "aarch64".equals(System.getProperty("os.arch"))
&& System.getProperty("os.name").startsWith("Mac");
}

@Test
public void testToTensorMPS() {
if (!checkMpsCompatible()) {
throw new SkipException("MPS toTensor test requires Apple Silicon macOS.");
}

// Test that toTensor does not fail on MPS (e.g. due to use of float64 for division)
try (NDManager manager = NDManager.newBaseManager(Device.fromName("mps"))) {
NDArray array = manager.create(127f).reshape(1, 1, 1, 1);
NDArray tensor = array.getNDArrayInternal().toTensor();
Assert.assertEquals(tensor.toFloatArray(), new float[] {127f / 255f});
}
}

@Test
public void testClassificationsMPS() {
if (!checkMpsCompatible()) {
throw new SkipException("MPS classification test requires Apple Silicon macOS.");
}

// Test that classifications do not fail on MPS (e.g. due to conversion of probabilities to
// float64)
try (NDManager manager = NDManager.newBaseManager(Device.fromName("mps"))) {
List<String> names = Arrays.asList("First", "Second", "Third", "Fourth", "Fifth");
NDArray tensor = manager.create(new float[] {0f, 0.125f, 1f, 0.5f, 0.25f});
Classifications classifications = new Classifications(names, tensor);
Assert.assertEquals(classifications.topK(1), Collections.singletonList("Third"));
}
}
}

0 comments on commit 36d4aec

Please sign in to comment.