Skip to content

Commit

Permalink
Track blocks in enrich operators (#102184)
Browse files Browse the repository at this point in the history
This change tracks Blocks in enrich server operators. With this change, 
all Blocks in the enrich feature should be properly tracked with BlockFactory.
  • Loading branch information
dnhatn authored and timgrein committed Nov 30, 2023
1 parent b5e9a73 commit dcb5499
Show file tree
Hide file tree
Showing 6 changed files with 217 additions and 109 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/102184.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 102184
summary: Track ESQL enrich memory
area: ES|QL
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ private void doLookup(
final SourceOperator queryOperator = switch (matchType) {
case "match", "range" -> {
QueryList queryList = QueryList.termQueryList(fieldType, searchExecutionContext, inputBlock);
yield new EnrichQuerySourceOperator(queryList, searchExecutionContext.getIndexReader());
yield new EnrichQuerySourceOperator(blockFactory, queryList, searchExecutionContext.getIndexReader());
}
default -> throw new EsqlIllegalArgumentException("illegal match type " + matchType);
};
Expand Down Expand Up @@ -289,7 +289,7 @@ private void doLookup(
// merging field-values by position
final int[] mergingChannels = IntStream.range(0, extractFields.size()).map(i -> i + 1).toArray();
intermediateOperators.add(
new MergePositionsOperator(singleLeaf, inputPage.getPositionCount(), 0, mergingChannels, mergingTypes)
new MergePositionsOperator(singleLeaf, inputPage.getPositionCount(), 0, mergingChannels, mergingTypes, blockFactory)
);
AtomicReference<Page> result = new AtomicReference<>();
OutputOperator outputOperator = new OutputOperator(List.of(), Function.identity(), result::set);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.ArrayUtil;
import org.elasticsearch.compute.data.ConstantIntVector;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.DocVector;
import org.elasticsearch.compute.data.IntArrayVector;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.SourceOperator;
import org.elasticsearch.core.Releasables;

import java.io.IOException;
import java.io.UncheckedIOException;
Expand All @@ -34,14 +34,16 @@
*/
final class EnrichQuerySourceOperator extends SourceOperator {

private final BlockFactory blockFactory;
private final QueryList queryList;
private int queryPosition;
private Weight weight = null;
private final IndexReader indexReader;
private int leafIndex = 0;
private final IndexSearcher searcher;

EnrichQuerySourceOperator(QueryList queryList, IndexReader indexReader) {
EnrichQuerySourceOperator(BlockFactory blockFactory, QueryList queryList, IndexReader indexReader) {
this.blockFactory = blockFactory;
this.queryList = queryList;
this.indexReader = indexReader;
this.searcher = new IndexSearcher(indexReader);
Expand Down Expand Up @@ -92,32 +94,39 @@ private Page queryOneLeaf(Weight weight, int leafIndex) throws IOException {
if (scorer == null) {
return null;
}
DocCollector collector = new DocCollector();
scorer.score(collector, leafReaderContext.reader().getLiveDocs());
final int matches = collector.matches;
DocVector docVector = new DocVector(
new ConstantIntVector(0, matches),
new ConstantIntVector(leafIndex, matches),
new IntArrayVector(collector.docs, matches),
true
);
IntBlock positionBlock = new ConstantIntVector(queryPosition, matches).asBlock();
return new Page(docVector.asBlock(), positionBlock);
IntVector docs = null, segments = null, shards = null;
boolean success = false;
try (IntVector.Builder docsBuilder = blockFactory.newIntVectorBuilder(1)) {
scorer.score(new DocCollector(docsBuilder), leafReaderContext.reader().getLiveDocs());
docs = docsBuilder.build();
final int positionCount = docs.getPositionCount();
segments = blockFactory.newConstantIntVector(leafIndex, positionCount);
shards = blockFactory.newConstantIntVector(0, positionCount);
var positions = blockFactory.newConstantIntBlockWith(queryPosition, positionCount);
success = true;
return new Page(new DocVector(shards, segments, docs, true).asBlock(), positions);
} finally {
if (success == false) {
Releasables.close(docs, shards, segments);
}
}
}

private static class DocCollector implements LeafCollector {
int matches = 0;
int[] docs = new int[0];
final IntVector.Builder docIds;

DocCollector(IntVector.Builder docIds) {
this.docIds = docIds;
}

@Override
public void setScorer(Scorable scorer) {

}

@Override
public void collect(int doc) throws IOException {
docs = ArrayUtil.grow(docs, matches + 1);
docs[matches++] = doc;
public void collect(int doc) {
docIds.appendInt(doc);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
package org.elasticsearch.xpack.esql.enrich;

import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.Operator;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;

import java.util.Arrays;
Expand Down Expand Up @@ -53,8 +55,16 @@ final class MergePositionsOperator implements Operator {
private PositionBuilder positionBuilder = null;

private Page outputPage;

MergePositionsOperator(boolean singleMode, int positionCount, int positionChannel, int[] mergingChannels, ElementType[] mergingTypes) {
private final BlockFactory blockFactory;

MergePositionsOperator(
boolean singleMode,
int positionCount,
int positionChannel,
int[] mergingChannels,
ElementType[] mergingTypes,
BlockFactory blockFactory
) {
if (mergingChannels.length != mergingTypes.length) {
throw new IllegalArgumentException(
"Merging channels don't match merging types; channels="
Expand All @@ -63,14 +73,21 @@ final class MergePositionsOperator implements Operator {
+ Arrays.toString(mergingTypes)
);
}
this.blockFactory = blockFactory;
this.singleMode = singleMode;
this.positionCount = positionCount;
this.positionChannel = positionChannel;
this.mergingChannels = mergingChannels;
this.mergingTypes = mergingTypes;
this.outputBuilders = new Block.Builder[mergingTypes.length];
for (int i = 0; i < mergingTypes.length; i++) {
outputBuilders[i] = mergingTypes[i].newBlockBuilder(positionCount);
try {
for (int i = 0; i < mergingTypes.length; i++) {
outputBuilders[i] = mergingTypes[i].newBlockBuilder(positionCount, blockFactory);
}
} finally {
if (outputBuilders[outputBuilders.length - 1] == null) {
Releasables.close(outputBuilders);
}
}
}

Expand All @@ -96,7 +113,7 @@ public void addInput(Page page) {
flushPositionBuilder();
}
if (positionBuilder == null) {
positionBuilder = new PositionBuilder(currentPosition, mergingTypes);
positionBuilder = new PositionBuilder(currentPosition, mergingTypes, blockFactory);
}
positionBuilder.combine(page, mergingChannels);
}
Expand All @@ -105,36 +122,53 @@ public void addInput(Page page) {
}
}

static final class PositionBuilder {
static final class PositionBuilder implements Releasable {
private final int position;
private final Block.Builder[] builders;

PositionBuilder(int position, ElementType[] elementTypes) {
PositionBuilder(int position, ElementType[] elementTypes, BlockFactory blockFactory) {
this.position = position;
this.builders = new Block.Builder[elementTypes.length];
for (int i = 0; i < builders.length; i++) {
builders[i] = elementTypes[i].newBlockBuilder(1);
try {
for (int i = 0; i < builders.length; i++) {
builders[i] = elementTypes[i].newBlockBuilder(1, blockFactory);
}
} finally {
if (builders[builders.length - 1] == null) {
Releasables.close(builders);
}
}
}

void combine(Page page, int[] channels) {
for (int i = 0; i < channels.length; i++) {
builders[i].appendAllValuesToCurrentPosition(page.getBlock(channels[i]));
Block block = page.getBlock(channels[i]);
builders[i].appendAllValuesToCurrentPosition(block);
}
}

void buildTo(Block.Builder[] output) {
for (int i = 0; i < output.length; i++) {
output[i].appendAllValuesToCurrentPosition(builders[i].build());
try (var b = builders[i]; Block block = b.build()) {
output[i].appendAllValuesToCurrentPosition(block);
}
}
}

@Override
public void close() {
Releasables.close(builders);
}
}

private void flushPositionBuilder() {
fillNullUpToPosition(positionBuilder.position);
filledPositions++;
positionBuilder.buildTo(outputBuilders);
positionBuilder = null;
try (var p = positionBuilder) {
p.buildTo(outputBuilders);
} finally {
positionBuilder = null;
}
}

private void fillNullUpToPosition(int position) {
Expand All @@ -152,14 +186,10 @@ public void finish() {
flushPositionBuilder();
}
fillNullUpToPosition(positionCount);
try {
Block[] blocks = Arrays.stream(outputBuilders).map(Block.Builder::build).toArray(Block[]::new);
outputPage = new Page(blocks);
finished = true;
assert outputPage.getPositionCount() == positionCount;
} finally {
Releasables.closeExpectNoException(outputBuilders);
}
final Block[] blocks = Block.Builder.buildAll(outputBuilders);
outputPage = new Page(blocks);
assert outputPage.getPositionCount() == positionCount;
finished = true;
}

@Override
Expand All @@ -176,6 +206,10 @@ public Page getOutput() {

@Override
public void close() {

Releasables.close(Releasables.wrap(outputBuilders), positionBuilder, () -> {
if (outputPage != null) {
outputPage.releaseBlocks();
}
});
}
}

0 comments on commit dcb5499

Please sign in to comment.