In [None]:
import yaml
from openai import AzureOpenAI
import pandas as pd
import torch
from torch import softmax

In [30]:
with open("../secrets/azure_keys.yaml", "r") as file:
    secrets = yaml.safe_load(file)
    gpt4o_config = secrets["azure_gpt-4o_deployment"]

In [31]:
client = AzureOpenAI(
    api_version=gpt4o_config["api_version"],
    azure_endpoint=gpt4o_config["endpoint_uri"],
    api_key=gpt4o_config["api_key"],
)
deployment = gpt4o_config["deployment"]

In [70]:
system_message = (
    "You are a helpful assistant and medical professional that analyzes ICU time-series "
    "data and determines the most likely diagnosis.\n\n"
    "Be specific and check the values against reference values.\n"
    "Return the result strictly in this JSON format.\n"
    "Example:\n"
    "{\n"
    '  "diagnosis": "<aki or not-aki>",\n'
    '  "classification": "<a float value between 0 and 1 representing the probability of your diagnosis>",\n'
    '  "explanation": "<a brief explanation for the prediction. state reference values and check against provided features>"\n'
    "}\n\n"
)

In [71]:
user_message = (
    "Suggest a diagnosis of aki for the following patient data.\n"
    "Patient data:\n\n"
    "Patient Info — index: 0, age: 91.0, sex: Female, height: 165.1, weight: 79.4\n"
    "Albumin (unit: g/dL): min=3.070, max=3.070, mean=3.070\n"
    "Alkaline Phosphatase (unit: U/L): min=116.680, max=116.680, mean=116.680\n"
    "Alanine Aminotransferase (ALT) (unit: U/L): min=157.960, max=157.960, mean=157.960\n"
    "Aspartate Aminotransferase (AST) (unit: U/L): min=216.220, max=216.220, mean=216.220\n"
    "Base Excess (unit: mmol/L): min=-0.540, max=-0.540, mean=-0.540\n"
    "Bicarbonate (unit: mmol/L): min=24.030, max=24.030, mean=24.030\n"
    "Total Bilirubin (unit: mg/dL): min=1.950, max=2.790, mean=2.370\n"
    "Band Neutrophils (unit: %): min=5.000, max=5.000, mean=5.000\n"
    "Blood Urea Nitrogen (BUN) (unit: mg/dL): min=22.600, max=22.600, mean=22.600\n"
    "Calcium (unit: mg/dL): min=8.330, max=8.330, mean=8.330\n"
    "Ionized Calcium (unit: mmol/L): min=1.130, max=1.130, mean=1.130\n"
    "Creatine Kinase (CK) (unit: U/L): min=1367.040, max=1367.040, mean=1367.040\n"
    "Creatine Kinase-MB (CK-MB) (unit: ng/mL): min=23.420, max=23.420, mean=23.420\n"
    "Chloride (unit: mmol/L): min=104.510, max=104.510, mean=104.510\n"
    "Creatinine (unit: mg/dL): min=1.040, max=1.040, mean=1.040\n"
    "C-Reactive Protein (CRP) (unit: mg/L): min=79.480, max=79.480, mean=79.480\n"
    "Diastolic Blood Pressure (unit: mmHg): min=40.000, max=54.000, mean=46.833\n"
    "Fibrinogen (unit: mg/dL): min=277.650, max=277.650, mean=277.650\n"
    "Fraction of Inspired Oxygen (FiO2) (unit: %): min=35.000, max=35.000, mean=35.000\n"
    "Glucose (unit: mg/dL): min=140.370, max=140.370, mean=140.370\n"
    "Hemoglobin (unit: g/dL): min=10.250, max=10.250, mean=10.250\n"
    "Heart Rate (unit: bpm): min=69.000, max=81.000, mean=72.000\n"
    "inr (unit: ): min=1.470, max=1.470, mean=1.470\n"
    "Potassium (unit: mmol/L): min=4.090, max=4.090, mean=4.090\n"
    "Lactate (unit: mmol/L): min=2.420, max=2.420, mean=2.420\n"
    "Lymphocytes (unit: %): min=13.370, max=13.370, mean=13.370\n"
    "Mean Arterial Pressure (MAP) (unit: mmHg): min=58.000, max=72.500, mean=64.583\n"
    "Mean Corpuscular Hemoglobin (MCH) (unit: pg): min=30.030, max=30.030, mean=30.030\n"
    "Mean Corpuscular Hemoglobin Concentration (MCHC) (unit: g/dL): min=33.080, max=33.080, mean=33.080\n"
    "Mean Corpuscular Volume (MCV) (unit: fL): min=90.860, max=90.860, mean=90.860\n"
    "Methemoglobin (unit: %): min=1.310, max=1.310, mean=1.310\n"
    "Magnesium (unit: mg/dL): min=2.050, max=2.050, mean=2.050\n"
    "Sodium (unit: mmol/L): min=138.880, max=138.880, mean=138.880\n"
    "Neutrophils (unit: %): min=76.840, max=76.840, mean=76.840\n"
    "Oxygen Saturation (unit: %): min=91.500, max=97.000, mean=94.667\n"
    "Partial Pressure of Carbon Dioxide (PaCO2) (unit: mmHg): min=42.440, max=42.440, mean=42.440\n"
    "pH Level (unit: /): min=7.380, max=7.380, mean=7.380\n"
    "Phosphate (unit: mg/dL): min=3.300, max=3.300, mean=3.300\n"
    "Platelets (unit: 1000/µL): min=196.570, max=196.570, mean=196.570\n"
    "Partial Pressure of Oxygen (PaO2) (unit: mmHg): min=150.250, max=150.250, mean=150.250\n"
    "Partial Thromboplastin Time (PTT) (unit: sec): min=40.540, max=40.540, mean=40.540\n"
    "Respiratory Rate (unit: breaths/min): min=14.000, max=22.000, mean=17.750\n"
    "Systolic Blood Pressure (unit: mmHg): min=103.000, max=140.000, mean=123.333\n"
    "Temperature (unit: °C): min=36.222, max=36.278, mean=36.269\n"
    "Troponin T (unit: ng/mL): min=0.930, max=0.930, mean=0.930\n"
    "Urine Output (unit: mL/h): min=80.000, max=200.000, mean=116.667\n"
    "White Blood Cell Count (WBC) (unit: 1000/µL): min=11.890, max=11.890, mean=11.890\n"
)

