In [40]:
# 1. 导入需要的库，初始化SparkSession

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sum, avg, count, max, min, when, month, year, expr, lit
from pyspark.sql.window import Window
from pyspark.ml.feature import VectorAssembler, StringIndexer, StandardScaler
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.sql.types import DoubleType

spark = SparkSession.builder \
    .appName("ComplexSalesAnalysis") \
    .master("local[*]") \
    .config("spark.driver.memory", "2g") \
    .config("spark.executor.memory", "2g") \
    .getOrCreate()

In [41]:
# 2. 加载数据集
sales_df = spark.read.csv("sales_data.csv", header=True, inferSchema=True)
print(f"数据集行数: {sales_df.count()}")

# 显示数据结构
sales_df.printSchema()

数据集行数: 113036
root
 |-- Date: string (nullable = true)
 |-- Day: integer (nullable = true)
 |-- Month: string (nullable = true)
 |-- Year: integer (nullable = true)
 |-- Customer_Age: integer (nullable = true)
 |-- Age_Group: string (nullable = true)
 |-- Customer_Gender: string (nullable = true)
 |-- Country: string (nullable = true)
 |-- State: string (nullable = true)
 |-- Product_Category: string (nullable = true)
 |-- Sub_Category: string (nullable = true)
 |-- Product: string (nullable = true)
 |-- Order_Quantity: integer (nullable = true)
 |-- Unit_Cost: integer (nullable = true)
 |-- Unit_Price: integer (nullable = true)
 |-- Profit: integer (nullable = true)
 |-- Cost: integer (nullable = true)
 |-- Revenue: integer (nullable = true)



In [42]:
# 3. 创建多个表并连接

# 创建产品信息表
print("创建产品信息表...")
product_info = sales_df.select(
    "Product",
    "Product_Category", 
    "Sub_Category",
    "Unit_Cost",
    "Unit_Price"
).distinct()
print(f"产品数量: {product_info.count()}")

# 创建客户信息表
print("创建客户信息表...")
customer_info = sales_df.select(
    "Customer_Age",
    "Age_Group",
    "Customer_Gender",
    "Country",
    "State"
).distinct()
print(f"客户数量: {customer_info.count()}")

# 创建订单信息表
print("创建订单信息表...")
order_info = sales_df.select(
    "Date",
    "Product",
    "Order_Quantity",
    "Profit",
    "Revenue"
)
print(f"订单数量: {order_info.count()}")

# 连接产品信息和订单信息
print("\n连接产品信息和订单信息...")
joined_df = order_info.join(
    product_info, 
    on="Product", 
    how="inner"
)
print(f"连接后数据行数: {joined_df.count()}")
joined_df.show(5)

# 连接三个表
print("\n连接三个表...")
full_join = sales_df.join(
    product_info, 
    on=["Product", "Product_Category", "Sub_Category"],
    how="inner"
).join(
    customer_info,
    on=["Customer_Age", "Age_Group", "Customer_Gender", "Country", "State"],
    how="inner"
)
print(f"完全连接后数据行数: {full_join.count()}")
full_join.select("Date", "Product", "Customer_Age", "Customer_Gender", "Revenue").show(5)

创建产品信息表...
产品数量: 138
创建客户信息表...
客户数量: 2162
创建订单信息表...
订单数量: 113036

连接产品信息和订单信息...
连接后数据行数: 139696
+--------------------+---------+--------------+------+-------+----------------+------------+---------+----------+
|             Product|     Date|Order_Quantity|Profit|Revenue|Product_Category|Sub_Category|Unit_Cost|Unit_Price|
+--------------------+---------+--------------+------+-------+----------------+------------+---------+----------+
|Fender Set - Moun...|2016/2/20|             1|     9|     17|        Clothing|      Gloves|        8|        22|
|Fender Set - Moun...|2014/2/20|             2|    18|     34|        Clothing|      Gloves|        8|        22|
|Fender Set - Moun...|2015/11/1|            16|   147|    275|        Clothing|      Gloves|        8|        22|
|Fender Set - Moun...|2013/11/1|            18|   165|    309|        Clothing|      Gloves|        8|        22|
|Fender Set - Moun...|2015/8/31|            16|   157|    285|        Clothing|      Gloves|        8| 

In [43]:
# 4. 使用Spark SQL执行复杂查询

