# Detecting Substitution Errors

In [0]:
%load_ext autoreload
%autoreload 1
%aimport data.adress
%aimport utils

In [0]:
import sys
sys.path.append("..")
import numpy as np
import pandas as pd
import re
from pprint import pprint
# Evaluation
from utils import evaluate
import seaborn as sns
from sklearn.metrics import roc_auc_score
from scipy.stats import ttest_ind, mannwhitneyu
# Generative AI
from openai import OpenAI
from utils import llm_call
import mlflow
from mlflow.genai.scorers import Safety, scorer
from mlflow.entities import Feedback

## Load the data

In [0]:
from data.adress import load_transcripts

In [0]:
adress_trans = load_transcripts()
adress_trans = adress_trans[["Speaker", "Transcript", "Transcript_clean", "Phonological Error", "Semantic Error", "Neologistic Error", "Morphological Error", "Dysfluency", "Substitution Error"]]

# Count paraphasias
adress_trans["num_word_errors"] = adress_trans["Transcript"].apply(lambda x: len(re.findall(r"\[*\s[a-z:]+\]", x)))
adress_trans["num_phonological"] = adress_trans["Transcript"].apply(lambda x: len(re.findall(r"\[*\sp[:a-z]+\]", x)))
adress_trans["num_semantic"] = adress_trans["Transcript"].apply(lambda x: len(re.findall(r"\[*\ss[:a-z]+\]", x)))
adress_trans["num_neologistic"] = adress_trans["Transcript"].apply(lambda x: len(re.findall(r"\[*\sn[:a-z]+\]", x)))
adress_trans["num_morphological"] = adress_trans["Transcript"].apply(lambda x: len(re.findall(r"\[*\sm[:a-z]+\]", x)))
adress_trans["num_dysfluency"] = adress_trans["Transcript"].apply(lambda x: len(re.findall(r"\[*\sd[:a-z]+\]", x)))

adress_trans.head()

In [0]:
trn_pt_utt_idx = (adress_trans.index.get_level_values("split") == "train") & (adress_trans["Speaker"] == "Patient")
dev_pt_utt_idx = (adress_trans.index.get_level_values("split") == "dev")   & (adress_trans["Speaker"] == "Patient")
tst_pt_utt_idx = (adress_trans.index.get_level_values("split") == "test")  & (adress_trans["Speaker"] == "Patient")

In [0]:
trn_pts = adress_trans.loc["train"].index.get_level_values("ID").unique().values
dev_pts = adress_trans.loc["dev"].index.get_level_values("ID").unique().values
tst_pts = adress_trans.loc["test"].index.get_level_values("ID").unique().values

## Baseline: MLM-Based Detector

In [0]:
import json
import os

In [0]:
for output_json in os.listdir("mlm_outputs"):
    percentile = output_json.strip(".json").split("_")[-1]

    with open(f"mlm_outputs/{output_json}", "r") as f:
        detections = json.load(f)
        # pprint(detections)

    adress_trans.loc[:, f"mlm_det_{percentile}"] = pd.NA

    for det in detections:
        if det["type"] == "paraphasia":
            split = [split for split, pt_list in zip(["train", "dev", "test"], [trn_pts, dev_pts, tst_pts]) if det["ID"] in pt_list][0]
            if adress_trans.loc[(split, det["ID"], det["utt_num"]), "Speaker"] == "Patient":
                cur_dets = adress_trans.loc[(split, det["ID"], det["utt_num"]), f"mlm_det_{percentile}"]

                if pd.isna(cur_dets):
                    cur_dets = {"detections": []} 

                cur_dets["detections"].append(det)
                
                adress_trans.at[(split, det["ID"], det["utt_num"]), f"mlm_det_{percentile}"] = cur_dets

Performance on train

In [0]:
for percentile in [90, 925, 95, 98, 99]:
    true = adress_trans.loc[trn_pt_utt_idx, "Substitution Error"]
    pred = adress_trans.loc[trn_pt_utt_idx, f"mlm_det_{percentile}"].apply(lambda x: len(x["detections"]) > 0 if not pd.isna(x) else False).astype(int)
    print(percentile, "&", evaluate(true, pred, return_latex=True))

Best performing on test

In [0]:
true = adress_trans.loc[tst_pt_utt_idx, "Substitution Error"]
pred = adress_trans.loc[tst_pt_utt_idx, "mlm_det_90"].apply(lambda x: len(x["detections"]) > 0 if not pd.isna(x) else False).astype(int)
print(evaluate(true, pred, return_latex=True))

## LLM-Based Detector

In [0]:
import json

In [0]:
mlflow_creds = mlflow.utils.databricks_utils.get_databricks_host_creds()

