# Fine Tuning T5-base to understand Medical Domain

In [1]:
# setup env
# !pip install transformers datasets evaluate rouge_score accelerate -q

In [2]:
import torch
import numpy as np
import pandas as pd
from datasets import load_dataset
from huggingface_hub import notebook_login

from transformers import AutoTokenizer
from transformers import DataCollatorForSeq2Seq
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer

import evaluate

In [26]:
# set variable & parameters
MODEL_CHECKPOINT = "t5-small"  # t5-base
MODEL_REPO = "medical_diagnostic_summarizer"
PREFIX = "summarize: "
MAX_INPUT_LENGTH = 1024
MAX_TARGET_LENGTH = 128
BATCH_SIZE = 16


In [17]:
notebook_login()
# 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

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

## Load Dataset

### Dataset Summary
In response to the COVID-19 pandemic, the White House and a coalition of leading research groups have prepared the COVID-19 Open Research Dataset (CORD-19). CORD-19 is a resource of over 1,000,000 scholarly articles, including over 400,000 with full text, about COVID-19, SARS-CoV-2, and related coronaviruses. This freely available dataset is provided to the global research community to apply recent advances in natural language processing and other AI techniques to generate new insights in support of the ongoing fight against this infectious disease. This is a processed version of the dataset, where we removed some empty entries and formated it to be compatible with the alpaca training. For more details on the data, please refer to the original publicatio.

In [5]:
split = 'train[0:50000]'
billsum = load_dataset("medalpaca/medical_meadow_cord19", split=split)
billsum = billsum.train_test_split(test_size=0.2)
billsum["train"][0]

{'output': 'From Personalized Medicine to Population Health: A Survey of mHealth Sensing Techniques',
 'instruction': 'Please summerize the given abstract to a title',
 'input': 'Mobile Sensing Apps have been widely used as a practical approach to collect behavioral and health-related information from individuals and provide timely intervention to promote health and well-beings, such as mental health and chronic cares. As the objectives of mobile sensing could be either \\emph{(a) personalized medicine for individuals} or \\emph{(b) public health for populations}, in this work we review the design of these mobile sensing apps, and propose to categorize the design of these apps/systems in two paradigms -- \\emph{(i) Personal Sensing} and \\emph{(ii) Crowd Sensing} paradigms. While both sensing paradigms might incorporate with common ubiquitous sensing technologies, such as wearable sensors, mobility monitoring, mobile data offloading, and/or cloud-based data analytics to collect and pro

In [6]:
billsum

DatasetDict({
    train: Dataset({
        features: ['output', 'instruction', 'input'],
        num_rows: 40000
    })
    test: Dataset({
        features: ['output', 'instruction', 'input'],
        num_rows: 10000
    })
})

## Preprocces dataset

In [7]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT)

In [8]:
def preprocess_function(examples):
    inputs = [PREFIX + doc for doc in examples["input"]]
    model_inputs = tokenizer(inputs, max_length=MAX_INPUT_LENGTH, truncation=True)

    labels = tokenizer(text_target=examples["output"], max_length=MAX_TARGET_LENGTH, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [9]:
tokenized_dataset = billsum.map(preprocess_function, batched=True)

Map:   0%|          | 0/40000 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [10]:
# Load Model
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_CHECKPOINT)

In [11]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

## Compute Metrics

In [12]:
metrics = evaluate.load("rouge")

In [13]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = metrics.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)

    return {k: round(v, 4) for k, v in result.items()}

## Define Hyperparameter

In [18]:
training_args = Seq2SeqTrainingArguments(
    output_dir=MODEL_REPO,
    evaluation_strategy="epoch",
    learning_rate=1e-3,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=3,
    predict_with_generate=True,
    fp16=True,
    push_to_hub=True,
)

In [19]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

## Training

In [20]:
trainer.train()

Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
1,2.1658,1.97035,0.411,0.2134,0.3502,0.3502,17.6057
2,1.9441,1.883011,0.4155,0.2172,0.355,0.3551,17.6832




Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
1,2.1658,1.97035,0.411,0.2134,0.3502,0.3502,17.6057
2,1.9441,1.883011,0.4155,0.2172,0.355,0.3551,17.6832
3,1.7621,1.867043,0.4177,0.2184,0.3563,0.3564,17.6943


