In [1]:
""" 导入模块 """
from pyspark.sql import SparkSession
from pyspark.ml.feature import StringIndexer
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.feature import OneHotEncoder
from pyspark.ml.classification import LogisticRegression

In [2]:
""" 创建Spark会话对象 """
spark = SparkSession.builder.appName('log_reg').getOrCreate()

In [4]:
""" 读取数据集 """
df = spark.read.csv('Log_Reg_dataset.csv',
                    inferSchema=True,
                    header=True)

In [5]:
""" 查看前20行数据 """
df.show()

+---------+---+--------------+--------+----------------+------+
|  Country|Age|Repeat_Visitor|Platform|Web_pages_viewed|Status|
+---------+---+--------------+--------+----------------+------+
|    India| 41|             1|   Yahoo|              21|     1|
|   Brazil| 28|             1|   Yahoo|               5|     0|
|   Brazil| 40|             0|  Google|               3|     0|
|Indonesia| 31|             1|    Bing|              15|     1|
| Malaysia| 32|             0|  Google|              15|     1|
|   Brazil| 32|             0|  Google|               3|     0|
|   Brazil| 32|             0|  Google|               6|     0|
|Indonesia| 27|             0|  Google|               9|     0|
|Indonesia| 32|             0|   Yahoo|               2|     0|
|Indonesia| 31|             1|    Bing|              16|     1|
| Malaysia| 27|             1|  Google|              21|     1|
|Indonesia| 29|             1|   Yahoo|               9|     1|
|Indonesia| 33|             1|   Yahoo| 

In [6]:
""" 查看数据形状结构 """
(df.count(), len(df.columns))

(20000, 6)

In [7]:
""" 查看数据字段类型 """
df.printSchema()

root
 |-- Country: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Repeat_Visitor: integer (nullable = true)
 |-- Platform: string (nullable = true)
 |-- Web_pages_viewed: integer (nullable = true)
 |-- Status: integer (nullable = true)



In [8]:
""" 查看统计指标 """
df.describe().show()

+-------+--------+-----------------+-----------------+--------+-----------------+------------------+
|summary| Country|              Age|   Repeat_Visitor|Platform| Web_pages_viewed|            Status|
+-------+--------+-----------------+-----------------+--------+-----------------+------------------+
|  count|   20000|            20000|            20000|   20000|            20000|             20000|
|   mean|    null|         28.53955|           0.5029|    null|           9.5533|               0.5|
| stddev|    null|7.888912950773227|0.500004090187782|    null|6.073903499824976|0.5000125004687693|
|    min|  Brazil|               17|                0|    Bing|                1|                 0|
|    max|Malaysia|              111|                1|   Yahoo|               29|                 1|
+-------+--------+-----------------+-----------------+--------+-----------------+------------------+



In [9]:
""" 查看Country每个国家出现的次数 """
df.groupBy('Country').count().show()

+---------+-----+
|  Country|count|
+---------+-----+
| Malaysia| 1218|
|    India| 4018|
|Indonesia|12178|
|   Brazil| 2586|
+---------+-----+



In [11]:
""" 查看Platform每个网站门户出现的次数 """
df.groupBy('Platform').count().show()

+--------+-----+
|Platform|count|
+--------+-----+
|   Yahoo| 9859|
|    Bing| 4360|
|  Google| 5781|
+--------+-----+



In [12]:
""" 查看Status每个类别出现的次数 """
df.groupBy('Status').count().show()

+------+-----+
|Status|count|
+------+-----+
|     1|10000|
|     0|10000|
+------+-----+



In [13]:
"""  """
df.groupBy('Country').mean().show()

+---------+------------------+-------------------+---------------------+--------------------+
|  Country|          avg(Age)|avg(Repeat_Visitor)|avg(Web_pages_viewed)|         avg(Status)|
+---------+------------------+-------------------+---------------------+--------------------+
| Malaysia|27.792282430213465| 0.5730706075533661|   11.192118226600986|  0.6568144499178982|
|    India|27.976854156296664| 0.5433051269288203|   10.727227476356397|  0.6212045793927327|
|Indonesia| 28.43159796354081| 0.5207751683363442|    9.985711939563148|  0.5422893742814913|
|   Brazil|30.274168600154677|  0.322892498066512|    4.921113689095128|0.038669760247486466|
+---------+------------------+-------------------+---------------------+--------------------+



In [14]:
df.groupBy('Platform').mean().show()

