diff --git a/src/main/java/org/numenta/nupic/algorithms/SpatialPooler.java b/src/main/java/org/numenta/nupic/algorithms/SpatialPooler.java index 078879af..dd7c43b7 100644 --- a/src/main/java/org/numenta/nupic/algorithms/SpatialPooler.java +++ b/src/main/java/org/numenta/nupic/algorithms/SpatialPooler.java @@ -25,7 +25,11 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Map; +import java.util.Map.Entry; import java.util.Random; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import org.numenta.nupic.Connections; import org.numenta.nupic.model.Column; @@ -898,11 +902,19 @@ public int[] inhibitColumns(Connections c, double[] overlaps) { * @return */ public int[] inhibitColumnsGlobal(Connections c, double[] overlaps, double density) { - int numCols = c.getNumColumns(); - int numActive = (int)(density * numCols); - int[] winners = ArrayUtils.nGreatest(overlaps, numActive); - Arrays.sort(winners); - return winners; + int numCols = c.getNumColumns(); + int numActive = (int)(density * numCols); + return IntStream.range(0, overlaps.length) + .boxed() + .collect(Collectors.toMap(index->index, index->overlaps[index])) + .entrySet() + .stream() + .sorted(Map.Entry.comparingByValue().reversed()) + .limit(numActive) + .map(Entry::getKey) + .sorted() + .mapToInt(Integer::intValue) + .toArray(); } /** diff --git a/src/test/java/org/numenta/nupic/util/ArrayUtilsTest.java b/src/test/java/org/numenta/nupic/util/ArrayUtilsTest.java index f42a8ed8..10edd24f 100644 --- a/src/test/java/org/numenta/nupic/util/ArrayUtilsTest.java +++ b/src/test/java/org/numenta/nupic/util/ArrayUtilsTest.java @@ -323,4 +323,10 @@ public void testIsSparse() { assertFalse(ArrayUtils.isSparse(t)); assertTrue(ArrayUtils.isSparse(t1)); } + + @Test + public void testNGreatest() { + double[] overlaps = new double[] { 1, 2, 1, 4, 8, 3, 12, 5, 4, 1 }; + assertTrue(Arrays.equals(new int[] { 6, 4, 7 }, ArrayUtils.nGreatest(overlaps, 3))); + } }