Setup & imports

In [1]:
import re
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
import math
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import TrainingArguments, Trainer, DataCollatorForSeq2Seq, Seq2SeqTrainer
from peft import LoraConfig, get_peft_model
from transformers import BitsAndBytesConfig
from sklearn.model_selection import train_test_split
from tqdm import tqdm


In [2]:
# Reading CSV file using pandas
# Basic usage:
df = pd.read_csv(r"D:\UT_Austin\sem4\physionet.org\files\labelled-notes-hospital-course\1.1.0\mimic-iv-bhc.csv")

print("CSV file loaded successfully!")
print(f"Shape: {df.shape}")
print(f"Columns: {list(df.columns)}")
df.head()  # Display first 5 rows

CSV file loaded successfully!
Shape: (270033, 5)
Columns: ['note_id', 'input', 'target', 'input_tokens', 'target_tokens']


Unnamed: 0,note_id,input,target,input_tokens,target_tokens
0,10000032-DS-21,<SEX> F <SERVICE> MEDICINE <ALLERGIES> No Know...,"___ HCV cirrhosis c/b ascites, hiv on ART, h/o...",1946,231
1,10000032-DS-22,<SEX> F <SERVICE> MEDICINE <ALLERGIES> Percoce...,"___ with HIV on HAART, HCV cirrhosis with asci...",2183,810
2,10000117-DS-21,<SEX> F <SERVICE> MEDICINE <ALLERGIES> omepraz...,Ms. ___ is a ___ with history of GERD who pres...,1060,172
3,10000117-DS-22,<SEX> F <SERVICE> ORTHOPAEDICS <ALLERGIES> ome...,The patient presented to the emergency departm...,1195,330
4,10000248-DS-10,<SEX> M <SERVICE> MEDICINE <ALLERGIES> No Know...,Mr. ___ is a ___ with history of mild FVIII de...,1961,230


Basic cleaning / normalization functions

In [3]:
def normalize_whitespace(text: str) -> str:
    # collapse multiple spaces / newlines into single
    text = re.sub(r'\r', ' ', text)
    text = re.sub(r'\n+', '\n', text)
    text = re.sub(r'\s+', ' ', text)
    return text.strip()

def clean_placeholders(text: str) -> str:
    # e.g. ensure consistent casing / spacing around placeholders
    # This is generic; adapt if you have more placeholder variants
    text = re.sub(r'<\s*SEX\s*>', '<SEX>', text, flags=re.IGNORECASE)
    text = re.sub(r'<\s*SERVICE\s*>', '<SERVICE>', text, flags=re.IGNORECASE)
    text = re.sub(r'<\s*ALLERGIES\s*>', '<ALLERGIES>', text, flags=re.IGNORECASE)
    return text

def remove_control_characters(text: str) -> str:
    # remove non-printable / weird control chars
    # e.g. anything in category Cc except \n
    cleaned = ''.join(ch for ch in text if (ch == '\n' or (ord(ch) >= 32 and ord(ch) != 127)))
    return cleaned

def preprocess_text(text: str) -> str:
    if text is None:
        return ""
    text = normalize_whitespace(text)
    text = clean_placeholders(text)
    text = remove_control_characters(text)
    return text


In [4]:
df['input_clean'] = df['input'].map(preprocess_text)
df['target_clean'] = df['target'].map(preprocess_text)

Config

In [5]:
# ---------- Config ----------
MODEL_NAME = "Falconsai/medical_summarization"   # or "Mahalingam/DistilBart-Med-Summary"
MAX_INPUT_LEN = 1024
MAX_CHUNK_LEN = 768            # <= MAX_INPUT_LEN to leave room for prompt tokens
MAX_SUMMARY_LEN = 300
CHUNK_STRIDE = 384             # overlap to avoid boundary losses
TOP_SENT_PER_CHUNK = 3         # H2: how many sentences to keep per chunk before final merge
FINAL_SUMMARY_LEN = 350
INSTR = (
    "Summarize the following hospital discharge note into a concise, factual "
    "Brief Hospital Course. Focus on diagnoses, key events, treatments, and disposition.\n\n"
)

# ---------- Filtering using token columns ----------

In [6]:
def filter_pairs_by_tokens(df, min_in=100, min_out=20, out_longer_than_in_ratio=1.3):
    before = len(df)
    ok = (df["input_tokens"] >= min_in) & (df["target_tokens"] >= min_out) & \
         (df["target_tokens"] < (df["input_tokens"] * out_longer_than_in_ratio))
    df2 = df[ok].copy()
    print(f"[filter] kept {len(df2)}/{before} rows")
    return df2

