Setup & imports

In [60]:
import re
import random
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
from torch.utils.data import Dataset, DataLoader

from transformers import AutoTokenizer  # or the specific tokenizer you use

In [61]:
# Reading CSV file using pandas
# Basic usage:
df = pd.read_csv(r"D:\UT_Austin\sem4\physionet.org\files\labelled-notes-hospital-course\1.1.0\mimic-iv-bhc.csv")

print("CSV file loaded successfully!")
print(f"Shape: {df.shape}")
print(f"Columns: {list(df.columns)}")
df.head()  # Display first 5 rows

CSV file loaded successfully!
Shape: (270033, 5)
Columns: ['note_id', 'input', 'target', 'input_tokens', 'target_tokens']


Unnamed: 0,note_id,input,target,input_tokens,target_tokens
0,10000032-DS-21,<SEX> F <SERVICE> MEDICINE <ALLERGIES> No Know...,"___ HCV cirrhosis c/b ascites, hiv on ART, h/o...",1946,231
1,10000032-DS-22,<SEX> F <SERVICE> MEDICINE <ALLERGIES> Percoce...,"___ with HIV on HAART, HCV cirrhosis with asci...",2183,810
2,10000117-DS-21,<SEX> F <SERVICE> MEDICINE <ALLERGIES> omepraz...,Ms. ___ is a ___ with history of GERD who pres...,1060,172
3,10000117-DS-22,<SEX> F <SERVICE> ORTHOPAEDICS <ALLERGIES> ome...,The patient presented to the emergency departm...,1195,330
4,10000248-DS-10,<SEX> M <SERVICE> MEDICINE <ALLERGIES> No Know...,Mr. ___ is a ___ with history of mild FVIII de...,1961,230


Basic cleaning / normalization functions

In [62]:
def normalize_whitespace(text: str) -> str:
    # collapse multiple spaces / newlines into single
    text = re.sub(r'\r', ' ', text)
    text = re.sub(r'\n+', '\n', text)
    text = re.sub(r'\s+', ' ', text)
    return text.strip()

def clean_placeholders(text: str) -> str:
    # e.g. ensure consistent casing / spacing around placeholders
    # This is generic; adapt if you have more placeholder variants
    text = re.sub(r'<\s*SEX\s*>', '<SEX>', text, flags=re.IGNORECASE)
    text = re.sub(r'<\s*SERVICE\s*>', '<SERVICE>', text, flags=re.IGNORECASE)
    text = re.sub(r'<\s*ALLERGIES\s*>', '<ALLERGIES>', text, flags=re.IGNORECASE)
    return text

def remove_control_characters(text: str) -> str:
    # remove non-printable / weird control chars
    # e.g. anything in category Cc except \n
    cleaned = ''.join(ch for ch in text if (ch == '\n' or (ord(ch) >= 32 and ord(ch) != 127)))
    return cleaned

def preprocess_text(text: str) -> str:
    if text is None:
        return ""
    text = normalize_whitespace(text)
    text = clean_placeholders(text)
    text = remove_control_characters(text)
    return text


In [63]:
df['input_clean'] = df['input'].map(preprocess_text)
df['target_clean'] = df['target'].map(preprocess_text)


Filtering / dropping low-quality / edge examples

In [64]:
# Drop examples where input or target is empty / very short
MIN_INPUT_TOKENS = 50
MIN_TARGET_TOKENS = 20

mask_good = (df['input_clean'].map(len) > MIN_INPUT_TOKENS) & \
            (df['target_clean'].map(len) > MIN_TARGET_TOKENS)

df = df[mask_good].copy()

# Optionally: drop outliers where target is longer than input (if that happens)
df = df[df['target_tokens'] < df['input_tokens'] * 1.2]  # example threshold


Tokenization & truncation / handling long inputs

In [65]:
tokenizer = AutoTokenizer.from_pretrained("Falconsai/medical_summarization")  # e.g. ‚Äút5-base‚Äù, ‚Äúllama-‚Ä¶‚Äù, etc.
# You'll also need the model itself
from transformers import AutoModelForSeq2SeqLM
model = AutoModelForSeq2SeqLM.from_pretrained("Falconsai/medical_summarization")
MAX_INPUT_LEN = 2048  # adjust depending on your model capacity
MAX_TARGET_LEN = 512  # adjust

