Skip to content

Commit

Permalink
- Add multidimension support to LowMemoryBinarySparseMatrix.
Browse files Browse the repository at this point in the history
- Update BinarySparseMatrixTest to run test on both implementations.
  • Loading branch information
chelu committed Oct 17, 2015
1 parent 8c1204b commit a242031
Show file tree
Hide file tree
Showing 9 changed files with 122 additions and 85 deletions.
13 changes: 1 addition & 12 deletions src/main/java/org/numenta/nupic/util/BitSetMatrix.java
Expand Up @@ -45,24 +45,13 @@ public BitSetMatrix(int[] dimensions, boolean useColumnMajorOrdering) {
this.data = new BitSet(getSize());
}

@Override
public Boolean get(int... coordinates) {
return get(computeIndex(coordinates));
}

@Override
public Boolean get(int index) {
return this.data.get(index);
}

@Override
public Matrix<Boolean> set(int[] coordinates, Boolean value) {
this.data.set(computeIndex(coordinates), value);
return this;
}

@Override
public FlatMatrix<Boolean> set(int index, Boolean value) {
public BitSetMatrix set(int index, Boolean value) {
this.data.set(index, value);
return this;
}
Expand Down
13 changes: 1 addition & 12 deletions src/main/java/org/numenta/nupic/util/FlatArrayMatrix.java
Expand Up @@ -44,20 +44,9 @@ public FlatArrayMatrix(int[] dimensions, boolean useColumnMajorOrdering) {
this.data = (T[]) new Object[getSize()];
}

@Override
public T get(int... indexes) {
return get(computeIndex(indexes));
}

@Override
public T get(int index) {
return data[index];
}

@Override
public FlatArrayMatrix<T> set(int[] indexes, T value) {
set(computeIndex(indexes), value);
return this;
return this.data[index];
}

@Override
Expand Down
20 changes: 18 additions & 2 deletions src/main/java/org/numenta/nupic/util/FlatMatrixSupport.java
Expand Up @@ -23,8 +23,6 @@
package org.numenta.nupic.util;

import java.lang.reflect.Array;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.Arrays;

/**
Expand Down Expand Up @@ -184,6 +182,23 @@ public static String print1DArray(Object aObject) {
return "[]";
}

@Override
public abstract T get(int index);

@Override
public abstract FlatMatrixSupport<T> set(int index, T value);

@Override
public T get(int... indexes) {
return get(computeIndex(indexes));
}

@Override
public FlatMatrixSupport<T> set(int[] indexes, T value) {
set(computeIndex(indexes), value);
return this;
}

public int getSize() {
return Arrays.stream(this.dimensions).reduce((n,i) -> n*i).getAsInt();
}
Expand Down Expand Up @@ -211,4 +226,5 @@ public int getNumDimensions() {
public int[] getDimensionMultiples() {
return this.dimensionMultiples;
}

}
Expand Up @@ -22,10 +22,11 @@

package org.numenta.nupic.util;

import java.lang.reflect.Array;
import java.util.Arrays;

/**
* Low Memory implementation of {@link SparseBinaryMatrixSupport} without
* Low Memory implementation of {@link SparseBinaryMatrix} without
* a backing array.
*
* @author Jose Luis Martin
Expand All @@ -43,17 +44,34 @@ public LowMemorySparseBinaryMatrix(int[] dimensions, boolean useColumnMajorOrder
@Override
public Object getSlice(int... coordinates) {
int[] dimensions = getDimensions();
int[] slice = new int[dimensions[1]];
for(int j = 0; j < dimensions[1]; j++) {
slice[j] = get(coordinates[0], j);
// check for valid coordinates
if (coordinates.length >= dimensions.length)
sliceError(coordinates);

int sliceDimensionsLength = dimensions.length - coordinates.length;
int[] sliceDimensions = (int[]) Array.newInstance(int.class, sliceDimensionsLength);

for (int i = coordinates.length ; i < dimensions.length; i++)
sliceDimensions[i - coordinates.length] = dimensions[i];

int[] elementCoordinates = Arrays.copyOf(coordinates, coordinates.length + 1);
Object slice = Array.newInstance(int.class, sliceDimensions);

if (coordinates.length + 1 == dimensions.length) {
// last slice

for (int i = 0; i < dimensions[coordinates.length]; i++) {
elementCoordinates[coordinates.length] = i;
Array.set(slice, i, get(elementCoordinates));
}
}
//Ensure return value is of type Array
if(!slice.getClass().isArray()) {
throw new IllegalArgumentException(
"This method only returns the array holding the specified index: " +
Arrays.toString(coordinates));
else {
for (int i = 0; i < dimensions[sliceDimensionsLength]; i++) {
elementCoordinates[coordinates.length] = i;
Array.set(slice, i, getSlice(elementCoordinates));
}
}

return slice;
}

Expand All @@ -79,10 +97,8 @@ public void rightVecSumAtNZ(int[] inputVector, int[] results) {

@Override
public LowMemorySparseBinaryMatrix set(int value, int... coordinates) {
if (value > 0) {
super.set(value, coordinates);
updateTrueCounts(coordinates);
}
super.set(value, coordinates);
updateTrueCounts(coordinates);

return this;
}
Expand All @@ -96,27 +112,16 @@ public LowMemorySparseBinaryMatrix setForTest(int index, int value) {
return this;
}

/**
* Update the true counts for a coordinates.
* @param coordinates
*/
private void updateTrueCounts(int... coordinates) {
int sum = 0;

for (int j = 0; j < dimensions[1]; j++) {
sum += getIntValue(coordinates[0], j);
}

Object slice = getSlice(coordinates[0]);
int sum = ArrayUtils.aggregateArray(slice);
setTrueCount(coordinates[0],sum);
}

@Override
protected int[] values() {
int[] dense = new int[getMaxIndex()];
for (int i = 0; i <= getMaxIndex(); i++) {
dense[i] = get(i);
}

return dense;
}


@Override
public LowMemorySparseBinaryMatrix set(int index, Object value) {
super.set(index, ((Integer) value).intValue());
Expand Down
8 changes: 3 additions & 5 deletions src/main/java/org/numenta/nupic/util/SparseBinaryMatrix.java
Expand Up @@ -63,18 +63,16 @@ private void back(int val, int... coordinates) {
public Object getSlice(int... coordinates) {
Object slice = backingArray;
for(int i = 0;i < coordinates.length;i++) {
slice = Array.get(slice, coordinates[i]);;
slice = Array.get(slice, coordinates[i]);
}
//Ensure return value is of type Array
if(!slice.getClass().isArray()) {
throw new IllegalArgumentException(
"This method only returns the array holding the specified index: " +
Arrays.toString(coordinates));
sliceError(coordinates);
}

return slice;
}

/**
* Fills the specified results array with the result of the
* matrix vector multiplication.
Expand Down
Expand Up @@ -56,7 +56,18 @@ public SparseBinaryMatrixSupport(int[] dimensions, boolean useColumnMajorOrderin
* an actual value instead of the array holding it.
*/
public abstract Object getSlice(int... coordinates);


/**
* Launch getSlice error, to share it with subclass {@link #getSlice(int...)}
* implementations.
* @param coordinates
*/
protected void sliceError(int... coordinates) {
throw new IllegalArgumentException(
"This method only returns the array holding the specified index: " +
Arrays.toString(coordinates));
}

/**
* Fills the specified results array with the result of the
* matrix vector multiplication.
Expand Down
Expand Up @@ -27,8 +27,9 @@
import org.junit.Assert;

/**
* Test for {@link LowMemorySparseBinaryMatrix}
*
* @author Jose Luis Martin
*
*/
public class LowMemorySparseBinaryMatrixTest {

Expand Down
12 changes: 10 additions & 2 deletions src/test/java/org/numenta/nupic/util/MatrixTest.java
Expand Up @@ -22,8 +22,12 @@

package org.numenta.nupic.util;

import org.junit.Test;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;

import java.util.Arrays;

import org.junit.Test;

/**
* Generic test for Matrix hirearchy.
Expand All @@ -45,6 +49,7 @@ public void testBitSetMatrixSet() {
}

assertArrayEquals(expected, asDense(bsm));
assertEquals(Arrays.toString(expected), FlatArrayMatrix.print1DArray(asDense(bsm)));
}

@Test
Expand All @@ -56,8 +61,9 @@ public void testFlatArrayMatrixSet() {
for (int index : this.indexes) {
fam.set(index, 1);
}

assertArrayEquals(expected, asDense(fam));
assertEquals(Arrays.toString(expected), FlatArrayMatrix.print1DArray(asDense(fam)));
}

private Object[] asDense(FlatMatrix<?> matrix) {
Expand All @@ -67,6 +73,8 @@ private Object[] asDense(FlatMatrix<?> matrix) {
dense[i] = matrix.get(i);
}



return dense;
}
}

0 comments on commit a242031

Please sign in to comment.