In [13]:
# Multi-task NLP Pipeline using Apache Spark (PySpark) + Full EDA + Advanced ML Model

from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql import Row
from pyspark.sql.functions import col, when, array, length, size, lit, count, udf, concat_ws, trim, split, desc
from pyspark.sql.types import ArrayType, FloatType, IntegerType, StringType
from pyspark.ml.feature import StringIndexer, Tokenizer, StopWordsRemover, CountVectorizer, IDF, IndexToString, StringIndexerModel
from pyspark.ml.classification import RandomForestClassifier, OneVsRest
from pyspark import StorageLevel
from pyspark.ml import Pipeline
from pyspark.sql import DataFrame
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import nltk
from wordcloud import WordCloud
from sklearn.feature_extraction.text import CountVectorizer as SklearnCountVectorizer
from collections import Counter
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from sparknlp.base import DocumentAssembler
from sparknlp.annotator import BertSentenceEmbeddings, ClassifierDLApproach
from pyspark.ml import Pipeline
from pyspark.sql.functions import udf, col
from pyspark.sql.types import StringType
from pyspark.ml.feature import StringIndexer

nltk.download('stopwords')
from nltk.corpus import stopwords
stop_words = set(stopwords.words('english'))
from nltk.util import ngrams
nltk.download("punkt")

[nltk_data] Downloading package stopwords to /usr/share/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [14]:
# Start Spark Session with memory optimizations
spark = SparkSession.builder \
    .appName("") \
    .master("local[*]") \
    .config("spark.jars.packages", "com.johnsnowlabs.nlp:spark-nlp_2.12:5.5.3") \
    .getOrCreate()

In [15]:
full_df = spark.read.csv(
    "/kaggle/input/merged-data/merged_dataset_clean.csv",
    header=True,
    inferSchema=True,
    multiLine=True,
    escape="\""
)


In [16]:
full_df.head(5)

[Row(id='eew5j0j', text='That game hurt.', label='sadness'),
 Row(id='ed2mah1', text="You do right, if you don't care then fuck 'em!", label='neutral'),
 Row(id='eeibobj', text='Man I love reddit.', label='love'),
 Row(id='eda6yn6', text='[NAME] was nowhere near them, he was by the Falcon.', label='neutral'),
 Row(id='eespn2i', text='Right? Considering it’s such an important document, I should know the damned thing backwards and forwards... thanks again for the help!', label='gratitude')]

In [30]:
from pyspark.ml import Pipeline
import random

# Prepare base stages
document = DocumentAssembler().setInputCol("text").setOutputCol("document")
bert = BertSentenceEmbeddings.pretrained("sent_small_bert_L2_128", "en") \
    .setInputCols(["document"]).setOutputCol("sentence_embeddings")

# Shuffle and retrain loop
num_rounds = 5
best_model = None
best_accuracy = 0.0

for i in range(num_rounds):
    print(f"\n⚙️ Training round {i+1}/{num_rounds}")
    
    # Shuffle and split
    # Sample e.g., 30% of data, shuffle
    sample_df = full_df.sample(withReplacement=False, fraction=0.3, seed=42)
    shuffled_df = sample_df.orderBy(F.rand(seed=random.randint(0, 9999)))
    train_df = shuffled_df.limit(int(shuffled_df.count() * 0.8))
    
    classifier = ClassifierDLApproach() \
        .setInputCols(["sentence_embeddings"]) \
        .setOutputCol("category") \
        .setLabelColumn("label") \
        .setMaxEpochs(5) \
        .setEnableOutputLogs(True) \
        .setBatchSize(32) \
        .setLr(0.003) \
        .setValidationSplit(0.2) \
        .setRandomSeed(random.randint(0, 9999))
    
    pipeline = Pipeline(stages=[document, bert, classifier])
    model = pipeline.fit(train_df)

sent_small_bert_L2_128 download started this may take some time.


25/05/06 00:21:44 WARN S3AbortableInputStream: Not all bytes were read from the S3ObjectInputStream, aborting HTTP connection. This is likely an error and may result in sub-optimal behavior. Request only the bytes you need via a ranged GET or drain the input stream after use.


Approximate size to download 16.1 MB
[OK!]

⚙️ Training round 1/5


                                                                                                    

