In [1]:
import sys
import os

project_root = os.path.abspath("../..")

if project_root not in sys.path:
    sys.path.append(project_root)

os.environ["PYSPARK_PYTHON"] = sys.executable
os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable

In [2]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
os.environ["TRANSFORMERS_NO_TORCHVISION"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [3]:
from datapipeline.utils.spark_session import get_spark_session

spark = get_spark_session("Gold_NER_Embedding")
spark.conf.set("spark.databricks.delta.schema.autoMerge.enabled", "true")
spark.conf.set("spark.sql.parquet.enableVectorizedReader", "false")

In [4]:
gold_lang_path = "../../sanewsstorage/gold/articles_lang"

gold_df = spark.read.format("delta").load(gold_lang_path)

In [5]:
gold_df = gold_df.repartition(4)

In [6]:
import pandas as pd
import spacy

from pyspark.sql.functions import pandas_udf, col
from pyspark.sql.types import ArrayType, StructType, StructField, StringType

In [7]:
ner_schema = ArrayType(
    StructType([
        StructField("entity", StringType(), True),
        StructField("label", StringType(), True)
    ])
)

In [8]:
@pandas_udf(ner_schema)
def ner_udf(texts: pd.Series) -> pd.Series:

    nlp = spacy.load("xx_ent_wiki_sm")

    results = []

    for text in texts:

        if text is None or text.strip() == "":
            results.append([])
            continue

        doc = nlp(text)

        ents = [
            (ent.text, ent.label_)
            for ent in doc.ents
        ]

        results.append(ents)

    return pd.Series(results)

In [9]:
gold_df = gold_df.withColumn(
    "entities",
    ner_udf(col("clean_text"))
)

In [10]:
from sentence_transformers import SentenceTransformer
from pyspark.sql.types import ArrayType, FloatType

In [11]:
embedding_schema = ArrayType(FloatType())

In [12]:
@pandas_udf(embedding_schema)
def embedding_udf(texts: pd.Series) -> pd.Series:

    model = SentenceTransformer(
        "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
    )

    text_list = texts.fillna("").tolist()

    vectors = model.encode(
        text_list,
        batch_size=32,
        show_progress_bar=False
    )

    return pd.Series(
        [vec.tolist() for vec in vectors]
    )

In [13]:
gold_df = gold_df.withColumn(
    "embedding",
    embedding_udf(col("clean_text"))
)

In [15]:
gold_df.show(5)

+--------------------+--------------------+--------------------+--------------------+-------------------+--------+--------------------+--------------------+--------------------+--------------------+--------------------+-------------------+----------------+--------------------+--------------------+-----------------+--------------+--------------------+--------------------+
|          article_id|               title|         description|             content|       published_at|language|                 url|            keywords|          categories|             creator|           source_id|        source_name|ingestion_source|         bronze_hash|          clean_text|language_detected|language_final|            entities|           embedding|
+--------------------+--------------------+--------------------+--------------------+-------------------+--------+--------------------+--------------------+--------------------+--------------------+--------------------+-------------------+-------------

In [14]:
gold_df.count()

27663

In [16]:
from delta.tables import DeltaTable

gold_ml_path = "../../sanewsstorage/gold/articles_enriched"

if DeltaTable.isDeltaTable(spark, gold_ml_path):

    delta_table = DeltaTable.forPath(spark, gold_ml_path)

    (
        delta_table.alias("t")
        .merge(
            gold_df.alias("s"),
            "t.bronze_hash = s.bronze_hash"
        )
        .whenMatchedUpdateAll()
        .whenNotMatchedInsertAll()
        .execute()
    )

else:

    (
        gold_df.write
        .format("delta")
        .mode("overwrite")
        .save(gold_ml_path)
    )
