In [1]:
from pyspark.sql import SparkSession

spark = SparkSession.builder \
        .master("local") \
        .appName("Spark ML 2") \
        .config("spark.ui.port", "4050") \
        .getOrCreate()

sc = spark.sparkContext

22/10/27 17:12:53 WARN Utils: Your hostname, orange resolves to a loopback address: 127.0.1.1; using 166.104.246.51 instead (on interface enp15s0)
22/10/27 17:12:53 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


22/10/27 17:12:54 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [8]:
df = spark.read.format('csv')\
               .option('header', 'true')\
               .option("quote", "\"")\
               .option("escape", "\"")\
               .load('../../data/imdb-review-sentiment.csv')
df.show(5)

+--------------------+-----+
|                text|label|
+--------------------+-----+
|I grew up (b. 196...|    0|
|When I put this m...|    0|
|Why do people who...|    0|
|Even though I hav...|    0|
|Im a die hard Dad...|    1|
+--------------------+-----+
only showing top 5 rows



In [3]:
from pyspark.ml.feature import Tokenizer
tokenizer = Tokenizer(inputCol='text', outputCol='words')
df = tokenizer.transform(df)
df.show(5)

+--------------------+-----+--------------------+
|                text|label|               words|
+--------------------+-----+--------------------+
|I grew up (b. 196...|    0|[i, grew, up, (b....|
|When I put this m...|    0|[when, i, put, th...|
|Why do people who...|    0|[why, do, people,...|
|Even though I hav...|    0|[even, though, i,...|
|Im a die hard Dad...|    1|[im, a, die, hard...|
+--------------------+-----+--------------------+
only showing top 5 rows



In [4]:
from pyspark.ml.feature import HashingTF
hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol='features')
df = hashingTF.transform(df)
df.show(5)

+--------------------+-----+--------------------+--------------------+
|                text|label|               words|            features|
+--------------------+-----+--------------------+--------------------+
|I grew up (b. 196...|    0|[i, grew, up, (b....|(262144,[2101,371...|
|When I put this m...|    0|[when, i, put, th...|(262144,[1109,270...|
|Why do people who...|    0|[why, do, people,...|(262144,[2306,592...|
|Even though I hav...|    0|[even, though, i,...|(262144,[5923,721...|
|Im a die hard Dad...|    1|[im, a, die, hard...|(262144,[1109,243...|
+--------------------+-----+--------------------+--------------------+
only showing top 5 rows



In [5]:
from pyspark.sql.types import ArrayType, StringType
from pyspark.sql.functions import udf

stopwords = ['i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you', "you're"]

def f(s):
    return [''.join(e for e in token if e.isalnum()) for token in s if token not in stopwords]

func = udf(f, ArrayType(StringType()))
df = df.withColumn('clean_words', func(df['words']))
df.select('clean_words').show(5, truncate=100)

[Stage 4:>                                                          (0 + 1) / 1]

+----------------------------------------------------------------------------------------------------+
|                                                                                         clean_words|
+----------------------------------------------------------------------------------------------------+
|[grew, up, b, 1965, watching, and, loving, the, thunderbirds, all, mates, at, school, watched, pl...|
|[when, put, this, movie, in, dvd, player, and, sat, down, with, a, coke, and, some, chips, had, s...|
|[why, do, people, who, do, not, know, what, a, particular, time, in, the, past, was, like, feel, ...|
|[even, though, have, great, interest, in, biblical, movies, was, bored, to, death, every, minute,...|
|[im, a, die, hard, dads, army, fan, and, nothing, will, ever, change, that, got, all, the, tapes,...|
+----------------------------------------------------------------------------------------------------+
only showing top 5 rows



                                                                                

In [6]:
from pyspark import keyword_only
from pyspark.ml import Transformer
from pyspark.ml.param.shared import HasInputCol, HasOutputCol, Param

class RemoveStopWordsAndSpecialCharacters(Transformer, HasInputCol, HasOutputCol):
    @keyword_only
    def __init__(self, inputCol=None, outputCol=None, stopwords=None) -> None:
        super().__init__()
        self.stopwords = Param(self, 'stopwords', '')
        self._setDefault(stopwords=set())
        kwargs = self._input_kwargs
        self._set(**kwargs)

    def setStopwords(self, value):
        self._paramMap[self.stopwords] = value
        return self
    
    def getStopwords(self):
        return self.getOrDefault(self.stopwords)

    def _transform(self, dataset):
        stopwords = self.getStopwords()
        def f(s):
            return [''.join(e for e in token if e.isalnum()) 
                    for token in s if token not in stopwords]
        t = ArrayType(StringType())
        out_col = self.getOutputCol()
        in_col = dataset[self.getInputCol()]
        return dataset.withColumn(out_col, udf(f, t)(in_col))

In [9]:
from pyspark.ml import Pipeline
from pyspark.ml.feature import Tokenizer, HashingTF, Word2Vec, VectorAssembler
from pyspark.ml.classification import LinearSVC

tokenizer = Tokenizer(inputCol='text', outputCol='words')
cleaning = RemoveStopWordsAndSpecialCharacters(inputCol='words', outputCol='clean_words', stopwords=stopwords)
hashingTF = HashingTF(inputCol='clean_words', outputCol='tf')
w2v = Word2Vec(vectorSize=2, inputCol='clean_words', outputCol='w2v', minCount=1, maxIter=10)
asm = VectorAssembler(inputCols=[hashingTF.getOutputCol(), w2v.getOutputCol()], outputCol='features')
svm = LinearSVC(labelCol='label')
mypipeline = Pipeline(stages=[tokenizer, cleaning, hashingTF, w2v, asm, svm])
df = mypipeline.fit(df).transform(df)
df.select('prediction').show()

                                                                                

IllegalArgumentException: requirement failed: Column label must be of type numeric but was actually of type string.

In [None]:
sc.stop()