-
Notifications
You must be signed in to change notification settings - Fork 821
/
SummarizeData.scala
236 lines (186 loc) · 8.46 KB
/
SummarizeData.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.
package com.microsoft.azure.synapse.ml.stages
import com.microsoft.azure.synapse.ml.codegen.Wrappable
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.{BooleanParam, DoubleParam, ParamMap}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.storage.StorageLevel
import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer
trait SummarizeDataParams extends Wrappable with DefaultParamsWritable {
/** Compute count statistics. Default is true.
* @group param
*/
final val counts: BooleanParam = new BooleanParam(this, "counts", "Compute count statistics")
setDefault(counts -> true)
/** @group getParam */
final def getCounts: Boolean = $(counts)
/** @group setParam */
def setCounts(value: Boolean): this.type = set(counts, value)
/** Compute basic statistics. Default is true.
* @group param
*/
final val basic: BooleanParam = new BooleanParam(this, "basic", "Compute basic statistics")
setDefault(basic, true)
/** @group getParam */
final def getBasic: Boolean = $(basic)
/** @group setParam */
def setBasic(value: Boolean): this.type = set(basic, value)
/** Compute sample statistics. Default is true.
* @group param
*/
final val sample: BooleanParam = new BooleanParam(this, "sample", "Compute sample statistics")
setDefault(sample, true)
/** @group getParam */
final def getSample: Boolean = $(sample)
/** @group setParam */
def setSample(value: Boolean): this.type = set(sample, value)
/** Compute percentiles. Default is true
* @group param
*/
final val percentiles: BooleanParam = new BooleanParam(this, "percentiles", "Compute percentiles")
setDefault(percentiles, true)
/** @group getParam */
final def getPercentiles: Boolean = $(percentiles)
/** @group setParam */
def setPercentiles(value: Boolean): this.type = set(percentiles, value)
/** Threshold for quantiles - 0 is exact
* @group param
*/
final val errorThreshold: DoubleParam =
new DoubleParam(this, "errorThreshold", "Threshold for quantiles - 0 is exact")
setDefault(errorThreshold, 0.0)
/** @group getParam */
final def getErrorThreshold: Double = $(errorThreshold)
/** @group setParam */
def setErrorThreshold(value: Double): this.type = set(errorThreshold, value)
protected def validateAndTransformSchema(schema: StructType): StructType = {
val columns = ListBuffer(SummarizeData.FeatureColumn)
if ($(counts)) columns ++= SummarizeData.CountFields
if ($(basic)) columns ++= SummarizeData.BasicFields
if ($(sample)) columns ++= SummarizeData.SampleFields
if ($(percentiles)) columns ++= SummarizeData.PercentilesFields
StructType(columns)
}
}
// UID should be overridden by driver for controlled identification at the DAG level
/** Compute summary statistics for the dataset. The following statistics are computed:
* - counts
* - basic
* - sample
* - percentiles
* - errorThreshold - error threshold for quantiles
* @param uid The id of the module
*/
class SummarizeData(override val uid: String)
extends Transformer
with SummarizeDataParams with SynapseMLLogging {
logClass(FeatureNames.Core)
def this() = this(Identifiable.randomUID("SummarizeData"))
override def transform(dataset: Dataset[_]): DataFrame = {
logTransform[DataFrame]({
val df = dataset.toDF()
// Some of these statistics are bad to compute
df.persist(StorageLevel.MEMORY_ONLY)
val subFrames = ListBuffer[DataFrame]()
if ($(counts)) subFrames += computeCounts(df)
if ($(basic)) subFrames += curriedBasic(df)
if ($(sample)) subFrames += sampleStats(df)
if ($(percentiles)) subFrames += curriedPerc(df)
df.unpersist(false)
val base = createJoinBase(df)
subFrames.foldLeft(base) { (z, dfi) => z.join(dfi, SummarizeData.FeatureColumnName) }
}, dataset.columns.length)
}
def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}
def copy(extra: ParamMap): SummarizeData = defaultCopy(extra)
private def computeCounts = computeOnAll(computeCountsImpl, SummarizeData.CountFields)
private def computeCountsImpl(col: String, df: DataFrame): Array[Double] = {
val column = df.col(col)
val dataType = df.schema(col).dataType
val mExpr = isnull(column) || (if (dataType.equals(BooleanType)) isnan(column.cast(DoubleType)) else isnan(column))
val countMissings = df.where(mExpr).count().toDouble
// approxCount returns Long which > Double!
val dExpr = approx_count_distinct(column)
val distinctCount = df.select(dExpr).first.getLong(0).toDouble
Array(df.count() - countMissings, distinctCount, countMissings)
}
private def sampleStats = computeOnNumeric(sampleStatsImpl, SummarizeData.SampleFields)
private def sampleStatsImpl(col: String, df: DataFrame): Array[Double] = {
val column = df.col(col)
val k = kurtosis(column)
val sk = skewness(column)
val v = variance(column)
val sd = stddev(column)
df.select(v, sd, sk, k).first.toSeq.map(_.asInstanceOf[Double]).toArray
}
private def curriedBasic = {
val quants = SummarizeData.BasicQuantiles
computeOnNumeric(quantStub(quants, $(errorThreshold)), SummarizeData.BasicFields)
}
private def curriedPerc = {
val quants = SummarizeData.PercentilesQuantiles
computeOnNumeric(quantStub(quants, $(errorThreshold)), SummarizeData.PercentilesFields)
}
private def quantStub(vals: Array[Double], err: Double) =
(cn: String, df: DataFrame) => df.stat.approxQuantile(cn, vals, err)
private def computeOnNumeric = computeColumnStats(sf => sf.dataType.isInstanceOf[NumericType]) _
private def computeOnAll = computeColumnStats(sf => true) _
private def allNaNs(l: Int): Array[Double] = Array.fill(l)(Double.NaN)
private def createJoinBase(df: DataFrame) = computeColumnStats(sf => false)((cn, df) => Array(), List())(df)
private def computeColumnStats
(p: StructField => Boolean)
(statFunc: (String, DataFrame) => Array[Double], newColumns: Seq[StructField])
(df: DataFrame): DataFrame = {
val emptyRow = allNaNs(newColumns.length)
val outList = df.schema.map(col => (col.name, if (p(col)) statFunc(col.name, df) else emptyRow))
val rows = outList.map { case (n, r) => Row.fromSeq(n +: r) }
val schema = SummarizeData.FeatureColumn +: newColumns
df.sparkSession.createDataFrame(rows.asJava, StructType(schema))
}
}
object SummarizeData extends DefaultParamsReadable[SummarizeData] {
object Statistic extends Enumeration {
type Statistic = Value
val Counts, Basic, Sample, Percentiles = Value
}
final val FeatureColumnName = "Feature"
final val FeatureColumn = StructField(FeatureColumnName, StringType, false)
final val PercentilesQuantiles = Array(0.005, 0.01, 0.05, 0.95, 0.99, 0.995)
final val PercentilesFields = List(
StructField("P0_5", DoubleType, true),
StructField("P1", DoubleType, true),
StructField("P5", DoubleType, true),
StructField("P95", DoubleType, true),
StructField("P99", DoubleType, true),
StructField("P99_5", DoubleType, true))
final val SampleFields = List(
StructField("Sample_Variance", DoubleType, true),
StructField("Sample_Standard_Deviation", DoubleType, true),
StructField("Sample_Skewness", DoubleType, true),
StructField("Sample_Kurtosis", DoubleType, true))
final val BasicQuantiles = Array(0, 0.25, 0.5, 0.75, 1)
final val BasicFields = List(
StructField("Min", DoubleType, true),
StructField("1st_Quartile", DoubleType, true),
StructField("Median", DoubleType, true),
StructField("3rd_Quartile", DoubleType, true),
StructField("Max", DoubleType, true)
//TODO: StructField("Range", DoubleType, true),
//TODO: StructField("Mean", DoubleType, true),
//TODO: StructField("Mean Deviation", DoubleType, true),
// Mode is JSON Array of modes - needs a little special treatment
//TODO: StructField("Mode", StringType, true))
)
final val CountFields = List(
StructField("Count", DoubleType, false),
StructField("Unique_Value_Count", DoubleType, false),
StructField("Missing_Value_Count", DoubleType, false))
}