def encode_example(input_text: str, target_text: str):
    # encode input and target, with truncation if too long
    enc = tokenizer(
        input_text,
        truncation=True,
        max_length=MAX_INPUT_LEN,
        padding=False,
        return_tensors="pt",
    )
    dec = tokenizer(
        target_text,
        truncation=True,
        max_length=MAX_TARGET_LEN,
        padding=False,
        return_tensors="pt",
    )
    return {
        'input_ids': enc.input_ids.squeeze(0),
        'attention_mask': enc.attention_mask.squeeze(0),
        'labels': dec.input_ids.squeeze(0)
    }

# Example application
encoded = encode_example(df.iloc[0]['input_clean'], df.iloc[0]['target_clean'])
print(f"Setup successful! Input shape: {encoded['input_ids'].shape}, Labels shape: {encoded['labels'].shape}")




Setup successful! Input shape: torch.Size([2048]), Labels shape: torch.Size([285])


Creating dataset variants (full, section-filtered, chunked)

4.1 Section filtering

In [66]:
# Example: extract ‚ÄúHOSPITAL COURSE‚Äù section if present
def extract_section(text: str, section_header: str) -> str:
    # naive: split by the header and then stop when next header appears
    # This assumes section headers in uppercase and followed by colon, e.g. ‚ÄúHOSPITAL COURSE:‚Äù
    pattern = rf"{section_header}\s*:"
    parts = re.split(pattern, text, flags=re.IGNORECASE)
    if len(parts) <= 1:
        return ""  # section not found
    rest = parts[1]
    # Now stop when you see another section header (e.g. all caps + colon)
    # e.g. next_header = uppercase letters + colon
    m = re.search(r"\n[A-Z ]{3,50}:\s", rest)
    if m:
        rest = rest[: m.start()]
    return rest.strip()

# Example: create version of input using only ‚ÄúHOSPITAL COURSE‚Äù section
df['section_hcourse'] = df['input_clean'].map(lambda x: extract_section(x, "HOSPITAL COURSE"))

# Use that as an alternative input version if not empty, else fallback to full
df['input_section_or_full'] = df.apply(lambda row: row['section_hcourse'] if len(row['section_hcourse']) > 0 else row['input_clean'], axis=1)


4.2 Chunking / sliding windows

In [67]:
def chunk_input(input_ids: torch.Tensor, attention_mask: torch.Tensor, chunk_size: int, stride: int = None):
    # Returns list of (chunk_input_ids, chunk_masks)
    if stride is None:
        stride = chunk_size  # non overlapping
    chunks = []
    L = input_ids.size(0)
    start = 0
    while start < L:
        end = min(start + chunk_size, L)
        chunk_ids = input_ids[start:end]
        chunk_mask = attention_mask[start:end]
        chunks.append((chunk_ids, chunk_mask))
        if end == L:
            break
        start += stride
    return chunks

# Example: for each example, produce chunks
CHUNK_SIZE = 1024
STRIDE = 512

row = encoded  # from earlier encode_example
input_ids = row['input_ids']
mask = row['attention_mask']
chunks = chunk_input(input_ids, mask, CHUNK_SIZE, STRIDE)
print(f"Number of chunks: {len(chunks)}")


Number of chunks: 3


Custom Dataset & DataLoader

In [68]:
class BHC_Dataset(Dataset):
    def __init__(self, df, tokenizer, max_input_len=2048, max_target_len=512, variant="full"):
        self.df = df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_input_len = max_input_len
        self.max_target_len = max_target_len
        self.variant = variant  # e.g. "full", "section_filtered", etc.

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        if self.variant == "section_filtered":
            input_text = row.get('input_section_or_full', row['input_clean'])
        else:
            input_text = row['input_clean']
        target_text = row['target_clean']
        enc = self.tokenizer(
            input_text,
            truncation=True,
            max_length=self.max_input_len,
            padding=False,
            return_tensors="pt",
        )
        dec = self.tokenizer(
            target_text,
            truncation=True,
            max_length=self.max_target_len,
            padding=False,
            return_tensors="pt",
        )
        item = {
            'input_ids': enc.input_ids.squeeze(0),
            'attention_mask': enc.attention_mask.squeeze(0),
            'labels': dec.input_ids.squeeze(0)
        }
        return item

# Example usage
ds = BHC_Dataset(df, tokenizer, max_input_len=2048, max_target_len=512, variant="full")
dl = DataLoader(ds, batch_size=4, shuffle=True, collate_fn=lambda x: x)


Logging / Documentation

