### 数据准备

In [1]:
from pyspark.sql import SparkSession, DataFrame as SparkDataFrame
from pyspark.sql.types import *

TYPE_MAP = {  # 定义数据类型映射
    # 下面的类均来自pyspark.sql.types
    "int": IntegerType,
    "bigint": IntegerType,
    "float": FloatType,
    "double": DoubleType,
    "string": StringType,
    "decimal": DecimalType,
    "bool": BooleanType,
    "date": DateType,
    "datetime": TimestampType
}

In [2]:
### 初始化spark session
spark_session = SparkSession.builder.master("local[4]").appName("FE:bank-loan").getOrCreate()
print(spark_session.version)

2.4.4


In [4]:
### 关于schema
"""
Note:
除了预先定义schema, 还可以:
1. 让spark自行推断schema(在read时传入inferschema='true'), 会造成较大的计算开销;
2. 另一种做法是先全部读成文本, 随后再进行类型转换(比较safe, 但后面的工作比较麻烦)
"""
# 加载定义好的schema文件
fields = []
discrete_fields = ["id"]
continuous_fields = ["id"]
with open("D:/python_projects/spark_learn/data/bank-full-schema.txt", "r") as f:
    for line in f:
        name, data_type = line.strip().split(",")
        fields.append(StructField(name, TYPE_MAP.get(data_type)(), nullable=True))
        if name != "id":
            if data_type == "string":
                discrete_fields.append(name)
            else:
                continuous_fields.append(name)
schema = StructType(fields)

In [6]:
### 加载数据集
filepath = "D:/python_projects/spark_learn/data/bank-full.csv"
df = spark_session.read.csv(filepath, schema=schema, sep=",", header="false")
# 这里python的传参更加直观, 也可以像scala一样使用.options(传入dict)
# 注意bool型的参数要以字符串的形式传入
df.show(10)
# 在目前的storage level上将该数据集注册为一张表便于以后进行查询
df.createOrReplaceTempView("bank_main")

+---+---+------------+--------+---------+-------+-------+-------+----+-------+---+-----+--------+--------+-----+--------+--------+---+
| id|age|         job| marital|education|default|balance|housing|loan|contact|day|month|duration|campaign|pdays|previous|poutcome|  y|
+---+---+------------+--------+---------+-------+-------+-------+----+-------+---+-----+--------+--------+-----+--------+--------+---+
|  1| 58|  management| married| tertiary|     no| 2143.0|    yes|  no|unknown|  5|  may|   261.0|       1|   -1|       0| unknown| no|
|  2| 44|  technician|  single|secondary|     no|   29.0|    yes|  no|unknown|  5|  may|   151.0|       1|   -1|       0| unknown| no|
|  3| 33|entrepreneur| married|secondary|     no|    2.0|    yes| yes|unknown|  5|  may|    76.0|       1|   -1|       0| unknown| no|
|  4| 47| blue-collar| married|  unknown|     no| 1506.0|    yes|  no|unknown|  5|  may|    92.0|       1|   -1|       0| unknown| no|
|  5| 33|     unknown|  single|  unknown|     no|    1.

### 数据处理

In [7]:
from pyspark.ml.feature import StringIndexer

In [13]:
# 对离散型特征进行编码
df_numerics = df.select(*continuous_fields)
df_discrete = df.select(*discrete_fields)
transformed_cols = []
for col in discrete_fields:
    if col != "id":
        label_encoder = StringIndexer(inputCol=col, outputCol="%s_code" % col)  # 初始化encoder
        col_df = df_discrete.select("id", col)  # 选中一个离散特征
        transformed_cols.append("%s_code" % col)
        encoded = label_encoder.fit(col_df).transform(col_df).select(["id", "%s_code" % col])  # 进行编码
        df_numerics = df_numerics.join(encoded, "id")  # 与原来所有数值特征拼接
df_numerics.show()

+---+---+-------+---+--------+--------+-----+--------+--------+------------+--------------+------------+------------+---------+------------+----------+-------------+------+
| id|age|balance|day|duration|campaign|pdays|previous|job_code|marital_code|education_code|default_code|housing_code|loan_code|contact_code|month_code|poutcome_code|y_code|
+---+---+-------+---+--------+--------+-----+--------+--------+------------+--------------+------------+------------+---------+------------+----------+-------------+------+
|  1| 58| 2143.0|  5|   261.0|       1|   -1|       0|     1.0|         0.0|           1.0|         0.0|         0.0|      0.0|         1.0|       0.0|          0.0|   0.0|
|  2| 44|   29.0|  5|   151.0|       1|   -1|       0|     2.0|         1.0|           0.0|         0.0|         0.0|      0.0|         1.0|       0.0|          0.0|   0.0|
|  3| 33|    2.0|  5|    76.0|       1|   -1|       0|     7.0|         0.0|           0.0|         0.0|         0.0|      1.0|        

In [15]:
### 建立决策树模型
# 正在学习中(2020-02-14)
from pyspark.ml.classification import DecisionTreeClassifier

TypeError: Invalid param value given for param "featuresCol". Could not convert <class 'list'> to string type

In [None]:
spark_session.stop()