# DS/CMPSC 410 Pipeline



### 

In [1]:
#What I’m doing here is basically importing all the PySpark components I need for this project.
#I load Spark SQL, ML, and some feature engineering tools, plus the classifier and evaluator.
#This is just the standard setup before I can run any Spark pipeline.


import pyspark
from pyspark.sql import SparkSession, Window
from pyspark.sql import functions as F
from pyspark.sql import types as T

from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler
from pyspark.ml.classification import GBTClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator



In [2]:
# Here I’m just creating the Spark session
spark = (
    SparkSession
    .builder
    .appName("DS410_TimeSeries_Tweets")
    .getOrCreate()
)

# Disable ANSI mode so malformed values become NULL instead of killing the job
spark.conf.set("spark.sql.ansi.enabled", "false")

Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
  from pandas.core.computation.check import NUMEXPR_INSTALLED
  from pandas.core import (
25/12/01 13:01:35 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
# here are the file we need for the pipeline and ml

spx_path = "/storage/work/yfl5682/Project/SP500/SPX_full_5min_with_datetime_parts.csv"
tweets_path = "/storage/work/yfl5682/Project/tweets_with_topics_v2_flat.csv"

In [4]:
# ==================================================
# PART A. Load SPX data, build bar_time, compute RV
# ==================================================

# In this cell what I’m basically doing is preparing the SPX 5-minute time series
# and building the volatility labels I’ll later join with Trump tweets.

# I read in the SPX CSV that already has year/month/day/hour/minute/second
# split out as separate columns. Then I cast all of these time parts and the OHLC(Open, High, Low, Close), check SP 500 file 
# price fields into the correct numeric types so Spark can do proper math on them.

# I then reconstruct a full timestamp column called bar_time by stitching
# together the year/month/day/hour/minute/second into a single string and then
# converting it to a real Spark timestamp. This gives me a clean time index for
# each 5-minute bar. if you guys want to use 1 minute bar it will be still very simple. 

# Once I have bar_time, I compute log returns of the SPX close price
# within each trading day using a window ordered by time. then I build
# realized volatility features, 30-minute pre-event RV, using the previous 6 bars;
# and 30-minute post-event RV, using the current bar plus the next 6 bars.
# These two columns: `rv_pre_30m` and `rv_post_30m`, will later serve as my
# “before vs after tweet” volatility measures.



# 2.1 Read SPX CSV (5-min bars with datetime parts)
spx_raw = (
    spark.read
    .option("header", "true")
    .csv(spx_path)
)

# 2.2 Cast numeric fields to proper types
spx = (
    spx_raw
    .withColumn("year",   F.col("year").cast(T.IntegerType()))
    .withColumn("month",  F.col("month").cast(T.IntegerType()))
    .withColumn("day",    F.col("day").cast(T.IntegerType()))
    .withColumn("hour",   F.col("hour").cast(T.IntegerType()))
    .withColumn("minute", F.col("minute").cast(T.IntegerType()))
    .withColumn("second", F.col("second").cast(T.IntegerType()))
    .withColumn("Open",   F.col("Open").cast(T.DoubleType()))
    .withColumn("High",   F.col("High").cast(T.DoubleType()))
    .withColumn("Low",    F.col("Low").cast(T.DoubleType()))
    .withColumn("Close",  F.col("Close").cast(T.DoubleType()))
)

# 2.3 Reconstruct full timestamp for each bar
spx_time_str = F.format_string(
    "%04d-%02d-%02d %02d:%02d:%02d",
    F.col("year"), F.col("month"), F.col("day"),
    F.col("hour"), F.col("minute"), F.col("second")
)

spx = spx.withColumn(
    "bar_time",
    F.to_timestamp(spx_time_str, "yyyy-MM-dd HH:mm:ss")
)

print("SPX rows:", spx.count())
spx.select("bar_time", "Open", "High", "Low", "Close").show(5, truncate=False)

# 2.4 Compute log returns and 30-min pre/post realized volatility
# Use a per-day window to avoid unnecessary full-table windows
w_order = Window.partitionBy(F.to_date("bar_time")).orderBy("bar_time")

spx_lr = spx.withColumn(
    "log_return",
    F.log(F.col("Close") / F.lag("Close").over(w_order))
)

# 6 previous / next bars = 30 minutes for 5-min data
w_pre = w_order.rowsBetween(-6, -1)
w_post = w_order.rowsBetween(0, 6)

spx_rv = (
    spx_lr
    .withColumn(
        "rv_pre_30m",
        F.sqrt(F.sum(F.pow(F.col("log_return"), F.lit(2.0))).over(w_pre))
    )
    .withColumn(
        "rv_post_30m",
        F.sqrt(F.sum(F.pow(F.col("log_return"), F.lit(2.0))).over(w_post))
    )
)

spx_rv.select("bar_time", "Close", "rv_pre_30m", "rv_post_30m").show(5)


SPX rows: 357535
+-------------------+-------+-------+-------+-------+
|bar_time           |Open   |High   |Low    |Close  |
+-------------------+-------+-------+-------+-------+
|2008-01-02 09:30:00|1467.97|1470.14|1467.97|1470.05|
|2008-01-02 09:35:00|1470.17|1470.17|1467.88|1469.49|
|2008-01-02 09:40:00|1469.78|1471.71|1469.39|1471.22|
|2008-01-02 09:45:00|1471.56|1471.77|1470.69|1470.78|
|2008-01-02 09:50:00|1470.28|1471.06|1470.1 |1470.74|
+-------------------+-------+-------+-------+-------+
only showing top 5 rows


[Stage 10:>                                                         (0 + 2) / 2]

+-------------------+-------+--------------------+--------------------+
|           bar_time|  Close|          rv_pre_30m|         rv_post_30m|
+-------------------+-------+--------------------+--------------------+
|2008-01-02 09:30:00|1470.05|                NULL|0.004660684218949346|
|2008-01-02 09:35:00|1469.49|                NULL|0.004714627426470183|
|2008-01-02 09:40:00|1471.22|3.810119996833424E-4|0.004947337820406521|
|2008-01-02 09:45:00|1470.78|0.001236740274360...|0.004863112334274586|
|2008-01-02 09:50:00|1470.74|0.001272398144075946|0.004897381750463964|
+-------------------+-------+--------------------+--------------------+
only showing top 5 rows


                                                                                

In [5]:
# ==================================================
# PART B. Load tweets, build tweet_time, basic features
# ==================================================



# here what im trying to do is to rebuild a clean timestamp for each tweet, and engineer a few simple
# features that I might use later in the ML pipeline.

# I cast all the time-related columns plus favorites and retweets
# into integers, and I drop any rows where the datetime pieces are incomplete.
# I don’t want broken timestamps to mess up the time alignment with SPX.
#
# I reconstruct a full `tweet_time` timestamp column, similar to what
# I did for SPX, by formatting year/month/day/hour/minute/second into a
# single string and converting it to a Spark timestamp.
#
# I clean up the boolean flags `isRetweet` and `isDeleted` by turning
# the original "t"/"f" strings into real Boolean columns. I’m not necessarily
# using them in the main ML model yet, but I keep them in a clean format. Highly doubt that if this is even useful lmao...
#
# I build some super simple text-based features: length of the tweet
# (`text_len`) and how many exclamation marks it has (`num_exclam`). At the
# end I preview the key columns I care about: timing, topic labels, sentiment,
# trading-hours flag, engagement metrics, and the basic text features. you guys could add more interesting features if u want. 


# 3.1 Read flat tweets CSV
tweets_raw = (
    spark.read
    .option("header", "true")
    .csv(tweets_path)
)

# 3.2 Cast integer columns
int_cols = [
    "year", "month", "day", "hour", "minute", "second",
    "favorites", "retweets"
]

tweets_num = tweets_raw
for c in int_cols:
    tweets_num = tweets_num.withColumn(c, F.col(c).cast(T.IntegerType()))

# Drop rows with incomplete datetime fields
tweets_num = tweets_num.filter(
    F.col("year").isNotNull() &
    F.col("month").isNotNull() &
    F.col("day").isNotNull() &
    F.col("hour").isNotNull() &
    F.col("minute").isNotNull() &
    F.col("second").isNotNull()
)

# 3.3 Build full tweet timestamp
tweet_time_str = F.format_string(
    "%04d-%02d-%02d %02d:%02d:%02d",
    F.col("year"), F.col("month"), F.col("day"),
    F.col("hour"), F.col("minute"), F.col("second")
)

tweets = tweets_num.withColumn(
    "tweet_time",
    F.to_timestamp(tweet_time_str, "yyyy-MM-dd HH:mm:ss")
)

# 3.4 Boolean flags (we do not use them in ML, but keep them clean)
tweets = (
    tweets
    .withColumn("isRetweet", (F.col("isRetweet") == F.lit("t")).cast(T.BooleanType()))
    .withColumn("isDeleted", (F.col("isDeleted") == F.lit("t")).cast(T.BooleanType()))
)

# 3.5 Simple text features
tweets = (
    tweets
    .withColumn("text_len",   F.length("text"))
    .withColumn("num_exclam", F.size(F.split(F.col("text"), "!")) - F.lit(1))
)

tweets.select(
    "id", "tweet_time", "category", "blue_category",
    "sentiment", "intensity", "during_trading_hours",
    "favorites", "retweets", "text_len", "num_exclam"
).show(5, truncate=100)


+-------------------+-------------------+---------------------------------------+-----------------------------------------------------------------+---------+---------+--------------------+---------+--------+--------+----------+
|                 id|         tweet_time|                               category|                                                    blue_category|sentiment|intensity|during_trading_hours|favorites|retweets|text_len|num_exclam|
+-------------------+-------------------+---------------------------------------+-----------------------------------------------------------------+---------+---------+--------------------+---------+--------+--------+----------+
|  98454970654916608|2011-08-02 18:07:48|     Macroeconomics & Monetary Policies|                                          Market / Economy / Jobs| Negative|   Medium|               False|       49|     255|      66|         0|
|1234653427789070336|2020-03-03 01:34:50|   Campaign / Rally / Election Politics|       

In [6]:
# ==================================================
# PART C. Match tweets to nearest SPX bar within ±10 minutes
# ==================================================


# In this cell what I’m basically doing is aligning each tweet with the closest
# SPX 5-minute bar within a +-10 minute window.
#
# I rename some overlapping columns on both the tweets side and the SPX side to avoid name collisions after the join
#
# I join tweets and SPX bars on the same calendar day using the date
# part of `tweet_time` and `bar_time`. At this stage, each tweet is matched
# with *all* SPX bars from that day.
#
# Step 3: For every tweet–bar pair, I compute the absolute time difference
# in seconds between `tweet_time` and `bar_time`, and I only keep those pairs
# where the difference is within ±10 minutes (<= 600 seconds). This trims
# out bars that are clearly too far from the tweet.
#
# Step 4: Among the remaining candidates, I use a window partitioned by tweet
# `id` and ordered by `time_diff_sec`, then keep only the row with rank 1.
# This gives me a single nearest SPX bar for each tweet inside the time window.
#
# Step 5: I rename `bar_time` to `event_time` for clarity, since this is the
# timestamp I’ll treat as the “event bar” in the later modeling. Finally, I
# count how many tweets survived this matching step and throw an error if
# nothing matched, so I don’t accidentally train on an empty dataset.







# To avoid column name collisions, rename year/hour on both sides
tweets_for_join = (
    tweets
    .withColumnRenamed("year", "tweet_year")
    .withColumnRenamed("hour", "tweet_hour")
)

spx_for_join = (
    spx_rv
    .withColumnRenamed("year", "spx_year")
    .withColumnRenamed("hour", "spx_hour")
)

# 4.1 Join tweets with SPX on the same calendar day
joined = (
    tweets_for_join
    .join(
        spx_for_join,
        F.to_date("tweet_time") == F.to_date("bar_time"),
        "inner"
    )
)

# 4.2 Compute absolute time difference (in seconds) between tweet_time and bar_time
joined = joined.withColumn(
    "time_diff_sec",
    F.abs(F.unix_timestamp("bar_time") - F.unix_timestamp("tweet_time"))
)

# 4.3 Keep only bars within ±10 minutes of the tweet
window_secs = 600  # ±10 minutes
joined_window = joined.filter(F.col("time_diff_sec") <= window_secs)

# 4.4 For each tweet, keep only the single nearest bar
w_nearest = Window.partitionBy("id").orderBy("time_diff_sec")

events = (
    joined_window
    .withColumn("rn", F.row_number().over(w_nearest))
    .filter(F.col("rn") == 1)
    .drop("rn", "time_diff_sec")
)

# Rename bar_time to event_time for clarity
events = events.withColumnRenamed("bar_time", "event_time")

print("Events rows (tweets used in pipeline):", events.count())
events.select(
    "id", "tweet_time", "event_time",
    "Close", "rv_pre_30m", "rv_post_30m",
    "tweet_year", "tweet_hour", "spx_year", "spx_hour"
).show(10, truncate=False)

if events.count() == 0:
    raise ValueError("No matched events within ±10 minutes. Check data ranges or window size.")



                                                                                

Events rows (tweets used in pipeline): 10870


                                                                                

+-------------------+-------------------+-------------------+-------+---------------------+---------------------+----------+----------+--------+--------+
|id                 |tweet_time         |event_time         |Close  |rv_pre_30m           |rv_post_30m          |tweet_year|tweet_hour|spx_year|spx_hour|
+-------------------+-------------------+-------------------+-------+---------------------+---------------------+----------+----------+--------+--------+
|1001404640796336128|2018-05-29 10:07:26|2018-05-29 10:05:00|2709.29|0.0019230080247938965|0.0018386673426389981|2018      |10        |2018    |10      |
|1001410457092218880|2018-05-29 10:30:32|2018-05-29 10:30:00|2704.19|0.0012550933708381099|0.0023081077350491125|2018      |10        |2018    |10      |
|1001415199516254208|2018-05-29 10:49:23|2018-05-29 10:50:00|2705.22|0.001656133860550514 |0.0022112435484160416|2018      |10        |2018    |10      |
|1001417880116891650|2018-05-29 11:00:02|2018-05-29 11:00:00|2698.87|0.00192

                                                                                

In [7]:
# ==================================================
# PART D. Build shock label (is_shock) and ML dataframe
# ==================================================

# 5.1 Only consider rows where rv_post_30m exists
non_null_events = events.where(F.col("rv_post_30m").isNotNull())
if non_null_events.count() == 0:
    raise ValueError("No non-null rv_post_30m values; cannot build shock label.")

# 5.2 Use 90% quantile of rv_post_30m as shock threshold (more shocks than 95%)
shock_quantile = 0.90
quantiles = non_null_events.approxQuantile("rv_post_30m", [shock_quantile], 0.0)
threshold = float(quantiles[0])
print(f"Shock threshold ({int(shock_quantile*100)}% quantile of rv_post_30m):", threshold)

# 5.3 Define binary label: is_shock = 1 if rv_post_30m > threshold
labeled = events.withColumn(
    "is_shock",
    (F.col("rv_post_30m") > F.lit(threshold)).cast(T.IntegerType())
)

# 5.4 Categorical and numeric feature columns
cat_cols = [
    "category",
    "blue_category",
    "sentiment",
    "intensity",
    "has_market_action_keywords",
    "during_trading_hours",
    "towards_ceo_or_company",
]

num_cols = [
    "favorites",
    "retweets",
    "text_len",
    "num_exclam",
    "rv_pre_30m",
    "Close",
    "tweet_hour",
    "doy",
]

# Day-of-week as a simple time feature
labeled = labeled.withColumn("doy", F.dayofweek("tweet_time"))

# 5.5 Build final ML dataframe
ml_df = labeled.select(
    "is_shock", "event_time", "tweet_year", *cat_cols, *num_cols
).withColumnRenamed("tweet_year", "year")

print("Shock class distribution:")
ml_df.groupBy("is_shock").count().show()

ml_df.show(5, truncate=120)

                                                                                

Shock threshold (90% quantile of rv_post_30m): 0.0028990257946355995
Shock class distribution:


                                                                                

+--------+-----+
|is_shock|count|
+--------+-----+
|       1| 1087|
|       0| 9783|
+--------+-----+





+--------+-------------------+----+-----------------------------------------+----------------------------------------------------+---------+---------+--------------------------+--------------------+----------------------+---------+--------+--------+----------+---------------------+-------+----------+---+
|is_shock|         event_time|year|                                 category|                                       blue_category|sentiment|intensity|has_market_action_keywords|during_trading_hours|towards_ceo_or_company|favorites|retweets|text_len|num_exclam|           rv_pre_30m|  Close|tweet_hour|doy|
+--------+-------------------+----+-----------------------------------------+----------------------------------------------------+---------+---------+--------------------------+--------------------+----------------------+---------+--------+--------+----------+---------------------+-------+----------+---+
|       0|2018-05-29 10:05:00|2018|     Campaign / Rally / Election Politics|     

                                                                                

In [8]:
# ==================================================
# PART E. ML pipeline: StringIndexer + OneHotEncoder + GBT
# ==================================================

from pyspark.ml.classification import GBTClassifier, LogisticRegression

# 6.0 Clean NULLs in numeric features to avoid VectorAssembler errors
# We simply fill numeric NULLs with 0.0 (you can choose a different strategy if desired).
ml_df_clean = ml_df
for c in num_cols:
    ml_df_clean = ml_df_clean.withColumn(c, F.coalesce(F.col(c), F.lit(0.0)))

print("After NULL cleaning, any NULL in numeric cols?")
for c in num_cols:
    null_cnt = ml_df_clean.filter(F.col(c).isNull()).count()
    print(f"{c}: {null_cnt} NULLs")

# 6.1 Train/test split (random split; could also split by year)
train, test = ml_df_clean.randomSplit([0.7, 0.3], seed=42)
if test.count() == 0:
    test = train

print("Train rows:", train.count())
print("Test rows:",  test.count())

stages_common = []

# 6.2 Encode categorical columns (shared by both models)
for c in cat_cols:
    idx_col = c + "_idx"
    oh_col = c + "_oh"

    indexer = StringIndexer(
        inputCol=c,
        outputCol=idx_col,
        handleInvalid="keep"   # unseen or null categories go to a special bucket
    )

    encoder = OneHotEncoder(
        inputCol=idx_col,
        outputCol=oh_col
    )

    stages_common += [indexer, encoder]

# 6.3 Assemble all features into a single vector
feature_cols = [c + "_oh" for c in cat_cols] + num_cols

assembler = VectorAssembler(
    inputCols=feature_cols,
    outputCol="features",
    handleInvalid="keep"   # if any unexpected NULL sneaks in, keep instead of error
)
stages_common.append(assembler)

# ==================================================
# 6.4 Model 1: Gradient-Boosted Trees (GBT)
# ==================================================
gbt = GBTClassifier(
    labelCol="is_shock",
    featuresCol="features",
    maxDepth=3,
    maxIter=20,
)

gbt_pipeline = Pipeline(stages=stages_common + [gbt])

gbt_model = gbt_pipeline.fit(train)
gbt_pred = gbt_model.transform(test)

print("=== GBT sample predictions ===")
gbt_pred.select("event_time", "is_shock", "probability", "prediction") \
        .show(10, truncate=False)

# ==================================================
# 6.5 Model 2: Logistic Regression (baseline linear model)
# ==================================================

# We can reuse the same preprocessing stages (indexer + OHE + assembler),
# but need a fresh pipeline object to attach LogisticRegression at the end.
lr = LogisticRegression(
    labelCol="is_shock",
    featuresCol="features",
    maxIter=50,
    regParam=0.0,      # you can tune regularization if needed
    elasticNetParam=0.0
)

lr_pipeline = Pipeline(stages=stages_common + [lr])

lr_model = lr_pipeline.fit(train)
lr_pred = lr_model.transform(test)

print("=== Logistic Regression sample predictions ===")
lr_pred.select("event_time", "is_shock", "probability", "prediction") \
       .show(10, truncate=False)

After NULL cleaning, any NULL in numeric cols?
favorites: 0 NULLs
retweets: 0 NULLs
text_len: 0 NULLs
num_exclam: 0 NULLs
rv_pre_30m: 0 NULLs
Close: 0 NULLs
tweet_hour: 0 NULLs
doy: 0 NULLs


                                                                                

Train rows: 7704


                                                                                

Test rows: 3166


25/12/01 13:03:08 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
                                                                                

=== GBT sample predictions ===


25/12/01 13:03:27 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS


+-------------------+--------+----------------------------------------+----------+
|event_time         |is_shock|probability                             |prediction|
+-------------------+--------+----------------------------------------+----------+
|2009-06-30 13:35:00|0       |[0.9044880790851014,0.09551192091489857]|0.0       |
|2009-07-28 15:50:00|0       |[0.810774500929421,0.18922549907057895] |0.0       |
|2009-08-11 14:50:00|0       |[0.9321156042164586,0.06788439578354144]|0.0       |
|2009-08-14 14:35:00|0       |[0.8519157424649585,0.1480842575350415] |0.0       |
|2009-09-14 15:50:00|0       |[0.852829543817516,0.14717045618248403] |0.0       |
|2009-10-05 14:40:00|0       |[0.7534886013839786,0.24651139861602145]|0.0       |
|2009-10-14 14:15:00|0       |[0.7521372489758473,0.24786275102415267]|0.0       |
|2010-03-31 13:40:00|0       |[0.9415470637678413,0.05845293623215875]|0.0       |
|2010-04-06 14:35:00|0       |[0.9129256092935111,0.08707439070648892]|0.0       |
|201

                                                                                

=== Logistic Regression sample predictions ===


                                                                                

+-------------------+--------+-----------------------------------------+----------+
|event_time         |is_shock|probability                              |prediction|
+-------------------+--------+-----------------------------------------+----------+
|2009-06-30 13:35:00|0       |[0.9621662237900255,0.03783377620997452] |0.0       |
|2009-07-28 15:50:00|0       |[0.8714417149827358,0.12855828501726418] |0.0       |
|2009-08-11 14:50:00|0       |[0.9539380340084452,0.046061965991554765]|0.0       |
|2009-08-14 14:35:00|0       |[0.9142697932132346,0.08573020678676535] |0.0       |
|2009-09-14 15:50:00|0       |[0.9538459128418048,0.046154087158195245]|0.0       |
|2009-10-05 14:40:00|0       |[0.9669801641980873,0.0330198358019127]  |0.0       |
|2009-10-14 14:15:00|0       |[0.9560894130978992,0.04391058690210081] |0.0       |
|2010-03-31 13:40:00|0       |[0.9779721582785278,0.022027841721472208]|0.0       |
|2010-04-06 14:35:00|0       |[0.9603001258033964,0.03969987419660359] |0.0 

In [9]:
# ==================================================
# PART F. Evaluation metrics (PR-AUC and ROC-AUC)
# ==================================================

if test.select("is_shock").distinct().count() > 1:
    evaluator_pr = BinaryClassificationEvaluator(
        labelCol="is_shock",
        rawPredictionCol="rawPrediction",
        metricName="areaUnderPR"
    )
    evaluator_roc = BinaryClassificationEvaluator(
        labelCol="is_shock",
        rawPredictionCol="rawPrediction",
        metricName="areaUnderROC"
    )

    # Evaluate GBT
    gbt_pr_auc = evaluator_pr.evaluate(gbt_pred)
    gbt_roc_auc = evaluator_roc.evaluate(gbt_pred)
    print(f"[GBT] PR-AUC:  {gbt_pr_auc:.4f}")
    print(f"[GBT] ROC-AUC: {gbt_roc_auc:.4f}")

    # Evaluate Logistic Regression
    lr_pr_auc = evaluator_pr.evaluate(lr_pred)
    lr_roc_auc = evaluator_roc.evaluate(lr_pred)
    print(f"[LR ] PR-AUC:  {lr_pr_auc:.4f}")
    print(f"[LR ] ROC-AUC: {lr_roc_auc:.4f}")
else:
    print("Test set has only one class label; PR/ROC evaluation is not meaningful.")

                                                                                

[GBT] PR-AUC:  0.7138
[GBT] ROC-AUC: 0.9254


                                                                                

[LR ] PR-AUC:  0.6498
[LR ] ROC-AUC: 0.9066


In [10]:
# ==================================================
# Save models
# ==================================================

model_dir = "/storage/work/yfl5682/Project/models"

# Save GBT model
gbt_save_path = f"{model_dir}/gbt_model"
gbt_model.write().overwrite().save(gbt_save_path)
print(f"GBT model saved to: {gbt_save_path}")

# Save Logistic Regression model
lr_save_path  = f"{model_dir}/lr_model"
lr_model.write().overwrite().save(lr_save_path)
print(f"Logistic Regression model saved to: {lr_save_path}")

                                                                                

GBT model saved to: /storage/work/yfl5682/Project/models/gbt_model
Logistic Regression model saved to: /storage/work/yfl5682/Project/models/lr_model
