
# Article Chunking with `unstructured`

We will use [unstructured](https://unstructured.io/) for our primary chunking strategy. We are going to use this for the actual body content and it is common to change the arguments of the unstructured [partitioning](https://docs.unstructured.io/open-source/core-functionality/partitioning) functions upon future iterations where we are improving our Dataset curation for pre-training or fine-tuning or our chunking strategy for our VS index.

**NOTE**: Since we are working with XML data we are going to use the [partition-xml](https://docs.unstructured.io/open-source/core-functionality/partitioning#partition-xml) function. There are many libraries out there that can make use of the xml tags we left in our body column and they can excluded easily with regex or opensource xml parsing library. Thus, we left the xml in the body to allow for discovery of new / different parsing strategies in the future. Chunking strategies are becoming more and more relavant in improving RAG performance.

**NOTE**: YES. We could have used [partition-xml](https://docs.unstructured.io/open-source/core-functionality/partitioning#partition-xml) function to parse from file instead of from the `curated_articles` delta table. Similar to the above note, we did this to make future iterative improvements faster as reading text from file in blob storage has a much larger I/O preformance cost. This was a deliberate architecture decision for future enhancements, not just to conform to a [Medallion Architecture](https://www.databricks.com/glossary/medallion-architecture)... although we are doing that as well.

In [0]:
dbutils.widgets.dropdown(name="FILE_TYPE", defaultValue="xml", choices=["xml", "text"])
FILE_TYPE = dbutils.widgets.get("FILE_TYPE")
dbutils.widgets.dropdown(name="INSPECT_CONTENT", defaultValue="true", choices=["true", "false"])

In [0]:
%pip install unstructured
dbutils.library.restartPython()

In [0]:
%run ./_resources/pubmed_pipeline_config $RESET_ALL_DATA=false $DISPLAY_CONFIGS=true

In [0]:
# Create a UDF that will chunk our article bodies
#TODO: check if we have multi-language sources
#TODO: evaluate using pandas UDF

from unstructured.partition.xml import partition_xml
from pyspark.sql.types import ArrayType, StringType
import xml.etree.ElementTree as ET

def chunk_xml_body(body: str, attrs: dict):
    root = ET.Element('root', attrib=attrs)
    root.text = body
    body_elements = partition_xml(text=str(ET.tostring(root, encoding='utf-8'), 'UTF-8'),
                                  xml_keep_tags = False,
                                  encoding='utf-8',
                                  include_metadata=False,
                                  languages=['eng',],
                                  date_from_file_object=None,
                                  chunking_strategy='by_title',
                                  multipage_sections=True,
                                  combine_text_under_n_chars=300,
                                  new_after_n_chars=1400,
                                  max_characters=1250)
    body_chunks = [be.text for be in body_elements if len(be.text) >= 110]
    return body_chunks

chunk_xml_body_udf = udf(chunk_xml_body, ArrayType(StringType()))


The proposed schema for our target table is:
  TODO: include DDL from CREATE_TABLE_processed_articles_content.sql

In [0]:
from pyspark.sql.functions import col, lit, concat
from pyspark.sql.functions import xpath_string, explode, posexplode

# This includes limit for discussion, real workload will not include a limit
content_src = pubmed.curated_articles.df \
                    .withColumn('contents', chunk_xml_body_udf('body', 'attrs')) \
                    .select(col('AccessionID').alias('pmid'),
                            xpath_string(col('front'),lit('front/article-meta/title-group/article-title')).alias('title'),
                            xpath_string(col('front'),lit('front/journal-meta/journal-title-group/journal-title')).alias('journal'),
                            lit('NEED DESIRED CITATION FORMAT').alias('citation'),
                            xpath_string(col('front'),lit('front/article-meta/pub-date/year')).alias('year'),
                            posexplode('contents').alias('content_pos', 'content')) \
                    .withColumn('id', concat(col('pmid'), lit('-'), col('content_pos'))) \
                    .drop('content_pos') \
                    .alias('src')
content_src.createOrReplaceTempView('content_src')

#display(content_src)

In [0]:
#TODO: make syntax cleaner
# TODO: make as merge - isn't simple merge, so will need to write out as separate effort

sql_insert_overwrite = f"""
INSERT OVERWRITE {pubmed.processed_articles_content.name}
SELECT 
    id,
    pmid,
    journal,
    title,
    year,
    citation,
    content
FROM content_src"""

spark.sql(sql_insert_overwrite)

In [0]:
display(pubmed.processed_articles_content.df)


# DISCOVERY

In [0]:
from unstructured.partition.xml import partition_xml
from pyspark.sql.types import ArrayType, StringType

#code to capture math exception local

def chunk_xml_body(body: str, attrs: dict):
    try: 
        xml_body = '<root xmlns:xlink="http://www.w3.org/1999/xlink">'+ body + '</root>'
        #print(xml_body)
        body_elements = partition_xml(text=xml_body,
                                xml_keep_tags = False,
                                encoding='utf-8',
                                include_metadata=False,
                                languages=['eng',],
                                date_from_file_object=None,
                                chunking_strategy='by_title',
                                multipage_sections=True,
                                combine_text_under_n_chars=300,
                                new_after_n_chars=1400,
                                max_characters=1250)    
        body_chunks = [str(be.text) for be in body_elements if len(be.text) >= 110]
        #print(body_chunks)
        return [None]
    except:
        return [str(body),]

chunk_xml_body_err_cap_udf = udf(chunk_xml_body_err_cap, ArrayType(StringType()))

from pyspark.sql.functions import col, lit, concat
from pyspark.sql.functions import xpath_string, explode, posexplode

# This includes limit for discussion, real workload will not include a limit
content_cap = pubmed.curated_articles.df.limit(100) \
                    .withColumn('contents', chunk_xml_body_err_cap_udf('body')) \
                    .select(col('AccessionID').alias('pmid'),
                            xpath_string(col('front'),lit('front/article-meta/title-group/article-title')).alias('title'),
                            xpath_string(col('front'),lit('front/journal-meta/journal-title-group/journal-title')).alias('journal'),
                            lit('NEED DESIRED CITATION FORMAT').alias('citation'),
                            xpath_string(col('front'),lit('front/article-meta/pub-date/year')).alias('year'),
                            posexplode('contents').alias('content_pos', 'content')) \
                    .withColumn('id', concat(col('pmid'), lit('-'), col('content_pos'))) \
                    .drop('content_pos') \
                    .alias('cap')


display(content_cap)

In [0]:
%sql


In [0]:
dat = spark.sql('select * from `pubmed-pipeline`.curated.articles_xml WHERE AccessionID = "PMC11098454"').collect()[0]
attrs = dat.attrs
body = dat.body
processing_metadata = dat.processing_metadata

In [0]:
processing_metadata

In [0]:
from unstructured.partition.xml import partition_xml
from pyspark.sql.types import ArrayType, StringType
import xml.etree.ElementTree as ET

def chunk_xml_body(body: str, attrs: dict):
    root = ET.Element('root', attrib=attrs)
    root.text = body
    body_elements = partition_xml(text=str(ET.tostring(root, encoding='utf-8'), 'UTF-8'),
                                  xml_keep_tags = False,
                                  encoding='utf-8',
                                  include_metadata=False,
                                  languages=['eng',],
                                  date_from_file_object=None,
                                  chunking_strategy='by_title',
                                  multipage_sections=True,
                                  combine_text_under_n_chars=300,
                                  new_after_n_chars=1400,
                                  max_characters=1250)
    body_chunks = [be.text for be in body_elements if len(be.text) >= 110]
    return body_chunks

chunk_xml_body(body, attrs)

In [0]:
body

In [0]:
import xml.etree.ElementTree as ET
root = ET.Element('root', attrib=attrs)
root.text = body
xml_body = str(ET.tostring(root, encoding='utf-8'), 'UTF-8')


In [0]:
xml_body

In [0]:

# Create the root element


# Create a child element
child1 = ET.SubElement(root, 'element1')
child1.text = 'Content of element1'

# Create another child element
child2 = ET.SubElement(root, 'element2')
child2.text = 'Content of element2'

# Create the ElementTree object
tree = ET.ElementTree(root)

# Write to a file
tree.write('output.xml', xml_declaration=True, encoding='UTF-8')