# Detecting Repetitive Speech

In [0]:
%load_ext autoreload
%autoreload 1
%aimport data.adress

In [0]:
import sys
sys.path.append("..")
import numpy as np
import pickle
from pprint import pprint

### Validation Data

In [0]:
from data.adress import load_CHAT_transcripts

data = load_CHAT_transcripts()
data = data[["Speaker", "Transcript", "Transcript_clean", "Repetitive speech"]]
data.head()

## LLM-Based Detector

In [0]:
from openai import OpenAI

DATABRICKS_TOKEN    = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
DB_ENDPOINT_URL     = "https://adb-2035410508966251.11.azuredatabricks.net/serving-endpoints"

client = OpenAI(api_key=DATABRICKS_TOKEN, base_url=DB_ENDPOINT_URL)

In [0]:
def run(prompt, model, utt_per_query):
    outputs = []
    for (trn_tst, pt_id), transcript in data.groupby(level=["train_test", "patient_id"]):
        for end in range(transcript.shape[0]):
            # Skip provider turns
            if transcript.loc[(trn_tst, pt_id, end), "Speaker"] == "Provider":
                continue

            # Sliding window
            start = max(0, end - utt_per_query + 1)
            text = "\n".join(transcript.loc[(slice(None), slice(None), slice(start, end)), "Utterance"].to_list())

            # Execute query
            response = client.chat.completions.create(
                model=model, 
                messages=[
                    {
                        "role": "user", 
                        "content": prompt.format(text)
                    }
                ]
            )
            outputs.append(response.choices[0].message.content)

    return outputs

### Prompts
Best version: v1

In [0]:
prompt_v1 = "Identify all instances where the patient repeats the same sounds, words, or phrases--either consecutively or non-consecutively--in a way that indicates cognitive impairment in the following utterance:\n\n{}\n\nReturn a bullet point list, where each bullet contains a complete quote of the full phrase in which the repetition occurs, exactly as spoken by the patient. Start the quote at the first repeated sound, word, or phrase and end the quote after the last repeated sound, word, or phrase. Do not include any explanations. Do not include repetition prompted by the provider. If no repetitions are found, return \"None\"."

In [0]:
outputs = run(prompt_v1, "openai_gpt_4o", 1)

In [0]:
with open("outputs_repetitive_speech.pkl", "wb") as f:
    pickle.dump(outputs, f)

## Comparative Approach: n-gram overlap

In [0]:
import spacy
from collections import deque
from itertools import islice
# from thefuzz import fuzz
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim

In [0]:
!python -m spacy download en_core_web_sm

In [0]:
class RepetitionNgramAnalysis:
    def __init__(self):
        self.nlp = spacy.load("en_core_web_sm")
        self.sbert = SentenceTransformer("all-MiniLM-L6-v2")

        self.context_uni = deque(maxlen=5)
        self.context_bi  = deque(maxlen=10)
        self.context_tri = deque(maxlen=15)

    def ngrams(self, doc, n):
        return list(zip(*(islice(doc, i, None) for i in range(n))))

    def search(self, ngrams, doc, context, comparator="exact_match", sim_threshold=0.95):
        for ng in ngrams:
            for ctxt in context:             
                print(ng, ctxt)
                phrase1 = " ".join([tok.lemma_.lower() for tok in ng])
                phrase2 = " ".join([tok.lemma_.lower() for tok in ctxt])
                print(phrase1, phrase2)

                if comparator == "exact_match":
                    if phrase1 == phrase2:
                        print("REPETITION")                        
                elif comparator == "spacy_sim":
                    print(doc[ng[0].i, ng[-1].i])
                    print(doc[ctxt[0].i, ctxt[-1].i])
                    raise NotImplementedError(f"Comparator {comparator} not implemented.")
                # elif comparator == "fuzzy_ratio":
                #     if fuzz.ratio(phrase1, phrase2) >= sim_threshold:
                #         print("REPETITION")
                # elif comparator == "fuzzy_pratio":
                #     if fuzz.partial_ratio(phrase1, phrase2) >= sim_threshold:
                #         print("REPETITION")
                elif comparator == "sbert_sim":
                    raise NotImplementedError(f"Comparator {comparator} not implemented.")
                else:
                    raise NotImplementedError(f"Comparator {comparator} not implemented.")

            context.append(ng)

    def detect(self, text):
        doc = self.nlp(text)    # tokenize text

        output = np.zeros(len(doc))

        unigrams = self.ngrams(doc, 1)
        self.search(unigrams, doc, self.context_uni, "exact_match")

        # bigrams = self.ngrams(doc, 6)
        # self.search(bigrams, 6, 0, 0.7)

In [0]:
d = RepetitionNgramAnalysis()

outputs = []
for i, utt in data.iterrows():
    if utt["Speaker"] == "Patient":
        outputs.append( d.detect(utt["Transcript_clean"]) )

In [0]:
d = RepetitionNgramAnalysis()
d.detect("and I said the uh little sister's  uh reach Johnny's  Johnny he is  uh he is up on the ladder uh get gettin some cookies and the little sister reachin up  reach reaching up after some .")

In [0]:
data.loc[data["Repetitive speech"] == 1, ("Transcript", "Transcript_clean")].values