In [1]:
# 创建SparkSession对象
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('random_forest').getOrCreate()

In [2]:
# 读取数据集
path = r"G:\LKM\PySark机器学习、自然语言处理与推荐系统\随机森林\affairs.csv"
df = spark.read.csv(path, inferSchema=True, header=True)

In [3]:
# 检查数据集的形状结构
(df.count(), len(df.columns))

(6366, 6)

In [4]:
# 验证输入值的数据类型
df.printSchema()

root
 |-- rate_marriage: integer (nullable = true)
 |-- age: double (nullable = true)
 |-- yrs_married: double (nullable = true)
 |-- children: double (nullable = true)
 |-- religious: integer (nullable = true)
 |-- affairs: integer (nullable = true)



In [5]:
# 查看前五条数据
df.show(5)

+-------------+----+-----------+--------+---------+-------+
|rate_marriage| age|yrs_married|children|religious|affairs|
+-------------+----+-----------+--------+---------+-------+
|            5|32.0|        6.0|     1.0|        3|      0|
|            4|22.0|        2.5|     0.0|        2|      0|
|            3|32.0|        9.0|     3.0|        3|      1|
|            3|27.0|       13.0|     3.0|        1|      1|
|            4|22.0|        2.5|     0.0|        1|      1|
+-------------+----+-----------+--------+---------+-------+
only showing top 5 rows



In [6]:
# 使用describe函数来检查数据集的统计指标
df.describe().select('summary', 'rate_marriage', 'age', 'yrs_married', 'children', 'religious').show()

+-------+------------------+------------------+-----------------+------------------+------------------+
|summary|     rate_marriage|               age|      yrs_married|          children|         religious|
+-------+------------------+------------------+-----------------+------------------+------------------+
|  count|              6366|              6366|             6366|              6366|              6366|
|   mean| 4.109644989004084|29.082862079798932| 9.00942507068803|1.3968740182218033|2.4261702796104303|
| stddev|0.9614295945655025| 6.847881883668817|7.280119972766412| 1.433470828560344|0.8783688402641785|
|    min|                 1|              17.5|              0.5|               0.0|                 1|
|    max|                 5|              42.0|             23.0|               5.5|                 4|
+-------+------------------+------------------+-----------------+------------------+------------------+



In [7]:
# 查看类别出现次数
df.groupBy('affairs').count().show()

+-------+-----+
|affairs|count|
+-------+-----+
|      1| 2053|
|      0| 4313|
+-------+-----+



In [8]:
# 查看婚姻评分人数情况
df.groupBy('rate_marriage').count().show()

+-------------+-----+
|rate_marriage|count|
+-------------+-----+
|            1|   99|
|            3|  993|
|            5| 2684|
|            4| 2242|
|            2|  348|
+-------------+-----+



In [9]:
# 查看婚姻评分与婚外恋的关系
df.groupBy('rate_marriage', 'affairs').count().orderBy('rate_marriage', 'affairs', 'count', ascending=True).show()

+-------------+-------+-----+
|rate_marriage|affairs|count|
+-------------+-------+-----+
|            1|      0|   25|
|            1|      1|   74|
|            2|      0|  127|
|            2|      1|  221|
|            3|      0|  446|
|            3|      1|  547|
|            4|      0| 1518|
|            4|      1|  724|
|            5|      0| 2197|
|            5|      1|  487|
+-------------+-------+-----+



In [10]:
# 查看宗教与婚外恋的关系
df.groupBy('religious', 'affairs').count().orderBy('religious', 'affairs', 'count', ascending=True).show()

+---------+-------+-----+
|religious|affairs|count|
+---------+-------+-----+
|        1|      0|  613|
|        1|      1|  408|
|        2|      0| 1448|
|        2|      1|  819|
|        3|      0| 1715|
|        3|      1|  707|
|        4|      0|  537|
|        4|      1|  119|
+---------+-------+-----+



In [11]:
# 子女数量与婚外恋的关系
df.groupBy('children', 'affairs').count().orderBy('children', 'affairs', 'count', ascending=True).show()

+--------+-------+-----+
|children|affairs|count|
+--------+-------+-----+
|     0.0|      0| 1912|
|     0.0|      1|  502|
|     1.0|      0|  747|
|     1.0|      1|  412|
|     2.0|      0|  873|
|     2.0|      1|  608|
|     3.0|      0|  460|
|     3.0|      1|  321|
|     4.0|      0|  197|
|     4.0|      1|  131|
|     5.5|      0|  124|
|     5.5|      1|   79|
+--------+-------+-----+



In [12]:
# 婚外恋的情况
df.groupBy('affairs').mean().show()

+-------+------------------+------------------+------------------+------------------+------------------+------------+
|affairs|avg(rate_marriage)|          avg(age)|  avg(yrs_married)|     avg(children)|    avg(religious)|avg(affairs)|
+-------+------------------+------------------+------------------+------------------+------------------+------------+
|      1|3.6473453482708234|30.537018996590355|11.152459814905017|1.7289332683877252| 2.261568436434486|         1.0|
|      0| 4.329700904242986| 28.39067934152562| 7.989334569904939|1.2388128912589844|2.5045212149316023|         0.0|
+-------+------------------+------------------+------------------+------------------+------------------+------------+



In [13]:
# 使用Spark的VectorAssembler来创建合并所有输入特征的单个向量
from pyspark.ml.feature import VectorAssembler

# 将所有的列组装成单个向量，该向量会充当模型的输入特征
df_assembler = VectorAssembler(inputCols=['rate_marriage', 'age', 'yrs_married', 'children', 'religious'], outputCol='features')
df = df_assembler.transform(df)
df.printSchema()

