# 10 - Improvements in the data loss

This notebook:

- documents further assessment of data loss and the performance of the NER component

- demonstrates how improvements can be achieved and

- provides direction of future work.


In the same folder, notebook `03_evaluate_data_loss`  evaluates the loss of identifiers in the Synthea-LLM-NER pipeline. Further research showed that there can be significant improvements in the usage by the LLM and the subsequent detection by the NER of personal identifiers.

This notebook documents such a successful attempt. This notebook also assumes you have read the previous notebooks.

- new dataset generation using the same synthea dataset as the previous notebooks but with a different prompt for the LLM
- light pre-processesing of the synthea data
- evaluate data loss betwwen the Synthea-LLM-NER steps
- assess the effect on the privacy risk score 


Future work in the repo source code could include adding `age` as a field in the `Record` class and light synthea data preprosessing before being fed as input to the LLM.

In [None]:
import json
import os
from collections import defaultdict

import matplotlib.pyplot as plt
import pandas as pd
from fuzzywuzzy import fuzz

import privacy_fingerprint.extract.aws_comprehend as aws
import privacy_fingerprint.generate.language_model as llm
import privacy_fingerprint.generate.synthea as synthea
from privacy_fingerprint.common import compare_common_records
from privacy_fingerprint.common.config import (
    load_experiment_config,
    load_experiment_config_from_file,
    load_global_config_from_file,
)
from privacy_fingerprint.score import PrivacyRiskScorer, encode, preprocess

## Load configs, change the LLM prompt and create output dir


In [None]:
prompt = (
    "You are an InstructGPT. Describe this patient as if you were a medical doctor and "
    "include in your answer the provided date of birth, age and the "
    "NHS number of the patient."
)

print(prompt)

In [None]:
global_config = load_global_config_from_file("../global_config.yaml")
experiment_config = load_experiment_config_from_file(
    "../experiment_config.yaml"
)

In [None]:
# This is how we change the LLM prompt in a notebook.
# This change could also happen in the config file.
expt_config = load_experiment_config()
expt_config.openai.prompt = prompt
expt_config.synthea.encounter_type = (
    "Encounter for symptom"  # as it had been created in the original dataset
)
load_experiment_config(expt_config.dict())

In [None]:
# The outputs of this notebook will be saved to a directory
output_dir = "../../experiments/10_improve_data_loss"
os.makedirs(output_dir, exist_ok=True)

## Load previously generated synthea

Here you need to replace `synthea_dir` with the directory where the synthea output was saved.

In [None]:
# We are using a previously generated set of records they can be loaded as follows:
synthea_dir = "<...>"

with open(os.path.join(synthea_dir, "synthea_dataset.json")) as fp:
    synthea_records = json.load(fp)

## Processing before feeding the synthea data to the LLM

Here we keep from the ISO format of the visit date the date (default format "YYYY-MM-DD"), we calculate the patient's age from the provided date of visit and date of birth dates (i.e. the patient's age at the time of visit), and we remove the 'visit type' field, which although necessary when creating the synthea records, it does not add any information to the LLM.

In [None]:
for record in synthea_records:
    record["visit date"] = str(
        pd.to_datetime(record["visit date"], errors="coerce").date()
    )
    record["age"] = (
        pd.to_datetime(record["visit date"], errors="coerce").year
        - pd.to_datetime(record["date of birth"], errors="coerce").year
    )
    del record["visit type"]

## Create free-text clinical notes with the LLM

We pass to the LLM the promt as defined in this notebook, as well as the processed synthea records.

In [None]:
clinical_note_generator = llm.LMGenerator()

In [None]:
llm_results = clinical_note_generator.generate_text(synthea_records)

In [None]:
llm_results = list(llm_results)
print(*llm_results[:5], sep="\n\n------------------")

In [None]:
# code to save the generated LLM results
with open(os.path.join(output_dir, "llm_dataset.json"), "w") as fp:
    json.dump(llm_results, fp)

## Extract data from the unstructured text using the NER

We then perform the "reverse" step by using an NER service (AWS ComprehendMedical) to extract the information we injected into the unstructured records again. This is the most expensive step of the process, so a helper formula is provided based on the costs as of March 10th 2023. Updated costs can be found on the AWS documentation.

In [None]:
print("Estimated cost is $", aws.calculate_ner_cost(llm_results))

In [None]:
aws_extract = aws.ComprehendExtractor()
ner_records = [
    aws_extract.extract_record(medical_note) for medical_note in llm_results
]

The result is a list of dictionaries of extracted entities. Individual entities, their text spans, and the NER's confidence in the output can be viewed in the output.

In [None]:
# code to save the extracted ner results
with open(os.path.join(output_dir, "ner_dataset.json"), "w") as fp:
    json.dump(ner_records, fp)

## Convert to common format and compare

In order to compare the synthea records to the extracted ner records, we need to standardise their format. 

In this experiment we also injected the information `age` which is not included in the common format, so we will compare this apart. 

Further work could include adding `age` to the common format.

In [None]:
common_synthea_results = synthea.prepare_common_records(
    synthea.DEFAULT_IDENTIFIERS, synthea_records
)

In [None]:
common_ner_results = aws.prepare_common_records(
    aws.DEFAULT_IDENTIFIERS, ner_records
)

In [None]:
record_comparison_summary = []
for s, n in zip(common_synthea_results, common_ner_results):
    overall_score, max_score, summary = compare_common_records(s, n)
    record_comparison_summary.append(summary)

record_comparison_summary = pd.DataFrame(record_comparison_summary)
record_comparison_summary.plot.box(rot=90, ylabel="Data recovery (%)")

### Compare the `age` field

This field is not currently in the common format, so it is compared separately. In the future this field should be included in the common Record format, so this step would be redundant. 

