New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: ICE/PDP explainer #1284
Merged
Merged
feat: ICE/PDP explainer #1284
Changes from 41 commits
Commits
Show all changes
42 commits
Select commit
Hold shift + click to select a range
aa166b5
Initial PDP version.
ezherdeva 151ef99
Apply suggestions
ezherdeva b47d410
Added ICE
ezherdeva 7d70110
Apply suggestions and fix
ezherdeva f5049e3
Added discrete
ezherdeva e6e985e
Added logic for discrete features
ezherdeva a23df5c
New logic (without unit tests)
ezherdeva 43d1648
WIP
memoryz 9b379e8
WIP
memoryz b9cbb7b
rebased the main branch
ezherdeva 5ba0bec
small fix
ezherdeva c0c9ddf
added some unit tests
ezherdeva 51e3d4f
added python code
ezherdeva bda7882
Update core/src/main/scala/com/microsoft/azure/synapse/ml/explainers/…
ezherdeva fa0aa6f
Update core/src/main/scala/com/microsoft/azure/synapse/ml/explainers/…
ezherdeva adc4301
Update core/src/main/scala/com/microsoft/azure/synapse/ml/explainers/…
ezherdeva 058f27b
fix1
ezherdeva f234890
Merge branch 'ezherdeva/ice_pdp' of https://github.com/ezherdeva/mmls…
ezherdeva 5d3d38e
Update core/src/main/scala/com/microsoft/azure/synapse/ml/explainers/…
ezherdeva b630d39
Merge branch 'ezherdeva/ice_pdp' of https://github.com/ezherdeva/mmls…
ezherdeva 172a050
Update core/src/main/scala/com/microsoft/azure/synapse/ml/explainers/…
ezherdeva 69486ed
fix 2
ezherdeva fd7d13b
Merge branch 'ezherdeva/ice_pdp' of https://github.com/ezherdeva/mmls…
ezherdeva 1d658d5
Fixed comments
ezherdeva 25ad8fa
fix comments
ezherdeva 8045357
fix comments 2
ezherdeva df4e6c6
Merge branch 'master' into ezherdeva/ice_pdp
ezherdeva 2c207d3
last fix
ezherdeva fa87e5c
added copyright to py files
ezherdeva bce0a92
Merge branch 'master' into ezherdeva/ice_pdp
ezherdeva 77b6267
Update src/test/scala/com/microsoft/azure/synapse/ml/core/test/fuzzin…
ezherdeva 7c25c57
fix 2
ezherdeva 98173d5
Merge branch 'ezherdeva/ice_pdp' of https://github.com/ezherdeva/mmls…
ezherdeva a11c718
fix python issue
ezherdeva 6483daf
fix python issue (small fix)
ezherdeva ef2c35e
fixed python issue
ezherdeva 8c3a6dc
fixed comments and add more docs
ezherdeva 5b53fa5
Merge branch 'master' into ezherdeva/ice_pdp
mhamilton723 e492014
fix comments
ezherdeva 0043fb6
Merge branch 'ezherdeva/ice_pdp' of https://github.com/ezherdeva/mmls…
ezherdeva 61624de
fix code style
ezherdeva 9254b8d
Merge branch 'master' into ezherdeva/ice_pdp
mhamilton723 File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
48 changes: 48 additions & 0 deletions
48
core/src/main/python/synapse/ml/explainers/ICETransformer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Copyright (C) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. See LICENSE in project root for information. | ||
|
||
from synapse.ml.explainers._ICETransformer import _ICETransformer | ||
from pyspark.ml.common import inherit_doc | ||
from typing import List, Dict, Union | ||
|
||
@inherit_doc | ||
class ICETransformer(_ICETransformer): | ||
def setCategoricalFeatures(self, values: List[Union[str, Dict]]): | ||
""" | ||
Args: | ||
values: The list of values that represent categorical features to explain. | ||
Values are list of dicts with parameters or just a list of names of categorical features | ||
""" | ||
if len(values) == 0: | ||
pass | ||
else: | ||
list_values = [] | ||
for value in values: | ||
if isinstance(value, str): | ||
list_values.append({"name": value}) | ||
elif isinstance(value, dict): | ||
list_values.append(value) | ||
else: | ||
pass | ||
self._java_obj.setCategoricalFeaturesPy(list_values) | ||
return self | ||
|
||
def setNumericFeatures(self, values: List[Union[str, Dict]]): | ||
""" | ||
Args: | ||
values: The list of values that represent numeric features to explain. | ||
Values are list of dicts with parameters or just a list of names of numeric features | ||
""" | ||
if len(values) == 0: | ||
pass | ||
else: | ||
list_values = [] | ||
for value in values: | ||
if isinstance(value, str): | ||
list_values.append({"name": value}) | ||
elif isinstance(value, dict): | ||
list_values.append(value) | ||
else: | ||
pass | ||
self._java_obj.setNumericFeaturesPy(list_values) | ||
return self |
278 changes: 278 additions & 0 deletions
278
core/src/main/scala/com/microsoft/azure/synapse/ml/explainers/ICEExplainer.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,278 @@ | ||
// Copyright (C) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. See LICENSE in project root for information. | ||
|
||
package com.microsoft.azure.synapse.ml.explainers | ||
|
||
import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions | ||
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Transformer} | ||
import org.apache.spark.ml.param.{ParamMap, ParamValidators, Params, _} | ||
import org.apache.spark.ml.util.Identifiable | ||
import org.apache.spark.sql.functions._ | ||
import org.apache.spark.sql.types._ | ||
import org.apache.spark.sql.{DataFrame, Dataset, Row} | ||
import org.apache.spark.ml.stat.Summarizer | ||
import com.microsoft.azure.synapse.ml.codegen.Wrappable | ||
import scala.jdk.CollectionConverters.asScalaBufferConverter | ||
|
||
|
||
trait ICEFeatureParams extends Params with HasNumSamples { | ||
|
||
val averageKind = "average" | ||
val individualKind = "individual" | ||
|
||
val categoricalFeatures = new TypedArrayParam[ICECategoricalFeature] ( | ||
this, | ||
"categoricalFeatures", | ||
"The list of categorical features to explain.", | ||
_.forall(_.validate) | ||
) | ||
|
||
def setCategoricalFeatures(values: Seq[ICECategoricalFeature]): this.type = this.set(categoricalFeatures, values) | ||
def getCategoricalFeatures: Seq[ICECategoricalFeature] = $(categoricalFeatures) | ||
|
||
def setCategoricalFeaturesPy(values: java.util.List[java.util.HashMap[String, Any]]): this.type = { | ||
val features: Seq[ICECategoricalFeature] = values.asScala.map(f => ICECategoricalFeature.fromMap(f)) | ||
this.setCategoricalFeatures(features) | ||
} | ||
|
||
val numericFeatures = new TypedArrayParam[ICENumericFeature] ( | ||
this, | ||
"numericFeatures", | ||
"The list of numeric features to explain.", | ||
_.forall(_.validate) | ||
) | ||
|
||
def setNumericFeatures(values: Seq[ICENumericFeature]): this.type = this.set(numericFeatures, values) | ||
def getNumericFeatures: Seq[ICENumericFeature] = $(numericFeatures) | ||
|
||
def setNumericFeaturesPy(values: java.util.List[java.util.HashMap[String, Any]]): this.type = { | ||
val features: Seq[ICENumericFeature] = values.asScala.map(ICENumericFeature.fromMap) | ||
this.setNumericFeatures(features) | ||
} | ||
|
||
val kind = new Param[String] ( | ||
this, | ||
"kind", | ||
"Whether to return the partial dependence plot (PDP) averaged across all the samples in the " + | ||
"dataset or individual feature importance (ICE) per sample. " + | ||
"Allowed values are \"average\" for PDP and \"individual\" for ICE.", | ||
ParamValidators.inArray(Array(averageKind, individualKind)) | ||
) | ||
|
||
def getKind: String = $(kind) | ||
def setKind(value: String): this.type = set(kind, value) | ||
|
||
setDefault(kind -> "individual", | ||
numericFeatures -> Seq.empty[ICENumericFeature], | ||
categoricalFeatures -> Seq.empty[ICECategoricalFeature]) | ||
} | ||
|
||
/** | ||
* ICETransformer displays the model dependence on specified features with the given dataframe | ||
* as background dataset. It supports 2 types of plots: individual - dependence per instance and | ||
* average - across all the samples in the dataset. | ||
* Note: This transformer only supports one-way dependence plot. | ||
*/ | ||
@org.apache.spark.annotation.Experimental | ||
class ICETransformer(override val uid: String) extends Transformer | ||
with HasExplainTarget | ||
with HasModel | ||
with ICEFeatureParams | ||
with Wrappable | ||
with ComplexParamsWritable { | ||
|
||
override protected lazy val pyInternalWrapper = true | ||
|
||
def this() = this(Identifiable.randomUID("ICETransformer")) | ||
|
||
private def calcDependence(df: DataFrame, idCol: String, targetClassesColumn: String, | ||
feature: String, values: Array[_], outputColName: String): DataFrame = { | ||
|
||
val dataType = df.schema(feature).dataType | ||
val explodeFunc = explode(array(values.map(v => lit(v)): _*).cast(ArrayType(dataType))) | ||
|
||
val predicted = getModel.transform(df.withColumn(feature, explodeFunc)) | ||
val targetCol = DatasetExtensions.findUnusedColumnName("target", predicted) | ||
val dependenceCol = DatasetExtensions.findUnusedColumnName("feature__dependence", predicted) | ||
|
||
val explainTarget = extractTarget(predicted.schema, targetClassesColumn) | ||
val result = predicted.withColumn(targetCol, explainTarget) | ||
|
||
getKind.toLowerCase match { | ||
case `averageKind` => | ||
// PDP output schema: 1 row * 1 col (pdp for the given feature: feature_value -> explanations) | ||
result | ||
.groupBy(feature) | ||
.agg(Summarizer.mean(col(targetCol)).alias(dependenceCol)) | ||
.agg( | ||
map_from_arrays( | ||
collect_list(feature), | ||
collect_list(dependenceCol) | ||
).alias(outputColName) | ||
) | ||
|
||
case `individualKind` => | ||
// ICE output schema: n rows * 2 cols (idCol + ice for the given feature: map(feature_value -> explanations)) | ||
result | ||
.groupBy(idCol) | ||
.agg( | ||
map_from_arrays( | ||
collect_list(feature), | ||
collect_list(targetCol) | ||
).alias(outputColName) | ||
) | ||
} | ||
} | ||
|
||
def transform(ds: Dataset[_]): DataFrame = { | ||
transformSchema(ds.schema) | ||
val df = ds.toDF | ||
val idCol = DatasetExtensions.findUnusedColumnName("idCol", df) | ||
val targetClasses = DatasetExtensions.findUnusedColumnName("targetClasses", df) | ||
val dfWithId = df | ||
.withColumn(idCol, monotonically_increasing_id()) | ||
.withColumn(targetClasses, get(targetClassesCol).map(col).getOrElse(lit(getTargetClasses))) | ||
|
||
// Collect feature values for all features from original dataset - dfWithId | ||
val (categoricalFeatures, numericFeatures) = (getCategoricalFeatures, getNumericFeatures) | ||
|
||
// If numSamples is specified, randomly pick numSamples instances from the input dataset | ||
val sampled: Dataset[Row] = get(numSamples).map(dfWithId.orderBy(rand()).limit).getOrElse(dfWithId).cache | ||
|
||
// Collect values from the input dataframe and create dependenceDF from them | ||
val features = categoricalFeatures ++ numericFeatures | ||
val dependenceDfs= features.map { | ||
case f: ICECategoricalFeature => | ||
(f, collectCategoricalValues(dfWithId, f)) | ||
case f: ICENumericFeature => | ||
(f, collectSplits(dfWithId, f)) | ||
}.map { | ||
case (f, values) => | ||
calcDependence(sampled, idCol, targetClasses, f.getName, values, f.getOutputColName) | ||
} | ||
|
||
// In the case of ICE, the function will return the initial df with columns corresponding to each feature to explain | ||
// In the case of PDP the function will return df with a shape (1 row * number of features to explain) | ||
|
||
getKind.toLowerCase match { | ||
case `individualKind` => | ||
dependenceDfs.reduceOption(_.join(_, Seq(idCol), "inner")) | ||
.map(sampled.join(_, Seq(idCol), "inner").drop(idCol)).get | ||
case `averageKind` => | ||
dependenceDfs.reduce(_ crossJoin _) | ||
} | ||
} | ||
|
||
private def collectCategoricalValues[_](df: DataFrame, feature: ICECategoricalFeature): Array[_] = { | ||
val featureCount = DatasetExtensions.findUnusedColumnName("__feature__count__", df) | ||
df.groupBy(col(feature.name)) | ||
.agg(count("*").as(featureCount)) | ||
.orderBy(col(featureCount).desc) | ||
.head(feature.getNumTopValue) | ||
.map(row => row.get(0)) | ||
} | ||
|
||
private def createNSplits(n: Int)(from: Double, to: Double): Seq[Double] = { | ||
(0 to n) map { | ||
i => (to - from) / n * i + from | ||
} | ||
} | ||
|
||
private def collectSplits(df: DataFrame, numericFeature: ICENumericFeature): Array[Double] = { | ||
val (feature, nSplits, rangeMin, rangeMax) = (numericFeature.name, numericFeature.getNumSplits, | ||
numericFeature.rangeMin, numericFeature.rangeMax) | ||
val featureCol = df.schema(feature) | ||
|
||
val createSplits = createNSplits(nSplits) _ | ||
|
||
val values = if (rangeMin.isDefined && rangeMax.isDefined) { | ||
val (mi, ma) = (rangeMin.get, rangeMax.get) | ||
// The ranges are defined | ||
featureCol.dataType match { | ||
case _@(ByteType | IntegerType | LongType | ShortType) => | ||
if (ma.toLong - mi.toLong <= nSplits) { | ||
// For integral types, no need to create more splits than needed. | ||
(mi.toLong to ma.toLong) map (_.toDouble) | ||
} else { | ||
createSplits(mi, ma) | ||
} | ||
|
||
case _ => | ||
createSplits(mi, ma) | ||
} | ||
} else { | ||
// The ranges need to be calculated from background dataset. | ||
featureCol.dataType match { | ||
case _@(ByteType | IntegerType | LongType | ShortType) => | ||
val Row(minValue: Long, maxValue: Long) = df | ||
.agg(min(col(feature)).cast(LongType), max(col(feature)).cast(LongType)) | ||
.head | ||
|
||
val mi = rangeMin.map(_.toLong).getOrElse(minValue) | ||
val ma = rangeMax.map(_.toLong).getOrElse(maxValue) | ||
|
||
if (ma - mi <= nSplits) { | ||
// For integral types, no need to create more splits than needed. | ||
(mi to ma) map (_.toDouble) | ||
} else { | ||
createSplits(mi, ma) | ||
} | ||
case _ => | ||
val Row(minValue: Double, maxValue: Double) = df | ||
.agg(min(col(feature)).cast(DoubleType), max(col(feature)).cast(DoubleType)) | ||
.head | ||
|
||
val mi = rangeMin.getOrElse(minValue) | ||
val ma = rangeMax.getOrElse(maxValue) | ||
createSplits(mi, ma) | ||
} | ||
} | ||
|
||
values.toArray | ||
} | ||
|
||
override def copy(extra: ParamMap): Transformer = this.defaultCopy(extra) | ||
|
||
override def transformSchema(schema: StructType): StructType = { | ||
// Check the data type for categorical features | ||
val (categoricalFeatures, numericFeatures) = (getCategoricalFeatures, getNumericFeatures) | ||
val allowedCategoricalTypes = Array(StringType, BooleanType, ByteType, ShortType, IntegerType, LongType) | ||
categoricalFeatures.foreach { | ||
f => | ||
schema(f.name).dataType match { | ||
case StringType | BooleanType | ByteType | ShortType | IntegerType | LongType => | ||
case _ => throw new | ||
Exception(s"Data type for categorical features" + | ||
s" must be ${allowedCategoricalTypes.mkString("[", ",", "]")}.") | ||
} | ||
} | ||
|
||
val allowedNumericTypes = Array(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType) | ||
numericFeatures.foreach { | ||
f => | ||
schema(f.name).dataType match { | ||
case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | _: DecimalType => | ||
case _ => throw new | ||
Exception(s"Data type for numeric features must be ${allowedNumericTypes.mkString("[", ",", "]")}.") | ||
} | ||
} | ||
|
||
// Check if features are specified | ||
val featureNames = (categoricalFeatures ++ numericFeatures).map(_.getName) | ||
if (featureNames.isEmpty) { | ||
throw new Exception("No categorical features or numeric features are set to the explainer. " + | ||
"Call setCategoricalFeatures or setNumericFeatures to set the features to be explained.") | ||
} | ||
|
||
// Check for duplicate feature specification | ||
val duplicateFeatureNames = featureNames.groupBy(identity).mapValues(_.length).filter(_._2 > 1).keys.toArray | ||
if (duplicateFeatureNames.nonEmpty) { | ||
throw new Exception(s"Duplicate features specified: ${duplicateFeatureNames.mkString(", ")}") | ||
} | ||
|
||
validateSchema(schema) | ||
schema | ||
} | ||
} | ||
|
||
object ICETransformer extends ComplexParamsReadable[ICETransformer] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might want to rename this to featureCount to be consistent with other added columns