# 注册临时表
sales_df.createOrReplaceTempView("sales")
product_info.createOrReplaceTempView("products")
customer_info.createOrReplaceTempView("customers")

# 查询每个产品类别的总收入和平均利润
print("每个产品类别的总收入和平均利润:")
category_stats = spark.sql("""
    SELECT 
        Product_Category,
        COUNT(*) as Order_Count,
        SUM(Revenue) as Total_Revenue,
        AVG(Profit) as Avg_Profit,
        SUM(Profit) as Total_Profit,
        ROUND((SUM(Profit) / SUM(Revenue)) * 100, 2) as Profit_Margin_Percent
    FROM sales
    GROUP BY Product_Category
    ORDER BY Total_Revenue DESC
""")
category_stats.show()

# 查询每个国家的销售情况
print("每个国家的销售情况:")
country_stats = spark.sql("""
    SELECT 
        Country,
        COUNT(*) as Order_Count,
        SUM(Revenue) as Total_Revenue,
        SUM(Profit) as Total_Profit,
        AVG(Customer_Age) as Avg_Customer_Age
    FROM sales
    GROUP BY Country
    ORDER BY Total_Revenue DESC
""")
country_stats.show()

# 查询每月销售趋势
print("每月销售趋势:")
monthly_trend = spark.sql("""
    SELECT 
        Year,
        Month,
        COUNT(*) as Order_Count,
        SUM(Revenue) as Monthly_Revenue,
        SUM(Profit) as Monthly_Profit,
        AVG(Order_Quantity) as Avg_Order_Quantity
    FROM sales
    GROUP BY Year, Month
    ORDER BY Year, 
        CASE Month
            WHEN 'January' THEN 1
            WHEN 'February' THEN 2
            WHEN 'March' THEN 3
            WHEN 'April' THEN 4
            WHEN 'May' THEN 5
            WHEN 'June' THEN 6
            WHEN 'July' THEN 7
            WHEN 'August' THEN 8
            WHEN 'September' THEN 9
            WHEN 'October' THEN 10
            WHEN 'November' THEN 11
            WHEN 'December' THEN 12
        END
""")
monthly_trend.show()

# 使用窗口函数：查询每个类别的前3名产品
print("每个产品类别的前3名畅销产品:")
product_ranking = spark.sql("""
    SELECT * FROM (
        SELECT 
            Product_Category,
            Product,
            SUM(Revenue) as Total_Revenue,
            RANK() OVER (PARTITION BY Product_Category ORDER BY SUM(Revenue) DESC) as Revenue_Rank
        FROM sales
        GROUP BY Product_Category, Product
    ) ranked
    WHERE Revenue_Rank <= 3
    ORDER BY Product_Category, Revenue_Rank
""")
product_ranking.show()

# 复杂条件查询：高价值客户识别
print("高价值客户识别:")
high_value_customers = spark.sql("""
    SELECT 
        Customer_Age,
        Age_Group,
        Customer_Gender,
        Country,
        COUNT(*) as Order_Count,
        SUM(Revenue) as Total_Spent,
        AVG(Revenue) as Avg_Order_Value,
        MAX(Revenue) as Max_Order_Value
    FROM sales
    GROUP BY Customer_Age, Age_Group, Customer_Gender, Country
    HAVING SUM(Revenue) > 1000  -- 只显示总消费超过1000的客户
    ORDER BY Total_Spent DESC
    LIMIT 10
""")
high_value_customers.show()

每个产品类别的总收入和平均利润:
+----------------+-----------+-------------+------------------+------------+---------------------+
|Product_Category|Order_Count|Total_Revenue|        Avg_Profit|Total_Profit|Profit_Margin_Percent|
+----------------+-----------+-------------+------------------+------------+---------------------+
|           Bikes|      25982|     61782134| 789.7496728504349|    20519276|                33.21|
|     Accessories|      70120|     15117992|126.38871933827724|     8862377|                58.62|
|        Clothing|      16934|      8370882|167.67727648517774|     2839447|                33.92|
+----------------+-----------+-------------+------------------+------------+---------------------+

每个国家的销售情况:
+--------------+-----------+-------------+------------+------------------+
|       Country|Order_Count|Total_Revenue|Total_Profit|  Avg_Customer_Age|
+--------------+-----------+-------------+------------+------------------+
| United States|      39206|     27975547|    1107364