In [69]:
def log_stats(df_before: pd.DataFrame, df_after: pd.DataFrame, step_name: str):
    print(f"Step: {step_name}")
    print(f"Before: {len(df_before)} examples")
    print(f"After: {len(df_after)} examples")
    dropped = len(df_before) - len(df_after)
    print(f"Dropped: {dropped} examples ({100 * dropped / len(df_before):.1f}%)")


In [70]:
df0 = df.copy()
# filtering step
df = df[mask_good].copy()
log_stats(df0, df, "drop short / empty examples")


Step: drop short / empty examples
Before: 269059 examples
After: 269059 examples
Dropped: 0 examples (0.0%)


  df = df[mask_good].copy()


In [71]:
df.to_csv("bhc_preprocessed_v1.csv", index=False)

In [72]:
def summarize(text, max_input_length=1024, max_summary_length=256):
    inputs = tokenizer(text, return_tensors="pt",
                       truncation=True, max_length=max_input_length, 
                       padding=True)  # or pad to a min batch size
    output = model.generate(**inputs, max_length=max_summary_length, 
                            num_beams=2, early_stopping=True)
    summary = tokenizer.decode(output[0], skip_special_tokens=True)
    return summary

In [73]:
print(df.iloc[0]['input_clean'])

<SEX> F <SERVICE> MEDICINE <ALLERGIES> No Known Allergies / Adverse Drug Reactions <ATTENDING> ___ <CHIEF COMPLAINT> Worsening ABD distension and pain <MAJOR SURGICAL OR INVASIVE PROCEDURE> Paracentesis <HISTORY OF PRESENT ILLNESS> ___ HCV cirrhosis c/b ascites, hiv on ART, h/o IVDU, COPD, bioplar, PTSD, presented from OSH ED with worsening abd distension over past week. Pt reports self-discontinuing lasix and spirnolactone ___ weeks ago, because she feels like "they don't do anything" and that she "doesn't want to put more chemicals in her." She does not follow Na-restricted diets. In the past week, she notes that she has been having worsening abd distension and discomfort. She denies ___ edema, or SOB, or orthopnea. She denies f/c/n/v, d/c, dysuria. She had food poisoning a week ago from eating stale cake (n/v 20 min after food ingestion), which resolved the same day. She denies other recent illness or sick contacts. She notes that she has been noticing gum bleeding while brushing he

In [74]:
# Example
instruction = (
    "Summarize the following hospital discharge note into a concise, factual "
    "Brief Hospital Course summary:\n\n"
)
summary = summarize(instruction + df.iloc[0]['input_clean'], max_input_length=1024, max_summary_length=300)
print(summary)

SEX> SERVICE> MEDICINE ALLERGIES> No Known Allergies / Adverse Drug Reactions ATTENDING> ___ CHIEF COMPLAINT> Worsening abd distension and pain MAJOR SURGICAL OR INVASIVE PROCEDURE> Paracentesis HISTORY OF PRESENT ILLNESS> ___ HCV cirrhosis c/b ascites, hiv on ART, h/o IVDU, COPD, bioplar, PTSD, presented from OSH ED with worsening abd distension over past week.


Train / Validation Split

In [75]:
from sklearn.model_selection import train_test_split

train_df, val_df = train_test_split(df, test_size=0.1, random_state=42)
print(f"Original Training length: {len(train_df)} | Original Validation Length: {len(val_df)}")


Original Training length: 242153 | Original Validation Length: 26906


sample the dataset

In [76]:
# train_df = train_df.sample(2000, random_state=42)
# val_df = val_df.sample(500, random_state=42)
# print(f"Sampled Training length: {len(train_df)} | Sampled Validation Length: {len(val_df)}")

In [77]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

model_name = "Falconsai/medical_summarization"
tokenizer = AutoTokenizer.from_pretrained(model_name,use_fast=True)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)  # This is PyTorch by default




In [78]:
from torch.utils.data import Dataset, DataLoader

class SummarizationDataset(Dataset):
    def __init__(self, df, tokenizer, max_input_len=1024, max_target_len=256):
        self.df = df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_input_len = max_input_len
        self.max_target_len = max_target_len

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        inputs = tokenizer(
            row["input_clean"],
            truncation=True,
            padding="max_length",
            max_length=self.max_input_len,
            return_tensors="pt",
        )
        targets = tokenizer(
            row["target_clean"],
            truncation=True,
            padding="max_length",
            max_length=self.max_target_len,
            return_tensors="pt",
        )
        return {
            "input_ids": inputs.input_ids.squeeze(),
            "attention_mask": inputs.attention_mask.squeeze(),
            "labels": targets.input_ids.squeeze(),
        }

