# Prepare data

This task writes the training data to a parquet folder in Volumes
so the version of Ray that ships with 16.4 ML can access it natively.

There are a variety of other data reading and writing mechanisms supported.

In [0]:
import os

from pyspark.sql import functions as F
from pyspark.sql import types as T
from pyspark.sql import DataFrame

In [0]:
dbutils.widgets.text("catalog_name", "")
dbutils.widgets.text("schema_name", "")
dbutils.widgets.text("data_dir", "")

catalog_name = dbutils.widgets.get("catalog_name")
schema_name = dbutils.widgets.get("schema_name")
data_dir = dbutils.widgets.get("data_dir")

assert catalog_name, "catalog_name is required"
assert schema_name, "schema_name is required"
assert data_dir, "data_dir is required"

spark.sql(f"USE CATALOG {catalog_name}")
spark.sql(f"USE SCHEMA {schema_name}")

source_table_name = "yelp_reviews_silver"
splits = ["train", "test"]

print(f"catalog_name: {catalog_name}")
print(f"schema_name: {schema_name}")
print(f"data_dir: {data_dir}")
print(f"source_table_name: {source_table_name}")
print(f"splits: {splits}")

In [0]:
def write_parquet(df: DataFrame, split: str, data_dir: str) -> None:
    filename = os.path.join(data_dir, f"{split}.parquet")
    split_df = df.filter(F.col("split") == split).select("input_ids", "label")
    split_df.write.format("parquet").mode("overwrite").save(filename)

def read_parquet(data_dir: str, split: str) -> DataFrame:
    filename = os.path.join(data_dir, f"{split}.parquet")
    return spark.read.format("parquet").load(filename)

In [0]:
df = spark.table(source_table_name)

for split in splits:
    write_parquet(df, split, data_dir)

In [0]:
for split in splits:
    print(f"split: {split}")
    display(read_parquet(data_dir, split).limit(10))