# Dementia Bank Datasets
- [Dementia Bank](https://talkbank.org/dementia/)
- [CHAT Transcription Manual](https://talkbank.org/0info/manuals/CHAT.pdf)

In [0]:
%load_ext autoreload
%autoreload 1
%aimport data.adress
%aimport data.adresso

In [0]:
import sys
sys.path.append("..")
import numpy as np
import pandas as pd
import re
import json
import matplotlib.pyplot as plt
from functools import partial
from tqdm import tqdm

## ADReSS

In [0]:
import data.adress as ADReSS

In [0]:
adress_trans = ADReSS.load_transcripts()
adress_trans.head()

In [0]:
adress_lbls = ADReSS.load_outcomes()
adress_lbls.head()

In [0]:
# Table 1
for split in ["train", "dev", "test"]:
    ad_mask = (adress_lbls.index.get_level_values("split") == split) & (adress_lbls["AD_dx"] == 1)
    cn_mask = (adress_lbls.index.get_level_values("split") == split) & (adress_lbls["AD_dx"] == 0)

    print("%s & %d & %d\\%% & %.1f $\\pm$ %.1f & %d\\%% & %.1f $\\pm$ %.1f \\\\" % 
          (split,
           adress_lbls.loc[split].shape[0],
           100 * (adress_lbls.loc[ad_mask, "gender"] == 1).sum() / ad_mask.sum(),
           adress_lbls.loc[ad_mask, "age"].mean(),
           adress_lbls.loc[ad_mask, "age"].std(),
           100 * (adress_lbls.loc[cn_mask, "gender"] == 1).sum() / cn_mask.sum(),
           adress_lbls.loc[cn_mask, "age"].mean(),
           adress_lbls.loc[cn_mask, "age"].std())
    )

Prevalence of utterances exhibiting MCI indicators. 

In [0]:
adress_trans[["Filler", "Repetition", "Revision", "Short pause", "Medium pause", "Long pause", "Speech delays", "Vague", "Phonological Paraphasia", "Semantic Paraphasia", "Neologistic Paraphasia", "Morphological Paraphasia", "Dysfluency", "Paraphasia"]].sum()

Distribution of MMSE scores.

In [0]:
bins = np.linspace(0, 30, 15)

plt.hist([adress_lbls.loc["train", "mmse"], adress_lbls.loc["test", "mmse"]], bins=bins, label=["Train", "Test"], edgecolor="k", zorder=3)
plt.xlabel("MMSE")
plt.ylabel("Frequency")
plt.legend(loc="upper left")
plt.xlim([-1, 31])
plt.grid(zorder=0)
plt.show()

#### Generating character spans for relevant annotations using a LLM.

In [0]:
import mlflow
import mlflow.genai
from mlflow.genai.scorers import RelevanceToQuery, Safety, Guidelines, scorer
from mlflow.entities import Feedback
from openai import OpenAI

In [0]:
mlflow_creds = mlflow.utils.databricks_utils.get_databricks_host_creds()

client = OpenAI(
    api_key=mlflow_creds.token,
    base_url=f"{mlflow_creds.host}/serving-endpoints"
)

In [0]:
@mlflow.trace
def get_span_labels(input_json: str, dev_prompt: str):
    response = client.chat.completions.create(
        model="openai_gpt_4o",
        messages=[
            {"role": "developer", "content": dev_prompt},
            {"role": "user", "content": input_json}
        ],
        response_format={"type": "json_object"},
        temperature=0.0,
        top_p=0.95
    )

    return response.choices[0].message.content
    # return json.loads(response.choices[0].message.content)

In [0]:
filler_prompt = '''# IDENTITY
You are an expert in CHAT transcription format. 

# INSTRUCTIONS
You will be provided with a JSON object where keys are indices and values are objects, each object representing an utterance. Each utterance object contains the "Annotated Utterance", "Clean Utterance", and an empty "Filler Spans" list.

Your task is to complete the provided JSON by populating the empty "Filler Spans" list for each utterance object. 

The CHAT format for annotating a filler word is to prefix the word with an ampersand (&) symbol.

To do this, for each utterance object:
1. Identify all words annotated as filler in the "Annotated Utterance".
2. For each filler word, determine the [start, end] character spans relative to the "Clean Utterance".

# OUTPUT FORMAT
Return the completed JSON object, with the "Filler Spans" lists filled in with a list of (start, end) character spans for fillers.
'''

repetition_prompt = '''# IDENTITY
You are an expert in CHAT transcription format. 

# INSTRUCTIONS
You will be provided with a JSON object where keys are indices and values are objects, each object representing an utterance. Each utterance object contains the "Annotated Utterance", "Clean Utterance", and an empty "Repetition Spans" list.

Your task is to complete the provided JSON by populating the empty "Repetition Spans" list for each utterance object. 

The CHAT format uses two different methods for marking repetitions, depending on the number of words repeated:
- For multi-word repetitions, the repeated phrase is enclosed in angle brackets (<...>), followed by the [/] marker.
- For single-word repetitions, the [/] symbol by itself indicates that the single word immediately before it is being repeated.

To do this, for each utterance object:
1. Identify all words annotated as repetition in the "Annotated Utterance". 
2. For each repetition event, determine the pair of character spans [[start1, end1], [start2, end2]] relative to the "Clean Utterance".
    - [start1, end1] is the span of the first occurrence (from inside the <...>).
    - [start2, end2] is the span of the second occurrence (from after the [/]).

# OUTPUT FORMAT
Return the completed JSON object, with the "Repetition Spans" lists filled in with a list of pairs of [start, end] character spans for repetitions. Each pair must correspond to a single annotated repetition event and be structured as [[start_occurrence_1, end_occurrence_1], [start_occurrence_2, end_occurrence_2]].
'''

In [0]:
def filler_spans(input_json):
    return get_span_labels(input_json, filler_prompt)

def repetition_spans(input_json):
    return get_span_labels(input_json, repetition_prompt)

In [0]:
eval_dataset = []
for idx, grp in adress_trans.groupby(level=("split", "ID")):
    
    input_table = grp[["Transcript", "Transcript_clean"]].copy()
    input_table.columns = ["Annotated Utterance", "Clean Utterance"]
    input_table.loc[:, "Filler Spans"] = [[] for _ in range(input_table.shape[0])]

    eval_dataset.append({"inputs": {"input_json": input_table.to_json(orient="index")}})

In [0]:
# TODO add scorer to check that they didnt change the utterances

def correct_num_spans(inputs, outputs, regex_pattern):
    data = json.loads(outputs)

    value = True
    rationale = ""
    for idx, output in data.items():
        n_amp = len(re.findall(regex_pattern, output["Annotated Utterance"]))
        n_spans = len(output["Filler Spans"])
        same = (n_amp == n_spans)

        value &= same
        if not same:
            rationale += f"For {idx} we found {n_spans} spans, expected {n_amp}. "

    return Feedback(value=value, rationale=rationale)

@scorer
def correct_num_filler_spans(inputs, outputs):
    return correct_num_spans(inputs, outputs, regex_pattern=r'&(?!=)\w+')

@scorer
def correct_num_repetition_spans(inputs, outputs):
    return correct_num_spans(inputs, outputs, regex_pattern=r'\[/\]')

@scorer
def correct_num_revision_spans(inputs, outputs):
    return correct_num_spans(inputs, outputs, regex_pattern=r'\[/\/\]')

scorers = [
    RelevanceToQuery(),
    Safety(),
    # NOTE the following only works when get_span_labels returns a JSON string (not a JSON dict) 
    # correct_num_filler_spans, 
    correct_num_repetition_spans,   
    # correct_num_revision_spans, 
]

In [0]:
with mlflow.start_run(run_name="repetition_v0"): 
    eval_results = mlflow.genai.evaluate(
        data=eval_dataset,
        predict_fn=filler_spans,
        scorers=scorers
    )

## ADReSSo

In [0]:
import data.adresso as ADReSSo

In [0]:
adresso_lbls = ADReSSo.load_labels()
adresso_lbls.head()

In [0]:
print("Num NA MMSE scores:", adresso_lbls["mmse"].isna().sum())
print()

print("ALL Pct AD: %.3f" % (adresso_lbls["AD_dx"].sum() / adresso_lbls.shape[0]))
print("TRAIN Pct AD: %.3f" % (adresso_lbls.loc["train", "AD_dx"].sum() / adresso_lbls.loc["train"].shape[0]))
print("TEST Pct AD: %.3f" % (adresso_lbls.loc["test", "AD_dx"].sum() / adresso_lbls.loc["test"].shape[0]))
print()

print("TRN Pct MMSE < 24 and No Dementia Dx: %.3f" % (adresso_lbls.loc[(adresso_lbls.index.get_level_values("split") == "train") & (adresso_lbls["mmse"] < 24).values & (adresso_lbls["AD_dx"] == 0).values].shape[0] / adresso_lbls.loc["train"].shape[0]))
print("TST Pct MMSE < 24 and No Dementia Dx: %.3f" % (adresso_lbls.loc[(adresso_lbls.index.get_level_values("split") == "test") & (adresso_lbls["mmse"] < 24).values & (adresso_lbls["AD_dx"] == 0).values].shape[0] / adresso_lbls.loc["test"].shape[0]))

Distribution of MMSE scores.

In [0]:
bins = np.linspace(0, 30, 15)

plt.hist([adresso_lbls.loc["train", "mmse"], adresso_lbls.loc["test", "mmse"]], bins=bins, label=["Train", "Test"], edgecolor="k", zorder=3)
plt.xlabel("MMSE")
plt.ylabel("Frequency")
plt.legend(loc="upper left")
plt.xlim([-1, 31])
plt.grid(zorder=0)
plt.show()