From 6d0735598612910810faf8bd44ba27af310c550b Mon Sep 17 00:00:00 2001 From: Marcus Schrattenholzer Date: Fri, 18 Aug 2023 11:51:55 +0200 Subject: [PATCH] Fix isIn(Collection) (issue #1227) --- .../java/tech/tablesaw/api/NumericColumn.java | 6 +++- .../tech/tablesaw/api/NumberColumnTest.java | 29 ++++++++++++------- 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/core/src/main/java/tech/tablesaw/api/NumericColumn.java b/core/src/main/java/tech/tablesaw/api/NumericColumn.java index fd4834220..71972fb90 100644 --- a/core/src/main/java/tech/tablesaw/api/NumericColumn.java +++ b/core/src/main/java/tech/tablesaw/api/NumericColumn.java @@ -8,6 +8,8 @@ import java.text.NumberFormat; import java.util.Collection; import java.util.Optional; +import java.util.SortedSet; +import java.util.TreeSet; import java.util.function.BiPredicate; import java.util.function.DoubleBinaryOperator; import java.util.function.DoubleFunction; @@ -77,9 +79,11 @@ default Selection eval(final BiPredicate predicate, final Number /** {@inheritDoc} */ @Override default Selection isIn(Collection numbers) { + final SortedSet doubleNumbers = new TreeSet<>(); + numbers.forEach(n -> doubleNumbers.add(n.doubleValue())); final Selection results = new BitmapBackedSelection(); for (int i = 0; i < size(); i++) { - if (numbers.contains(getDouble(i))) { + if (doubleNumbers.contains(getDouble(i))) { results.add(i); } } diff --git a/core/src/test/java/tech/tablesaw/api/NumberColumnTest.java b/core/src/test/java/tech/tablesaw/api/NumberColumnTest.java index 4a5808388..066220e18 100644 --- a/core/src/test/java/tech/tablesaw/api/NumberColumnTest.java +++ b/core/src/test/java/tech/tablesaw/api/NumberColumnTest.java @@ -33,6 +33,7 @@ import java.text.NumberFormat; import java.util.ArrayList; import java.util.Arrays; +import java.util.HashSet; import java.util.List; import java.util.Locale; import java.util.concurrent.TimeUnit; @@ -203,21 +204,29 @@ public void createFromNumbers() { } @Test - public void testIsIn() { - Number[] originalValues = {32, 42, 40, 57, 52, -2}; - Number[] resultValues = {10.0, -2.0, 57.0, -5.0}; - List inValues = Arrays.asList(resultValues); + public void testIsInIntegerColumn() { + List inValues = Arrays.asList(new Number[]{10d, (short) -2, 57, 42L, 40f, 52d, -5d}); - DoubleColumn initial = DoubleColumn.create("Test"); - Table t = Table.create("t", initial); + Table t = Table.create("t", IntColumn.create("Test", 32, 42, 40, 57, 52, -2, 11, 25)); - for (Number value : originalValues) { - initial.append(value); - } + Selection filter = t.numberColumn("Test").isIn(inValues); + Table result = t.where(filter); + assertEquals( + new HashSet<>(Arrays.asList(-2, 57, 42, 40, 52)), + result.numberColumn("Test").asSet()); + } + + @Test + public void testIsInDoubleColumn() { + List inValues = Arrays.asList(new Number[]{10d, (short) -2, 57, 42L, 40f, 52d, -5d}); + + Table t = Table.create("t", DoubleColumn.create("Test", 32d, 42d, 40d, 57d, 52d, -2d, 11d, 25d)); Selection filter = t.numberColumn("Test").isIn(inValues); Table result = t.where(filter); - assertEquals(2, result.rowCount()); + assertEquals( + new HashSet<>(Arrays.asList(-2d, 40d, 42d, 52d, 57d)), + result.numberColumn("Test").asSet()); } @Test