# Detect DICOM Image PHI and Redact PHI
- Recommended cluster: 16.4 LTS ML

## Setup `pixels` package (if you haven't done so)
1. [Create a git folder](https://docs.databricks.com/aws/en/repos/git-operations-with-repos) cloning the [pixels](https://github.com/databricks-industry-solutions/pixels) package
2. Then run the [`config/setup.py`]($./config/setup) script in the repo folder

In [0]:
%run ./config/setup

In [0]:
import pandas as pd

In [0]:
output_dir = "/Volumes/hls_radiology/tcia/redacted"

## Load input dataframe
`VLMPhiExtractor` requires that input be must be ONE of the following:
1. a .dcm file path (e.g. `/Volumes/<catalog>/<schema>/2.1.656.0.2.8048482.9.537.165816238/1-1.dcm`)
2. image file path (e.g. `/Volumes/<catalog>/<schema>/2.1.656.0.2.8048482.9.537.165816238/1-1.jpg`)
3. image encoded as a base64 string required by VLM (e.g. `/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwc...`)

#### A table of DICOM paths already exist

In [0]:
df = (
  spark.table("hls_radiology.tcia.midi_b_val_subset")
  .drop('jpg_base64str', 'jpg_base64str_masked')
)
display(df)

## Bulk run: `VLMTransformer().transform(df)`
Returns a df with a new column `response` (specify in `outputCol`) returned by the VLM specified in `endpoint`

`VLMTransformer` wraps around `VLMPhiExtractor` for spark dataframe transformations.<br>
It also allows 3 types of inputs (see earlier [cell](https://e2-demo-field-eng.cloud.databricks.com/editor/notebooks/372649807139118?o=1444828305810485#command/3070146489819417)):
1. `input_type="dicom"` for .dcm file path
2. `input_type="image"` for image file path
3. `input_type="base64"`for image encoded as a base64 string

### Run VLM transformer for PHI detection
- 1 partition: 1.89s/image (132 sec/70 images)
- 8 partitions: 0.8s/image (56 sec/70 images)


In [0]:
from dbx.pixels.dicom.dicom_vlm_phi_detector import VLMTransformer

# For .dcm path input (inputCol="path", input_type="dicom")
vlm_transformer = VLMTransformer(endpoint="databricks-claude-3-7-sonnet", input_type="dicom")

df = df.repartition(8) # for parallelism
out_df = vlm_transformer.transform(df)
display(out_df)

## Extract PHI entities and evaluate against ground truth (`has_phi`)

In [0]:
from pyspark.sql.functions import split, col, when, size

extracted_df = (out_df
    .withColumn("entities", col("response.content"))
    .withColumn("phi_detected", when(size(col("entities"))>1, True).otherwise(False))
    #.drop("jpg_base64str", "jpg_base64str_masked")
)
display(extracted_df)

In [0]:
# Extract selected columns and convert to pandas for subsequent sklearn metrics computation
extracted_pdf = extracted_df.select("has_phi", "phi_detected").toPandas()
display(extracted_pdf)

In [0]:
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

precision = precision_score(extracted_pdf.has_phi, extracted_pdf.phi_detected)
recall = recall_score(extracted_pdf.has_phi, extracted_pdf.phi_detected)
f1 = f1_score(extracted_pdf.has_phi, extracted_pdf.phi_detected)
accuracy = accuracy_score(extracted_pdf.has_phi, extracted_pdf.phi_detected)

precision, recall, f1, accuracy

## Bulk redaction with Pandas UDF

In [0]:
from pyspark.sql.functions import pandas_udf
from dbx.pixels.dicom.dicom_utils import array_to_image
from dbx.pixels.dicom.dicom_easyocr_redactor import ocr_dcm
#from typing import BinaryIO

@pandas_udf("map<string, string>")
def ocr2redactarr_udf(paths: pd.Series) -> pd.Series:
    def ocr2redactarr(path: str) -> str:
        # Find text bounding boxes and apply fill mask
        redacted_array = ocr_dcm(path, display=False, gpu=True)
        # Save redacted images as jpg in output_dir
        suffix = "_".join(path.split("/")[-2:])
        output_path=f'{output_dir}/{suffix.replace(".dcm", ".jpg")}'
        array_to_image(redacted_array,
                              output_path=output_path,
                              return_type=None)
        return {"source": path, "redacted": output_path}
    return paths.apply(ocr2redactarr)

### Run OCR Redactor
- 1 partition: 65.7s/image (591 sec/9 images)
- 8 partitions: 43.4s/image (391 sec/9 images)

In [0]:
redact_df = extracted_df.where(extracted_df.phi_detected==True)
redact_df = redact_df.repartition(8)

# Bulk redaction with path as input
display(redact_df.select(ocr2redactarr_udf(col("path"))))

View the files in the [output_dir](https://e2-demo-field-eng.cloud.databricks.com/explore/data/volumes/hls_radiology/tcia/redacted)