In [0]:
!pip install --upgrade openai
dbutils.library.restartPython()

Selecting cohort of Penn OBSERVER patients.

In [0]:
import os
import pandas as pd
import re
import tiktoken
import time
from openai import OpenAI

In [0]:
directory = "/Volumes/biomedicalinformatics_analytics/dev_lab_johnson/swimcap/Penn OBSERVER/problem_lists/"

idx, pls = [], []
for file in os.listdir(directory):
    with open(os.path.join(directory, file), "r") as fp:
        idx.append(file.rsplit(".")[0])
        pls.append(fp.read())

problem_lists = pd.DataFrame(data=pls, index=idx, columns=["problem_list"])
problem_lists

In [0]:
def num_tokens_from_messages(messages, model):
    """Return the number of tokens used by a list of messages."""
    try:
        encoding = tiktoken.encoding_for_model(model)
    except KeyError:
        print("Warning: model not found. Using o200k_base encoding.")
        encoding = tiktoken.get_encoding("o200k_base")
    if model in {
        "gpt-3.5-turbo-0125",
        "gpt-4-0314",
        "gpt-4-32k-0314",
        "gpt-4-0613",
        "gpt-4-32k-0613",
        "gpt-4o-mini-2024-07-18",
        "gpt-4o-2024-08-06"
        }:
        tokens_per_message = 3
        tokens_per_name = 1
    elif "gpt-3.5-turbo" in model:
        print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0125.")
        return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0125")
    elif "gpt-4o-mini" in model:
        print("Warning: gpt-4o-mini may update over time. Returning num tokens assuming gpt-4o-mini-2024-07-18.")
        return num_tokens_from_messages(messages, model="gpt-4o-mini-2024-07-18")
    elif "gpt-4o" in model:
        print("Warning: gpt-4o and gpt-4o-mini may update over time. Returning num tokens assuming gpt-4o-2024-08-06.")
        return num_tokens_from_messages(messages, model="gpt-4o-2024-08-06")
    elif "gpt-4" in model:
        print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
        return num_tokens_from_messages(messages, model="gpt-4-0613")
    else:
        raise NotImplementedError(
            f"""num_tokens_from_messages() is not implemented for model {model}."""
        )
    num_tokens = 0
    for message in messages:
        num_tokens += tokens_per_message
        for key, value in message.items():
            num_tokens += len(encoding.encode(value))
            if key == "name":
                num_tokens += tokens_per_name
    num_tokens += 3  # every reply is primed with <|start|>assistant<|message|>
    return num_tokens

In [0]:
# "content": f"Here is a patient's problem list:\n\n{row.list}\n\nBased on this, label this patient as either \"Probable Alzheimer's dementia (AD)\" or \"Healthy Control\". Probable AD should be assigned if the problem list includes relevant diagnostic terms or if there are multiple conditions strongly associated with Alzheimer’s dementia. If the label is \"Probable AD\", list the specific problems from the problem list that contributed to this decision. If the label is “Healthy Control”, do not include a list of problems. Format your response as follows:\n\nLabel: <Probable AD or Healthy Control>\n{{If \"Probable AD\", include the following line:}}\nRelevant problems: <comma-separated list of problems from the problem list relevant to the label>"


In [0]:
DATABRICKS_TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()

client = OpenAI(
    api_key=DATABRICKS_TOKEN,
    base_url="https://adb-2035410508966251.11.azuredatabricks.net/serving-endpoints"
)

model   = "openai_gpt_4o"
tpm     = 1e6
rpm     = 6150

prompt  = (
    "Here is a patient's problem list summarizing their active health issues "
    "(e.g., diagnoses, chronic conditions, injuries):\n\n"
    "{}\n\n"
    "Based on this information, label the patient as either:\n\n"
    "- Probable MCI (Mild Cognitive Impairment), or\n"
    "- Healthy Control\n\n"
    "If the problem list includes multiple conditions that are commonly associated "
    "with cognitive decline, MCI, or Alzheimer's dementia (e.g., memory loss, gait abnormality), consider assigning "
    'the \"Probable MCI\" label. Otherwise, assign \"Healthy Control\".\n\n'
    "Format your response as follows:\n\n"
    "- Label: <Probable MCI or Healthy Control>\n"
    "- Reason: <comma-separated list of relevant issues from the problem list, or \"N/A\" if Healthy Control>"
)


tokens_used = 0
t = time.time()
for i, row in problem_lists.iterrows():
    messages = [{"role": "user", "content": prompt.format(row.problem_list)}]
    n_query_tokens = num_tokens_from_messages(messages, model.split("_", maxsplit=1)[1].replace("_", "-"))
    print("N query tokens:", n_query_tokens)

    # elapsed = time.time() - t
    # if elapsed < 60:
    #     if tokens_used + n_query_tokens > tpm:
    #         # print("sleeping...")
    #         time.sleep(60 - elapsed)
    #         tokens_used += n_query_tokens
    # else:
    #     tokens_used = 0
    #     t = time.time()

    response = client.chat.completions.create(
        model=model,
        messages=messages
    )
    output = response.choices[0].message.content
    n_response_tokens = response.usage.completion_tokens
    print("N response tokens:", n_response_tokens)

    tokens_used += n_response_tokens

    matches = re.findall(r'^\s*[^:]+:\s*(.*)', output, re.MULTILINE)
    problem_lists.loc[i, "generated_label"] = matches[0].rstrip()
    problem_lists.loc[i, "reason"] = matches[1].rstrip()


In [0]:
problem_lists.to_excel("visit_problem_lists_labeled.xlsx")