train_dataset = SummarizationDataset(train_df, tokenizer)
val_dataset = SummarizationDataset(val_df, tokenizer)
model.gradient_checkpointing_enable()

Fine-Tuning

LORA config adapter

In [79]:
# Enable optimizations
torch.backends.cudnn.benchmark = True     # Good for fixed input sizes
torch.backends.cuda.matmul.allow_tf32 = True  # Good for newer GPUs
# Safe GPU optimizations
try:
    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True
        if hasattr(torch.backends.cuda.matmul, 'allow_tf32'):
            torch.backends.cuda.matmul.allow_tf32 = True
        print("‚úÖ GPU optimizations enabled")
    else:
        print("‚ö†Ô∏è CUDA not available, skipping GPU optimizations")
except Exception as e:
    print(f"‚ö†Ô∏è Could not enable optimizations: {e}")

‚úÖ GPU optimizations enabled


In [80]:
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type="SEQ_2_SEQ_LM",
)
model = get_peft_model(model, lora_config)

In [81]:
from transformers import Trainer, TrainingArguments, TrainerCallback
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="steps",
    eval_steps=1000,
    save_steps=1000,
    warmup_steps=100,                      # üî• Add warmup for stability
    # max_steps=300,                     # üî• Move max_steps here
    per_device_train_batch_size=2,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=8,
    num_train_epochs=1,
    learning_rate=5e-5,
    weight_decay=0.01,
    fp16=True,                     # üî• enable mixed precision
    save_total_limit=3,
    report_to="none",
    dataloader_num_workers=0,      # fewer threads on Windows
    logging_steps=200,
    resume_from_checkpoint=True,  # üî• Add this
    dataloader_pin_memory=False,           # üî• Reduce memory usage
    remove_unused_columns=True,            # üî• Clean up unused data
    gradient_checkpointing=True,           # üî• Trade compute for memory
    optim="adamw_torch",  # üî• Use non-deprecated AdamW

)
# Enhanced callback for longer training
class EnhancedCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None:
            # Add timestamp and step info
            import datetime
            timestamp = datetime.datetime.now().strftime("%H:%M:%S")
            if 'loss' in logs:
                print(f"[{timestamp}] Step {state.global_step}: Loss = {logs['loss']:.4f}")
            if 'eval_loss' in logs:
                print(f"[{timestamp}] Step {state.global_step}: Eval Loss = {logs['eval_loss']:.4f}")

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    callbacks=[EnhancedCallback()],
)
# -------------------------------
# 8. Train (start small)
# -------------------------------
trainer.train()

  0%|          | 0/15134 [00:00<?, ?it/s]

You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  return fn(*args, **kwargs)


[08:52:27] Step 200: Loss = 4.6693
{'loss': 4.6693, 'learning_rate': 4.967407210323268e-05, 'epoch': 0.01}
[08:58:08] Step 400: Loss = 3.9221
{'loss': 3.9221, 'learning_rate': 4.9008913130238126e-05, 'epoch': 0.03}
[09:03:40] Step 600: Loss = 3.7591
{'loss': 3.7591, 'learning_rate': 4.834375415724358e-05, 'epoch': 0.04}
[09:09:13] Step 800: Loss = 3.6854
{'loss': 3.6854, 'learning_rate': 4.768192097911401e-05, 'epoch': 0.05}
[09:14:51] Step 1000: Loss = 3.6661
{'loss': 3.6661, 'learning_rate': 4.7016762006119464e-05, 'epoch': 0.07}


  0%|          | 0/6727 [00:00<?, ?it/s]

[09:27:59] Step 1000: Eval Loss = 3.4089
{'eval_loss': 3.4088950157165527, 'eval_runtime': 787.5448, 'eval_samples_per_second': 34.164, 'eval_steps_per_second': 8.542, 'epoch': 0.07}


  return fn(*args, **kwargs)


[09:33:32] Step 1200: Loss = 3.6195
{'loss': 3.6195, 'learning_rate': 4.635160303312492e-05, 'epoch': 0.08}
[09:39:10] Step 1400: Loss = 3.5908
{'loss': 3.5908, 'learning_rate': 4.568644406013037e-05, 'epoch': 0.09}
[09:44:47] Step 1600: Loss = 3.5882
{'loss': 3.5882, 'learning_rate': 4.50246108820008e-05, 'epoch': 0.11}
[09:50:24] Step 1800: Loss = 3.5732
{'loss': 3.5732, 'learning_rate': 4.4359451909006256e-05, 'epoch': 0.12}
[09:56:01] Step 2000: Loss = 3.5446
{'loss': 3.5446, 'learning_rate': 4.369429293601171e-05, 'epoch': 0.13}


  0%|          | 0/6727 [00:00<?, ?it/s]

