In [2]:
import sys
import os

project_root = os.path.abspath("../..")

if project_root not in sys.path:
    sys.path.append(project_root)

os.environ["PYSPARK_PYTHON"] = sys.executable
os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable

In [3]:
from datapipeline.utils.spark_session import get_spark_session

spark = get_spark_session("Trend_Feature_Engineering")

In [4]:
clusters_path = os.path.join(
    project_root,
    "sanewsstorage/ml/clusters_labeled"
)

articles_path = os.path.join(
    project_root,
    "sanewsstorage/gold/articles_final"
)

trend_timeseries_path = os.path.join(
    project_root,
    "sanewsstorage/ml/cluster_time_series"
)

trending_clusters_path = os.path.join(
    project_root,
    "sanewsstorage/ml/trending_clusters"
)

In [5]:
from pyspark.sql.functions import col

clusters_df = (
    spark.read.format("delta")
    .load(clusters_path)
    .select(
        "bronze_hash",
        "cluster_id",
        "cluster_label"
    )
)

In [6]:
articles_df = (
    spark.read.format("delta")
    .load(articles_path)
    .select(
        "bronze_hash",
        "published_at"
    )
)

In [7]:
df = clusters_df.join(
    articles_df,
    on="bronze_hash",
    how="left"
)

In [8]:
from pyspark.sql.functions import to_date

df = df.withColumn(
    "date",
    to_date(col("published_at"))
)

In [9]:
from pyspark.sql.functions import count

cluster_time_series = (
    df.groupBy(
        "cluster_id",
        "cluster_label",
        "date"
    )
    .agg(
        count("*").alias("article_count")
    )
)

In [10]:
from delta.tables import DeltaTable

if DeltaTable.isDeltaTable(
    spark,
    trend_timeseries_path
):

    delta_table = DeltaTable.forPath(
        spark,
        trend_timeseries_path
    )

    (
        delta_table.alias("t")
        .merge(
            cluster_time_series.alias("s"),
            """
            t.cluster_id = s.cluster_id
            AND t.date = s.date
            """
        )
        .whenMatchedUpdateAll()
        .whenNotMatchedInsertAll()
        .execute()
    )

else:

    (
        cluster_time_series.write
        .format("delta")
        .mode("overwrite")
        .save(trend_timeseries_path)
    )

In [20]:
cluster_time_series.show(5)

+----------+--------------------+----------+-------------+
|cluster_id|       cluster_label|      date|article_count|
+----------+--------------------+----------+-------------+
|         5|chars, available,...|2026-02-03|          205|
|         5|chars, available,...|2026-02-05|          181|
|        19|available, plans,...|2026-02-03|          309|
|         4|    chars, news, new|2026-02-03|          459|
|         8|chars, available,...|2026-02-08|          171|
+----------+--------------------+----------+-------------+
only showing top 5 rows



In [12]:
from pyspark.sql.functions import current_date, date_sub, count

last_1d = df.filter(
    col("date") >= date_sub(current_date(), 1)
)

prev_1d = df.filter(
    (col("date") < date_sub(current_date(), 1)) &
    (col("date") >= date_sub(current_date(), 2))
)

last_counts = (
    last_1d.groupBy("cluster_id")
    .agg(count("*").alias("last_24h"))
)

prev_counts = (
    prev_1d.groupBy("cluster_id")
    .agg(count("*").alias("prev_24h"))
)

In [None]:
growth_df = (
    last_counts.join(
        prev_counts,
        on="cluster_id",
        how="left"
    )
    .fillna(0)
)

from pyspark.sql.functions import expr

growth_df = growth_df.withColumn(
    "growth_rate",
    expr(
        """
        CASE
            WHEN prev_24h = 0 THEN 100
            ELSE ((last_24h - prev_24h) / prev_24h) * 100
        END
        """
    )
)

In [14]:
labels_df = clusters_df.select(
    "cluster_id",
    "cluster_label"
).dropDuplicates()

trending_df = growth_df.join(
    labels_df,
    on="cluster_id",
    how="left"
)

In [15]:
from pyspark.sql.functions import col

trending_df = trending_df.withColumn(
    "trend_score",
    col("last_24h") * 0.7 +
    col("growth_rate") * 0.3
)

In [16]:
if DeltaTable.isDeltaTable(
    spark,
    trending_clusters_path
):

    delta_table = DeltaTable.forPath(
        spark,
        trending_clusters_path
    )

    (
        delta_table.alias("t")
        .merge(
            trending_df.alias("s"),
            "t.cluster_id = s.cluster_id"
        )
        .whenMatchedUpdateAll()
        .whenNotMatchedInsertAll()
        .execute()
    )

else:

    (
        trending_df.write
        .format("delta")
        .mode("overwrite")
        .save(trending_clusters_path)
    )


In [19]:
trending_df.count()

0

In [17]:
trending_df.orderBy(
    col("trend_score").desc()
).show(20, truncate=False)

+----------+--------+--------+-----------+-------------+-----------+
|cluster_id|last_24h|prev_24h|growth_rate|cluster_label|trend_score|
+----------+--------+--------+-----------+-------------+-----------+
+----------+--------+--------+-----------+-------------+-----------+