# ---------- Section extraction (robust headers) ----------

In [7]:
SECTION_HEADERS = [
    r"HOSPITAL COURSE", r"BRIEF HOSPITAL COURSE", r"HISTORY OF PRESENT ILLNESS",
    r"ASSESSMENT", r"ASSESSMENT/PLAN", r"PLAN", r"IMPRESSION"
]

def extract_section(text: str, header_regex: str) -> str:
    # Find section by header, stop at next ALL CAPS header ending with colon or newline
    m = re.search(rf"(?P<h>{header_regex})\s*:?[\r\n]+", text, flags=re.IGNORECASE)
    if not m:
        return ""
    start = m.end()
    # next header: lines with ALL CAPS words and optional slashes followed by ":" or newline
    m2 = re.search(r"\n[A-Z][A-Z/\- ]{2,50}(:|\n)", text[start:])
    end = start + (m2.start() if m2 else len(text) - start)
    return text[start:end].strip()

def best_section_or_full(text: str) -> str:
    # Prefer HOSPITAL COURSE; then HPI; then ASSESSMENT/PLAN; else full text
    for hdr in [r"HOSPITAL COURSE", r"BRIEF HOSPITAL COURSE",
                r"HISTORY OF PRESENT ILLNESS", r"ASSESSMENT/PLAN", r"ASSESSMENT", r"PLAN"]:
        sec = extract_section(text, hdr)
        if len(sec) > 50:   # minimal useful length
            return sec
    return text

# ---------- Summarization helpers ----------

In [8]:
torch.set_grad_enabled(False)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to("cuda")
model.config.use_cache = False   # important for training seq2seq with LoRA
def generate_summary(text: str, max_in=MAX_INPUT_LEN, max_out=MAX_SUMMARY_LEN):
    enc = tokenizer(INSTR + text, return_tensors="pt", truncation=True, max_length=max_in).to(model.device)
    out = model.generate(
        **enc,
        max_length=max_out,
        num_beams=4,
        temperature=0.9,
        repetition_penalty=2.0,
        early_stopping=True
    )
    return tokenizer.decode(out[0], skip_special_tokens=True)



# ---------- Your existing chunker: now used for real ----------

In [9]:
def chunk_ids_attention(input_ids: torch.Tensor, attention_mask: torch.Tensor,
                        chunk_size: int, stride: int):
    # same behavior as your chunk_input, but returns easy dicts
    L = input_ids.size(0)
    chunks = []
    start = 0
    while start < L:
        end = min(start + chunk_size, L)
        chunks.append({
            "input_ids": input_ids[start:end],
            "attention_mask": attention_mask[start:end]
        })
        if end == L:
            break
        start += stride
    return chunks

In [10]:
def text_to_token_chunks(text: str, chunk_size=MAX_CHUNK_LEN, stride=CHUNK_STRIDE):
    enc = tokenizer(INSTR + text, return_tensors="pt", truncation=False)
    ids, mask = enc["input_ids"].squeeze(0), enc["attention_mask"].squeeze(0)
    chunks = chunk_ids_attention(ids, mask, chunk_size, stride)
    # decode chunks back to text so we can summarize each
    texts = []
    for ch in chunks:
        t = tokenizer.decode(ch["input_ids"], skip_special_tokens=True)
        # remove the instruction that got baked into the first chunk decode
        t = t.replace(INSTR.strip(), "").strip()
        texts.append(t)
    return texts

# ---------- H2 weighted merge utilities ----------

In [11]:
MED_KEYWORDS = set("""
cirrhosis ascites encephalopathy copd pneumonia sepsis arf aki ckd chf hf pe dvt mi nstemi stemi stroke cva
hiv art haart hbv hcv hepatitis varices gi bleed gi-bleed hematemesis melena
diuretic furosemide spironolactone lactulose rifaximin ceftriaxone vancomycin zosyn
paracentesis thoracentesis intubation extubation dialysis hemodialysis crrt
na k hgb wbc plt creatinine bun bilirubin inr pt aptt lft ast alt alk phos troponin
""".split())

In [12]:
def sentence_split(text: str):
    # simple splitter; you can swap for spacy if available
    return [s.strip() for s in re.split(r'(?<=[\.\?!])\s+', text) if len(s.strip()) > 0]

