Skip to content

Commit

Permalink
[SPARK-6091] [MLLIB] Add MulticlassMetrics in PySpark/MLlib
Browse files Browse the repository at this point in the history
https://issues.apache.org/jira/browse/SPARK-6091

Author: Yanbo Liang <ybliang8@gmail.com>

Closes apache#6011 from yanboliang/spark-6091 and squashes the following commits:

bb3e4ba [Yanbo Liang] trigger jenkins
53c045d [Yanbo Liang] keep compatibility for python 2.6
972d5ac [Yanbo Liang] Add MulticlassMetrics in PySpark/MLlib
  • Loading branch information
yanboliang authored and nemccarthy committed Jun 19, 2015
1 parent 0d87b58 commit 7006eae
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.SparkContext._
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg.{Matrices, Matrix}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame

/**
* ::Experimental::
Expand All @@ -33,6 +34,13 @@ import org.apache.spark.rdd.RDD
@Experimental
class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) {

/**
* An auxiliary constructor taking a DataFrame.
* @param predictionAndLabels a DataFrame with two double columns: prediction and label
*/
private[mllib] def this(predictionAndLabels: DataFrame) =
this(predictionAndLabels.map(r => (r.getDouble(0), r.getDouble(1))))

private lazy val labelCountByClass: Map[Double, Long] = predictionAndLabels.values.countByValue()
private lazy val labelCount: Long = labelCountByClass.values.sum
private lazy val tpByClass: Map[Double, Int] = predictionAndLabels
Expand Down
129 changes: 129 additions & 0 deletions python/pyspark/mllib/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,135 @@ def r2(self):
return self.call("r2")


class MulticlassMetrics(JavaModelWrapper):
"""
Evaluator for multiclass classification.
>>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0),
... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)])
>>> metrics = MulticlassMetrics(predictionAndLabels)
>>> metrics.falsePositiveRate(0.0)
0.2...
>>> metrics.precision(1.0)
0.75...
>>> metrics.recall(2.0)
1.0...
>>> metrics.fMeasure(0.0, 2.0)
0.52...
>>> metrics.precision()
0.66...
>>> metrics.recall()
0.66...
>>> metrics.weightedFalsePositiveRate
0.19...
>>> metrics.weightedPrecision
0.68...
>>> metrics.weightedRecall
0.66...
>>> metrics.weightedFMeasure()
0.66...
>>> metrics.weightedFMeasure(2.0)
0.65...
"""

def __init__(self, predictionAndLabels):
"""
:param predictionAndLabels an RDD of (prediction, label) pairs.
"""
sc = predictionAndLabels.ctx
sql_ctx = SQLContext(sc)
df = sql_ctx.createDataFrame(predictionAndLabels, schema=StructType([
StructField("prediction", DoubleType(), nullable=False),
StructField("label", DoubleType(), nullable=False)]))
java_class = sc._jvm.org.apache.spark.mllib.evaluation.MulticlassMetrics
java_model = java_class(df._jdf)
super(MulticlassMetrics, self).__init__(java_model)

def truePositiveRate(self, label):
"""
Returns true positive rate for a given label (category).
"""
return self.call("truePositiveRate", label)

def falsePositiveRate(self, label):
"""
Returns false positive rate for a given label (category).
"""
return self.call("falsePositiveRate", label)

def precision(self, label=None):
"""
Returns precision or precision for a given label (category) if specified.
"""
if label is None:
return self.call("precision")
else:
return self.call("precision", float(label))

def recall(self, label=None):
"""
Returns recall or recall for a given label (category) if specified.
"""
if label is None:
return self.call("recall")
else:
return self.call("recall", float(label))

def fMeasure(self, label=None, beta=None):
"""
Returns f-measure or f-measure for a given label (category) if specified.
"""
if beta is None:
if label is None:
return self.call("fMeasure")
else:
return self.call("fMeasure", label)
else:
if label is None:
raise Exception("If the beta parameter is specified, label can not be none")
else:
return self.call("fMeasure", label, beta)

@property
def weightedTruePositiveRate(self):
"""
Returns weighted true positive rate.
(equals to precision, recall and f-measure)
"""
return self.call("weightedTruePositiveRate")

@property
def weightedFalsePositiveRate(self):
"""
Returns weighted false positive rate.
"""
return self.call("weightedFalsePositiveRate")

@property
def weightedRecall(self):
"""
Returns weighted averaged recall.
(equals to precision, recall and f-measure)
"""
return self.call("weightedRecall")

@property
def weightedPrecision(self):
"""
Returns weighted averaged precision.
"""
return self.call("weightedPrecision")

def weightedFMeasure(self, beta=None):
"""
Returns weighted averaged f-measure.
"""
if beta is None:
return self.call("weightedFMeasure")
else:
return self.call("weightedFMeasure", beta)


def _test():
import doctest
from pyspark import SparkContext
Expand Down

0 comments on commit 7006eae

Please sign in to comment.