+--------+------------------+-------------------+---------------------+------------------+
|Platform|          avg(Age)|avg(Repeat_Visitor)|avg(Web_pages_viewed)|       avg(Status)|
+--------+------------------+-------------------+---------------------+------------------+
|   Yahoo|28.569226087838523| 0.5094837204584644|    9.599655137437875|0.5071508266558474|
|    Bing| 28.68394495412844| 0.4720183486238532|    9.114908256880733|0.4559633027522936|
|  Google|28.380038055699707| 0.5149628092025601|    9.804878048780488|0.5210171250648676|
+--------+------------------+-------------------+---------------------+------------------+



In [15]:
df.groupBy('Status').mean().show()

+------+--------+-------------------+---------------------+-----------+
|Status|avg(Age)|avg(Repeat_Visitor)|avg(Web_pages_viewed)|avg(Status)|
+------+--------+-------------------+---------------------+-----------+
|     1| 26.5435|             0.7019|              14.5617|        1.0|
|     0| 30.5356|             0.3039|               4.5449|        0.0|
+------+--------+-------------------+---------------------+-----------+



In [16]:
""" 特征工程 """
search_engine_indexer = StringIndexer(inputCol='Platform', outputCol='Platform_Num').fit(df)
df = search_engine_indexer.transform(df)

In [19]:
df.show(3, False)

+-------+---+--------------+--------+----------------+------+------------+
|Country|Age|Repeat_Visitor|Platform|Web_pages_viewed|Status|Platform_Num|
+-------+---+--------------+--------+----------------+------+------------+
|India  |41 |1             |Yahoo   |21              |1     |0.0         |
|Brazil |28 |1             |Yahoo   |5               |0     |0.0         |
|Brazil |40 |0             |Google  |3               |0     |1.0         |
+-------+---+--------------+--------+----------------+------+------------+
only showing top 3 rows



In [20]:
""" 独热编码 """
search_engine_encoder = OneHotEncoder(inputCol='Platform_Num', outputCol='Platform_Vector')
df = search_engine_encoder.transform(df)

In [21]:
df.show()

+---------+---+--------------+--------+----------------+------+------------+---------------+
|  Country|Age|Repeat_Visitor|Platform|Web_pages_viewed|Status|Platform_Num|Platform_Vector|
+---------+---+--------------+--------+----------------+------+------------+---------------+
|    India| 41|             1|   Yahoo|              21|     1|         0.0|  (2,[0],[1.0])|
|   Brazil| 28|             1|   Yahoo|               5|     0|         0.0|  (2,[0],[1.0])|
|   Brazil| 40|             0|  Google|               3|     0|         1.0|  (2,[1],[1.0])|
|Indonesia| 31|             1|    Bing|              15|     1|         2.0|      (2,[],[])|
| Malaysia| 32|             0|  Google|              15|     1|         1.0|  (2,[1],[1.0])|
|   Brazil| 32|             0|  Google|               3|     0|         1.0|  (2,[1],[1.0])|
|   Brazil| 32|             0|  Google|               6|     0|         1.0|  (2,[1],[1.0])|
|Indonesia| 27|             0|  Google|               9|     0|       

In [22]:
country_indexer = StringIndexer(inputCol='Country', outputCol='Country_Num').fit(df)
df = country_indexer.transform(df)

In [23]:
df.show()

+---------+---+--------------+--------+----------------+------+------------+---------------+-----------+
|  Country|Age|Repeat_Visitor|Platform|Web_pages_viewed|Status|Platform_Num|Platform_Vector|Country_Num|
+---------+---+--------------+--------+----------------+------+------------+---------------+-----------+
|    India| 41|             1|   Yahoo|              21|     1|         0.0|  (2,[0],[1.0])|        1.0|
|   Brazil| 28|             1|   Yahoo|               5|     0|         0.0|  (2,[0],[1.0])|        2.0|
|   Brazil| 40|             0|  Google|               3|     0|         1.0|  (2,[1],[1.0])|        2.0|
|Indonesia| 31|             1|    Bing|              15|     1|         2.0|      (2,[],[])|        0.0|
| Malaysia| 32|             0|  Google|              15|     1|         1.0|  (2,[1],[1.0])|        3.0|
|   Brazil| 32|             0|  Google|               3|     0|         1.0|  (2,[1],[1.0])|        2.0|
|   Brazil| 32|             0|  Google|               6

