In [80]:
#!/usr/bin/env python3

"""
- Script name: evaluate_model.
- Author: Dan Bright, cosmoid@tuta.io.
- Description: A script to evaluate performance of NER models
- Version: 0.1
"""

# declare imports
import os, re, json, html
import pandas as pd
from pathlib import Path
from pprint import pprint
from collections import OrderedDict

In [81]:
class CleanData:
    """
    Class that cleans and prepares the data training &

    Consumes:
        - input_data: list[tuple] = list of input data to clean, in form [(record_id:int, text:str)]
        - separate_slashes: bool = whether to separate slashes by a space [True|False]
        - remove_linebreaks: bool = whether to remove linebreaks & join by a space [True|False]
        - remove_non_alphanum: bool = whether to remove all non-alphanumeric characters [True|False]
        - ensure_encoding: bool = ensure all characters are correctly encoded (True|False)
    Produces:
        - list of cleaned data, in form [(record_id:int, text:str)]
    """

    def __new__(
        cls,
        input_data: list[tuple] = [],
        separate_slashes: bool = True,  # Default to True
        remove_linebreaks: bool = True,  # Default to True
        remove_non_alphanum: bool = True,  # Default to True
        ensure_encoding: bool = True,  # Default to True
    ) -> list[tuple]:
        obj = super().__new__(cls)
        return obj._run_filters(
            docs=input_data,
            separate_slashes=separate_slashes,
            remove_linebreaks=remove_linebreaks,
            remove_non_alphanum=remove_non_alphanum,
            ensure_encoding=ensure_encoding,
        )

    def _run_filters(
        self,
        docs,
        separate_slashes,
        remove_linebreaks,
        remove_non_alphanum,
        ensure_encoding,
    ) -> list[tuple]:
        """
        Method to iterate the data & run filters.
        Returns: list of cleaned data in form [(record_id:int, text:str)]
        """
        filtered_docs: list[tuple] = []
        for record in docs:
            record_txt: str = record[1]
            record_txt = (
                self._separate_slashes(record_txt) if separate_slashes else record_txt
            )
            record_txt = (
                self._remove_linebreaks(record_txt) if remove_linebreaks else record_txt
            )
            record_txt = (
                self._remove_non_alphanum(record_txt)
                if remove_non_alphanum
                else record_txt
            )
            record_txt = (
                self._ensure_encoding(record_txt) if ensure_encoding else record_txt
            )
            filtered_docs.append((record[0], record_txt))
        return filtered_docs

    @staticmethod
    def _ensure_encoding(input: str) -> str:
        """
        Method to ensure characters are encoded correctly
        (i.e., no html entities, etc)
        """
        return html.unescape(input)

    @staticmethod
    def _separate_slashes(input: str) -> str:
        """
        Method to ensure all slashes within strings are surrounded by
        whitespace.
        """
        return re.sub(r"(?<!\s)/(?!\s)", " / ", input)

    @staticmethod
    def _remove_linebreaks(input: str) -> str:
        """
        Method to remove paragraphs breaks.
        """
        return " ".join(input.splitlines())

    @staticmethod
    def _remove_non_alphanum(input: str) -> str:
        """
        Method to remove all non-alphanumeric characters, except:
          - whitespaces
          - dots
          - forward slashes
        """
        return re.sub(r"\s+", " ", re.sub(r"[^\w\s\.\/]+", "", input))

