diff --git a/src/main/java/org/numenta/nupic/util/ArrayUtils.java b/src/main/java/org/numenta/nupic/util/ArrayUtils.java index a45671f9..23cd4dc7 100644 --- a/src/main/java/org/numenta/nupic/util/ArrayUtils.java +++ b/src/main/java/org/numenta/nupic/util/ArrayUtils.java @@ -170,6 +170,32 @@ public static double[] concat(double[] first, double[] second) { } return retVal; } + + public static int maxIndex(int[] shape) { + return shape[0] * Math.max(1, initDimensionMultiples(shape)[0]) - 1; + } + + /** + * Returns an array of coordinates calculated from + * a flat index. + * + * @param index specified flat index + * @param shape the array specifying the size of each dimension + * @param isColumnMajor increments row first then column (default: false) + * + * @return a coordinate array + */ + public static int[] toCoordinates(int index, int[] shape, boolean isColumnMajor) { + int[] dimensionMultiples = initDimensionMultiples(shape); + int[] returnVal = new int[shape.length]; + int base = index; + for(int i = 0;i < dimensionMultiples.length; i++) { + int quotient = base / dimensionMultiples[i]; + base %= dimensionMultiples[i]; + returnVal[i] = quotient; + } + return isColumnMajor ? reverse(returnVal) : returnVal; + } /** * Utility to compute a flat index from coordinates. @@ -205,21 +231,246 @@ public static int fromCoordinate(int[] coordinates) { * Initializes internal helper array which is used for multidimensional * index computation. * - * @param dimensions + * @param shape an array specifying sizes of each dimension * @return */ - public static int[] initDimensionMultiples(int[] dimensions) { + public static int[] initDimensionMultiples(int[] shape) { int holder = 1; - int len = dimensions.length; - int[] dimensionMultiples = new int[dimensions.length]; + int len = shape.length; + int[] dimensionMultiples = new int[shape.length]; for (int i = 0; i < len; i++) { - holder *= (i == 0 ? 1 : dimensions[len - i]); + holder *= (i == 0 ? 1 : shape[len - i]); dimensionMultiples[len - 1 - i] = holder; } return dimensionMultiples; } + + /** + * Takes a two-dimensional input array and returns a new array which is "rotated" + * a quarter-turn clockwise. + * + * @param array The array to rotate. + * @return The rotated array. + */ + public static int[][] rotateRight(int[][] array) { + int r = array.length; + if (r == 0) { + return new int[0][0]; // Special case: zero-length array + } + int c = array[0].length; + int[][] result = new int[c][r]; + for (int i = 0; i < r; i++) { + for (int j = 0; j < c; j++) { + result[j][r - 1 - i] = array[i][j]; + } + } + return result; + } + + + /** + * Takes a two-dimensional input array and returns a new array which is "rotated" + * a quarter-turn counterclockwise. + * + * @param array The array to rotate. + * @return The rotated array. + */ + public static int[][] rotateLeft(int[][] array) { + int r = array.length; + if (r == 0) { + return new int[0][0]; // Special case: zero-length array + } + int c = array[0].length; + int[][] result = new int[c][r]; + for (int i = 0; i < r; i++) { + for (int j = 0; j < c; j++) { + result[c - 1 - j][i] = array[i][j]; + } + } + return result; + } + + /** + * Takes a one-dimensional input array of m n numbers and returns a two-dimensional + * array of m rows and n columns. The first n numbers of the given array are copied + * into the first row of the new array, the second n numbers into the second row, + * and so on. This method throws an IllegalArgumentException if the length of the input + * array is not evenly divisible by n. + * + * @param array The values to put into the new array. + * @param n The number of desired columns in the new array. + * @return The new m n array. + * @throws IllegalArgumentException If the length of the given array is not + * a multiple of n. + */ + public static int[][] ravel(int[] array, int n) throws IllegalArgumentException { + if (array.length % n != 0) { + throw new IllegalArgumentException(array.length + " is not evenly divisible by " + n); + } + int length = array.length; + int[][] result = new int[length / n][n]; + for (int i = 0; i < length; i++) { + result[i / n][i % n] = array[i]; + } + return result; + } + + /** + * Takes a m by n two dimensional array and returns a one-dimensional array of size m n + * containing the same numbers. The first n numbers of the new array are copied from the + * first row of the given array, the second n numbers from the second row, and so on. + * + * @param array The array to be unraveled. + * @return The values in the given array. + */ + public static int[] unravel(int[][] array) { + int r = array.length; + if (r == 0) { + return new int[0]; // Special case: zero-length array + } + int c = array[0].length; + int[] result = new int[r * c]; + int index = 0; + for (int i = 0; i < r; i++) { + for (int j = 0; j < c; j++) { + result[index] = array[i][j]; + index++; + } + } + return result; + } /** + * Takes a two-dimensional array of r rows and c columns and reshapes it to + * have (r*c)/n by n columns. The value in location [i][j] of the input array + * is copied into location [j][i] of the new array. + * + * @param array The array of values to be reshaped. + * @param n The number of columns in the created array. + * @return The new (r*c)/n by n array. + * @throws IllegalArgumentException If r*c is not evenly divisible by n. + */ + public static int[][] reshape(int[][] array, int n) throws IllegalArgumentException { + int r = array.length; + if (r == 0) { + return new int[0][0]; // Special case: zero-length array + } + if ((array.length * array[0].length) % n != 0) { + int size = array.length * array[0].length; + throw new IllegalArgumentException(size + " is not evenly divisible by " + n); + } + int c = array[0].length; + int[][] result = new int[(r * c) / n][n]; + int ii = 0; + int jj = 0; + + for (int i = 0; i < r; i++) { + for (int j = 0; j < c; j++) { + result[ii][jj] = array[i][j]; + jj++; + if (jj == n) { + jj = 0; + ii++; + } + } + } + return result; + } + + /** + * Returns an int[] with the dimensions of the input. + * @param inputArray + * @return + */ + public static int[] shape(Object inputArray) { + int nr = 1 + inputArray.getClass().getName().lastIndexOf('['); + Object oa = inputArray; + int[] l = new int[nr]; + for(int i = 0;i < nr;i++) { + int len = l[i] = Array.getLength(oa); + if (0 < len) { oa = Array.get(oa, 0); } + } + + return l; + } + + /** + * Sorts the array, then returns an array containing the indexes of + * those sorted items in the original array. + *

+ * int[] args = argsort(new int[] { 11, 2, 3, 7, 0 }); + * contains: + * [4, 1, 2, 3, 0] + * + * @param in + * @return + */ + public static int[] argsort(int[] in) { + return argsort(in, -1, -1); + } + + /** + * Sorts the array, then returns an array containing the indexes of + * those sorted items in the original array which are between the + * given bounds (start=inclusive, end=exclusive) + *

+ * int[] args = argsort(new int[] { 11, 2, 3, 7, 0 }, 0, 3); + * contains: + * [4, 1, 2] + * + * @param in + * @return the indexes of input elements filtered in the way specified + * + * @see #argsort(int[]) + */ + public static int[] argsort(int[] in, int start, int end) { + if(start == -1 || end == -1) { + return IntStream.of(in).sorted().map(i -> + Arrays.stream(in).boxed().collect(Collectors.toList()).indexOf(i)).toArray(); + } + + return IntStream.of(in).sorted().map(i -> + Arrays.stream(in).boxed().collect(Collectors.toList()).indexOf(i)) + .skip(start).limit(end).toArray(); + } + + /** + * Transforms 2D matrix of doubles to 1D by concatenation + * @param A + * @return + */ + public static double[] to1D(double[][] A){ + + double[] B = new double[A.length * A[0].length]; + int index = 0; + + for(int i = 0;i l = Arrays.stream(substInds).boxed().collect(Collectors.toList()); + return IntStream.range(0, source.length).map( + i -> l.indexOf(i) == -1 ? source[i] : substitutes[i]).toArray(); + } + /** * Returns a sorted unique array of integers * diff --git a/src/main/java/org/numenta/nupic/util/NearestNeighbor.java b/src/main/java/org/numenta/nupic/util/NearestNeighbor.java new file mode 100644 index 00000000..d506feb4 --- /dev/null +++ b/src/main/java/org/numenta/nupic/util/NearestNeighbor.java @@ -0,0 +1,34 @@ +package org.numenta.nupic.util; + + + +public class NearestNeighbor { + //private LinkedList + + /** + * Creates a new {@code NearestNeighbor} with the specified + * rows. Rows must be 0 or greater, and cols must be greater + * than zero (i.e. NearestNeighbor(0, 40) is ok). + * + * @param rows (optional) number of rows + * @param cols number of columns + */ + public NearestNeighbor(int rows, int cols) { + + } + + public double[] vecLpDist(double distanceNorm, int[] inputPattern, boolean takeRoot) { + return null; + } + + public int[] rightVecSumAtNZ(int[] inputVector, int[][] base) { + int[] results = new int[base.length]; + for (int i = 0; i < base.length; i++) { + for (int j = 0;j < base[i].length;j++) { + if (inputVector[j] != 0) + results[i] += (inputVector[j] * base[i][j]); + } + } + return results; + } +} diff --git a/src/test/java/org/numenta/nupic/util/ArrayUtilsTest.java b/src/test/java/org/numenta/nupic/util/ArrayUtilsTest.java index 10edd24f..a1233d09 100644 --- a/src/test/java/org/numenta/nupic/util/ArrayUtilsTest.java +++ b/src/test/java/org/numenta/nupic/util/ArrayUtilsTest.java @@ -31,11 +31,174 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; public class ArrayUtilsTest { + @Test + public void testSubst() { + int[] original = new int[] { 30, 30, 30, 30, 30 }; + int[] substitutes = new int[] { 0, 1, 2, 3, 4 }; + int[] substInds = new int[] { 4, 1, 3 }; + + int[] expected = { 30, 1, 30, 3, 4 }; + + assertTrue(Arrays.equals(expected, ArrayUtils.subst(original, substitutes, substInds))); + } + + @Test + public void testMaxIndex() { + int max = ArrayUtils.maxIndex(new int[] { 2, 4, 5 }); + assertEquals(39, max); + } + + @Test + public void testToCoordinates() { + int[] coords = ArrayUtils.toCoordinates(19, new int[] { 2, 4, 5 }, false); + assertTrue(Arrays.equals(new int[] { 0, 3, 4 }, coords)); + + coords = ArrayUtils.toCoordinates(19, new int[] { 2, 4, 5 }, true); + assertTrue(Arrays.equals(new int[] { 4, 3, 0 }, coords)); + } + + @Test + public void testArgsort() { + int[] args = ArrayUtils.argsort(new int[] { 11, 2, 3, 7, 0 }); + assertTrue(Arrays.equals(new int[] {4, 1, 2, 3, 0}, args)); + + args = ArrayUtils.argsort(new int[] { 11, 2, 3, 7, 0 }, -1, -1); + assertTrue(Arrays.equals(new int[] {4, 1, 2, 3, 0}, args)); + + args = ArrayUtils.argsort(new int[] { 11, 2, 3, 7, 0 }, 0, 3); + assertTrue(Arrays.equals(new int[] {4, 1, 2}, args)); + } + + @Test + public void testShape() { + int[][] inputPattern = { { 2, 3, 4, 5 }, { 6, 7, 8, 9} }; + int[] shape = ArrayUtils.shape(inputPattern); + assertTrue(Arrays.equals(new int[] { 2, 4 }, shape)); + } + + @Test + public void testReshape() { + int[][] test = { + { 0, 1, 2, 3, 4, 5 }, + { 6, 7, 8, 9, 10, 11 } + }; + + int[][] expected = { + { 0, 1, 2 }, + { 3, 4, 5 }, + { 6, 7, 8 }, + { 9, 10, 11 } + }; + + int[][] result = ArrayUtils.reshape(test, 3); + for(int i = 0;i < result.length;i++) { + for(int j = 0;j < result[i].length;j++) { + assertEquals(expected[i][j], result[i][j]); + } + } + + // Unhappy case + try { + ArrayUtils.reshape(test, 5); + }catch(Exception e) { + assertTrue(e instanceof IllegalArgumentException); + assertEquals("12 is not evenly divisible by 5", e.getMessage()); + } + + // Test zero-length case + int[] result4 = ArrayUtils.unravel(new int[0][]); + assertNotNull(result4); + assertTrue(result4.length == 0); + } + + @Test + public void testRavelAndUnRavel() { + int[] test = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11 }; + int[][] expected = { + { 0, 1, 2, 3, 4, 5 }, + { 6, 7, 8, 9, 10, 11 } + }; + + int[][] result = ArrayUtils.ravel(test, 6); + for(int i = 0;i < result.length;i++) { + for(int j = 0;j < result[i].length;j++) { + assertEquals(expected[i][j], result[i][j]); + } + } + + int[] result2 = ArrayUtils.unravel(result); + for(int i = 0;i < result2.length;i++) { + assertEquals(test[i], result2[i]); + } + + // Unhappy case + try { + ArrayUtils.ravel(test, 5); + }catch(Exception e) { + assertTrue(e instanceof IllegalArgumentException); + assertEquals("12 is not evenly divisible by 5", e.getMessage()); + } + + // Test zero-length case + int[] result4 = ArrayUtils.unravel(new int[0][]); + assertNotNull(result4); + assertTrue(result4.length == 0); + } + + @Test + public void testRotateRight() { + int[][] test = new int[][] { + { 1, 0, 1, 0 }, + { 1, 0, 1, 0 }, + { 1, 0, 1, 0 }, + { 1, 0, 1, 0 } + }; + + int[][] expected = new int[][] { + { 1, 1, 1, 1 }, + { 0, 0, 0, 0 }, + { 1, 1, 1, 1 }, + { 0, 0, 0, 0 } + }; + + int[][] result = ArrayUtils.rotateRight(test); + for(int i = 0;i < result.length;i++) { + for(int j = 0;j < result[i].length;j++) { + assertEquals(result[i][j], expected[i][j]); + } + } + } + + @Test + public void testRotateLeft() { + int[][] test = new int[][] { + { 1, 0, 1, 0 }, + { 1, 0, 1, 0 }, + { 1, 0, 1, 0 }, + { 1, 0, 1, 0 } + }; + + int[][] expected = new int[][] { + { 0, 0, 0, 0 }, + { 1, 1, 1, 1 }, + { 0, 0, 0, 0 }, + { 1, 1, 1, 1 } + }; + + int[][] result = ArrayUtils.rotateLeft(test); + for(int i = 0;i < result.length;i++) { + for(int j = 0;j < result[i].length;j++) { + assertEquals(result[i][j], expected[i][j]); + } + } + } + @Test public void testConcat() { // Test happy path