Permalink
Browse files

get rid of jcpp shapeInfo calls (#5693)

  • Loading branch information...
raver119 committed Jun 23, 2018
1 parent ebec55f commit 70dc867fe8e1cc21ad47473d5c99517537286728
@@ -111,7 +111,8 @@
protected transient boolean compressed = false;
// this field holds jvm copy of shapeInfo
protected long[] javaShapeInformation;
protected transient JvmShapeInfo jvmShapeInfo;
//Precalculate these arrays (like [3,2,1,0], [2,1,0], [1,0], [0] etc) for use in TAD, to avoid creating same int[]s over and over
@@ -1063,7 +1064,7 @@ public INDArray javaTensorAlongDimension(int index, int... dimension) {
private void setShapeInformation(Pair<DataBuffer, long[]> shapeInfo) {
this.shapeInformation = shapeInfo.getFirst();
this.javaShapeInformation = shapeInfo.getSecond();
this.jvmShapeInfo = new JvmShapeInfo(shapeInfo.getSecond());
}
@@ -1181,10 +1182,10 @@ public long vectorsAlongDimension(int dimension) {
long length = length();
if (dimension >= Shape.rank(javaShapeInformation)) {
if (length / size(Shape.rank(javaShapeInformation) - 1) >= Integer.MAX_VALUE)
if (dimension >= jvmShapeInfo.rank) {
if (length / size(jvmShapeInfo.rank - 1) >= Integer.MAX_VALUE)
throw new IllegalArgumentException("Vectors along dimension can not be >= Integer.MAX_VALUE");
return (int) (length / size(Shape.rank(javaShapeInformation) - 1));
return (int) (length / size(jvmShapeInfo.rank - 1));
}
if (length / size(dimension) >= Integer.MAX_VALUE)
throw new IllegalArgumentException("Vectors along dimension can not be >= Integer.MAX_VALUE");
@@ -1201,10 +1202,10 @@ public long vectorsAlongDimension(int dimension) {
@Override
public INDArray vectorAlongDimension(int index, int dimension) {
if (dimension < 0)
dimension = Shape.rank(javaShapeInformation) + dimension;
dimension = jvmShapeInfo.getRank() + dimension;
//return the whole thing
if (dimension == Shape.rank(javaShapeInformation) - 1 && size(dimension) == 1 && rank() > 2
if (dimension == jvmShapeInfo.getRank() - 1 && size(dimension) == 1 && rank() > 2
|| rank() > 2 && dimension == 0 && size(dimension) == 1) {
return this;
}
@@ -1516,7 +1517,7 @@ public INDArray putScalar(int[] indexes, double value) {
return putScalar(indexes[0], indexes[1], indexes[2], indexes[3], value);
} else {
autoProcessScalarCall();
long offset = Shape.getOffset(javaShapeInformation, indexes);
long offset = Shape.getOffset(jvmShapeInfo.javaShapeInformation, indexes);
data.put(offset, value);
}
return this;
@@ -1542,7 +1543,7 @@ public INDArray putScalar(long[] indexes, double value) {
return putScalar(indexes[0], indexes[1], indexes[2], indexes[3], value);
} else {
autoProcessScalarCall();
long offset = Shape.getOffset(javaShapeInformation, indexes);
long offset = Shape.getOffset(jvmShapeInfo.javaShapeInformation, indexes);
data.put(offset, value);
}
return this;
@@ -1560,7 +1561,7 @@ public INDArray putScalar(long row, long col, double value) {
if (rank() > 2)
throw new IllegalStateException("Cannot use putScalar(int,int,double) on a rank " + rank() + " INDArray");
long offset = Shape.getOffsetUnsafe(javaShapeInformation, row, col);
long offset = Shape.getOffsetUnsafe(jvmShapeInfo.javaShapeInformation, row, col);
data.put(offset, value);
return this;
}
@@ -1574,16 +1575,16 @@ public INDArray putScalar(long dim0, long dim1, long dim2, double value) {
throw new IllegalStateException(
"Cannot use putScalar(int,int,int,double) on a rank " + rank() + " INDArray");
long offset = 0; // Shape.getOffsetUnsafe(javaShapeInformation, dim0, dim1, dim2);
long size_0 = javaShapeInformation[1];
long size_1 = javaShapeInformation[1 + 1];
long size_2 = javaShapeInformation[1 + 2];
long size_0 = jvmShapeInfo.javaShapeInformation[1];
long size_1 = jvmShapeInfo.javaShapeInformation[1 + 1];
long size_2 = jvmShapeInfo.javaShapeInformation[1 + 2];
if (size_0 != 1)
offset += dim0 * javaShapeInformation[1 + 0 + 3];
offset += dim0 * jvmShapeInfo.javaShapeInformation[1 + 0 + 3];
if (size_1 != 1)
offset += dim1 * javaShapeInformation[1 + 1 + 3];
offset += dim1 * jvmShapeInfo.javaShapeInformation[1 + 1 + 3];
if (size_2 != 1)
offset += dim2 * javaShapeInformation[1 + 2 + 3];
offset += dim2 * jvmShapeInfo.javaShapeInformation[1 + 2 + 3];
data.put(offset, value);
return this;
@@ -1597,7 +1598,7 @@ public INDArray putScalar(long dim0, long dim1, long dim2, long dim3, double val
if (rank() != 4)
throw new IllegalStateException(
"Cannot use putScalar(int,int,int,int,double) on a rank " + rank() + " INDArray");
long offset = Shape.getOffsetUnsafe(javaShapeInformation, dim0, dim1, dim2, dim3);
long offset = Shape.getOffsetUnsafe(jvmShapeInfo.javaShapeInformation, dim0, dim1, dim2, dim3);
data.put(offset, value);
return this;
}
@@ -2031,13 +2032,13 @@ public boolean isScalar() {
if (isEmpty())
return false;
if (Shape.rank(javaShapeInformation) == 0) {
if (jvmShapeInfo.rank == 0) {
return true;
} else if (Shape.rank(javaShapeInformation) > 2) {
} else if (jvmShapeInfo.rank > 2) {
return false;
} else if (Shape.rank(javaShapeInformation) == 1) {
return shapeOf().getInt(0) == 1;
} else if (Shape.rank(javaShapeInformation) == 2) {
} else if (jvmShapeInfo.rank == 1) {
return shape()[0] == 1;
} else if (jvmShapeInfo.rank == 2) {
return shape()[0] == 1 && shape()[1] == 1 || length() == 1;
}
@@ -2588,8 +2589,8 @@ public boolean isView() {
And it's possible to be not a view, and have non-empty originalBuffer
*/
// length/data.length can be different in case of Threshold conversion
return Shape.offset(javaShapeInformation) > 0
|| (length() < data().length() && data.dataType() != DataBuffer.Type.INT)
return Shape.offset(jvmShapeInfo.javaShapeInformation) > 0
|| (length() < data().length() && data.dataType() != DataBuffer.Type.LONG)
|| data().originalDataBuffer() != null;
}
@@ -2682,7 +2683,7 @@ public INDArray subArray(long[] offsets, int[] shape, int[] stride) {
long[] dotProductOffsets = offsets;
int[] dotProductStride = stride;
long offset = Shape.offset(javaShapeInformation) + NDArrayIndex.offset(dotProductStride, dotProductOffsets);
long offset = Shape.offset(jvmShapeInfo.javaShapeInformation) + NDArrayIndex.offset(dotProductStride, dotProductOffsets);
if (offset >= data().length())
offset = ArrayUtil.sumLong(offsets);
@@ -2749,7 +2750,7 @@ public INDArray getScalar(long i) {
if (i < 0)
i += this.length();
long idx = this.isVector() ? i : Shape.getOffset(this.javaShapeInformation, Shape.ind2subC(this.shape(), i));
long idx = this.isVector() ? i : Shape.getOffset(jvmShapeInfo.javaShapeInformation, Shape.ind2subC(this.shape(), i));
val buffer = Nd4j.createBuffer(this.data(), idx, 1);
val shape = Nd4j.getShapeInfoProvider().createShapeInformation(new long[0], new long[0],0,1,'c');
return Nd4j.createArrayFromShapeBuffer(buffer, shape);
@@ -3198,10 +3199,10 @@ protected DataBuffer strideOf() {
@Override
public int stride(int dimension) {
int rank = Shape.rank(javaShapeInformation);
int rank = jvmShapeInfo.rank;
if (dimension < 0)
return strideOf().getInt(dimension + rank);
return strideOf().getInt(dimension);
return (int) stride()[dimension + rank];
return (int) stride()[dimension];
}
@Override
@@ -4163,12 +4164,12 @@ public INDArray replaceWhere(INDArray arr, Condition condition) {
@Deprecated
public long linearIndex(long i) {
long idx = i;
for (int j = 0; j < Shape.rank(javaShapeInformation) - 1; j++) {
for (int j = 0; j < jvmShapeInfo.rank - 1; j++) {
if (size((int) i) == 1)
continue;
idx += i * stride(j);
}
return Shape.offset(javaShapeInformation) + (idx);
return Shape.offset(jvmShapeInfo.javaShapeInformation) + (idx);
}
@@ -4192,7 +4193,7 @@ public INDArray slice(long slice) {
if (slice >= slices)
throw new IllegalArgumentException("Illegal slice " + slice);
if (Shape.rank(javaShapeInformation) == 0 || isVector()) {
if (jvmShapeInfo.rank == 0 || isVector()) {
if (slice == 0 || isVector()) {
return createScalarForIndex(slice, true);
}
@@ -4269,7 +4270,7 @@ public INDArray slice(long slice, int dimension) {
if (slice >= slices)
throw new IllegalArgumentException("Illegal slice " + slice);
if (Shape.rank(javaShapeInformation) == 0) {
if (jvmShapeInfo.rank == 0) {
if (slice == 0)
return createScalarForIndex(slice, true);
else
@@ -4305,7 +4306,7 @@ public INDArray getScalar(int[] indexes) {
if (indexes[i] < 0)
indexes[i] += this.size(i);
}
long idx = Shape.getOffset(this.javaShapeInformation, indexes);
long idx = Shape.getOffset(jvmShapeInfo.javaShapeInformation, indexes);
val buffer = Nd4j.createBuffer(this.data(), idx, 1);
val shape = Nd4j.getShapeInfoProvider().createShapeInformation(new long[0], new long[0],0,1,'c');
return Nd4j.createArrayFromShapeBuffer(buffer, shape);
@@ -4321,7 +4322,7 @@ public INDArray getScalar(long... indexes) {
indexes[i] += this.size(i);
}
long idx = Shape.getOffset(this.javaShapeInformation, indexes);
long idx = Shape.getOffset(jvmShapeInfo.javaShapeInformation, indexes);
val buffer = Nd4j.createBuffer(this.data(), idx, 1);
val shape = Nd4j.getShapeInfoProvider().createShapeInformation(new long[0], new long[0],0,1,'c');
return Nd4j.createArrayFromShapeBuffer(buffer, shape);
@@ -4695,8 +4696,8 @@ public void checkDimensions(INDArray other) {
Shape.stride(shapeInformation)) : " Other array should have been stride: "
+ Shape.toString(Shape.stride(shapeInformation)) + " but was "
+ Arrays.toString(other.stride());
assert Shape.offset(javaShapeInformation) == other.offset() : "Offset of this array is "
+ Shape.offset(javaShapeInformation) + " but other was " + other.offset();
assert Shape.offset(jvmShapeInfo.javaShapeInformation) == other.offset() : "Offset of this array is "
+ Shape.offset(jvmShapeInfo.javaShapeInformation) + " but other was " + other.offset();
}
@@ -5279,7 +5280,7 @@ public LongBuffer shapeInfo() {
* @return the shape of this matrix
*/
public long[] shape() {
return Shape.shape(javaShapeInformation);
return jvmShapeInfo.shape;
}
/**
@@ -5300,7 +5301,7 @@ public String shapeInfoToString() {
*/
@Override
public long[] stride() {
return Shape.stride(javaShapeInformation);
return jvmShapeInfo.stride;
}
@@ -5314,7 +5315,7 @@ public long offset() {
@Override
public char ordering() {
return Shape.order(javaShapeInformation);
return jvmShapeInfo.order;
}
/**
@@ -5326,29 +5327,26 @@ public char ordering() {
*/
@Override
public long size(int dimension) {
if (dimension < 0)
dimension += jvmShapeInfo.rank;
if (isScalar()) {
if (dimension == 0 || dimension == 1 || dimension < 0)
return length();
else
throw new IllegalArgumentException("Illegal dimension for scalar " + dimension);
}
if (dimension < 0) {
return shapeOf().getInt(dimension + Shape.rank(javaShapeInformation));
}
if (dimension >= rank())
throw new IllegalArgumentException("Invalid size: cannot get size of dimension " + dimension + " for rank "
+ rank() + " NDArray (array shape: " + Arrays.toString(this.shape()) + ")");
val _shapeInfo = shapeInfoDataBuffer();
val _shape = shapeOf();
return shapeOf().getInt(dimension);
return jvmShapeInfo.shape[dimension];
}
@Override
public int rank() {
return Shape.rank(javaShapeInformation);
return jvmShapeInfo.rank;
}
/**
@@ -5358,7 +5356,7 @@ public int rank() {
*/
@Override
public long length() {
return Shape.length(javaShapeInformation);
return jvmShapeInfo.length;
}
/**
@@ -5369,7 +5367,7 @@ public long length() {
@Override
@Deprecated
public long lengthLong() {
return Shape.length(javaShapeInformation);
return jvmShapeInfo.length;
}
@Override
@@ -5390,7 +5388,7 @@ public INDArray broadcast(INDArray result) {
boolean compatible = true;
int count = shape.length - 1;
int thisCount = Shape.rank(javaShapeInformation) - 1;
int thisCount = jvmShapeInfo.rank - 1;
for (int i = shape.length - 1; i > 0; i--) {
if (count < 0 || thisCount < 0)
break;
@@ -5511,7 +5509,7 @@ public INDArray dimShuffle(Object[] rearrange, int[] newOrder, boolean[] broadCa
public INDArray dimShuffle(Object[] rearrange, long[] newOrder, boolean[] broadCastable) {
Nd4j.getCompressor().autoDecompress(this);
if (broadCastable.length != Shape.rank(javaShapeInformation))
if (broadCastable.length != jvmShapeInfo.rank)
throw new IllegalArgumentException(
"The broadcastable dimensions must be the same length as the current shape");
@@ -5621,7 +5619,7 @@ public INDArray permute(int... rearrange) {
return dup();
boolean alreadyInOrder = true;
//IntBuffer shapeInfo = shapeInfo();
int rank = Shape.rank(javaShapeInformation);
int rank = jvmShapeInfo.rank;
for (int i = 0; i < rank; i++) {
if (rearrange[i] != i) {
alreadyInOrder = false;
@@ -5654,7 +5652,7 @@ public INDArray permute(int... rearrange) {
public INDArray permutei(int... rearrange) {
boolean alreadyInOrder = true;
val shapeInfo = shapeInfo();
int rank = Shape.rank(javaShapeInformation);
int rank = jvmShapeInfo.rank;
for (int i = 0; i < rank; i++) {
if (rearrange[i] != i) {
alreadyInOrder = false;
@@ -5730,7 +5728,7 @@ public INDArray permutei(int... rearrange) {
protected void checkArrangeArray(int[] arr) {
assert arr.length == Shape.rank(javaShapeInformation) : "Invalid rearrangement: number of arrangement != shape";
assert arr.length == jvmShapeInfo.rank : "Invalid rearrangement: number of arrangement != shape";
for (int i = 0; i < arr.length; i++) {
if (arr[i] >= arr.length)
throw new IllegalArgumentException("The specified dimensions can't be swapped. Given element " + i
@@ -5760,7 +5758,7 @@ protected void autoProcessScalarCall() {
*/
@Override
public boolean isVector() {
if (Shape.rank(javaShapeInformation) == 1)
if (jvmShapeInfo.rank == 1)
return true;
return isRowVector() || isColumnVector();
@@ -6588,6 +6586,6 @@ public INDArray convertToDoubles() {
*/
@Override
public boolean isEmpty() {
return ArrayOptionsHelper.arrayType(javaShapeInformation) == ArrayType.EMPTY;
return ArrayOptionsHelper.arrayType(jvmShapeInfo.javaShapeInformation) == ArrayType.EMPTY;
}
}
@@ -0,0 +1,28 @@
package org.nd4j.linalg.api.ndarray;
import lombok.Getter;
import lombok.NonNull;
import org.nd4j.linalg.api.shape.Shape;
public class JvmShapeInfo {
@Getter protected long[] javaShapeInformation;
@Getter protected long[] shape;
@Getter protected long[] stride;
@Getter protected long length;
@Getter protected long ews;
@Getter protected long extras;
@Getter protected char order;
@Getter protected int rank;
public JvmShapeInfo(@NonNull long[] javaShapeInformation) {
this.javaShapeInformation = javaShapeInformation;
this.shape = Shape.shape(javaShapeInformation);
this.stride = Shape.stride(javaShapeInformation);
this.length = Shape.length(javaShapeInformation);
this.ews = Shape.elementWiseStride(javaShapeInformation);
this.extras = Shape.extras(javaShapeInformation);
this.order = Shape.order(javaShapeInformation);
this.rank = Shape.rank(javaShapeInformation);
}
}
Oops, something went wrong.

0 comments on commit 70dc867

Please sign in to comment.