## STEPS

* We go through the 5 steps that are required to use a trained model and optimize it for a desired level of recall

## STEP 0: LIBRARIES

In [None]:
import json

In [None]:
from transformers import HfArgumentParser, TrainingArguments

In [None]:
from robust_deid.ner_datasets import DatasetSplitter, DatasetCreator, SpanFixer, SpanValidation
from robust_deid.sequence_tagging import SequenceTagger, RecallThresholder
from robust_deid.sequence_tagging.arguments import (
    ModelArguments,
    DataTrainingArguments,
    EvaluationArguments,
)

## STEP 1: INITIALIZE

In [None]:
# Initialize the location where we will store the validation data
validation_file_raw = '/home/pk621/projects/data/ehr_deidentification/i2b2/validation_unfixed.jsonl'
# Initialize the location where we will store the validation data after fixing the spans
validation_file = '/home/pk621/projects/data/ehr_deidentification/i2b2/validation.jsonl'
# Initialize the location where the spans for hte validation data are stored
validation_spans_file = '/home/pk621/projects/data/ehr_deidentification/i2b2/validation_spans.jsonl'
# Initialize the location where we will store the sentencized and tokenized validation dataset (validation_file)
ner_validation_file = '/home/pk621/projects/data/ehr_deidentification/ner_datasets/i2b2_train/validation.jsonl'
# Initialize the location where we will store the model logits (predictions_file)
# Verify this file location - Ensure it's the same location that you will pass in the json file
# to the sequence tagger model. i.e. output_predictions_file in the json file should have the same
# value as below
logits_file = '/home/pk621/projects/data/ehr_deidentification/model_predictions/i2b2_train/logits.jsonl'
# Initialize the model config. This config file contains the various parameters of the model.
model_config = './run/i2b2/logits_i2b2.json'

In [None]:
# Initialize the sentencizer and tokenizer
sentencizer = 'en_core_sci_sm'
tokenizer = 'clinical'
notation = 'BILOU'

## STEP 2: FIX SPANS

* This step is optional and may not be required
* This code may be required if you have spans that don't line up with your tokenizer (e.g dataset was annoated at a character level and yout tokenizer doesn't split at the same position). This code fixes the spans so that the code below (creating NER datasets) runs wothout error.
* We experienced the issue above in the step where we create the NER dataset (step 5) - where we need to align the labels with the tokens based on the BILOU/BIO.. notation. Without this step, we would run into alignment issues.
* If you face the same issue, running this step should fix it - changes the label start and end positions of the annotated spans based on your tokenizer and saves the new spans.

In [None]:
ner_types = ["PATIENT", "STAFF", "AGE", "DATE", "PHONE", "ID", "EMAIL", "PATORG", "LOC", "HOSP", "OTHERPHI"]
# Sometimes there may be some label (span) overlap - the priority list assigns a priority to each label.
# Higher preference is given to labels with higher priority when resolving label overlap
ner_priorities = [2, 1, 2, 2, 2, 2, 2, 1, 2, 1, 1]
## Initialize the span fixer object
span_fixer = SpanFixer(
    tokenizer=tokenizer,
    sentencizer=sentencizer,
    ner_priorities={ner_type: priority for ner_type, priority in zip(ner_types, ner_priorities)},
    verbose=True
)
## Write the dataset with the fixed test spans to a file
with open(validation_file, 'w') as file:
    for note in span_fixer.fix(
        input_file=validation_file_raw,
        text_key='text',
        spans_key='spans'
    ):
        file.write(json.dumps(note) + '\n')

## STEP 3: NER DATASET
* Sentencize and tokenize the raw text. We used sentences of length 128, which includes an additional 32 context tokens on either side of the sentence. These 32 tokens serve (from the previous & next sentence) serve as additional context to the current sentence.
* We used the en_core_sci_sm sentencizer and a custom tokenizer (can be found in the preprocessing module)
* The dataset stored in the ner_dataset_file will be used as input to the sequence tagger model

In [None]:
# Create the dataset creator object
dataset_creator = DatasetCreator(
    sentencizer=sentencizer,
    tokenizer=tokenizer,
    max_tokens=128,
    max_prev_sentence_token=32,
    max_next_sentence_token=32,
    default_chunk_size=32,
    ignore_label='NA'
)

In [None]:
# This function call sentencizes and tokenizes the dataset
# It returns a generator that iterates through the sequences.
# We write the output to the ner_dataset_file (in json format)
# Validation split
ner_notes_test = dataset_creator.create(
    input_file=validation_file,
    mode='train',
    notation=notation,
    token_text_key='text',
    metadata_key='meta',
    note_id_key='note_id',
    label_key='label',
    span_text_key='spans'
)

In [None]:
# Write test ner split to file
with open(ner_validation_file, 'w') as file:
    for ner_sentence in ner_notes_test:
        file.write(json.dumps(ner_sentence) + '\n')

