In [2]:
from pyspark.sql import SparkSession
import matplotlib.pyplot as plt

vocabulary_size = 500
global_truncate = True

#Create a SparkSession
spark = SparkSession.builder \
    .appName('app_name') \
    .master('local[*]') \
    .config('spark.sql.execution.arrow.pyspark.enabled', True) \
    .config('spark.sql.session.timeZone', 'UTC') \
    .config('spark.driver.memory','16g') \
    .config("spark.executor.memory", "16g") \
    .config('spark.ui.showConsoleProgress', True) \
    .config('spark.sql.repl.eagerEval.enabled', True) \
    .getOrCreate()

23/07/26 00:55:44 WARN Utils: Your hostname, lab01 resolves to a loopback address: 127.0.1.1; using 10.0.1.132 instead (on interface eth0)
23/07/26 00:55:44 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/07/26 00:55:46 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
23/07/26 00:55:49 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
23/07/26 00:55:49 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.


In [3]:
from pyspark.sql import functions as func

data = spark.read.csv(
    'train.csv', 
    header=True,
    inferSchema=True,
    escape='"'
)

data = data.withColumn("text", func.regexp_replace("text", '"', ''))

data.createOrReplaceTempView("spooky_sentences")

result_df = spark.sql(
    '''
    SELECT 
        text AS sentence,
        author,
        size(split(text, ' ')) AS word_count
    FROM
        spooky_sentences
    ORDER BY 
        word_count DESC
    
    '''
)

print(data.show())


                                                                                

+-------+--------------------+------+
|     id|                text|author|
+-------+--------------------+------+
|id26305|This process, how...|   EAP|
|id17569|It never once occ...|   HPL|
|id11008|In his left hand ...|   EAP|
|id27763|How lovely is spr...|   MWS|
|id12958|Finding nothing e...|   HPL|
|id22965|A youth passed in...|   MWS|
|id09674|The astronomer, p...|   EAP|
|id13515|The surcingle hun...|   EAP|
|id19322|I knew that you c...|   EAP|
|id00912|I confess that ne...|   MWS|
|id16737|He shall find tha...|   MWS|
|id16607|Here we barricade...|   EAP|
|id19764|Herbert West need...|   HPL|
|id18886|The farm like gro...|   HPL|
|id17189|But a glance will...|   EAP|
|id12799|He had escaped me...|   MWS|
|id08441|To these speeches...|   EAP|
|id13117|Her native sprigh...|   MWS|
|id14862|I even went so fa...|   EAP|
|id20836|His facial aspect...|   HPL|
+-------+--------------------+------+
only showing top 20 rows

None


In [4]:
from pyspark.ml.feature import SQLTransformer
from pyspark.sql.types import StringType
from pyspark.sql.functions import udf
import re

def remove_symbols(text):
    cleaned_text = re.sub(r"[^\w\s]", "", text)
    lowercase_text = cleaned_text.lower()
    return lowercase_text

# Register the UDF
udf_remove_symbols = udf(remove_symbols, StringType())
spark.udf.register("remove_symbols", udf_remove_symbols)


In [8]:
# Pipelining
from pyspark.ml import Pipeline
from pyspark.ml.feature import StopWordsRemover, Tokenizer, CountVectorizer, IDF, StringIndexer

# Define the SQLTransformer stage to apply the UDF
sql_transformer = SQLTransformer(
    statement="SELECT *, remove_symbols(sentence) AS cleaned_text FROM __THIS__"
)

# Step 1: Tokenization
tokenizer = Tokenizer(inputCol="cleaned_text", outputCol="tokens")

# Step 2: Stop word removal
stopwords_remover = StopWordsRemover(inputCol="tokens", outputCol="filtered_tokens")

# Step 3: TF-IDF calculation
vectorizer = CountVectorizer(inputCol="filtered_tokens", outputCol="vectorized_tokens")
idf = IDF(inputCol="vectorized_tokens", outputCol="tfidf")

string_indexer = StringIndexer(inputCol='author', outputCol='label')

processed_pipeline = Pipeline(stages=[sql_transformer, tokenizer, stopwords_remover, vectorizer, idf, string_indexer])

processed_train_data = processed_pipeline.fit(result_df).transform(result_df)

processed_train_data.show(10)

                                                                                

Py4JJavaError: An error occurred while calling o125.fit.
: org.apache.spark.SparkException: Input column author does not exist.
	at org.apache.spark.ml.feature.StringIndexerBase.$anonfun$validateAndTransformSchema$2(StringIndexer.scala:128)
	at scala.collection.TraversableLike.$anonfun$flatMap$1(TraversableLike.scala:293)
	at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
	at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
	at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:198)
	at scala.collection.TraversableLike.flatMap(TraversableLike.scala:293)
	at scala.collection.TraversableLike.flatMap$(TraversableLike.scala:290)
	at scala.collection.mutable.ArrayOps$ofRef.flatMap(ArrayOps.scala:198)
	at org.apache.spark.ml.feature.StringIndexerBase.validateAndTransformSchema(StringIndexer.scala:123)
	at org.apache.spark.ml.feature.StringIndexerBase.validateAndTransformSchema$(StringIndexer.scala:115)
	at org.apache.spark.ml.feature.StringIndexer.validateAndTransformSchema(StringIndexer.scala:145)
	at org.apache.spark.ml.feature.StringIndexer.transformSchema(StringIndexer.scala:252)
	at org.apache.spark.ml.PipelineStage.transformSchema(Pipeline.scala:71)
	at org.apache.spark.ml.feature.StringIndexer.fit(StringIndexer.scala:237)
	at org.apache.spark.ml.feature.StringIndexer.fit(StringIndexer.scala:145)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.lang.Thread.run(Thread.java:750)
