In [1]:
import sys
import os

project_root = os.path.abspath("../..")

if project_root not in sys.path:
    sys.path.append(project_root)

os.environ["PYSPARK_PYTHON"] = sys.executable
os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable

In [2]:
from datapipeline.utils.spark_session import get_spark_session
spark = get_spark_session("Trend_Feature_Engineering")

In [3]:
from pyspark.sql.functions import (
    col, to_date, count,
    current_date, date_sub,
    expr
)
from delta.tables import DeltaTable

In [4]:
clusters_path = os.path.join(project_root, "sanewsstorage/ml/clusters")
labels_path   = os.path.join(project_root, "sanewsstorage/ml/clusters_labeled")
articles_path = os.path.join(project_root, "sanewsstorage/gold/articles_final")

trend_timeseries_path   = os.path.join(project_root, "sanewsstorage/ml/trend_time_series")
trending_clusters_path  = os.path.join(project_root, "sanewsstorage/ml/trending_clusters")

In [5]:
clusters_df = spark.read.format("delta").load(clusters_path)

labels_df = spark.read.format("delta").load(labels_path)

articles_df = (
    spark.read.format("delta")
    .load(articles_path)
    .select("bronze_hash", "published_at")
)

In [6]:
df = (
    clusters_df
    .join(articles_df, on="bronze_hash", how="left")
    .withColumn("date", to_date(col("published_at")))
)

In [7]:
cluster_time_series = (
    df.groupBy("cluster_id", "date")
    .agg(count("*").alias("article_count"))
)

cluster_time_series = cluster_time_series.join(
    labels_df.select("cluster_id", "cluster_label"),
    on="cluster_id",
    how="left"
)

In [8]:
if DeltaTable.isDeltaTable(spark, trend_timeseries_path):

    delta_table = DeltaTable.forPath(spark, trend_timeseries_path)

    (
        delta_table.alias("t")
        .merge(
            cluster_time_series.alias("s"),
            "t.cluster_id = s.cluster_id AND t.date = s.date"
        )
        .whenMatchedUpdateAll()
        .whenNotMatchedInsertAll()
        .execute()
    )

else:

    (
        cluster_time_series.write
        .format("delta")
        .mode("overwrite")
        .save(trend_timeseries_path)
    )

In [9]:
cluster_time_series.show(5)

+----------+----------+-------------+--------------------+
|cluster_id|      date|article_count|       cluster_label|
+----------+----------+-------------+--------------------+
|        13|2026-02-04|          128|paid, plans, avai...|
|         3|2026-02-05|           81|news, earnings, f...|
|        19|2026-02-05|          259|available, plans,...|
|        16|2026-02-03|          283|   chars, bowl, news|
|        18|2026-02-08|          136|chars, available,...|
+----------+----------+-------------+--------------------+
only showing top 5 rows



In [10]:
cluster_time_series.count()

156

In [11]:
last_1d = df.filter(
    col("date") >= date_sub(current_date(), 1)
)

prev_1d = df.filter(
    (col("date") < date_sub(current_date(), 1)) &
    (col("date") >= date_sub(current_date(), 2))
)

In [12]:
last_counts = last_1d.groupBy("cluster_id").agg(
    count("*").alias("last_24h")
)

prev_counts = prev_1d.groupBy("cluster_id").agg(
    count("*").alias("prev_24h")
)

In [13]:
growth_df = (
    last_counts
    .join(prev_counts, on="cluster_id", how="left")
    .fillna(0)
)

growth_df = growth_df.withColumn(
    "growth_rate",
    expr("""
        CASE
            WHEN prev_24h = 0 THEN 100
            ELSE ((last_24h - prev_24h) / prev_24h) * 100
        END
    """)
)

In [14]:
trending_df = growth_df.join(
    labels_df.select("cluster_id", "cluster_label"),
    on="cluster_id",
    how="left"
)

trending_df = trending_df.withColumn(
    "trend_score",
    col("last_24h") * 0.7 +
    col("growth_rate") * 0.3
)

In [15]:
if DeltaTable.isDeltaTable(spark, trending_clusters_path):

    delta_table = DeltaTable.forPath(spark, trending_clusters_path)

    (
        delta_table.alias("t")
        .merge(
            trending_df.alias("s"),
            "t.cluster_id = s.cluster_id"
        )
        .whenMatchedUpdateAll()
        .whenNotMatchedInsertAll()
        .execute()
    )

else:

    (
        trending_df.write
        .format("delta")
        .mode("overwrite")
        .save(trending_clusters_path)
    )

In [17]:
trending_df.orderBy(
    col("trend_score").desc()
).show(20, truncate=False)

+----------+--------+--------+-----------+-------------+-----------+
|cluster_id|last_24h|prev_24h|growth_rate|cluster_label|trend_score|
+----------+--------+--------+-----------+-------------+-----------+
+----------+--------+--------+-----------+-------------+-----------+



In [18]:
spark.stop()