client = OpenAI(
    api_key=mlflow_creds.token,
    base_url=f"{mlflow_creds.host}/serving-endpoints"
)

In [0]:
datasets = {
    "train": [],
    "dev": [],
    "test": []
}

for (split, pt_id), grp in adress_trans.groupby(level=["split", "ID"]):
    transcript = "\n".join((grp.index.get_level_values("utt_num").astype(str) + ": [" + grp["Speaker"] + "] " + grp["Transcript_clean"]).values)
    datasets[split].append({
        "split": "train",
        "pt_id": pt_id,
        "inputs": {"text": transcript},
        "expectations": {"has_error": {utt_num: row["Substitution Error"] for (_, _, utt_num), row in grp.iterrows() if row["Speaker"] == "Patient"}}
    })

# pprint(datasets["train"])

In [0]:
@scorer
def correct(expectations, outputs):
    det_utts = [det["utt_num"] for det in outputs["detections"]]

    det_err = expectations["has_error"]
    val = 100 * sum([1 for utt_id in det_err if ((utt_id in det_utts) and det_err[utt_id]) or ((utt_id not in det_utts) and not det_err[utt_id])]) / len(det_err)

    return Feedback(value=val)

In [0]:
def process_mlflow_outputs(result):
    outputs = result.tables["eval_results"][["response", "assessments"]]
    outputs["labels"] = outputs["assessments"].apply(lambda x: [a["value"] for a in x if a["name"] == "has_error"][0])
    outputs["labels"] = outputs["labels"].apply(json.loads)
    return outputs

def extract_true_pred_from_ouputs(outputs):
    outputs["utts_with_dets"] = outputs["response"].apply(lambda x: np.unique([det["utt_num"] for det in x["detections"]]))

    true, pred = [], []
    for i, row in outputs.iterrows():
        true.extend(row["labels"].values())
        pred.extend([1 if int(utt_num) in row["utts_with_dets"] else 0 for utt_num in row["labels"]])
    
    return true, pred

In [0]:
with pd.option_context("display.max_colwidth", None):
    print(adress_trans.loc[(adress_trans.index.get_level_values(0) == "dev") & (adress_trans["Phonological Error"] == 1), ["Transcript"]])

In [0]:
with pd.option_context("display.max_colwidth", None):
    print(adress_trans.loc[("dev", "S082"), ["Speaker", "Transcript_clean"]])

#### Explore different prompts

In [0]:
# version = "1_3"
# prompt = '''# INSTRUCTIONS
# You are a neurologist analyzing a patient's speech sample for signs of cognitive impairment. 

# Your task is to identify all substitution errors in a patient's speech provided in the input below.

# ### Definition
# Substitution errors occur when a person involuntarily replaces their intended word with an unintended word while speaking. Focus on detecting the following five substitution errors types:
# - Phonemic paraphasias, where sounds within the intended word are added, dropped, substituted, or rearranged (e.g., saying ``papple'' for ``apple''). This does **not** include words with dropped sounds if the usage represents a common, informal speaking style rather than a clinical error (e.g., "gettin" for "getting").
# - Semantic paraphasias, where the intended word is substituted entirely with another real word (e.g., saying ``cat'' for ``dog'').
# - Neologisms, where the entire intended word is substituted with a non-word (e.g., saying "foundament" for "foundation").
# - Morphological errors, where the intended word is used in the incorrect form, such as the wrong number (e.g., saying "child" for "children") or tense (e.g., saying "walked" for "walk").
# - Intra-word dysfluencies, where the production of the intended word is disrupted by an inserted sound (e.g., saying "beuhcause" for "because").

# Only flag single words that are clinically significant substitution errors and use the surrounding utterances to better understand the context of any given word. 

# ### Output Format
# Your output must be a single JSON object with a single key "detections" whose value is an array of JSON objects. Each object in the array represents one detected substitution error and must have the following keys-value pairs:
# - "type": "substitution error".
# - "utt_num": The number of the utterance in which the error occurs.
# - "text": The verbatim text of the substition error.
# - "span": The character span for the "text" in the "utt_num"-th utterance.
# - "justification": A brief explanation of why the "text" is a substitution error within its specific context.

# # INPUT
# {input_text}
# '''

# version = "1_5_fewshot"
# fs_prompt = '''# INSTRUCTIONS
# You are a neurologist analyzing a patient's speech sample for signs of cognitive impairment. 

# Your task is to identify all substitution errors in a patient's speech provided in the input below.

