Skip to content

Commit

Permalink
SDVariable Improvements (#7371)
Browse files Browse the repository at this point in the history
* invertPermutation helpers

* Unit tests

* SDVariable operations: dot, reshape(int, long), permute

* Kotlin interop methods
  • Loading branch information
rnett authored and AlexDBlack committed Mar 27, 2019
1 parent a092919 commit 88d4acb
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 0 deletions.
Expand Up @@ -639,6 +639,26 @@ public SDVariable mmul(String name, SDVariable other, @NonNull MMulTranspose mMu
}


/**
* See {@link #dot(String, SDVariable, int...)}
*/
public SDVariable dot(SDVariable other, int... dimensions){
return dot(null, other, dimensions);
}

/**
* Matrix dot product: out = dot(this,other, dimensions)
*
* @param name Name of the output variable
* @param other Other variable to perform matrix multiplication with
* @return Output variable (result of mmul)
*/
public SDVariable dot(String name, SDVariable other, int... dimensions){
return sameDiff.dot(name, this, other, dimensions);
}



/**
* See {@link #add(String, double)}
*/
Expand Down Expand Up @@ -680,6 +700,22 @@ public SDVariable add(String name, SDVariable x) {
return sameDiff.updateVariableNameAndReference(result, name);
}

/**
* For Kotlin operator interop
* @see #add(String, SDVariable)
*/
public SDVariable plus(SDVariable other){
return add(other);
}

/**
* For Kotlin operator interop
* @see #add(String, double)
*/
public SDVariable plus(double other){
return add(other);
}

/**
* See {@link #sub(String, double)}
*/
Expand Down Expand Up @@ -721,6 +757,22 @@ public SDVariable sub(String name, SDVariable x) {
return sameDiff.updateVariableNameAndReference(result,name);
}

/**
* For Kotlin operator interop
* @see #sub(String, SDVariable)
*/
public SDVariable minus(SDVariable other){
return sub(other);
}

/**
* For Kotlin operator interop
* @see #sub(String, double)
*/
public SDVariable minus(double other){
return sub(other);
}

/**
* See {@link #div(String,double)}
*/
Expand Down Expand Up @@ -804,6 +856,22 @@ public SDVariable mul(String name, SDVariable x) {
return sameDiff.updateVariableNameAndReference(result,name);
}

/**
* For Kotlin operator interop
* @see #mul(String, SDVariable)
*/
public SDVariable times(SDVariable other){
return mul(other);
}

/**
* For Kotlin operator interop
* @see #mul(String, double)
*/
public SDVariable times(double other){
return mul(other);
}

/**
* See {@link #pow(String, double)}
*/
Expand Down Expand Up @@ -1651,6 +1719,41 @@ public SDVariable reshape(SDVariable newShape){
return sameDiff.reshape(this, newShape);
}

/**
* Reshape the current variable to the specified shape. The output variable will have the same values as the
* input, but with the specified shape.<br>
* Note that prod(shape) must match length(input) == prod(input.shape)
*
* @param newShape New shape for variable
* @return Output variable
*/
public SDVariable reshape(int... newShape){
return sameDiff.reshape(this, newShape);
}

/**
* Reshape the current variable to the specified shape. The output variable will have the same values as the
* input, but with the specified shape.<br>
* Note that prod(shape) must match length(input) == prod(input.shape)
*
* @param newShape New shape for variable
* @return Output variable
*/
public SDVariable reshape(long... newShape){
return sameDiff.reshape(this, newShape);
}

/**
* Permute the dimensions of the current variable according to the specified permutation indices.<br>
* Example: if the current variable has shape [a,b,c] and dimensions = [2,0,1] the output has shape [c,a,b]
*
* @param dimensions The new dimension order
* @return Output variable (permuted input)
*/
public SDVariable permute(int... dimensions){
return sameDiff.permute(this, dimensions);
}

/**
* Associate the specified array with this variable
* @param array Array to associate with this variable
Expand Down
34 changes: 34 additions & 0 deletions nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/ArrayUtil.java
Expand Up @@ -3401,4 +3401,38 @@ public static int arrayLength(Object current){
} else
throw new IllegalStateException("Unknown array type (or not an array): " + current.getClass()); //Should never happen
}

/**
* Compute the inverse permutation indices for a permutation operation<br>
* Example: if input is [2, 0, 1] then output is [1, 2, 0]<br>
* The idea is that x.permute(input).permute(invertPermutation(input)) == x
*
* @param input 1D indices for permutation
* @return 1D inverted permutation
*/
public static int[] invertPermutation(int... input){
int[] target = new int[input.length];

for(int i = 0 ; i < input.length ; i++){
target[input[i]] = i;
}

return target;
}

/**
* @see #invertPermutation(int...)
*
* @param input 1D indices for permutation
* @return 1D inverted permutation
*/
public static long[] invertPermutation(long... input){
long[] target = new long[input.length];

for(int i = 0 ; i < input.length ; i++){
target[(int) input[i]] = i;
}

return target;
}
}
@@ -0,0 +1,26 @@
package org.nd4j.linalg.util;

import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;

import org.junit.Test;

public class ArrayUtilTest {

@Test
public void testInvertPermutationInt(){
assertArrayEquals(
new int[]{ 2, 4, 3, 0, 1 },
ArrayUtil.invertPermutation(3, 4, 0, 2, 1)
);
}

@Test
public void testInvertPermutationLong(){
assertArrayEquals(
new long[]{ 2, 4, 3, 0, 1 },
ArrayUtil.invertPermutation(3L, 4L, 0L, 2L, 1L)
);
}

}

0 comments on commit 88d4acb

Please sign in to comment.