Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ICE/PDP explainer #1284

Merged
merged 42 commits into from Dec 20, 2021
Merged
Show file tree
Hide file tree
Changes from 41 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
aa166b5
Initial PDP version.
ezherdeva Aug 21, 2021
151ef99
Apply suggestions
ezherdeva Sep 21, 2021
b47d410
Added ICE
ezherdeva Sep 23, 2021
7d70110
Apply suggestions and fix
ezherdeva Oct 4, 2021
f5049e3
Added discrete
ezherdeva Oct 16, 2021
e6e985e
Added logic for discrete features
ezherdeva Oct 19, 2021
a23df5c
New logic (without unit tests)
ezherdeva Oct 20, 2021
43d1648
WIP
memoryz Oct 20, 2021
9b379e8
WIP
memoryz Oct 21, 2021
b9cbb7b
rebased the main branch
ezherdeva Oct 21, 2021
5ba0bec
small fix
ezherdeva Oct 22, 2021
c0c9ddf
added some unit tests
ezherdeva Nov 6, 2021
51e3d4f
added python code
ezherdeva Nov 18, 2021
bda7882
Update core/src/main/scala/com/microsoft/azure/synapse/ml/explainers/…
ezherdeva Nov 20, 2021
fa0aa6f
Update core/src/main/scala/com/microsoft/azure/synapse/ml/explainers/…
ezherdeva Nov 20, 2021
adc4301
Update core/src/main/scala/com/microsoft/azure/synapse/ml/explainers/…
ezherdeva Nov 20, 2021
058f27b
fix1
ezherdeva Nov 20, 2021
f234890
Merge branch 'ezherdeva/ice_pdp' of https://github.com/ezherdeva/mmls…
ezherdeva Nov 20, 2021
5d3d38e
Update core/src/main/scala/com/microsoft/azure/synapse/ml/explainers/…
ezherdeva Nov 20, 2021
b630d39
Merge branch 'ezherdeva/ice_pdp' of https://github.com/ezherdeva/mmls…
ezherdeva Nov 20, 2021
172a050
Update core/src/main/scala/com/microsoft/azure/synapse/ml/explainers/…
ezherdeva Nov 20, 2021
69486ed
fix 2
ezherdeva Nov 20, 2021
fd7d13b
Merge branch 'ezherdeva/ice_pdp' of https://github.com/ezherdeva/mmls…
ezherdeva Nov 20, 2021
1d658d5
Fixed comments
ezherdeva Nov 20, 2021
25ad8fa
fix comments
ezherdeva Nov 29, 2021
8045357
fix comments 2
ezherdeva Nov 29, 2021
df4e6c6
Merge branch 'master' into ezherdeva/ice_pdp
ezherdeva Nov 29, 2021
2c207d3
last fix
ezherdeva Dec 3, 2021
fa87e5c
added copyright to py files
ezherdeva Dec 3, 2021
bce0a92
Merge branch 'master' into ezherdeva/ice_pdp
ezherdeva Dec 3, 2021
77b6267
Update src/test/scala/com/microsoft/azure/synapse/ml/core/test/fuzzin…
ezherdeva Dec 3, 2021
7c25c57
fix 2
ezherdeva Dec 7, 2021
98173d5
Merge branch 'ezherdeva/ice_pdp' of https://github.com/ezherdeva/mmls…
ezherdeva Dec 7, 2021
a11c718
fix python issue
ezherdeva Dec 9, 2021
6483daf
fix python issue (small fix)
ezherdeva Dec 9, 2021
ef2c35e
fixed python issue
ezherdeva Dec 10, 2021
8c3a6dc
fixed comments and add more docs
ezherdeva Dec 11, 2021
5b53fa5
Merge branch 'master' into ezherdeva/ice_pdp
mhamilton723 Dec 14, 2021
e492014
fix comments
ezherdeva Dec 15, 2021
0043fb6
Merge branch 'ezherdeva/ice_pdp' of https://github.com/ezherdeva/mmls…
ezherdeva Dec 15, 2021
61624de
fix code style
ezherdeva Dec 15, 2021
9254b8d
Merge branch 'master' into ezherdeva/ice_pdp
mhamilton723 Dec 20, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
48 changes: 48 additions & 0 deletions core/src/main/python/synapse/ml/explainers/ICETransformer.py
@@ -0,0 +1,48 @@
# Copyright (C) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in project root for information.

