In [2]:
import pandas as pd
from tqdm import tqdm
import torch
from datasets import load_dataset, load_metric
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [5]:
dataset_name = 'dmacres/mimiciii-hospitalcourse-cossim-pagerank-batched-extractive-summ-v2'
mimiciii_dataset = load_dataset(dataset_name)
mimiciii_dataset

Downloading readme:   0%|          | 0.00/886 [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/22.9M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/22.8M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/107M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating validation split:   0%|          | 0/5356 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/5356 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/24993 [00:00<?, ? examples/s]

DatasetDict({
    validation: Dataset({
        features: ['subject_id', 'hadm_id', 'target_text', 'extractive_notes_summ', 'n_notes'],
        num_rows: 5356
    })
    test: Dataset({
        features: ['subject_id', 'hadm_id', 'target_text', 'extractive_notes_summ', 'n_notes'],
        num_rows: 5356
    })
    train: Dataset({
        features: ['subject_id', 'hadm_id', 'target_text', 'extractive_notes_summ', 'n_notes'],
        num_rows: 24993
    })
})

In [6]:
rouge_metric = load_metric("rouge")

  rouge_metric = load_metric("rouge")


Downloading builder script:   0%|          | 0.00/2.17k [00:00<?, ?B/s]

In [7]:
model_ckpt = "facebook/bart-large"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = AutoModelForSeq2SeqLM.from_pretrained(model_ckpt).to(device)

Downloading (…)okenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.63k [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.02G [00:00<?, ?B/s]

In [8]:
def chunks(list_of_elements, batch_size):
    """Yield successive batch-sized chunks from list_of_elements."""
    for i in range(0, len(list_of_elements), batch_size):
        yield list_of_elements[i : i + batch_size]

def evaluate_summaries_bart(dataset, metric, model, tokenizer,
                               batch_size=16, device=device,
                               note_text="extractive_notes_summ",
                               note_summary="target_text"):
    note_batches = list(chunks(dataset[note_text], batch_size))
    target_batches = list(chunks(dataset[note_summary], batch_size))

    for note_batch, target_batch in tqdm(
        zip(note_batches, target_batches), total=len(note_batches)):

        inputs = tokenizer(note_batch, max_length=1024,  truncation=True,
                        padding="max_length", return_tensors="pt")

        summaries = model.generate(input_ids=inputs["input_ids"].to(device),
                         attention_mask=inputs["attention_mask"].to(device),
                         length_penalty=0.8, num_beams=8, max_length=128)

        decoded_summaries = [tokenizer.decode(s, skip_special_tokens=True,
                                clean_up_tokenization_spaces=True)
               for s in summaries]
        decoded_summaries = [d.replace("<n>", " ") for d in decoded_summaries]
        metric.add_batch(predictions=decoded_summaries, references=target_batch)

    score = metric.compute()
    return score

In [9]:
#hide_output
def convert_examples_to_features(example_batch):
    input_encodings = tokenizer(example_batch["extractive_notes_summ"], max_length=1024,
                                truncation=True)

    with tokenizer.as_target_tokenizer():
        target_encodings = tokenizer(example_batch["target_text"], max_length=128,
                                     truncation=True)

    return {"input_ids": input_encodings["input_ids"],
            "attention_mask": input_encodings["attention_mask"],
            "labels": target_encodings["input_ids"]}


In [10]:

columns = ["input_ids", "attention_mask", "labels"]

mimiciii_dataset_pt = mimiciii_dataset.map(convert_examples_to_features,
                                       batched=True)

mimiciii_dataset_pt.set_format(type="torch", columns=columns)


Map:   0%|          | 0/5356 [00:00<?, ? examples/s]



Map:   0%|          | 0/5356 [00:00<?, ? examples/s]

Map:   0%|          | 0/24993 [00:00<?, ? examples/s]

In [11]:
mimiciii_dataset_pt

DatasetDict({
    validation: Dataset({
        features: ['subject_id', 'hadm_id', 'target_text', 'extractive_notes_summ', 'n_notes', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 5356
    })
    test: Dataset({
        features: ['subject_id', 'hadm_id', 'target_text', 'extractive_notes_summ', 'n_notes', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 5356
    })
    train: Dataset({
        features: ['subject_id', 'hadm_id', 'target_text', 'extractive_notes_summ', 'n_notes', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 24993
    })
})

In [12]:
mimiciii_dataset_pt['train'][0]

{'input_ids': tensor([    0,   970,  1189,  ..., 19258,   620,     2]),
 'attention_mask': tensor([1, 1, 1,  ..., 1, 1, 1]),
 'labels': tensor([    0, 18276,  4843,    21,  2641,     7,     5, 14913, 12557, 21712,
           544,     8, 12796,    41, 20583, 13907,   248,  3977,  1543, 40139,
            19, 20971,     9,    39,  2849,   417,  9799, 36020,  1580,     4,
          1869,  8428, 10481,    37,    21,  7225,     7,     5, 12296,   575,
          1933,  8935,  1792,  1070,     8,  4375,     4,    91,    21,  4925,
            15, 14632,   927,   179,    13, 18587,  1759,  3792,  4360, 33966,
             6,    39,  1925,  1164,    21,  4875,     8,    37,    21, 14316,
            19,   593, 14913,  6240,     4,    91,    21,    67,  4925,    15,
          3349,   282,    23, 18356,    13,  8555,     9,  3766,  8092,  2811,
          3186,    18,  1233,  4835,   750,     4,    83,   618, 23655, 12464,
         14194,  7646,   618, 23655,  1022,     8,    41,  3855,    11,    

In [13]:
from transformers import DataCollatorForSeq2Seq

seq2seq_data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [14]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir='../bart-large-mimiciii-v2', num_train_epochs=3, warmup_steps=500,
    per_device_train_batch_size=1, per_device_eval_batch_size=1,
    weight_decay=0.01, logging_steps=10, push_to_hub=True,
    evaluation_strategy='steps', eval_steps=500, save_steps=1e6,
    gradient_accumulation_steps=16)

In [15]:
# hide_output
trainer = Trainer(model=model, args=training_args,
                  tokenizer=tokenizer, data_collator=seq2seq_data_collator,
                  train_dataset=mimiciii_dataset_pt['train'],
                  eval_dataset=mimiciii_dataset_pt['validation'])

In [16]:
# hide_output
trainer.train()


You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss,Validation Loss
500,3.0398,2.891853
1000,2.6518,2.557353
1500,2.5785,2.442617
2000,2.4006,2.416251
2500,2.3405,2.353037
3000,2.3185,2.314798
3500,2.2378,2.277536
4000,2.1804,2.252854
4500,2.1945,2.237562


TrainOutput(global_step=4686, training_loss=2.5001476013502706, metrics={'train_runtime': 15915.5156, 'train_samples_per_second': 4.711, 'train_steps_per_second': 0.294, 'total_flos': 1.5417642843957658e+17, 'train_loss': 2.5001476013502706, 'epoch': 3.0})

In [17]:
trainer.push_to_hub("Training complete!")

events.out.tfevents.1700053713.54e729c39d4b.468.0:   0%|          | 0.00/81.5k [00:00<?, ?B/s]

Upload 3 LFS files:   0%|          | 0/3 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

training_args.bin:   0%|          | 0.00/4.60k [00:00<?, ?B/s]

'https://huggingface.co/dmacres/bart-large-mimiciii-v2/tree/main/'

In [18]:
model_ckpt_cust = "../bart-large-mimiciii-v2"
tokenizer_cust = AutoTokenizer.from_pretrained(model_ckpt)
model_cust = AutoModelForSeq2SeqLM.from_pretrained(model_ckpt).to(device)

test_sample = mimiciii_dataset['test'].shuffle(seed = 42).select([3])
test_sample_text = test_sample['extractive_notes_summ']
print(test_sample_text)
test_sample_target = test_sample['target_text']
print('\n\n\n\n')
print(test_sample_target)

inputs = tokenizer_cust(test_sample_text, max_length=1024,  truncation=True,
                padding="max_length", return_tensors="pt")


summaries = model_cust.generate(input_ids=inputs["input_ids"].to(device),
                 attention_mask=inputs["attention_mask"].to(device),
                 length_penalty=0.8, num_beams=8, max_length=1024)

['There is some patchy opacity at the right lung base -- ? Again seen is opacification of left hemithorax. Sternotomy wires and riht paratracheal/suprahilar sutures are noted. Possible prior inferior myocardial infarction. Premature ventricularcontractions. Opacity at the right cardiophrenic angle could reflect a small effusion. The left hemithorax is opacified, with, as noted, shift of the mediastinum. FINDINGS:  The endotracheal tube, NG tube, right central line and the left pneumonectomy site appear unchanged. The extreme right costophrenic angle is excluded from the film. chest, 1 vw The patient is status post sternotomy. NG tube present, tip extending beneath diaphragm off film. Rotated positioning, which limits assessment of the central line tip. An NG tube is present, tip extending beneath the diaphragm. The right chest shows some atelectasis, but is otherwise grossly clear. A right IJ central line is present, tip probably overlies the SVC, though difficult to confirm due to lef

In [19]:
decoded_summaries = [tokenizer_cust.decode(s, skip_special_tokens=True,
                        clean_up_tokenization_spaces=True)
       for s in summaries]
decoded_summaries = [d.replace("<n>", " ") for d in decoded_summaries]
decoded_summaries

["There is some patchy opacity at the right lung base. A NG tube is coiled within the esophagus. A right IJ central line is present, tip probably overlies the SVC, though difficult to confirm due to leftward shift of the mediastinum and rotated positioning. A fracture of the left fourth lateral rib is again noted. Also noted is surgical material in the right lower quadrant, possibly representing a right lower lobectomy. IMPRESSION: Compared with one day earlier, no significant interval change is detected. The right chest shows some atelectasis, but is otherwise grossly clear. The extreme right costophrenic angle is excluded from the film. The left hemithorax in the pneumonectomy bed iscompatible with aorta and is unchanged. No previoustracing available for comparison. SINGLE FRONTAL VIEW OF THE CHEST:  71 year old man with hypoxia, transferred from OSH with bowel perf REASON FOR THIS EXAMINATION:  No pneumothorax is detected in the upper chest. Chest, 1 vw The patient is status post st

In [20]:
score = evaluate_summaries_bart(
    mimiciii_dataset['test'], rouge_metric, trainer.model, tokenizer,
    batch_size=2, note_text="extractive_notes_summ", note_summary="target_text")

rouge_methods = ['rouge1', 'rouge2', 'rougeL', 'rougeLsum']
rouge_dict = dict((rm, score[rm].mid.fmeasure) for rm in rouge_methods)
pd.DataFrame(rouge_dict, index=[f"bart-large"])

100%|██████████| 2678/2678 [2:44:08<00:00,  3.68s/it]


Unnamed: 0,rouge1,rouge2,rougeL,rougeLsum
bart-large,0.023803,0.011787,0.018235,0.018248


In [21]:
rouge_methods = ['rouge1', 'rouge2', 'rougeL', 'rougeLsum']
rouge_dict = dict((rm, score[rm].mid.fmeasure) for rm in rouge_methods)
pd.DataFrame(rouge_dict, index=[f"bart-large"]).to_csv('bart-large-mimiciii-v2-rogue-metrics.csv', index = False)