In [0]:
# %pip install --upgrade pip
# %pip install torch==2.0.1
# %pip install transformers==4.29.2
# %pip install scikit-learn==0.24.2
# %pip install pyspark==3.4.0
# %pip install pandas==1.3.4
# %pip install accelerate==0.20.3
# %pip install pytorch-lightning==2.0.3
# %pip install seqeval==1.2.2
# %pip install datasets==2.14.4
# %pip install tqdm==4.65.0
# %pip install evaluate==0.4.0
# %pip install mlflow==2.10.2
# %pip install mlflow[pipelines]
# %pip install torchvision==0.15.2
# dbutils.library.restartPython()

In [0]:
# import required libraries
import os 
import re

import unicodedata

from pyspark.sql import SparkSession
from pyspark.sql.functions import explode, explode_outer, when,  isnull, col, arrays_zip, element_at, lower, trim, count, split, from_json
import pyspark.sql.types as T
import pyspark.sql.functions as F

import pandas as pd
import mlflow

print(spark)



<pyspark.sql.session.SparkSession object at 0x7fa0782928b0>


In [0]:
# create a toy dataframe
dff = spark.createDataFrame([["row123", "Fluffy is taking Dasuquin. She has stopped using Cosequin."]]).toDF("rowkey", "text")
display(dff)

rowkey,text
row123,Fluffy is taking Dasuquin. She has stopped using Cosequin.


Send spark dataframe through NER model and append predictions

In [0]:
# internally trained NER model, URI can be found in model artifacts
logged_model_ner = 'runs:/09c916aca7ed485fa918039471ae853a/ner' 

# load NER model as UDF 
# the virtualenv will reproduce the environment (spark version, python version, pkg versions) used when model was trained
model_udf = mlflow.pyfunc.spark_udf(
    spark, 
    model_uri=logged_model_ner, 
    result_type="string", # json string - will require further parsing
    env_manager="virtualenv")

Downloading artifacts:   0%|          | 0/11 [00:00<?, ?it/s]

2024/03/07 19:41:18 INFO mlflow.store.artifact.artifact_repo: The progress bar can be disabled by setting the environment variable MLFLOW_ENABLE_ARTIFACTS_PROGRESS_BAR to false
2024/03/07 19:41:31 INFO mlflow.pyfunc: This UDF will use virtualenv to recreate the model's software environment for inference. This may take extra time during execution.


Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

2024/03/07 19:41:31 INFO mlflow.models.flavor_backend_registry: Selected backend for flavor 'python_function'
2024/03/07 19:41:31 INFO mlflow.utils.virtualenv: Installing python 3.9.5 if it does not exist
2024/03/07 19:41:32 INFO mlflow.utils.virtualenv: Environment /local_disk0/.ephemeral_nfs/repl_tmp_data/ReplId-17ad3-b63bf-abd95-2/mlflow/envs/virtualenv_envs/mlflow-4dd7bf40bb032fc4439df332473663d872ccf4c2 already exists
2024/03/07 19:41:32 INFO mlflow.utils.environment: === Running command '['bash', '-c', 'source /local_disk0/.ephemeral_nfs/repl_tmp_data/ReplId-17ad3-b63bf-abd95-2/mlflow/envs/virtualenv_envs/mlflow-4dd7bf40bb032fc4439df332473663d872ccf4c2/bin/activate && python -c ""']'


In [0]:
# splits text into sentences using punctuation that indicates end of sentence (.?!)
import re
import pyspark.sql.types as T

def split_text(text):
    text_length = len(text)
    finditer_output = re.finditer(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|!)\s', text)
    sentences = re.split(r"(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|!)\s", text)
    delimiter_spans = []
    for f in finditer_output: 
        delimiter_spans.append(list(f.span()))
    # create an empty list to store the character index spans of the sentences
    sentence_spans = []
    # this counts the number of delimiters found by counting the number of items in the list
    upper_bound = len(delimiter_spans)
    # loop through the delimiter list, and based on criteria generate the list with the sentence spans
    if upper_bound > 0:
        for i in range(upper_bound+1):
            if i == 0:
                sentence_spans.append(list([1, delimiter_spans[0][0]-1]))
            elif i == upper_bound:
                sentence_spans.append(list([delimiter_spans[upper_bound-1][1]+1, text_length]))
            else:
                sentence_spans.append(list([delimiter_spans[(i-1)][1]+1, delimiter_spans[i][0]-1])) 
    # if there are zero sentences, then append entire note
    else:
        sentences = [text]
        sentence_spans.append(list([1, text_length]))
    # zip sentences and spans into lists eturn output to be stored in the spark table
    sentence_and_span = [list(a) for a in zip(sentences, sentence_spans)]
    # return the zipped sentences and spans
    return sentence_and_span