In [72]:
response = client.chat.completions.create(
    messages=[
        {
            "role": "developer",
            "content": system_message,
        },
        {
            "role": "user",
            "content": user_message,
        },
    ],
    max_tokens=10000,
    temperature=1.0,
    top_p=1.0,
    model=deployment,
    logprobs=True,
    top_logprobs=5
)

print(response.choices[0].message.content)

```json
{
  "diagnosis": "not-aki",
  "classification": 0.2,
  "explanation": "The patient has a creatinine level of 1.040 mg/dL, which is within the normal range for adult females (0.6–1.1 mg/dL). Urine output is 116.667 mL/h on average, which does not meet the criteria for oliguria (<0.5 mL/kg/h for more than six hours). Blood Urea Nitrogen (BUN) is elevated at 22.600 mg/dL (normal range: 6–20 mg/dL), but this alone does not meet the diagnostic criteria for AKI without significantly elevated creatinine or markedly reduced urine output. Other laboratory values such as electrolytes (Potassium, Sodium) and pH are within acceptable ranges, with no evidence of acid-base imbalance. Thus, the likelihood of acute kidney injury (AKI) is low."
}
```


In [None]:
# Extract relevant data from logprobs_content
logprobs_content = response.choices[0].logprobs.content
logprobs_data = [
    {
        "token": token_logprob.token,
        "logprob": token_logprob.logprob,
        "top_logprobs": token_logprob.top_logprobs,
    }
    for token_logprob in logprobs_content
]

# Create a DataFrame
logprobs_df = pd.DataFrame(logprobs_data)
logprobs_df.head(20)