[10:09:04] Step 2000: Eval Loss = 3.3187
{'eval_loss': 3.3187360763549805, 'eval_runtime': 783.3279, 'eval_samples_per_second': 34.348, 'eval_steps_per_second': 8.588, 'epoch': 0.13}


  return fn(*args, **kwargs)


[10:14:40] Step 2200: Loss = 3.5374
{'loss': 3.5374, 'learning_rate': 4.302913396301716e-05, 'epoch': 0.15}
[10:20:15] Step 2400: Loss = 3.5109
{'loss': 3.5109, 'learning_rate': 4.236397499002262e-05, 'epoch': 0.16}
[10:25:51] Step 2600: Loss = 3.4881
{'loss': 3.4881, 'learning_rate': 4.169881601702807e-05, 'epoch': 0.17}
[10:31:26] Step 2800: Loss = 3.4950
{'loss': 3.495, 'learning_rate': 4.1033657044033525e-05, 'epoch': 0.19}
[10:36:53] Step 3000: Loss = 3.4866
{'loss': 3.4866, 'learning_rate': 4.036849807103898e-05, 'epoch': 0.2}


  0%|          | 0/6727 [00:00<?, ?it/s]

[10:49:52] Step 3000: Eval Loss = 3.2640
{'eval_loss': 3.2639641761779785, 'eval_runtime': 779.4557, 'eval_samples_per_second': 34.519, 'eval_steps_per_second': 8.63, 'epoch': 0.2}


  return fn(*args, **kwargs)


[10:55:28] Step 3200: Loss = 3.4777
{'loss': 3.4777, 'learning_rate': 3.970333909804443e-05, 'epoch': 0.21}
[11:01:03] Step 3400: Loss = 3.4725
{'loss': 3.4725, 'learning_rate': 3.903818012504989e-05, 'epoch': 0.22}
[11:06:38] Step 3600: Loss = 3.4624
{'loss': 3.4624, 'learning_rate': 3.837302115205535e-05, 'epoch': 0.24}
[11:12:15] Step 3800: Loss = 3.4404
{'loss': 3.4404, 'learning_rate': 3.77078621790608e-05, 'epoch': 0.25}
[11:17:51] Step 4000: Loss = 3.4249
{'loss': 3.4249, 'learning_rate': 3.7042703206066255e-05, 'epoch': 0.26}


  0%|          | 0/6727 [00:00<?, ?it/s]

[11:30:52] Step 4000: Eval Loss = 3.2266
{'eval_loss': 3.2266054153442383, 'eval_runtime': 780.4214, 'eval_samples_per_second': 34.476, 'eval_steps_per_second': 8.62, 'epoch': 0.26}


  return fn(*args, **kwargs)


[11:36:27] Step 4200: Loss = 3.4097
{'loss': 3.4097, 'learning_rate': 3.63775442330717e-05, 'epoch': 0.28}
[11:42:04] Step 4400: Loss = 3.4158
{'loss': 3.4158, 'learning_rate': 3.5712385260077156e-05, 'epoch': 0.29}
[11:47:40] Step 4600: Loss = 3.4237
{'loss': 3.4237, 'learning_rate': 3.504722628708261e-05, 'epoch': 0.3}
[11:53:15] Step 4800: Loss = 3.4155
{'loss': 3.4155, 'learning_rate': 3.438206731408807e-05, 'epoch': 0.32}
[11:58:52] Step 5000: Loss = 3.3958
{'loss': 3.3958, 'learning_rate': 3.3716908341093525e-05, 'epoch': 0.33}


  0%|          | 0/6727 [00:00<?, ?it/s]

[12:11:51] Step 5000: Eval Loss = 3.1967
{'eval_loss': 3.1966910362243652, 'eval_runtime': 779.546, 'eval_samples_per_second': 34.515, 'eval_steps_per_second': 8.629, 'epoch': 0.33}


  return fn(*args, **kwargs)


