Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockUtils;
import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.compute.data.DoubleVector;
import org.elasticsearch.compute.data.DoubleVectorBlock;
import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.Operator;
import org.elasticsearch.compute.operator.Warnings;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.xcontent.XContentBuilder;
Expand All @@ -28,6 +29,7 @@
import java.util.Collection;
import java.util.Deque;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
Expand All @@ -40,11 +42,26 @@
*
*/
public class LinearScoreEvalOperator implements Operator {
public record Factory(int discriminatorPosition, int scorePosition, LinearConfig linearConfig) implements OperatorFactory {
public record Factory(
int discriminatorPosition,
int scorePosition,
LinearConfig linearConfig,
String sourceText,
int sourceLine,
int sourceColumn
) implements OperatorFactory {

@Override
public Operator get(DriverContext driverContext) {
return new LinearScoreEvalOperator(discriminatorPosition, scorePosition, linearConfig);
return new LinearScoreEvalOperator(
driverContext,
discriminatorPosition,
scorePosition,
linearConfig,
sourceText,
sourceLine,
sourceColumn
);
}

@Override
Expand Down Expand Up @@ -74,11 +91,30 @@ public String describe() {
private long rowsReceived = 0;
private long rowsEmitted = 0;

public LinearScoreEvalOperator(int discriminatorPosition, int scorePosition, LinearConfig config) {
private final String sourceText;
private final int sourceLine;
private final int sourceColumn;
private Warnings warnings;
private final DriverContext driverContext;

public LinearScoreEvalOperator(
DriverContext driverContext,
int discriminatorPosition,
int scorePosition,
LinearConfig config,
String sourceText,
int sourceLine,
int sourceColumn
) {
this.scorePosition = scorePosition;
this.discriminatorPosition = discriminatorPosition;
this.config = config;
this.normalizer = createNormalizer(config.normalizer());
this.driverContext = driverContext;

this.sourceText = sourceText;
this.sourceLine = sourceLine;
this.sourceColumn = sourceColumn;

finished = false;
inputPages = new ArrayDeque<>();
Expand Down Expand Up @@ -123,25 +159,54 @@ private void createOutputPages() {

private void processInputPage(Page inputPage) {
BytesRefBlock discriminatorBlock = inputPage.getBlock(discriminatorPosition);
DoubleVectorBlock initialScoreBlock = inputPage.getBlock(scorePosition);
DoubleBlock initialScoreBlock = inputPage.getBlock(scorePosition);

Page newPage = null;
Block scoreBlock = null;
DoubleVector.Builder scores = null;
DoubleBlock.Builder scores = null;

try {
scores = discriminatorBlock.blockFactory().newDoubleVectorBuilder(discriminatorBlock.getPositionCount());
scores = discriminatorBlock.blockFactory().newDoubleBlockBuilder(discriminatorBlock.getPositionCount());

for (int i = 0; i < inputPage.getPositionCount(); i++) {
String discriminator = discriminatorBlock.getBytesRef(i, new BytesRef()).utf8ToString();
Object discriminatorValue = BlockUtils.toJavaObject(discriminatorBlock, i);

if (discriminatorValue == null) {
warnings().registerException(new IllegalArgumentException("group column has null values; assigning null scores"));
scores.appendNull();
continue;
} else if (discriminatorValue instanceof List<?>) {
warnings().registerException(
new IllegalArgumentException("group column contains multivalued entries; assigning null scores")
);
scores.appendNull();
continue;
}
String discriminator = ((BytesRef) discriminatorValue).utf8ToString();

var weight = config.weights().get(discriminator) == null ? 1.0 : config.weights().get(discriminator);

double score = initialScoreBlock.getDouble(i);
initialScoreBlock.doesHaveMultivaluedFields();

Object scoreValue = BlockUtils.toJavaObject(initialScoreBlock, i);
if (scoreValue == null) {
warnings().registerException(new IllegalArgumentException("score column has null values; assigning null scores"));
scores.appendNull();
continue;
} else if (scoreValue instanceof List<?>) {
warnings().registerException(
new IllegalArgumentException("score column contains multivalued entries; assigning null scores")
);
scores.appendNull();
continue;
}

double score = (double) scoreValue;

scores.appendDouble(weight * normalizer.normalize(score, discriminator));
}

scoreBlock = scores.build().asBlock();
scoreBlock = scores.build();
newPage = inputPage.appendBlock(scoreBlock);

int[] projections = new int[newPage.getBlockCount() - 1];
Expand Down Expand Up @@ -270,23 +335,43 @@ private Normalizer createNormalizer(LinearConfig.Normalizer normalizer) {
};
}

private interface Normalizer {
double normalize(double score, String discriminator);
private abstract static class Normalizer {
abstract double normalize(double score, String discriminator);

void preprocess(Collection<Page> inputPages, int scorePosition, int discriminatorPosition);
abstract void preprocess(double score, String discriminator);

void finalizePreprocess() {};

void preprocess(Collection<Page> inputPages, int scorePosition, int discriminatorPosition) {
for (Page inputPage : inputPages) {
DoubleBlock scoreBlock = inputPage.getBlock(scorePosition);
BytesRefBlock discriminatorBlock = inputPage.getBlock(discriminatorPosition);

for (int i = 0; i < inputPage.getPositionCount(); i++) {
Object scoreValue = BlockUtils.toJavaObject(scoreBlock, i);
Object discriminatorValue = BlockUtils.toJavaObject(discriminatorBlock, i);

if (scoreValue instanceof Double score && discriminatorValue instanceof BytesRef discriminator) {
preprocess(score, discriminator.utf8ToString());
}
}
}

finalizePreprocess();
}
}

private class NoneNormalizer implements Normalizer {
private static class NoneNormalizer extends Normalizer {
@Override
public double normalize(double score, String discriminator) {
return score;
}

@Override
public void preprocess(Collection<Page> inputPages, int scorePosition, int discriminatorPosition) {}
void preprocess(double score, String discriminator) {}
}

private class L2NormNormalizer implements Normalizer {
private static class L2NormNormalizer extends Normalizer {
private final Map<String, Double> l2Norms = new HashMap<>();

@Override
Expand All @@ -297,24 +382,17 @@ public double normalize(double score, String discriminator) {
}

@Override
public void preprocess(Collection<Page> inputPages, int scorePosition, int discriminatorPosition) {
for (Page inputPage : inputPages) {
DoubleVectorBlock scoreBlock = inputPage.getBlock(scorePosition);
BytesRefBlock discriminatorBlock = inputPage.getBlock(discriminatorPosition);

for (int i = 0; i < inputPage.getPositionCount(); i++) {
double score = scoreBlock.getDouble(i);
String discriminator = discriminatorBlock.getBytesRef(i, new BytesRef()).utf8ToString();

l2Norms.compute(discriminator, (k, v) -> v == null ? score * score : v + score * score);
}
}
void preprocess(double score, String discriminator) {
l2Norms.compute(discriminator, (k, v) -> v == null ? score * score : v + score * score);
}

@Override
void finalizePreprocess() {
l2Norms.replaceAll((k, v) -> Math.sqrt(v));
}
}

private class MinMaxNormalizer implements Normalizer {
private static class MinMaxNormalizer extends Normalizer {
private final Map<String, Double> minScores = new HashMap<>();
private final Map<String, Double> maxScores = new HashMap<>();

Expand All @@ -334,19 +412,17 @@ public double normalize(double score, String discriminator) {
}

@Override
public void preprocess(Collection<Page> inputPages, int scorePosition, int discriminatorPosition) {
for (Page inputPage : inputPages) {
DoubleVectorBlock scoreBlock = inputPage.getBlock(scorePosition);
BytesRefBlock discriminatorBlock = inputPage.getBlock(discriminatorPosition);

for (int i = 0; i < inputPage.getPositionCount(); i++) {
double score = scoreBlock.getDouble(i);
String discriminator = discriminatorBlock.getBytesRef(i, new BytesRef()).utf8ToString();
void preprocess(double score, String discriminator) {
minScores.compute(discriminator, (key, value) -> value == null ? score : Math.min(value, score));
maxScores.compute(discriminator, (key, value) -> value == null ? score : Math.max(value, score));
}
}

minScores.compute(discriminator, (key, value) -> value == null ? score : Math.min(value, score));
maxScores.compute(discriminator, (key, value) -> value == null ? score : Math.max(value, score));
}
}
private Warnings warnings() {
if (warnings == null) {
this.warnings = Warnings.createWarnings(driverContext.warningsMode(), sourceLine, sourceColumn, sourceText);
}

return warnings;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,18 @@

import org.apache.lucene.util.BytesRef;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockUtils;
import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.compute.data.DoubleVector;
import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.AbstractPageMappingOperator;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.Operator;
import org.elasticsearch.compute.operator.Warnings;
import org.elasticsearch.core.Releasables;

import java.util.HashMap;
import java.util.List;

/**
* Updates the score column with new scores using the RRF formula.
Expand All @@ -27,10 +30,25 @@
*/
public class RrfScoreEvalOperator extends AbstractPageMappingOperator {

public record Factory(int discriminatorPosition, int scorePosition, RrfConfig rrfConfig) implements OperatorFactory {
public record Factory(
int discriminatorPosition,
int scorePosition,
RrfConfig rrfConfig,
String sourceText,
int sourceLine,
int sourceColumn
) implements OperatorFactory {
@Override
public Operator get(DriverContext driverContext) {
return new RrfScoreEvalOperator(discriminatorPosition, scorePosition, rrfConfig);
return new RrfScoreEvalOperator(
driverContext,
discriminatorPosition,
scorePosition,
rrfConfig,
sourceText,
sourceLine,
sourceColumn
);
}

@Override
Expand All @@ -48,37 +66,62 @@ public String describe() {
private final int scorePosition;
private final int discriminatorPosition;
private final RrfConfig config;
private Warnings warnings;
private final DriverContext driverContext;
private final String sourceText;
private final int sourceLine;
private final int sourceColumn;

private HashMap<String, Integer> counters = new HashMap<>();

public RrfScoreEvalOperator(int discriminatorPosition, int scorePosition, RrfConfig config) {
public RrfScoreEvalOperator(
DriverContext driverContext,
int discriminatorPosition,
int scorePosition,
RrfConfig config,
String sourceText,
int sourceLine,
int sourceColumn
) {
this.scorePosition = scorePosition;
this.discriminatorPosition = discriminatorPosition;
this.config = config;
this.driverContext = driverContext;
this.sourceText = sourceText;
this.sourceLine = sourceLine;
this.sourceColumn = sourceColumn;
}

@Override
protected Page process(Page page) {
BytesRefBlock discriminatorBlock = (BytesRefBlock) page.getBlock(discriminatorPosition);

DoubleVector.Builder scores = discriminatorBlock.blockFactory().newDoubleVectorBuilder(discriminatorBlock.getPositionCount());
BytesRefBlock discriminatorBlock = page.getBlock(discriminatorPosition);
DoubleBlock.Builder scores = discriminatorBlock.blockFactory().newDoubleBlockBuilder(discriminatorBlock.getPositionCount());

for (int i = 0; i < page.getPositionCount(); i++) {
String discriminator = discriminatorBlock.getBytesRef(i, new BytesRef()).utf8ToString();

int rank = counters.getOrDefault(discriminator, 1);
counters.put(discriminator, rank + 1);

var weight = config.weights().getOrDefault(discriminator, 1.0);

scores.appendDouble(1.0 / (config.rankConstant() + rank) * weight);
Object value = BlockUtils.toJavaObject(discriminatorBlock, i);

if (value == null) {
warnings().registerException(new IllegalArgumentException("group column has null values; assigning null scores"));
scores.appendNull();
} else if (value instanceof List<?>) {
warnings().registerException(
new IllegalArgumentException("group column contains multivalued entries; assigning null scores")
);
scores.appendNull();
} else {
String discriminator = ((BytesRef) value).utf8ToString();
int rank = counters.getOrDefault(discriminator, 1);
var weight = config.weights().getOrDefault(discriminator, 1.0);
scores.appendDouble(1.0 / (config.rankConstant() + rank) * weight);
counters.put(discriminator, rank + 1);
}
}

Page newPage = null;
Block scoreBlock = null;

try {
scoreBlock = scores.build().asBlock();
scoreBlock = scores.build();
newPage = page.appendBlock(scoreBlock);

int[] projections = new int[newPage.getBlockCount() - 1];
Expand All @@ -105,4 +148,12 @@ protected Page process(Page page) {
public String toString() {
return "RrfScoreEvalOperator";
}

private Warnings warnings() {
if (warnings == null) {
this.warnings = Warnings.createWarnings(driverContext.warningsMode(), sourceLine, sourceColumn, sourceText);
}

return warnings;
}
}
Loading