## STEPS

* We go through the 5 steps that are required to use a trained model for evaluation against a dataset

## STEP 0: LIBRARIES

In [None]:
import json

In [None]:
from transformers import HfArgumentParser, TrainingArguments

In [None]:
from robust_deid.ner_datasets import DatasetSplitter, DatasetCreator, SpanFixer, SpanValidation
from robust_deid.sequence_tagging import SequenceTagger
from robust_deid.sequence_tagging.arguments import (
    ModelArguments,
    DataTrainingArguments,
    EvaluationArguments,
)

## STEP 1: INITIALIZE

In [None]:
# Initialize the path where the dataset is located (input_file).
# Initialize the location where we will store the test data
test_file_raw = '/home/pk621/projects/data/ehr_deidentification/i2b2/test_unfixed.jsonl'
# Initialize the location where we will store the test data after fixing the spans
test_file = '/home/pk621/projects/data/ehr_deidentification/i2b2/test.jsonl'
# Initialize the location where the spans for hte test data are stored
test_spans_file = '/home/pk621/projects/data/ehr_deidentification/i2b2/test_spans.jsonl'
# Initialize the location where we will store the sentencized and tokenized test dataset (test_file)
ner_test_file = '/home/pk621/projects/data/ehr_deidentification/ner_datasets/i2b2_train/test.jsonl'

# Initialize the model config. This config file contains the various parameters of the model.
model_config = './run/i2b2/eval_i2b2.json'

In [None]:
# Initialize the sentencizer and tokenizer
sentencizer = 'en_core_sci_sm'
tokenizer = 'clinical'
notation = 'BILOU'

## STEP 2: TEST SPANS
* We write out the test spans and also setup the token level test dataset below
* We do this because we may have different tokenizers and to make a fair comparison, we compare models at the span level. To do this we need the span information at a character level (start of span & end of span in terms of character position). If we have information at a character level, it is easier to compare different tokenizers. Now we not only have token level performance, but also span level performance
* We use this step to write out the annotated spans to do span level evaluation. 
* One of the reason we did this is because in step 3 we modify the original annotated spans so that we can create a NER dataset. To ensure that we still evaluate on the original annotated dataset we do this step. Read step 3 to understand further why we need do this to ensure a fair comparison during evaluation.
* This step has the test data (or any dataset that you want to test on) in the original form - spans with the specified start and end position.
* In summary, for doing span level evaluation we need a mapping between note_id and the annotated spans for that note_id 

In [None]:
# We write out the test spans and also setup the token level test dataset below
# We do this because we may have different tokenizers and to make a fair comparison, we compare models
# at the span level. To do this we need the span information at a character level 
# (start of span & end of span in terms of character position). If we have information at a character
# level, it is easier to compare different tokenizers. Now we not only have token level performance, 
# but also span level performance
with open(test_spans_file, 'w') as file:
    for span_info in SpanValidation.get_spans(
            input_file=test_file_raw,
            metadata_key='meta',
            note_id_key='note_id',
            spans_key='spans'):
        file.write(json.dumps(span_info) + '\n')

## STEP 3: FIX SPANS

* This step is optional and may not be required
* This code may be required if you have spans that don't line up with your tokenizer (e.g dataset was annoated at a character level and yout tokenizer doesn't split at the same position). This code fixes the spans so that the code below (creating NER datasets) runs wothout error.
* We experienced the issue above in the step where we create the NER dataset (step 5) - where we need to align the labels with the tokens based on the BILOU/BIO.. notation. Without this step, we would run into alignment issues.
* If you face the same issue, running this step should fix it - changes the label start and end positions of the annotated spans based on your tokenizer and saves the new spans.

In [None]:
ner_types = ["PATIENT", "STAFF", "AGE", "DATE", "PHONE", "ID", "EMAIL", "PATORG", "LOC", "HOSP", "OTHERPHI"]
# Sometimes there may be some label (span) overlap - the priority list assigns a priority to each label.
# Higher preference is given to labels with higher priority when resolving label overlap
ner_priorities = [2, 1, 2, 2, 2, 2, 2, 1, 2, 1, 1]
## Initialize the span fixer object
span_fixer = SpanFixer(
    tokenizer=tokenizer,
    sentencizer=sentencizer,
    ner_priorities={ner_type: priority for ner_type, priority in zip(ner_types, ner_priorities)},
    verbose=True
)
## Write the dataset with the fixed test spans to a file
with open(test_file, 'w') as file:
    for note in span_fixer.fix(
        input_file=test_file_raw,
        text_key='text',
        spans_key='spans'
    ):
        file.write(json.dumps(note) + '\n')

