In [2]:
from pyspark.sql import SparkSession
import pandas as pd
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler

In [3]:
spark = SparkSession\
     .builder\
     .appName("DecisionTree")\
     .master("local[*]") \
     .enableHiveSupport()\
     .getOrCreate()
data=spark.sql("select * from  ml.adult")
#OneHotEncoder不能处理空字符串。所以我们需要将数据集中的空字符串提前处理一下
df=data.na.replace('','NA')
cols = df.columns #和pandas一样看列名
df.printSchema()

root
 |-- age: string (nullable = true)
 |-- workclass: string (nullable = true)
 |-- fnlwgt: string (nullable = true)
 |-- education: string (nullable = true)
 |-- education_num: string (nullable = true)
 |-- marital_status: string (nullable = true)
 |-- occupation: string (nullable = true)
 |-- relationship: string (nullable = true)
 |-- race: string (nullable = true)
 |-- sex: string (nullable = true)
 |-- capital_gain: string (nullable = true)
 |-- capital_loss: string (nullable = true)
 |-- hours_per_week: string (nullable = true)
 |-- native_country: string (nullable = true)
 |-- salary: string (nullable = true)



In [4]:
#找到所有的string类型的变量
#dtypes用来看数据变量类型
cat_features = [item[0] for item in df.dtypes if item[1]=='string']
# 需要删除 salary列，否则标签泄露
cat_features.remove('salary')
#找到所有数字变量
num_features = [item[0] for item in df.dtypes if item[1]!='string']

In [5]:
stages = []
for col in cat_features:
    # 字符串转成索引
    string_index = StringIndexer(inputCol = col, outputCol = col + 'Index')
    # 转换为OneHot编码
    encoder = OneHotEncoder(inputCol=string_index.getOutputCol(), outputCol=col + "_one_hot")
    # 将每个字段的转换方式 放到stages中
    stages += [string_index, encoder]

In [6]:
# 将salary转换为索引
label_string_index = StringIndexer(inputCol = 'salary', outputCol = 'label')
# 添加到stages中
stages += [label_string_index]

In [7]:
# 类别变量 + 数值变量
assembler_cols = [c + "_one_hot" for c in cat_features] + num_features
assembler = VectorAssembler(inputCols=assembler_cols, outputCol="features")
stages += [assembler]


# 使用pipeline完成数据处理
pipeline = Pipeline(stages=stages)
pipeline_model = pipeline.fit(df)
df = pipeline_model.transform(df)
selected_cols = ["label", "features"] + cols
df = df.select(selected_cols)

In [8]:
pd.DataFrame(df.take(2), columns = df.columns)

Unnamed: 0,label,features,age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,salary
0,1.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",49,Private,101320,Assoc-acdm,12.0,Married-civ-spouse,,Wife,White,Female,0,1902,40,United-States,>=50k
1,1.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",44,Private,236746,Masters,14.0,Divorced,Exec-managerial,Not-in-family,White,Male,10520,0,45,United-States,>=50k


In [9]:
train, test = df.randomSplit([0.7, 0.3], seed=2021)
print(train.count())
print(test.count())

22777
9784


In [10]:
from pyspark.ml.classification import DecisionTreeClassifier

# 创建决策树模型
dt = DecisionTreeClassifier(featuresCol = 'features', labelCol = 'label', maxDepth = 3)
dt_model = dt.fit(train)

#查看决策树结构
print(dt_model._call_java('toDebugString'))

DecisionTreeClassificationModel (uid=DecisionTreeClassifier_add798222934) of depth 3 with 15 nodes
  If (feature 21758 in {0.0})
   If (feature 21789 in {0.0})
    If (feature 21798 in {1.0})
     Predict: 1.0
    Else (feature 21798 not in {1.0})
     Predict: 0.0
   Else (feature 21789 not in {0.0})
    If (feature 21920 in {1.0})
     Predict: 1.0
    Else (feature 21920 not in {1.0})
     Predict: 0.0
  Else (feature 21758 not in {0.0})
   If (feature 21789 in {0.0})
    If (feature 21797 in {1.0})
     Predict: 0.0
    Else (feature 21797 not in {1.0})
     Predict: 1.0
   Else (feature 21789 not in {0.0})
    If (feature 21729 in {1.0})
     Predict: 1.0
    Else (feature 21729 not in {1.0})
     Predict: 0.0



In [25]:
predictions = dt_model.transform(test)
predictions.printSchema()
predictions.select("label")

root
 |-- label: double (nullable = false)
 |-- features: vector (nullable = true)
 |-- age: string (nullable = true)
 |-- workclass: string (nullable = true)
 |-- fnlwgt: string (nullable = true)
 |-- education: string (nullable = true)
 |-- education_num: string (nullable = true)
 |-- marital_status: string (nullable = true)
 |-- occupation: string (nullable = true)
 |-- relationship: string (nullable = true)
 |-- race: string (nullable = true)
 |-- sex: string (nullable = true)
 |-- capital_gain: string (nullable = true)
 |-- capital_loss: string (nullable = true)
 |-- hours_per_week: string (nullable = true)
 |-- native_country: string (nullable = true)
 |-- salary: string (nullable = true)
 |-- rawPrediction: vector (nullable = true)
 |-- probability: vector (nullable = true)
 |-- prediction: double (nullable = false)



DataFrame[label: double]

In [13]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator

evaluator = BinaryClassificationEvaluator()
evaluator.evaluate(predictions)

0.7116663558561737