In [1]:
import sys
import os

project_root = os.path.abspath("../..")

if project_root not in sys.path:
    sys.path.append(project_root)

os.environ["PYSPARK_PYTHON"] = sys.executable
os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable

In [2]:
from datapipeline.utils.spark_session import get_spark_session
spark = get_spark_session("ML_Event_Detection")

In [3]:
from pyspark.sql.functions import (
    col, to_date, count,
    avg, stddev, when,
    current_date, date_sub
)
from pyspark.sql.window import Window
from delta.tables import DeltaTable

In [4]:
clusters_path = os.path.join(project_root, "sanewsstorage/ml/clusters")
labels_path   = os.path.join(project_root, "sanewsstorage/ml/clusters_labeled")
articles_path = os.path.join(project_root, "sanewsstorage/gold/articles_final")
event_path    = os.path.join(project_root, "sanewsstorage/ml/events")

In [5]:
clusters_df = spark.read.format("delta").load(clusters_path)

labels_df = spark.read.format("delta").load(labels_path)

articles_df = (
    spark.read.format("delta")
    .load(articles_path)
    .select("bronze_hash", "published_at")
)

In [6]:
df = (
    clusters_df
    .join(articles_df, on="bronze_hash", how="left")
    .join(labels_df, on="cluster_id", how="left")
    .select(
        "bronze_hash",
        "cluster_id",
        "cluster_label",
        "published_at"
    )
)

df = df.withColumn("date", to_date(col("published_at")))

In [7]:
daily_counts = (
    df.groupBy("cluster_id", "cluster_label", "date")
    .agg(count("bronze_hash").alias("article_count"))
)

In [8]:
window_spec = (
    Window
    .partitionBy("cluster_id")
    .orderBy("date")
    .rowsBetween(-7, -1)
)

baseline_df = (
    daily_counts
    .withColumn("baseline_avg", avg("article_count").over(window_spec))
    .withColumn("baseline_std", stddev("article_count").over(window_spec))
)

In [9]:
event_df = (
    baseline_df
    .withColumn(
        "z_score",
        when(col("baseline_std").isNull(), 0)
        .otherwise(
            (col("article_count") - col("baseline_avg")) /
            col("baseline_std")
        )
    )
    .withColumn(
        "is_event",
        when(col("z_score") >= 2.0, 1).otherwise(0)
    )
    .withColumn(
        "event_intensity",
        when(col("z_score") >= 4, "Viral")
        .when(col("z_score") >= 2, "Trending")
        .otherwise("Normal")
    )
)

In [10]:

if DeltaTable.isDeltaTable(spark, event_path):

    delta_table = DeltaTable.forPath(spark, event_path)

    (
        delta_table.alias("t")
        .merge(
            event_df.alias("s"),
            "t.cluster_id = s.cluster_id AND t.date = s.date"
        )
        .whenMatchedUpdateAll()
        .whenNotMatchedInsertAll()
        .execute()
    )

else:

    (
        event_df.write
        .format("delta")
        .mode("overwrite")
        .save(event_path)
    )

In [11]:
event_df.show(5)

+----------+--------------------+----------+-------------+------------------+------------------+-------------------+--------+---------------+
|cluster_id|       cluster_label|      date|article_count|      baseline_avg|      baseline_std|            z_score|is_event|event_intensity|
+----------+--------------------+----------+-------------+------------------+------------------+-------------------+--------+---------------+
|         0|2026, available, ...|2026-02-02|           58|              NULL|              NULL|                0.0|       0|         Normal|
|         0|2026, available, ...|2026-02-03|          292|              58.0|              NULL|                0.0|       0|         Normal|
|         0|2026, available, ...|2026-02-04|          386|             175.0|165.46298679765212| 1.2752096652167653|       0|         Normal|
|         0|2026, available, ...|2026-02-05|          212|245.33333333333334|168.90628565371193|-0.1973480927860354|       0|         Normal|
|     

In [12]:
event_df.count()

156

In [13]:
spark.stop()