# Load documents

Read the documents from an external source into Databricks.

In this case we are simply reading from a Hugging Face dataset, but 
you may be reading from a cloud storage bucket directly, or via Databricks Volumes,
or from structured streaming or Lakeflow Pipelines sources.

In [0]:
from delta import DeltaTable

from pyspark.sql import functions as F
from pyspark.sql import types as T
from pyspark.sql import SparkSession, DataFrame, Window, Row

from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer

import ray.train.huggingface.transformers

from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
from transformers import TrainingArguments
from transformers import Trainer

from datasets import load_dataset

import numpy as np
import evaluate

import datasets
import os

In [0]:
dbutils.widgets.text("catalog_name", "")
dbutils.widgets.text("schema_name", "")
dbutils.widgets.text("hf_datasets_cache", "")

catalog_name = dbutils.widgets.get("catalog_name")
schema_name = dbutils.widgets.get("schema_name")
hf_datasets_cache = dbutils.widgets.get("hf_datasets_cache")

assert catalog_name, "catalog_name is required"
assert schema_name, "schema_name is required"
assert hf_datasets_cache, "hf_datasets_cache is required"

spark.sql(f"USE CATALOG {catalog_name}")
spark.sql(f"USE SCHEMA {schema_name}")
os.environ["HF_DATASETS_CACHE"] = hf_datasets_cache

target_table_name = "yelp_reviews_bronze"

datasets.utils.logging.disable_progress_bar()

print(f"catalog_name: {catalog_name}")
print(f"schema_name: {schema_name}")
print(f"hf_datasets_cache: {hf_datasets_cache}")
print(f"target_table_name: {target_table_name}")

In [0]:
def ensure_target_table_exists(target_table_name: str) -> None:
    spark.sql(
        f"""
        CREATE TABLE IF NOT EXISTS {target_table_name} (
            id INT NOT NULL,
            text STRING NOT NULL,
            label INT NOT NULL,
            split STRING NOT NULL,
            CONSTRAINT pk_{target_table_name} PRIMARY KEY (id)
        )
        """
    )


def get_split_as_dataframe(
    dataset: datasets.Dataset,
    split: str,
    start_id: int = 0
) -> DataFrame:
    assert split in dataset.keys()
    id_assignment = np.arange(dataset[split].num_rows) + start_id
    return (
        spark.createDataFrame(
            dataset[split]
            .add_column("id", id_assignment),
            schema=T.StructType([
                T.StructField("id", T.IntegerType()),
                T.StructField("text", T.StringType()),
                T.StructField("label", T.IntegerType()),
            ])
        )
        .withColumn("split", F.lit(split))
    )


def merge_append_table(
    spark: SparkSession,
    source_df: DataFrame,
    target_table_name: str
) -> None:
    target_table = DeltaTable.forName(spark, target_table_name)
    merge = (
        target_table.alias("target")
        .merge(source_df.alias("source"), "source.id = target.id")
        .whenNotMatchedInsertAll()
    )
    merge.execute()


In [0]:
dataset = load_dataset("yelp_review_full")
df_train = get_split_as_dataframe(dataset, "train", start_id=0)
df_test = get_split_as_dataframe(dataset, "test", start_id=dataset["train"].num_rows)

In [0]:
ensure_target_table_exists(target_table_name)
merge_append_table(spark, df_train, target_table_name)
merge_append_table(spark, df_test, target_table_name)
row_count = spark.table(target_table_name).count()
print(f"row_count: {row_count}")
display(spark.table(target_table_name))