In [0]:
%pip install -r ../../requirements.txt
%restart_python

In [0]:
%pip freeze

In [0]:
%load_ext autoreload
%autoreload 2

In [0]:
from src.dbxmetagen.config import MetadataConfig
from src.dbxmetagen.deterministic_pi import *
from presidio_analyzer import AnalyzerEngine, BatchAnalyzerEngine
from presidio_analyzer.dict_analyzer_result import DictAnalyzerResult

import pandas as pd
from pprint import pprint
from typing import Iterator

def format_presidio_batch_results(
    results: Iterator[DictAnalyzerResult], score_threshold: float = 0
) -> List:
    # Results must be a dict with col1_name and col2_name as keys.
    # col1 should be doc_id
    # col2 should be text
    col1, col2 = tuple(results)
    doc_ids = col1.value
    original_texts = col2.value
    recognizer_results = col2.recognizer_results

    output = []
    for i, res_doc in enumerate(recognizer_results):
        for j, res_ent in enumerate(res_doc):
            ans = res_ent.to_dict()
            ans["doc_id"] = doc_ids[i]
            ans["entity"] = original_texts[i][res_ent.start : res_ent.end]
            if ans.get("score", 0) > score_threshold:
                output.append(ans)
    return output

In [0]:
df = spark.table("dbxmetagen.eval_data.jsl_48docs")
display(df)

In [0]:
# text_dict = dict(df.select("doc_id", "text").limit(3).toPandas().values)
text_dict = df.select("doc_id", "text").distinct().toPandas().to_dict(orient="list")
text_dict

In [0]:
score_threshold = 0.5
#analyzer = AnalyzerEngine(default_score_threshold=score_threshold)
analyzer = get_analyzer_engine(add_pci=False, default_score_threshold=score_threshold)
batch_analyzer = BatchAnalyzerEngine(analyzer_engine=analyzer)
results = batch_analyzer.analyze_dict(
    text_dict,
    language="en",
    keys_to_skip=["doc_id"],
    score_threshold=score_threshold,
    batch_size=16,
    n_process=3,
)
results_copy = results

In [0]:
output = format_presidio_batch_results(results, score_threshold=score_threshold)
output

In [0]:
df_results = spark.createDataFrame(pd.DataFrame(output))
display(df_results)

In [0]:
from pyspark.sql.functions import lower, trim, col, asc_nulls_last

df_join = df.join(
    df_results,
    (lower(trim(df_results.entity)) == lower(trim(df.chunk)))
    & (df_results.doc_id == df.doc_id)
    & (df_results.start == df.begin),
    how="outer",
).drop("text").orderBy(asc_nulls_last(df.doc_id), asc_nulls_last(df.begin))
display(df_join)

In [0]:
import spacy

corpus = '\n'.join(text_dict['text'])
all_tokens = len(corpus)

In [0]:
pos_actual = df.count()
pos_pred = df_results.count()
tp = df_join.where(col("chunk").isNotNull() & col("entity").isNotNull()).count()
fp = pos_pred - tp

neg_actual = all_tokens - pos_actual
tn = neg_actual - fp
fn = pos_actual - tp
neg_pred = tn + fn

recall = tp/pos_actual
precision = tp/pos_pred
specificity = tn/neg_actual
npv = tn/neg_pred

neg_actual, pos_actual, neg_pred, pos_pred, tn, tp, fp, fn, recall, precision, specificity, npv

|          | Neg_actual |  Pos_actual |        |
|----------|------------|-------------|--------|
| Neg_pred | 249546     |  772        | 250318 |
| Pos_pred |    890     |  707        |   1597 |
|          | 250436     | 1479        |        |