[12:17:23] Step 5200: Loss = 3.4009
{'loss': 3.4009, 'learning_rate': 3.305174936809898e-05, 'epoch': 0.34}
[12:22:50] Step 5400: Loss = 3.3680
{'loss': 3.368, 'learning_rate': 3.238659039510443e-05, 'epoch': 0.36}
[12:28:24] Step 5600: Loss = 3.4034
{'loss': 3.4034, 'learning_rate': 3.1721431422109886e-05, 'epoch': 0.37}
[12:33:58] Step 5800: Loss = 3.4047
{'loss': 3.4047, 'learning_rate': 3.105627244911534e-05, 'epoch': 0.38}
[12:39:32] Step 6000: Loss = 3.3736
{'loss': 3.3736, 'learning_rate': 3.039111347612079e-05, 'epoch': 0.4}


  0%|          | 0/6727 [00:00<?, ?it/s]

[12:52:30] Step 6000: Eval Loss = 3.1707
{'eval_loss': 3.1706550121307373, 'eval_runtime': 777.8587, 'eval_samples_per_second': 34.59, 'eval_steps_per_second': 8.648, 'epoch': 0.4}


  return fn(*args, **kwargs)


[12:58:05] Step 6200: Loss = 3.3642
{'loss': 3.3642, 'learning_rate': 2.9725954503126248e-05, 'epoch': 0.41}
[13:03:40] Step 6400: Loss = 3.3659
{'loss': 3.3659, 'learning_rate': 2.9060795530131702e-05, 'epoch': 0.42}
[13:09:15] Step 6600: Loss = 3.3771
{'loss': 3.3771, 'learning_rate': 2.8395636557137156e-05, 'epoch': 0.44}
[13:14:49] Step 6800: Loss = 3.3611
{'loss': 3.3611, 'learning_rate': 2.773047758414261e-05, 'epoch': 0.45}
[13:20:24] Step 7000: Loss = 3.3717
{'loss': 3.3717, 'learning_rate': 2.7065318611148067e-05, 'epoch': 0.46}


  0%|          | 0/6727 [00:00<?, ?it/s]

[13:33:20] Step 7000: Eval Loss = 3.1493
{'eval_loss': 3.1492719650268555, 'eval_runtime': 775.308, 'eval_samples_per_second': 34.704, 'eval_steps_per_second': 8.677, 'epoch': 0.46}


  return fn(*args, **kwargs)


[13:38:55] Step 7200: Loss = 3.3637
{'loss': 3.3637, 'learning_rate': 2.640015963815352e-05, 'epoch': 0.48}
[13:44:30] Step 7400: Loss = 3.3744
{'loss': 3.3744, 'learning_rate': 2.5735000665158975e-05, 'epoch': 0.49}
[13:50:06] Step 7600: Loss = 3.3739
{'loss': 3.3739, 'learning_rate': 2.5069841692164432e-05, 'epoch': 0.5}
[13:55:42] Step 7800: Loss = 3.3527
{'loss': 3.3527, 'learning_rate': 2.4408008514034856e-05, 'epoch': 0.52}
[14:01:20] Step 8000: Loss = 3.3541
{'loss': 3.3541, 'learning_rate': 2.374284954104031e-05, 'epoch': 0.53}


  0%|          | 0/6727 [00:00<?, ?it/s]

[14:14:19] Step 8000: Eval Loss = 3.1320
{'eval_loss': 3.1320409774780273, 'eval_runtime': 779.3039, 'eval_samples_per_second': 34.526, 'eval_steps_per_second': 8.632, 'epoch': 0.53}


  return fn(*args, **kwargs)


[14:19:55] Step 8200: Loss = 3.3391
{'loss': 3.3391, 'learning_rate': 2.3077690568045767e-05, 'epoch': 0.54}
[14:25:32] Step 8400: Loss = 3.3444
{'loss': 3.3444, 'learning_rate': 2.2412531595051217e-05, 'epoch': 0.56}
[14:31:09] Step 8600: Loss = 3.3400
{'loss': 3.34, 'learning_rate': 2.1750698416921644e-05, 'epoch': 0.57}
[14:36:45] Step 8800: Loss = 3.3704
{'loss': 3.3704, 'learning_rate': 2.1085539443927098e-05, 'epoch': 0.58}
[14:42:22] Step 9000: Loss = 3.3378
{'loss': 3.3378, 'learning_rate': 2.0420380470932555e-05, 'epoch': 0.59}


  0%|          | 0/6727 [00:00<?, ?it/s]

[14:55:20] Step 9000: Eval Loss = 3.1172
{'eval_loss': 3.117156505584717, 'eval_runtime': 778.0574, 'eval_samples_per_second': 34.581, 'eval_steps_per_second': 8.646, 'epoch': 0.59}


  return fn(*args, **kwargs)