def score_sentence(s: str):
    # heuristic score = keyword hits + numbers + medical abbreviations
    tokens = re.findall(r"[A-Za-z0-9\-/+\.]+", s.lower())
    hits = sum(1 for t in tokens if t in MED_KEYWORDS)
    nums = sum(1 for t in tokens if re.fullmatch(r"\d+(\.\d+)?", t))
    labs = len(re.findall(r"\b(wbc|hgb|na|k|cr|bun|inr|ast|alt|troponin|plt)\b", s.lower()))
    dates = len(re.findall(r"\bday\s*\d+|\d{1,2}/\d{1,2}\b", s.lower()))
    length_penalty = 0.0 if len(s) < 40 else (0.1 if len(s) < 150 else -0.2)  # prefer concise
    return hits*2 + labs*2 + nums*0.5 + dates*0.5 + length_penalty

def top_k_sentences(text: str, k=TOP_SENT_PER_CHUNK):
    sents = sentence_split(text)
    scored = sorted(((score_sentence(s), s) for s in sents), key=lambda x: x[0], reverse=True)
    return [s for _, s in scored[:k]]

Pipeline: Hierarchical (chunk -> summarize -> weighted merge -> final summarize)

In [13]:
def summarize_hierarchical_H2(long_text: str) -> str:
    # Stage 0: prefer section text if available
    base_text = best_section_or_full(long_text)

    # Stage 1: split base_text into token chunks and summarize each
    chunk_texts = text_to_token_chunks(base_text, chunk_size=MAX_CHUNK_LEN, stride=CHUNK_STRIDE)
    if len(chunk_texts) == 0:
        return generate_summary(base_text, max_in=MAX_INPUT_LEN, max_out=FINAL_SUMMARY_LEN)

    mini_summaries = []
    for t in chunk_texts:
        mini_summaries.append(generate_summary(t, max_in=MAX_INPUT_LEN, max_out=MAX_SUMMARY_LEN))

    # Stage 2 (H2): collect top sentences from each mini-summary
    candidates = []
    for ms in mini_summaries:
        candidates.extend(top_k_sentences(ms, k=TOP_SENT_PER_CHUNK))

    # Merge candidates into a single text and do the final pass
    merged = " ".join(candidates)
    final = generate_summary(merged, max_in=MAX_INPUT_LEN, max_out=FINAL_SUMMARY_LEN)
    return final

# ---------- Convenience wrappers ----------

In [14]:
def summarize_bhc(text: str, strategy="hierarchical-H2"):
    if strategy == "section-only":
        return generate_summary(best_section_or_full(text), max_in=MAX_INPUT_LEN, max_out=FINAL_SUMMARY_LEN)
    return summarize_hierarchical_H2(text)


# Ensure df exists and has input_clean / target_clean, then:

In [15]:
df = filter_pairs_by_tokens(df)
print(summarize_bhc(df.iloc[0]["input_clean"], strategy="hierarchical-H2"))

Token indices sequence length is longer than the specified maximum sequence length for this model (2251 > 512). Running this sequence through the model will result in indexing errors


[filter] kept 269224/270033 rows
___ edema is a common cause of abd distension and pain, but it has been reported that she has been having worsening abd distension and discomfort over past week.results:in the ED, initial vitals were 98.4 70 106/63 16 97%RA Labs notable for ALT/AST/AP ______: ___, Tbili1.6, WBC 5K, platelet 77, INR 1.6.


# QLoRA Trainer Setup (drop this in AFTER you confirm model + tokenizer load)

In [16]:
# 4bit quant config for 8GB GPU
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
)

# load base model quantized 4bit
base_model = AutoModelForSeq2SeqLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto"
)
base_model.config.use_cache = False  # required for training
base_model.train()
# LoRA config (this is the correct target modules for seq2seq)
lora_cfg = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    task_type="SEQ_2_SEQ_LM",
    target_modules=["q", "v"]   # this is safer for T5/BART style
)

model = get_peft_model(base_model, lora_cfg)
model.print_trainable_parameters()


Some weights of T5ForConditionalGeneration were not initialized from the model checkpoint at Falconsai/medical_summarization and are newly initialized: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


trainable params: 294,912 || all params: 60,801,536 || trainable%: 0.4850403779272945


# Dataset Wrapper for Trainer

In [17]:
class BHCTrainDataset(Dataset):
    def __init__(self, df):
        self.df = df.reset_index(drop=True)

    def __len__(self): return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        base_text = best_section_or_full(row["input_clean"])
        enc = tokenizer(INSTR + base_text,
                        truncation=True, max_length=MAX_INPUT_LEN,
                        padding=True,
                        return_tensors="pt")
        dec = tokenizer(row["target_clean"],
                        truncation=True, max_length=MAX_SUMMARY_LEN,
                        padding=True,
                        return_tensors="pt")
        return {
            "input_ids": enc["input_ids"].squeeze(0),
            "attention_mask": enc["attention_mask"].squeeze(0),
            "labels": dec["input_ids"].squeeze(0),
        }


