In [1]:
import sys
import os

project_root = os.path.abspath("../..")

if project_root not in sys.path:
    sys.path.append(project_root)

os.environ["PYSPARK_PYTHON"] = sys.executable
os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable

In [2]:
from datapipeline.utils.spark_session import get_spark_session

spark = get_spark_session("ML_Clustering")

spark.conf.set("spark.databricks.delta.schema.autoMerge.enabled", "true")
spark.conf.set("spark.sql.parquet.enableVectorizedReader", "false")

In [3]:
import mlflow

mlflow.set_tracking_uri("file:../../mlruns")
mlflow.set_experiment("News_Clustering")

  return FileStore(store_uri, store_uri)


<Experiment: artifact_location='file:///c:/Users/Echelon/Desktop/re/sa-news/datapipeline/notebooks/../../mlruns/167325228610632209', creation_time=1771327003642, experiment_id='167325228610632209', last_update_time=1771327003642, lifecycle_stage='active', name='News_Clustering', tags={}>

In [4]:
gold_path = "../../sanewsstorage/gold/articles_final"

In [5]:
from pyspark.sql.functions import col

gold_df = (
    spark.read.format("delta").load(gold_path)
    .select(
        col("bronze_hash"),
        col("clean_text"),
        col("embedding"),
        col("published_at"),
        col("topic")
    )
    .filter(col("embedding").isNotNull())
)

In [9]:
gold_df.printSchema()
gold_df.count()

root
 |-- bronze_hash: long (nullable = true)
 |-- clean_text: string (nullable = true)
 |-- embedding: array (nullable = true)
 |    |-- element: float (containsNull = true)
 |-- published_at: timestamp (nullable = true)
 |-- topic: string (nullable = true)



27663

In [10]:
import pandas as pd
import numpy as np

pdf = gold_df.toPandas()

embeddings = np.array(pdf["embedding"].tolist())

In [11]:
from sklearn.cluster import KMeans

k = 20

with mlflow.start_run():

    kmeans = KMeans(
        n_clusters=k,
        random_state=42,
        n_init="auto"
    )

    clusters = kmeans.fit_predict(embeddings)

    pdf["cluster_id"] = clusters

    mlflow.log_param("n_clusters", k)
    mlflow.log_metric("inertia", kmeans.inertia_)

    mlflow.sklearn.log_model(
        kmeans,
        "kmeans_model"
    )


  flavor.save_model(path=local_path, mlflow_model=mlflow_model, **kwargs)


In [12]:
cluster_df = spark.createDataFrame(pdf)

In [13]:
cluster_df = cluster_df.select(
    "bronze_hash",
    "cluster_id"
)

In [14]:
ml_cluster_path = "../../sanewsstorage/ml/clusters"

In [22]:
from delta.tables import DeltaTable

if DeltaTable.isDeltaTable(spark, ml_cluster_path):

    delta_table = DeltaTable.forPath(spark, ml_cluster_path)

    (
        delta_table.alias("t")
        .merge(
            cluster_df.alias("s"),
            "t.bronze_hash = s.bronze_hash"
        )
        .whenNotMatchedInsertAll()
        .execute()
    )

else:

    (
        cluster_df.write
        .format("delta")
        .mode("overwrite")
        .save(ml_cluster_path)
    )


In [23]:
cluster_df.groupBy("cluster_id").count().orderBy("count", ascending=False).show()

+----------+-----+
|cluster_id|count|
+----------+-----+
|        11| 2221|
|         2| 2048|
|         1| 2034|
|         4| 1961|
|         8| 1936|
|        14| 1924|
|         9| 1729|
|        18| 1715|
|        16| 1711|
|        10| 1585|
|         6| 1389|
|        19| 1333|
|         0| 1299|
|        15| 1066|
|         5|  942|
|         7|  763|
|        13|  572|
|        17|  505|
|         3|  480|
|        12|  450|
+----------+-----+



In [17]:
cluster_df.show(5)

+--------------------+----------+
|         bronze_hash|cluster_id|
+--------------------+----------+
|-1889541934639726581|         1|
|-1802607308808159383|        15|
|-7957235805070414591|        14|
|  754431215262548956|        18|
|-5991696855120567499|         6|
+--------------------+----------+
only showing top 5 rows

