In [1]:
from pyspark.sql.types import StructType, StructField, DoubleType
from pyspark.ml import Pipeline
from pyspark.ml.feature import VectorAssembler

In [2]:
schema = StructType([
    StructField('label', DoubleType(), False),
    StructField('a1', DoubleType(), False),
    StructField('a2', DoubleType(), False)
])

In [3]:
raw = sqlContext.read.format('csv').schema(schema).load('test.csv')
raw.show()

+-----+------+------+
|label|    a1|    a2|
+-----+------+------+
|  1.0|2.6487|4.5192|
|  1.0|1.5438|2.4443|
|  1.0| 1.899|4.2409|
|  1.0|2.4711|5.8097|
|  1.0| 3.359|6.4423|
|  1.0|3.2406|5.8097|
|  1.0|3.8128|6.3917|
|  1.0|4.4441|6.8725|
|  1.0|3.6747|6.7966|
|  1.0|4.7401| 8.163|
|  1.0|3.8917|7.4038|
|  1.0| 4.602|7.6316|
|  1.0|5.7265|7.7581|
|  1.0|4.9571|6.5688|
|  1.0|3.9903|5.3543|
|  1.0|3.0236|4.4686|
|  1.0|2.0568|2.9757|
|  1.0|1.2676|2.4443|
|  1.0| 1.169|0.9008|
|  1.0|1.7411|2.1154|
+-----+------+------+
only showing top 20 rows



In [4]:
raw.printSchema()

root
 |-- label: double (nullable = true)
 |-- a1: double (nullable = true)
 |-- a2: double (nullable = true)



In [5]:
jvm = sc._jvm
BinomialLogisticRegression = jvm.org.apache.spark.ml.classification.BinomialLogisticRegression

In [6]:
assembler = VectorAssembler(inputCols=['a1', 'a2'], outputCol='features')
blr =  BinomialLogisticRegression('test')

In [7]:
assembled = assembler.transform(raw)
assembled.show()

+-----+------+------+---------------+
|label|    a1|    a2|       features|
+-----+------+------+---------------+
|  1.0|2.6487|4.5192|[2.6487,4.5192]|
|  1.0|1.5438|2.4443|[1.5438,2.4443]|
|  1.0| 1.899|4.2409| [1.899,4.2409]|
|  1.0|2.4711|5.8097|[2.4711,5.8097]|
|  1.0| 3.359|6.4423| [3.359,6.4423]|
|  1.0|3.2406|5.8097|[3.2406,5.8097]|
|  1.0|3.8128|6.3917|[3.8128,6.3917]|
|  1.0|4.4441|6.8725|[4.4441,6.8725]|
|  1.0|3.6747|6.7966|[3.6747,6.7966]|
|  1.0|4.7401| 8.163| [4.7401,8.163]|
|  1.0|3.8917|7.4038|[3.8917,7.4038]|
|  1.0| 4.602|7.6316| [4.602,7.6316]|
|  1.0|5.7265|7.7581|[5.7265,7.7581]|
|  1.0|4.9571|6.5688|[4.9571,6.5688]|
|  1.0|3.9903|5.3543|[3.9903,5.3543]|
|  1.0|3.0236|4.4686|[3.0236,4.4686]|
|  1.0|2.0568|2.9757|[2.0568,2.9757]|
|  1.0|1.2676|2.4443|[1.2676,2.4443]|
|  1.0| 1.169|0.9008| [1.169,0.9008]|
|  1.0|1.7411|2.1154|[1.7411,2.1154]|
+-----+------+------+---------------+
only showing top 20 rows



In [8]:
model = blr.fit(assembled._jdf)

In [9]:
scala_df = model.transform(assembled._jdf)
scala_df.createTempView('df')
sqlContext.table("df").show()

+-----+------+------+---------------+--------------------+--------------------+----------+
|label|    a1|    a2|       features|       rawPrediction|         probability|prediction|
+-----+------+------+---------------+--------------------+--------------------+----------+
|  1.0|2.6487|4.5192|[2.6487,4.5192]| [9.548369209080004]|[0.9999286876136482]|       1.0|
|  1.0|1.5438|2.4443|[1.5438,2.4443]|[6.9351049516733205]| [0.999027924906245]|       1.0|
|  1.0| 1.899|4.2409| [1.899,4.2409]|[13.198495949310274]|[0.9999981466167451]|       1.0|
|  1.0|2.4711|5.8097|[2.4711,5.8097]| [16.92555614256815]|[0.9999999554010797]|       1.0|
|  1.0| 3.359|6.4423| [3.359,6.4423]|  [14.0573200798583]|[0.9999992147948887]|       1.0|
|  1.0|3.2406|5.8097|[3.2406,5.8097]|[11.807562356025276]|[0.9999925520356462]|       1.0|
|  1.0|3.8128|6.3917|[3.8128,6.3917]|[10.796126126997208]|[0.9999795217376068]|       1.0|
|  1.0|4.4441|6.8725|[4.4441,6.8725]| [8.905729838746801]|[0.9998644087129379]|       1.0|