Skip to content

Commit

Permalink
chore: change udf vec2array to pyspark.ml.functions.vector_to_array (#…
Browse files Browse the repository at this point in the history
…2131)

Co-authored-by: Mark Hamilton <mhamilton723@gmail.com>
  • Loading branch information
memoryz and mhamilton723 committed Nov 16, 2023
1 parent 46a1ef8 commit 28cd6db
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 70 deletions.
126 changes: 63 additions & 63 deletions docs/Explore Algorithms/Responsible AI/Explanation Dashboard.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,61 +2,64 @@
"cells": [
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"## Interpretability - Explanation Dashboard\n",
"\n",
"In this example, similar to the \"Interpretability - Tabular SHAP explainer\" notebook, we use Kernel SHAP to explain a tabular classification model built from the Adults Census dataset and then visualize the explanation in the ExplanationDashboard from https://github.com/microsoft/responsible-ai-widgets.\n",
"\n",
"First we import the packages and define some UDFs we will need later."
],
"metadata": {
"collapsed": false
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"%pip install raiwidgets itsdangerous==2.0.1 interpret-community"
],
"metadata": {
"collapsed": false
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from IPython.terminal.interactiveshell import TerminalInteractiveShell\n",
"from synapse.ml.explainers import *\n",
"from pyspark.ml import Pipeline\n",
"from pyspark.ml.classification import LogisticRegression\n",
"from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler\n",
"from pyspark.ml.functions import vector_to_array\n",
"from pyspark.sql.types import *\n",
"from pyspark.sql.functions import *\n",
"import pandas as pd\n",
"\n",
"vec_access = udf(lambda v, i: float(v[i]), FloatType())\n",
"vec2array = udf(lambda vec: vec.toArray().tolist(), ArrayType(FloatType()))"
],
"metadata": {
"collapsed": false
}
"vec_access = udf(lambda v, i: float(v[i]), FloatType())"
]
},
{
"cell_type": "markdown",
"source": [
"Now let's read the data and train a simple binary classification model."
],
"metadata": {
"collapsed": false
}
},
"source": [
"Now let's read the data and train a simple binary classification model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"df = spark.read.parquet(\n",
Expand Down Expand Up @@ -102,46 +105,46 @@
"lr = LogisticRegression(featuresCol=\"features\", labelCol=\"label\", weightCol=\"fnlwgt\")\n",
"pipeline = Pipeline(stages=[strIndexer, onehotEnc, vectAssem, lr])\n",
"model = pipeline.fit(training)"
],
"metadata": {
"collapsed": false
}
]
},
{
"cell_type": "markdown",
"source": [
"After the model is trained, we randomly select some observations to be explained."
],
"metadata": {
"collapsed": false
}
},
"source": [
"After the model is trained, we randomly select some observations to be explained."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"explain_instances = (\n",
" model.transform(training).orderBy(rand()).limit(5).repartition(200).cache()\n",
")\n",
"display(explain_instances)"
],
"metadata": {
"collapsed": false
}
]
},
{
"cell_type": "markdown",
"source": [
"We create a TabularSHAP explainer, set the input columns to all the features the model takes, specify the model and the target output column we are trying to explain. In this case, we are trying to explain the \"probability\" output which is a vector of length 2, and we are only looking at class 1 probability. Specify targetClasses to `[0, 1]` if you want to explain class 0 and 1 probability at the same time. Finally we sample 100 rows from the training data for background data, which is used for integrating out features in Kernel SHAP."
],
"metadata": {
"collapsed": false
}
},
"source": [
"We create a TabularSHAP explainer, set the input columns to all the features the model takes, specify the model and the target output column we are trying to explain. In this case, we are trying to explain the \"probability\" output which is a vector of length 2, and we are only looking at class 1 probability. Specify targetClasses to `[0, 1]` if you want to explain class 0 and 1 probability at the same time. Finally we sample 100 rows from the training data for background data, which is used for integrating out features in Kernel SHAP."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"shap = TabularSHAP(\n",
Expand All @@ -155,29 +158,29 @@
")\n",
"\n",
"shap_df = shap.transform(explain_instances)"
],
"metadata": {
"collapsed": false
}
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"Once we have the resulting dataframe, we extract the class 1 probability of the model output, the SHAP values for the target class, the original features and the true label. Then we convert it to a pandas dataframe for visualization.\n",
"For each observation, the first element in the SHAP values vector is the base value (the mean output of the background dataset), and each of the following element is the SHAP values for each feature."
],
"metadata": {
"collapsed": false
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"shaps = (\n",
" shap_df.withColumn(\"probability\", vec_access(col(\"probability\"), lit(1)))\n",
" .withColumn(\"shapValues\", vec2array(col(\"shapValues\").getItem(0)))\n",
" .withColumn(\"shapValues\", vector_to_array(col(\"shapValues\").getItem(0)))\n",
" .select(\n",
" [\"shapValues\", \"probability\", \"label\"] + categorical_features + numeric_features\n",
" )\n",
Expand All @@ -187,23 +190,23 @@
"shaps_local.sort_values(\"probability\", ascending=False, inplace=True, ignore_index=True)\n",
"pd.set_option(\"display.max_colwidth\", None)\n",
"shaps_local"
],
"metadata": {
"collapsed": false
}
]
},
{
"cell_type": "markdown",
"source": [
"We can visualize the explanation in the [interpret-community format](https://github.com/interpretml/interpret-community) in the ExplanationDashboard from https://github.com/microsoft/responsible-ai-widgets/"
],
"metadata": {
"collapsed": false
}
},
"source": [
"We can visualize the explanation in the [interpret-community format](https://github.com/interpretml/interpret-community) in the ExplanationDashboard from https://github.com/microsoft/responsible-ai-widgets/"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import numpy as np\n",
Expand All @@ -216,14 +219,14 @@
"local_importance_values = shaps_local[[\"shapValues\"]]\n",
"eval_data = shaps_local[features]\n",
"true_y = np.array(shaps_local[[\"label\"]])"
],
"metadata": {
"collapsed": false
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"list_local_importance_values = local_importance_values.values.tolist()\n",
Expand All @@ -236,19 +239,16 @@
" # remove the bias from local importance values\n",
" del converted_list[0]\n",
" converted_importance_values.append(converted_list)"
],
"metadata": {
"collapsed": false
}
]
},
{
"cell_type": "markdown",
"source": [
"When running Synapse Analytics, please follow instructions here [Package management - Azure Synapse Analytics | Microsoft Docs](https://docs.microsoft.com/en-us/azure/synapse-analytics/spark/apache-spark-azure-portal-add-libraries) to install [\"raiwidgets\"](https://pypi.org/project/raiwidgets/) and [\"interpret-community\"](https://pypi.org/project/interpret-community/) packages."
],
"metadata": {
"collapsed": false
}
},
"source": [
"When running Synapse Analytics, please follow instructions here [Package management - Azure Synapse Analytics | Microsoft Docs](https://docs.microsoft.com/en-us/azure/synapse-analytics/spark/apache-spark-azure-portal-add-libraries) to install [\"raiwidgets\"](https://pypi.org/project/raiwidgets/) and [\"interpret-community\"](https://pypi.org/project/interpret-community/) packages."
]
},
{
"cell_type": "code",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,13 @@
"from pyspark.ml import Pipeline\n",
"from pyspark.ml.classification import LogisticRegression\n",
"from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler\n",
"from pyspark.ml.functions import vector_to_array\n",
"from pyspark.sql.types import *\n",
"from pyspark.sql.functions import *\n",
"import pandas as pd\n",
"from synapse.ml.core.platform import *\n",
"\n",
"\n",
"vec_access = udf(lambda v, i: float(v[i]), FloatType())\n",
"vec2array = udf(lambda vec: vec.toArray().tolist(), ArrayType(FloatType()))"
"vec_access = udf(lambda v, i: float(v[i]), FloatType())"
]
},
{
Expand Down Expand Up @@ -225,7 +224,7 @@
"source": [
"shaps = (\n",
" shap_df.withColumn(\"probability\", vec_access(col(\"probability\"), lit(1)))\n",
" .withColumn(\"shapValues\", vec2array(col(\"shapValues\").getItem(0)))\n",
" .withColumn(\"shapValues\", vector_to_array(col(\"shapValues\").getItem(0)))\n",
" .select(\n",
" [\"shapValues\", \"probability\", \"label\"] + categorical_features + numeric_features\n",
" )\n",
Expand Down
24 changes: 21 additions & 3 deletions docs/Explore Algorithms/Responsible AI/Text Explainers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
"nuid": "a2689fb5-2425-430d-8261-6e39598b6505",
"showTitle": false,
"title": ""
},
"vscode": {
"languageId": "python"
}
},
"outputs": [],
Expand All @@ -35,11 +38,11 @@
"from pyspark.sql.types import *\n",
"from pyspark.ml import Pipeline\n",
"from pyspark.ml.classification import LogisticRegression\n",
"from pyspark.ml.functions import vector_to_array\n",
"from synapse.ml.explainers import *\n",
"from synapse.ml.featurize.text import TextFeaturizer\n",
"from synapse.ml.core.platform import *\n",
"\n",
"vec2array = udf(lambda vec: vec.toArray().tolist(), ArrayType(FloatType()))\n",
"vec_access = udf(lambda v, i: float(v[i]), FloatType())"
]
},
Expand All @@ -66,6 +69,9 @@
"nuid": "a02806b1-e0ba-4b6f-93bf-5d3eb635e43e",
"showTitle": false,
"title": ""
},
"vscode": {
"languageId": "python"
}
},
"outputs": [],
Expand Down Expand Up @@ -103,6 +109,9 @@
"nuid": "9a2fb867-194d-4660-b655-6373ec7272bf",
"showTitle": false,
"title": ""
},
"vscode": {
"languageId": "python"
}
},
"outputs": [],
Expand Down Expand Up @@ -139,6 +148,9 @@
"nuid": "3a9fbdc8-9660-4337-b3eb-7c717aabf0cc",
"showTitle": false,
"title": ""
},
"vscode": {
"languageId": "python"
}
},
"outputs": [],
Expand Down Expand Up @@ -181,6 +193,9 @@
"nuid": "63623d84-8d6d-4f5b-8e2b-83e21866fb26",
"showTitle": false,
"title": ""
},
"vscode": {
"languageId": "python"
}
},
"outputs": [],
Expand All @@ -200,7 +215,7 @@
" lime.transform(explain_instances)\n",
" .select(\"tokens\", \"weights\", \"r2\", \"probability\", \"text\")\n",
" .withColumn(\"probability\", vec_access(\"probability\", lit(1)))\n",
" .withColumn(\"weights\", vec2array(col(\"weights\").getItem(0)))\n",
" .withColumn(\"weights\", vector_to_array(col(\"weights\").getItem(0)))\n",
" .withColumn(\"r2\", vec_access(\"r2\", lit(0)))\n",
" .withColumn(\"tokens_weights\", arrays_zip(\"tokens\", \"weights\"))\n",
")\n",
Expand Down Expand Up @@ -233,6 +248,9 @@
"nuid": "9d3fd01d-f140-465e-ae53-d3b25f246e4d",
"showTitle": false,
"title": ""
},
"vscode": {
"languageId": "python"
}
},
"outputs": [],
Expand All @@ -251,7 +269,7 @@
" shap.transform(explain_instances)\n",
" .select(\"tokens\", \"shaps\", \"r2\", \"probability\", \"text\")\n",
" .withColumn(\"probability\", vec_access(\"probability\", lit(1)))\n",
" .withColumn(\"shaps\", vec2array(col(\"shaps\").getItem(0)))\n",
" .withColumn(\"shaps\", vector_to_array(col(\"shaps\").getItem(0)))\n",
" .withColumn(\"shaps\", slice(col(\"shaps\"), lit(2), size(col(\"shaps\"))))\n",
" .withColumn(\"r2\", vec_access(\"r2\", lit(0)))\n",
" .withColumn(\"tokens_shaps\", arrays_zip(\"tokens\", \"shaps\"))\n",
Expand Down

0 comments on commit 28cd6db

Please sign in to comment.