# Penn Data

In [0]:
%load_ext autoreload
%autoreload 1
%aimport data.observer

In [0]:
import sys
sys.path.append("..")
import data.observer as OBSERVER
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

## Load the Data

#### Cognitive Impairment Outcomes

In [0]:
labels = pd.read_excel("/Volumes/biomedicalinformatics_analytics/dev_lab_johnson/watch/penn_CI_labels.xlsx")
labels["patient_id"] = labels["patient_id"].str.strip()
labels.head()

In [0]:
labels.shape

In [0]:
print("%d & %d\\%% & %.1f $\\pm$ %.1f & %d & %d\\%% & %.1f $\\pm$ %.1f \\\\" %
      ((labels["CI"] == 1).sum(),
        100 * ((labels["CI"] == 1) & (labels["Gender"] == "Female")).sum() / (labels["CI"] == 1).sum(),
        labels.loc[labels["CI"] == 1, "Age (in Years)"].mean(),
        labels.loc[labels["CI"] == 1, "Age (in Years)"].std(),
       (labels["CI"] == 0).sum(),
        100 * ((labels["CI"] == 0) & (labels["Gender"] == "Female")).sum() / (labels["CI"] == 0).sum(),
        labels.loc[labels["CI"] == 0, "Age (in Years)"].mean(),
        labels.loc[labels["CI"] == 0, "Age (in Years)"].std()
      ))

Downsampling the cognitively normal cohort to be balanced with the cognitively impaired cohort.

In [0]:
np.random.seed(1234567)

ci_mean_age = labels.loc[labels["CI"] == 1, "Age (in Years)"].mean()
ci_pct_female = 100 * ((labels["CI"] == 1) & (labels["Gender"] == "Female")).sum() / (labels["CI"] == 1).sum()

cn_patients = labels.loc[labels["CI"] == 0, "patient_id"]

results = []
for i in range(10000):
    sample = np.random.choice(cn_patients, size=(labels["CI"] == 1).sum(), replace=False)
    results.append({
        "pt_ids": sample, 
        "avg_age": labels.loc[labels["patient_id"].isin(sample), "Age (in Years)"].mean(),
        "pct_female": 100 * (labels["patient_id"].isin(sample) & (labels["Gender"] == "Female")).sum() / sample.shape[0]
    })
    
candidate_subsamples = pd.DataFrame(results)
candidate_subsamples["mean_age_diff"] = abs(ci_mean_age - candidate_subsamples["avg_age"])
candidate_subsamples["pct_female_diff"] = abs(ci_pct_female - candidate_subsamples["pct_female"])
candidate_subsamples["composite_diff"] = candidate_subsamples["mean_age_diff"] + candidate_subsamples["pct_female_diff"]
pt_subsample = candidate_subsamples.loc[candidate_subsamples["composite_diff"].idxmin(), "pt_ids"]
pt_subsample

'PT025', 'PT099', 'PT075', 'PT016', 'PT044', 'PT085', 'PT035', 'PT103', 'PT036', 'PT058', 'PT002', 'PT095'

In [0]:
labels_ds = labels.loc[(labels["CI"] == 1) | labels["patient_id"].isin(pt_subsample)]

print("%d & %d\\%% & %.1f $\\pm$ %.1f & %d & %d\\%% & %.1f $\\pm$ %.1f \\\\" %
      ((labels_ds["CI"] == 1).sum(),
        100 * ((labels_ds["CI"] == 1) & (labels_ds["Gender"] == "Female")).sum() / (labels_ds["CI"] == 1).sum(),
        labels_ds.loc[labels_ds["CI"] == 1, "Age (in Years)"].mean(),
        labels_ds.loc[labels_ds["CI"] == 1, "Age (in Years)"].std(),
       (labels_ds["CI"] == 0).sum(),
        100 * ((labels_ds["CI"] == 0) & (labels_ds["Gender"] == "Female")).sum() / (labels_ds["CI"] == 0).sum(),
        labels_ds.loc[labels_ds["CI"] == 0, "Age (in Years)"].mean(),
        labels_ds.loc[labels_ds["CI"] == 0, "Age (in Years)"].std()
      ))

#### Cognitive Test Scores

In [0]:
lbls = OBSERVER.load_labels()
lbls.head()

In [0]:
bins = np.linspace(0, 30, 15)

plt.hist(lbls["MMSE"], bins=bins, edgecolor="k", zorder=3)
plt.xlabel("MMSE")
plt.ylabel("Frequency")
plt.xlim([-1, 31])
plt.grid(zorder=0)
plt.show()

In [0]:
bins = np.linspace(0, 30, 15)

