# Transformation: Use spark to generate joined word embeddings

In this notebook, we pick up the output of the last Airflow operator in `download_and_ocr_images` which was writing partitioned table to Google BigQuery for individual save_key words we queried Bing for. We combine all the _partioned tables and compute word embeddings for them, that can be later used for assigning features to the graph nodes (currently out of scope). As a result of this, we can generate dashboards based on word counts per save_key etc. 

The Notebook using `pyspark` covers the following steps:

* Check all tables in bigquery datasets with _partitioned
* Load them and union them to a single spark DataFrame taking into account the seach_key (named documentType)
* Preprocess tokens: lowercase, remove punctuation
* Compute word embeddings using spacy
* Write joined `embedding` table back to Bigquery dataset. 

In [None]:
import os
import dotenv

import pyspark
from pyspark.sql import functions as F
from pyspark.sql import SparkSession
from udfs import (
    spacy_word2vec_grouped_udf,
    remove_punctuation_udf,
    lowercase_udf
)

# Load environment variables from .env file for more flexibility
dotenv.load_dotenv(dotenv.find_dotenv(".env"), override=True)

home = os.path.expanduser("~")
project_id = os.environ["GCP_PROJECT_ID"]
bigquery_dataset = os.environ["BIGQUERY_DATASET"]
bucket = os.environ["GCP_GCS_BUCKET"]

The most exhaustive task is to compute word embeddings for the OCRed tokens. To split up the work create partitions for that. 

In [None]:
n_partitions_word2vec = 20

Initialize spark session. Please note the `README.md` file denoting download of additional jars. 

In [None]:
spark = SparkSession.builder \
    .master("local[*]") \
    .appName("test") \
    .config("spark.jars", f"{home}/bin/spark-3.0.3-bin-hadoop3.2/jars/spark-bigquery-latest_2.12.jar") \
    .config("spark.jars", f"{home}/bin/spark-3.0.3-bin-hadoop3.2/jars/gcs-connector-hadoop2-2.1.1.jar") \
    .getOrCreate()

Get all tables in the dataset, which contain _partitioned. These are the ones, created in the last step of airflow

In [None]:
spark.conf.set("materializationDataset", bigquery_dataset)

tables = (
    spark.read.format("bigquery")
    .option("project", project_id)
    .option("viewsEnabled", "true")
    .load(f"select table_name from {project_id}.{bigquery_dataset}.INFORMATION_SCHEMA.TABLES")
).toPandas()["table_name"]

tables = tables[tables.str.endswith("_partitioned")].tolist()
tables

Load the tables and union them into a single table

In [None]:
queries = []

for table in tables:
    
    documentType = table.split("_partitioned")[0]

    tab = (
        spark.read.format("bigquery")
        .option("project", project_id)
        .option("table", f"{bigquery_dataset}.{table}")
        .load()
    )
    tab.registerTempTable(table)
    queries.append(
        f"""SELECT *, '{documentType}' AS documentType FROM {table}"""
    )
    
query = " UNION ALL ".join(queries)

ocr = spark.sql(query)
ocr.count()

Normalize the tokens to be vectorized by

* Removing punctuation
* make lowercase
* more to come ...

In [None]:
ocr = (
    ocr
    .withColumn("text", 
        remove_punctuation_udf(F.col("text"))
    )
    .withColumn("text", 
        lowercase_udf(F.col("text"))
    )
)

Now, reorganize the data and compute word embeddings using spacy. Seems kind of hacky, however, is the only working method found. Check this out https://towardsdatascience.com/a-couple-tricks-for-using-spacy-at-scale-54affd8326cf

In [None]:
embedding = (
    ocr
    .select('id', 'documentId', 'documentType', 'text')
    .groupby((F.floor(F.rand() * n_partitions_word2vec)).alias('groupNumber'))
    .agg(F.collect_list(F.struct(F.col('id'), F.col("documentId"), F.col("documentType"), F.col('text'))).alias('documentGroup'))
    .repartition('groupNumber')
    .select(F.explode(spacy_word2vec_grouped_udf(F.col('documentGroup'))).alias('results'))
    .select(F.col('results.*'))
    .select("id", "documentId", "documentType", "text", "vector")
    .sort(F.col("documentId"), F.col("id"))
    .join(
        ocr.select('id', 'documentId', 'block_num', 'line_num', 'left', 'top', 'width', 'height', 'conf'),
        on=['id', 'documentId'], 
        )
    .withColumnRenamed("id", "textId")
    )

In [None]:
embedding.show()

This may take a while. Really heavy part is actually writing to BigQuery for some reason. 

In [None]:
embedding.write.format('bigquery') \
   .option("project", project_id) \
   .option("temporaryGcsBucket", bucket) \
   .mode("overwrite") \
   .save(f"{bigquery_dataset}.embedding")