# ### Definition
# Substitution errors occur when a person involuntarily replaces their intended word with an unintended word while speaking. Focus on detecting the following five substitution errors types:
# - Phonemic paraphasias, where sounds within the intended word are added, dropped, substituted, or rearranged (e.g., saying ``papple'' for ``apple''). 
# - Semantic paraphasias, where the intended word is substituted entirely with another real word (e.g., saying ``cat'' for ``dog'').
# - Neologisms, where the entire intended word is substituted with a non-word (e.g., saying "foundament" for "foundation").
# - Morphological errors, where the intended word is used in the incorrect form, such as the wrong number (e.g., saying "child" for "children") or tense (e.g., saying "walked" for "walk").
# - Intra-word dysfluencies, where the production of the intended word is disrupted by an inserted sound (e.g., saying "beuhcause" for "because").

# Only flag single words that are clinically significant substitution errors and use the surrounding utterances to better understand the context of any given word. **Focus on errors that seem unlikely to be caused by common, informal speaking patterns.**

# ### Output Format
# Your output must be a single JSON object with a single key "detections" whose value is an array of JSON objects. Each object in the array represents one detected substitution error and must have the following keys-value pairs:
# - "type": "substitution error".
# - "utt_num": The number of the utterance in which the error occurs.
# - "text": The verbatim text of the substition error.
# - "span": The character span for the "text" in the "utt_num"-th utterance.
# - "justification": A brief explanation of why the "text" is a substitution error within its specific context.

# # EXAMPLES
# **Input**
# 1. Patient: well there's a mother standing there uh uh washing the dishes an the sink is overspilling .
# 2. Patient: an uh the window's open .
# 3. Patient: and outside the window there's a walk with a c curved walk with a garden .

# **Correct Output**
# {{
#     "detections": [
#         {{"type": "substitution error", "utt_num": 1, "text": "overspilling", "span": [,], "justification": "The word 'overspilling' is a semantic paraphasia where the intended word was 'overflowing'.}},
#     ]
# }}

# **Input**
# 1. Provider: [laughs] well is there anything else that you can think of ?
# 2. Patient: but mostly uh is I I have uh not not so much trouble uh in I d it uh uh looking at a thing at it uh as um s um am an imarriage but not [silence] but not getting anything that you'll want s want [inaudible] .

# **Correct Output**
# {{
#     "detections": [
#         {{"type": "substitution error", "utt_num": 1, "text": "imarriage", "span": [,], "justification": "The word 'imarriage' is a phonemic paraphasia where the intended word was 'image'.}},
#     ]
# }}

# # INPUT
# {input_text}
# '''


version = "1_6"
prompt = '''# INSTRUCTIONS
You are a neurologist analyzing a patient's speech sample for signs of cognitive impairment. 

Your task is to identify all substitution errors in a patient's speech provided in the input below.

### Definition
Substitution errors occur when a person involuntarily replaces their intended word with an unintended word while speaking. Focus on detecting the following five substitution errors types:
- Phonemic paraphasias, where sounds within the intended word are added, dropped, substituted, or rearranged (e.g., saying ``papple'' for ``apple''). 
- Semantic paraphasias, where the intended word is substituted entirely with another real word (e.g., saying ``cat'' for ``dog'').
- Neologisms, where the entire intended word is substituted with a non-word (e.g., saying "foundament" for "foundation").
- Morphological errors, where the intended word is used in the incorrect form, such as the wrong number (e.g., saying "child" for "children") or tense (e.g., saying "walked" for "walk").
- Intra-word dysfluencies, where the production of the intended word is disrupted by an inserted sound (e.g., saying "beuhcause" for "because").

Only flag single words that are clinically significant substitution errors and use the surrounding utterances to better understand the context of any given word.

### Output Format
Your output must be a single JSON object with a single key "detections" whose value is an array of JSON objects. Each object in the array represents one detected substitution error and must have the following keys-value pairs:
- "type": "substitution error".
- "utt_num": The number of the utterance in which the error occurs.
- "text": The verbatim text of the substition error.
- "span": The character span for the "text" in the "utt_num"-th utterance.
- "justification": A brief explanation of why the "text" is a substitution error within its specific context.

# INPUT
{input_text}
'''

In [0]:
fn = lambda text: llm_call(client, "openai_gpt_4o", None, prompt.format(input_text=text), {"type": "json_object"})

with mlflow.start_run(run_name=f"gpt_trn_pV{version}"): 
    gpt_dets_trn = mlflow.genai.evaluate(
        data=datasets["train"],
        predict_fn=fn,
        scorers=[correct]
    )

In [0]:
outputs = process_mlflow_outputs(gpt_dets_trn)
outputs.to_pickle(f"llm_outputs/sub_err_trn_gpt_pV{version}")

