/
IndexTransformerSpec.scala
109 lines (87 loc) · 4.32 KB
/
IndexTransformerSpec.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
package com.collective.modelmatrix.transform
import com.collective.modelmatrix.CategoricalColumn.{AllOther, CategoricalValue}
import com.collective.modelmatrix.{ModelMatrixEncoding, ModelMatrixAccess, ModelFeature, TestSparkContext}
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.scalatest.FlatSpec
import scodec.bits.ByteVector
import scala.util.Random
import scalaz.{\/-, -\/}
import scalaz.syntax.either._
class IndexTransformerSpec extends FlatSpec with TestSparkContext {
val sqlContext = ModelMatrixAccess.sqlContext(sc)
val schema = StructType(Seq(
StructField("adv_site", StringType)
))
val input = new Random().shuffle(
Seq.fill(5000)(Row("cnn.com")) ++
Seq.fill(4000)(Row("bbc.com")) ++
Seq.fill(400)(Row("hbo.com")) ++
Seq.fill(200)(Row("mashable.com")) ++
// first 4 sites supports more than 2% of data set
Seq.fill(100)(Row("gizmodo.com")) ++
Seq.fill(100)(Row("reddit.com")) ++
Seq.fill(100)(Row("amc.com")) ++
Seq.fill(100)(Row("msnbc.com")) ++
// null columns should be skipped
Seq.fill(100)(Row(null))
)
val isActive = true
val withAllOther = true
val adSite = ModelFeature(isActive, "Ad", "ad_site", "adv_site", Index(2, withAllOther))
val df = sqlContext.createDataFrame(sc.parallelize(input), schema)
val transformer = new IndexTransformer(Transformer.extractFeatures(df, Seq(adSite)) match {
case -\/(err) => sys.error(s"Can't extract features: $err")
case \/-(suc) => suc
})
"Index Transformer" should "support string typed model feature" in {
val valid = transformer.validate(adSite)
assert(valid == TypedModelFeature(adSite, StringType).right)
}
it should "fail if feature column doesn't exists" in {
val failed = transformer.validate(adSite.copy(feature = "adv_site"))
assert(failed == FeatureTransformationError.FeatureColumnNotFound("adv_site").left)
}
it should "calculate correct categorical columns with all other" in {
val typed = transformer.validate(adSite).toOption.get
val columns = transformer.transform(typed)
assert(columns.size == 4)
assert(columns(0).columnId == 1)
assert(columns(0).isInstanceOf[CategoricalValue])
assert(columns(0).asInstanceOf[CategoricalValue].sourceName == "cnn.com")
assert(columns(0).asInstanceOf[CategoricalValue].sourceValue == ModelMatrixEncoding.encode("cnn.com"))
assert(columns(0).asInstanceOf[CategoricalValue].count == 5000)
assert(columns(0).asInstanceOf[CategoricalValue].cumulativeCount == 5000)
assert(columns(1).columnId == 2)
assert(columns(1).isInstanceOf[CategoricalValue])
assert(columns(1).asInstanceOf[CategoricalValue].sourceName == "bbc.com")
assert(columns(1).asInstanceOf[CategoricalValue].sourceValue == ModelMatrixEncoding.encode("bbc.com"))
assert(columns(1).asInstanceOf[CategoricalValue].count == 4000)
assert(columns(1).asInstanceOf[CategoricalValue].cumulativeCount == 9000)
assert(columns(2).columnId == 3)
assert(columns(2).isInstanceOf[CategoricalValue])
assert(columns(2).asInstanceOf[CategoricalValue].sourceName == "hbo.com")
assert(columns(2).asInstanceOf[CategoricalValue].sourceValue == ModelMatrixEncoding.encode("hbo.com"))
assert(columns(2).asInstanceOf[CategoricalValue].count == 400)
assert(columns(2).asInstanceOf[CategoricalValue].cumulativeCount == 9400)
assert(columns(3).columnId == 4)
assert(columns(3).isInstanceOf[AllOther])
assert(columns(3).asInstanceOf[AllOther].count == 600)
assert(columns(3).asInstanceOf[AllOther].cumulativeCount == 10000)
}
it should "remove all other column" in {
val typed = transformer.validate(adSite.copy(transform = Index(2.0, allOther = false))).toOption.get
val columns = transformer.transform(typed)
assert(columns.size == 3)
}
it should "return less columns with higher support factor" in {
val typed = transformer.validate(adSite.copy(transform = Index(10.0, allOther = false))).toOption.get
val columns = transformer.transform(typed)
assert(columns.size == 2)
}
it should "return all columns with low support factor" in {
val typed = transformer.validate(adSite.copy(transform = Index(0.0001, withAllOther))).toOption.get
val columns = transformer.transform(typed)
assert(columns.size == 8)
}
}