/
BinsTransformerSpec.scala
67 lines (50 loc) · 2.21 KB
/
BinsTransformerSpec.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
package com.collective.modelmatrix.transform
import com.collective.modelmatrix.{ModelFeature, ModelMatrixAccess, TestSparkContext}
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import org.scalatest.FlatSpec
import scala.util.Random
import scalaz.syntax.either._
import scalaz.{-\/, \/-}
class BinsTransformerSpec extends FlatSpec with TestSparkContext {
val sqlContext = ModelMatrixAccess.sqlContext(sc)
val schema = StructType(Seq(
StructField("adv_site", StringType),
StructField("pct_click", DoubleType)
))
val rnd = new Random()
def rand(p: Double): Double = {
p + rnd.nextInt(100).toDouble / 10000
}
val input =
Seq.fill(20)(Row("cnn.com", rand(0.5))) ++
Seq.fill(20)(Row("bbc.com", rand(0.6))) ++
Seq.fill(20)(Row("hbo.com", rand(0.7))) ++
Seq.fill(20)(Row("mashable.com", rand(0.8))) ++
Seq.fill(20)(Row("reddit.com", rand(0.9))) ++
Seq.fill(20)(Row("ycombinator.com", rand(1.0)))
val isActive = true
val withAllOther = true
val adSite = ModelFeature(isActive, "Ad", "ad_site", "adv_site", Bins(3, 0, 0))
val sitePerformance = ModelFeature(isActive, "Site", "site_performance", "pct_click", Bins(3, 0, 0))
val df = sqlContext.createDataFrame(sc.parallelize(input), schema)
val transformer = new BinsTransformer(Transformer.extractFeatures(df, Seq(adSite, sitePerformance)) match {
case -\/(err) => sys.error(s"Can't extract features: $err")
case \/-(suc) => suc
})
"Bins Transformer" should "support integer typed model feature" in {
val valid = transformer.validate(sitePerformance)
assert(valid == TypedModelFeature(sitePerformance, DoubleType).right)
val typed = valid.toOption.get
val columns = transformer.transform(typed)
assert(columns.size == 3)
}
it should "fail if feature column doesn't exists" in {
val failed = transformer.validate(sitePerformance.copy(feature = "site_clicks"))
assert(failed == FeatureTransformationError.FeatureColumnNotFound("site_clicks").left)
}
it should "fail if column type is not supported" in {
val failed = transformer.validate(adSite)
assert(failed == FeatureTransformationError.UnsupportedTransformDataType("ad_site", StringType, Bins(3, 0, 0)).left)
}
}