## STEP 4: SEQUENCE TAGGING
* Train the sequence model - specify parameters to the sequence model in the config file (model_config). The model will be trained with the specified parameters. For more information of these parameters, please refer to huggingface (or use the docs provided).
* You can manually pass in the parameters instead of using the config file. The config file option is recommended. In our example we are passing the parameters through a config file. If you do not want to use the config file, skip the next code block and manually enter the values in the following code blocks. You will still need to read in the training args using huggingface and change values in the training args according to your needs.

In [None]:
parser = HfArgumentParser((
    ModelArguments,
    DataTrainingArguments,
    EvaluationArguments,
    TrainingArguments
))
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, evaluation_args, training_args = parser.parse_json_file(json_file=model_config)

In [None]:
# Initialize the sequence tagger
sequence_tagger = SequenceTagger(
    task_name=data_args.task_name,
    notation=data_args.notation,
    ner_types=data_args.ner_types,
    model_name_or_path=model_args.model_name_or_path,
    config_name=model_args.config_name,
    tokenizer_name=model_args.tokenizer_name,
    post_process=model_args.post_process,
    cache_dir=model_args.cache_dir,
    model_revision=model_args.model_revision,
    use_auth_token=model_args.use_auth_token,
    threshold=model_args.threshold,
    do_lower_case=data_args.do_lower_case,
    fp16=training_args.fp16,
    seed=training_args.seed,
    local_rank=training_args.local_rank
)
# Load the required functions of the sequence tagger
sequence_tagger.load()

In [None]:
# Set the required data and predictions of the sequence tagger
# Can also use data_args.test_file instead of ner_dataset_file (make sure it matches ner_dataset_file)
sequence_tagger.set_predict(
    test_file=data_args.test_file,
    max_test_samples=data_args.max_predict_samples,
    preprocessing_num_workers=data_args.preprocessing_num_workers,
    overwrite_cache=data_args.overwrite_cache
)

In [None]:
# Initialize the huggingface trainer
sequence_tagger.setup_trainer(training_args=training_args)

In [None]:
# Store predictions in the specified file
predictions = sequence_tagger.predict()
# Write predictions to a file
with open(logits_file, 'w') as file:
    for prediction in predictions:
        file.write(json.dumps(prediction) + '\n')

## STEP 5: RECALL THRESHOLDING
* The objective is to modify the classification thresholds, i.e. instead of choosing the class with the highest probability as the prediction for a token (optimize F1), we modify the classification thresholds to optimize recall.
* The code below is to get these thresholds such that we get the desired level of recall. We use a validation dataset to optimize the threshold and level of recall.
* We get the thresholds by re-formulating the NER task as a binary classifiation task. PHI v/s non-PHI. We have two two methods to do this: MAX and SUM.
* MAX: 
    - probability of PHI class = maximum SoftMax probability over all the PHI classes
    - probability of non-PHI class
* SUM: 
    - probability of PHI class = sum of SoftMax probabilities over all the PHI classes
    - probability of non-PHI class
* A brief explantion of how we use these thresholds is explained below
* Feel free to test out differrent recall thresholds. The thresholds are computed against a validation dataset. You would then make use of these thresholds to run the evaluation or forward pass against a test dataset

In [None]:
recall_thresholder = RecallThresholder(notation=data_args.notation, ner_types=data_args.ner_types)

In [None]:
# Threshold mode - max
# This means that an input token is tagged with the non-PHI class only if the 
# maximum probability over all PHI classes was less than the chosen threshold.
# We tag the token with the PHI class that has the highest probability
precision, recall, threshold = recall_thresholder.get_precision_recall_threshold(
    logits_file=logits_file,
    recall_cutoff=99.87/100,
    threshold_mode='max',
    predictions_key='predictions',
    labels_key='labels'
)

In [None]:
print('Threshold Mode: ' + 'MAX')
print('At threshold: ', threshold)
print('Precision is: ', precision * 100)
print('Recall is: ', recall * 100)

In [None]:
# Threshold mode - sum
# This means that an input token is tagged with the PHI class only if the sum of 
# probabilities over all PHI classes is greater than the chosen threshold.
# We tag the token with the PHI class that has the highest probability
precision, recall, threshold = recall_thresholder.get_precision_recall_threshold(
    logits_file=logits_file,
    recall_cutoff=99.87/100,
    threshold_mode='sum',
    predictions_key='predictions',
    labels_key='labels'
)

In [None]:
print('Threshold Mode: ' + 'SUM')
print('At threshold: ', threshold)
print('Precision is: ', precision * 100)
print('Recall is: ', recall * 100)

## Using the thresholds
* Once we have the thresholds - we can use these to evaluate against a dataset and/or run the forward pass against a dataset to aggressively remove PHI
* We have 4 config files included in the run folder, two for the evaluation (max & sum) and two for the forward pass (max & sum).
* You can make use of these config files and run the scripts/notebooks in the evaluation and forward pass folders.
* Replace the config files (argmax) with the ones shown above (recall optimized) and the evaluation and forward pass will be run with the recall optimized models.