true, pred = extract_true_pred_from_ouputs(outputs)
print(evaluate(true, pred, return_latex=True))

0.112 & 0.714 & 0.194 & 0.790 & 0.754 \\
0.118 & 0.657 & 0.200 & 0.814 & 0.739 \\
0.121 & 0.600 & 0.201 & 0.831 & 0.720 \\
0.152 & 0.714 & 0.251 & 0.850 & 0.784 \\
0.152 & 0.714 & 0.251 & 0.850 & 0.784 \\
0.117 & 0.600 & 0.196 & 0.826 & 0.717 \\


##### Evaluate the best configuration

In [0]:
fn = lambda text: llm_call(client, "openai_gpt_4o", None, prompt.format(input_text=text), {"type": "json_object"})

with mlflow.start_run(run_name=f"gpt_eval_tst_pV{version}") as run:
    gpt_dets_tst = mlflow.genai.evaluate(
        predict_fn=fn,
        data=datasets["test"],
        scorers=[correct]
    )

In [0]:
outputs = process_mlflow_outputs(gpt_dets_tst)
outputs.to_pickle(f"llm_outputs/sub_err_tst_gpt_pV{version}")

true, pred = extract_true_pred_from_ouputs(outputs)
print(evaluate(true, pred, return_latex=True))

## Summary metrics for paraphasia detections

In [0]:
from data.adress import load_outcomes
from utils import create_custom_nlp

In [0]:
outcomes = load_outcomes()
outcomes.head()

In [0]:
trn_pts = outcomes.loc[(outcomes.index.get_level_values("split") == "train")].index.values
trn_ad_pts = outcomes.loc[(outcomes.index.get_level_values("split") == "train") & (outcomes["AD_dx"] == 1)].index.values
trn_cn_pts = outcomes.loc[(outcomes.index.get_level_values("split") == "train") & (outcomes["AD_dx"] == 0)].index.values

In [0]:
tst_pts = outcomes.loc[(outcomes.index.get_level_values("split") == "test")].index.values
tst_ad_pts = outcomes.loc[(outcomes.index.get_level_values("split") == "test") & (outcomes["AD_dx"] == 1)].index.values
tst_cn_pts = outcomes.loc[(outcomes.index.get_level_values("split") == "test") & (outcomes["AD_dx"] == 0)].index.values

**Paraphasia Rate** = Total number of detected paraphasias / Total number of words spoken

In [0]:
nlp = create_custom_nlp()

def compute_paraphasia_rate(outputs):
    num = outputs.apply(lambda x: len(x["detections"]) if not pd.isna(x) else 0).groupby(level=("split", "ID")).sum()
    den = adress_trans.apply(lambda x: sum([1 for token in nlp(x["Transcript_clean"]) if not (token.is_punct or token.is_space or token._.is_silence_tag or token._.is_inaudible_tag or token._.is_event_tag)]) if x["Speaker"] == "Patient" else 0, axis=1).groupby(level=("split", "ID")).sum()
    return 100 * num / den

In [0]:
# outcomes["gt_paraphasia_rate"] = compute_paraphasia_rate(adress_trans["num_paraphasias"].apply(lambda x: {"detections": [0] * x}))
# outcomes["gt_phonological_rate"] = compute_paraphasia_rate(adress_trans["num_phonological"].apply(lambda x: {"detections": [0] * x}))
# outcomes["gt_semantic_rate"] = compute_paraphasia_rate(adress_trans["num_semantic"].apply(lambda x: {"detections": [0] * x}))
# outcomes["gt_neologistic_rate"] = compute_paraphasia_rate(adress_trans["num_neologistic"].apply(lambda x: {"detections": [0] * x}))
# outcomes["gt_morphological_rate"] = compute_paraphasia_rate(adress_trans["num_morphological"].apply(lambda x: {"detections": [0] * x}))
# outcomes["gt_dysfluency_rate"] = compute_paraphasia_rate(adress_trans["num_dysfluency"].apply(lambda x: {"detections": [0] * x}))
outcomes["mlm_sub_error_rate"] = compute_paraphasia_rate(adress_trans["mlm_det_90"])

**Inter-Paraphasia Distance**

In [0]:
from sklearn.preprocessing import MinMaxScaler
import pickle

