Expected input is a pyspark dataframe with a column containing image as base64 string for input into VLMs

DICOM images are previously scaled down and converted into base64 string via `dbx.pixels.dicom.dicom_to_base64jpg`

In [0]:
# If running over databricks-connect.
# Skip this cell if running as Databricks NB
from databricks.connect import DatabricksSession
spark = DatabricksSession.builder.getOrCreate()

In [0]:
%pip install -e /Workspace/Users/yen.low@databricks.com/pixel/pixels
dbutils.library.restartPython()

In [0]:
%reload_ext autoreload
%autoreload 2

## Load input dataframe
Required input column must be ONE of the following:
1. a .dcm file path (e.g. `/Volumes/hls_radiology/2.1.656.0.2.8048482.9.537.165816238/1-1.dcm`)
2. image file path (e.g. `/Volumes/hls_radiology/2.1.656.0.2.8048482.9.537.165816238/1-1.jpg`)
3. image encoded as a base64 string required by VLM (e.g. `/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwc...`)

#### Read .dcm paths from Volume

In [0]:
# Volume where .dcm files are stored
phi_folder = "/Volumes/hls_radiology/tcia/downloads/tciaDownload/"

# Recursivingely read in .dcm files
# Binary is automatically read into a "content" column. 
# Drop to avoid storing large binary files in tables
df = (spark.read.format("binaryFile")
    .option("recursiveFileLookup", "true")
    .load(phi_folder)
    .drop("content")
)
display(df)

(df.write.format("delta")
    .mode("overwrite").option("overwriteSchema", "true")
    .saveAsTable("hls_radiology.tcia.midi_b_val_subset")
)

#### If table of DICOM paths already exist

In [0]:
df = spark.table("hls_radiology.tcia.midi_b_val_subset")
display(df)

## Run on a single image

In [0]:
from dbx.pixels.dicom.dicom_vlm_phi_detector import VLMPhiExtractor

# Initialize VLMPhiExtractor to do PHI extraction by VLM
extractor = VLMPhiExtractor(endpoint="databricks-claude-3-7-sonnet")

In [0]:
# Get single dicom file
path = df.select("path").take(1)[-1]['path']
extractor.extract(path)

In [0]:
# Get single image string
jpg_base64str = df.select("jpg_base64str").take(1)[-1]['jpg_base64str']
extractor.extract(jpg_base64str, input_type="base64")

## Redact single image

In [0]:
from dbx.pixels.dicom.dicom_easyocr_redactor import ocr_dcm

bb_redact = ocr_dcm(path, display=True)

## Bulk run: `VLMTransformer().transform(df)`
where the input column is a image string in base64.<br>
This returns a df with a new column `response` returned by the VLM specified in `endpoint`

In [0]:
from dbx.pixels.dicom.dicom_vlm_phi_detector import VLMTransformer

vlm_transformer = VLMTransformer(endpoint="databricks-claude-3-7-sonnet", 
                                 temperature = 0.0,
                                 num_output_tokens=200,
                                 inputCol="path", 
                                 outputCol="response",
                                 input_type="base64",
                                 max_width=768)
out_df = vlm_transformer.transform(df.limit(10))
display(out_df)

## Extract PHI entities and evaluate against ground truth (`has_phi`)

In [0]:
from pyspark.sql.functions import split, col, when, size

extracted_df = (out_df
    .withColumn("entities", col("response.content"))
    .withColumn("phi_detected", when(size(col("entities"))>1, True).otherwise(False))
    .drop("jpg_base64str", "jpg_base64str_masked")
)
display(extracted_df)

In [0]:
# Extract selected columns and convert to pandas for subsequent sklearn metrics computation
extracted_pdf = extracted_df.select("has_phi", "phi_detected").toPandas()
display(extracted_pdf)

In [0]:
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

precision = precision_score(extracted_pdf.has_phi, extracted_pdf.phi_detected)
recall = recall_score(extracted_pdf.has_phi, extracted_pdf.phi_detected)
f1 = f1_score(extracted_pdf.has_phi, extracted_pdf.phi_detected)
accuracy = accuracy_score(extracted_pdf.has_phi, extracted_pdf.phi_detected)

precision, recall, f1, accuracy

## Bulk redaction: wrap Pandas UDF around redaction functions `ocr_dcm` and `dicom_to_array`

In [0]:
from dbx.pixels.dicom.dicom_utils import dicom_to_array


@pandas_udf("array<array<int>>")
def dicom_to_array_udf(path: pd.Series) -> pd.Series:
    return path.apply(dicom_to_array)

@pandas_udf("array<array<int>>")
def ocr_dcm_udf(path: pd.Series) -> pd.Series:
    return path.apply(ocr_dcm)


@pandas_udf("array<array<int>>")
def fill_bb_udf(path: pd.Series) -> pd.Series:
    return path.apply(fill_bounding_boxes)



extracted_df = (extracted_df.where(extracted_df.phi_detected==True)
    .withColumn("image_array",  dicom_to_array_udf(col("bb")))
    .withColumn("bb",  ocr_dcm_udf(col("path")))
    .withColumn("image_array",  dicom_to_array_udf(col("bb")))
)
display(extracted_df)