root
 |-- rate_marriage: integer (nullable = true)
 |-- age: double (nullable = true)
 |-- yrs_married: double (nullable = true)
 |-- children: double (nullable = true)
 |-- religious: integer (nullable = true)
 |-- affairs: integer (nullable = true)
 |-- features: vector (nullable = true)



In [14]:
# 查看特征工程后的数据
df.select(['features', 'affairs']).show(10, False)

+-----------------------+-------+
|features               |affairs|
+-----------------------+-------+
|[5.0,32.0,6.0,1.0,3.0] |0      |
|[4.0,22.0,2.5,0.0,2.0] |0      |
|[3.0,32.0,9.0,3.0,3.0] |1      |
|[3.0,27.0,13.0,3.0,1.0]|1      |
|[4.0,22.0,2.5,0.0,1.0] |1      |
|[4.0,37.0,16.5,4.0,3.0]|1      |
|[5.0,27.0,9.0,1.0,1.0] |1      |
|[4.0,27.0,9.0,0.0,2.0] |1      |
|[5.0,37.0,23.0,5.5,2.0]|1      |
|[5.0,37.0,23.0,5.5,2.0]|1      |
+-----------------------+-------+
only showing top 10 rows



In [15]:
# 截取训练模型的数据集
model_df = df.select(['features', 'affairs'])

In [16]:
# 划分数据集
train_df, test_df = model_df.randomSplit([0.75, 0.25])

# 查看训练集的分布
train_df.groupBy('affairs').count().show()

+-------+-----+
|affairs|count|
+-------+-----+
|      1| 1573|
|      0| 3268|
+-------+-----+



In [17]:
# 查看训练集的分布
test_df.groupBy('affairs').count().show()

+-------+-----+
|affairs|count|
+-------+-----+
|      1|  480|
|      0| 1045|
+-------+-----+



In [18]:
# 构建和训练随机森林模型
from pyspark.ml.classification import RandomForestClassifier

model = RandomForestClassifier(labelCol='affairs', numTrees=50).fit(train_df)

In [19]:
# 基于测试数据进项评估
model_predict = model.transform(test_df)
model_predict.show()

+--------------------+-------+--------------------+--------------------+----------+
|            features|affairs|       rawPrediction|         probability|prediction|
+--------------------+-------+--------------------+--------------------+----------+
|[1.0,27.0,2.5,0.0...|      1|[21.5524713378311...|[0.43104942675662...|       1.0|
|[1.0,27.0,6.0,0.0...|      0|[18.5396362370975...|[0.37079272474195...|       1.0|
|[1.0,27.0,6.0,1.0...|      1|[19.2769148479397...|[0.38553829695879...|       1.0|
|[1.0,27.0,6.0,2.0...|      1|[18.4123409983753...|[0.36824681996750...|       1.0|
|[1.0,27.0,9.0,2.0...|      1|[15.9422892279239...|[0.31884578455847...|       1.0|
|[1.0,27.0,13.0,2....|      1|[15.8192193765475...|[0.31638438753095...|       1.0|
|[1.0,32.0,9.0,3.0...|      1|[15.9567556143388...|[0.31913511228677...|       1.0|
|[1.0,32.0,13.0,2....|      1|[15.8336857629624...|[0.31667371525924...|       1.0|
|[1.0,32.0,16.5,2....|      1|[16.1003696511144...|[0.32200739302228...|    

In [20]:
# 查看预测结果分布
model_predict.groupBy('prediction').count().show()

+----------+-----+
|prediction|count|
+----------+-----+
|       0.0| 1248|
|       1.0|  277|
+----------+-----+



In [21]:
# 导入评估模块
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.evaluation import BinaryClassificationEvaluator


In [22]:
# 计算准确率
accuracy = MulticlassClassificationEvaluator(labelCol='affairs', metricName='accuracy').evaluate(model_predict)
accuracy

0.7239344262295082

In [23]:
# 计算精度
precision = MulticlassClassificationEvaluator(labelCol='affairs', metricName='weightedPrecision').evaluate(model_predict)
precision

0.7048322187370539

In [24]:
# 计算AUC曲线下的面积
auc = BinaryClassificationEvaluator(labelCol='affairs').evaluate(model_predict)
auc

0.7294946172248804

In [25]:
# 计算特征的重要性
model.featureImportances

SparseVector(5, {0: 0.5912, 1: 0.0224, 2: 0.2394, 3: 0.063, 4: 0.0841})

In [26]:
# 将特征映射数值
df.schema['features'].metadata['ml_attr']['attrs']

{'numeric': [{'idx': 0, 'name': 'rate_marriage'},
  {'idx': 1, 'name': 'age'},
  {'idx': 2, 'name': 'yrs_married'},
  {'idx': 3, 'name': 'children'},
  {'idx': 4, 'name': 'religious'}]}

In [29]:
# 保存模型
from pyspark.ml.classification import RandomForestClassificationModel

model.save(r"G:\LKM\PySark机器学习、自然语言处理与推荐系统\随机森林")

Py4JJavaError: An error occurred while calling o2927.save.
: java.io.IOException: Path G:\LKM\PySark机器学习、自然语言处理与推荐系统\随机森林 already exists. To overwrite it, please use write.overwrite().save(path) for Scala and use write().overwrite().save(path) for Java and Python.
	at org.apache.spark.ml.util.FileSystemOverwrite.handleOverwrite(ReadWrite.scala:683)
	at org.apache.spark.ml.util.MLWriter.save(ReadWrite.scala:167)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.lang.Thread.run(Thread.java:748)