In [18]:
train_df, val_df = train_test_split(df, test_size=0.1, random_state=42)
print(f"Original Training length: {len(train_df)} | Original Validation Length: {len(val_df)}")


Original Training length: 242301 | Original Validation Length: 26923


In [19]:
train_df = train_df.sample(n=2000, random_state=42)
val_df   = val_df.sample(n=200, random_state=42)
print(f"Subset Training length: {len(train_df)} | Subset Validation Length: {len(val_df)}")

Subset Training length: 2000 | Subset Validation Length: 200


In [20]:
torch.set_grad_enabled(True)
train_args = TrainingArguments(
    output_dir="./outputs",
    num_train_epochs=3,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=4,   # fits 8GB GPU
    learning_rate=2e-4,
    warmup_steps=500,                # this is important
    lr_scheduler_type="cosine",
    fp16=True,
    logging_steps=50,
    save_steps=1000,
    evaluation_strategy="epoch",
    save_total_limit=2,
    report_to="none",
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

trainer = Trainer(
    model=model,
    args=train_args,
    train_dataset=BHCTrainDataset(train_df),
    eval_dataset=BHCTrainDataset(val_df),
    data_collator=data_collator
)

trainer.train()
model.save_pretrained("./finetuned_bhc_test3")
tokenizer.save_pretrained("./finetuned_bhc_test3")




  0%|          | 0/1500 [00:00<?, ?it/s]

You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'loss': 4.4447, 'learning_rate': 1.9600000000000002e-05, 'epoch': 0.1}
{'loss': 4.5322, 'learning_rate': 3.9200000000000004e-05, 'epoch': 0.2}
{'loss': 4.3045, 'learning_rate': 5.92e-05, 'epoch': 0.3}
{'loss': 4.1221, 'learning_rate': 7.920000000000001e-05, 'epoch': 0.4}
{'loss': 4.0134, 'learning_rate': 9.92e-05, 'epoch': 0.5}
{'loss': 3.9899, 'learning_rate': 0.0001192, 'epoch': 0.6}
{'loss': 3.9281, 'learning_rate': 0.0001392, 'epoch': 0.7}
{'loss': 3.8872, 'learning_rate': 0.00015920000000000002, 'epoch': 0.8}
{'loss': 3.8405, 'learning_rate': 0.00017920000000000002, 'epoch': 0.9}
{'loss': 3.8524, 'learning_rate': 0.00019920000000000002, 'epoch': 1.0}


  0%|          | 0/200 [00:00<?, ?it/s]

{'eval_loss': 3.692786455154419, 'eval_runtime': 9.872, 'eval_samples_per_second': 20.259, 'eval_steps_per_second': 20.259, 'epoch': 1.0}
{'loss': 3.7897, 'learning_rate': 0.0001988651744737914, 'epoch': 1.1}
{'loss': 3.8541, 'learning_rate': 0.00019529793415172192, 'epoch': 1.2}
{'loss': 3.7792, 'learning_rate': 0.0001893841424151264, 'epoch': 1.3}
{'loss': 3.8064, 'learning_rate': 0.0001812694164433094, 'epoch': 1.4}
{'loss': 3.8657, 'learning_rate': 0.00017115356772092854, 'epoch': 1.5}
{'loss': 3.8237, 'learning_rate': 0.00015928568201610595, 'epoch': 1.6}
{'loss': 3.7684, 'learning_rate': 0.00014595798606214882, 'epoch': 1.7}
{'loss': 3.7859, 'learning_rate': 0.0001314986519655305, 'epoch': 1.8}
{'loss': 3.7354, 'learning_rate': 0.00011626371651948838, 'epoch': 1.9}
{'loss': 3.7149, 'learning_rate': 0.00010062831439655591, 'epoch': 2.0}


  0%|          | 0/200 [00:00<?, ?it/s]

