In [0]:
%load_ext autoreload
%autoreload 1
%aimport data.adress

In [0]:
import sys
sys.path.append("..")
import pickle
from pprint import pprint

### Validation Datasets

In [0]:
import data.adress as ADReSS

data = ADReSS.load_CHAT_transcripts()
data = data[["Speaker", "Transcript", "Transcript_clean", "Utterance", "Vague speech"]]
data.head(10)

## LLM-Based Detector

In [0]:
from openai import OpenAI

DATABRICKS_TOKEN    = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
DB_ENDPOINT_URL     = "https://adb-2035410508966251.11.azuredatabricks.net/serving-endpoints"

client = OpenAI(api_key=DATABRICKS_TOKEN, base_url=DB_ENDPOINT_URL)

In [0]:
def run(prompt, model, utt_per_query):
    outputs = []
    for (trn_tst, pt_id), transcript in data.groupby(level=["train_test", "patient_id"]):
        for end in range(transcript.shape[0]):
            # Skip provider turns
            if transcript.loc[(trn_tst, pt_id, end), "Speaker"] == "Provider":
                continue

            # Sliding window
            start = max(0, end - utt_per_query + 1)
            text = "\n".join(transcript.loc[(slice(None), slice(None), slice(start, end)), "Utterance"].to_list())

            # Execute query
            response = client.chat.completions.create(
                model=model, 
                messages=[
                    {
                        "role": "user", 
                        "content": prompt.format(text)
                    }
                ]
            )
            outputs.append(response.choices[0].message.content)

    return outputs

### Prompts
Best version: v1

In [0]:
prompt_v1 = "Identify all instances of vague speech in the last line of the following transcript:\n\n{}\n\nList each instance of vague speech as a bullet point in the order that they are spoken. Do not include any explanations. If no repetitions are found, return \"None\"."

"Identify all vague words or phrases that indicate possible cognitive impairment (e.g., \"you know\" or \"that thing\") from the following transcript:"

In [0]:
outputs = run(prompt_v1, "openai_gpt_4o", 1)

In [0]:
with open("outputs_vague_speech.pkl", "wb") as f:
    pickle.dump(outputs, f)

### DEMO

## BERT-based detector

### Experiments to Run:
1. Iteratively mask each word in utterance, and have BERT predict the word. If any of the top-predicted words do not match the actual words then flag entire utterance as vague speech.

In [0]:
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline

model_name = "google-bert/bert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(model_name)
bert_model = AutoModelForMaskedLM.from_pretrained(model_name)

# add patient provider labels to vocabulary
n_added_toks = tokenizer.add_tokens(["[PATIENT]", "[PROVIDER]"])
bert_model.resize_token_embeddings(len(tokenizer))

bert_model = pipeline(
    task="fill-mask",
    model=bert_model,
    tokenizer=tokenizer,
    torch_dtype=torch.float16,
    device=0
)

In [0]:
import re

def mask_each_word(text):
    words = text.split()

    inputs, labels = [], []
    for i in range(len(words)):
        if re.match(r"\[PATIENT\]|\[PROVIDER\]", words[i]) or re.fullmatch(r"\W+", words[i]):
            continue

        labels.append(words[i])
        temp = words.copy()
        temp[i] = "[MASK]"
        inputs.append(" ".join(temp))

    return inputs, labels

def heuristic1(labels, outputs, top_k=5):
    ct = 0
    for actual_word, pred_words in zip(labels, outputs):
        if type(pred_words) == dict:
            return 999
        
        if all([actual_word != pred_words[i]["token_str"] for i in range(top_k)]):
            ct += 1
            
    return ct / len(labels)

def do(text):
    if text.split(maxsplit=1)[0] == "[PROVIDER]":
        return pd.NA
    
    inputs, labels = mask_each_word(text)
    outputs = bert_model(inputs)

    # Heuristics
    h1 = heuristic1(labels, outputs)
    return h1

In [0]:
data = adress_data.loc[:,["Speaker","Transcript_clean"]]
data["text"] = data.apply(lambda x: "[{}] {}".format(x["Speaker"].upper(), x["Transcript_clean"]), axis=1)
data["h1"] = data["text"].apply(do)
data