# Preprocess documents

Apply any transformations, such as parsing, cleansing, etc... to the loaded documents.

Here we pre-tokenize and count the tokens as an example of a preprocessing step you may
want to run either before training or inference. Other transformations could include
filtering invalid documents, deduplicating documents, enriching with additional metadata,
and many other transformations you may need to meet your business goals.

In [0]:
from typing import (
    Iterator
)

import pandas as pd

from delta import DeltaTable

from pyspark.sql import functions as F
from pyspark.sql import types as T
from pyspark.sql import SparkSession, DataFrame, Window, Row

from transformers import AutoTokenizer

In [0]:
dbutils.widgets.text("catalog_name", "")
dbutils.widgets.text("schema_name", "")
dbutils.widgets.text("hugging_face_id", "")

catalog_name = dbutils.widgets.get("catalog_name")
schema_name = dbutils.widgets.get("schema_name")
hugging_face_id = dbutils.widgets.get("hugging_face_id")

assert catalog_name, "catalog_name is required"
assert schema_name, "schema_name is required"
assert hugging_face_id, "hugging_face_id is required"

spark.sql(f"USE CATALOG {catalog_name}")
spark.sql(f"USE SCHEMA {schema_name}")

source_table_name = "yelp_reviews_bronze"
target_table_name = "yelp_reviews_silver"

print(f"catalog_name: {catalog_name}")
print(f"schema_name: {schema_name}")
print(f"hugging_face_id: {hugging_face_id}")

In [0]:
def ensure_target_table_exists(target_table_name: str) -> None:
    spark.sql(
        f"""
        CREATE TABLE IF NOT EXISTS {target_table_name} (
            id INT NOT NULL,
            text STRING NOT NULL,
            label INT NOT NULL,
            split STRING NOT NULL,
            input_ids ARRAY<INT> NOT NULL,
            full_token_count INT NOT NULL,
            truncated_token_count INT NOT NULL,
            CONSTRAINT pk_{target_table_name} PRIMARY KEY (id)
        )
        """
    )


def merge_append_table(
    spark: SparkSession,
    source_df: DataFrame,
    target_table_name: str
) -> None:
    target_table = DeltaTable.forName(spark, target_table_name)
    merge = (
        target_table.alias("target")
        .merge(source_df.alias("source"), "source.id = target.id")
        .whenNotMatchedInsertAll()
    )
    merge.execute()


@F.pandas_udf(returnType=T.IntegerType())
def count_tokens_udf(
    text_series_iterator: Iterator[pd.Series],
) -> Iterator[pd.Series]:
    # Tokenize the full text so we can analyze how many tokens it would 
    # include if we weren't truncating. We can't simply chop this off as
    # we might not handle special tokens properly.
    tokenizer = AutoTokenizer.from_pretrained(hugging_face_id)
    for text_series in text_series_iterator:
        texts = text_series.to_list()
        tokenized = tokenizer(texts)
        input_ids = tokenized["input_ids"]
        counts = [len(x) for x in input_ids]
        yield pd.Series(counts)


@F.pandas_udf(returnType=T.ArrayType(T.IntegerType()))
def tokenize_and_truncate_udf(
    text_series_iterator: Iterator[pd.Series],
) -> Iterator[pd.Series]:
    tokenizer = AutoTokenizer.from_pretrained(hugging_face_id)
    for text_series in text_series_iterator:
        texts = text_series.to_list()
        tokenized = tokenizer(texts, truncation=True)
        input_ids = tokenized["input_ids"]
        yield pd.Series(input_ids)


In [0]:
df_silver = (
    spark.table(source_table_name)
    .withColumn("input_ids", tokenize_and_truncate_udf(F.col("text")))
    .withColumn("full_token_count", count_tokens_udf(F.col("text")))
    .withColumn("truncated_token_count", F.size(F.col("input_ids")))
)

In [0]:
ensure_target_table_exists(target_table_name)
merge_append_table(spark, df_silver, target_table_name)
row_count = spark.table(target_table_name).count()
print(f"row_count: {row_count}")
display(spark.table(target_table_name))

In [0]:
num_bins = 20  # Choose appropriate number of bins
result = (
    spark.table(target_table_name)
    .select(F.histogram_numeric("full_token_count", F.lit(num_bins)).alias("hist"))
    .select(F.inline("hist"))
)
display(result)

Databricks visualization. Run in Databricks to view.