plt.hist(lbls["FRS"], bins=bins, edgecolor="k", zorder=3)
plt.xlabel("FRS")
plt.ylabel("Frequency")
plt.xlim([-1, 31])
plt.grid(zorder=0)
plt.show()

## Labeling Probable MCI using GPT

In [0]:
!pip install --upgrade openai
dbutils.library.restartPython()

In [0]:
import os
import pandas as pd
import re
import tiktoken
import time
from openai import OpenAI

Load the data

In [0]:
directory = "/Volumes/biomedicalinformatics_analytics/dev_lab_johnson/swimcap/Penn OBSERVER/problem_lists/"

idx, pls = [], []
for file in os.listdir(directory):
    with open(os.path.join(directory, file), "r") as fp:
        idx.append(file.rsplit(".")[0])
        pls.append(fp.read())

problem_lists = pd.DataFrame(data=pls, index=idx, columns=["problem_list"])
problem_lists

In [0]:
def num_tokens_from_messages(messages, model):
    """Return the number of tokens used by a list of messages."""
    try:
        encoding = tiktoken.encoding_for_model(model)
    except KeyError:
        print("Warning: model not found. Using o200k_base encoding.")
        encoding = tiktoken.get_encoding("o200k_base")
    if model in {
        "gpt-3.5-turbo-0125",
        "gpt-4-0314",
        "gpt-4-32k-0314",
        "gpt-4-0613",
        "gpt-4-32k-0613",
        "gpt-4o-mini-2024-07-18",
        "gpt-4o-2024-08-06"
        }:
        tokens_per_message = 3
        tokens_per_name = 1
    elif "gpt-3.5-turbo" in model:
        print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0125.")
        return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0125")
    elif "gpt-4o-mini" in model:
        print("Warning: gpt-4o-mini may update over time. Returning num tokens assuming gpt-4o-mini-2024-07-18.")
        return num_tokens_from_messages(messages, model="gpt-4o-mini-2024-07-18")
    elif "gpt-4o" in model:
        print("Warning: gpt-4o and gpt-4o-mini may update over time. Returning num tokens assuming gpt-4o-2024-08-06.")
        return num_tokens_from_messages(messages, model="gpt-4o-2024-08-06")
    elif "gpt-4" in model:
        print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
        return num_tokens_from_messages(messages, model="gpt-4-0613")
    else:
        raise NotImplementedError(
            f"""num_tokens_from_messages() is not implemented for model {model}."""
        )
    num_tokens = 0
    for message in messages:
        num_tokens += tokens_per_message
        for key, value in message.items():
            num_tokens += len(encoding.encode(value))
            if key == "name":
                num_tokens += tokens_per_name
    num_tokens += 3  # every reply is primed with <|start|>assistant<|message|>
    return num_tokens

In [0]:
DATABRICKS_TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()

client = OpenAI(
    api_key=DATABRICKS_TOKEN,
    base_url="https://adb-2035410508966251.11.azuredatabricks.net/serving-endpoints"
)

model   = "openai_gpt_4o"
tpm     = 1e6
rpm     = 6150

prompt  = (
    "Here is a patient's problem list summarizing their active health issues "
    "(e.g., diagnoses, chronic conditions, injuries):\n\n"
    "{}\n\n"
    "Based on this information, label the patient as either:\n\n"
    "- Probable MCI (Mild Cognitive Impairment), or\n"
    "- Healthy Control\n\n"
    "If the problem list includes multiple conditions that are commonly associated "
    "with cognitive decline, MCI, or Alzheimer's dementia (e.g., memory loss, gait abnormality), consider assigning "
    'the \"Probable MCI\" label. Otherwise, assign \"Healthy Control\".\n\n'
    "Format your response as follows:\n\n"
    "- Label: <Probable MCI or Healthy Control>\n"
    "- Reason: <comma-separated list of relevant issues from the problem list, or \"N/A\" if Healthy Control>"
)


tokens_used = 0
t = time.time()
for i, row in problem_lists.iterrows():
    messages = [{"role": "user", "content": prompt.format(row.problem_list)}]
    n_query_tokens = num_tokens_from_messages(messages, model.split("_", maxsplit=1)[1].replace("_", "-"))
    print("N query tokens:", n_query_tokens)

    response = client.chat.completions.create(
        model=model,
        messages=messages
    )
    output = response.choices[0].message.content
    n_response_tokens = response.usage.completion_tokens
    print("N response tokens:", n_response_tokens)

    tokens_used += n_response_tokens

    matches = re.findall(r'^\s*[^:]+:\s*(.*)', output, re.MULTILINE)
    problem_lists.loc[i, "generated_label"] = matches[0].rstrip()
    problem_lists.loc[i, "reason"] = matches[1].rstrip()


In [0]:
problem_lists.to_excel("visit_problem_lists_labeled.xlsx")