In [24]:
country_encoder = OneHotEncoder(inputCol='Country_Num', outputCol='Country_Vector')
df = country_encoder.transform(df)

In [25]:
df_assembler = VectorAssembler(inputCols=['Platform_Vector', 'Country_Vector', 'Age', 'Repeat_Visitor', 'Web_pages_viewed'], outputCol='features')
df = df_assembler.transform(df)
df.printSchema()

root
 |-- Country: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Repeat_Visitor: integer (nullable = true)
 |-- Platform: string (nullable = true)
 |-- Web_pages_viewed: integer (nullable = true)
 |-- Status: integer (nullable = true)
 |-- Platform_Num: double (nullable = false)
 |-- Platform_Vector: vector (nullable = true)
 |-- Country_Num: double (nullable = false)
 |-- Country_Vector: vector (nullable = true)
 |-- features: vector (nullable = true)



In [26]:
model_df = df.select(['features', 'Status'])

In [28]:
model_df.show(3, False)

+-----------------------------------+------+
|features                           |Status|
+-----------------------------------+------+
|[1.0,0.0,0.0,1.0,0.0,41.0,1.0,21.0]|1     |
|[1.0,0.0,0.0,0.0,1.0,28.0,1.0,5.0] |0     |
|(8,[1,4,5,7],[1.0,1.0,40.0,3.0])   |0     |
+-----------------------------------+------+
only showing top 3 rows



In [29]:
""" 划分数据集 """
train_df, test_df = model_df.randomSplit([0.75, 0.25])

In [30]:
""" 构建和训练逻辑回归模型 """
log_reg = LogisticRegression(labelCol='Status').fit(train_df)

In [32]:
""" 训练结果 """
train_results = log_reg.evaluate(train_df).predictions
train_results.filter(train_results['Status'] == 1).filter(train_results['prediction']==1).select(['Status', 'prediction','probability']).show(10, False)

+------+----------+----------------------------------------+
|Status|prediction|probability                             |
+------+----------+----------------------------------------+
|1     |1.0       |[0.3029767325575794,0.6970232674424207] |
|1     |1.0       |[0.3029767325575794,0.6970232674424207] |
|1     |1.0       |[0.16933486681801524,0.8306651331819848]|
|1     |1.0       |[0.16933486681801524,0.8306651331819848]|
|1     |1.0       |[0.16933486681801524,0.8306651331819848]|
|1     |1.0       |[0.16933486681801524,0.8306651331819848]|
|1     |1.0       |[0.0872619606177328,0.9127380393822673] |
|1     |1.0       |[0.0872619606177328,0.9127380393822673] |
|1     |1.0       |[0.0872619606177328,0.9127380393822673] |
|1     |1.0       |[0.0872619606177328,0.9127380393822673] |
+------+----------+----------------------------------------+
only showing top 10 rows



In [33]:
results = log_reg.evaluate(test_df).predictions
results.printSchema()

root
 |-- features: vector (nullable = true)
 |-- Status: integer (nullable = true)
 |-- rawPrediction: vector (nullable = true)
 |-- probability: vector (nullable = true)
 |-- prediction: double (nullable = false)



In [34]:
results.select(['Status', 'prediction']).show(10, False)

+------+----------+
|Status|prediction|
+------+----------+
|0     |0.0       |
|0     |0.0       |
|0     |0.0       |
|0     |0.0       |
|0     |0.0       |
|0     |0.0       |
|0     |0.0       |
|0     |0.0       |
|0     |0.0       |
|0     |0.0       |
+------+----------+
only showing top 10 rows



In [35]:
""" 混淆矩阵 """
tp = results[(results.Status == 1) & (results.prediction == 1)].count()
tn = results[(results.Status == 0) & (results.prediction == 0)].count()
fp = results[(results.Status == 0) & (results.prediction == 1)].count()
fn = results[(results.Status == 1) & (results.prediction == 0)].count()

In [36]:
""" 准确率 """
accuracy = float((tp + tn) / (results.count()))
print(accuracy)

0.9382863990413421


In [37]:
""" 召回率 """
recall = float(tp) / (tp + fn)
print(recall)

0.9395129615082483


In [38]:
""" 精度 """
precision = float(tp) / (tp + fp)
print(precision)

0.939144091087554


In [39]:
""" F1分数 """
f1_score = 2 * ((precision * recall) / (precision + recall))
print(f1_score)

0.9393284900844295


In [None]:
""" AUC """