TrainOutput(global_step=7500, training_loss=2.010565665690104, metrics={'train_runtime': 5510.0195, 'train_samples_per_second': 21.779, 'train_steps_per_second': 1.361, 'total_flos': 2.116325792494387e+16, 'train_loss': 2.010565665690104, 'epoch': 3.0})

In [21]:
trainer.push_to_hub()

'https://huggingface.co/fahmiaziz/medical_diagnostic_summarizer/tree/main/'

## Evaluate

In [31]:
# text = """summarize:
# About acne
# Acne is a common skin condition that affects most people at some point.
#  It causes spots, oily skin and sometimes skin that's hot or painful to touch.

# Acne most commonly develops on the:

# face – this affects almost everyone with acne
# back – this affects more than half of people with acne
# chest – this affects about 15% of people with acne
# Types of spots
# There are 6 main types of spot caused by acne:

# blackheads – small black or yellowish bumps that develop on the skin; they're not filled with dirt, but are black because the inner lining of the hair follicle produces pigmentation (colouring)
# whiteheads – have a similar appearance to blackheads, but may be firmer and won't empty when squeezed
# papules – small red bumps that may feel tender or sore
# pustules – similar to papules, but have a white tip in the centre, caused by a build-up of pus
# nodules – large hard lumps that build up beneath the surface of the skin and can be painful
# cysts – the most severe type of spot caused by acne; they're large pus-filled lumps that look similar to boils and carry the greatest risk of causing permanent scarring
# """


# text = """summarize:
# COURSE WHILE IN HOSPITAL
# Relevant Complaint(s) and Concerns:
# 1. Upon arrival: Patient presented with five days of increased urinary frequency, urgency and dysuria as well as
# 48 hours of fever and rigors. He was hypotensive and tachycardic upon arrival to the emergency department.
# The internal medicine service was consulted. The following issues were addressed during the hospitalization:
# Summary Course in Hospital (Issues Addressed):
# 2. Fever and urinary symptoms: A preliminary diagnosis of pyelonephritis was established. Other causes of fever
# were possible but less likely. The patient was hypotensive on initial assessment with a blood pressure of
# 80/40. Serum lactate was elevated at 6.1. A bolus of IV fluid was administered (1.5L) but the patient remained
# hypotensive. Our colleagues from ICU were consulted. An arterial line was inserted for hemodynamic
# monitoring. Hemodynamics were supported with levophed and crystalloids. Piptazo was started after blood
# and urine cultures were drawn. After 12 hours serum lactate had normalized and hemodynamics had
# stabilized. Blood cultures were positive for E.Coli that was sensitive to all antibiotics. The patient was stepped
# down to oral ciprofloxacin to complete a total 14 day course of antibiotics.
# On further review it was learned that the patient has been experiencing symptoms of prostatism for the last
# year. An abdominal ultrasound performed for elevated liver enzymes and acute kidney injury confirmed a
# """


text = """summarize:
DIAGNOSIS:
A. SKIN, RIGHT ARM, SHAVE BIOPSY:
COMPATIBLE WITH PERFORATING DISORDER WITH FEATURES OF
ELASTOSIS PERFORANS SERPIGINOSUM.
B. SKIN, LEFT NECK, SHAVE BIOPSY:
1. COMPATIBLE WITH PERFORATING DISORDER WITH FEATURES
OF ELASTOSIS PERFORANS SERPIGINOSUM.
2. ASSOCIATED SPONGIOTIC DERMATITIS WITH OCCASIONAL
EOSINOPHILS (SEE NOTE).
"""

In [23]:
def summarize_text(text: str, model: str):
    tokenizer = AutoTokenizer.from_pretrained(model)
    model = AutoModelForSeq2SeqLM.from_pretrained(model)
    inputs = tokenizer(text, return_tensors="pt").input_ids
    outputs = model.generate(inputs, max_new_tokens=100, do_sample=False)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

### T5-small

In [32]:
summarize_text(text, model="t5-small")

'RIGHT ARM, SHAVE BIOPSY: COMPATIBLE WITH PERFORATING DISORDER WITH FEATURES OF ELASTOSIS PERFORANS SERPIGINOSUM. ASSOCIATED SPONGIOTIC DERMATITIS WITH OCCASIONAL EOSINOPHILS (SEE NOTE).'

### Fine Tuning Model

In [33]:
summarize_text(text, model=MODEL_REPO)

'COMPATIBLE WITH PERFORATING DISORDER WITH FEATURES OF ELASTOSIS PERFORANS SERPIGINOSUM'