Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add details section for dcg ranking metric #31177

Merged
merged 4 commits into from
Jun 15, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import java.util.Optional;
import java.util.stream.Collectors;

import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRatings;

Expand Down Expand Up @@ -129,26 +130,31 @@ public EvalQueryQuality evaluate(String taskId, SearchHit[] hits,
.collect(Collectors.toList());
List<RatedSearchHit> ratedHits = joinHitsWithRatings(hits, ratedDocs);
List<Integer> ratingsInSearchHits = new ArrayList<>(ratedHits.size());
int unratedResults = 0;
for (RatedSearchHit hit : ratedHits) {
// unknownDocRating might be null, which means it will be unrated docs are
// ignored in the dcg calculation
// we still need to add them as a placeholder so the rank of the subsequent
// ratings is correct
// unknownDocRating might be null, in which case unrated docs will be ignored in the dcg calculation.
// we still need to add them as a placeholder so the rank of the subsequent ratings is correct
ratingsInSearchHits.add(hit.getRating().orElse(unknownDocRating));
if (hit.getRating().isPresent() == false) {
unratedResults++;
}
}
double dcg = computeDCG(ratingsInSearchHits);
final double dcg = computeDCG(ratingsInSearchHits);
double result = dcg;
double idcg = 0;

if (normalize) {
Collections.sort(allRatings, Comparator.nullsLast(Collections.reverseOrder()));
double idcg = computeDCG(allRatings.subList(0, Math.min(ratingsInSearchHits.size(), allRatings.size())));
if (idcg > 0) {
dcg = dcg / idcg;
idcg = computeDCG(allRatings.subList(0, Math.min(ratingsInSearchHits.size(), allRatings.size())));
if (idcg != 0) {
result = dcg / idcg;
} else {
dcg = 0;
result = 0;
}
}
EvalQueryQuality evalQueryQuality = new EvalQueryQuality(taskId, dcg);
EvalQueryQuality evalQueryQuality = new EvalQueryQuality(taskId, result);
evalQueryQuality.addHitsAndRatings(ratedHits);
evalQueryQuality.setMetricDetails(new Detail(dcg, idcg, unratedResults));
return evalQueryQuality;
}

Expand All @@ -167,7 +173,7 @@ private static double computeDCG(List<Integer> ratings) {
private static final ParseField K_FIELD = new ParseField("k");
private static final ParseField NORMALIZE_FIELD = new ParseField("normalize");
private static final ParseField UNKNOWN_DOC_RATING_FIELD = new ParseField("unknown_doc_rating");
private static final ConstructingObjectParser<DiscountedCumulativeGain, Void> PARSER = new ConstructingObjectParser<>("dcg_at", false,
private static final ConstructingObjectParser<DiscountedCumulativeGain, Void> PARSER = new ConstructingObjectParser<>("dcg", false,
args -> {
Boolean normalized = (Boolean) args[0];
Integer optK = (Integer) args[2];
Expand Down Expand Up @@ -217,4 +223,118 @@ public final boolean equals(Object obj) {
public final int hashCode() {
return Objects.hash(normalize, unknownDocRating, k);
}

public static final class Detail implements MetricDetail {

private static ParseField DCG_FIELD = new ParseField("dcg");
private static ParseField IDCG_FIELD = new ParseField("ideal_dcg");
private static ParseField NDCG_FIELD = new ParseField("normalized_dcg");
private static ParseField UNRATED_FIELD = new ParseField("unrated_docs");
private final double dcg;
private final double idcg;
private final int unratedDocs;

Detail(double dcg, double idcg, int unratedDocs) {
this.dcg = dcg;
this.idcg = idcg;
this.unratedDocs = unratedDocs;
}

Detail(StreamInput in) throws IOException {
this.dcg = in.readDouble();
this.idcg = in.readDouble();
this.unratedDocs = in.readVInt();
}

@Override
public
String getMetricName() {
return NAME;
}

@Override
public XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For normalized DCG, the REST output this adds to the response for each rated request is e.g.:

{
		"dcg": 0.55,
		"unrated_docs": 3
}

for non-normalized dcg and something like this for the normalized variant:

{
	"dcg": 0.69,
	"ideal_dcg": 0.60,
	"normalized_dcg": 1.14,
	"unrated_docs": 2
}

While this isn't super helpful for plain dcg (the metric value is already reported elsewhere, but the number of unrated documents might be interesting to users or for display in a UI), the IDCG and normalization is somewhat interesting I believe.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @cbuescher, the details are useful. I am just wondering if we should report duplicate details (dcg in non-normalized version, and normalized_dcg in normalized version), since they are already reported. But I will leave the decision to you as the main architect of ranking evalution.

Copy link
Member Author

@cbuescher cbuescher Jun 8, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am just wondering if we should report duplicate details (dcg in non-normalized version, and normalized_dcg in normalized version)

I was wondering the same actually, but then things like parsing the metric details on the client side suddenly gets much more complext, because in order to re-create the details object we would have to somehow detect which variant we currently parse, and if e.g. the "dcg" value was left out here, we'd have to reach out to the metrics score field which is parsed on another level, so that gets kind of ugly. I also don't like the redundancy very much but this way the objects stays kind of self contained. That said I'll give it another thought...

builder.field(DCG_FIELD.getPreferredName(), this.dcg);
if (this.idcg != 0) {
builder.field(IDCG_FIELD.getPreferredName(), this.idcg);
builder.field(NDCG_FIELD.getPreferredName(), this.dcg / this.idcg);
}
builder.field(UNRATED_FIELD.getPreferredName(), this.unratedDocs);
return builder;
}

private static final ConstructingObjectParser<Detail, Void> PARSER = new ConstructingObjectParser<>(NAME, true, args -> {
return new Detail((Double) args[0], (Double) args[1] != null ? (Double) args[1] : 0.0d, (Integer) args[2]);
});

static {
PARSER.declareDouble(constructorArg(), DCG_FIELD);
PARSER.declareDouble(optionalConstructorArg(), IDCG_FIELD);
PARSER.declareInt(constructorArg(), UNRATED_FIELD);
}

public static Detail fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeDouble(this.dcg);
out.writeDouble(this.idcg);
out.writeVInt(this.unratedDocs);
}

@Override
public String getWriteableName() {
return NAME;
}

/**
* @return the discounted cumulative gain
*/
public double getDCG() {
return this.dcg;
}

/**
* @return the ideal discounted cumulative gain, can be 0 if nothing was computed, e.g. because no normalization was required
*/
public double getIDCG() {
return this.idcg;
}

/**
* @return the normalized discounted cumulative gain, can be 0 if nothing was computed, e.g. because no normalization was required
*/
public double getNDCG() {
return (this.idcg != 0) ? this.dcg / this.idcg : 0;
}

/**
* @return the number of unrated documents in the search results
*/
public Object getUnratedDocs() {
return this.unratedDocs;
}

@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
DiscountedCumulativeGain.Detail other = (DiscountedCumulativeGain.Detail) obj;
return (this.dcg == other.dcg &&
this.idcg == other.idcg &&
this.unratedDocs == other.unratedDocs);
}

@Override
public int hashCode() {
return Objects.hash(this.dcg, this.idcg, this.unratedDocs);
}
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
PrecisionAtK.Detail::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(MeanReciprocalRank.NAME),
MeanReciprocalRank.Detail::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(DiscountedCumulativeGain.NAME),
DiscountedCumulativeGain.Detail::fromXContent));
return namedXContent;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
namedWriteables.add(
new NamedWriteableRegistry.Entry(EvaluationMetric.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetail.class, PrecisionAtK.NAME, PrecisionAtK.Detail::new));
namedWriteables
.add(new NamedWriteableRegistry.Entry(MetricDetail.class, MeanReciprocalRank.NAME, MeanReciprocalRank.Detail::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetail.class, MeanReciprocalRank.NAME, MeanReciprocalRank.Detail::new));
namedWriteables.add(
new NamedWriteableRegistry.Entry(MetricDetail.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain.Detail::new));
return namedWriteables;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.elasticsearch.index.rankeval;

import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.text.Text;
Expand Down Expand Up @@ -254,9 +255,8 @@ private void assertParsedCorrect(String xContent, Integer expectedUnknownDocRati

public static DiscountedCumulativeGain createTestItem() {
boolean normalize = randomBoolean();
Integer unknownDocRating = Integer.valueOf(randomIntBetween(0, 1000));

return new DiscountedCumulativeGain(normalize, unknownDocRating, 10);
Integer unknownDocRating = frequently() ? Integer.valueOf(randomIntBetween(0, 1000)) : null;
return new DiscountedCumulativeGain(normalize, unknownDocRating, randomIntBetween(1, 10));
}

public void testXContentRoundtrip() throws IOException {
Expand All @@ -283,7 +283,25 @@ public void testXContentParsingIsNotLenient() throws IOException {
parser.nextToken();
XContentParseException exception = expectThrows(XContentParseException.class,
() -> DiscountedCumulativeGain.fromXContent(parser));
assertThat(exception.getMessage(), containsString("[dcg_at] unknown field"));
assertThat(exception.getMessage(), containsString("[dcg] unknown field"));
}
}

public void testMetricDetails() {
double dcg = randomDoubleBetween(0, 1, true);
double idcg = randomBoolean() ? 0.0 : randomDoubleBetween(0, 1, true);
double expectedNdcg = idcg != 0 ? dcg / idcg : 0.0;
int unratedDocs = randomIntBetween(0, 100);
DiscountedCumulativeGain.Detail detail = new DiscountedCumulativeGain.Detail(dcg, idcg, unratedDocs);
assertEquals(dcg, detail.getDCG(), 0.0);
assertEquals(idcg, detail.getIDCG(), 0.0);
assertEquals(expectedNdcg, detail.getNDCG(), 0.0);
assertEquals(unratedDocs, detail.getUnratedDocs());
if (idcg != 0) {
assertEquals("{\"dcg\":{\"dcg\":" + dcg + ",\"ideal_dcg\":" + idcg + ",\"normalized_dcg\":" + expectedNdcg
+ ",\"unrated_docs\":" + unratedDocs + "}}", Strings.toString(detail));
} else {
assertEquals("{\"dcg\":{\"dcg\":" + dcg + ",\"unrated_docs\":" + unratedDocs + "}}", Strings.toString(detail));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,20 @@ public static EvalQueryQuality randomEvalQueryQuality() {
EvalQueryQuality evalQueryQuality = new EvalQueryQuality(randomAlphaOfLength(10),
randomDoubleBetween(0.0, 1.0, true));
if (randomBoolean()) {
if (randomBoolean()) {
int metricDetail = randomIntBetween(0, 2);
switch (metricDetail) {
case 0:
evalQueryQuality.setMetricDetails(new PrecisionAtK.Detail(randomIntBetween(0, 1000), randomIntBetween(0, 1000)));
} else {
break;
case 1:
evalQueryQuality.setMetricDetails(new MeanReciprocalRank.Detail(randomIntBetween(0, 1000)));
break;
case 2:
evalQueryQuality.setMetricDetails(new DiscountedCumulativeGain.Detail(randomDoubleBetween(0, 1, true),
randomBoolean() ? randomDoubleBetween(0, 1, true) : 0, randomInt()));
break;
default:
throw new IllegalArgumentException("illegal randomized value in test");
}
}
evalQueryQuality.addHitsAndRatings(ratedHits);
Expand Down