Training started - epochs: 5 - learning_rate: 0.003 - batch_size: 32 - training_examples: 18813 - classes: 31
Epoch 1/5 - 4.61s - loss: 1921.7938 - acc: 0.20962082 - batches: 588
Quality on validation dataset (20.0%), validation examples = 4703
time to finish evaluation: 0.48s
Macro-average	 prec: 0.006721859, rec: 0.032258064, f1: 0.011125428
Micro-average	 prec: 0.20837763, recall: 0.20837763, f1: 0.20837763
Epoch 2/5 - 4.25s - loss: 1921.8977 - acc: 0.20978053 - batches: 588
Quality on validation dataset (20.0%), validation examples = 4703
time to finish evaluation: 0.17s
Macro-average	 prec: 0.006721859, rec: 0.032258064, f1: 0.011125428
Micro-average	 prec: 0.20837763, recall: 0.20837763, f1: 0.20837763
Epoch 3/5 - 4.61s - loss: 1921.8977 - acc: 0.20978053 - batches: 588
Quality on validation dataset (20.0%), validation examples = 4703
time to finish evaluation: 0.16s
Macro-average	 prec: 0.006721859, rec: 0.032258064, f1: 0.011125428
Micro-average	 prec: 0.20837763, recall: 0.208

                                                                                                    

Training started - epochs: 5 - learning_rate: 0.003 - batch_size: 32 - training_examples: 18813 - classes: 31
Epoch 1/5 - 4.71s - loss: 1924.7842 - acc: 0.20923713 - batches: 588
Quality on validation dataset (20.0%), validation examples = 4703
time to finish evaluation: 0.63s
Macro-average	 prec: 0.0067904494, rec: 0.032258064, f1: 0.011219211
Micro-average	 prec: 0.21050394, recall: 0.21050394, f1: 0.21050394
Epoch 2/5 - 4.47s - loss: 1924.8977 - acc: 0.21008892 - batches: 588
Quality on validation dataset (20.0%), validation examples = 4703
time to finish evaluation: 0.14s
Macro-average	 prec: 0.0067904494, rec: 0.032258064, f1: 0.011219211
Micro-average	 prec: 0.21050394, recall: 0.21050394, f1: 0.21050394
Epoch 3/5 - 4.26s - loss: 1924.8977 - acc: 0.21008892 - batches: 588
Quality on validation dataset (20.0%), validation examples = 4703
time to finish evaluation: 0.16s
Macro-average	 prec: 0.0067904494, rec: 0.032258064, f1: 0.011219211
Micro-average	 prec: 0.21050394, recall: 0.

                                                                                                    

Training started - epochs: 5 - learning_rate: 0.003 - batch_size: 32 - training_examples: 18813 - classes: 31
Epoch 1/5 - 5.63s - loss: 1954.8182 - acc: 0.16334702 - batches: 588
Quality on validation dataset (20.0%), validation examples = 4703
time to finish evaluation: 0.53s
Macro-average	 prec: 0.005329474, rec: 0.032258064, f1: 0.009147633
Micro-average	 prec: 0.16521369, recall: 0.16521369, f1: 0.16521369
Epoch 2/5 - 4.34s - loss: 1954.8977 - acc: 0.16377291 - batches: 588
Quality on validation dataset (20.0%), validation examples = 4703
time to finish evaluation: 0.17s
Macro-average	 prec: 0.005329474, rec: 0.032258064, f1: 0.009147633
Micro-average	 prec: 0.16521369, recall: 0.16521369, f1: 0.16521369
Epoch 3/5 - 4.37s - loss: 1954.8977 - acc: 0.16377291 - batches: 588
Quality on validation dataset (20.0%), validation examples = 4703
time to finish evaluation: 0.17s
Macro-average	 prec: 0.005329474, rec: 0.032258064, f1: 0.009147633
Micro-average	 prec: 0.16521369, recall: 0.165

                                                                                                    