[15:00:59] Step 9200: Loss = 3.3499
{'loss': 3.3499, 'learning_rate': 1.975522149793801e-05, 'epoch': 0.61}
[15:06:36] Step 9400: Loss = 3.3240
{'loss': 3.324, 'learning_rate': 1.909006252494346e-05, 'epoch': 0.62}
[15:12:15] Step 9600: Loss = 3.3560
{'loss': 3.356, 'learning_rate': 1.8424903551948917e-05, 'epoch': 0.63}
[15:17:52] Step 9800: Loss = 3.3163
{'loss': 3.3163, 'learning_rate': 1.775974457895437e-05, 'epoch': 0.65}
[15:23:30] Step 10000: Loss = 3.2980
{'loss': 3.298, 'learning_rate': 1.7094585605959825e-05, 'epoch': 0.66}


  0%|          | 0/6727 [00:00<?, ?it/s]

[15:36:30] Step 10000: Eval Loss = 3.1070
{'eval_loss': 3.1070311069488525, 'eval_runtime': 779.4766, 'eval_samples_per_second': 34.518, 'eval_steps_per_second': 8.63, 'epoch': 0.66}


  return fn(*args, **kwargs)


[15:42:06] Step 10200: Loss = 3.3195
{'loss': 3.3195, 'learning_rate': 1.6429426632965282e-05, 'epoch': 0.67}
[15:47:45] Step 10400: Loss = 3.3145
{'loss': 3.3145, 'learning_rate': 1.5764267659970732e-05, 'epoch': 0.69}
[15:53:22] Step 10600: Loss = 3.3288
{'loss': 3.3288, 'learning_rate': 1.5099108686976188e-05, 'epoch': 0.7}
[15:58:59] Step 10800: Loss = 3.3164
{'loss': 3.3164, 'learning_rate': 1.4433949713981642e-05, 'epoch': 0.71}
[16:04:36] Step 11000: Loss = 3.3090
{'loss': 3.309, 'learning_rate': 1.3768790740987098e-05, 'epoch': 0.73}


  0%|          | 0/6727 [00:00<?, ?it/s]

[16:17:38] Step 11000: Eval Loss = 3.0974
{'eval_loss': 3.097398042678833, 'eval_runtime': 781.8101, 'eval_samples_per_second': 34.415, 'eval_steps_per_second': 8.604, 'epoch': 0.73}


  return fn(*args, **kwargs)


[16:23:14] Step 11200: Loss = 3.3256
{'loss': 3.3256, 'learning_rate': 1.310363176799255e-05, 'epoch': 0.74}
[16:28:53] Step 11400: Loss = 3.3268
{'loss': 3.3268, 'learning_rate': 1.2438472794998005e-05, 'epoch': 0.75}
[16:34:31] Step 11600: Loss = 3.3116
{'loss': 3.3116, 'learning_rate': 1.177331382200346e-05, 'epoch': 0.77}
[16:40:10] Step 11800: Loss = 3.3169
{'loss': 3.3169, 'learning_rate': 1.1108154849008915e-05, 'epoch': 0.78}
[16:45:50] Step 12000: Loss = 3.3098
{'loss': 3.3098, 'learning_rate': 1.0442995876014367e-05, 'epoch': 0.79}


  0%|          | 0/6727 [00:00<?, ?it/s]

[16:58:53] Step 12000: Eval Loss = 3.0911
{'eval_loss': 3.0910980701446533, 'eval_runtime': 783.7654, 'eval_samples_per_second': 34.329, 'eval_steps_per_second': 8.583, 'epoch': 0.79}


  return fn(*args, **kwargs)


[17:04:30] Step 12200: Loss = 3.2782
{'loss': 3.2782, 'learning_rate': 9.777836903019823e-06, 'epoch': 0.81}
[17:10:00] Step 12400: Loss = 3.3264
{'loss': 3.3264, 'learning_rate': 9.112677930025277e-06, 'epoch': 0.82}
[17:15:34] Step 12600: Loss = 3.3101
{'loss': 3.3101, 'learning_rate': 8.44751895703073e-06, 'epoch': 0.83}
[17:21:15] Step 12800: Loss = 3.2932
{'loss': 3.2932, 'learning_rate': 7.782359984036184e-06, 'epoch': 0.85}
[17:26:56] Step 13000: Loss = 3.2950
{'loss': 3.295, 'learning_rate': 7.11720101104164e-06, 'epoch': 0.86}


  0%|          | 0/6727 [00:00<?, ?it/s]

