In [7]:
import findspark
findspark.init("/home/chenekv/spark-3.5.4-bin-hadoop3-scala2.13")

from pyspark.sql import SparkSession
from pyspark.sql.functions import udf, col, when
from pyspark.sql.types import DateType, Row
import pyspark.sql.functions as F
from pyspark.ml.feature import VectorAssembler, StandardScaler
from pyspark.ml.regression import LinearRegression
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml import Pipeline
from datetime import datetime, timedelta

spark = SparkSession.builder.appName("PurchaseRedemptionForecast_Optimized").getOrCreate()

user_balance_path = "/home/chenekv/code/jupyterlab/Purchase Redemption Data/user_balance_table.csv"
user_balance_df = spark.read.csv(user_balance_path, header=True, inferSchema=True)

# 删除不需要的列，包括利率相关的列
user_balance_df = user_balance_df.drop("user_id", "category1", "category2", "category3", "category4")

# 按日期汇总申购和赎回金额
daily_balance_df = user_balance_df.groupBy("report_date").agg(
    F.sum("total_purchase_amt").alias("total_purchase_amt_sum"),
    F.sum("total_redeem_amt").alias("total_redeem_amt_sum")
)

24/12/24 16:18:08 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
                                                                                

In [8]:
# 定义UDF将整数日期转换为DateType
def int_to_date(date_int):
    return datetime.strptime(str(date_int), "%Y%m%d")

int_to_date_udf = udf(int_to_date, DateType())

# 转换日期列为DateType
daily_balance_df = daily_balance_df.withColumn("report_date", int_to_date_udf("report_date"))

# 按日期排序
sorted_df = daily_balance_df.orderBy("report_date")

# 提取日期特征（年、月、日、星期几、是否周末）
extracted_df = sorted_df.withColumn("year", F.year("report_date")) \
                         .withColumn("month", F.month("report_date")) \
                         .withColumn("day", F.dayofmonth("report_date")) \
                         .withColumn("weekday", F.dayofweek("report_date")) \
                         .withColumn("is_weekend", when(F.dayofweek("report_date").isin([6,7]), 1).otherwise(0))

# 组装特征向量，包括新添加的特征
feature_cols = ["year", "month", "day", "weekday", "is_weekend"]
assembler = VectorAssembler(inputCols=feature_cols, outputCol="features_unscaled")
features_unscaled_df = assembler.transform(extracted_df)

# 标准化特征
scaler = StandardScaler(inputCol="features_unscaled", outputCol="features", withStd=True, withMean=True)
scaler_model = scaler.fit(features_unscaled_df)
final_df = scaler_model.transform(features_unscaled_df)

In [9]:
# 构建训练集，筛选2013-07-01至2014-08-31的数据
train_df = final_df.filter(
    (F.col("report_date") >= "2013-07-01") & (F.col("report_date") <= "2014-08-31")
)

# 训练申购线性回归模型，添加正则化参数regParam=0
lr_purchase = LinearRegression(featuresCol="features", labelCol="total_purchase_amt_sum", regParam=0)
purchase_model = lr_purchase.fit(train_df)

# 训练赎回线性回归模型，添加正则化参数regParam=1.0
lr_redeem = LinearRegression(featuresCol="features", labelCol="total_redeem_amt_sum", regParam=1.0)
redeem_model = lr_redeem.fit(train_df)

# 构建预测集test_df
start_date = datetime.strptime("2014-09-01", "%Y-%m-%d")
end_date = datetime.strptime("2014-09-30", "%Y-%m-%d")
delta = end_date - start_date

date_list = [start_date + timedelta(days=i) for i in range(delta.days + 1)]
date_rows = [Row(report_date=date) for date in date_list]
test_dates_df = spark.createDataFrame(date_rows)

test_features_df = test_dates_df.withColumn("year", F.year("report_date")) \
    .withColumn("month", F.month("report_date")) \
    .withColumn("day", F.dayofmonth("report_date")) \
    .withColumn("weekday", F.dayofweek("report_date")) \
    .withColumn("is_weekend", when(F.dayofweek("report_date").isin([6,7]), 1).otherwise(0))

final_test_df = pipeline_model.transform(test_features_df)

# 使用申购模型进行预测
predictions_purchase = purchase_model.transform(final_test_df)
# 使用赎回模型进行预测
predictions_redeem = redeem_model.transform(final_test_df)

24/12/24 16:30:36 WARN Instrumentation: [cfd82a21] regParam is zero, which might cause numerical instability and overfitting.


In [10]:
# 合并预测结果
predictions = predictions_purchase.select(
    col("report_date"),
    col("prediction").alias("predicted_total_purchase_amt_sum")
).join(
    predictions_redeem.select(
        col("report_date"),
        col("prediction").alias("predicted_total_redeem_amt_sum")
    ),
    on="report_date"
)

predictions.select("report_date", "predicted_total_purchase_amt_sum", "predicted_total_redeem_amt_sum") \
    .orderBy("report_date") \
    .show(30, truncate=False)

# 结果保存到CSV文件
predictions.select("report_date", "predicted_total_purchase_amt_sum", "predicted_total_redeem_amt_sum") \
    .orderBy("report_date") \
    .coalesce(1) \
    .write \
    .csv("/home/chenekv/code/jupyterlab/Purchase Redemption Data/predictions_september_2014_optimized.csv", header=True, mode="overwrite")

spark.stop()

+-------------------+--------------------------------+------------------------------+
|report_date        |predicted_total_purchase_amt_sum|predicted_total_redeem_amt_sum|
+-------------------+--------------------------------+------------------------------+
|2014-09-01 00:00:00|2.7383887306066704E8            |2.9113740342730874E8          |
|2014-09-02 00:00:00|2.8609955332742065E8            |2.965170968220188E8           |
|2014-09-03 00:00:00|2.9836023359417427E8            |3.018967902167289E8           |
|2014-09-04 00:00:00|3.106209138609279E8             |3.07276483611439E8            |
|2014-09-05 00:00:00|2.2009421260618907E8            |2.558074010115124E8           |
|2014-09-06 00:00:00|2.323548928729427E8             |2.6118709440622252E8          |
|2014-09-07 00:00:00|2.6526555958516014E8            |2.993297245801239E8           |
|2014-09-08 00:00:00|2.7752623985191375E8            |3.0470941797483397E8          |
|2014-09-09 00:00:00|2.8978692011866736E8            |