diff --git a/src/java/org/apache/cassandra/index/sai/disk/PostingListKeyRangeIterator.java b/src/java/org/apache/cassandra/index/sai/disk/PostingListKeyRangeIterator.java index 904880462bf9..3afc61f46239 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/PostingListKeyRangeIterator.java +++ b/src/java/org/apache/cassandra/index/sai/disk/PostingListKeyRangeIterator.java @@ -115,7 +115,7 @@ protected PrimaryKey computeNext() if (rowId == PostingList.END_OF_STREAM) return endOfData(); - return new PrimaryKeyWithSource(primaryKeyMap, rowId, searcherContext.minimumKey, searcherContext.maximumKey); + return primaryKeyMap.primaryKeyFromRowId(rowId, searcherContext.minimumKey, searcherContext.maximumKey); } catch (Throwable t) { @@ -160,20 +160,11 @@ private long getNextRowId() throws IOException long segmentRowId; if (needsSkipping) { - long targetSstableRowId; - if (skipToToken instanceof PrimaryKeyWithSource - && ((PrimaryKeyWithSource) skipToToken).getSourceSstableId().equals(primaryKeyMap.getSSTableId())) + long targetSstableRowId = primaryKeyMap.ceiling(skipToToken); + // skipToToken is larger than max token in token file + if (targetSstableRowId < 0) { - targetSstableRowId = ((PrimaryKeyWithSource) skipToToken).getSourceRowId(); - } - else - { - targetSstableRowId = primaryKeyMap.ceiling(skipToToken); - // skipToToken is larger than max token in token file - if (targetSstableRowId < 0) - { - return PostingList.END_OF_STREAM; - } + return PostingList.END_OF_STREAM; } int targetSegmentRowId = Math.toIntExact(targetSstableRowId - searcherContext.getSegmentRowIdOffset()); segmentRowId = postingList.advance(targetSegmentRowId); diff --git a/src/java/org/apache/cassandra/index/sai/disk/PrimaryKeyMap.java b/src/java/org/apache/cassandra/index/sai/disk/PrimaryKeyMap.java index aaf94c83a709..5a0e5e712545 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/PrimaryKeyMap.java +++ b/src/java/org/apache/cassandra/index/sai/disk/PrimaryKeyMap.java @@ -21,6 +21,7 @@ import java.io.Closeable; import java.io.IOException; +import javax.annotation.Nonnull; import javax.annotation.concurrent.NotThreadSafe; import org.apache.cassandra.index.sai.utils.PrimaryKey; @@ -84,6 +85,24 @@ default void close() throws IOException */ PrimaryKey primaryKeyFromRowId(long sstableRowId); + /** + * Returns a {@link PrimaryKey} for a row Id + * + * Note: the lower and upper bounds are used to avoid reading the primary key from disk in the event + * that compared primary keys are in non-overlapping ranges. The ranges can be within the table, and must + * contain the row id. This requirement is not validated, as validation would remove the performance benefit + * of this optimization. + * + * @param sstableRowId the row Id to lookup + * @param lowerBound the inclusive lower bound of the primary key being created + * @param upperBound the inclusive upper bound of the primary key being created + * @return the {@link PrimaryKey} associated with the row Id + */ + default PrimaryKey primaryKeyFromRowId(long sstableRowId, @Nonnull PrimaryKey lowerBound, @Nonnull PrimaryKey upperBound) + { + return primaryKeyFromRowId(sstableRowId); + } + /** * Returns a row Id for a {@link PrimaryKey}. If there is no such term, returns the `-(next row id) - 1` where * `next row id` is the row id of the next greatest {@link PrimaryKey} in the map. diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentMetadata.java b/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentMetadata.java index 7236512d75c4..b4479b4d51da 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentMetadata.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentMetadata.java @@ -38,7 +38,6 @@ import org.apache.cassandra.index.sai.SSTableContext; import org.apache.cassandra.index.sai.disk.ModernResettableByteBuffersIndexOutput; import org.apache.cassandra.index.sai.disk.PostingList; -import org.apache.cassandra.index.sai.disk.PrimaryKeyWithSource; import org.apache.cassandra.index.sai.disk.format.IndexComponentType; import org.apache.cassandra.index.sai.disk.format.Version; import org.apache.cassandra.index.sai.disk.io.IndexInput; @@ -181,10 +180,9 @@ private SegmentMetadata(IndexInput input, IndexContext context, Version version, // We need to load eagerly to allow us to close the partition key map. min = pkm.primaryKeyFromRowId(minSSTableRowId).loadDeferred(); max = pkm.primaryKeyFromRowId(maxSSTableRowId).loadDeferred(); + this.minKey = pkm.primaryKeyFromRowId(minSSTableRowId, min, max).loadDeferred(); + this.maxKey = pkm.primaryKeyFromRowId(maxSSTableRowId, min, max).loadDeferred(); } - - this.minKey = new PrimaryKeyWithSource(min, sstableContext.sstable.getId(), minSSTableRowId, min, max); - this.maxKey = new PrimaryKeyWithSource(max, sstableContext.sstable.getId(), maxSSTableRowId, min, max); } else { diff --git a/src/java/org/apache/cassandra/index/sai/disk/PrimaryKeyWithSource.java b/src/java/org/apache/cassandra/index/sai/disk/v2/PrimaryKeyWithSource.java similarity index 82% rename from src/java/org/apache/cassandra/index/sai/disk/PrimaryKeyWithSource.java rename to src/java/org/apache/cassandra/index/sai/disk/v2/PrimaryKeyWithSource.java index eecf6bde2761..fafb216aa8e7 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/PrimaryKeyWithSource.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v2/PrimaryKeyWithSource.java @@ -16,27 +16,28 @@ * limitations under the License. */ -package org.apache.cassandra.index.sai.disk; +package org.apache.cassandra.index.sai.disk.v2; import io.github.jbellis.jvector.util.RamUsageEstimator; import org.apache.cassandra.db.Clustering; import org.apache.cassandra.db.DecoratedKey; import org.apache.cassandra.dht.Token; +import org.apache.cassandra.index.sai.disk.PrimaryKeyMap; import org.apache.cassandra.index.sai.utils.PrimaryKey; import org.apache.cassandra.io.sstable.SSTableId; import org.apache.cassandra.utils.bytecomparable.ByteComparable; import org.apache.cassandra.utils.bytecomparable.ByteSource; -public class PrimaryKeyWithSource implements PrimaryKey +class PrimaryKeyWithSource implements PrimaryKey { - private final PrimaryKeyMap primaryKeyMap; private final SSTableId sourceSstableId; private final long sourceRowId; private PrimaryKey delegatePrimaryKey; + private PrimaryKeyMap primaryKeyMap; private final PrimaryKey sourceSstableMinKey; private final PrimaryKey sourceSstableMaxKey; - public PrimaryKeyWithSource(PrimaryKeyMap primaryKeyMap, long sstableRowId, PrimaryKey sourceSstableMinKey, PrimaryKey sourceSstableMaxKey) + PrimaryKeyWithSource(PrimaryKeyMap primaryKeyMap, long sstableRowId, PrimaryKey sourceSstableMinKey, PrimaryKey sourceSstableMaxKey) { this.primaryKeyMap = primaryKeyMap; this.sourceSstableId = primaryKeyMap.getSSTableId(); @@ -45,20 +46,13 @@ public PrimaryKeyWithSource(PrimaryKeyMap primaryKeyMap, long sstableRowId, Prim this.sourceSstableMaxKey = sourceSstableMaxKey; } - public PrimaryKeyWithSource(PrimaryKey primaryKey, SSTableId sourceSstableId, long sourceRowId, PrimaryKey sourceSstableMinKey, PrimaryKey sourceSstableMaxKey) - { - this.delegatePrimaryKey = primaryKey; - this.primaryKeyMap = null; - this.sourceSstableId = sourceSstableId; - this.sourceRowId = sourceRowId; - this.sourceSstableMinKey = sourceSstableMinKey; - this.sourceSstableMaxKey = sourceSstableMaxKey; - } - private PrimaryKey primaryKey() { if (delegatePrimaryKey == null) + { delegatePrimaryKey = primaryKeyMap.primaryKeyFromRowId(sourceRowId); + primaryKeyMap = null; // Removes the no longer needed reference to the primary key map. + } return delegatePrimaryKey; } @@ -74,13 +68,10 @@ public SSTableId getSourceSstableId() } @Override - public PrimaryKeyWithSource forStaticRow() + public PrimaryKey forStaticRow() { - return new PrimaryKeyWithSource(primaryKey().forStaticRow(), - sourceSstableId, - sourceRowId, - sourceSstableMinKey, - sourceSstableMaxKey); + // We cannot use row awareness if we need a static row. + return primaryKey().forStaticRow(); } @Override @@ -104,7 +95,8 @@ public Clustering clustering() @Override public PrimaryKey loadDeferred() { - return primaryKey().loadDeferred(); + primaryKey().loadDeferred(); + return this; } @Override diff --git a/src/java/org/apache/cassandra/index/sai/disk/v2/RowAwarePrimaryKeyFactory.java b/src/java/org/apache/cassandra/index/sai/disk/v2/RowAwarePrimaryKeyFactory.java index 1664c51634ac..748541c476b9 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v2/RowAwarePrimaryKeyFactory.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v2/RowAwarePrimaryKeyFactory.java @@ -28,6 +28,7 @@ import org.apache.cassandra.db.ClusteringComparator; import org.apache.cassandra.db.DecoratedKey; import org.apache.cassandra.dht.Token; +import org.apache.cassandra.index.sai.disk.PrimaryKeyMap; import org.apache.cassandra.index.sai.utils.PrimaryKey; import org.apache.cassandra.utils.ByteBufferUtil; import org.apache.cassandra.utils.bytecomparable.ByteComparable; @@ -67,6 +68,11 @@ public PrimaryKey create(DecoratedKey partitionKey, Clustering clustering) return new RowAwarePrimaryKey(partitionKey.getToken(), partitionKey, clustering, null); } + PrimaryKey createWithSource(PrimaryKeyMap primaryKeyMap, long sstableRowId, PrimaryKey sourceSstableMinKey, PrimaryKey sourceSstableMaxKey) + { + return new PrimaryKeyWithSource(primaryKeyMap, sstableRowId, sourceSstableMinKey, sourceSstableMaxKey); + } + private class RowAwarePrimaryKey implements PrimaryKey { private Token token; diff --git a/src/java/org/apache/cassandra/index/sai/disk/v2/RowAwarePrimaryKeyMap.java b/src/java/org/apache/cassandra/index/sai/disk/v2/RowAwarePrimaryKeyMap.java index d96611c9dd60..5b95076e3ea5 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v2/RowAwarePrimaryKeyMap.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v2/RowAwarePrimaryKeyMap.java @@ -81,10 +81,11 @@ public static class RowAwarePrimaryKeyMapFactory implements Factory private FileHandle termsTrie = null; private final IPartitioner partitioner; private final ClusteringComparator clusteringComparator; - private final PrimaryKey.Factory primaryKeyFactory; + private final RowAwarePrimaryKeyFactory primaryKeyFactory; private final SSTableId sstableId; + private final boolean hasStaticColumns; - public RowAwarePrimaryKeyMapFactory(IndexComponents.ForRead perSSTableComponents, PrimaryKey.Factory primaryKeyFactory, SSTableReader sstable) + public RowAwarePrimaryKeyMapFactory(IndexComponents.ForRead perSSTableComponents, RowAwarePrimaryKeyFactory primaryKeyFactory, SSTableReader sstable) { try { @@ -105,6 +106,7 @@ public RowAwarePrimaryKeyMapFactory(IndexComponents.ForRead perSSTableComponents this.primaryKeyFactory = primaryKeyFactory; this.clusteringComparator = sstable.metadata().comparator; this.sstableId = sstable.getId(); + this.hasStaticColumns = sstable.metadata().hasStaticColumns(); } catch (Throwable t) { @@ -124,7 +126,8 @@ public PrimaryKeyMap newPerSSTablePrimaryKeyMap() partitioner, primaryKeyFactory, clusteringComparator, - sstableId); + sstableId, + hasStaticColumns); } catch (IOException e) { @@ -149,17 +152,19 @@ public void close() throws IOException private final SortedTermsReader sortedTermsReader; private final SortedTermsReader.Cursor cursor; private final IPartitioner partitioner; - private final PrimaryKey.Factory primaryKeyFactory; + private final RowAwarePrimaryKeyFactory primaryKeyFactory; private final ClusteringComparator clusteringComparator; private final SSTableId sstableId; + private final boolean hasStaticColumns; private RowAwarePrimaryKeyMap(LongArray rowIdToToken, SortedTermsReader sortedTermsReader, SortedTermsReader.Cursor cursor, IPartitioner partitioner, - PrimaryKey.Factory primaryKeyFactory, + RowAwarePrimaryKeyFactory primaryKeyFactory, ClusteringComparator clusteringComparator, - SSTableId sstableId) + SSTableId sstableId, + boolean hasStaticColumns) { this.rowIdToToken = rowIdToToken; this.sortedTermsReader = sortedTermsReader; @@ -168,6 +173,7 @@ private RowAwarePrimaryKeyMap(LongArray rowIdToToken, this.primaryKeyFactory = primaryKeyFactory; this.clusteringComparator = clusteringComparator; this.sstableId = sstableId; + this.hasStaticColumns = hasStaticColumns; } @Override @@ -188,6 +194,13 @@ public PrimaryKey primaryKeyFromRowId(long sstableRowId) return primaryKeyFactory.createDeferred(partitioner.getTokenFactory().fromLongValue(token), () -> supplier(sstableRowId)); } + @Override + public PrimaryKey primaryKeyFromRowId(long sstableRowId, PrimaryKey lowerBound, PrimaryKey upperBound) + { + return hasStaticColumns ? primaryKeyFromRowId(sstableRowId) + : primaryKeyFactory.createWithSource(this, sstableRowId, lowerBound, upperBound); + } + private long skinnyExactRowIdOrInvertedCeiling(PrimaryKey key) { // Fast path when there is no clustering, i.e., there is one row per partition. @@ -212,6 +225,13 @@ private long skinnyExactRowIdOrInvertedCeiling(PrimaryKey key) @Override public long exactRowIdOrInvertedCeiling(PrimaryKey key) { + if (key instanceof PrimaryKeyWithSource) + { + var pkws = (PrimaryKeyWithSource) key; + if (pkws.getSourceSstableId().equals(sstableId)) + return pkws.getSourceRowId(); + } + if (clusteringComparator.size() == 0) return skinnyExactRowIdOrInvertedCeiling(key); @@ -226,6 +246,13 @@ public long exactRowIdOrInvertedCeiling(PrimaryKey key) @Override public long ceiling(PrimaryKey key) { + if (key instanceof PrimaryKeyWithSource) + { + var pkws = (PrimaryKeyWithSource) key; + if (pkws.getSourceSstableId().equals(sstableId)) + return pkws.getSourceRowId(); + } + if (clusteringComparator.size() == 0) { long rowId = skinnyExactRowIdOrInvertedCeiling(key); diff --git a/src/java/org/apache/cassandra/index/sai/disk/v2/V2OnDiskFormat.java b/src/java/org/apache/cassandra/index/sai/disk/v2/V2OnDiskFormat.java index f3c6ecbbaaf4..4497f724cd84 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v2/V2OnDiskFormat.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v2/V2OnDiskFormat.java @@ -99,7 +99,8 @@ public PrimaryKey.Factory newPrimaryKeyFactory(ClusteringComparator comparator) @Override public PrimaryKeyMap.Factory newPrimaryKeyMapFactory(IndexComponents.ForRead perSSTableComponents, PrimaryKey.Factory primaryKeyFactory, SSTableReader sstable) { - return new RowAwarePrimaryKeyMap.RowAwarePrimaryKeyMapFactory(perSSTableComponents, primaryKeyFactory, sstable); + assert primaryKeyFactory instanceof RowAwarePrimaryKeyFactory; + return new RowAwarePrimaryKeyMap.RowAwarePrimaryKeyMapFactory(perSSTableComponents, (RowAwarePrimaryKeyFactory) primaryKeyFactory, sstable); } @Override diff --git a/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java index 361749a93ad5..268e1b140a69 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java @@ -45,7 +45,6 @@ import org.apache.cassandra.index.sai.QueryContext; import org.apache.cassandra.index.sai.disk.PostingList; import org.apache.cassandra.index.sai.disk.PrimaryKeyMap; -import org.apache.cassandra.index.sai.disk.PrimaryKeyWithSource; import org.apache.cassandra.index.sai.disk.v1.IndexSearcher; import org.apache.cassandra.index.sai.disk.v1.PerIndexFiles; import org.apache.cassandra.index.sai.disk.v1.SegmentMetadata; @@ -534,12 +533,7 @@ private SegmentRowIdOrdinalPairs flatmapPrimaryKeysToBitsAndRows(List 0) @@ -84,20 +87,41 @@ protected PrimaryKey computeNext() // We jumped over the highest key seen so far, so make it the new highest key. highestKey = nextKey; // Remember this iterator to avoid advancing it again, because it is already at the highest key - alreadyAdvanced = index; + indexOfHighestKey = index; // This iterator jumped over, so the other iterators are lagging behind now, // including the ones already advanced in the earlier cycles of the inner loop. // Therefore, restart the inner loop in order to advance // the other iterators except this one to match the new highest key. continue outer; } + assert comparisonResult == 0 : String.format("skipTo skipped to an item smaller than the target; " + "iterator: %s, target key: %s, returned key: %s", range, highestKey, nextKey); + + // More specific keys should win over full partitions, + // because they match a single row instead of the whole partition. + // However, because this key matches with the earlier keys, we can continue the inner loop. + if (!nextKey.hasEmptyClustering()) + { + highestKey = nextKey; + indexOfHighestKey = index; + } } } - // If we reached here, next() has been called at least once on each range iterator and - // the last call to next() on each iterator returned a value equal to the highestKey. + // If we reached here, we have a match - all iterators are at the same key == highestKey. + + // Now we need to advance the iterators to avoid returning the same key again. + // This is tricky because of empty clustering keys that match the whole partition. + // We must not advance ranges at keys with empty clustering because they + // may still match the next keys returned by other iterators in the next cycles. + // However, if all ranges are at the same partition with empty clustering (highestKey.hasEmptyClustering()), + // we must advance all of them, because we return the key for the whole partition and that partition is done. + for (var range : ranges) + { + if (highestKey.hasEmptyClustering() || !range.peek().hasEmptyClustering()) + range.next(); + } // Move the iterator that was called the least times to the start of the list. // This is an optimisation assuming that iterator is likely a more selective one. @@ -140,16 +164,6 @@ protected void performSkipTo(PrimaryKey nextToken) range.skipTo(nextToken); } - /** - * Fetches the next available item from the iterator, such that the item is not lower than the given key. - * If no such items are available, returns null. - */ - private PrimaryKey nextOrNull(KeyRangeIterator iterator, PrimaryKey minKey) - { - iterator.skipTo(minKey); - return iterator.hasNext() ? iterator.next() : null; - } - public void close() throws IOException { ranges.forEach(FileUtils::closeQuietly); diff --git a/src/java/org/apache/cassandra/index/sai/iterators/KeyRangeUnionIterator.java b/src/java/org/apache/cassandra/index/sai/iterators/KeyRangeUnionIterator.java index c1cef8b4109e..340a1f5ab9be 100644 --- a/src/java/org/apache/cassandra/index/sai/iterators/KeyRangeUnionIterator.java +++ b/src/java/org/apache/cassandra/index/sai/iterators/KeyRangeUnionIterator.java @@ -24,17 +24,23 @@ import com.google.common.collect.Iterables; +import org.apache.cassandra.db.DecoratedKey; import org.apache.cassandra.index.sai.utils.PrimaryKey; import org.apache.cassandra.io.util.FileUtils; /** * Range Union Iterator is used to return sorted stream of elements from multiple KeyRangeIterator instances. + * Keys are sorted by natural order of PrimaryKey, however if two keys are equal by their natural order, + * the one with an empty clustering always wins. */ @SuppressWarnings("resource") public class KeyRangeUnionIterator extends KeyRangeIterator { public final List ranges; + // If set, we must first skip this partition. + private DecoratedKey partitionToSkip = null; + private KeyRangeUnionIterator(Builder.Statistics statistics, List ranges) { super(statistics); @@ -43,6 +49,10 @@ private KeyRangeUnionIterator(Builder.Statistics statistics, List 0) + { candidate = range; + } } } + if (candidate == null) return endOfData(); - return candidate.next(); + + var result = candidate.next(); + + // If the winning candidate has an empty clustering, this means it selects the whole partition, so + // advance all other ranges to the end of this partition to avoid duplicates. + // We delay that to the next call to computeNext() though, because if we have a wide partition, it's better + // to first let the caller consume all the rows from this partition - maybe they won't call again. + if (result.hasEmptyClustering()) + partitionToSkip = result.partitionKey(); + + return result; + } + + private void maybeSkipCurrentPartition() + { + if (partitionToSkip != null) + { + for (KeyRangeIterator range : ranges) + skipPartition(range, partitionToSkip); + + partitionToSkip = null; + } + } + + private void skipPartition(KeyRangeIterator iterator, DecoratedKey partitionKey) + { + // TODO: Push this logic down to the iterator where it can be more efficient + while (iterator.hasNext() && iterator.peek().partitionKey() != null && iterator.peek().partitionKey().compareTo(partitionKey) <= 0) + iterator.next(); } protected void performSkipTo(PrimaryKey nextKey) diff --git a/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java index f484a7ec8329..46aa825d18fb 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java @@ -345,7 +345,7 @@ private boolean isEqualToLastKey(PrimaryKey key) // filtered and considered as a result multiple times). return lastKey != null && Objects.equals(lastKey.partitionKey(), key.partitionKey()) && - Objects.equals(lastKey.clustering(), key.clustering()); + (lastKey.hasEmptyClustering() || key.hasEmptyClustering() || Objects.equals(lastKey.clustering(), key.clustering())); } private void fillNextSelectedKeysInPartition(DecoratedKey partitionKey, List nextPrimaryKeys) diff --git a/test/unit/org/apache/cassandra/index/sai/cql/NumericIndexMixedVersionTest.java b/test/unit/org/apache/cassandra/index/sai/cql/NumericIndexMixedVersionTest.java new file mode 100644 index 000000000000..4ae698ca13d9 --- /dev/null +++ b/test/unit/org/apache/cassandra/index/sai/cql/NumericIndexMixedVersionTest.java @@ -0,0 +1,177 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.index.sai.cql; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import org.junit.Test; + +import org.apache.cassandra.config.CassandraRelevantProperties; +import org.apache.cassandra.index.sai.SAITester; +import org.apache.cassandra.index.sai.SAIUtil; +import org.apache.cassandra.index.sai.disk.format.Version; +import org.apache.cassandra.index.sai.plan.QueryController; + +import static org.junit.Assert.assertEquals; + +public class NumericIndexMixedVersionTest extends SAITester +{ + // Versions in random order + final static List VERSIONS = getVersions(); + + private static List getVersions() + { + var versions = new ArrayList<>(Version.ALL); + Collections.reverse(versions); + // AA is the earliest version and produces different data for flush vs compaction, so we have + // special logic to hit that and make this first. + assert versions.get(0).equals(Version.AA); + logger.info("Running mixed version test with versions: {}", versions); + return versions; + } + + + // This test does not trigger an issue. It simply confirms that we can query across versions. + @Test + public void testMultiVersionCompatibilityNoClusteringColumns() throws Throwable + { + createTable("CREATE TABLE %s (pk int, val int, PRIMARY KEY(pk))"); + createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex'"); + + // Note that we do not test the multi-version path where compaction produces different sstables, which is + // the norm in CNDB. If we had a way to compnact individual sstables, we could. + disableCompaction(); + + SAIUtil.setCurrentVersion(Version.AA); + for (int j = 0; j < 500; j++) + execute("INSERT INTO %s (pk, val) VALUES (?, ?)", j, j); + flush(); + compact(); + + // Insert 500 rows per version, each with a unique pk but overlapping values. + int pk = 0; + for (var version : VERSIONS) + { + SAIUtil.setCurrentVersion(version); + for (int i = 0; i < 500; i++) + execute("INSERT INTO %s (pk, val) VALUES (?, ?)", pk++, i); + flush(); + } + + // Confirm that compaction (aka rebuilding all indexes onto same version) also produces correct results + final int expectedRows = pk; + runThenFlushThenCompact(() -> { + var batchLimit = CassandraRelevantProperties.SAI_PARTITION_ROW_BATCH_SIZE.getInt(); + // Query that will hit all sstables and exceed the cassandra.sai.partition_row_batch_size limit + var rows = executeNetWithPaging("SELECT pk FROM %s WHERE val >= 0 LIMIT 10000", batchLimit / 2); + assertEquals(expectedRows, rows.all().size()); + + rows = executeNetWithPaging("SELECT pk FROM %s WHERE val >= 0 LIMIT 10000", batchLimit); + assertEquals(expectedRows, rows.all().size()); + + rows = executeNetWithPaging("SELECT pk FROM %s WHERE val >= 0 LIMIT 10000", batchLimit * 2); + assertEquals(expectedRows, rows.all().size()); + + // Test without paging + assertNumRows(expectedRows, "SELECT pk FROM %%s WHERE val >= 0 LIMIT 10000"); + }); + } + + @Test + public void testMultiVersionCompatibilityWithClusteringColumns() throws Throwable + { + createTable("CREATE TABLE %s (pk int, ck int, val int, PRIMARY KEY(pk, ck)) WITH CLUSTERING ORDER BY (ck ASC)"); + createIndex("CREATE CUSTOM INDEX ON %s(val) USING 'StorageAttachedIndex'"); + + // Note that we do not test the multi-version path where compaction produces different sstables, which is + // the norm in CNDB. If we had a way to compact individual sstables, we could. + disableCompaction(); + + SAIUtil.setCurrentVersion(Version.AA); + int ck = 0; + for (int j = 0; j < 500; j++) + execute("INSERT INTO %s (pk, ck, val) VALUES (1, ?, ?)", ck++, j); + flush(); + compact(); + + // Insert 500 rows per version + for (var version : VERSIONS) + { + SAIUtil.setCurrentVersion(version); + for (int j = 0; j < 500; j++) + execute("INSERT INTO %s (pk, ck, val) VALUES (1, ?, ?)", ck++, j); + flush(); + } + + // Confirm that compaction (aka rebuilding all indexes onto same version) also produces correct results + final int expectedRows = ck; + runThenFlushThenCompact(() -> { + // When using paging, we get an excessive number of results because of logic within the contoller.select + // method that short circuits when one of the indexes is aa (not row aware). + var batchLimit = CassandraRelevantProperties.SAI_PARTITION_ROW_BATCH_SIZE.getInt(); + var rows = executeNetWithPaging("SELECT ck FROM %s WHERE val >= 0 LIMIT 10000", batchLimit / 2); + assertEquals(expectedRows, rows.all().size()); + + rows = executeNetWithPaging("SELECT ck FROM %s WHERE val >= 0 LIMIT 10000", batchLimit); + assertEquals(expectedRows, rows.all().size()); + + rows = executeNetWithPaging("SELECT ck FROM %s WHERE val >= 0 LIMIT 10000", batchLimit * 2); + assertEquals(expectedRows, rows.all().size()); + + // Test without paging. This test actually fails by producing fewer than expected rows because of an issue in + // partition-only primary keys and row aware primary keys that are considered equal. When they are unioned + // in the iterator, we take one and leave the other (they evaluate to equal after all) but this behavior + // filters out a result that would have loaded the whole partition and might have returned a unique result. + assertNumRows(expectedRows, "SELECT ck FROM %%s WHERE val >= 0 LIMIT 10000"); + }); + } + + + @Test + public void testMultiVersionCompatibilityWithClustringColumnsIntersection() throws Throwable + { + QueryController.QUERY_OPT_LEVEL = 0; + SAIUtil.setCurrentVersion(Version.AA); + + createTable("CREATE TABLE %s (pk int, ck int, val1 int, val2 int, PRIMARY KEY(pk, ck))"); + createIndex("CREATE CUSTOM INDEX ON %s(val1) USING 'StorageAttachedIndex'"); + disableCompaction(); + + // Insert rows so that all have v1 == 1. Index has AA version, and don't compact to get the AA version where we + // get a single primary key per partition in the internal iterator. + for (int j = 0; j < 500; j++) + { + execute("INSERT INTO %s (pk, ck, val1) VALUES (-1, ?, 1)", j); + execute("INSERT INTO %s (pk, ck, val1) VALUES (?, ?, ?)", j, j, j); + } + flush(); + + // Now, create rows with v2 values and index with all versions + SAIUtil.setCurrentVersion(Version.DB); + createIndex("CREATE CUSTOM INDEX ON %s(val2) USING 'StorageAttachedIndex'"); + + + flush(); // force new memtable classes to get version + for (int j = 0; j < 10; j++) + execute("INSERT INTO %s (pk, ck, val2) VALUES (-1, ?, ?)", j, j); + + beforeAndAfterFlush(() -> { + assertNumRows(10, "SELECT ck FROM %%s WHERE val1 = 1 AND val2 >= 0 LIMIT 1000"); + }); + } +} diff --git a/test/unit/org/apache/cassandra/index/sai/iterators/AbstractKeyRangeIteratorTest.java b/test/unit/org/apache/cassandra/index/sai/iterators/AbstractKeyRangeIteratorTest.java index a43afe62a7ea..cd778ad839e9 100644 --- a/test/unit/org/apache/cassandra/index/sai/iterators/AbstractKeyRangeIteratorTest.java +++ b/test/unit/org/apache/cassandra/index/sai/iterators/AbstractKeyRangeIteratorTest.java @@ -17,16 +17,37 @@ */ package org.apache.cassandra.index.sai.iterators; +import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Objects; import java.util.Set; +import java.util.function.BiConsumer; +import java.util.function.Function; import java.util.stream.Collectors; +import javax.annotation.Nullable; + import org.junit.Assert; +import org.junit.Test; +import org.apache.cassandra.db.Clustering; +import org.apache.cassandra.db.ClusteringComparator; +import org.apache.cassandra.db.DecoratedKey; +import org.apache.cassandra.db.marshal.LongType; +import org.apache.cassandra.dht.Murmur3Partitioner; +import org.apache.cassandra.index.sai.disk.FileUtils; +import org.apache.cassandra.index.sai.disk.format.Version; import org.apache.cassandra.index.sai.utils.PrimaryKey; import org.apache.cassandra.index.sai.utils.SaiRandomizedTest; import org.apache.cassandra.utils.Pair; +import static org.apache.cassandra.io.util.FileUtils.closeQuietly; + public class AbstractKeyRangeIteratorTest extends SaiRandomizedTest { protected long[] arr(long... longArray) @@ -159,4 +180,465 @@ static Pair createRandomIterator() throw new AssertionError(); } } + + + private static final PrimaryKey.Factory TEST_PRIMARY_KEY_FACTORY = Version.current().onDiskFormat() + .newPrimaryKeyFactory(new ClusteringComparator(LongType.instance)); + + /** + * Generates a random list of primary keys with the given average number of partitions and rows per partition. + * Partition keys and clusterings are generated in such a way that when combining two such lists generated with + * same parameters (but different random), there is a high chance both sets would contain many common keys, as well + * as each would contain some keys not present in the other set. + * + * @return list of primary keys in (token, clustering) order. + */ + static List randomPrimaryKeys(int avgPartitions, int avgRowsPerPartition) + { + List keys = new ArrayList<>((int)(avgPartitions * avgRowsPerPartition * 1.5)); + + for (int p = 0; p < avgPartitions * 2; p++) + { + if (randomBoolean()) // skip 50% of partitions + continue; + + if (randomBoolean()) + { + keys.add(makeKey(p, null)); // add partition key only + } + else + { + for (int r = 0; r < avgRowsPerPartition * 2; r++) + { + if (randomBoolean()) // skip 50% of rows + keys.add(makeKey(p, (long) r)); + } + } + } + + // We must sort the keys to recover proper token order + Collections.sort(keys); + return keys; + } + + + /** + * Helper to create PrimaryKey with/without clustering. + * Pass null clustering to create a key with Clustering.EMPTY. + */ + static PrimaryKey makeKey(long partitionKey, @Nullable Long clustering) + { + ByteBuffer pkValue = LongType.instance.getSerializer().serialize(partitionKey); + ByteBuffer clusteringValue = LongType.instance.getSerializer().serialize(clustering); + DecoratedKey pk = Murmur3Partitioner.instance.decorateKey(pkValue); + Clustering c = clustering == null ? Clustering.EMPTY : Clustering.make(clusteringValue); + return TEST_PRIMARY_KEY_FACTORY.create(pk, c); + } + + /** + * Convenience method for comparing arrays of PrimaryKey; we don't use assertEquals to compare arrays + * because its output is one huge line of text that is hard to read when the test fails. + */ + void assertKeysEqual(List expected, List result) + { + int matchesUntil = 0; + try + { + for (int i = 0; i < expected.size() && i < result.size(); i++) + { + PrimaryKey e = expected.get(i); + PrimaryKey r = result.get(i); + assertEquals(e, r); + matchesUntil = i; + } + + if (result.size() < expected.size()) + throw new AssertionError("Missing " + (expected.size() - result.size()) + " key(s) at the end"); + if (result.size() > expected.size()) + throw new AssertionError("Got extra keys at the end: " + result.get(expected.size())); + } + catch (AssertionError e) + { + // Print out all the keys that matched properly before the failure to help debugging + for (int i = 0; i < matchesUntil; i++) + System.err.println("Keys match correctly: " + expected.get(i)); + + throw e; + } + } + + /** + * Checks if the given keys are in increasing order and contain no duplicates. + */ + static void assertIncreasing(Collection keys) + { + PrimaryKey lastPrimaryKey = null; + DecoratedKey lastPartitionKey = null; + Clustering lastClustering = Clustering.EMPTY; + for (PrimaryKey key : keys) + { + if (key.hasEmptyClustering() && key.partitionKey().equals(lastPartitionKey)) + throw new AssertionError("A primary key with empty clustering follows a key in the same partition:\n" + key + "\nafter:\n" + lastPrimaryKey); + + if (!key.hasEmptyClustering() && lastClustering.isEmpty() && key.partitionKey().equals(lastPartitionKey)) + throw new AssertionError("A primary key with non-empty clustering follows a key with empty clustering in the same partition:\n" + key + "\nafter:\n" + lastPrimaryKey); + + if (Objects.equals(key, lastPrimaryKey)) + throw new AssertionError("Duplicate key:\n" + key + " = " + lastPrimaryKey); + + if (lastPrimaryKey != null && key.compareTo(lastPrimaryKey) < 0) + throw new AssertionError("Out of order key:\n" + key + " < " + lastPrimaryKey); + + lastPrimaryKey = key; + lastPartitionKey = key.partitionKey(); + lastClustering = key.clustering(); + } + } + + /** + * Helper class to quickly find if a key exists in the set or not. + * We cannot just use a hashmap for that, because keys with no clustering match full partitions. + */ + static class PrimaryKeySet + { + Set partitions = new HashSet<>(); + Set>> rows = new HashSet<>(); + + public PrimaryKeySet(Collection keys) + { + for (PrimaryKey pk : keys) + { + if (pk.hasEmptyClustering()) + partitions.add(pk.partitionKey()); + else + rows.add(Pair.create(pk.partitionKey(), pk.clustering())); + } + } + + public boolean contains(PrimaryKey key) + { + return partitions.contains(key.partitionKey()) || + rows.contains(Pair.create(key.partitionKey(), key.clustering())); + } + } + + + static class PrimaryKeyListIterator extends KeyRangeIterator + { + private final List keys; + private int currentIdx = 0; + + private PrimaryKeyListIterator(List keys) + { + super(keys.isEmpty() ? null : keys.get(0), keys.isEmpty() ? null : keys.get(keys.size() - 1), keys.size()); + this.keys = new ArrayList<>(keys); + + } + + public static PrimaryKeyListIterator create(PrimaryKey... keys) + { + List list = Arrays.asList(keys); + Collections.sort(list); + return new PrimaryKeyListIterator(list); + } + + public static PrimaryKeyListIterator create(List keys) + { + Collections.sort(keys); + return new PrimaryKeyListIterator(keys); + } + + @Override + protected PrimaryKey computeNext() + { + if (currentIdx >= keys.size()) + return endOfData(); + + return keys.get(currentIdx++); + } + + @Override + protected void performSkipTo(PrimaryKey nextToken) + { + while (currentIdx < keys.size() && keys.get(currentIdx).compareTo(nextToken) < 0) + currentIdx++; + } + + @Override + public void close() + {} + } + + + /** + * Prints each key in a separate line to help debugging. + * Useful because those printed keys are very long. + */ + void printKeys(Collection keys) + { + for (PrimaryKey key : keys) + System.err.println(key); + } + + /** + * Generates all permutations of array of integers from 0 to n - 1. + * E.g. for n = 3, generates: [[0, 1, 2], [0, 2, 1], [1, 0, 2], [1, 2, 0], [2, 1, 0], [2, 0, 1]] + */ + static List permutations(int n) { + int[] indices = new int[n]; + for (int i = 0; i < n; i++) { + indices[i] = i; + } + + List result = new ArrayList<>(); + generatePermutations(indices, 0, result); + return result; + } + + // Recursive function to find all possible permutations + private static void generatePermutations(int[] arr, int idx, List res) { + if (idx == arr.length) + { + res.add(Arrays.copyOf(arr, arr.length)); + return; + } + + for (int i = idx; i < arr.length; i++) { + int temp = arr[idx]; + arr[idx] = arr[i]; + arr[i] = temp; + + generatePermutations(arr, idx + 1, res); + + temp = arr[idx]; + arr[idx] = arr[i]; + arr[i] = temp; + } + } + + + /** + * Performs a merge operation on the primary key lists and validates the result. + * If the validation fails, it will try to first minimize the input lists needed to crash the operation + * or fail the validation. If the operation succeeds, returns normally, otherwise throws the final exception. + * + * @param inputs some arbitrary lists of primary keys + * @param merge operation under test that merges multiple lists into one + * @param validator validation function that checks if the result of the operation is correct, expected to throw if not + */ + void testMerge(List> inputs, + Function>, List> merge, + BiConsumer>, List> validator) throws Throwable + { + List result = null; // last result we obtained from operation (may be valid) + List failedResult = null; // last result that failed validation + Throwable exception = null; // last exception we got from operation or validation + + try + { + result = merge.apply(inputs); + validator.accept(inputs, result); + return; // test passes, nothing to do + } + catch (Throwable e) + { + failedResult = result; + exception = e; + } + + // Run the test with smaller inputs until the test doesn't fail anymore + // or we reach the max number of attempts + int attempt = 0; + + while (attempt < 10 && inputs.stream().anyMatch(l -> !l.isEmpty())) + { + // make a copy of each input with some keys removed + boolean removed; // tracks if we actually removed something + List> minimizedInputs; + do + { + minimizedInputs = new ArrayList<>(); + removed = false; + int totalKeys = inputs.stream().mapToInt(List::size).sum(); + + for (List input : inputs) + { + ArrayList minimized = new ArrayList<>(); + minimizedInputs.add(minimized); + + // We want to remove a constant fraction of keys (~10%) to make sure we converge quickly, + // but we must be carefult when the number of keys gets small, so we don't end up leaving all keys + // unmodified. + for (PrimaryKey key : input) + if (nextInt(Math.min(10, totalKeys)) != 0) + minimized.add(key); + + removed |= minimized.size() < input.size(); + } + } while (!removed); + + try + { + result = null; // must clean result in case operation.apply fails in the next line; + // we don't want to keep a result from a previous run + result = merge.apply(minimizedInputs); + validator.accept(minimizedInputs, result); + attempt++; + } + catch (Throwable e) + { + // if we're still failing, then it's a success! we managed to get a smaller input + attempt = 0; + inputs = minimizedInputs; + failedResult = result; + exception = e; + } + } + + System.err.println("Validation failed"); + for (int i = 0; i < inputs.size(); i++) + { + System.err.println("\nInput " + i + ':'); + printKeys(inputs.get(i)); + } + + + if (failedResult != null) + { + System.err.println("\nResult:"); + printKeys(failedResult); + } + + throw exception; + } + + /** + * Tests skipping support of the given merge operation. + * Works by comparing the results obtained from calling skipTo on the result merge iterator directly, + * with the results obtained by first materializing the full merge result and then applying skipping to + * the list (which is easy and obviously correct). + *

+ * This does not test the correctness of merge operation itself. + * It only checks if skipping works correctly. + *

+ * If the validation fails, it will try to first minimize the input lists in the same way as {@link #testMerge}. + * + * @param inputs some arbitrary lists of primary keys + * @param skips the list of positions to skip to + * @param mergeOperation the merge operation under test, e.g. intersection or union + * @throws Throwable when the merge operation or validation of results fails + */ + void testSkipping(List> inputs, + List skips, + Function>, KeyRangeIterator> mergeOperation) throws Throwable + { + int sizeLimit = inputs.stream().mapToInt(List::size).sum() + 10; + + // The test and validation code looks very alike, but the test code performs skipping *directly* + // on the KeyRangeIterator, while the validation logic first materializes the merge + // result to an in-memory list and then applies skipping on the list. + try + { + testMerge(inputs, + inp -> { + KeyRangeIterator iterator = mergeOperation.apply(inp); + return collectKeysSkipping(iterator, skips); + }, + (inp, result) -> { + KeyRangeIterator iterator = mergeOperation.apply(inp); + List merged = collectKeys(iterator, sizeLimit); + List expected = collectKeysSkipping(merged, skips); + assertKeysEqual(expected, result); + }); + } + catch (Throwable e) + { + // Skipping informaion is not printed by testMerge, so print it here to help debugging: + System.err.println("\nSkipping operations:"); + for (Skip skip : skips) + System.err.println(skip); + throw e; + } + } + + /** + * Generates a random list of skip operations to perform on the given keys. + */ + static List randomSkips(List keys) + { + List skipPositions = new ArrayList<>(); + for (int i = 0; i < keys.size(); i++) + skipPositions.add(nextInt(keys.size())); + Collections.sort(skipPositions); + + List skips = new ArrayList<>(); + for (int pos : skipPositions) + skips.add(new Skip(keys.get(pos), nextInt(1, 5))); + + return skips; + } + + /** + * Iterates the given iterator and collects all keys into an array. + */ + static List collectKeys(KeyRangeIterator iterator, int sizeLimit) + { + try + { + List result = new ArrayList<>(); + while (iterator.hasNext() && result.size() < sizeLimit) + result.add(iterator.next()); + return result; + } + finally + { + closeQuietly(iterator); + } + } + + /** + * Iterates the given iterator, skipping to the given keys and collecting a chunk of keys after each skip. + */ + static List collectKeysSkipping(KeyRangeIterator iterator, List skips) + { + try + { + List result = new ArrayList<>(); + for (Skip skip : skips) + { + iterator.skipTo(skip.target); + for (int i = 0; i < skip.chunkSize && iterator.hasNext(); i++) + result.add(iterator.next()); + } + return result; + } + finally + { + closeQuietly(iterator); + } + } + + static List collectKeysSkipping(List keys, List skips) + { + return collectKeysSkipping(PrimaryKeyListIterator.create(keys), skips); + } + + static class Skip + { + public final PrimaryKey target; + public final int chunkSize; + + Skip(PrimaryKey skipToKey, int chunkSize) + { + this.target = skipToKey; + this.chunkSize = chunkSize; + } + + @Override + public String toString() + { + return "Skip: { " + "target: " + target + ", chunkSize: " + chunkSize + " }"; + } + } + } diff --git a/test/unit/org/apache/cassandra/index/sai/iterators/KeyRangeIntersectionIteratorTest.java b/test/unit/org/apache/cassandra/index/sai/iterators/KeyRangeIntersectionIteratorTest.java index 0d8bcd365ede..cea9884ef424 100644 --- a/test/unit/org/apache/cassandra/index/sai/iterators/KeyRangeIntersectionIteratorTest.java +++ b/test/unit/org/apache/cassandra/index/sai/iterators/KeyRangeIntersectionIteratorTest.java @@ -18,6 +18,7 @@ package org.apache.cassandra.index.sai.iterators; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Set; @@ -397,4 +398,93 @@ static Pair createRandom(int nRanges) IntStream.range(1, ranges.length).forEach(i -> expectedSet.retainAll(toSet(ranges[i]))); return Pair.create(builder.build(), expectedSet.stream().mapToLong(Long::longValue).sorted().toArray()); } + + @Test + public void testRandomized() throws Throwable + { + for (int iteratorCount = 2; iteratorCount <= 5; iteratorCount++) + { + for (int testIteration = 0; testIteration < 200; testIteration++) + { + var inputs = new ArrayList>(iteratorCount); + for (int j = 0; j < iteratorCount; j++) + inputs.add(randomPrimaryKeys(testIteration / 10, testIteration / 10)); + + testMerge(inputs, + KeyRangeIntersectionIteratorTest::intersection, + KeyRangeIntersectionIteratorTest::validateIntersectionResults); + } + } + } + + @Test + public void testSkippingRandomized() throws Throwable + { + for (int iteratorCount = 2; iteratorCount <= 5; iteratorCount++) + { + for (int testIteration = 0; testIteration < 200; testIteration++) + { + var inputs = new ArrayList>(iteratorCount); + for (int j = 0; j < iteratorCount; j++) + inputs.add(randomPrimaryKeys(testIteration / 10, testIteration / 10)); + + // Generate random skip positions. + // Use a different data set so that some skip positions exist in the merged result and some do not. + var skips = randomSkips(randomPrimaryKeys(testIteration / 10, testIteration / 10)); + + testSkipping(inputs, skips, KeyRangeIntersectionIteratorTest::intersectionIterator); + } + } + } + + private static List intersection(List> inputs) + { + // Limit the size of the result to avoid test timeouts. + // We don't need to throw, because excessive results will be checked by validation logic + // and that way we get better diagnostics. If we threw an assertion error here, the results wouldn't be printed. + var sizeLimit = inputs.stream().mapToInt(List::size).sum() + 10; + return collectKeys(intersectionIterator(inputs), sizeLimit); + } + + private static KeyRangeIterator intersectionIterator(List> inputs) + { + var builder = KeyRangeIntersectionIterator.builder(); + for (List input : inputs) + builder.add(PrimaryKeyListIterator.create(input)); + + return builder.build(); + } + + private static void validateIntersectionResults(List> inputs, List result) + { + // Check for order and duplicates: + assertIncreasing(result); + + // Index the keys we got for faster search: + ArrayList inputSets = new ArrayList<>(inputs.size()); + for (List input : inputs) + inputSets.add(new PrimaryKeySet(input)); + + // Check if all keys are present: + PrimaryKeySet resultSet = new PrimaryKeySet(result); + for (List keys : inputs) + { + for (PrimaryKey key : keys) + { + if (!inputSets.stream().allMatch(input -> input.contains(key))) + continue; + + assertTrue("Missing key in intersection result:\n" + key, resultSet.contains(key)); + } + } + + // Check if we inluded only the rows that are present in all inputs; + // excessive rows are likely not a correctness issue, but they may degrade performance: + for (int i = 0; i < inputs.size(); i++) + { + for (PrimaryKey key : result) + assertTrue("Unexpected key in intersection result, not covered by input " + i + + ":\n" + key, inputSets.get(i).contains(key)); + } + } } diff --git a/test/unit/org/apache/cassandra/index/sai/iterators/KeyRangeUnionIteratorTest.java b/test/unit/org/apache/cassandra/index/sai/iterators/KeyRangeUnionIteratorTest.java index a4b4152246d9..769000a9826f 100644 --- a/test/unit/org/apache/cassandra/index/sai/iterators/KeyRangeUnionIteratorTest.java +++ b/test/unit/org/apache/cassandra/index/sai/iterators/KeyRangeUnionIteratorTest.java @@ -17,10 +17,7 @@ */ package org.apache.cassandra.index.sai.iterators; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashSet; -import java.util.List; +import java.util.*; import java.util.stream.Collectors; import org.junit.Assert; @@ -405,4 +402,188 @@ public void testUnionOfRandom() validateWithSkipping(builder.build(), totalOrdered); } } + + @Test + public void testEmptyClusteringTwoWayMerge() { + PrimaryKey[] keysA = { + makeKey(1, 1L), + makeKey(2, 1L), + makeKey(2, 1000L), + makeKey(3, null), + makeKey(3, 30L), + makeKey(3, 31L), + makeKey(3, 32L), + makeKey(3, 33L), + makeKey(4, null) + }; + + PrimaryKey[] keysB = { + makeKey(0, null), + makeKey(1, 2L), + makeKey(2, null), + makeKey(3, 31L), + makeKey(4, null) + }; + + List expected = Arrays.asList( + makeKey(0, null), + makeKey(1, 1L), + makeKey(1, 2L), + makeKey(2, null), + makeKey(3, null), + makeKey(4, null) + ); + + testUnion(expected, keysA, keysB); + } + + @Test + public void testEmptyClusteringThreeWayMerge() { + PrimaryKey[] keysA = { + makeKey(1, 11L), + makeKey(2, 21L), + makeKey(2, 1000L), + makeKey(3, null), + makeKey(3, 0L), + makeKey(3, 1L), + makeKey(3, 2L), + makeKey(4, 41L), + makeKey(6, null), + makeKey(7, 72L), + makeKey(7, 73L) + }; + + PrimaryKey[] keysB = { + makeKey(0, null), + makeKey(1, 13L), + makeKey(2, null), + makeKey(3, 1L), + makeKey(4, 40L), + makeKey(4, 42L), + makeKey(4, 43L), + makeKey(4, 45L), + makeKey(5, 50L), + makeKey(7, 71L), + makeKey(7, 73L), + makeKey(7, 74L) + }; + + PrimaryKey[] keysC = { + makeKey(1, 12L), + makeKey(2, 22L), + makeKey(2, 5L), + makeKey(3, 1L), + makeKey(4, null), + makeKey(6, 60L), + makeKey(7, null) + }; + + List expected = Arrays.asList( + makeKey(0, null), + makeKey(1, 11L), + makeKey(1, 12L), + makeKey(1, 13L), + makeKey(2, null), + makeKey(3, null), + makeKey(4, null), + makeKey(5, 50L), + makeKey(6, null), + makeKey(7, null) + ); + + testUnion(expected, keysA, keysB, keysC); + } + + private void testUnion(List expected, PrimaryKey[]... inputs) { + // Test all permutations of input arrays to ensure order of iterators does not matter + for (int[] permutation : permutations(inputs.length)) + { + KeyRangeUnionIterator.Builder builder = KeyRangeUnionIterator.builder(); + + for (int i = 0; i < inputs.length; i++) + builder.add(PrimaryKeyListIterator.create(inputs[permutation[i]])); + + KeyRangeIterator union = builder.build(); + + List result = new ArrayList<>(); + while (union.hasNext()) { + result.add(union.next()); + } + + Collections.sort(expected); + assertKeysEqual(expected, result); + } + } + + @Test + public void testRandomized() throws Throwable + { + for (int iteratorCount = 2; iteratorCount <= 5; iteratorCount++) + { + for (int i = 0; i < 200; i++) + { + var inputs = new ArrayList>(iteratorCount); + for (int j = 0; j < iteratorCount; j++) + inputs.add(randomPrimaryKeys(i / 10, i / 10)); + + testMerge(inputs, + KeyRangeUnionIteratorTest::union, + KeyRangeUnionIteratorTest::validateUnionResults); + } + } + } + + @Test + public void testSkippingRandomized() throws Throwable + { + for (int iteratorCount = 2; iteratorCount <= 5; iteratorCount++) + { + for (int testIteration = 0; testIteration < 200; testIteration++) + { + var inputs = new ArrayList>(iteratorCount); + for (int j = 0; j < iteratorCount; j++) + inputs.add(randomPrimaryKeys(testIteration / 10, testIteration / 10)); + + // Generate random skip positions. + // Use a different data set so that some skip positions exist in the merged result and some do not. + var skips = randomSkips(randomPrimaryKeys(testIteration / 10, testIteration / 10)); + + testSkipping(inputs, skips, KeyRangeUnionIteratorTest::unionIterator); + } + } + } + + + private static List union(List> inputs) + { + var iterator = unionIterator(inputs); + + // Limit the size of the result to avoid test timeouts. + // We don't need to throw, because excessive results will be checked by validation logic + // and that way we get better diagnostics. If we threw an assertion error here, the results wouldn't be printed. + var sizeLimit = inputs.stream().mapToInt(List::size).sum() + 10; + return collectKeys(iterator, sizeLimit); + } + + private static KeyRangeIterator unionIterator(List> inputs) + { + var builder = KeyRangeUnionIterator.builder(); + for (List input : inputs) + builder.add(PrimaryKeyListIterator.create(input)); + return builder.build(); + } + + + private static void validateUnionResults(List> inputs, List result) + { + // Check for order and duplicates: + assertIncreasing(result); + + // Check if we're not missing anything - all keys from input lists must be found in the output + PrimaryKeySet resultKeySet = new PrimaryKeySet(result); + for (List input : inputs) + for (PrimaryKey key : input) + assertTrue("Missing key in union result:\n" + key, resultKeySet.contains(key)); + } } +