# One vs. Rest Classifier
[Source link](https://learning.oreilly.com/library/view/apache-spark-2-x/9781783551606/0b079c89-7dee-4091-8692-43b8955cf6b1.xhtml), [data](https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/iris.scale)

### Import packages

In [1]:
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.classification import OneVsRest
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

### Load data

In [2]:
data_path = ##PATH-TO-DATA##

In [3]:
data = spark.read.format('libsvm').option('numFeatures', '4').load(data_path + 'iris.scale.txt')

In [4]:
data.show(10, False)

[Stage 0:>                                                          (0 + 1) / 1]

+-----+--------------------------------------------------------+
|label|features                                                |
+-----+--------------------------------------------------------+
|1.0  |(4,[0,1,2,3],[-0.555556,0.25,-0.864407,-0.916667])      |
|1.0  |(4,[0,1,2,3],[-0.666667,-0.166667,-0.864407,-0.916667]) |
|1.0  |(4,[0,2,3],[-0.777778,-0.898305,-0.916667])             |
|1.0  |(4,[0,1,2,3],[-0.833333,-0.0833334,-0.830508,-0.916667])|
|1.0  |(4,[0,1,2,3],[-0.611111,0.333333,-0.864407,-0.916667])  |
|1.0  |(4,[0,1,2,3],[-0.388889,0.583333,-0.762712,-0.75])      |
|1.0  |(4,[0,1,2,3],[-0.833333,0.166667,-0.864407,-0.833333])  |
|1.0  |(4,[0,1,2,3],[-0.611111,0.166667,-0.830508,-0.916667])  |
|1.0  |(4,[0,1,2,3],[-0.944444,-0.25,-0.864407,-0.916667])     |
|1.0  |(4,[0,1,2,3],[-0.666667,-0.0833334,-0.830508,-1.0])     |
+-----+--------------------------------------------------------+
only showing top 10 rows



                                                                                

### Initialize Logistic Regression

In [5]:
lr = LogisticRegression().setLabelCol('label').setFeaturesCol('features')
lr.setMaxIter(15)
lr.setTol(1e-3)
lr.setFitIntercept(True)

LogisticRegression_9079be1a3100

### Initialize One-vs-Rest with Logistic Regression

In [6]:
ovr = OneVsRest(classifier=lr)

In [7]:
ovr.setPredictionCol('prediction')

OneVsRest_db57606a0f1a

### Train/Test split & fit model

In [8]:
train, test = data.randomSplit([0.8, 0.2], seed=20)

In [9]:
model = ovr.fit(train)

23/03/02 04:25:54 WARN org.apache.spark.ml.util.Instrumentation: [ca0e767a] All labels are the same value and fitIntercept=true, so the coefficients will be zeros. Training is not needed.
23/03/02 04:26:00 WARN com.github.fommil.netlib.BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeSystemBLAS
23/03/02 04:26:00 WARN com.github.fommil.netlib.BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeRefBLAS


### Generate predictions for the test data

In [10]:
predictions = model.transform(test).select('label', 'prediction')

In [11]:
predictions.show(predictions.count())

[Stage 60:>                                                         (0 + 1) / 1]

+-----+----------+
|label|prediction|
+-----+----------+
|  1.0|       1.0|
|  1.0|       1.0|
|  1.0|       1.0|
|  1.0|       1.0|
|  1.0|       1.0|
|  1.0|       1.0|
|  1.0|       1.0|
|  1.0|       1.0|
|  1.0|       1.0|
|  1.0|       1.0|
|  2.0|       2.0|
|  2.0|       2.0|
|  2.0|       2.0|
|  2.0|       2.0|
|  2.0|       2.0|
|  2.0|       2.0|
|  2.0|       2.0|
|  3.0|       3.0|
|  3.0|       3.0|
|  3.0|       3.0|
|  3.0|       2.0|
|  3.0|       3.0|
|  3.0|       3.0|
|  3.0|       3.0|
|  3.0|       3.0|
|  3.0|       3.0|
|  3.0|       3.0|
|  3.0|       3.0|
|  3.0|       3.0|
|  3.0|       3.0|
+-----+----------+



                                                                                

### Evaluate

You can find more metrics [here](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.evaluation.MulticlassClassificationEvaluator.html#pyspark.ml.evaluation.MulticlassClassificationEvaluator.metricName).

In [12]:
Eval = MulticlassClassificationEvaluator()
Eval.setMetricName('accuracy')
print('Accuracy: ' + str(Eval.evaluate(predictions)))

[Stage 61:>                                                         (0 + 1) / 1]

Accuracy: 0.9666666666666667


                                                                                