from synapse.ml.explainers._ICETransformer import _ICETransformer
from pyspark.ml.common import inherit_doc
from typing import List, Dict, Union

@inherit_doc
class ICETransformer(_ICETransformer):
def setCategoricalFeatures(self, values: List[Union[str, Dict]]):
"""
Args:
values: The list of values that represent categorical features to explain.
Values are list of dicts with parameters or just a list of names of categorical features
"""
if len(values) == 0:
pass
else:
list_values = []
for value in values:
if isinstance(value, str):
list_values.append({"name": value})
elif isinstance(value, dict):
list_values.append(value)
else:
pass
self._java_obj.setCategoricalFeaturesPy(list_values)
return self

def setNumericFeatures(self, values: List[Union[str, Dict]]):
"""
Args:
values: The list of values that represent numeric features to explain.
Values are list of dicts with parameters or just a list of names of numeric features
"""
if len(values) == 0:
pass
else:
list_values = []
for value in values:
if isinstance(value, str):
list_values.append({"name": value})
elif isinstance(value, dict):
list_values.append(value)
else:
pass
self._java_obj.setNumericFeaturesPy(list_values)
return self
@@ -0,0 +1,278 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.explainers

import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Transformer}
import org.apache.spark.ml.param.{ParamMap, ParamValidators, Params, _}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.ml.stat.Summarizer
import com.microsoft.azure.synapse.ml.codegen.Wrappable
import scala.jdk.CollectionConverters.asScalaBufferConverter


trait ICEFeatureParams extends Params with HasNumSamples {

val averageKind = "average"
val individualKind = "individual"

val categoricalFeatures = new TypedArrayParam[ICECategoricalFeature] (
this,
"categoricalFeatures",
"The list of categorical features to explain.",
_.forall(_.validate)
)

def setCategoricalFeatures(values: Seq[ICECategoricalFeature]): this.type = this.set(categoricalFeatures, values)
def getCategoricalFeatures: Seq[ICECategoricalFeature] = $(categoricalFeatures)

def setCategoricalFeaturesPy(values: java.util.List[java.util.HashMap[String, Any]]): this.type = {
val features: Seq[ICECategoricalFeature] = values.asScala.map(f => ICECategoricalFeature.fromMap(f))
this.setCategoricalFeatures(features)
}

val numericFeatures = new TypedArrayParam[ICENumericFeature] (
this,
"numericFeatures",
"The list of numeric features to explain.",
_.forall(_.validate)
)

def setNumericFeatures(values: Seq[ICENumericFeature]): this.type = this.set(numericFeatures, values)
def getNumericFeatures: Seq[ICENumericFeature] = $(numericFeatures)

def setNumericFeaturesPy(values: java.util.List[java.util.HashMap[String, Any]]): this.type = {
val features: Seq[ICENumericFeature] = values.asScala.map(ICENumericFeature.fromMap)
this.setNumericFeatures(features)
}

val kind = new Param[String] (
this,
"kind",
"Whether to return the partial dependence plot (PDP) averaged across all the samples in the " +
"dataset or individual feature importance (ICE) per sample. " +
"Allowed values are \"average\" for PDP and \"individual\" for ICE.",
ParamValidators.inArray(Array(averageKind, individualKind))
)

def getKind: String = $(kind)
def setKind(value: String): this.type = set(kind, value)

setDefault(kind -> "individual",
numericFeatures -> Seq.empty[ICENumericFeature],
categoricalFeatures -> Seq.empty[ICECategoricalFeature])
}