In [44]:
# 5. 使用DataFrame API进行复杂转换

# 添加派生列
print("添加派生列:")
enriched_df = sales_df.withColumn(
    "Profit_Margin", 
    expr("ROUND((Profit / Revenue) * 100, 2)")
).withColumn(
    "Unit_Profit",
    expr("Unit_Price - Unit_Cost")
).withColumn(
    "Customer_Segment",
    when(col("Customer_Age") < 25, "Youth")
    .when((col("Customer_Age") >= 25) & (col("Customer_Age") < 40), "Young_Adult")
    .when((col("Customer_Age") >= 40) & (col("Customer_Age") < 60), "Middle_Aged")
    .otherwise("Senior")
).withColumn(
    "Revenue_Category",
    when(col("Revenue") < 500, "Low")
    .when((col("Revenue") >= 500) & (col("Revenue") < 1000), "Medium")
    .otherwise("High")
)
enriched_df.select("Product", "Customer_Age", "Customer_Segment", "Revenue", "Revenue_Category", "Profit_Margin").show(5)

# 使用窗口函数计算移动平均
print("计算每个产品的移动平均收入:")
window_spec = Window.partitionBy("Product").orderBy("Date").rowsBetween(-2, 0)
moving_avg_df = enriched_df.withColumn(
    "Moving_Avg_Revenue",
    avg("Revenue").over(window_spec)
)
moving_avg_df.select("Date", "Product", "Revenue", "Moving_Avg_Revenue").show(5)

# 数据透视
print("数据透视：按产品类别和客户性别统计收入:")
pivot_df = sales_df.groupBy("Product_Category") \
    .pivot("Customer_Gender") \
    .agg(
        sum("Revenue").alias("Total_Revenue"),
        count(lit(1)).alias("Order_Count")
    )
pivot_df.show()

添加派生列:
+-------------------+------------+----------------+-------+----------------+-------------+
|            Product|Customer_Age|Customer_Segment|Revenue|Revenue_Category|Profit_Margin|
+-------------------+------------+----------------+-------+----------------+-------------+
|Hitch Rack - 4-Bike|          19|           Youth|    950|          Medium|        62.11|
|Hitch Rack - 4-Bike|          19|           Youth|    950|          Medium|        62.11|
|Hitch Rack - 4-Bike|          49|     Middle_Aged|   2401|            High|        56.89|
|Hitch Rack - 4-Bike|          49|     Middle_Aged|   2088|            High|         56.9|
|Hitch Rack - 4-Bike|          47|     Middle_Aged|    418|             Low|        56.94|
+-------------------+------------+----------------+-------+----------------+-------------+
only showing top 5 rows
计算每个产品的移动平均收入:
+---------+------------+-------+------------------+
|     Date|     Product|Revenue|Moving_Avg_Revenue|
+---------+------------+-------

In [45]:
# 6. 使用MLlib进行机器学习任务
# 任务：预测收入类别（Revenue_Category）
# 这是一个分类任务，基于产品、客户、订单等特征预测订单收入水平（Low/Medium/High）

print("准备特征工程...")

# 创建收入类别特征
ml_df = sales_df.withColumn(
    "Revenue_Category",
    when(col("Revenue") < 500, "Low")
    .when((col("Revenue") >= 500) & (col("Revenue") < 1000), "Medium")
    .otherwise("High")
)

# 选择特征和标签（移除Order_Quantity、Unit_Price、Unit_Cost，避免数据泄露）
# Revenue = Order_Quantity × Unit_Price，所以这些特征会导致数据泄露
ml_df = ml_df.select(
    col("Product_Category").alias("Product_Category"),
    col("Sub_Category").alias("Sub_Category"),
    col("Country").alias("Country"),
    col("State").alias("State"),
    col("Customer_Gender").alias("Customer_Gender"),
    col("Customer_Age").cast(DoubleType()).alias("Customer_Age"),
    col("Age_Group").alias("Age_Group"),
    col("Year").cast(DoubleType()).alias("Year"),
    col("Month").alias("Month"),
    col("Revenue_Category").alias("label")
).fillna(0)

# 对分类特征进行编码
print("编码分类特征...")
categorical_cols = ["Product_Category", "Sub_Category", "Country", "State", 
                   "Customer_Gender", "Age_Group", "Month"]