{'eval_loss': 3.587766170501709, 'eval_runtime': 9.8582, 'eval_samples_per_second': 20.288, 'eval_steps_per_second': 20.288, 'epoch': 2.0}
{'loss': 3.7388, 'learning_rate': 8.497744108792429e-05, 'epoch': 2.1}
{'loss': 3.7325, 'learning_rate': 6.969647303672262e-05, 'epoch': 2.2}
{'loss': 3.7512, 'learning_rate': 5.5161678390996796e-05, 'epoch': 2.3}
{'loss': 3.724, 'learning_rate': 4.173095203314241e-05, 'epoch': 2.4}
{'loss': 3.7478, 'learning_rate': 2.9958884921619367e-05, 'epoch': 2.5}
{'loss': 3.7526, 'learning_rate': 1.965585998878724e-05, 'epoch': 2.6}
{'loss': 3.7, 'learning_rate': 1.1331174429944347e-05, 'epoch': 2.7}
{'loss': 3.7602, 'learning_rate': 5.189809631596798e-06, 'epoch': 2.8}
{'loss': 3.7581, 'learning_rate': 1.3829863771011253e-06, 'epoch': 2.9}
{'loss': 3.7137, 'learning_rate': 4.4412891050171765e-09, 'epoch': 3.0}


  0%|          | 0/200 [00:00<?, ?it/s]

{'eval_loss': 3.5670135021209717, 'eval_runtime': 9.192, 'eval_samples_per_second': 21.758, 'eval_steps_per_second': 21.758, 'epoch': 3.0}
{'train_runtime': 695.4945, 'train_samples_per_second': 8.627, 'train_steps_per_second': 2.157, 'train_loss': 3.8739136454264322, 'epoch': 3.0}


('./finetuned_bhc_test3\\tokenizer_config.json',
 './finetuned_bhc_test3\\special_tokens_map.json',
 './finetuned_bhc_test3\\spiece.model',
 './finetuned_bhc_test3\\added_tokens.json',
 './finetuned_bhc_test3\\tokenizer.json')

# Build eval function over your validation subset

In [21]:

def evaluate_model(df, strategy="section-only"):
    preds = []
    refs = []
    for i in tqdm(range(len(df))):
        text = df.iloc[i]["input_clean"]
        ref  = df.iloc[i]["target_clean"]
        summary = summarize_bhc(text, strategy=strategy)   # this uses the hierarchical/section pipelines we built
        preds.append(summary)
        refs.append(ref)
    return preds, refs


## Run predictions on val subset for both strategies

In [22]:
subset_val_df_small = val_df.sample(n=5, random_state=42)

In [23]:
preds_section, refs = evaluate_model(subset_val_df_small, strategy="section-only")
preds_h2, _ = evaluate_model(subset_val_df_small, strategy="hierarchical-H2")

100%|██████████| 5/5 [00:27<00:00,  5.56s/it]
100%|██████████| 5/5 [03:51<00:00, 46.20s/it]


# Compute ROUGE

In [24]:
from evaluate import load as load_metric
rouge = load_metric("rouge")

scores_section = rouge.compute(predictions=preds_section, references=refs)
scores_h2 = rouge.compute(predictions=preds_h2, references=refs)

print("SECTION ONLY:", scores_section)
print("H2 HIERARCHICAL:", scores_h2)


SECTION ONLY: {'rouge1': 0.28760386315999636, 'rouge2': 0.07946909492894719, 'rougeL': 0.15914617549270282, 'rougeLsum': 0.15914617549270282}
H2 HIERARCHICAL: {'rouge1': 0.26445205235529406, 'rouge2': 0.06886098276222131, 'rougeL': 0.15482794352618695, 'rougeLsum': 0.15482794352618695}


# Compute BERTScore

In [25]:
bertscore = load_metric("bertscore")

berts_section = bertscore.compute(predictions=preds_section, references=refs, lang="en")
berts_h2 = bertscore.compute(predictions=preds_h2, references=refs, lang="en")

print("BERTScore SECTION ONLY:", {
   "precision": sum(berts_section["precision"])/len(berts_section["precision"]),
   "recall": sum(berts_section["recall"])/len(berts_section["recall"]),
   "f1": sum(berts_section["f1"])/len(berts_section["f1"]),
})

print("BERTScore H2 HIERARCHICAL:", {
   "precision": sum(berts_h2["precision"])/len(berts_h2["precision"]),
   "recall": sum(berts_h2["recall"])/len(berts_h2["recall"]),
   "f1": sum(berts_h2["f1"])/len(berts_h2["f1"]),
})


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BERTScore SECTION ONLY: {'precision': 0.844977867603302, 'recall': 0.8066615104675293, 'f1': 0.8252794861793518}
BERTScore H2 HIERARCHICAL: {'precision': 0.8270078182220459, 'recall': 0.798820161819458, 'f1': 0.8125460743904114}