In [0]:
def inter_sub_error_dist(output_name):
    ISED_metrics = pd.DataFrame(index=outcomes.index, columns=["mean_ISED", "std_ISED", "mean_ISED_norm", "mean_ISED_imputed", "std_ISED_imputed"], dtype=float)

    for split, pt_id in ISED_metrics.index:
        ISEDs = []
        for utt_num, row in adress_trans.loc[(split, pt_id)].iterrows():
            if row["Speaker"] == "Patient":     # skip provider speech
                if not pd.isna(row[output_name]) and len(row[output_name]["detections"]) > 1:   # can only compute ISED if there is more than one error
                    # get word spans for utterance
                    doc = nlp(row["Transcript_clean"])
                    word_spans = [(token.text, token.idx, token.idx + len(token.text)) for token in doc if not (token.is_punct or token.is_space or token._.is_silence_tag or token._.is_inaudible_tag or token._.is_event_tag)]
                    # print("word_spans", word_spans)

                    # get substitution error words indices
                    paraphasia_word_idxs = [word_spans.index((det["text"], det["span"][0], det["span"][1])) for det in row[output_name]["detections"]]
                    # print("paraphasia_word_idxs:", paraphasia_word_idxs)

                    # inter paraphasia distance
                    ISEDs.extend([paraphasia_word_idxs[i+1] - paraphasia_word_idxs[i] - 1 for i in range(len(paraphasia_word_idxs) - 1)])
                    # print("ISED", ISEDs[-1])
                    # break

        ISED_metrics.loc[(split, pt_id), "mean_ISED"] = np.mean(ISEDs)
        ISED_metrics.loc[(split, pt_id), "std_ISED"] = np.std(ISEDs)

    scaler = MinMaxScaler(feature_range=(0, 1))
    scaler.fit(ISED_metrics.loc["train", ["mean_ISED"]])
    with open("ISED_scaler.pkl", "wb") as f:
        pickle.dump(scaler, f)

    ISED_metrics["mean_ISED_norm"] = scaler.transform(ISED_metrics[["mean_ISED"]])
    ISED_metrics["mean_ISED_imputed"] = ISED_metrics["mean_ISED_norm"].fillna(1.0)
    ISED_metrics["std_ISED_imputed"] = ISED_metrics["std_ISED"].fillna(0)

    return ISED_metrics                            

In [0]:
with pd.option_context('display.max_rows', None, 'display.max_colwidth', None):
    outcomes[["mlm_mean_ISED", "mlm_std_ISED", "mlm_mean_ISED_imp", "mlm_std_ISED_imp"]] = inter_sub_error_dist("mlm_det_90")[["mean_ISED", "std_ISED", "mean_ISED_imputed", "std_ISED_imputed"]]

Test for statistically significant differences in feature values between the AD and non-AD groups.

In [0]:
# Generates the rows of Table 10 and 11
def analysis(split_idx, split_ad, split_cn, metrics):
    for m in metrics:
        ## score averages
        mean_ad = outcomes.loc[split_ad, m].mean()
        std_ad = outcomes.loc[split_ad, m].std()
        mean_cn = outcomes.loc[split_cn, m].mean()
        std_cn = outcomes.loc[split_cn, m].std()

        ## correlation metrics
        match = re.search(r"(\w+)_(\w+)_ISED", m)
        if match:
            # use imputed feature values for statistical tests
            m = "%s_%s_ISED_imp" % match.groups()

        res_ttest = ttest_ind(outcomes.loc[split_idx, "AD_dx"], outcomes.loc[split_idx, m])
        res_mannw = mannwhitneyu(outcomes.loc[split_idx, "AD_dx"], outcomes.loc[split_idx, m])
        res_auc   = roc_auc_score(outcomes.loc[split_idx, "AD_dx"], outcomes.loc[split_idx, m])

        print("%s & %.2f (%.2f) & %.2f (%.2f) & %.2f (%s) & %.2f (%s) & %.2f \\\\" % 
                (m,
                mean_ad,
                std_ad,
                mean_cn,
                std_cn,
                res_ttest.statistic,
                str(round(res_ttest.pvalue, 3)) if res_ttest.pvalue >= 0.001 else "$<$0.001", 
                res_mannw.statistic, 
                str(round(res_ttest.pvalue, 3)) if res_ttest.pvalue >= 0.001 else "$<$0.001",
                res_auc)
        )

In [0]:
mets = [
    # "gt_paraphasia_rate",
    # "gt_phonological_rate",
    # "gt_semantic_rate",
    # "gt_neologistic_rate",
    # "gt_morphological_rate",
    # "gt_dysfluency_rate",
    "mlm_sub_error_rate",
    "mlm_mean_ISED", 
    "mlm_std_ISED",
]

In [0]:
analysis(trn_pts, trn_ad_pts, trn_cn_pts, mets)

In [0]:
outcomes[["mlm_sub_error_rate", "mlm_mean_ISED_imp", "mlm_std_ISED_imp"]].to_csv("feature_data/sub_error_feats.csv")