In [1]:
## run this file on google colab!

In [1]:
# !apt-get update # Update apt-get repository.
# !apt-get install openjdk-8-jdk-headless -qq > /dev/null # Install Java.
# !wget -q http://archive.apache.org/dist/spark/spark-3.1.1/spark-3.1.1-bin-hadoop3.2.tgz # Download Apache Sparks.
# !tar xf spark-3.1.1-bin-hadoop3.2.tgz # Unzip the tgz file.
# !pip install -q findspark # Install findspark. Adds PySpark to the System path during runtime.

In [2]:
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
os.environ["SPARK_HOME"] = "/content/spark-3.1.1-bin-hadoop3.2"

import findspark
findspark.init()

from pyspark.sql import SparkSession
from pyspark.sql.functions import explode, col, substring, split, size, log, count, \
                                  countDistinct, monotonically_increasing_id
from pyspark.sql.window import Window
from pyspark.sql import functions as F
from pyspark.sql.types import ArrayType, StringType
from string import punctuation

In [3]:
spark = SparkSession.builder.master("local[*]").getOrCreate()

### Read text from disk

In [4]:
text_file_path = "news.txt"
text_df = spark.read.text(text_file_path)

Following cell defines a function called `remove_punctuation` that takes a list of words as input and removes punctuation from each word. The function uses the `str.maketrans` method to create a translation table that maps each punctuation character to `None`. It then applies this translation to each word using a list comprehension, effectively removing punctuation.

The function also includes a second list comprehension to filter out any empty strings that might result from the removal of punctuation. The cleaned words are returned as a list.

The code then registers this Python function as a User Defined Function (UDF) in Spark with the name `remove_punctuation_udf` using `spark.udf.register`. This UDF can then be used in Spark SQL queries to apply the punctuation removal logic to DataFrame columns. Note that `ArrayType(StringType())` specifies the return type of the UDF as a list of strings.

In [5]:
def remove_punctuation(words):
    translator = str.maketrans("", "", punctuation)
    cleaned_words = [word.translate(translator) for word in words]
    cleaned_words = [word for word in cleaned_words if word.strip()]
    return cleaned_words

spark.udf.register("remove_punctuation_udf", remove_punctuation, ArrayType(StringType()))


<function __main__.remove_punctuation(words)>

This following cell processes a DataFrame (`text_df`) containing text data.

1. **Splitting into Words:**
   Splits the text in the "value" column of `text_df` into words, and the result is stored in a new DataFrame (`row_df`) with a column named "words."

2. **Lowercasing:**
   Converts all the words to lowercase using the `lower` function and creates a new DataFrame (`lowercase_lines_df`) with the resulting lowercase words.

3. **Exploding Words:**
   Uses the `explode` function to transform the array of words into separate rows for each word. The result is a DataFrame (`exploded_lines_df`) with a column named "word."

4. **Creating a Temporary View:**
   Creates a temporary view named "lowercase_words_view" from the DataFrame `lowercase_lines_df`. This allows you to use Spark SQL queries on this view.

5. **Applying Punctuation Removal UDF:**
   Uses a Spark SQL query to apply the previously registered UDF (`remove_punctuation_udf`) to remove punctuation from the words in the "lowercase_words_view" view. The result is a DataFrame (`clean_lines_df`) with a column named "words" containing cleaned words.

6. **Assigning Document IDs:**
   Adds a new column "doc_id" to the DataFrame `clean_lines_df` using the `monotonically_increasing_id` function. This column is assigned a unique identifier for each row, essentially serving as a document ID.


In [6]:
row_df = text_df.select(split(text_df.value, " ").alias("words"))
lowercase_lines_df = row_df.selectExpr("transform(words, word -> lower(word)) as words")
exploded_lines_df = lowercase_lines_df.select(explode(lowercase_lines_df.words).alias("word"))
lowercase_lines_df.createOrReplaceTempView("lowercase_words_view")
clean_lines_df = spark.sql("SELECT remove_punctuation_udf(words) as words FROM lowercase_words_view")
lines_df = clean_lines_df.withColumn("doc_id", monotonically_increasing_id())

This following cell calculates the TF-IDF (Term Frequency-Inverse Document Frequency) values for each word in a collection of documents.

1. **Calculate Term Frequency (TF):**
   This step explodes the array of words into separate rows for each word, groups by document ID and word, and calculates the term frequency (`word_count`) for each word in each document.

2. **Calculate Total Word Count per Document:**
   Calculates the total word count per document by summing the word counts. The result is stored in a DataFrame (`total_word_count_df`).

