To run this, press "*Runtime*" and press "*Run all*" on a **free** Tesla T4 Google Colab instance!
<div class="align-center">
<a href="https://unsloth.ai/"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
<a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord button.png" width="145"></a>
<a href="https://docs.unsloth.ai/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a></a> Join Discord if you need help + ⭐ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐
</div>

To install Unsloth on your own computer, follow the installation instructions on our Github page [here](https://docs.unsloth.ai/get-started/installing-+-updating).

You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & [how to save it](#Save)

Visit our docs for all our [model uploads](https://docs.unsloth.ai/get-started/all-our-models) and [notebooks](https://docs.unsloth.ai/get-started/unsloth-notebooks).


### Installation

In [1]:
%%capture
# Normally using pip install unsloth is enough

# Temporarily as of Jan 31st 2025, Colab has some issues with Pytorch
# Using pip install unsloth will take 3 minutes, whilst the below takes <1 minute:
!pip install --no-deps bitsandbytes accelerate xformers==0.0.29 peft trl triton
!pip install --no-deps cut_cross_entropy unsloth_zoo
!pip install sentencepiece protobuf datasets huggingface_hub hf_transfer
!pip install --no-deps unsloth

In [1]:
from unsloth import FastLanguageModel
import torch
import re 
from tqdm import tqdm
from datasets import load_dataset
from sklearn.metrics import cohen_kappa_score, confusion_matrix
import numpy as np
import pandas as pd
from datasets import Dataset

import sys
import os
sys.path.append(os.path.abspath(".."))
from importlib import reload
import utils.utils as utils
reload(utils)

max_seq_length = 1024 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = False # Use 4bit quantization to reduce memory usage. Can be False.
# 4bit pre quantized models we support for 4x faster downloading + no OOMs.
fourbit_models = [
    "unsloth/Meta-Llama-3.1-8B-bnb-4bit",      # Llama-3.1 15 trillion tokens model 2x faster!
    "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit",
    "unsloth/Meta-Llama-3.1-70B-bnb-4bit",
    "unsloth/Meta-Llama-3.1-405B-bnb-4bit",    # We also uploaded 4bit for 405b!
    "unsloth/Mistral-Nemo-Base-2407-bnb-4bit", # New Mistral 12b 2x faster!
    "unsloth/Mistral-Nemo-Instruct-2407-bnb-4bit",
    "unsloth/mistral-7b-v0.3-bnb-4bit",        # Mistral v3 2x faster!
    "unsloth/mistral-7b-instruct-v0.3-bnb-4bit",
    "unsloth/Phi-3.5-mini-instruct",           # Phi-3.5 2x faster!
    "unsloth/Phi-3-medium-4k-instruct",
    "unsloth/gemma-2-9b-bnb-4bit",
    "unsloth/gemma-2-27b-bnb-4bit",            # Gemma 2x faster!
] # More models at https://huggingface.co/unsloth


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 08-03 22:44:28 [__init__.py:256] Automatically detected platform cuda.


### Unsloth

In [2]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Qwen3-8B", ##"qwen3_8B_full_bit_3_epochs_times_10_paraphrase", #'qwen3_8B_full_bit_30_epochs', #"./llama_3_1_8B_ESI_Handbook_10",  #"unsloth/Meta-Llama-3.1-8B-Instruct",
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)

==((====))==  Unsloth 2025.7.11: Fast Qwen3 patching. Transformers: 4.54.1. vLLM: 0.8.1.
   \\   /|    NVIDIA GeForce RTX 3090. Num GPUs = 1. Max memory: 23.586 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.6. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post2. FA2 = True]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

We now add LoRA adapters so we only need to update 1 to 10% of all parameters!

In [3]:
model = FastLanguageModel.get_peft_model(
    model,
    r = 32, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

Unsloth 2025.7.11 patched 36 layers with 36 QKV layers, 36 O layers and 36 MLP layers.


<a name="Data"></a>
### Data Prep
We now use the Alpaca dataset from [yahma](https://huggingface.co/datasets/yahma/alpaca-cleaned), which is a filtered version of 52K of the original [Alpaca dataset](https://crfm.stanford.edu/2023/03/13/alpaca.html). You can replace this code section with your own data prep.

**[NOTE]** To train only on completions (ignoring the user's input) read TRL's docs [here](https://huggingface.co/docs/trl/sft_trainer#train-on-completions-only).

**[NOTE]** Remember to add the **EOS_TOKEN** to the tokenized output!! Otherwise you'll get infinite generations!

If you want to use the `llama-3` template for ShareGPT datasets, try our conversational [notebook](https://colab.research.google.com/drive/1XamvWYinY6FOSX9GLvnqSjjsNflxdhNc?usp=sharing).

For text completions like novel writing, try this [notebook](https://colab.research.google.com/drive/1ef-tab5bhkvWmBOObepl1WgJvfvSzn5Q?usp=sharing).

In [4]:
reload(utils)
# alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

# ### Instruction:
# {}

# ### Input:
# {}

# ### Response:
# {}"""

# EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN
# def formatting_prompts_func(examples):
#     instructions = examples["instruction"]
#     inputs       = examples["input"]
#     outputs      = examples["output"]
#     texts = []
#     for instruction, input, output in zip(instructions, inputs, outputs):
#         # Must add EOS_TOKEN, otherwise your generation will go on forever!
#         text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
#         texts.append(text)
#     return { "text" : texts, }
# pass

# from datasets import load_dataset
# dataset = load_dataset("yahma/alpaca-cleaned", split = "train")
# dataset = dataset.map(formatting_prompts_func, batched = True,)

EOS_TOKEN = tokenizer.eos_token  # Ensure EOS_TOKEN is defined

# ---- 1. Define helper functions ----

# Function to convert a CSV row into instruction, input, and output strings.
def format_example(row,dataset='triage-ktas'):
    if dataset == 'triage-ESI-train':
        return {"input": row['Clinical Vignettes'], "output": row['Rationale']}
    if dataset == 'triage-ESI-test':
        return {"input": row['Clinical Vignettes'], "output": row['Acuity']}
    # Create a natural language description of the patient.
    patient_description = utils.format_row(row, dataset=dataset)
    
    # Define the instruction for the model.
    if dataset=='triage-mimic':
        instruction = "Based on their clinical presentation, determine the Emergency Severity Index (ESI) acuity for the following patient."
    elif dataset=='triage-ktas':
        instruction = "Based on their clinical presentation, determine the KTAS acuity for the following patient."

    # The input is the patient description.
    input_text = patient_description
    
    # The expected output is a formatted acuity statement.
    if dataset == 'triage-ktas':
        output_text = f"The KTAS acuity for this patient is {row['KTAS_expert']}."
    else:
        output_text = f"The ESI acuity for this patient is {row['acuity']}."
    
    # Return a dict that contains our three keys (and optionally the label for evaluation).
    return {"instruction": instruction, "input": input_text, "output": output_text}

# ---- 2. Load the CSV and convert to a Dataset ----

# Read your CSV data.
#train_df = pd.read_csv("../data/mimic-iv-private/anchor_year_group_datasets/2014_-_2016/small_train_dataset.csv")
#train_df = pd.read_csv("../data/kaggle/train.csv")
train_df = pd.read_csv("../data/ESI-Handbook/train.csv")

# Convert the pandas DataFrame into a Hugging Face Dataset.
dataset = Dataset.from_pandas(train_df)

# Map our serialization function over each example.
dataset = dataset.map(lambda x: format_example(x,dataset='triage-ESI-train'))

# ---- 3. Format the examples into a single prompt string ----

# Define the Alpaca-style prompt template.
alpaca_prompt = """### Instruction: {}

### Input: {}

### Response: {}"""

lima_prompt = """{} Based on their clinical presentation, what Emergency Severity Index (ESI) acuity should the patient be assigned?

{}
"""

# Function to wrap our instruction, input, and output into a single text string.
def formatting_prompts_func(examples):
    # instructions = examples["instruction"]
    inputs = examples["input"]
    outputs = examples["output"]
    texts = []
    for input_text, output_text in zip(inputs, outputs):
        # Append EOS_TOKEN to avoid infinite generation.
        text = lima_prompt.format(input_text, output_text) + EOS_TOKEN
        texts.append(text)
    return {"text": texts}

def formatting_prompts_func_test(examples):
    # instructions = examples["instruction"]
    inputs = examples["input"]
    outputs = examples["output"]
    texts = []
    for input_text, output_text in zip(inputs, outputs):
        # Append EOS_TOKEN to avoid infinite generation.
        text = lima_prompt.format(input_text, "")
        texts.append(text)
    return {"text": texts}


# Map the formatting function over the dataset (batched for efficiency).
dataset = dataset.map(formatting_prompts_func, batched=True)

# ---- At this point, your dataset is in the format expected by your code. ----
# Each example in the dataset now has a "text" field containing your full prompt.
print(dataset[1])



Map:   0%|          | 0/89 [00:00<?, ? examples/s]

Map:   0%|          | 0/89 [00:00<?, ? examples/s]

{'Clinical Vignettes': 'A 41-year-old male involved in a bicycle accident walks into the emergency department with his right arm in a sling. He tells you that he fell off his bike and landed on his right arm. His is complaining of pain in the wrist area and has a 2-centimeter laceration on his left elbow. “My helmet saved me,” he tells you.', 'Rationale': 'ESI level 3: Two or more resources. At a minimum, this patient will require an x ray of his right arm and suturing of his left elbow laceration.', 'acuity': 3, 'input': 'A 41-year-old male involved in a bicycle accident walks into the emergency department with his right arm in a sling. He tells you that he fell off his bike and landed on his right arm. His is complaining of pain in the wrist area and has a 2-centimeter laceration on his left elbow. “My helmet saved me,” he tells you.', 'output': 'ESI level 3: Two or more resources. At a minimum, this patient will require an x ray of his right arm and suturing of his left elbow lacera

In [5]:
print(dataset[29])

{'Clinical Vignettes': '76-year-old male requests to see a doctor because his toenails are hard. Upon further questioning, the triage nurse ascertains that the patient is unable to cut his own toenails. He denies any breaks in the skin or signs of infection. He has a history of chronic obstructive pulmonary disease and uses several metered-dose inhalers. His vital signs are normal for his age.', 'Rationale': 'ESI level 5: No resources. This elderly gentleman has such brittle toenails that he is no longer able to clip them himself. He requires a brief exam and an outpatient referral to a podiatrist.', 'acuity': 5, 'input': '76-year-old male requests to see a doctor because his toenails are hard. Upon further questioning, the triage nurse ascertains that the patient is unable to cut his own toenails. He denies any breaks in the skin or signs of infection. He has a history of chronic obstructive pulmonary disease and uses several metered-dose inhalers. His vital signs are normal for his a

In [5]:
#test_df = pd.read_csv("../data/mimic-iv-private/anchor_year_group_datasets/2017_-_2019/test_dataset.csv")
test_df = pd.read_csv("../data/ESI-Handbook/test.csv")

test_dataset = Dataset.from_pandas(test_df)
test_dataset = test_dataset.map(lambda x: format_example(x,'triage-ESI-test'))
test_dataset = test_dataset.map(formatting_prompts_func_test, batched=True)

Map:   0%|          | 0/209 [00:00<?, ? examples/s]

Map:   0%|          | 0/209 [00:00<?, ? examples/s]

### Continued-Pretraining

In [9]:
# Load the cleaned text file
with open("../data/cleaned_handbook.txt", "r", encoding="utf-8") as f:
    text = f.read()

# Further clean the text
def remove_newline_between_chars(text):
    return re.sub(r"(\S)\n(\S)", r"\1 \2", text)

# Example usage
cleaned_text = remove_newline_between_chars(text)

# Tokenize and chunk text
def chunk_text(text, max_length):
    """Splits text into chunks while maintaining sentence boundaries."""
    import re
    sentences = re.split(r"(?<=[.!?]) +", text)  # Split by sentence
    chunks, current_chunk = [], ""
    
    for sentence in sentences:
        if len(tokenizer.encode(current_chunk + sentence)) < max_length:
            current_chunk += " " + sentence
        else:
            chunks.append(current_chunk.strip())
            current_chunk = sentence

    if current_chunk:
        chunks.append(current_chunk.strip())

    return chunks

# Create chunks
chunks = chunk_text(cleaned_text, max_seq_length)

# Add EOS token at the end of each chunk
EOS_TOKEN = tokenizer.eos_token
formatted_chunks = [{"text": chunk} for chunk in chunks]

# Convert to Hugging Face dataset
handbook = Dataset.from_list(formatted_chunks)

# Function for formatting
def formatting_prompts_func(examples):
    return {"text": [example for example in examples["text"]]}

# Apply formatting function
handbook = handbook.map(formatting_prompts_func, batched=True)

Map: 100%|██████████| 27/27 [00:00<00:00, 18215.57 examples/s]


In [10]:
handbook[0]

{'text': "# Emergency Severity Index, Version 4: Implementation Handbook\n\n### Note from the Director\n\nThe Agency for Healthcare Research and Quality is pleased to bring you the Emergency Severity _Index, Version 4: Implementation Handbook. This manual covers all details of the Emergency_ Severity Index (ESI)—a five-level emergency department triage algorithm that provides clinically relevant stratification of patients into five groups from 1 (most urgent) to 5 (least urgent) on the basis of acuity and resource needs.\n\nAfter emergency physicians Richard Wuerz and David Eitel developed the ESI in 1998 and pilot testing yielded favorable results, the ESI Triage Group was formed. Further work on the initial development of ESI was carried out under an AHRQ grant. The ESI Triage Group, which consisted of medical clinicians, managers, educators, and researchers, further refined the algorithm to what it is today.\n\nIn keeping with our mission to improve the quality, safety, efficiency, 

In [11]:
from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported
from unsloth import UnslothTrainer, UnslothTrainingArguments

trainer = UnslothTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = handbook,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    dataset_num_proc = 8,

    args = UnslothTrainingArguments(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 8,

        warmup_ratio = 0.1,
        num_train_epochs = 10,

        learning_rate = 5e-5,
        embedding_learning_rate = 5e-6,

        fp16 = not is_bfloat16_supported(),
        bf16 = is_bfloat16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.00,
        lr_scheduler_type = "cosine",
        seed = 3407,
        output_dir = "outputs",
        report_to = "none", # Use this for WandB etc
    ),
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Map (num_proc=8): 100%|██████████| 27/27 [00:01<00:00, 18.38 examples/s]


In [12]:
trainer_stats = trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs = 1
   \\   /|    Num examples = 27 | Num Epochs = 10
O^O/ \_/ \    Batch size per device = 2 | Gradient Accumulation steps = 8
\        /    Total batch size = 16 | Total steps = 10
 "-____-"     Number of trainable parameters = 83,886,080


Step,Training Loss
1,2.0421
2,4.0493
3,4.0637
4,4.0577
5,4.0473
6,4.0705
7,3.9566
8,4.0535
9,4.1799
10,3.9547


<a name="Train"></a>
### Train the model
Now let's use Huggingface TRL's `SFTTrainer`! More docs here: [TRL SFT docs](https://huggingface.co/docs/trl/sft_trainer). We do 60 steps to speed things up, but you can set `num_train_epochs=1` for a full run, and turn off `max_steps=None`. We also support TRL's `DPOTrainer`!

In [6]:
from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    dataset_num_proc = 2,
    packing = True, # Can make training 5x faster for short sequences.
    args = TrainingArguments(
        run_name = "qwen3_8B_full_bit_raw_+_ESI_Case_Examples_25", #"qwen3_8B_full_bit_ESI_Handbook_3_with_10_paragraphse_+_ESI_Case_Examples_25",
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        num_train_epochs = 25, # Set this for 1 full training run.
        # max_steps = 60,
        learning_rate = 2e-4,
        fp16 = not is_bfloat16_supported(),
        bf16 = is_bfloat16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "qwen3_8B_full_bit_raw_+_ESI_Case_Examples_25",
        report_to = "wandb", # Use this for WandB etc
    ),
)

Unsloth: Tokenizing ["text"]:   0%|          | 0/89 [00:00<?, ? examples/s]

In [7]:
trainer_stats = trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 89 | Num Epochs = 25 | Total steps = 300
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 4 x 1) = 8
 "-____-"     Trainable parameters = 87,293,952 of 8,278,029,312 (1.05% trained)
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mjiosephlee[0m ([33mupenn-ml[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss
1,2.5835
2,2.6708
3,2.5342
4,2.618
5,2.4939
6,2.5007
7,2.2506
8,2.2249
9,2.0915
10,1.9421


In [13]:
# @title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

GPU = NVIDIA GeForce RTX 3090. Max memory = 23.586 GB.
16.727 GB of memory reserved.


In [14]:
# @title Show final memory and time stats
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory * 100, 3)
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(
    f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training."
)
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")

482.7513 seconds used for training.
8.05 minutes used for training.
Peak reserved memory = 16.727 GB.
Peak reserved memory for training = 0.0 GB.
Peak reserved memory % of max memory = 70.919 %.
Peak reserved memory for training % of max memory = 0.0 %.


<a name="Inference"></a>
### Inference
Let's run the model! You can change the instruction and input - leave the output blank!

**[NEW] Try 2x faster inference in a free Colab for Llama-3.1 8b Instruct [here](https://colab.research.google.com/drive/1T-YBVfnphoVc8E2E854qF3jdia2Ll2W2?usp=sharing)**

In [17]:
# alpaca_prompt = Copied from above
FastLanguageModel.for_inference(model) # Enable native 2x faster inference
inputs = tokenizer(
[
    test_dataset[14]['text']
    # alpaca_prompt.format(
    #     "Continue the fibonnaci sequence.", # instruction
    #     "1, 1, 2, 3, 5, 8", # input
    #     "", # output - leave this blank for generation!
    # )
], return_tensors = "pt").to("cuda")

outputs = model.generate(**inputs, max_new_tokens = 250, use_cache = True)
tokenizer.batch_decode(outputs)

['<|begin_of_text|>He has had diarrhea for 2 days, and he just started throwing up this morning. This has been going around the family, and he seems to have it the worst. He has been drinking before today, but now he doesn’t want anything to drink,” reports the mother of a 19-month-old. The toddler is awake and alert but quiet in the mother’s arms, and you notice his lips are dry and cracked. Vital signs: T 99°F, RR 30, HR 130, SpO2 100%. Based on their clinical presentation, what Emergency Severity Index (ESI) acuity should the patient be assigned?\n\n\nESI level 2: High-risk patient. The patient is 19 months old, and he is showing signs of dehydration and possible intoxication. He needs to be assessed for the need of IV fluids and possibly IV medications. His clinical presentation is consistent with a high-risk patient.\n```\n\n\n\nKey Concept: The Emergency Severity Index (ESI) is a five-level acuity-based triage tool that is designed to identify patients who need immediate lifesavi

In [None]:
from sklearn.metrics import cohen_kappa_score, mean_squared_error

FastLanguageModel.for_inference(model) # Enable native 2x faster inference

def extract_response(text, dataset = 'triage'):
    if dataset == 'triage-handbook':
        match = re.search(r"ESI\s+level\s+(\d+)", text, re.DOTALL)
        return int(match.group(1)) if match else None
    # This pattern looks for "Response:" and then non-greedily skips any characters until it finds a sequence of digits.
    match = re.search(r"Response:.*?(\d+)", text, re.DOTALL)
    return int(match.group(1)) if match else None

# Initialize tracking variables
correct = 0
wrong = 0
y_true = []
y_pred = []
undertriage = 0
overtriage = 0

def generate_response(input_text):
    inputs = tokenizer([input_text], return_tensors="pt").to("cuda")
    outputs = model.generate(**inputs, max_new_tokens=65, use_cache=True)
    decoded_output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
    #print(decoded_output)
    return extract_response(decoded_output, dataset='triage-handbook')

# Iterate through test dataset
for i, sample in tqdm(enumerate(test_dataset)):
    input_text = sample['text']
    true_acuity = int(sample['Acuity'])
    predicted_acuity = generate_response(input_text)
    
    if predicted_acuity is not None:
        #print(f"Sample {i}: True Acuity: {true_acuity}, Predicted: {predicted_acuity}")
        y_true.append(true_acuity)
        y_pred.append(predicted_acuity)
        if predicted_acuity == true_acuity:
            correct += 1
        else:
            wrong += 1
            
            # Track undertriage and overtriage
            if predicted_acuity < true_acuity:
                undertriage += 1
            elif predicted_acuity > true_acuity:
                overtriage += 1
    else:
        print(f"Sample {i}: No valid response extracted.")
        wrong += 1

# Calculate Undertriage & Overtriage Rates
total_samples = len(y_true)
undertriage_rate = undertriage / total_samples * 100
overtriage_rate = overtriage / total_samples * 100
# Print accuracy
accuracy = correct / (correct + wrong) * 100
qwk_score = cohen_kappa_score(y_true, y_pred, weights='quadratic')
mse = mean_squared_error(y_true, y_pred)
print(f"Model Accuracy: {accuracy:.2f}%")
print(f"Quadratic Weighted Kappa (QWK): {qwk_score:.4f}")
print(f"Mean Squared Error (MSE): {mse:.4f}")
print(f"Undertriage Rate: {undertriage_rate:.2f}%")
print(f"Overtriage Rate: {overtriage_rate:.2f}%")

# Save predictions and ground truth in a CSV file using pandas
# You can also include an index column if needed.
df = pd.DataFrame({
    "Index": range(len(y_true)),
    "True_Acuity": y_true,
    "Predicted_Acuity": y_pred
})
df.to_csv("../results/Triage-Handbook/Qwen3-8B-full_bit-raw_Tuned_25_predictions.csv", index=False)
print("Predictions and ground truth saved.")

209it [06:49,  1.96s/it]

Model Accuracy: 70.81%
Quadratic Weighted Kappa (QWK): 0.8814
Mean Squared Error (MSE): 0.4258
Undertriage Rate: 21.05%
Overtriage Rate: 8.13%
Predictions and ground truth saved.





: 

In [23]:
old_df = pd.read_csv("../results/Triage-Handbook/Pretrained_10_Tuned_10_predictions.csv")


In [8]:
from sklearn.metrics import f1_score
f1_weighted = f1_score(old_df['True_Acuity'], old_df['Predicted_Acuity'], average='weighted')

print(f"F1 Score (Weighted): {f1_weighted:.4f}")

NameError: name 'old_df' is not defined

In [9]:
from sklearn.metrics import f1_score
f1_weighted = f1_score(y_true, y_pred, average='weighted')

print(f"F1 Score (Weighted): {f1_weighted:.4f}")

F1 Score (Weighted): 0.7372


<a name="Save"></a>
### Saving, loading finetuned models
To save the final model as LoRA adapters, either use Huggingface's `push_to_hub` for an online save or `save_pretrained` for a local save.

**[NOTE]** This ONLY saves the LoRA adapters, and not the full model. To save to 16bit or GGUF, scroll down!

In [9]:
model.save_pretrained("llama_3_1_8B_ESI_Handbook_10_+_ESI_Case_Examples_25")  # Local saving


In [22]:
model.save_pretrained("lora_model")  # Local saving
tokenizer.save_pretrained("lora_model")
# model.push_to_hub("your_name/lora_model", token = "...") # Online saving
# tokenizer.push_to_hub("your_name/lora_model", token = "...") # Online saving

SyntaxError: invalid decimal literal (2453640556.py, line 1)

Now if you want to load the LoRA adapters we just saved for inference, set `False` to `True`:

In [5]:
if True:
    from unsloth import FastLanguageModel
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name = "outputs/checkpoint-12500", # YOUR MODEL YOU USED FOR TRAINING
        max_seq_length = max_seq_length,
        dtype = dtype,
        load_in_4bit = load_in_4bit,
    )
    FastLanguageModel.for_inference(model) # Enable native 2x faster inference

# alpaca_prompt = You MUST copy from above!

# inputs = tokenizer(
# [
#     alpaca_prompt.format(
#         "What is a famous tall tower in Paris?", # instruction
#         "", # input
#         "", # output - leave this blank for generation!
#     )
# ], return_tensors = "pt").to("cuda")

# from transformers import TextStreamer
# text_streamer = TextStreamer(tokenizer)
# _ = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128)

==((====))==  Unsloth 2025.1.8: Fast Llama patching. Transformers: 4.48.2.
   \\   /|    GPU: NVIDIA GeForce RTX 3090. Max memory: 23.586 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.6. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post2. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.34it/s]
Unsloth 2025.1.8 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


NameError: name 'tqdm' is not defined

In [15]:

def quadratic_weighted_kappa(y_true, y_pred, min_rating=1, max_rating=5):
    hist_rater_a = np.histogram(y_true, bins=np.arange(min_rating, max_rating + 2))[0]
    hist_rater_b = np.histogram(y_pred, bins=np.arange(min_rating, max_rating + 2))[0]
    
    confusion = confusion_matrix(y_true, y_pred, labels=np.arange(min_rating, max_rating + 1))
    num_ratings = len(hist_rater_a)
    weights = np.array([[((i - j) ** 2) / ((num_ratings - 1) ** 2) for j in range(num_ratings)] for i in range(num_ratings)])
    expected = np.outer(hist_rater_a, hist_rater_b) / np.sum(hist_rater_a)
    kappa = 1.0 - (np.sum(weights * confusion) / np.sum(weights * expected))
    return kappa

def extract_response(text):
    match = re.search(r"Response:\s*(\d+)", text)
    return int(match.group(1)) if match else None

# Initialize tracking variables
correct = 0
wrong = 0
y_true = []
y_pred = []

def generate_response(input_text):
    print(input_text)
    inputs = tokenizer([input_text], return_tensors="pt").to("cuda")
    outputs = model.generate(**inputs, max_new_tokens=100, use_cache=True)
    decoded_output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
    print(decoded_output)
    return extract_response(decoded_output)

# Iterate through test dataset
for i, sample in tqdm(enumerate(test_dataset)):
    input_text = sample['text']
    true_acuity = sample['acuity']
    predicted_acuity = generate_response(input_text)
    
    if predicted_acuity is not None:
        # print(f"Sample {i}: True Acuity: {true_acuity}, Predicted: {predicted_acuity}")
        y_true.append(true_acuity)
        y_pred.append(predicted_acuity)
        if predicted_acuity == true_acuity:
            correct += 1
        else:
            wrong += 1
    else:
        print(f"Sample {i}: No valid response extracted.")
        wrong += 1

# Print accuracy
accuracy = correct / (correct + wrong) * 100
qwk_score = quadratic_weighted_kappa(y_true, y_pred)
print(f"Model Accuracy: {accuracy:.2f}%")
print(f"Quadratic Weighted Kappa (QWK): {qwk_score:.4f}")

0it [00:00, ?it/s]

### Instruction: Estimate the Emergency Severity Index (ESI) acuity for the following patient.

### Input: A black/african american, 61.0-year-old man arrives at the emergency department via ambulance. He has a temperature of 98.6°F, a heart rate of 76.0 bpm, a respiratory rate of 20.0 breaths per minute, oxygen saturation at 99.0%, systolic blood pressure of 151.0 mmHg, diastolic blood pressure of 90.0 mmHg, pain level reported as 13, and chief complaint described as "SHORTNESS OF BREATH".

### Response: 


1it [00:00,  1.82it/s]

### Instruction: Estimate the Emergency Severity Index (ESI) acuity for the following patient.

### Input: A black/african american, 61.0-year-old man arrives at the emergency department via ambulance. He has a temperature of 98.6°F, a heart rate of 76.0 bpm, a respiratory rate of 20.0 breaths per minute, oxygen saturation at 99.0%, systolic blood pressure of 151.0 mmHg, diastolic blood pressure of 90.0 mmHg, pain level reported as 13, and chief complaint described as "SHORTNESS OF BREATH".

### Response:  The estimated ESI acuity for this patient is 2.0.
Sample 0: No valid response extracted.
### Instruction: Estimate the Emergency Severity Index (ESI) acuity for the following patient.

### Input: A white - other european, 62.0-year-old woman arrives at the emergency department via ambulance. She has a temperature of 97.7°F, a heart rate of 82.0 bpm, a respiratory rate of 24.0 breaths per minute, oxygen saturation at 100.0%, systolic blood pressure of 155.0 mmHg, diastolic blood press

2it [00:01,  1.97it/s]

### Instruction: Estimate the Emergency Severity Index (ESI) acuity for the following patient.

### Input: A white - other european, 62.0-year-old woman arrives at the emergency department via ambulance. She has a temperature of 97.7°F, a heart rate of 82.0 bpm, a respiratory rate of 24.0 breaths per minute, oxygen saturation at 100.0%, systolic blood pressure of 155.0 mmHg, diastolic blood pressure of 83.0 mmHg, pain level reported as 3, and chief complaint described as "RLQ abdominal pain".

### Response:  The estimated ESI acuity for this patient is 3.0.
Sample 1: No valid response extracted.
### Instruction: Estimate the Emergency Severity Index (ESI) acuity for the following patient.

### Input: A white, 56.0-year-old man arrives at the emergency department via ambulance. He has a temperature of 98.7°F, a heart rate of 88.0 bpm, a respiratory rate of 16.0 breaths per minute, oxygen saturation at 97.0%, systolic blood pressure of 140.0 mmHg, diastolic blood pressure of 84.0 mmHg, p

3it [00:01,  2.05it/s]

### Instruction: Estimate the Emergency Severity Index (ESI) acuity for the following patient.

### Input: A white, 56.0-year-old man arrives at the emergency department via ambulance. He has a temperature of 98.7°F, a heart rate of 88.0 bpm, a respiratory rate of 16.0 breaths per minute, oxygen saturation at 97.0%, systolic blood pressure of 140.0 mmHg, diastolic blood pressure of 84.0 mmHg, pain level reported as 5, and chief complaint described as "R Flank pain, Right sided abdominal pain".

### Response:  The estimated ESI acuity for this patient is 3.0.
Sample 2: No valid response extracted.
### Instruction: Estimate the Emergency Severity Index (ESI) acuity for the following patient.

### Input: A black/african american, 63.0-year-old woman arrives at the emergency department via ambulance. She has a temperature of 98.1°F, a heart rate of 80.0 bpm, a respiratory rate of 16.0 breaths per minute, oxygen saturation at 100.0%, systolic blood pressure of 180.0 mmHg, diastolic blood pr

4it [00:01,  2.10it/s]

### Instruction: Estimate the Emergency Severity Index (ESI) acuity for the following patient.

### Input: A black/african american, 63.0-year-old woman arrives at the emergency department via ambulance. She has a temperature of 98.1°F, a heart rate of 80.0 bpm, a respiratory rate of 16.0 breaths per minute, oxygen saturation at 100.0%, systolic blood pressure of 180.0 mmHg, diastolic blood pressure of 90.0 mmHg, pain level reported as 3, and chief complaint described as "Neck pain, Back pain, MVC".

### Response:  The estimated ESI acuity for this patient is 4.0.
Sample 3: No valid response extracted.
### Instruction: Estimate the Emergency Severity Index (ESI) acuity for the following patient.

### Input: A white, 55.0-year-old woman arrives at the emergency department via ambulance. She has a temperature of 98.6°F, a heart rate of 96.0 bpm, a respiratory rate of 16.0 breaths per minute, oxygen saturation at 96.0%, systolic blood pressure of 164.0 mmHg, diastolic blood pressure of 80

5it [00:02,  2.13it/s]

### Instruction: Estimate the Emergency Severity Index (ESI) acuity for the following patient.

### Input: A white, 55.0-year-old woman arrives at the emergency department via ambulance. She has a temperature of 98.6°F, a heart rate of 96.0 bpm, a respiratory rate of 16.0 breaths per minute, oxygen saturation at 96.0%, systolic blood pressure of 164.0 mmHg, diastolic blood pressure of 80.0 mmHg, pain level reported as 3, and chief complaint described as "L Leg pain".

### Response:  The estimated ESI acuity for this patient is 3.0.
Sample 4: No valid response extracted.
### Instruction: Estimate the Emergency Severity Index (ESI) acuity for the following patient.

### Input: A white, 78.0-year-old woman arrives at the emergency department via walk-in. She has a temperature of 98.0°F, a heart rate of 72.0 bpm, a respiratory rate of 18.0 breaths per minute, oxygen saturation at 97.0%, systolic blood pressure of 162.0 mmHg, diastolic blood pressure of 65.0 mmHg, pain level reported as 6, 

6it [00:02,  2.14it/s]

### Instruction: Estimate the Emergency Severity Index (ESI) acuity for the following patient.

### Input: A white, 78.0-year-old woman arrives at the emergency department via walk-in. She has a temperature of 98.0°F, a heart rate of 72.0 bpm, a respiratory rate of 18.0 breaths per minute, oxygen saturation at 97.0%, systolic blood pressure of 162.0 mmHg, diastolic blood pressure of 65.0 mmHg, pain level reported as 6, and chief complaint described as "LUQ abdominal pain".

### Response:  The estimated ESI acuity for this patient is 3.0.
Sample 5: No valid response extracted.
### Instruction: Estimate the Emergency Severity Index (ESI) acuity for the following patient.

### Input: A black/african american, 52.0-year-old woman arrives at the emergency department via ambulance. She has a temperature of 98.4°F, a heart rate of 79.0 bpm, a respiratory rate of 18.0 breaths per minute, oxygen saturation at 99.0%, systolic blood pressure of 117.0 mmHg, diastolic blood pressure of 66.0 mmHg, p

7it [00:03,  2.16it/s]

### Instruction: Estimate the Emergency Severity Index (ESI) acuity for the following patient.

### Input: A black/african american, 52.0-year-old woman arrives at the emergency department via ambulance. She has a temperature of 98.4°F, a heart rate of 79.0 bpm, a respiratory rate of 18.0 breaths per minute, oxygen saturation at 99.0%, systolic blood pressure of 117.0 mmHg, diastolic blood pressure of 66.0 mmHg, pain level reported as 8, and chief complaint described as "R LEG PAIN".

### Response:  The estimated ESI acuity for this patient is 3.0.
Sample 6: No valid response extracted.
### Instruction: Estimate the Emergency Severity Index (ESI) acuity for the following patient.

### Input: A black/african, 28.0-year-old woman arrives at the emergency department via walk-in. She has a temperature of 99.7°F, a heart rate of 100.0 bpm, a respiratory rate of 16.0 breaths per minute, oxygen saturation at 100.0%, systolic blood pressure of 124.0 mmHg, diastolic blood pressure of 62.0 mmHg,

8it [00:03,  2.16it/s]

### Instruction: Estimate the Emergency Severity Index (ESI) acuity for the following patient.

### Input: A black/african, 28.0-year-old woman arrives at the emergency department via walk-in. She has a temperature of 99.7°F, a heart rate of 100.0 bpm, a respiratory rate of 16.0 breaths per minute, oxygen saturation at 100.0%, systolic blood pressure of 124.0 mmHg, diastolic blood pressure of 62.0 mmHg, pain level reported as 10, and chief complaint described as "SORE THROAT".

### Response:  The estimated ESI acuity for this patient is 3.0.
Sample 7: No valid response extracted.
### Instruction: Estimate the Emergency Severity Index (ESI) acuity for the following patient.

### Input: A hispanic/latino - dominican, 68.0-year-old woman arrives at the emergency department via walk-in. She has a temperature of 99.3°F, a heart rate of 79.0 bpm, a respiratory rate of 16.0 breaths per minute, oxygen saturation at 100.0%, systolic blood pressure of 142.0 mmHg, diastolic blood pressure of 82.0

9it [00:04,  2.17it/s]

### Instruction: Estimate the Emergency Severity Index (ESI) acuity for the following patient.

### Input: A hispanic/latino - dominican, 68.0-year-old woman arrives at the emergency department via walk-in. She has a temperature of 99.3°F, a heart rate of 79.0 bpm, a respiratory rate of 16.0 breaths per minute, oxygen saturation at 100.0%, systolic blood pressure of 142.0 mmHg, diastolic blood pressure of 82.0 mmHg, pain level reported as 10, and chief complaint described as "Finger swelling".

### Response:  The estimated ESI acuity for this patient is 3.0.
Sample 8: No valid response extracted.
### Instruction: Estimate the Emergency Severity Index (ESI) acuity for the following patient.

### Input: A black/cape verdean, 23.0-year-old man arrives at the emergency department via ambulance. He has a temperature of 97.3°F, a heart rate of 75.0 bpm, a respiratory rate of 16.0 breaths per minute, oxygen saturation at 98.0%, systolic blood pressure of 108.0 mmHg, diastolic blood pressure o

9it [00:04,  1.99it/s]


KeyboardInterrupt: 

You can also use Hugging Face's `AutoModelForPeftCausalLM`. Only use this if you do not have `unsloth` installed. It can be hopelessly slow, since `4bit` model downloading is not supported, and Unsloth's **inference is 2x faster**.

In [None]:
if False:
    # I highly do NOT suggest - use Unsloth if possible
    from peft import AutoPeftModelForCausalLM
    from transformers import AutoTokenizer
    model = AutoPeftModelForCausalLM.from_pretrained(
        "lora_model", # YOUR MODEL YOU USED FOR TRAINING
        load_in_4bit = load_in_4bit,
    )
    tokenizer = AutoTokenizer.from_pretrained("lora_model")

### Saving to float16 for VLLM

We also support saving to `float16` directly. Select `merged_16bit` for float16 or `merged_4bit` for int4. We also allow `lora` adapters as a fallback. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens.

In [None]:
# Merge to 16bit
if False: model.save_pretrained_merged("model", tokenizer, save_method = "merged_16bit",)
if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_16bit", token = "")

# Merge to 4bit
if False: model.save_pretrained_merged("model", tokenizer, save_method = "merged_4bit",)
if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_4bit", token = "")

# Just LoRA adapters
if False: model.save_pretrained_merged("model", tokenizer, save_method = "lora",)
if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "lora", token = "")

### GGUF / llama.cpp Conversion
To save to `GGUF` / `llama.cpp`, we support it natively now! We clone `llama.cpp` and we default save it to `q8_0`. We allow all methods like `q4_k_m`. Use `save_pretrained_gguf` for local saving and `push_to_hub_gguf` for uploading to HF.

Some supported quant methods (full list on our [Wiki page](https://github.com/unslothai/unsloth/wiki#gguf-quantization-options)):
* `q8_0` - Fast conversion. High resource use, but generally acceptable.
* `q4_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K.
* `q5_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K.

[**NEW**] To finetune and auto export to Ollama, try our [Ollama notebook](https://colab.research.google.com/drive/1WZDi7APtQ9VsvOrQSSC5DDtxq159j8iZ?usp=sharing)

In [None]:
# Save to 8bit Q8_0
if False: model.save_pretrained_gguf("model", tokenizer,)
# Remember to go to https://huggingface.co/settings/tokens for a token!
# And change hf to your username!
if False: model.push_to_hub_gguf("hf/model", tokenizer, token = "")

# Save to 16bit GGUF
if False: model.save_pretrained_gguf("model", tokenizer, quantization_method = "f16")
if False: model.push_to_hub_gguf("hf/model", tokenizer, quantization_method = "f16", token = "")

# Save to q4_k_m GGUF
if False: model.save_pretrained_gguf("model", tokenizer, quantization_method = "q4_k_m")
if False: model.push_to_hub_gguf("hf/model", tokenizer, quantization_method = "q4_k_m", token = "")

# Save to multiple GGUF options - much faster if you want multiple!
if False:
    model.push_to_hub_gguf(
        "hf/model", # Change hf to your username!
        tokenizer,
        quantization_method = ["q4_k_m", "q8_0", "q5_k_m",],
        token = "",
    )

Now, use the `model-unsloth.gguf` file or `model-unsloth-Q4_K_M.gguf` file in llama.cpp or a UI based system like Jan or Open WebUI. You can install Jan [here](https://github.com/janhq/jan) and Open WebUI [here](https://github.com/open-webui/open-webui)

And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!

Some other links:
1. Llama 3.2 Conversational notebook. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(1B_and_3B)-Conversational.ipynb)
2. Saving finetunes to Ollama. [Free notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)
3. Llama 3.2 Vision finetuning - Radiography use case. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb)
6. See notebooks for DPO, ORPO, Continued pretraining, conversational finetuning and more on our [documentation](https://docs.unsloth.ai/get-started/unsloth-notebooks)!

<div class="align-center">
  <a href="https://unsloth.ai"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
  <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord.png" width="145"></a>
  <a href="https://docs.unsloth.ai/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a>

  Join Discord if you need help + ⭐️ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐️
</div>
