In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import regexp_replace, col
from sparknlp.base import DocumentAssembler, Pipeline, EmbeddingsFinisher
from sparknlp.annotator import Tokenizer, WordEmbeddingsModel, SentenceEmbeddings
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
import pyspark.sql.functions as F
from pyspark.ml.classification import LogisticRegression
from pyspark.sql.window import Window
import matplotlib.pyplot as plt
from IPython.display import clear_output
from pyspark.sql.types import StructType, StructField, StringType

# Note: Spark-nlp downloads huge amount of jar files, this might take a while

spark = (
    SparkSession.builder.appName("Spark-Text-Classification")
    .master("local[*]")
    .config("spark.driver.memory", "8G")
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
    .config("spark.kryoserializer.buffer.max", "2000M")
    .config("spark.driver.maxResultSize", "0")
    .config("spark.jars.packages", "com.johnsnowlabs.nlp:spark-nlp_2.12:5.5.3")
    .getOrCreate()
)

# Read IMDB Dataset

In [None]:
# https://huggingface.co/datasets/stanfordnlp/imdb

imdb_dataset = spark.read.parquet("data/train-imdb.parquet").withColumn(
    "text", regexp_replace(col("text"), "[^a-zA-Z0-9\\s]", "")
)

# Create a window partitioned by label and ordered randomly
windowSpec = Window.partitionBy("label").orderBy(F.rand())

"""
# Add a row number per label and filter to keep only the first 1000 rows per class
# This ensures for efficiency and balance in the dataset for testing. 
# For production, you might want to use full dataset
"""
imdb_dataset = (
    imdb_dataset.withColumn("row_num", F.row_number().over(windowSpec))
    .filter(F.col("row_num") <= 3000)
    .drop("row_num")
)

# Show the sampled dataset
imdb_dataset.show(5, truncate=50)

print(
    "Number of classes in the sampled dataset: ",
    imdb_dataset.select("label").distinct().count(),
    "total number of rows: ",
    imdb_dataset.count(),
)

# Pipeline for Tokenizing and Embedding Creation

In [None]:
# Define Spark NLP pipeline stages

# 1. DocumentAssembler converts raw text into a document annotation.
document_assembler = DocumentAssembler().setInputCol("text").setOutputCol("document")

# 2. Tokenizer splits the document into tokens.
tokenizer = Tokenizer().setInputCols(["document"]).setOutputCol("token")

# 3. Load pre-trained GloVe embeddings
# 3. Load GloVe embeddings from local file
word_embeddings = (
    WordEmbeddingsModel.load("data/glove_100d")
    .setInputCols(["document", "token"])
    .setOutputCol("embeddings")
)

# 4. Create sentence-level embeddings by averaging the word embeddings.
sentence_embeddings = (
    SentenceEmbeddings()
    .setInputCols(["document", "embeddings"])
    .setOutputCol("sentence_embeddings")
    .setPoolingStrategy("AVERAGE")
)

# 5. Finisher converts NLP annotations into plain array column
finisher = (
    EmbeddingsFinisher()
    .setInputCols("sentence_embeddings")
    .setOutputCols("features")
    .setOutputAsVector(True)
    .setCleanAnnotations(True)
)

# Build the pipeline
nlp_pipeline = Pipeline(
    stages=[
        document_assembler,
        tokenizer,
        word_embeddings,
        sentence_embeddings,
        finisher,
    ]
)

In [None]:
# Fit and transform the data

fitted_pipeline = nlp_pipeline.fit(imdb_dataset)

final_data = fitted_pipeline.transform(imdb_dataset).selectExpr(
    "label", "explode(features) as features"
)

final_data.show(5, truncate=50)

# Train logistic regression model

In [None]:
# Split data into training (80%) and testing (20%) sets
train_data, test_data = final_data.randomSplit([0.8, 0.2], seed=42)

# Configure the LogisticRegression model from spark-rapids-ml
lr = LogisticRegression(featuresCol="features", labelCol="label", maxIter=50)

# Fit the model using the training data
model = lr.fit(train_data)

# Make predictions on the test data
predictions = model.transform(test_data)

# Evaluate test accuracy; compare the predicted labels with numeric labels ("label_index")
evaluator = MulticlassClassificationEvaluator(
    labelCol="label", predictionCol="prediction", metricName="accuracy"
)
accuracy = evaluator.evaluate(predictions)
print(f"Test Accuracy: {accuracy:.2f}")

In [None]:
predictions.show(5, truncate=50)

# Spark Structured Streaming for Text Classification
## Plotting and Embedding Pipline

In [None]:
batch_pos_counts, batch_neg_counts, batch_ids = [], [], []


def plot_sentiment(pos_count, neg_count, batch_id):
    plt.figure(figsize=(16, 6))
    plt.subplot(1, 2, 1)
    plt.plot(batch_ids, batch_pos_counts, "g-o", label="Positive")
    plt.plot(batch_ids, batch_neg_counts, "r-o", label="Negative")
    plt.xlabel("Batch")
    plt.ylabel("Count")
    plt.title("Sentiment Analysis Results by Batch")
    plt.legend()
    plt.grid(True)
    plt.ylim(bottom=0)
    plt.subplot(1, 2, 2)
    labels = ["Positive", "Negative"]
    sizes = [pos_count, neg_count]
    colors = ["green", "red"]
    plt.pie(sizes, labels=labels, colors=colors, autopct="%1.1f%%", startangle=90)
    plt.axis("equal")
    plt.title(f"Batch {batch_id} Sentiment Distribution")
    plt.tight_layout()
    clear_output(wait=True)
    plt.show()


def process_batch(df, batch_id):
    if df.count() == 0:
        print(f"Batch {batch_id}: Empty batch")
        return

    df = df.withColumn("text", regexp_replace(col("text"), "[^a-zA-Z0-9\\s]", ""))

    # Apply the Spark NLP pipeline to get sentence embeddings
    embeddings_df = fitted_pipeline.transform(df).selectExpr(
        "explode(features) as features"
    )

    # Predict using the trained Spark ML model
    predictions = model.transform(embeddings_df)

    # Collect predictions to driver for plotting
    counts = predictions.groupBy("prediction").count().toPandas()
    pos_count = (
        counts[counts["prediction"] == 1]["count"].sum()
        if 1 in counts["prediction"].values
        else 0
    )
    neg_count = (
        counts[counts["prediction"] == 0]["count"].sum()
        if 0 in counts["prediction"].values
        else 0
    )
    batch_pos_counts.append(pos_count)
    batch_neg_counts.append(neg_count)
    batch_ids.append(batch_id)
    plot_sentiment(pos_count, neg_count, batch_id)

## Spark Structured Streaming

In [None]:
yelp_stream = (
    spark.readStream.format("parquet")
    .schema(StructType([StructField("text", StringType(), True)]))
    .option("maxFilesPerTrigger", 1)
    .load("data/yelp/")
)

query = (
    yelp_stream.writeStream
    .foreachBatch(process_batch)
    .outputMode("append")
    .start()
)

try:
    query.awaitTermination(timeout=600)
    print("Query terminated after timeout")
except Exception as e:
    print(f"Query terminated due to: {e}")