In [None]:
# extract age from ner
def _extract_entity_type(record, entity_type):
    return [i for i in record if i["Type"] == entity_type]


def extract_age(record):
    candidates = _extract_entity_type(record, "AGE")
    if len(candidates) == 0:
        return None
    return [i["Text"] for i in candidates][0]


# extract age from synthea
def extract_age_sythea(record):
    return record["age"]

In [None]:
data = []
for s, n in zip(synthea_records, ner_records):
    age_synthea = extract_age_sythea(s)
    age_ner = extract_age(n["Entities"])
    data.append([age_synthea, age_ner])

In [None]:
age_df = pd.DataFrame(
    data,
    columns=["age_synthea", "age_ner"],
)
age_df["age"] = age_df.apply(
    lambda row: fuzz.ratio(str(row["age_synthea"]), str(row["age_ner"])),
    axis=1,
)

In [None]:
record_comparison_w_age = record_comparison_summary.copy(deep=True)
record_comparison_w_age["age"] = age_df["age"]

In [None]:
cols = [
    "nhs_number",
    "name",
    "age",
    "date_of_birth",
    "gender",
    "ethnicity",
    "disease",
    "date_of_visit",
    "department",
    "treatment",
    "prescription",
    "provider",
]

In [None]:
record_comparison_w_age[cols].plot.box(rot=90, ylabel="Data recovery (%)")

## Calculate the privacy risk scores of both synthea (initial) and ner (extracted) records

In [None]:
def simplify_ethnicity(text):
    text = text.lower()
    if text == "":
        return "unknown"
    mentions = defaultdict(int)
    for ethnicity, label in [
        ("white", "white"),
        ("black", "black"),
        ("african", "black"),
        ("asian", "asian"),
        ("indian", "asian"),
        ("pakistani", "asian"),
        ("chinese", "asian"),
    ]:
        if ethnicity in text:
            mentions[label] += 1
    if len(mentions) > 1:
        return "mixed"
    elif len(mentions) == 1:
        return list(mentions.keys())[0]
    else:
        return "unknown"


def simplify_date_of_birth(date):
    dt = pd.to_datetime(date, errors="coerce")
    if pd.isnull(dt):
        return None
    else:
        return 10 * (dt.year // 10)


transformations = {
    "gender": lambda x: x.lower()
    if x.lower() in ["female", "male"]
    else "unknown",
    "ethnicity": simplify_ethnicity,
    "date_of_birth": simplify_date_of_birth,
}

cols = [
    "date_of_birth",
    "gender",
    "ethnicity",
    "disease",
    "symptoms",
    "treatment",
    "prescriptions",
]

In [None]:
# calculate privacy risk score for ner records
pcm_dataset = preprocess(common_ner_results)

encoded_dataset, lookup = encode(
    pcm_dataset[cols].transform(
        {i: transformations.get(i, lambda x: x) for i in cols}
    )
)
scorer = PrivacyRiskScorer()
population_score = scorer.calculate_population_uniqueness(encoded_dataset)
print(population_score)
scorer.fit(encoded_dataset)
e2e = {
    "population_score": population_score,
    "individual_scores": scorer.predict(encoded_dataset),
}

In [None]:
# calculate privacy risk score for synthea records
synthea_pcm_dataset = preprocess(common_synthea_results)

encoded_dataset, lookup = encode(
    synthea_pcm_dataset[cols].transform(
        {i: transformations.get(i, lambda x: x) for i in cols}
    )
)
scorer = PrivacyRiskScorer()
population_score = scorer.calculate_population_uniqueness(encoded_dataset)
print(population_score)
scorer.fit(encoded_dataset)
initial_records = {
    "population_score": population_score,
    "individual_scores": scorer.predict(encoded_dataset),
}

In [None]:
print(
    "Population uniqueness on initial records",
    initial_records["population_score"],
)
print("Population uniqueness on extracted records", e2e["population_score"])

fig, ax = plt.subplots(1, 1)
ax.plot(initial_records["individual_scores"], e2e["individual_scores"], "k.")
ax.set_xlabel("Initial structured records")
ax.set_ylabel("NER extracted records")

In [None]:
comparison = pd.DataFrame(
    {
        "initial": initial_records["individual_scores"],
        "extract": e2e["individual_scores"],
    }
)
comparison["difference"] = (comparison.initial - comparison.extract).abs()

In [None]:
# Compare the ordering of records by privacy risk in the Synthea and extracted datasets
def compare_scores(a, b, label, ax=None, color=None):
    assert len(a) == len(b), "Lengths must match"
    if ax is None:
        fig, ax = plt.subplots(1, 1)
    c = pd.DataFrame({"a": a, "b": b})
    c = c.sort_values("b")
    c["b_rank"] = range(1, 1 + len(a))
    c = c.sort_values("a")
    c["a_rank"] = range(1, 1 + len(a))
    fraction_below = []
    for i in range(len(a)):
        fraction_below.append((c.iloc[:i].b_rank <= c.iloc[i].a_rank).sum())
    if color:
        ax.plot(fraction_below, label=label, color=color)
    else:
        ax.plot(fraction_below, label=label)
    return ax


ax = compare_scores(
    comparison.initial.tolist(),
    comparison.initial.tolist(),
    "Identity",
    color="#555555",
)
ax = compare_scores(
    comparison.initial.tolist(),
    comparison.extract.tolist(),
    "Extract",
    ax=ax,
    color="#c10078",
)

ax = compare_scores(
    comparison.initial.tolist(),
    comparison.initial.sample(frac=1).tolist(),
    "Random",
    ax=ax,
    color="#cccccc",
)
ax.legend()
ax.set_xlabel("Ranked scores from Synthea records")
ax.set_ylabel("Agreement following NER extraction")