Unnamed: 0,token,logprob,top_logprobs
0,```,-0.04864641,"[TopLogprob(token='```', bytes=[96, 96, 96], l..."
1,json,-4.3202e-07,"[TopLogprob(token='json', bytes=[106, 115, 111..."
2,\n,-5.512237e-07,"[TopLogprob(token='\n', bytes=[10], logprob=-5..."
3,{\n,0.0,"[TopLogprob(token='{\n', bytes=[123, 10], logp..."
4,,-3.292908e-06,"[TopLogprob(token=' ', bytes=[32], logprob=-3...."
5,"""",0.0,"[TopLogprob(token=' ""', bytes=[32, 34], logpro..."
6,diagn,0.0,"[TopLogprob(token='diagn', bytes=[100, 105, 97..."
7,osis,0.0,"[TopLogprob(token='osis', bytes=[111, 115, 105..."
8,""":",-1.936127e-07,"[TopLogprob(token='"":', bytes=[34, 58], logpro..."
9,"""",0.0,"[TopLogprob(token=' ""', bytes=[32, 34], logpro..."


In [None]:
# Convert logprob column to tensor and apply softmax
logprobs_df["probability"] = logprobs_df["logprob"].apply(
	lambda x: softmax(torch.tensor([x]), dim=0).item()
)
# Convert top_logprobs to DataFrame
logprobs_df["top_logprobs_logits"] = None
for i in range(len(logprobs_df)):
    logprob_list = []
    for j in range(len(logprobs_df["top_logprobs"][i])):
        logprob_list.append(logprobs_df["top_logprobs"][i][j].logprob)
    logprobs_df["top_logprobs_logits"][i] = logprob_list

In [69]:
from scipy.special import softmax
pd.options.display.float_format = "{:.8f}".format
logprobs_df["top_logprobs_probs"] = [softmax(x).tolist() for x in logprobs_df["top_logprobs_lists"]]
# logprobs_df["top_logprobs_probs"] = logprobs_df["top_logprobs_probs"].apply(
#     lambda probs: [round(prob, 8) for prob in probs]
# )
logprobs_df

Unnamed: 0,token,logprob,top_logprobs,top_logprobs_lists,top_logprobs_probs
0,```,-0.04864641,"[TopLogprob(token='```', bytes=[96, 96, 96], l...","[-0.048646413, -3.0486465, -10.548646, -10.798...","[0.9525263055646706, 0.04742348817149348, 2.62..."
1,json,-0.00000043,"[TopLogprob(token='json', bytes=[106, 115, 111...","[-4.3202e-07, -14.75, -20.0, -20.25, -21.125]","[0.9999996028780911, 3.927863682555673e-07, 2...."
2,\n,-0.00000055,"[TopLogprob(token='\n', bytes=[10], logprob=-5...","[-5.5122365e-07, -15.000001, -15.875001, -19.1...","[0.9999995598155925, 3.0590204856675866e-07, 1..."
3,{\n,0.00000000,"[TopLogprob(token='{\n', bytes=[123, 10], logp...","[0.0, -17.25, -22.125, -22.875, -24.875]","[0.9999999673799445, 3.2241866320835826e-08, 2..."
4,,-0.00000329,"[TopLogprob(token=' ', bytes=[32], logprob=-3....","[-3.2929079e-06, -12.875003, -14.375003, -15.7...","[0.9999966695647478, 2.561280990954953e-06, 5...."
...,...,...,...,...,...
178,kidney,-2.34647100,"[TopLogprob(token=' AK', bytes=[32, 65, 75], l...","[-0.22147098, -2.346471, -3.221471, -3.721471,...","[0.8132628209299555, 0.09713039074202238, 0.04..."
179,injury,-0.44135475,"[TopLogprob(token=' injury', bytes=[32, 105, 1...","[-0.44135475, -1.1913548, -3.6913548, -3.69135...","[0.6442873739747746, 0.30433979060203536, 0.02..."
180,".""\n",-0.29979340,"[TopLogprob(token='.""\n', bytes=[46, 34, 10], ...","[-0.2997934, -1.5497934, -4.0497932, -4.299793...","[0.743769984879595, 0.21309366842863503, 0.017..."
181,}\n,0.00000000,"[TopLogprob(token='}\n', bytes=[125, 10], logp...","[0.0, -16.875, -20.0, -20.375, -21.875]","[0.9999999492945119, 4.6911637839666563e-08, 2..."


In [None]:
[
    TopLogprob(token="no", bytes=[110, 111], logprob=0.0),
    TopLogprob(token="No", bytes=[78, 111], logprob=-24.75),
    TopLogprob(token="not", bytes=[110, 111, 116], logprob=-25.75),
    TopLogprob(token="NO", bytes=[78, 79], logprob=-26.125),
    TopLogprob(token="non", bytes=[110, 111, 110], logprob=-26.5),
]