Browse files

Custom score script support tests.

  • Loading branch information...
1 parent b388db1 commit 1d7cf753197ba2072c3b7125f6e8e6a1c067243b @adamalix adamalix committed with May 29, 2012
View
9 src/main/scala/com/foursquare/slashem/QueryBuilder.scala
@@ -56,7 +56,7 @@ case class QueryBuilder[M <: Record[M], Ord, Lim, MM <: MinimumMatchType, Y, H <
queryType: Option[String],
fieldsToFetch: List[String],
facetSettings: FacetSettings,
- customScoreScript: Option[String],
+ customScoreScript: Option[(String, Map[String, Any])],
hls: Option[String],
hlFragSize: Option[Int],
creator: Option[(Pair[Map[String,Any],
@@ -170,6 +170,9 @@ case class QueryBuilder[M <: Record[M], Ord, Lim, MM <: MinimumMatchType, Y, H <
this.copy(start=Some(s))
}
+ // Implicits on the following functions guard against the function being
+ // multiple times
+
/** Limit the query to only fetch back l results.
* Can only be applied to a query without an existing limit
* @param l The limit */
@@ -347,6 +350,10 @@ case class QueryBuilder[M <: Record[M], Ord, Lim, MM <: MinimumMatchType, Y, H <
this.copy(boostFields=f(meta)::boostFields)
}
+ def customScore(scriptName: String, params: Map[String, Any]) (implicit ev: ST =:= NoScoreModifiers):
+ QueryBuilder[M, Ord, Lim, MM, Y, H, Q, MinFacetCount, FacetLimit, NativeScoreScript] = {
+ this.copy(customScoreScript = Some((scriptName, params)))
+ }
//Print out some debugging information.
def test(): Unit = {
View
25 src/main/scala/com/foursquare/slashem/Schema.scala
@@ -459,7 +459,6 @@ trait SlashemSchema[M <: Record[M]] extends Record[M] {
})
}
-
def where[F](c: M => Clause[F]): QueryBuilder[M, Unordered, Unlimited, defaultMM, NoSelect, NoHighlighting, NoQualityFilter, NoMinimumFacetCount, Unlimited, NoScoreModifiers] = {
QueryBuilder(self, c(self), filters=Nil, boostQueries=Nil, queryFields=Nil,
phraseBoostFields=Nil, boostFields=Nil, start=None, limit=None,
@@ -614,6 +613,7 @@ trait ElasticSchema[M <: Record[M]] extends SlashemSchema[M] {
Response(this, creator, hitCount, start, docs,
fallOff=fallOff, min=min, fieldFacet))
}
+
def buildElasticQuery[Ord, Lim, MM <: MinimumMatchType, Y, H <: Highlighting, Q <: QualityFilter, FC <: FacetCount, FLim, ST <: ScoreType](qb: QueryBuilder[M, Ord, Lim, MM, Y, H, Q, FC, FLim, ST]): ElasticQueryBuilder = {
val baseQuery: ElasticQueryBuilder= qb.clauses.elasticExtend(qb.queryFields,
qb.phraseBoostFields,
@@ -624,12 +624,21 @@ trait ElasticSchema[M <: Record[M]] extends SlashemSchema[M] {
case _ => filteredQuery(baseQuery,combineFilters(qb.filters.map(_.elasticFilter(qb.queryFields))))
}
//Apply any custom scoring rules (aka emulating Solr's bq/bf)
- val boostedQuery = qb.boostFields match {
- case Nil => fq
- case _ => boostFields(fq, qb.boostFields)
+ val scoredQuery = qb.customScoreScript match {
+ case Some((script, params)) => {
+ scoreWithScript(fq, script, params)
+ }
+ case None => {
+ val boostedQuery = qb.boostFields match {
+ case Nil => fq
+ case _ => boostFields(fq, qb.boostFields)
+ }
+ boostedQuery
+ }
}
- boostedQuery
+ scoredQuery
}
+
def termFacetQuery(facetFields: List[Ast.Field], facetLimit: Option[Int]): List[AbstractFacetBuilder] = {
val fieldNames = facetFields.map(_.boost())
val facetQueries = fieldNames.map(name => {
@@ -660,9 +669,8 @@ trait ElasticSchema[M <: Record[M]] extends SlashemSchema[M] {
new AndFilterBuilder(filters:_*)
}
- def scoreWithScript(query: ElasticQueryBuilder,
- scriptName: String,
- namesAndParams: List[Pair[String, Any]]): ElasticQueryBuilder = {
+ def scoreWithScript(query: ElasticQueryBuilder, scriptName: String,
+ namesAndParams: Map[String, Any]): ElasticQueryBuilder = {
val customScoreQuery = new CustomScoreQueryBuilder(query)
customScoreQuery.script(scriptName).lang("native")
for ((name, param) <- namesAndParams) {
@@ -671,6 +679,7 @@ trait ElasticSchema[M <: Record[M]] extends SlashemSchema[M] {
customScoreQuery
}
}
+
trait SolrSchema[M <: Record[M]] extends SlashemSchema[M] {
self: M with SlashemSchema[M] =>
View
1 src/test/resources/es-plugin.properties
@@ -0,0 +1 @@
+plugin=com.foursquare.elasticsearch.scorer.FourSquareScorePlugin
View
17 src/test/scala/com/foursquare/elasticsearch/scorer/FoursquareScorePlugin.scala
@@ -0,0 +1,17 @@
+package com.foursquare.elasticsearch.scorer;
+
+import org.elasticsearch.plugins.AbstractPlugin;
+import org.elasticsearch.script.ScriptModule;
+
+/**
+ * Provides a fast* score script for our primary use case
+ */
+class FourSquareScorePlugin extends AbstractPlugin {
+ override def name(): String = "foursquare";
+
+ override def description(): String = "foursquare plugin";
+
+ def onModule(module: ScriptModule): Unit = {
+ module.registerScript("distance_score_magic", classOf[ScoreFactory]);
+ }
+}
View
42 ...ala/com/foursquare/elasticsearch/scorer/script/CombinedDistanceDocumentScorerScript.scala
@@ -0,0 +1,42 @@
+package com.foursquare.elasticsearch.scorer;
+
+import org.elasticsearch.common.Nullable;
+import org.elasticsearch.common.xcontent.support.XContentMapValues;
+import org.elasticsearch.index.field.data.NumericDocFieldData;
+import org.elasticsearch.index.mapper.geo.GeoPointDocFieldData;
+import org.elasticsearch.script.AbstractFloatSearchScript;
+import org.elasticsearch.script.ExecutableScript;
+import org.elasticsearch.script.NativeScriptFactory;
+import org.elasticsearch.search.lookup.DocLookup;
+
+import java.util.Map;
+
+/**
+ * Note: assumes that the point field is point
+ */
+case class CombinedDistanceDocumentScorerSearchScript(val lat: Double,
+ val lon: Double,
+ val weight1: Float,
+ val weight2: Float) extends AbstractFloatSearchScript {
+
+ override def runAsFloat(): Float = {
+ val myDoc: DocLookup = doc();
+ val point: GeoPointDocFieldData = myDoc.get("point").asInstanceOf[GeoPointDocFieldData];
+ val popularity: Double = myDoc.numeric("decayedPopularity1").asInstanceOf[NumericDocFieldData[_]].getDoubleValue()
+ // up to you to remove score from here or not..., also, possibly, add more weights options
+ val myScore: Float = (score() *
+ (1 + weight1 * math.pow(((1.0 * (math.pow(point.distanceInKm(lat, lon), 2.0))) + 1.0), -1.0)
+ + popularity * weight2)).toFloat;
+ myScore
+ }
+}
+
+class ScoreFactory extends NativeScriptFactory {
+ def newScript(@Nullable params: Map[String, Object]): ExecutableScript = {
+ val lat: Double = if (params == null) 1 else XContentMapValues.nodeDoubleValue(params.get("lat"), 0);
+ val lon: Double = if (params == null) 1 else XContentMapValues.nodeDoubleValue(params.get("lon"), 0);
+ val weight1: Float = if(params == null) 1 else XContentMapValues.nodeFloatValue(params.get("weight1"), 5000.0f);
+ val weight2: Float = if(params == null) 1 else XContentMapValues.nodeFloatValue(params.get("weight2"), 0.05f);
+ return new CombinedDistanceDocumentScorerSearchScript(lat, lon, weight1, weight2);
+ }
+}
View
38 src/test/scala/com/foursquare/slashem/ElasticQueryTest.scala
@@ -1,4 +1,5 @@
package com.foursquare.slashem
+import com.foursquare.elasticsearch.scorer.FourSquareScorePlugin
import com.foursquare.slashem._
import com.twitter.util.Duration
@@ -14,6 +15,7 @@ import org.scalacheck.Arbitrary.arbitrary
import org.specs.SpecsMatchers
import org.specs.matcher.ScalaCheckMatchers
+//import org.elasticsearch.common.settings.ImmutableSettings
import org.elasticsearch.node.NodeBuilder._
import org.elasticsearch.node.Node
import org.elasticsearch.client.Requests;
@@ -67,7 +69,7 @@ class ElasticQueryTest extends SpecsMatchers with ScalaCheckMatchers {
def testRecipGeoBoostTimeout {
val geoLat = 74
val geoLong = -31
- val r = ESimpleGeoPanda where (_.name contains "lolerskates") scoreBoostField(_.pos recipSqeGeoDistance(geoLat, geoLong, 1, 5000, 1)) fetch(Duration(0,TimeUnit.MILLISECONDS))
+ val r = ESimpleGeoPanda where (_.name contains "lolerskates") scoreBoostField(_.point recipSqeGeoDistance(geoLat, geoLong, 1, 5000, 1)) fetch(Duration(0,TimeUnit.MILLISECONDS))
}
@Test
@@ -183,7 +185,7 @@ class ElasticQueryTest extends SpecsMatchers with ScalaCheckMatchers {
}
@Test
def geoOrderDesc {
- var r = ESimpleGeoPanda where (_.name contains "ordertest") complexOrderDesc(_.pos sqeGeoDistance(74.0,-31.0)) fetch()
+ var r = ESimpleGeoPanda where (_.name contains "ordertest") complexOrderDesc(_.point sqeGeoDistance(74.0,-31.0)) fetch()
Assert.assertEquals(2,r.response.results.length)
val doc0 = r.response.oidScorePair.apply(0)
val doc1= r.response.oidScorePair.apply(1)
@@ -192,7 +194,7 @@ class ElasticQueryTest extends SpecsMatchers with ScalaCheckMatchers {
}
@Test
def geoOrderAsc {
- var r = ESimpleGeoPanda where (_.name contains "ordertest") complexOrderAsc(_.pos sqeGeoDistance(74.0,-31.0)) fetch()
+ var r = ESimpleGeoPanda where (_.name contains "ordertest") complexOrderAsc(_.point sqeGeoDistance(74.0,-31.0)) fetch()
Assert.assertEquals(2,r.response.results.length)
val doc0 = r.response.oidScorePair.apply(0)
val doc1= r.response.oidScorePair.apply(1)
@@ -201,7 +203,7 @@ class ElasticQueryTest extends SpecsMatchers with ScalaCheckMatchers {
}
@Test
def geoOrderIntAsc {
- var r = ESimpleGeoPanda where (_.name contains "ordertest") complexOrderAsc(_.pos sqeGeoDistance(74,-31)) fetch()
+ var r = ESimpleGeoPanda where (_.name contains "ordertest") complexOrderAsc(_.point sqeGeoDistance(74,-31)) fetch()
Assert.assertEquals(2,r.response.results.length)
val doc0 = r.response.oidScorePair.apply(0)
val doc1= r.response.oidScorePair.apply(1)
@@ -277,7 +279,7 @@ class ElasticQueryTest extends SpecsMatchers with ScalaCheckMatchers {
val geoLat = 74
val geoLong = -31
val r1 = ESimpleGeoPanda where (_.name contains "lolerskates") fetch()
- val r2 = ESimpleGeoPanda where (_.name contains "lolerskates") scoreBoostField(_.pos sqeGeoDistance(geoLat, geoLong)) fetch()
+ val r2 = ESimpleGeoPanda where (_.name contains "lolerskates") scoreBoostField(_.point sqeGeoDistance(geoLat, geoLong)) fetch()
Assert.assertEquals(r1.response.results.length,2)
Assert.assertEquals(r2.response.results.length,2)
Assert.assertTrue(r2.response.results.apply(0).score.value > r1.response.results.apply(0).score.value)
@@ -287,17 +289,17 @@ class ElasticQueryTest extends SpecsMatchers with ScalaCheckMatchers {
//Test GeoBoosting. Note will actually make further away document come up first
val geoLat = 74
val geoLong = -31
- val r = ESimpleGeoPanda where (_.name contains "lolerskates") scoreBoostField(_.pos sqeGeoDistance(geoLat, geoLong)) fetch()
+ val r = ESimpleGeoPanda where (_.name contains "lolerskates") scoreBoostField(_.point sqeGeoDistance(geoLat, geoLong)) fetch()
Assert.assertEquals(r.response.results.length,2)
- Assert.assertEquals(r.response.results.apply(0).pos.value._1,74.0,0.9)
+ Assert.assertEquals(r.response.results.apply(0).point.value._1,74.0,0.9)
}
@Test
def testRecipGeoBoost {
val geoLat = 74
val geoLong = -31
val r1 = ESimpleGeoPanda where (_.name contains "lolerskates") fetch()
- val r2 = ESimpleGeoPanda where (_.name contains "lolerskates") scoreBoostField(_.pos recipSqeGeoDistance(geoLat, geoLong, 1, 5000, 1)) fetch()
+ val r2 = ESimpleGeoPanda where (_.name contains "lolerskates") scoreBoostField(_.point recipSqeGeoDistance(geoLat, geoLong, 1, 5000, 1)) fetch()
Assert.assertEquals(r1.response.results.length,2)
Assert.assertEquals(r2.response.results.length,2)
Assert.assertTrue(r2.response.results.apply(0).score.value > r1.response.results.apply(0).score.value)
@@ -404,13 +406,17 @@ class ElasticQueryTest extends SpecsMatchers with ScalaCheckMatchers {
Assert.assertEquals(res1.response.results.length, 1)
}
- @Test
def testFilters {
// grab 2 results, filter to 1
val res1 = ESimplePanda where (_.hugenums contains 1L) filter(_.nicknamesString in List("jerry")) fetch()
Assert.assertEquals(res1.response.results.length, 1)
}
+ def testCustomScoreScripts {
+ val params: Map[String, Any] = Map("lat" -> -31.1, "lon" -> 74.0, "weight" -> 2000, "weight2" -> 0.03)
+ val response1 = ESimpleGeoPanda where(_.name contains "lolerskates") customScore("distance_score_magic", params) fetch()
+ Assert.assertEquals(response1.response.results.length, 2)
+ }
@Before
def hoboPrepIndex() {
@@ -437,6 +443,8 @@ class ElasticQueryTest extends SpecsMatchers with ScalaCheckMatchers {
println("Error creating the regular index, may allready exist ("+e+")")
}
}
+ val plugin = new FourSquareScorePlugin()
+
//Set up the geo panda index
val geoClient = ESimpleGeoPanda.meta.client
try {
@@ -447,7 +455,7 @@ class ElasticQueryTest extends SpecsMatchers with ScalaCheckMatchers {
val mapping = """
{ "slashemdoc" :{
"properties" : {
- "pos" : { type: "geo_point" }
+ "point" : { type: "geo_point" }
}
}}"""
val mappingReq = Requests.putMappingRequest(ESimpleGeoPanda.meta.indexName).source(mapping).`type`("slashemdoc")
@@ -459,16 +467,18 @@ class ElasticQueryTest extends SpecsMatchers with ScalaCheckMatchers {
val geodoc1 = geoClient.prepareIndex(ESimpleGeoPanda.meta.indexName,ESimpleGeoPanda.meta.docType,"4c809f4251ada1cdc3790b10").setSource(jsonBuilder()
.startObject()
.field("name","lolerskates")
- .field("pos",74.0,-31.1)
+ .field("point",74.0,-31.1)
.field("id","4c809f4251ada1cdc3790b10")
+ .field("decayedPopularity1", .5)
.endObject()
).execute()
.actionGet();
val geodoc2 = geoClient.prepareIndex(ESimpleGeoPanda.meta.indexName,ESimpleGeoPanda.meta.docType,"4c809f4251ada1cdc3790b11").setSource(jsonBuilder()
.startObject()
.field("name","lolerskates")
.field("id","4c809f4251ada1cdc3790b11")
- .field("pos",74.0,-31.0)
+ .field("point",74.0,-31.0)
+ .field("decayedPopularity1", 21.2)
.endObject()
).execute()
.actionGet();
@@ -569,15 +579,15 @@ class ElasticQueryTest extends SpecsMatchers with ScalaCheckMatchers {
.startObject()
.field("name","ordertest")
.field("id","4c809f4251ada1cdc3790b16")
- .field("pos",74.0,-32.0)
+ .field("point",74.0,-32.0)
.endObject()
).execute()
.actionGet();
val geoOrderdoc2 = geoClient.prepareIndex(ESimpleGeoPanda.meta.indexName,ESimpleGeoPanda.meta.docType,"4c809f4251ada1cdc3790b17").setSource(jsonBuilder()
.startObject()
.field("name","ordertest")
.field("id","4c809f4251ada1cdc3790b17")
- .field("pos",74.0,-31.0)
+ .field("point",74.0,-31.0)
.endObject()
).execute()
.actionGet();
View
3 src/test/scala/com/foursquare/slashem/ElasticTest.scala
@@ -40,5 +40,6 @@ class ESimpleGeoPanda extends ElasticSchema[ESimpleGeoPanda] {
object id extends SlashemObjectIdField(this)
object name extends SlashemStringField(this)
object score extends SlashemDoubleField(this)
- object pos extends SlashemPointField(this)
+ object point extends SlashemPointField(this)
+ object decayedPopularity1 extends SlashemDoubleField(this)
}

0 comments on commit 1d7cf75

Please sign in to comment.