In [15]:
# !pip install transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import pandas as pd
import glob
from read_csv_gz import read_csv_gz
import matplotlib.pyplot as plt
from tqdm import tqdm

In [16]:
import sys
print(sys.executable)


/opt/miniconda3/envs/erdos_spring_2025/bin/python


In [17]:
ed_diagnosis_df = read_csv_gz("ed_data/diagnosis.csv.gz")
ed_edstays_df = read_csv_gz("ed_data/edstays.csv.gz")
ed_medrecon_df = read_csv_gz("ed_data/medrecon.csv.gz")
ed_pyxis_df = read_csv_gz("ed_data/pyxis.csv.gz")
ed_triage_df = read_csv_gz("ed_data/triage.csv.gz")
ed_vitalsign_df = read_csv_gz("ed_data/vitalsign.csv.gz")

In [18]:
def text_gen(x):
    return f"You are a medical expert assisting in an Emergency Department (ED). Your task is to assess the **diagnostic complexity** of the following chief complaint on a scale of 1-10:\n- **1-3 (Low Complexity):** Easily diagnosable, requires minimal tests.\n- **4-6 (Moderate Complexity):** Needs some testing, single specialty.\n- **7-10 (High Complexity):** Requires multiple tests, possible admission, specialist consults.\nChief Complaint: **{x}**\nPlease only return an integer of Complexity Score (1-10):"

In [19]:
print(ed_triage_df.shape[0])

425087


In [20]:
def converter(df, m):
    ans = []
    temp = []
    for i in range(df.shape[0]):
        text = text_gen(df.iloc[i])
        temp.append(text)
        if (i + 1) % m == 0:
            if len(temp) != 0 :
                ans.append(temp)
            temp = []
    if len(temp) != 0:
        ans.append(temp)
    return ans

In [21]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [22]:
def indefence_prompt(batch_lst, tokenizer, model, device="cuda" if torch.cuda.is_available() else "cpu"):
    ans = []
    for batch in tqdm(batch_lst):
        bs = len(batch)
        inputs = tokenizer(batch, padding=True, padding_side='left', return_tensors="pt").to(device)
        output_sequences = model.generate(**inputs, max_length=200)
        for i in range(bs):
            generated_text = tokenizer.decode(output_sequences[i, :], skip_special_tokens=True)
            ans.append(generated_text)
    return pd.Series(ans)
        
        


In [23]:
lst_prompt = converter(ed_triage_df["chiefcomplaint"].head(234), 10)

In [24]:
import torch
print(torch.__version__)
print(torch.cuda.is_available())  # Should print True if CUDA is installed correctly


2.6.0
False


In [25]:
model_name = "bigscience/bloom-560m"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32)
tokenizer.pad_token = tokenizer.eos_token

In [26]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

GPTNeoXForCausalLM(
  (gpt_neox): GPTNeoXModel(
    (embed_in): Embedding(50304, 512)
    (emb_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-5): 6 x GPTNeoXLayer(
        (input_layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (post_attention_layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (post_attention_dropout): Dropout(p=0.0, inplace=False)
        (post_mlp_dropout): Dropout(p=0.0, inplace=False)
        (attention): GPTNeoXAttention(
          (query_key_value): Linear(in_features=512, out_features=1536, bias=True)
          (dense): Linear(in_features=512, out_features=512, bias=True)
        )
        (mlp): GPTNeoXMLP(
          (dense_h_to_4h): Linear(in_features=512, out_features=2048, bias=True)
          (dense_4h_to_h): Linear(in_features=2048, out_features=512, bias=True)
          (act): GELUActivation()
        )
      )
    )
    (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise

In [27]:
res = indefence_prompt(lst_prompt, tokenizer, model, device="cuda" if torch.cuda.is_available() else "cpu")

  0%|          | 0/24 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
  4%|▍         | 1/24 [00:01<00:41,  1.80s/it]Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
  8%|▊         | 2/24 [00:03<00:41,  1.89s/it]Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
 12%|█▎        | 3/24 [00:05<00:34,  1.66s/it]Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
 17%|█▋        | 4/24 [00:06<00:34,  1.72s/it]Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
 21%|██        | 5/24 [00:08<00:32,  1.71s/it]Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
 25%|██▌       | 6/24 [00:09<00:28,  1.56s/it]Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
 29%|██▉       | 7/24 [00:11<00:26,  1.55s/it]Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
 33%|███▎      | 8/24 [00:13<00:24,  1.55s/it]Setting `pad_token_id` to `eos_token_id`:0 for ope

In [28]:
print(res)

0      You are a medical expert assisting in an Emerg...
1      You are a medical expert assisting in an Emerg...
2      You are a medical expert assisting in an Emerg...
3      You are a medical expert assisting in an Emerg...
4      You are a medical expert assisting in an Emerg...
                             ...                        
229    You are a medical expert assisting in an Emerg...
230    You are a medical expert assisting in an Emerg...
231    You are a medical expert assisting in an Emerg...
232    You are a medical expert assisting in an Emerg...
233    You are a medical expert assisting in an Emerg...
Length: 234, dtype: object


In [29]:
print(res[0])

You are a medical expert assisting in an Emergency Department (ED). Your task is to assess the **diagnostic complexity** of the following chief complaint on a scale of 1-10:
- **1-3 (Low Complexity):** Easily diagnosable, requires minimal tests.
- **4-6 (Moderate Complexity):** Needs some testing, single specialty.
- **7-10 (High Complexity):** Requires multiple tests, possible admission, specialist consults.
Chief Complaint: **Hypotension**
Please only return an integer of Complexity Score (1-10):
- **1-3 (Low Complexity):** Easily diagnosable, requires minimal tests.
- **4-6 (Moderate Complexity):** Needs some testing, single specialty.
- **7-10 (High Complexity):** Requires multiple tests, possible
