In [5]:
import pandas as pd
import torch
import transformers
from datasets import Dataset
from transformers import (
    BartTokenizer,
    BartForConditionalGeneration,
    DataCollatorForSeq2Seq,
    Trainer,
    TrainingArguments,
)


In [6]:
train_df = pd.read_csv("train_clean.csv")
val_df   = pd.read_csv("val_clean.csv")
test_df  = pd.read_csv("test_clean.csv")

print("Train shape:", train_df.shape)
print("Val shape:",   val_df.shape)
print("Test shape:",  test_df.shape)

print(train_df.head(2))
print(train_df.columns)

Train shape: (21671, 2)
Val shape: (2408, 2)
Test shape: (2676, 2)
                                                body  \
0  <SEX> M <SERVICE> PODIATRY <ALLERGIES> No Know...   
1  <SEX> M <SERVICE> MEDICINE <ALLERGIES> Codeine...   

                                             summary  
0  Mr. ___ was admitted after presenting to the E...  
1  ___ year old with a history of alcoholism, wit...  
Index(['body', 'summary'], dtype='object')


In [7]:
# BART’s Trainer works best with datasets.Dataset objects
train_ds = Dataset.from_pandas(train_df)
val_ds   = Dataset.from_pandas(val_df)
test_ds  = Dataset.from_pandas(test_df)

train_ds, val_ds, test_ds

(Dataset({
     features: ['body', 'summary'],
     num_rows: 21671
 }),
 Dataset({
     features: ['body', 'summary'],
     num_rows: 2408
 }),
 Dataset({
     features: ['body', 'summary'],
     num_rows: 2676
 }))

In [26]:
model_name = "facebook/bart-base"  # <-- base, not large

tokenizer = BartTokenizer.from_pretrained(model_name)
model     = BartForConditionalGeneration.from_pretrained(model_name)

#model_name = "./bart_base_mimic_checkpoints/checkpoint-latest"

#tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
#model = BartForConditionalGeneration.from_pretrained(model_name)

In [10]:
summary_lengths = val_df["summary"].str.split().str.len()
print(summary_lengths.describe())

count    2408.000000
mean      387.260382
std       265.220139
min        43.000000
25%       204.000000
50%       318.000000
75%       506.000000
max      2188.000000
Name: summary, dtype: float64


In [27]:
# we’ll use longer max length for the note (body)
max_input_length = 1024     # BART-base max positions
max_target_length = 512     # updated based on your summary stats

def tokenize_batch(batch):
    # encode the input medical note
    model_inputs = tokenizer(
        batch["body"],
        max_length=max_input_length,
        padding="max_length",   # fixed padding (simple for small project)
        truncation=True,
    )

    # encode the target summary
    labels = tokenizer(
        text_target=batch["summary"],
        max_length=max_target_length,
        padding="max_length",
        truncation=True,
    )

    # Trainer expects labels["input_ids"]
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [28]:
train_tok = train_ds.map(
    tokenize_batch,
    batched=True,
    remove_columns=train_ds.column_names,
)

val_tok = val_ds.map(
    tokenize_batch,
    batched=True,
    remove_columns=val_ds.column_names,
)

test_tok = test_ds.map(
    tokenize_batch,
    batched=True,
    remove_columns=test_ds.column_names,
)

train_tok, val_tok, test_tok

Map:   0%|          | 0/21671 [00:00<?, ? examples/s]

Map:   0%|          | 0/2408 [00:00<?, ? examples/s]

Map:   0%|          | 0/2676 [00:00<?, ? examples/s]

(Dataset({
     features: ['input_ids', 'attention_mask', 'labels'],
     num_rows: 21671
 }),
 Dataset({
     features: ['input_ids', 'attention_mask', 'labels'],
     num_rows: 2408
 }),
 Dataset({
     features: ['input_ids', 'attention_mask', 'labels'],
     num_rows: 2676
 }))

In [29]:
# ===== Utility: Copy the latest numbered checkpoint into a stable name =====

import os
import shutil

def copy_latest_checkpoint(output_dir):
    # HuggingFace checkpoints look like: "checkpoint-1000", "checkpoint-2000", ...
    checkpoints = [d for d in os.listdir(output_dir) if d.startswith("checkpoint-")]
    if not checkpoints:
        print("No checkpoints found yet.")
        return
    
    # Sort by step number
    checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
    latest = checkpoints[-1]

    src = os.path.join(output_dir, latest)
    dst = os.path.join(output_dir, "checkpoint-latest")

    # Delete old stable checkpoint folder
    if os.path.exists(dst):
        shutil.rmtree(dst)

    shutil.copytree(src, dst)
    print(f"Saved latest checkpoint → {dst}")

In [30]:
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
)

In [41]:
use_fp16 = False   # M2 cannot use CUDA FP16, so keep this False

training_args = TrainingArguments(
    output_dir="bart_base_mimic_checkpoints",
    save_strategy="epoch",              # save once per epoch
    learning_rate=2e-5,                 # good LR for BART fine-tuning
    per_device_train_batch_size=1,      # safest for MacBook RAM
    per_device_eval_batch_size=1,
    weight_decay=0.01,
    
    # training duration
    num_train_epochs=2,                 # start with 1 to ensure it runs cleanly
                                          # later you may increase to 2 if stable

    #predict_with_generate=True,         # required for seq2seq / summarization
    fp16=use_fp16,                      # stays False on M2
    logging_steps=100,

    save_total_limit=1,                 # keep only latest checkpoint
    report_to="none",                   # avoids wandb warnings
)