/**
* ICETransformer displays the model dependence on specified features with the given dataframe
* as background dataset. It supports 2 types of plots: individual - dependence per instance and
* average - across all the samples in the dataset.
* Note: This transformer only supports one-way dependence plot.
*/
@org.apache.spark.annotation.Experimental
class ICETransformer(override val uid: String) extends Transformer
with HasExplainTarget
with HasModel
with ICEFeatureParams
with Wrappable
with ComplexParamsWritable {

override protected lazy val pyInternalWrapper = true

def this() = this(Identifiable.randomUID("ICETransformer"))

private def calcDependence(df: DataFrame, idCol: String, targetClassesColumn: String,
feature: String, values: Array[_], outputColName: String): DataFrame = {

val dataType = df.schema(feature).dataType
val explodeFunc = explode(array(values.map(v => lit(v)): _*).cast(ArrayType(dataType)))

val predicted = getModel.transform(df.withColumn(feature, explodeFunc))
val targetCol = DatasetExtensions.findUnusedColumnName("target", predicted)
val dependenceCol = DatasetExtensions.findUnusedColumnName("feature__dependence", predicted)

val explainTarget = extractTarget(predicted.schema, targetClassesColumn)
val result = predicted.withColumn(targetCol, explainTarget)

getKind.toLowerCase match {
case `averageKind` =>
// PDP output schema: 1 row * 1 col (pdp for the given feature: feature_value -> explanations)
result
.groupBy(feature)
.agg(Summarizer.mean(col(targetCol)).alias(dependenceCol))
.agg(
map_from_arrays(
collect_list(feature),
collect_list(dependenceCol)
).alias(outputColName)
)

case `individualKind` =>
// ICE output schema: n rows * 2 cols (idCol + ice for the given feature: map(feature_value -> explanations))
result
.groupBy(idCol)
.agg(
map_from_arrays(
collect_list(feature),
collect_list(targetCol)
).alias(outputColName)
)
}
}

def transform(ds: Dataset[_]): DataFrame = {
transformSchema(ds.schema)
val df = ds.toDF
val idCol = DatasetExtensions.findUnusedColumnName("idCol", df)
val targetClasses = DatasetExtensions.findUnusedColumnName("targetClasses", df)
val dfWithId = df
.withColumn(idCol, monotonically_increasing_id())
.withColumn(targetClasses, get(targetClassesCol).map(col).getOrElse(lit(getTargetClasses)))

// Collect feature values for all features from original dataset - dfWithId
val (categoricalFeatures, numericFeatures) = (getCategoricalFeatures, getNumericFeatures)

// If numSamples is specified, randomly pick numSamples instances from the input dataset
val sampled: Dataset[Row] = get(numSamples).map(dfWithId.orderBy(rand()).limit).getOrElse(dfWithId).cache

// Collect values from the input dataframe and create dependenceDF from them
val features = categoricalFeatures ++ numericFeatures
val dependenceDfs= features.map {
case f: ICECategoricalFeature =>
(f, collectCategoricalValues(dfWithId, f))
case f: ICENumericFeature =>
(f, collectSplits(dfWithId, f))
}.map {
case (f, values) =>
calcDependence(sampled, idCol, targetClasses, f.getName, values, f.getOutputColName)
}

// In the case of ICE, the function will return the initial df with columns corresponding to each feature to explain
// In the case of PDP the function will return df with a shape (1 row * number of features to explain)

getKind.toLowerCase match {
case `individualKind` =>
dependenceDfs.reduceOption(_.join(_, Seq(idCol), "inner"))
.map(sampled.join(_, Seq(idCol), "inner").drop(idCol)).get
case `averageKind` =>
dependenceDfs.reduce(_ crossJoin _)
}
}

private def collectCategoricalValues[_](df: DataFrame, feature: ICECategoricalFeature): Array[_] = {
val featureCount = DatasetExtensions.findUnusedColumnName("__feature__count__", df)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might want to rename this to featureCount to be consistent with other added columns

df.groupBy(col(feature.name))
.agg(count("*").as(featureCount))
.orderBy(col(featureCount).desc)
.head(feature.getNumTopValue)
.map(row => row.get(0))
}

private def createNSplits(n: Int)(from: Double, to: Double): Seq[Double] = {
(0 to n) map {
i => (to - from) / n * i + from
}
}

private def collectSplits(df: DataFrame, numericFeature: ICENumericFeature): Array[Double] = {
val (feature, nSplits, rangeMin, rangeMax) = (numericFeature.name, numericFeature.getNumSplits,
numericFeature.rangeMin, numericFeature.rangeMax)
val featureCol = df.schema(feature)

val createSplits = createNSplits(nSplits) _

val values = if (rangeMin.isDefined && rangeMax.isDefined) {
val (mi, ma) = (rangeMin.get, rangeMax.get)
// The ranges are defined
featureCol.dataType match {
case _@(ByteType | IntegerType | LongType | ShortType) =>
if (ma.toLong - mi.toLong <= nSplits) {
// For integral types, no need to create more splits than needed.
(mi.toLong to ma.toLong) map (_.toDouble)
} else {
createSplits(mi, ma)
}

case _ =>
createSplits(mi, ma)
}
} else {
// The ranges need to be calculated from background dataset.
featureCol.dataType match {
case _@(ByteType | IntegerType | LongType | ShortType) =>
val Row(minValue: Long, maxValue: Long) = df
.agg(min(col(feature)).cast(LongType), max(col(feature)).cast(LongType))
.head

val mi = rangeMin.map(_.toLong).getOrElse(minValue)
val ma = rangeMax.map(_.toLong).getOrElse(maxValue)

if (ma - mi <= nSplits) {
// For integral types, no need to create more splits than needed.
(mi to ma) map (_.toDouble)
} else {
createSplits(mi, ma)
}
case _ =>
val Row(minValue: Double, maxValue: Double) = df
.agg(min(col(feature)).cast(DoubleType), max(col(feature)).cast(DoubleType))
.head

val mi = rangeMin.getOrElse(minValue)
val ma = rangeMax.getOrElse(maxValue)
createSplits(mi, ma)
}
}

values.toArray
}

override def copy(extra: ParamMap): Transformer = this.defaultCopy(extra)

override def transformSchema(schema: StructType): StructType = {
// Check the data type for categorical features
val (categoricalFeatures, numericFeatures) = (getCategoricalFeatures, getNumericFeatures)
val allowedCategoricalTypes = Array(StringType, BooleanType, ByteType, ShortType, IntegerType, LongType)
categoricalFeatures.foreach {
f =>
schema(f.name).dataType match {
case StringType | BooleanType | ByteType | ShortType | IntegerType | LongType =>
case _ => throw new
Exception(s"Data type for categorical features" +
s" must be ${allowedCategoricalTypes.mkString("[", ",", "]")}.")
}
}

val allowedNumericTypes = Array(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType)
numericFeatures.foreach {
f =>
schema(f.name).dataType match {
case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | _: DecimalType =>
case _ => throw new
Exception(s"Data type for numeric features must be ${allowedNumericTypes.mkString("[", ",", "]")}.")
}
}

// Check if features are specified
val featureNames = (categoricalFeatures ++ numericFeatures).map(_.getName)
if (featureNames.isEmpty) {
throw new Exception("No categorical features or numeric features are set to the explainer. " +
"Call setCategoricalFeatures or setNumericFeatures to set the features to be explained.")
}

// Check for duplicate feature specification
val duplicateFeatureNames = featureNames.groupBy(identity).mapValues(_.length).filter(_._2 > 1).keys.toArray
if (duplicateFeatureNames.nonEmpty) {
throw new Exception(s"Duplicate features specified: ${duplicateFeatureNames.mkString(", ")}")
}

validateSchema(schema)
schema
}
}

object ICETransformer extends ComplexParamsReadable[ICETransformer]