## STEP 5: NER DATASET
* Sentencize and tokenize the raw text. We used sentences of length 128, which includes an additional 32 context tokens on either side of the sentence. These 32 tokens serve (from the previous & next sentence) serve as additional context to the current sentence.
* We used the en_core_sci_sm sentencizer and a custom tokenizer (can be found in the preprocessing module)
* The dataset stored in the ner_dataset_file will be used as input to the sequence tagger model

In [None]:
# Create the dataset creator object
dataset_creator = DatasetCreator(
    sentencizer=sentencizer,
    tokenizer=tokenizer,
    max_tokens=128,
    max_prev_sentence_token=32,
    max_next_sentence_token=32,
    default_chunk_size=32,
    ignore_label='NA'
)

In [None]:
# This function call sentencizes and tokenizes the dataset
# It returns a generator that iterates through the sequences.
# We write the output to the ner_dataset_file (in json format)
# Validation split
ner_notes_test = dataset_creator.create(
    input_file=test_file,
    mode='train',
    notation=notation,
    token_text_key='text',
    metadata_key='meta',
    note_id_key='note_id',
    label_key='label',
    span_text_key='spans'
)

In [None]:
# Write test ner split to file
with open(ner_test_file, 'w') as file:
    for ner_sentence in ner_notes_test:
        file.write(json.dumps(ner_sentence) + '\n')

## STEP 6: SEQUENCE TAGGING
* Train the sequence model - specify parameters to the sequence model in the config file (model_config). The model will be trained with the specified parameters. For more information of these parameters, please refer to huggingface (or use the docs provided).
* You can manually pass in the parameters instead of using the config file. The config file option is recommended. In our example we are passing the parameters through a config file. If you do not want to use the config file, skip the next code block and manually enter the values in the following code blocks. You will still need to read in the training args using huggingface and change values in the training args according to your needs.

In [None]:
parser = HfArgumentParser((
    ModelArguments,
    DataTrainingArguments,
    EvaluationArguments,
    TrainingArguments
))
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, evaluation_args, training_args = parser.parse_json_file(json_file=model_config)

In [None]:
# Initialize the sequence tagger
sequence_tagger = SequenceTagger(
    task_name=data_args.task_name,
    notation=data_args.notation,
    ner_types=data_args.ner_types,
    model_name_or_path=model_args.model_name_or_path,
    config_name=model_args.config_name,
    tokenizer_name=model_args.tokenizer_name,
    post_process=model_args.post_process,
    cache_dir=model_args.cache_dir,
    model_revision=model_args.model_revision,
    use_auth_token=model_args.use_auth_token,
    threshold=model_args.threshold,
    do_lower_case=data_args.do_lower_case,
    fp16=training_args.fp16,
    seed=training_args.seed,
    local_rank=training_args.local_rank
)
# Load the required functions of the sequence tagger
sequence_tagger.load()

In [None]:
# Set the required data for the evaluation of the sequence tagger
sequence_tagger.set_eval(
    validation_file=data_args.validation_file,
    max_val_samples=data_args.max_eval_samples,
    preprocessing_num_workers=data_args.preprocessing_num_workers,
    overwrite_cache=data_args.overwrite_cache
)
sequence_tagger.set_eval_metrics(
    validation_spans_file=evaluation_args.validation_spans_file,
    model_eval_script=evaluation_args.model_eval_script,
    ner_types_maps=evaluation_args.ner_type_maps,
    evaluation_mode=evaluation_args.evaluation_mode
)

In [None]:
# Initialize the huggingface trainer
sequence_tagger.setup_trainer(training_args=training_args)

In [None]:
metrics = sequence_tagger.evaluate()

In [None]:
print(json.dumps(metrics, indent=2))