3. **Calculate TF (Normalized Term Frequency):**
   Joins the TF DataFrame with the total word count DataFrame, calculates the normalized TF, and drops unnecessary columns.

4. **Calculate Document Frequency (DF):**
   Groups by word and calculates the document frequency (number of documents where each word appears).

5. **Calculate Inverse Document Frequency (IDF):**
   Calculates the inverse document frequency (IDF) for each word.

6. **Calculate TF-IDF:**
   Joins the TF and IDF DataFrames, selects relevant columns, and calculates the TF-IDF values for each word in each document.

7. **Show Results:**
   Displays the first 10 rows of the resulting TF-IDF DataFrame.


In [7]:
tf_df = lines_df.select("doc_id", explode("words").alias("word")) \
    .groupBy("doc_id", "word").agg(count("*").alias("word_count"))

total_word_count_df = tf_df.groupBy("doc_id").agg({"word_count": "sum"}).withColumnRenamed("sum(word_count)", "total_word_count")

tf_df = tf_df.join(total_word_count_df, "doc_id") \
    .withColumn("tf", col("word_count") / col("total_word_count")) \
    .drop("word_count", "total_word_count")

df_df = tf_df.groupBy("word").agg(countDistinct("doc_id").alias("document_frequency"))

total_docs = lines_df.select("doc_id").distinct().count()
idf_df = df_df.withColumn("idf", log(total_docs / (col("document_frequency") + 1)))

tf_idf_df = tf_df.join(idf_df, "word").select("doc_id", "word", "tf", "idf") \
      .withColumn("tf_idf", col("tf") * col("idf"))
tf_idf_df.show(10, truncate=False)


+------+--------+---------------------+------------------+---------------------+
|doc_id|word    |tf                   |idf               |tf_idf               |
+------+--------+---------------------+------------------+---------------------+
|6     |priority|0.0035714285714285713|1.791759469228055 |0.006399140961528767 |
|3     |some    |0.002531645569620253 |1.791759469228055 |0.004536099922096342 |
|3     |still   |0.002531645569620253 |1.0986122886681096|0.002781296933336986 |
|2     |still   |0.004878048780487805 |1.0986122886681096|0.005359084334966388 |
|9     |still   |0.005025125628140704 |1.0986122886681096|0.00552066476717643  |
|11    |tonnes  |0.005263157894736842 |1.3862943611198906|0.0072962861111573185|
|3     |tonnes  |0.005063291139240506 |1.3862943611198906|0.007019211955037421 |
|5     |import  |0.010752688172043012 |1.0986122886681096|0.011813035362022684 |
|1     |import  |0.006289308176100629 |1.0986122886681096|0.006909511249484966 |
|6     |import  |0.003571428

### results

In [8]:
result = tf_idf_df.filter(col("word") == "gas")
result.show(truncate=False)

+------+----+--------------------+------------------+--------------------+
|doc_id|word|tf                  |idf               |tf_idf              |
+------+----+--------------------+------------------+--------------------+
|0     |gas |0.0111731843575419  |1.3862943611198906|0.015489322470613303|
|4     |gas |0.009259259259259259|1.3862943611198906|0.012836058899258245|
+------+----+--------------------+------------------+--------------------+



In [9]:
result = tf_idf_df.filter(col("word") == "japan")
result.show(truncate=False)

+------+-----+--------------------+------------------+--------------------+
|doc_id|word |tf                  |idf               |tf_idf              |
+------+-----+--------------------+------------------+--------------------+
|6     |japan|0.03214285714285714 |0.8754687373538999|0.02814006655780392 |
|0     |japan|0.0111731843575419  |0.8754687373538999|0.009781773601719551|
|7     |japan|0.021739130434782608|0.8754687373538999|0.019031929072910864|
|5     |japan|0.03225806451612903 |0.8754687373538999|0.028240927011416124|
+------+-----+--------------------+------------------+--------------------+



In [10]:
result = tf_idf_df.filter(col("word") == "market")
result.show(truncate=False)

+------+------+--------------------+------------------+--------------------+
|doc_id|word  |tf                  |idf               |tf_idf              |
+------+------+--------------------+------------------+--------------------+
|10    |market|0.00423728813559322 |0.8754687373538999|0.003709613293872457|
|11    |market|0.010526315789473684|0.8754687373538999|0.009215460393198946|
|7     |market|0.021739130434782608|0.8754687373538999|0.019031929072910864|
|6     |market|0.010714285714285714|0.8754687373538999|0.009380022185934641|
+------+------+--------------------+------------------+--------------------+