# wrap the function above as a spark udf
split_text_udf = udf(split_text, T.ArrayType(T.StringType()))

In [0]:
# send dataframe through sentence parser. Retain character indexes of begin / end of sentences
# "1" based indexing (not 0)
sample_note_parsed_sentences = dff\
    .select(dff.rowkey,
            dff.text,
            split_text_udf(dff.text).alias("sentences_and_spans"))\
    .withColumn("exploded_sentences_spans", explode_outer(F.col("sentences_and_spans")))\
    .withColumn("list_sentence_spans", split(F.col("exploded_sentences_spans"), ",\s+(?=(\[\d+, \d+\]]$))" , -1))\
    .withColumn("sentence", F.col("list_sentence_spans")[0])\
    .withColumn("span", F.col("list_sentence_spans")[1])\
    .withColumn("sentence", F.expr("substring(sentence, 2, length(sentence)-1)"))\
    .withColumn("span", F.expr("substring(span, 1, length(span)-1)"))\
    .withColumn("sentence_span_list", split(F.col("span"), ", ", -1) )\
    .withColumn("sentence_span_start", F.col("sentence_span_list")[0])\
    .withColumn("sentence_span_end", F.col("sentence_span_list")[1])\
    .withColumn("sentence_span_start", F.expr("substring(sentence_span_start, 2, length(sentence_span_start)-1)").cast(T.IntegerType()))\
    .withColumn("sentence_span_end", F.expr("substring(sentence_span_end, 1, length(sentence_span_end)-1)").cast(T.IntegerType()))

display(sample_note_parsed_sentences)

rowkey,text,sentences_and_spans,exploded_sentences_spans,list_sentence_spans,sentence,span,sentence_span_list,sentence_span_start,sentence_span_end
row123,Fluffy is taking Dasuquin. She has stopped using Cosequin.,"List([Fluffy is taking Dasuquin., [1, 25]], [She has stopped using Cosequin., [28, 58]])","[Fluffy is taking Dasuquin., [1, 25]]","List([Fluffy is taking Dasuquin., [1, 25]])",Fluffy is taking Dasuquin.,"[1, 25]","List([1, 25])",1,25
row123,Fluffy is taking Dasuquin. She has stopped using Cosequin.,"List([Fluffy is taking Dasuquin., [1, 25]], [She has stopped using Cosequin., [28, 58]])","[She has stopped using Cosequin., [28, 58]]","List([She has stopped using Cosequin., [28, 58]])",She has stopped using Cosequin.,"[28, 58]","List([28, 58])",28,58


In [0]:
# send sentences through model
# if you let empty sentences (or ones that just contain \n) into the model it will run forever then bomb
with_ner_predictions = sample_note_parsed_sentences\
    .select(F.col("rowkey"), F.col("sentence"), F.col("sentence_span_start"), F.col("sentence_span_end"))\
    .withColumnRenamed("sentence", "text")\
    .filter((F.col("sentence_span_end") - F.col("sentence_span_start"))>0)\
    .withColumn("prediction", model_udf(F.col("text")))\
    .withColumn("prediction_array", 
                from_json(col("prediction"), 
                          'ARRAY<STRUCT<entity_group: STRING, score: FLOAT, word: STRING, start: INT, end: INT>>')
                )\
    .withColumn("individual_predictions", explode("prediction_array"))\
    .drop("prediction", "prediction_array")\
    .withColumn("entity_group", F.col("individual_predictions.entity_group"))\
    .withColumn("score", F.col("individual_predictions.score"))\
    .withColumn("word", F.col("individual_predictions.word"))\
    .withColumn("start", F.col("individual_predictions.start"))\
    .withColumn("end", F.col("individual_predictions.end"))\
    .drop("individual_predictions")

display(with_ner_predictions)

rowkey,text,sentence_span_start,sentence_span_end,entity_group,score,word,start,end
row123,Fluffy is taking Dasuquin.,1,25,Supplement,0.99635386,dasuquin,17,25
row123,She has stopped using Cosequin.,28,58,Supplement,0.99432075,cosequin,22,30


In [0]:
# with_ner_predictions.write.mode("overwrite").parquet("s3a://FAKE/FAKE")