In [0]:
#Imports & load test data

import mlflow
import mlflow.spark

from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.sql import functions as F

gold_table = "nlp_dev.gold.fasttext_gold"
df_gold = spark.read.table(gold_table)

df_test = df_gold.filter("split = 'test'").select("tfidf_features", "label", "split", "ingest_date").cache()

print("Test rows:", df_test.count())
display(df_test.limit(5))


In [0]:
#Load the registered model from UC

model_name = "nlp_dev.ml.fasttext_sentiment_lr"
model_version = 1

model_uri = f"models:/{model_name}/{model_version}"

lr_model = mlflow.spark.load_model(model_uri)



In [0]:
#Run predictions on test set

df_test_pred = lr_model.transform(df_test)

display(df_test_pred.select("label", "prediction", "probability").limit(10))


In [0]:
#Evaluate test performance (sanity check)

evaluator_acc = MulticlassClassificationEvaluator(
    labelCol="label", predictionCol="prediction", metricName="accuracy"
)

evaluator_f1 = MulticlassClassificationEvaluator(
    labelCol="label", predictionCol="prediction", metricName="f1"
)

test_accuracy = evaluator_acc.evaluate(df_test_pred)
test_f1 = evaluator_f1.evaluate(df_test_pred)

print("Test accuracy:", test_accuracy)
print("Test F1      :", test_f1)



In [0]:
pred_path = f"abfss://nlp@nlplakeadls001.dfs.core.windows.net/ml/fasttext_predictions/ingest_date=2025-10-10"

df_out = (
    df_test_pred
      .select(
          "label",
          "prediction",
          "probability",
          "split",
          "ingest_date"
      )
)

df_out.write.format("delta").mode("overwrite").option("overwriteSchema", "true").save(pred_path)
print("Wrote predictions to:", pred_path)



In [0]:
%sql
CREATE SCHEMA IF NOT EXISTS nlp_dev.ml;

CREATE TABLE IF NOT EXISTS nlp_dev.ml.fasttext_predictions
USING DELTA
LOCATION 'abfss://nlp@nlplakeadls001.dfs.core.windows.net/ml/fasttext_predictions/ingest_date=2025-10-10';


In [0]:
%sql
SELECT prediction, COUNT(*) 
FROM nlp_dev.ml.fasttext_predictions
GROUP BY prediction
ORDER BY prediction;
