Skip to content

Commit

Permalink
fix: LIME sometimes return nan weights (#1112)
Browse files Browse the repository at this point in the history
  • Loading branch information
memoryz committed Jul 3, 2021
1 parent 85f089d commit 94f04a8
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,9 @@ abstract class LIMEBase(override val uid: String)
(input, output, weight)
}.toSeq.unzip3

val inputsBV = BDM(inputs: _*)
val outputsBV = BDM(outputs: _*)
val weightsBV = BDV(weights: _*)

val lassoResults = outputsBV(::, *).toIndexedSeq.map {
new LassoRegression(regularization).fit(inputsBV, _, weightsBV, fitIntercept = true)
val (inputsBM, outputsBM, weightsBV) = (BDM(inputs: _*), BDM(outputs: _*), BDV(weights: _*))
val lassoResults = outputsBM(::, *).toIndexedSeq.map {
new LassoRegression(regularization).fit(inputsBM, _, weightsBV, fitIntercept = true)
}

val coefficientsMatrix = lassoResults.map(_.coefficients.toSpark)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,21 @@ private[explainers] class LIMETabularSampler(val instance: Row, val featureStats
(newSample, Vectors.dense(newStates.toArray), distance)
}

/**
* Create a sample that's identical to the instance, with states set to 1 for categorical vars
* and original value for numerical vars. Distance is set to 0.
*/
def sampleIdentity: (Row, Vector, Double) = {
val (identityRow, identityState) = featureStats.zipWithIndex.map {
case (_: DiscreteFeatureStats[Any], i) =>
(instance.get(i), 1d)
case (_: ContinuousFeatureStats, i) =>
(instance.getAsDouble(i), instance.getAsDouble(i))
}.unzip

(Row.fromSeq(identityRow), Vectors.dense(identityState.toArray), 0d)
}

override def nextState: Vector = {
val states = featureStats.zipWithIndex.map {
case (feature: DiscreteFeatureStats[Any], i) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ package com.microsoft.ml.spark.explainers
import breeze.stats.distributions.RandBasis
import com.microsoft.ml.spark.core.schema.DatasetExtensions
import org.apache.spark.injections.UDFUtils
import org.apache.spark.ml.ComplexParamsReadable
import org.apache.spark.ml.param.StringArrayParam
import org.apache.spark.ml.param.shared.HasInputCols
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.{ComplexParamsReadable, linalg}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row}
Expand Down Expand Up @@ -56,10 +56,14 @@ class TabularLIME(override val uid: String)
row: Row =>
implicit val randBasis: RandBasis = RandBasis.mt0
val sampler = new LIMETabularSampler(row, featureStats)
(1 to numSamples).map {
_ =>
val (sample, feature, distance) = sampler.sample
(sample, feature, distance)

// Adding identity sample to avoid all zero states in the sample space for categorical variables.
sampler.sampleIdentity +: {
(1 to numSamples).map {
_ =>
val (sample: Row, feature: linalg.Vector, distance: Double) = sampler.sample
(sample, feature, distance)
}
}
},
getSampleSchema(sampleType)
Expand Down

0 comments on commit 94f04a8

Please sign in to comment.