In [1]:
import sys
import os

project_root = os.path.abspath("../..")
if project_root not in sys.path:
    sys.path.append(project_root)

os.environ["PYSPARK_PYTHON"] = sys.executable
os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable

In [2]:
from datapipeline.utils.spark_session import get_spark_session
spark = get_spark_session("ML_Clustering")

spark.conf.set("spark.databricks.delta.schema.autoMerge.enabled", "true")
spark.conf.set("spark.sql.parquet.enableVectorizedReader", "false")

In [3]:
import mlflow
import mlflow.sklearn
import pandas as pd
import numpy as np
from sklearn.cluster import KMeans
from pyspark.sql.functions import col
from delta.tables import DeltaTable

In [12]:
gold_path = "../../sanewsstorage/gold/articles_final"
ml_cluster_path = "../../sanewsstorage/ml/clusters"
model_path = "../../models/kmeans_model"

mlflow.set_tracking_uri("file:../../mlruns")
mlflow.set_experiment("News_Clustering")

<Experiment: artifact_location='file:///c:/Users/Echelon/Desktop/re/sa-news/datapipeline/notebooks/../../mlruns/167325228610632209', creation_time=1771327003642, experiment_id='167325228610632209', last_update_time=1771327003642, lifecycle_stage='active', name='News_Clustering', tags={}>

In [6]:
gold_df = (
    spark.read.format("delta").load(gold_path)
    .select(
        col("bronze_hash"),
        col("embedding")
    )
    .filter(col("embedding").isNotNull())
)

In [7]:
gold_df.printSchema()
gold_df.count()

root
 |-- bronze_hash: long (nullable = true)
 |-- embedding: array (nullable = true)
 |    |-- element: float (containsNull = true)



27663

In [8]:
clusters_exist = DeltaTable.isDeltaTable(spark, ml_cluster_path)
model_exists = os.path.exists(model_path)

In [9]:
if clusters_exist:
    existing_clusters = spark.read.format("delta").load(ml_cluster_path)

    gold_df = gold_df.join(
        existing_clusters.select("bronze_hash"),
        on="bronze_hash",
        how="left_anti"
    )

if gold_df.limit(1).count() == 0:
    print("No new articles to cluster.")
    spark.stop()
    raise SystemExit

In [10]:
pdf = gold_df.toPandas()
embeddings = np.array(pdf["embedding"].tolist())

In [13]:
if model_exists:
    model = mlflow.sklearn.load_model(model_path)

else:
    k = 20

    with mlflow.start_run():
        model = KMeans(
            n_clusters=k,
            random_state=42,
            n_init="auto"
        )

        model.fit(embeddings)

        mlflow.log_param("n_clusters", k)
        mlflow.log_metric("inertia", model.inertia_)
        mlflow.sklearn.log_model(model, "model")

    mlflow.sklearn.save_model(model, model_path)

  flavor.save_model(path=local_path, mlflow_model=mlflow_model, **kwargs)
  mlflow.sklearn.save_model(model, model_path)


In [14]:
clusters = model.predict(embeddings)
pdf["cluster_id"] = clusters

In [15]:
cluster_df = spark.createDataFrame(
    pdf[["bronze_hash", "cluster_id"]]
)

In [16]:
if clusters_exist:

    delta_table = DeltaTable.forPath(spark, ml_cluster_path)

    (
        delta_table.alias("t")
        .merge(
            cluster_df.alias("s"),
            "t.bronze_hash = s.bronze_hash"
        )
        .whenNotMatchedInsertAll()
        .execute()
    )

else:

    cluster_df.write.format("delta").mode("overwrite").save(ml_cluster_path)

In [17]:
cluster_df.count()

27663

In [18]:
cluster_df.show()

+--------------------+----------+
|         bronze_hash|cluster_id|
+--------------------+----------+
|-1889541934639726581|         1|
|-1802607308808159383|        15|
|-7957235805070414591|        14|
|  754431215262548956|        18|
|-5991696855120567499|         6|
| 6908858088476876150|         0|
|-2905717159464896182|        18|
|-6717650960232191711|        15|
| 6079411292996756375|        15|
|  209973296890248629|         8|
|  160806983567694115|        10|
|-7700644663983271196|        15|
| 7122320868235704437|         9|
| -239444442816393325|         2|
|-2418879148395033879|         2|
|-7546881819110556056|         7|
|  984629088657411331|         9|
|-6427301616982749127|         9|
|-5360023006758077006|         6|
| -461260544960610769|         1|
+--------------------+----------+
only showing top 20 rows



In [None]:
spark.read.format("delta").load(ml_cluster_path) \
    .groupBy("cluster_id") \
    .count() \
    .orderBy("count", ascending=False) \
    .show()

cluster_df.show(5)

+--------------------+----------+
|         bronze_hash|cluster_id|
+--------------------+----------+
|-1889541934639726581|         1|
|-1802607308808159383|        15|
|-7957235805070414591|        14|
|  754431215262548956|        18|
|-5991696855120567499|         6|
+--------------------+----------+
only showing top 5 rows