In [86]:
class EvaluateModel:
    """
    Class that evaluates performance of NER models.

    Consumes:
        - jupyter: bool = whether script is being run as a Jupyter notebook (True|False)
        - debug_output = whether to write verbose processing output to STDOUT for debug (True|False)
        - output_json_path = Path to output JSON (string)
        - output_xlsx_path = Path to output XLSX (string)
        - ner_record_id_key = key of record ID field in NER JSON (string)
        - anno_record_id_key = key of record ID field in annotation JSON (string)
        - ner_json_path: str = path to JSON file containing NER results (string)
        - anno_json_path: str = path to JSON file containing the annotated data (string)
        - export_json: str = write results output to JSON file? (True|False)
        - export_xlsx: str = write results output to XLSX file? (True|False)
    Produces:
        - results: dict = python dictionary containing results metrics for the input corpus,
          comprising of:
            - true positives
            - false positives
            - false negatives
            - total number of extracted entities
            - precision
            - recall
            - f1-score
    Notes:
        - input JSON MUST be in the form as per this example:
          [{"RECORD_ID": 280, "ICDT_DATE": ["2015/05/03", "May 3 2015"], "ICDT_TIME": ["0330Z"]}]
    """

    def __new__(
        cls,
        jupyter: bool = True,  # default True
        debug_output: bool = True,  # default True
        output_json_path: str = "",
        output_xlsx_path: str = "",
        ner_record_id_key: str = "",
        anno_record_id_key: str = "",
        ner_json_path: str = "",
        anno_json_path: str = "",
        export_json: bool = False,  # default False
        export_xlsx: bool = False,  # default False
    ) -> dict:
        obj = super().__new__(cls)
        # define variables
        obj._jupyter: bool = jupyter
        obj._debug_output: bool = debug_output
        obj._output_json_path: Path = Path(output_json_path).resolve(strict=False)
        obj._output_xlsx_path: Path = Path(output_xlsx_path).resolve(strict=False)
        obj._ner_record_id_key: str = ner_record_id_key
        obj._anno_record_id_key: str = anno_record_id_key
        obj._ner_json_path: Path = Path(ner_json_path).resolve(strict=True)
        obj._anno_json_path: Path = Path(anno_json_path).resolve(strict=True)
        obj._export_json: str = export_json
        obj._export_xlsx: str = export_xlsx
        obj._data_ner: dict = {}
        obj._data_anno: dict = {}
        obj._analytics_data: dict = {}
        obj._results: OrderedDict = OrderedDict()
        obj._unprocessed: list = []
        # run methods [note: do not change running order]
        obj._load_data()
        obj._analyze_data()
        obj._calculate_stats()
        return obj._get_results()

    def _pdb(self, string: str) -> None:
        # method to print debug strings if requested
        print(string) if self._debug_output else None

    def _load_data(self) -> None:
        # loads & sorts incoming json to python dict
        with open(self._ner_json_path, "r") as file:
            self._data_ner = json.load(file)
        with open(self._anno_json_path, "r") as file:
            self._data_anno = json.load(file)
        self._data_anno = sorted(
            self._data_anno, key=lambda x: x[self._anno_record_id_key]
        )
        self.data_ner = sorted(self._data_ner, key=lambda x: x[self._ner_record_id_key])

    @staticmethod
    def _remove_one_occurrence(lst: list, value: str) -> list:
        # method to remove 1st occurrence of value from list
        if value in lst:
            index = lst.index(value)
            return lst[:index] + lst[index + 1 :]
        return lst

    @staticmethod
    def _check_phrase(phrase: str, tokens: list[str]) -> bool:
        for token in tokens:
            # if re.search(rf"\b{re.escape(phrase)}\b", token):
            if phrase == token:
                return True
        return False

    def _analyze_data(self) -> None:
        # method to compare NER results to annotations
        true_positive: int = 0  # retrieved for entity class matches annotated
        false_positive: int = 0  # retrieved for entity class not in annotated
        false_negative: int = 0  # not retrieved for entity class but in annotated
        total_extracted: int = 0  # total entities extracted
        self._unprocessed = list(
            set([r[self._anno_record_id_key] for r in self._data_anno]).difference(
                set([r[self._ner_record_id_key] for r in self._data_ner])
            )
        )
        for record in zip(self._data_ner, self._data_anno):
            extr_ents: dict = record[0]
            anno_ents: dict = record[1]
            e_id: int = extr_ents.pop(self._ner_record_id_key, None)  # rec ID extracted
            a_id: int = anno_ents.pop(
                self._anno_record_id_key, None
            )  # rec ID annotated
            if e_id == a_id:
                self._pdb(f"\n{self._ner_record_id_key}: {e_id}")
                e_clses: list = list(extr_ents.keys())
                a_clses: list = list(anno_ents.keys())
                for ent_cls in e_clses:
                    total_extracted += len(
                        [e for e in extr_ents[ent_cls] if e != "NULL"]
                    )
                    if ent_cls in a_clses:
                        e_toks: list = [
                            CleanData([(e_id, t.upper())])[0][1]
                            for t in extr_ents[ent_cls]
                        ]
                        a_toks: list = [
                            CleanData([(a_id, t.upper())])[0][1]
                            for t in anno_ents[ent_cls]
                        ]
                        for e_tok in e_toks:
                            if self._check_phrase(e_tok, a_toks):
                                self._pdb(f"TRUE POS: '{e_tok}' is in {a_toks}")
                                a_toks = self._remove_one_occurrence(a_toks, e_tok)
                                true_positive += 1
                            else:
                                if e_tok != "NULL":
                                    self._pdb(f"FALSE POS: '{e_tok}' not in {a_toks}")
                                    false_positive += 1
                        if a_toks:  # if any outstanding remain in list
                            self._pdb(
                                f"FALSE NEG: These tokens for {ent_cls} were not retrieved by NER: {a_toks}"
                            )
                            false_negative += len(a_toks)
                    else:
                        e_toks = [
                            CleanData([(e_id, t.upper())])[0][1]
                            for t in extr_ents[ent_cls]
                        ]
                        for token in e_toks:
                            if token != "NULL":
                                self._pdb(
                                    f"FALSE POS: Extracted token(s) '{e_toks}' for {ent_cls} do not exist in annotation."
                                )
                                false_positive += 1
            else:
                self._pdb(
                    f"Aborting analysis for annotated record {a_id} & extracted record {e_id}: IDs do not match!"
                )
        self._analytics_data.update(
            {
                "true_positive": true_positive,
                "false_positive": false_positive,
                "false_negative": false_negative,
                "total_extracted": total_extracted,
            }
        )

    def _calculate_stats(self) -> None:
        # method to calculate the stats
        # {'true_positive': 50, 'false_positive': 32, 'false_negative': 52, 'total_extracted': 82}
        tp: int = self._analytics_data["true_positive"]
        fp: int = self._analytics_data["false_positive"]
        fn: int = self._analytics_data["false_negative"]
        te: int = self._analytics_data["total_extracted"]
        precision: float = tp / te if tp > 0 else 1.0 if (fn == 0 and te == 0) else 0
        recall: float = (
            tp / (tp + fn) if tp > 0 else 1.0 if (fn == 0 and te == 0) else 0
        )
        f1_score: float = (
            ((2 * precision * recall) / (precision + recall))
            if (precision > 0 and recall > 0)
            else 1.0
            if (tp == 0 and fn == 0 and te == 0)
            else 0
        )
        self._results.update(
            {
                "true_positive": tp,
                "false_positive": fp,
                "false_negative": fn,
                "total_extracted": te,
                "precision": round(precision, 2),
                "recall": round(recall, 2),
                "f1_score": round(f1_score, 2),
            }
        )

    def _get_results(self) -> dict:
        return self._results

