In [26]:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import pandas as pd
import json
from utils import *

In [23]:
config = load_config()
PROJECT_PATH = config.project_path
DATA_PATH = PROJECT_PATH.joinpath("data/processed")
MODEL_PATH = config.model_path("llama3.2-3B")

In [24]:
from datasets import load_dataset

model = AutoModelForCausalLM.from_pretrained(MODEL_PATH)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

def load_datasets() :

    data = load_dataset("json", data_files={"train" : DATA_PATH.joinpath("train_conversation.jsonl").as_posix(),
                                     "test" : DATA_PATH.joinpath("test_conversation.jsonl").as_posix()})
    return data

Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.28s/it]


In [12]:
data = load_datasets()

dataset = data.map(lambda x : {"formatted_chat": tokenizer.apply_chat_template(x["messages"], tokenize=False, add_generation_prompt=False)})

Map: 100%|██████████| 14231/14231 [00:04<00:00, 2895.03 examples/s]
Map: 100%|██████████| 3558/3558 [00:01<00:00, 2918.75 examples/s]


In [20]:
dataset['train']['messages'][1]


[{'content': "You are a helpful assistant trained for healthcare. Here is the patient's discharge note. \n\n  \nName:  ___             Unit No:   ___\n \nAdmission Date:  ___              Discharge Date:   ___\n \nDate of Birth:  ___             Sex:   F\n \nService: MEDICINE\n \nAllergies: \nNo Known Allergies / Adverse Drug Reactions\n \nAttending: ___.\n \nChief Complaint:\nEncephalopathy\n \nMajor Surgical or Invasive Procedure:\nintubation\n \nHistory of Present Illness:\nMs. ___ is a ___ year old female, with past history of \naltered mental status, glioblastoma multiforme, who presented \nminimally responsive to an OSH, transferred for ___ for \nfurther management.\n \nThe patient was diagnosed with GBM approximately 8 weeks prior \nand is currently receiving care at ___. She is currently \nundergoing treatment with XRT and avastin. She has a baseline \nright sided hemiparesis.\n\nThe day prior to admission she had a R sided port place \nrequiring anesthesia due to patient anxie

In [8]:
# Define conversation categories
# 1. Return to the ED/Hospital indications (c1)
# 2. Medication Info (c2)
# 3. Diagnosis (c3)
# 4. Postdischarge treatment (c4)
# 5. ED tests and treatments (c5)
# 6. Follow-up (c6)

# 지금은 일단 naiive한 방법으로 분석한다.

In [62]:
prompt = """
### Instruction :
Here are some types of conversation categories between a physician and the patient. 
please classify each given sentence based on these criteria. 

### Define conversation categories
1. Return to the ED/Hospital indications (c1)
2. Medication Info (c2)
3. Diagnosis (c3)
4. Postdischarge treatment (c4)
5. ED tests and treatments (c5)
6. Follow-up (c6)

For example, if the sentence encompasses "Return to the ED/Hospital indications" then respond "c1"
If there are no match respond "NA"

### Sentence :
{}

### Classification : 
"""

In [51]:
model = pipeline("text-generation", model=MODEL_PATH)

Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.10s/it]
Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.


In [63]:
sample = dataset['train']['messages'][0]

In [66]:
outs = []
for s in sample :
    if s['role'] == "system" :
        continue
    text = s['content']
    out = model(prompt.format(text), return_full_text=False, max_new_tokens=3)
    outs.append(out)
    

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


In [83]:
outs[1][0]['generated_text']

'c4 \n\n'

In [45]:
out

[{'generated_text': ' \n\n*   Return to the ED/Hospital indications'}]