Training started - epochs: 5 - learning_rate: 0.003 - batch_size: 32 - training_examples: 18813 - classes: 31
Epoch 1/5 - 4.81s - loss: 1951.7585 - acc: 0.20892873 - batches: 588
Quality on validation dataset (20.0%), validation examples = 4703
time to finish evaluation: 0.53s
Macro-average	 prec: 0.006687564, rec: 0.032258064, f1: 0.011078413
Micro-average	 prec: 0.20731448, recall: 0.20731448, f1: 0.20731448
Epoch 2/5 - 4.46s - loss: 1951.8977 - acc: 0.20956758 - batches: 588
Quality on validation dataset (20.0%), validation examples = 4703
time to finish evaluation: 0.17s
Macro-average	 prec: 0.006687564, rec: 0.032258064, f1: 0.011078413
Micro-average	 prec: 0.20731448, recall: 0.20731448, f1: 0.20731448
Epoch 3/5 - 4.54s - loss: 1951.8977 - acc: 0.20956758 - batches: 588
Quality on validation dataset (20.0%), validation examples = 4703
time to finish evaluation: 0.17s
Macro-average	 prec: 0.006687564, rec: 0.032258064, f1: 0.011078413
Micro-average	 prec: 0.20731448, recall: 0.207

                                                                                                    

Training started - epochs: 5 - learning_rate: 0.003 - batch_size: 32 - training_examples: 18813 - classes: 31
Epoch 1/5 - 4.70s - loss: 1931.7916 - acc: 0.20983376 - batches: 588
Quality on validation dataset (20.0%), validation examples = 4703
time to finish evaluation: 0.49s
Macro-average	 prec: 0.006742436, rec: 0.032258064, f1: 0.011153597
Micro-average	 prec: 0.20901552, recall: 0.20901552, f1: 0.2090155
Epoch 2/5 - 4.30s - loss: 1930.8977 - acc: 0.21020642 - batches: 588
Quality on validation dataset (20.0%), validation examples = 4703
time to finish evaluation: 0.16s
Macro-average	 prec: 0.006742436, rec: 0.032258064, f1: 0.011153597
Micro-average	 prec: 0.20901552, recall: 0.20901552, f1: 0.2090155
Epoch 3/5 - 4.39s - loss: 1930.8977 - acc: 0.21020642 - batches: 588
Quality on validation dataset (20.0%), validation examples = 4703
time to finish evaluation: 0.14s
Macro-average	 prec: 0.006742436, rec: 0.032258064, f1: 0.011153597
Micro-average	 prec: 0.20901552, recall: 0.20901

In [31]:
# Step 1: Predict on test data
predictions = model.transform(test_df)

# Step 2: Extract predicted label from category.result (which is an array)
extract_pred = udf(lambda x: x[0] if isinstance(x, list) and len(x) > 0 else "", StringType())
predictions = predictions.withColumn("predicted_label", extract_pred(col("category.result")))

# Step 3: Fit StringIndexers on both true and predicted labels
label_indexer = StringIndexer(inputCol="label", outputCol="label_index").fit(predictions)
pred_indexer = StringIndexer(inputCol="predicted_label", outputCol="predicted_label_index").fit(predictions)

# Step 4: Transform data to get numeric labels for evaluation
predictions = label_indexer.transform(predictions)
predictions = pred_indexer.transform(predictions)


                                                                                                    

In [32]:
evaluator = MulticlassClassificationEvaluator(
    labelCol="label_index",
    predictionCol="predicted_label_index",
    metricName="accuracy"
)

accuracy = evaluator.evaluate(predictions)
print("BERT Model Accuracy:", accuracy)


[Stage 108:>                                                                            (0 + 1) / 1]

BERT Model Accuracy: 0.19848145128414188


                                                                                                    

In [33]:
# Create manual test samples for all 5 labels
manual_df = spark.createDataFrame([
    ("I don't want to live anymore",),                    
    ("I need to talk to a therapist",),                  
    ("I feel so alone lately",),                         
    ("I'm constantly worried and can't relax",),         
    ("Nothing excites me anymore, I'm always sad",)      
], ["text"])

# Run through trained pipeline model
manual_preds = model.transform(manual_df)

# Show the results
manual_preds.select("text", "category.result").show(truncate=False)


[Stage 111:>                                                                            (0 + 3) / 3]

+------------------------------------------+------------+
|text                                      |result      |
+------------------------------------------+------------+
|I don't want to live anymore              |[depression]|
|I need to talk to a therapist             |[depression]|
|I feel so alone lately                    |[depression]|
|I'm constantly worried and can't relax    |[depression]|
|Nothing excites me anymore, I'm always sad|[depression]|
+------------------------------------------+------------+



                                                                                                    

In [34]:
spark.stop

<bound method SparkSession.stop of <pyspark.sql.session.SparkSession object at 0x79678355b290>>