# Exercise 2: Text Processing and Classification using Spark

In [None]:
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql import Window

spark: SparkSession = SparkSession.builder.getOrCreate()

In [None]:
K = 75
FILE_PATH = 'hdfs:///user/dic24_shared/amazon-reviews/full/reviews_devset.json' # devset
# FILE_PATH = 'hdfs:///user/dic24_shared/amazon-reviews/full/reviewscombined.json' # full dataset

df = spark.read.json(FILE_PATH)
# df.head()

In [None]:
with open('../data/stopwords.txt') as f:
    stopwords = set(f.read().splitlines())

tokens = F.lower(F.col('reviewText'))
tokens = F.split(tokens, r'[^a-zA-Z<>^|]+')
tokens = F.filter(tokens, lambda token: F.length(token) > 1 and (token not in stopwords))
tokens = F.array_distinct(tokens)
tokens = F.explode(tokens)
df = df.withColumn('token', )
df = df[['category', 'token']]
# df.head()

In [None]:
counts = df.withColumn('n_c_t', F.count(F.expr('*')).over(Window.partitionBy('category', 'token')))
counts = counts.withColumn('n_t', F.count(F.expr('*')).over(Window.partitionBy('token')))
counts = counts.withColumn('n_c', F.count(F.expr('*')).over(Window.partitionBy('category')))
counts = counts.withColumn('n', F.lit(counts.count()))
# counts.head()

In [None]:
chisq = counts.withColumn('a', F.col('n_c_t'))
chisq = chisq.withColumn('b', F.col('n_c') - F.col('a'))
chisq = chisq.withColumn('c', F.col('n_t') - F.col('a'))
chisq = chisq.withColumn('d', F.col('n') - F.col('a') - F.col('b') - F.col('c'))
chisq = chisq.withColumn('chi_squared', F.col('n') * ((F.col('a') * F.col('d') - F.col('b') * F.col('c')) ** 2) / ((F.col('a') + F.col('b')) * (F.col('c') + F.col('d')) * (F.col('a') + F.col('c')) * (F.col('b') + F.col('d'))))
chisq = chisq[['category', 'token', 'chi_squared']]
# chisq.head()

In [None]:
topk = chisq.withColumn('rank', F.rank().over(Window.partitionBy('category').orderBy(F.desc('chi_squared'))))
topk = topk.filter(F.col('rank') <= K)
topk = topk.withColumn('topk', F.array('token', 'chi_squared'))
topk = topk.groupBy('category').agg(F.collect_list('topk').alias('topk'))
topk = topk.sort('category')
# topk.head()

In [None]:
with open('output.txt', 'w') as f:
    tokens = set()

    for row in topk.collect():
        tokens.update(map(lambda x: x[0], row['topk']))
        value_strings = [f'{value[1]}:{value[0]}' for value in row['topk']]
        print(' '.join([f'<{row["category"]}>'] + value_strings), file=f)

    print(' '.join(sorted(tokens)), file=f)