In [None]:
# run the script

# script parameters
JUPYTER: bool = True  # running on Jupyter notebook? (True|False)
NER_RESULT_JSON: str = "../../data/output/ner/test_result_gpt_1.json"  # path to json file containing NER results
ANNOTATIONS_JSON: str = "../../data/output/prepared_anno/test_sample_1_annotated.json"  # path to json file containing the annotated data
NER_RECORD_ID_KEY: str = "RECORD_ID"  # key of record ID field in NER JSON
ANNO_RECORD_ID_KEY: str = "RECORD_ID"  # key of record ID field in annotation JSON
EXPORT_JSON: bool = True  # write results output to JSON file? (True|False)
EXPORT_XLSX: bool = True  # write results output to XLSX file? (True|False)
OUTPUT_JSON_PATH: str = "../../data/output/eval/test_eval_gpt_1.json"  # path > output JSON
OUTPUT_XLSX_PATH: str = "../../data/output/eval/test_eval_gpt_1.xlsx"  # path > output XLSX
DEBUG_OUTPUT: bool = False  # whether to write verbose debug output to STDOUT

if __name__ == "__main__":
    results = EvaluateModel(
        jupyter=JUPYTER,
        debug_output=DEBUG_OUTPUT,
        output_json_path=OUTPUT_JSON_PATH,
        output_xlsx_path=OUTPUT_XLSX_PATH,
        ner_record_id_key=NER_RECORD_ID_KEY,
        anno_record_id_key=ANNO_RECORD_ID_KEY,
        ner_json_path=NER_RESULT_JSON,
        anno_json_path=ANNOTATIONS_JSON,
        export_json=EXPORT_JSON,
        export_xlsx=EXPORT_XLSX,
    )

    pprint(results)