[17:39:55] Step 13000: Eval Loss = 3.0867
{'eval_loss': 3.0867319107055664, 'eval_runtime': 779.4742, 'eval_samples_per_second': 34.518, 'eval_steps_per_second': 8.63, 'epoch': 0.86}


  return fn(*args, **kwargs)


[17:45:34] Step 13200: Loss = 3.3264
{'loss': 3.3264, 'learning_rate': 6.452042038047093e-06, 'epoch': 0.87}
[17:51:11] Step 13400: Loss = 3.3157
{'loss': 3.3157, 'learning_rate': 5.790208859917521e-06, 'epoch': 0.89}
[17:56:49] Step 13600: Loss = 3.3280
{'loss': 3.328, 'learning_rate': 5.125049886922975e-06, 'epoch': 0.9}
[18:02:25] Step 13800: Loss = 3.3108
{'loss': 3.3108, 'learning_rate': 4.459890913928429e-06, 'epoch': 0.91}
[18:08:02] Step 14000: Loss = 3.3200
{'loss': 3.32, 'learning_rate': 3.7947319409338836e-06, 'epoch': 0.93}


  0%|          | 0/6727 [00:00<?, ?it/s]

[18:21:01] Step 14000: Eval Loss = 3.0839
{'eval_loss': 3.0838568210601807, 'eval_runtime': 778.9819, 'eval_samples_per_second': 34.54, 'eval_steps_per_second': 8.636, 'epoch': 0.93}


  return fn(*args, **kwargs)


[18:26:38] Step 14200: Loss = 3.3151
{'loss': 3.3151, 'learning_rate': 3.129572967939338e-06, 'epoch': 0.94}
[18:32:16] Step 14400: Loss = 3.2784
{'loss': 3.2784, 'learning_rate': 2.4644139949447918e-06, 'epoch': 0.95}
[18:37:54] Step 14600: Loss = 3.2953
{'loss': 3.2953, 'learning_rate': 1.799255021950246e-06, 'epoch': 0.96}
[18:43:32] Step 14800: Loss = 3.3097
{'loss': 3.3097, 'learning_rate': 1.1340960489557006e-06, 'epoch': 0.98}
[18:49:08] Step 15000: Loss = 3.3244
{'loss': 3.3244, 'learning_rate': 4.6893707596115475e-07, 'epoch': 0.99}


  0%|          | 0/6727 [00:00<?, ?it/s]

[19:02:12] Step 15000: Eval Loss = 3.0827
{'eval_loss': 3.0826780796051025, 'eval_runtime': 783.5047, 'eval_samples_per_second': 34.341, 'eval_steps_per_second': 8.586, 'epoch': 0.99}


  return fn(*args, **kwargs)


{'train_runtime': 37175.6398, 'train_samples_per_second': 6.514, 'train_steps_per_second': 0.407, 'train_loss': 3.4158053352131117, 'epoch': 1.0}


TrainOutput(global_step=15134, training_loss=3.4158053352131117, metrics={'train_runtime': 37175.6398, 'train_samples_per_second': 6.514, 'train_steps_per_second': 0.407, 'train_loss': 3.4158053352131117, 'epoch': 1.0})

In [82]:
model.save_pretrained("./medical_summarizer_lora_v1")
tokenizer.save_pretrained("./medical_summarizer_lora_v1")
print("‚úÖ Training complete and model saved.")

‚úÖ Training complete and model saved.


EVALUATION: rouge or bleu
Common metrics available: "bleu", "rouge", "meteor", "bertscore", "exact_match", etc.

In [83]:
import evaluate
from tqdm import tqdm

metric = evaluate.load("rouge")

def evaluate_model(dataset, num_samples=50):
    model.eval()
    preds, refs = [], []
    for i in tqdm(range(num_samples)):
        text = dataset.df.iloc[i]["input_clean"]
        reference = dataset.df.iloc[i]["target_clean"]
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=2048).to(model.device)
        with torch.no_grad():
            output = model.generate(**inputs, max_new_tokens=256)
        pred = tokenizer.decode(output[0], skip_special_tokens=True)
        preds.append(pred)
        refs.append(reference)
    return metric.compute(predictions=preds, references=refs)

results = evaluate_model(val_dataset, num_samples=30)
print(results)


Downloading builder script: 0.00B [00:00, ?B/s]

100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 30/30 [02:25<00:00,  4.84s/it]


{'rouge1': 0.24322987122995235, 'rouge2': 0.061644801138736, 'rougeL': 0.15954732842055513, 'rougeLsum': 0.15995710914879385}
