In [1]:
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, T5Tokenizer, T5ForConditionalGeneration
import torch
import pandas as pd
import glob
from read_csv_gz import read_csv_gz
import matplotlib.pyplot as plt
from tqdm import tqdm

In [2]:
ed_diagnosis_df = read_csv_gz("ed_data/diagnosis.csv.gz")
ed_edstays_df = read_csv_gz("ed_data/edstays.csv.gz")
ed_medrecon_df = read_csv_gz("ed_data/medrecon.csv.gz")
ed_pyxis_df = read_csv_gz("ed_data/pyxis.csv.gz")
ed_triage_df = read_csv_gz("ed_data/triage.csv.gz")
ed_vitalsign_df = read_csv_gz("ed_data/vitalsign.csv.gz")


In [3]:
def text_gen(x):
    return f"Chief Complaint: {x}"

In [4]:
print(ed_triage_df.shape[0])

425087


In [5]:
def converter(df, m):
    ans = []
    temp = []
    for i in range(df.shape[0]):
        text = text_gen(df.iloc[i])
        temp.append(text)
        if (i + 1) % m == 0:
            if len(temp) != 0 :
                ans.append(temp)
            temp = []
    if len(temp) != 0:
        ans.append(temp)
    return ans

In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [77]:
def inference_prompt(batch_lst, tokenizer, model, device="cuda" if torch.cuda.is_available() else "cpu"):
    ans = []
    for batch in tqdm(batch_lst):
        bs = len(batch)
        inputs = tokenizer(batch, padding=True, return_tensors="pt", truncation=True, max_length=128)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        with torch.no_grad():
            output_sequences = model(**inputs)
            outputs_cls_embeddings = output_sequences.last_hidden_state[:, 0, :].cpu().numpy() 
        for i in range(bs):
            ans.append(outputs_cls_embeddings[i, :])

    return pd.Series(ans)
        
        


In [78]:
lst_prompt = converter(ed_triage_df["chiefcomplaint"].head(234), 10)

In [83]:
model_name = "UFNLP/gatortron-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

config.json:   0%|          | 0.00/534 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/379k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/713M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/713M [00:00<?, ?B/s]

In [84]:
model.to(device)

MegatronBertModel(
  (embeddings): MegatronBertEmbeddings(
    (word_embeddings): Embedding(50176, 1024, padding_idx=0)
    (position_embeddings): Embedding(512, 1024)
    (token_type_embeddings): Embedding(2, 1024)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): MegatronBertEncoder(
    (layer): ModuleList(
      (0-23): 24 x MegatronBertLayer(
        (attention): MegatronBertAttention(
          (ln): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
          (self): MegatronBertSelfAttention(
            (query): Linear(in_features=1024, out_features=1024, bias=True)
            (key): Linear(in_features=1024, out_features=1024, bias=True)
            (value): Linear(in_features=1024, out_features=1024, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): MegatronBertSelfOutput(
            (dense): Linear(in_features=1024, out_features=1024, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
  

In [85]:
res = inference_prompt(lst_prompt, tokenizer, model, device=device)

100%|██████████| 24/24 [00:04<00:00,  4.85it/s]


In [86]:
print(res.iloc[0].shape)
print(res.shape)

(1024,)
(234,)