In [42]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_tok,
    eval_dataset=val_tok,
    data_collator=data_collator,
    tokenizer=tokenizer,
)

#trainer.train() # for first epoch trained, then it trains from the checkpoint
trainer.train(resume_from_checkpoint="./bart_base_mimic_checkpoints/checkpoint-latest")
copy_latest_checkpoint("bart_base_mimic_checkpoints")

  trainer = Trainer(
There were missing keys in the checkpoint model loaded: ['model.encoder.embed_tokens.weight', 'model.decoder.embed_tokens.weight', 'lm_head.weight'].


Step,Training Loss
21700,1.9215
21800,2.152
21900,2.1533
22000,2.1742
22100,2.1761
22200,2.114
22300,2.0543
22400,2.0017
22500,2.0682
22600,1.9453




Saved latest checkpoint → bart_base_mimic_checkpoints/checkpoint-latest


In [43]:
test_metrics = trainer.evaluate(test_tok)
print(test_metrics)



{'eval_loss': 1.8558242321014404, 'eval_runtime': 400.5347, 'eval_samples_per_second': 6.681, 'eval_steps_per_second': 6.681, 'epoch': 2.0}


In [45]:
import evaluate

rouge = evaluate.load("rouge")

def compute_rouge(trainer, dataset, tokenizer, max_samples=200):
    preds = []
    refs = []

    # Force everything to run on CPU (fixes your MPS error)
    device = torch.device("cpu")
    trainer.model.to(device)

    # Limit sample count for speed
    n = min(len(dataset), max_samples)

    for i in range(n):
        item = dataset[i]

        # ---- Convert model inputs to CPU tensors ----
        model_inputs = {
            "input_ids": torch.tensor(item["input_ids"]).unsqueeze(0).to(device),
            "attention_mask": torch.tensor(item["attention_mask"]).unsqueeze(0).to(device),
        }

        # ---- Generate BART summary ----
        with torch.no_grad():
            generated_ids = trainer.model.generate(
                **model_inputs,
                max_length=max_target_length,    # your notebook already defines this
                num_beams=4,
                early_stopping=True,
            )

        pred_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        preds.append(pred_text)

        # ---- Decode reference summary ----
        label_ids = [x for x in item["labels"] if x != -100]
        ref_text = tokenizer.decode(label_ids, skip_special_tokens=True)
        refs.append(ref_text)

    # ---- Compute ROUGE ----
    scores = rouge.compute(predictions=preds, references=refs)
    # convert numpy float64 → Python float
    scores = {k: float(v) for k, v in scores.items()}

    return scores


# Run ROUGE
rouge_scores = compute_rouge(trainer, test_tok, tokenizer, max_samples=200)
# rouge_scores = compute_rouge(trainer, test_tok, tokenizer, max_samples=1)
print(rouge_scores)



{'rouge1': 0.4060627937166115, 'rouge2': 0.1597496056586239, 'rougeL': 0.23632356076996036, 'rougeLsum': 0.23537538350054155}


In [46]:
# pick one example from the test set
sample_idx  = 0  # change this to inspect different notes
sample_body = test_df.iloc[sample_idx]["body"]
sample_ref  = test_df.iloc[sample_idx]["summary"]

print("=== Original note (body) ===")
print(sample_body[:1500], "...")   # truncate for display
print("\n=== Reference summary ===")
print(sample_ref)
print("\n=== Model summary (BART-base) ===")

# tokenize the input note for generation
inputs = tokenizer(
    sample_body,
    max_length=max_input_length,
    truncation=True,
    return_tensors="pt",
)

# move to GPU if available
model = model.to("cpu")
inputs = {k: v.to("cpu") for k, v in inputs.items()}


# generate summary
with torch.no_grad():
    generated_ids = model.generate(
        **inputs,
        max_length=max_target_length,
        num_beams=4,
        length_penalty=1.0,
        early_stopping=True,
    )

generated_summary = tokenizer.decode(
    generated_ids[0],
    skip_special_tokens=True,
)

print(generated_summary)

=== Original note (body) ===
<SEX> M <SERVICE> SURGERY <ALLERGIES> hydromorphone <ATTENDING> ___. <CHIEF COMPLAINT> Ventral hernia <MAJOR SURGICAL OR INVASIVE PROCEDURE> Ventral hernia repair with component separation and mesh onlay (over posterior sheath) <HISTORY OF PRESENT ILLNESS> ___ man with prior liver resection who now presents for elective repair of asymptomatic incisional hernia. <PAST MEDICAL HISTORY> hypertension, hyperlipidemia, bowel obstructions and gastric ulcer disease. Prior surgeries include sphincterotomy for anal fissure and bilateral inguinal hernia repairs. <SOCIAL HISTORY> ___ <FAMILY HISTORY> Both parents have coronary artery disease. <PERTINENT RESULTS> ___ 03: 04PM BLOOD WBC-7.9# RBC-3.84* Hgb-12.8* Hct-35.8* MCV-93 MCH-33.3* MCHC-35.8 RDW-14.2 RDWSD-48.1* Plt ___ ___ 05: 55AM BLOOD WBC-10.0 RBC-3.73* Hgb-12.3* Hct-36.0* MCV-97 MCH-33.0* MCHC-34.2 RDW-14.5 RDWSD-51.7* Plt ___ ___ 03: 04PM BLOOD Glucose-129* UreaN-19 Creat-0.9 Na-141 K-5.5* Cl-105 HCO3-19* AnG