Skip to content

Commit

Permalink
fix: Fix issue in tabular lime sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
mhamilton723 committed May 21, 2021
1 parent 663d965 commit 74d4721
Showing 1 changed file with 2 additions and 9 deletions.
11 changes: 2 additions & 9 deletions src/main/scala/com/microsoft/ml/spark/lime/LIME.scala
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ class TabularLIME(val uid: String) extends Estimator[TabularLIMEModel]
extractParamMap().toSeq.foldLeft(new TabularLIMEModel()) { case (m, pp) =>
m.set(m.getParam(pp.param.name), pp.value)
}
.setColumnMeans(fitScaler.mean.toArray)
.setColumnSTDs(fitScaler.std.toArray)
})
}
Expand All @@ -207,22 +206,16 @@ class TabularLIMEModel(val uid: String) extends Model[TabularLIMEModel]

def this() = this(Identifiable.randomUID("TabularLIMEModel"))

val columnMeans = new DoubleArrayParam(this, "columnMeans", "the means of each of the columns for perturbation")

def getColumnMeans: Array[Double] = $(columnMeans)

def setColumnMeans(v: Array[Double]): this.type = set(columnMeans, v)

val columnSTDs = new DoubleArrayParam(this, "columnSTDs",
"the standard deviations of each of the columns for perturbation")

def getColumnSTDs: Array[Double] = $(columnSTDs)

def setColumnSTDs(v: Array[Double]): this.type = set(columnSTDs, v)

private def perturbedDenseVectors(v: DenseVector): Seq[DenseVector] = {
private def perturbedDenseVectors(dv: DenseVector): Seq[DenseVector] = {
Seq.fill(getNSamples) {
val perturbed = BDV.rand(v.size, Rand.gaussian) * BDV(getColumnSTDs) + BDV(getColumnMeans)
val perturbed = BDV.rand(dv.size, Rand.gaussian) * BDV(getColumnSTDs) + BDV(dv.values)
new DenseVector(perturbed.toArray)
}
}
Expand Down

0 comments on commit 74d4721

Please sign in to comment.