# Social Determinants of Health for FE5

This notebook details using JohnSnowLabs models to extract SDOH concepts from clinical text. 

This notebook will iterate through N records at a time to avoid reaching computer memory limits.

## Prerequisites

The following was run on a Windows VM and so some configurations/setup will be different.

Windows prerequisites:
* Download relevant version of hadoop/bin to, e.g., `C:\hadoop\bin` from https://github.com/cdarlint/winutils 
* Environment variables (these have been included in the following code blocks):
    * `HADOOP_HOME`= `C:\hadoop`
    * `PATH` += `%HADOOP_HOME%\bin`

Prerequisites:
* Java (JDK 1.8/Java 8)
    * Set `JAVA_HOME` to install path, and add to `PATH`
* Python version compatible with SparkNLP
    * Python packages: see `my.johnsnowlabs.com/docs` > `Install Locally on Python` > follow steps 
        * Some of these installs will take quite a while to download relevant jar files.
* Download license keys from `https://my.johnsnowlabs.com/subscriptions`
    * Save json file to, e.g., `C:\spark_jsl_54.json`
    * NB: Different versions of Spark NLP (e.g., 5.3 vs 5.4) will require different license keys.  

## Input Data

The input data should be either a CSV file or a JSONL file.

The CSV file should have the following structure (the names can be different, but these are preferred):
* `note_id`: an arbitrary note_id; this should be unique across the dataset (order won't work due to concurrency)
* `note_text`: text to be processed; null values will be dropped

## Output Data

This notebook will create a CSV file with the following variables:
* `note_id`: note_id from input data 
* `confidence`: float in range (0.0, 1.0]; some cutoff should be decided
* `entity`: SDOH entity, one of the following:
    * 'Access_To_Care'
    * 'Chidhood_Event'
    * 'Community_Safety'
    * 'Disability'
    * 'Eating_Disorder'
    * 'Education'
    * 'Environmental_Condition'
    * 'Exercise'
    * 'Family_Member'
    * 'Financial_Status'
    * 'Food_Insecurity'
    * 'Geographic_Entity'
    * 'Healthcare_Institution'
    * 'Housing'
    * 'Income'
    * 'Insurance_Status'
    * 'Legal_Issues'
    * 'Mental_Health'
    * 'Other_SDoH_Keywords'
    * 'Population_Group'
    * 'Quality_Of_Life'
    * 'Social_Exclusion'
    * 'Social_Support'
    * 'Spiritual_Beliefs'
    * 'Substance_Duration'
    * 'Substance_Frequency'
    * 'Substance_Quantity'
    * 'Substance_Use'
    * 'Transportation'
    * 'Violence_Or_Abuse'
* `assertion`: descriptors of SDOH concept
    * `Present`: is current
    * `Past`: described in the past
    * `Someone_Else`: related to a different person
    * `Possible`: described as possible
    * `Absent`: negated
    * `Family_History`: related to family history
    * Others?


In [None]:
import csv
import json
import os
from pathlib import Path
import sys
import time

from johnsnowlabs import nlp, medical
from loguru import logger
from pyspark.sql.functions import row_number
from pyspark.sql.window import Window
import sparknlp
import sparknlp_jsl

In [None]:
# Load license keys, prepare configurations
with open(r'C:\spark_jsl_54.json') as f:
    license_keys = json.load(f)

locals().update(license_keys)
os.environ.update(license_keys)
hadoop_home = r'C:\hadoop'
os.environ['HADOOP_HOME'] = hadoop_home
os.environ['PATH'] += fr';{hadoop_home}\bin;{hadoop_home}'
os.environ['PYSPARK_PYTHON'] = sys.executable
os.environ['PYSPARK_DRIVER_PYTHON'] = sys.executable

params = {'spark.driver.memory': '4G',
          'spark.executor.memory': '19G',
          'spark.kryoserializer.buffer.max': '2000M',
          'spark.driver.maxResultSize': '2000M',
          }

spark = sparknlp_jsl.start(license_keys['SECRET'], gpu=False, params=params)

logger.info(f'Spark NLP Version: {sparknlp.version()}')
logger.info(f'Spark NLP_JSL Version: {sparknlp_jsl.version()}')

In [None]:
# configurations
note_id_col = 'note_id'
note_text_col = 'note_text'
in_dataset_path = Path(r'path/to/example.jsonl')
out_dataset_path = Path(r'path/to/out.csv')
test_top_n = None  # set to integer to only run on e.g., 'top 100' for testing purposes
batch_size = 1000  # to determine optimal batch size, consider running `sdoh.ipynb` with varying `test_top_n`
# logger.add('logfile')  # uncomment this to log to file

In [None]:
# the pipeline
documentAssembler = (
    nlp.DocumentAssembler()
    .setInputCol('note_text')
    .setIdCol('note_id')
    .setOutputCol('document')
)

sentenceDetector = (
    nlp.SentenceDetectorDLModel.pretrained('sentence_detector_dl_healthcare', 'en', 'clinical/models')
    .setInputCols(['document'])
    .setOutputCol('sentence').setCustomBounds(['\|'])
)

tokenizer = nlp.Tokenizer() \
    .setInputCols(['sentence']) \
    .setOutputCol('token')  #\

clinical_embeddings = nlp.WordEmbeddingsModel.pretrained('embeddings_clinical', 'en', 'clinical/models') \
    .setInputCols(['sentence', 'token']) \
    .setOutputCol('embeddings')

ner_model = medical.NerModel.pretrained('ner_sdoh', 'en', 'clinical/models') \
    .setInputCols(['sentence', 'token', 'embeddings']) \
    .setOutputCol('ner')

ner_conv = medical.NerConverterInternal() \
    .setInputCols(['sentence', 'ner', 'token']) \
    .setOutputCol('chunk_main')

assertion = medical.AssertionDLModel.pretrained('assertion_sdoh_wip', 'en', 'clinical/models') \
    .setInputCols(['sentence', 'token', 'embeddings', 'chunk_main']) \
    .setOutputCol('assertion')

assertion_filterer_hypo = medical.AssertionFilterer() \
    .setInputCols(['sentence', 'chunk_main', 'assertion']) \
    .setOutputCol('filtered') \
    .setCriteria('assertion') \
    .setWhiteList(['present', 'Possible', 'Absent', 'Family_History', 'Someone_Else', 'Past'])

assertion_filterer_present = medical.AssertionFilterer() \
    .setInputCols(['sentence', 'chunk_main', 'assertion']) \
    .setOutputCol('filtered_present') \
    .setCriteria('assertion') \
    .setWhiteList(['present', 'Planned'])

assertion_filterer_possible = medical.AssertionFilterer() \
    .setInputCols(['sentence', 'chunk_main', 'assertion']) \
    .setOutputCol('filtered_possible') \
    .setCriteria('assertion') \
    .setWhiteList(['Possible'])

assertion_filterer_absent = medical.AssertionFilterer() \
    .setInputCols(['sentence', 'chunk_main', 'assertion']) \
    .setOutputCol('filtered_absent') \
    .setCriteria('assertion') \
    .setWhiteList(['Absent'])

assertion_filterer_hist = medical.AssertionFilterer() \
    .setInputCols(['sentence', 'chunk_main', 'assertion']) \
    .setOutputCol('filtered_hist') \
    .setCriteria('assertion') \
    .setWhiteList(['Family_History', 'Someone_Else'])

assertion_filterer_past = medical.AssertionFilterer() \
    .setInputCols(['sentence', 'chunk_main', 'assertion']) \
    .setOutputCol('filtered_past') \
    .setCriteria('assertion') \
    .setWhiteList(['Past'])

jsl_single_pipeline = nlp.Pipeline(
    stages=[
        documentAssembler,
        sentenceDetector,
        tokenizer,
        clinical_embeddings,
        ner_model,
        ner_conv,
        assertion,
        assertion_filterer_hypo,
        assertion_filterer_present,
        assertion_filterer_possible,
        assertion_filterer_absent,
        assertion_filterer_hist,
        assertion_filterer_past
    ]
)

In [None]:
# read dataset
if in_dataset_path.suffix == '.jsonl':
    df = spark.read.json(str(in_dataset_path)).select(note_id_col, note_text_col)
elif in_dataset_path.suffix == '.csv':
    df = spark.read.csv(str(in_dataset_path), header=True).select(note_id_col, note_text_col)
else:
    raise ValueError(f'Unrecognized extension: {in_dataset_path.suffix}')

# prepare dataset for processing
df = df.na.drop()  # drop notes without text
df = df.toDF('note_id', 'note_text')  # rename columns

In [None]:
# experiment with just a few records
if test_top_n:
    df = df.limit(test_top_n)

In [None]:
# prepare pipelines
p_model = jsl_single_pipeline.fit(df)
l_model = nlp.LightPipeline(p_model)

In [None]:
def annotate(row):
    """ Annotates the data provided into `Annotation` type results. """
    note_id = row['note_id']
    data = l_model.fullAnnotate(row['note_text'])
    for result in data:
        for entity in result['filtered']:
            entity = entity.metadata
            ner_type = entity['entity']
            assertion = entity['assertion']
            yield {
                'note_id': note_id,
                'confidence': entity['confidence'],
                'assertion': assertion,
                'entity': ner_type,
            }

In [None]:
# prepare batches for processing dataframe 
window_spec = Window.orderBy('note_id')  # Replace "some_column" with a column to order by
df = df.withColumn('row_num', row_number().over(window_spec))
num_batches = (df.count() + batch_size - 1) // batch_size
logger.info(f'Number of batches: {num_batches}')

In [None]:
# process write output to CSV
before_time = time.time()
n_entities = 0
n_notes = 0
fieldnames = ['note_id', 'confidence', 'entity', 'assertion']
with open(out_dataset_path, 'w', newline='', encoding='utf8') as fh:
    writer = csv.DictWriter(fh, fieldnames)
    writer.writeheader()
    for batch_num in range(num_batches):
        logger.info(f'Starting batch #{batch_num + 1} of {num_batches}')
        # get the current batch dataframe
        start_row = batch_num * batch_size + 1
        end_row = (batch_num + 1) * batch_size
        batch_df = df.filter((df.row_num >= start_row) & (df.row_num <= end_row))
        for i, row in enumerate(batch_df.collect(), start=n_notes + 1):
            for result in annotate(row):
                # result: {'note_id': '1', 'confidence': 0.9981, 'assertion': 'Present', 'entity': 'Employment'}
                writer.writerow(result)
                n_entities += 1
            if i % 10_000 == 0:
                after_time = time.time()
                logger.info(f'Processed {i} notes and wrote {n_entities} entities to file; {(after_time - before_time)/i} s/notes')
        n_notes += batch_df.count()
        after_time = time.time()
        logger.info(f'Finished batch {batch_num + 1}! Processed {n_notes} notes and wrote {n_entities} entities to file; {(after_time - before_time)/n_notes} s/notes')
after_time = time.time()
logger.info(f'Done! Processed {n_notes} notes and wrote {n_entities} entities to file; {(after_time - before_time)/n_notes} s/notes')