indexers = [StringIndexer(inputCol=col, outputCol=col+"_index", handleInvalid="keep") 
            for col in categorical_cols]

indexed_df = ml_df
for indexer in indexers:
    indexer_model = indexer.fit(indexed_df)
    indexed_df = indexer_model.transform(indexed_df)

# 对标签进行编码
label_indexer = StringIndexer(inputCol="label", outputCol="label_index", handleInvalid="keep")
label_indexer_model = label_indexer.fit(indexed_df)
indexed_df = label_indexer_model.transform(indexed_df)

# 创建特征向量（只使用业务特征，避免数据泄露）
numeric_features = ["Customer_Age", "Year"]
index_features = [col+"_index" for col in categorical_cols]
feature_columns = numeric_features + index_features

assembler = VectorAssembler(
    inputCols=feature_columns,
    outputCol="features_raw"
)

ml_data = assembler.transform(indexed_df)

# 特征标准化
print("特征标准化...")
scaler = StandardScaler(
    inputCol="features_raw",
    outputCol="features",
    withStd=True,
    withMean=True
)

scaler_model = scaler.fit(ml_data)
ml_data = scaler_model.transform(ml_data)

# 拆分训练集和测试集
(train_data, test_data) = ml_data.randomSplit([0.7, 0.3], seed=42)

# 使用随机森林分类器
print("训练随机森林分类模型...")
rf = RandomForestClassifier(
    featuresCol="features",
    labelCol="label_index",
    numTrees=100,
    maxDepth=10,
    seed=42
)

rf_model = rf.fit(train_data)

# 评估模型
print("评估模型性能...")
predictions = rf_model.transform(test_data)

# 将预测的索引转换回原始标签
from pyspark.ml.feature import IndexToString
label_converter = IndexToString(inputCol="prediction", outputCol="predicted_category", 
                                labels=label_indexer_model.labels)
predictions = label_converter.transform(predictions)

# 评估指标
evaluator_accuracy = MulticlassClassificationEvaluator(
    labelCol="label_index",
    predictionCol="prediction",
    metricName="accuracy"
)
evaluator_f1 = MulticlassClassificationEvaluator(
    labelCol="label_index",
    predictionCol="prediction",
    metricName="f1"
)

accuracy = evaluator_accuracy.evaluate(predictions)
f1 = evaluator_f1.evaluate(predictions)

print(f"\n模型性能评估:")
print(f"  准确率（Accuracy）: {accuracy:.4f}")
print(f"  F1分数: {f1:.4f}")

# 显示混淆矩阵（按类别统计）
print(f"\n各类别预测统计:")
predictions.groupBy("label", "predicted_category").count().orderBy("label", "predicted_category").show()

# 特征重要性分析
feature_importance = rf_model.featureImportances
importance_dict = {name: float(imp) for name, imp in zip(feature_columns, feature_importance)}
sorted_importance = sorted(importance_dict.items(), key=lambda x: x[1], reverse=True)

print(f"\n特征重要性（前10名）:")
for i, (name, imp) in enumerate(sorted_importance[:10], 1):
    print(f"  {i:2d}. {name:30s}: {imp:.4f}")

准备特征工程...
编码分类特征...
特征标准化...
训练随机森林分类模型...
评估模型性能...

模型性能评估:
  准确率（Accuracy）: 0.7952
  F1分数: 0.7337

各类别预测统计:
+------+------------------+-----+
| label|predicted_category|count|
+------+------------------+-----+
|  High|              High| 5866|
|  High|               Low| 1057|
|  High|            Medium|   57|
|   Low|              High|  347|
|   Low|               Low|20755|
|   Low|            Medium|  135|
|Medium|              High| 1532|
|Medium|               Low| 3774|
|Medium|            Medium|  182|
+------+------------------+-----+


特征重要性（前10名）:
   1. Product_Category_index        : 0.4502
   2. Sub_Category_index            : 0.4348
   3. Year                          : 0.0632
   4. Customer_Age                  : 0.0169
   5. State_index                   : 0.0130
   6. Month_index                   : 0.0083
   7. Country_index                 : 0.0070
   8. Age_Group_index               : 0.0034
   9. Customer_Gender_index         : 0.0032
