Skip to content

Commit

Permalink
[api] Fixed NDArray.toDevice() missing name issue (#2751)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Aug 17, 2023
1 parent 21c3d8e commit 59af05f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
9 changes: 7 additions & 2 deletions api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,10 @@ public NDArray toDevice(Device device, boolean copy) {
}
return this;
}
throw new UnsupportedOperationException(UNSUPPORTED_MSG);
NDArray array = getManager().create(getShape(), getDataType(), device);
array.setName(getName());
copyTo(array);
return array;
}

/** {@inheritDoc} */
Expand All @@ -160,7 +163,9 @@ public NDArray toType(DataType dataType, boolean copy) {
}
Number[] numbers = toArray();
ByteBuffer bb = toTypeInternal(numbers, dataType);
return manager.create(bb, getShape(), dataType);
NDArray array = manager.create(bb, getShape(), dataType);
array.setName(getName());
return array;
}

private ByteBuffer toTypeInternal(Number[] numbers, DataType dataType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,9 @@ public PtNDArray toDevice(Device device, boolean copy) {
if (device.equals(getDevice()) && !copy) {
return this;
}
return JniUtils.to(this, getDataType(), device);
PtNDArray array = JniUtils.to(this, getDataType(), device);
array.setName(getName());
return array;
}

/** {@inheritDoc} */
Expand All @@ -171,7 +173,9 @@ public PtNDArray toType(DataType dataType, boolean copy) {
if (dataType.equals(getDataType()) && !copy) {
return this;
}
return JniUtils.to(this, dataType, getDevice());
PtNDArray array = JniUtils.to(this, dataType, getDevice());
array.setName(array.getName());
return array;
}

/** {@inheritDoc} */
Expand Down Expand Up @@ -366,7 +370,9 @@ public void detach() {
/** {@inheritDoc} */
@Override
public NDArray duplicate() {
return JniUtils.clone(this);
NDArray array = JniUtils.clone(this);
array.setName(getName());
return array;
}

/** {@inheritDoc} */
Expand Down

0 comments on commit 59af05f

Please sign in to comment.