Загрузка PySpark

In [1]:
import findspark
findspark.init()

location = findspark.find()

import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
os.environ["SPARK_HOME"] = location

In [3]:
from pyspark.sql import SparkSession


spark = SparkSession.builder.getOrCreate()


sc = spark.sparkContext

In [4]:
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
%matplotlib inline

import warnings
warnings.filterwarnings('ignore')

In [5]:
df=spark.read.csv('./iris.csv', inferSchema=True, header=True)

In [6]:
df.describe().show()

+-------+------------------+-------------------+------------------+------------------+---------+
|summary|      sepal.length|        sepal.width|      petal.length|       petal.width|  variety|
+-------+------------------+-------------------+------------------+------------------+---------+
|  count|               150|                150|               150|               150|      150|
|   mean| 5.843333333333335|  3.057333333333334|3.7580000000000027| 1.199333333333334|     null|
| stddev|0.8280661279778637|0.43586628493669793|1.7652982332594662|0.7622376689603467|     null|
|    min|               4.3|                2.0|               1.0|               0.1|   Setosa|
|    max|               7.9|                4.4|               6.9|               2.5|Virginica|
+-------+------------------+-------------------+------------------+------------------+---------+



In [7]:
df1 = df.withColumnRenamed('sepal.length','sl') \
        .withColumnRenamed('sepal.width','sw') \
        .withColumnRenamed('petal.length','pl') \
        .withColumnRenamed('petal.width','pw')

In [8]:
df1.take(5)

[Row(sl=5.1, sw=3.5, pl=1.4, pw=0.2, variety='Setosa'),
 Row(sl=4.9, sw=3.0, pl=1.4, pw=0.2, variety='Setosa'),
 Row(sl=4.7, sw=3.2, pl=1.3, pw=0.2, variety='Setosa'),
 Row(sl=4.6, sw=3.1, pl=1.5, pw=0.2, variety='Setosa'),
 Row(sl=5.0, sw=3.6, pl=1.4, pw=0.2, variety='Setosa')]

In [23]:
df1.groupBy('variety').max().show()

+----------+-------+-------+-------+-------+
|   variety|max(sl)|max(sw)|max(pl)|max(pw)|
+----------+-------+-------+-------+-------+
| Virginica|    7.9|    3.8|    6.9|    2.5|
|    Setosa|    5.8|    4.4|    1.9|    0.6|
|Versicolor|    7.0|    3.4|    5.1|    1.8|
+----------+-------+-------+-------+-------+



In [48]:
from pyspark.ml import Pipeline
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler, StringIndexer, OneHotEncoder
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [29]:
pipeline = Pipeline(stages = [
                                StringIndexer(inputCol='variety', outputCol='variety_ind'),
                                VectorAssembler(inputCols = ['sl', 'sw', 'pl', 'pw'], outputCol='features')
                            ])

In [30]:
pipeline_trained = pipeline.fit(df1)

In [31]:
df_pipe = pipeline_trained.transform(df1)

In [32]:
train, test = df_pipe.randomSplit([0.8, 0.2], seed=12345)

In [33]:
model = LogisticRegression(featuresCol='features', labelCol='variety_ind')

In [35]:
model_fit = model.fit(train)

In [38]:
prediction = model_fit.transform(test)

In [39]:
prediction.show()

