In [0]:
from pyspark.sql import SparkSession, functions as F
from pyspark.sql.types import ArrayType, StructType, StructField, StringType, IntegerType
from pyspark.sql.window import Window

In [0]:
spark.sql("USE CATALOG workspace")
spark.sql("USE SCHEMA med")

In [0]:
spark = SparkSession.builder.getOrCreate()
docs_clean_df = spark.table("workspace.med.docs_clean")

In [0]:
CHUNK_SIZE = 800
# expand sentences
expanded_sentence = docs_clean_df.select(
    "doc_id",
    "source",
    "category",
    "title",
    "synonyms",
    "url",
    "snapshot_ts",
    F.posexplode("sentences").alias("sentence_index", "sentence")
).filter(F.col("sentence") != "")

In [0]:
w = Window.partitionBy("doc_id").orderBy(F.col("sentence_index")).rowsBetween(Window.unboundedPreceding, Window.currentRow) # group by doc_id and order by sentence_index
expanded_sentence = expanded_sentence.withColumn("sentence_len", F.length("sentence")) # add length of sentence
expanded_sentence = expanded_sentence.withColumn("cum_len", F.sum(F.col("sentence_len") + F.lit(1)).over(w)) # add cumulative length of sentence

In [0]:
expanded_sentence = expanded_sentence.withColumn("chunk_group", ((F.col("cum_len") - 1) / F.lit(CHUNK_SIZE)).cast("int")) # make sentence chunks (all sentences until 800 characters = 1 chunk)

In [0]:
chunks_df = (
    expanded_sentence
    .groupBy(
        "doc_id",
        "source",
        "category",
        "title",
        "synonyms",
        "url",
        "snapshot_ts",
        "chunk_group"
    )
    .agg(
        F.collect_list("sentence").alias("sentences_in_chunk")
    )
)

# join sentences into one string per chunk
chunks_df = chunks_df.withColumn(
    "chunk_text",
    F.concat_ws(" ", F.col("sentences_in_chunk"))
)

# use chunk_group as chunk_index
chunks_df = chunks_df.withColumn(
    "chunk_index",
    F.col("chunk_group")
)

# create chunk_id as doc_id_{chunk_index}
chunks_df = chunks_df.withColumn(
    "chunk_id",
    F.concat(F.col("doc_id"), F.lit("_"), F.col("chunk_index").cast("string"))
)

In [0]:
doc_chunks_df = chunks_df.select(
    "doc_id",
    "source",
    "chunk_id",
    "chunk_index",
    "chunk_text",
    "title",
    "category",
    "url",
    "synonyms",
    "snapshot_ts"
)

display(doc_chunks_df.limit(10))

In [0]:
(
    doc_chunks_df
      .write
      .format("delta")
      .mode("overwrite")
      .option("overwriteSchema", "true")
      .saveAsTable("workspace.med.doc_chunks")
)