forked from microsoft/SynapseML
-
Notifications
You must be signed in to change notification settings - Fork 0
/
DoubleMLEstimator.scala
312 lines (271 loc) · 12.7 KB
/
DoubleMLEstimator.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.
package com.microsoft.azure.synapse.ml.causal
import com.microsoft.azure.synapse.ml.codegen.Wrappable
import com.microsoft.azure.synapse.ml.train.{TrainClassifier, TrainRegressor}
import com.microsoft.azure.synapse.ml.core.schema.{DatasetExtensions, SchemaConstants}
import com.microsoft.azure.synapse.ml.core.utils.StopWatch
import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging
import com.microsoft.azure.synapse.ml.stages.DropColumns
import jdk.jfr.Experimental
import org.apache.commons.math3.stat.descriptive.rank.Percentile
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Estimator, Model, Pipeline}
import org.apache.spark.ml.classification.ProbabilisticClassifier
import org.apache.spark.ml.regression.{GeneralizedLinearRegression, Regressor}
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.param.{DoubleArrayParam, ParamMap}
import org.apache.spark.ml.param.shared.{HasPredictionCol, HasProbabilityCol, HasRawPredictionCol, HasWeightCol}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.StructType
import scala.concurrent.Future
/** Double ML estimators. The estimator follows the two stage process,
* where a set of nuisance functions are estimated in the first stage in a cross-fitting manner
* and a final stage estimates the average treatment effect (ATE) model.
* Our goal is to estimate the constant marginal ATE Theta(X)
*
* In this estimator, the ATE is estimated by using the following estimating equations:
* .. math ::
* Y - \\E[Y | X, W] = \\Theta(X) \\cdot (T - \\E[T | X, W]) + \\epsilon
*
* Thus if we estimate the nuisance functions :math:`q(X, W) = \\E[Y | X, W]` and
* :math:`f(X, W)=\\E[T | X, W]` in the first stage, we can estimate the final stage ate for each
* treatment t, by running a regression, minimizing the residual on residual square loss,
* estimating Theta(X) is a final regression problem, regressing tilde{Y} on X and tilde{T})
*
* .. math ::
* \\hat{\\theta} = \\arg\\min_{\\Theta}\
* \E_n\\left[ (\\tilde{Y} - \\Theta(X) \\cdot \\tilde{T})^2 \\right]
*
* Where
* `\\tilde{Y}=Y - \\E[Y | X, W]` and :math:`\\tilde{T}=T-\\E[T | X, W]` denotes the
* residual outcome and residual treatment.
*
* The nuisance function :math:`q` is a simple machine learning problem and
* user can use setOutcomeModel to set an arbitrary sparkML model
* that is internally used to solve this problem
*
* The problem of estimating the nuisance function :math:`f` is also a machine learning problem and
* user can use setTreatmentModel to set an arbitrary sparkML model
* that is internally used to solve this problem.
*
*/
//noinspection ScalaDocParserErrorInspection,ScalaDocUnclosedTagWithoutParser
class DoubleMLEstimator(override val uid: String)
extends Estimator[DoubleMLModel] with ComplexParamsWritable
with DoubleMLParams with SynapseMLLogging with Wrappable {
logClass()
def this() = this(Identifiable.randomUID("DoubleMLEstimator"))
/** Fits the DoubleML model.
*
* @param dataset The input dataset to train.
* @return The trained DoubleML model, from which you can get Ate and Ci values
*/
override def fit(dataset: Dataset[_]): DoubleMLModel = {
logFit({
require(getMaxIter > 0, "maxIter should be larger than 0!")
if (get(weightCol).isDefined) {
getTreatmentModel match {
case w: HasWeightCol => w.set(w.weightCol, getWeightCol)
case _ => throw new Exception("""The selected treatment model does not support sample weight,
but the weightCol parameter was set for the DoubleMLEstimator.
Please select a treatment model that supports sample weight.""".stripMargin)
}
getOutcomeModel match {
case w: HasWeightCol => w.set(w.weightCol, getWeightCol)
case _ => throw new Exception("""The selected outcome model does not support sample weight,
but the weightCol parameter was set for the DoubleMLEstimator.
Please select a outcome model that supports sample weight.""".stripMargin)
}
}
// sampling with replacement to redraw data and get TE value
// Run it for multiple times in parallel, get a number of TE values,
// Use average as Ate value, and 2.5% low end, 97.5% high end as Ci value
// Create execution context based on $(parallelism)
log.info(s"Parallelism: $getParallelism")
val executionContext = getExecutionContextProxy
val ateFutures =(1 to getMaxIter).toArray.map { index =>
Future[Option[Double]] {
log.info(s"Executing ATE calculation on iteration: $index")
// If the algorithm runs over 1 iteration, do not bootstrap from dataset,
// otherwise, redraw sample with replacement
val redrewDF = if (getMaxIter == 1) dataset else dataset.sample(withReplacement = true, fraction = 1)
val ate: Option[Double] =
try {
val totalTime = new StopWatch
val oneAte = totalTime.measure {
trainInternal(redrewDF)
}
log.info(s"Completed ATE calculation on iteration $index and got ATE value: $oneAte, " +
s"time elapsed: ${totalTime.elapsed() / 6e10} minutes")
Some(oneAte)
} catch {
case ex: Throwable =>
log.warn(s"ATE calculation got exception on iteration $index with the redrew sample data. " +
s"Exception details: $ex")
None
}
ate
}(executionContext)
}
val ates = awaitFutures(ateFutures).flatten
if (ates.isEmpty) {
throw new Exception("ATE calculation failed on all iterations. Please check the log for details.")
}
val dmlModel = this.copyValues(new DoubleMLModel(uid)).setRawTreatmentEffects(ates.toArray)
dmlModel
})
}
//scalastyle:off method.length
private def trainInternal(dataset: Dataset[_]): Double = {
def getModel(model: Estimator[_ <: Model[_]], labelColName: String) = {
model match {
case classifier: ProbabilisticClassifier[_, _, _] => (
new TrainClassifier()
.setModel(model)
.setLabelCol(labelColName),
classifier.getProbabilityCol
)
case regressor: Regressor[_, _, _] => (
new TrainRegressor()
.setModel(model)
.setLabelCol(labelColName),
regressor.getPredictionCol
)
}
}
def getPredictedCols(model: Estimator[_ <: Model[_]]): Array[String] = {
val rawPredictionCol = model match {
case rp: HasRawPredictionCol => Some(rp.getRawPredictionCol)
case _ => None
}
val predictionCol = model match {
case p: HasPredictionCol => Some(p.getPredictionCol)
case _ => None
}
val probabilityCol = model match {
case pr: HasProbabilityCol => Some(pr.getProbabilityCol)
case _ => None
}
(rawPredictionCol :: predictionCol :: probabilityCol :: Nil).flatten.toArray
}
val (treatmentEstimator, treatmentResidualPredictionColName) = getModel(
getTreatmentModel.copy(getTreatmentModel.extractParamMap()),
getTreatmentCol
)
val treatmentPredictionColsToDrop = getPredictedCols(getTreatmentModel)
val (outcomeEstimator, outcomeResidualPredictionColName) = getModel(
getOutcomeModel.copy(getOutcomeModel.extractParamMap()),
getOutcomeCol
)
val outcomePredictionColsToDrop = getPredictedCols(getOutcomeModel)
val treatmentResidualCol = DatasetExtensions.findUnusedColumnName(SchemaConstants.TreatmentResidualColumn, dataset)
val outcomeResidualCol = DatasetExtensions.findUnusedColumnName(SchemaConstants.OutcomeResidualColumn, dataset)
val treatmentResidualVecCol = DatasetExtensions.findUnusedColumnName("treatmentResidualVec", dataset)
def calculateResiduals(train: Dataset[_], test: Dataset[_]): DataFrame = {
val treatmentModel = treatmentEstimator.setInputCols(train.columns.filterNot(_ == getOutcomeCol)).fit(train)
val outcomeModel = outcomeEstimator.setInputCols(train.columns.filterNot(_ == getTreatmentCol)).fit(train)
val treatmentResidual =
new ResidualTransformer()
.setObservedCol(getTreatmentCol)
.setPredictedCol(treatmentResidualPredictionColName)
.setOutputCol(treatmentResidualCol)
val dropTreatmentPredictedColumns = new DropColumns().setCols(treatmentPredictionColsToDrop.toArray)
val outcomeResidual =
new ResidualTransformer()
.setObservedCol(getOutcomeCol)
.setPredictedCol(outcomeResidualPredictionColName)
.setOutputCol(outcomeResidualCol)
val dropOutcomePredictedColumns = new DropColumns().setCols(outcomePredictionColsToDrop.toArray)
val treatmentResidualVA =
new VectorAssembler()
.setInputCols(Array(treatmentResidualCol))
.setOutputCol(treatmentResidualVecCol)
.setHandleInvalid("skip")
val pipeline = new Pipeline().setStages(Array(
treatmentModel, treatmentResidual, dropTreatmentPredictedColumns,
outcomeModel, outcomeResidual, dropOutcomePredictedColumns,
treatmentResidualVA))
pipeline.fit(test).transform(test)
}
// Note, we perform these steps to get ATE
/*
1. Split sample, e.g. 50/50
2. Use the first split to fit the treatment model and the outcome model.
3. Use the two models to fit a residual model on the second split.
4. Cross-fit treatment and outcome models with the second split, residual model with the first split.
5. Average slopes from the two residual models.
*/
val splits = dataset.randomSplit(getSampleSplitRatio)
val (train, test) = (splits(0).cache, splits(1).cache)
val residualsDF1 = calculateResiduals(train, test)
val residualsDF2 = calculateResiduals(test, train)
// Average slopes from the two residual models.
val regressor = new GeneralizedLinearRegression()
.setLabelCol(outcomeResidualCol)
.setFeaturesCol(treatmentResidualVecCol)
.setFamily("gaussian")
.setLink("identity")
.setFitIntercept(false)
val coefficients = Array(residualsDF1, residualsDF2).map(regressor.fit).map(_.coefficients(0))
val ate = coefficients.sum / coefficients.length
Seq(train, test).foreach(_.unpersist)
ate
}
override def copy(extra: ParamMap): Estimator[DoubleMLModel] = {
defaultCopy(extra)
}
@DeveloperApi
override def transformSchema(schema: StructType): StructType = {
DoubleMLEstimator.validateTransformSchema(schema)
}
}
object DoubleMLEstimator extends ComplexParamsReadable[DoubleMLEstimator] {
def validateTransformSchema(schema: StructType): StructType = {
StructType(schema.fields)
}
}
/** Model produced by [[DoubleMLEstimator]]. */
class DoubleMLModel(val uid: String)
extends Model[DoubleMLModel] with DoubleMLParams with ComplexParamsWritable with Wrappable with SynapseMLLogging {
logClass()
override protected lazy val pyInternalWrapper = true
def this() = this(Identifiable.randomUID("DoubleMLModel"))
val rawTreatmentEffects = new DoubleArrayParam(
this,
"rawTreatmentEffects",
"raw treatment effect results for all iterations")
def getRawTreatmentEffects: Array[Double] = $(rawTreatmentEffects)
def setRawTreatmentEffects(v: Array[Double]): this.type = set(rawTreatmentEffects, v)
def getAvgTreatmentEffect: Double = {
val finalAte = $(rawTreatmentEffects).sum / $(rawTreatmentEffects).length
finalAte
}
def getConfidenceInterval: Array[Double] = {
val ciLowerBound = percentile($(rawTreatmentEffects), 100 * (1 - getConfidenceLevel))
val ciUpperBound = percentile($(rawTreatmentEffects), getConfidenceLevel * 100)
Array(ciLowerBound, ciUpperBound)
}
private def percentile(values: Seq[Double], quantile: Double): Double = {
val sortedValues = values.sorted
val percentile = new Percentile()
percentile.setData(sortedValues.toArray)
percentile.evaluate(quantile)
}
override def copy(extra: ParamMap): DoubleMLModel = defaultCopy(extra)
/**
* :: Experimental ::
* DoubleMLEstimator transform function is still experimental, and its behavior could change in the future.
*/
@Experimental
override def transform(dataset: Dataset[_]): DataFrame = {
logTransform[DataFrame]({
dataset.toDF()
})
}
@DeveloperApi
override def transformSchema(schema: StructType): StructType =
StructType(schema.fields)
}
object DoubleMLModel extends ComplexParamsReadable[DoubleMLModel]