+---+---+---+---+----------+-----------+-----------------+--------------------+--------------------+----------+
| sl| sw| pl| pw|   variety|variety_ind|         features|       rawPrediction|         probability|prediction|
+---+---+---+---+----------+-----------+-----------------+--------------------+--------------------+----------+
|4.6|3.2|1.4|0.2|    Setosa|        0.0|[4.6,3.2,1.4,0.2]|[77.9244494508551...|[1.0,1.1613253875...|       0.0|
|5.0|3.0|1.6|0.2|    Setosa|        0.0|[5.0,3.0,1.6,0.2]|[57.9160733534977...|[1.0,4.1576902534...|       0.0|
|5.0|3.2|1.2|0.2|    Setosa|        0.0|[5.0,3.2,1.2,0.2]|[75.3193543970114...|[1.0,2.0755277215...|       0.0|
|5.0|3.5|1.3|0.3|    Setosa|        0.0|[5.0,3.5,1.3,0.3]|[86.2425022041546...|[1.0,1.2215769953...|       0.0|
|5.1|3.5|1.4|0.3|    Setosa|        0.0|[5.1,3.5,1.4,0.3]|[83.1350309559163...|[1.0,9.4206143888...|       0.0|
|5.4|3.4|1.5|0.4|    Setosa|        0.0|[5.4,3.4,1.5,0.4]|[67.9413178196848...|[1.0,1.9889286309...|    

In [49]:
ev = MulticlassClassificationEvaluator(labelCol='variety_ind')

In [50]:
ev.evaluate(prediction)

1.0

# Предскажем тип цветка на на случайно полученных данных, используя нашу модель

In [95]:
sl = np.random.uniform(3.0, 7.0, 15)

In [96]:
sw = np.random.uniform(1.5, 4.6, 15)

In [97]:
pl = np.random.uniform(1.0, 7.0, 15)

In [98]:
pw = np.random.uniform(0.0, 4.0, 15)

In [99]:
pdf_test = pd.DataFrame({
    'sl': sl,
    'sw': sw,
    'pl': pl,
    'pw': pw
})

In [100]:
df_test = spark.createDataFrame(pdf_test)

In [101]:
df_test.show()

+------------------+------------------+------------------+-------------------+
|                sl|                sw|                pl|                 pw|
+------------------+------------------+------------------+-------------------+
| 5.952638554701469| 3.391935077592834|5.6331131307015205| 2.0021376205324546|
| 4.273048741132579|1.5819270353644987|  3.50990092277556| 1.3228392767742232|
| 4.608715801467209|1.7680092708669597|  4.99404449754477| 0.5462739767305473|
| 3.145316237594371|3.7100617228146215| 6.575223388094388|   0.69459697222841|
| 4.721985695263712| 3.175898238253458|3.5918556463198716|0.04400457110761957|
| 4.085668349202013| 1.532522251817288| 4.668077049254033| 1.9461043549783712|
|  3.34467498040212|3.3750956546196873| 6.652471789734088| 3.6988872810234708|
| 5.424661035856896| 2.461736662464517| 1.962250604862513| 0.4016035609098907|
| 4.883399311202359|3.6417229619092675|3.5539991791349577| 2.8052097723301834|
| 5.577540670094567| 4.123396239701686|1.38769795252

In [102]:
pipeline1 = Pipeline(stages = [VectorAssembler(inputCols = ['sl', 'sw', 'pl', 'pw'], outputCol='features')])

In [104]:
pipeline_test = pipeline1.fit(df_test)

In [165]:
model_fit_test = model.fit(df_pipe)

In [105]:
df_test_pipe = pipeline_test.transform(df_test)

In [106]:
df_test_pipe.show()

+------------------+------------------+------------------+-------------------+--------------------+
|                sl|                sw|                pl|                 pw|            features|
+------------------+------------------+------------------+-------------------+--------------------+
| 5.952638554701469| 3.391935077592834|5.6331131307015205| 2.0021376205324546|[5.95263855470146...|
| 4.273048741132579|1.5819270353644987|  3.50990092277556| 1.3228392767742232|[4.27304874113257...|
| 4.608715801467209|1.7680092708669597|  4.99404449754477| 0.5462739767305473|[4.60871580146720...|
| 3.145316237594371|3.7100617228146215| 6.575223388094388|   0.69459697222841|[3.14531623759437...|
| 4.721985695263712| 3.175898238253458|3.5918556463198716|0.04400457110761957|[4.72198569526371...|
| 4.085668349202013| 1.532522251817288| 4.668077049254033| 1.9461043549783712|[4.08566834920201...|
|  3.34467498040212|3.3750956546196873| 6.652471789734088| 3.6988872810234708|[3.34467498040212...|


In [166]:
prediction1 = model_fit_test.transform(df_test_pipe)

In [167]:
prediction1.select('prediction').show()

+----------+
|prediction|
+----------+
|       2.0|
|       1.0|
|       1.0|
|       0.0|
|       0.0|
|       2.0|
|       2.0|
|       0.0|
|       2.0|
|       0.0|
|       2.0|
|       1.0|
|       0.0|
|       1.0|
|       2.0|
+----------+



In [168]:
from pyspark.sql.functions import when
prediction1 = prediction1.withColumn("predicted_variety", 
                                     when(prediction1.prediction == 0 ,'Setosa') \
                                    .when(prediction1.prediction == 1 ,'Versicolor') \
                                    .when(prediction1.prediction == 2 ,'Virginica'))

In [169]:
prediction1.select('prediction', 'predicted_variety').show()

+----------+-----------------+
|prediction|predicted_variety|
+----------+-----------------+
|       2.0|        Virginica|
|       1.0|       Versicolor|
|       1.0|       Versicolor|
|       0.0|           Setosa|
|       0.0|           Setosa|
|       2.0|        Virginica|
|       2.0|        Virginica|
|       0.0|           Setosa|
|       2.0|        Virginica|
|       0.0|           Setosa|
|       2.0|        Virginica|
|       1.0|       Versicolor|
|       0.0|           Setosa|
|       1.0|       Versicolor|
|       2.0|        Virginica|
+----------+-----------------+



In [171]:
sc.stop()