In [1]:
pip install transformers datasets

Note: you may need to restart the kernel to use updated packages.


In [24]:
# Clear GPU memory
torch.cuda.empty_cache()

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [2]:
import torch
print("Visible:", torch.cuda.device_count())

Visible: 1


In [3]:
print("Using device:", torch.cuda.current_device())
print("Device name:", torch.cuda.get_device_name(torch.cuda.current_device()))

Using device: 0
Device name: NVIDIA RTX A6000


In [8]:
!pip install sentencepiece



In [4]:
from datasets import load_dataset
from transformers import T5ForConditionalGeneration, T5Tokenizer
import sentencepiece

# Load the dataset
dataset = load_dataset("omi-health/medical-dialogue-to-soap-summary")

# Load the T5 model and tokenizer
model = T5ForConditionalGeneration.from_pretrained('t5-small')
tokenizer = T5Tokenizer.from_pretrained('t5-small')

# Preprocess the dialogues and SOAP notes for tokenization
def preprocess_and_tokenize(example):
    # Tokenize the dialogue (input text)
    input_ids = tokenizer(example["dialogue"], padding="max_length", truncation=True, max_length=512, return_tensors="pt")

    # Tokenize the SOAP note (target text)
    labels = tokenizer(example["soap"], padding="max_length", truncation=True, max_length=512, return_tensors="pt").input_ids

    # Add labels to the input dictionary
    input_ids["labels"] = labels
    return input_ids

# Apply preprocessing and tokenization to the entire dataset
train_dataset = dataset['train'].map(preprocess_and_tokenize, batched=True)
val_dataset = dataset['validation'].map(preprocess_and_tokenize, batched=True)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [5]:
from transformers import TrainingArguments, Trainer, EarlyStoppingCallback

training_args = TrainingArguments(
    output_dir="./results/full_parameter",            # Directory for saving the model
    num_train_epochs=3,                               # Number of training epochs
    per_device_train_batch_size=16,                   # Batch size for training
    per_device_eval_batch_size=16,                    # Batch size for evaluation
    warmup_steps=500,                                 # Warmup steps for learning rate scheduler
    weight_decay=0.01,                                # Weight decay to avoid overfitting
    logging_dir="./logs",                             # Directory for logging
    logging_steps=10,                                 # Log every 10 steps
    save_steps=500,                                   # Save the model every 500 steps
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,                               # Keep only the last 2 saved models
    load_best_model_at_end=True,                      # Load the best model when finished
    report_to="none",
    fp16=True,                                        # Enable mixed precision training for faster training
    metric_for_best_model="eval_loss",
)

In [6]:
early_stopping_callback = EarlyStoppingCallback(
    early_stopping_patience=3,
)

In [7]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    callbacks=[early_stopping_callback],
)

# Fine-tune the model
trainer.train()

  trainer = Trainer(
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Epoch,Training Loss,Validation Loss
1,1.8873,1.610374
2,1.6862,1.462896
3,1.644,1.431133


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight', 'lm_head.weight'].


TrainOutput(global_step=1737, training_loss=2.241065425576324, metrics={'train_runtime': 493.5631, 'train_samples_per_second': 56.224, 'train_steps_per_second': 3.519, 'total_flos': 3755734990848000.0, 'train_loss': 2.241065425576324, 'epoch': 3.0})

In [8]:
model.save_pretrained('./finetuned_t5_small')
tokenizer.save_pretrained('./finetuned_t5_small')

('./finetuned_t5_small/tokenizer_config.json',
 './finetuned_t5_small/special_tokens_map.json',
 './finetuned_t5_small/spiece.model',
 './finetuned_t5_small/added_tokens.json')

In [None]:
# Extract the first 100 dialogues from the test set
test_dialogues = dataset['test']['dialogue'][:100]

In [10]:
# Check if a GPU is available and use it
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Move the model to the appropriate device
model.to(device)

# Tokenize and move inputs to the same device as the model
inputs = tokenizer(test_dialogues, return_tensors="pt", padding=True, truncation=True, max_length=512)
inputs = {key: value.to(device) for key, value in inputs.items()}  # Move input tensors to the same device as the model

# Generate SOAP notes for these dialogues
outputs = model.generate(inputs['input_ids'], max_length=256, num_beams=5, early_stopping=True)

# Decode the generated token IDs into readable SOAP notes
generated_soap_notes = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]

# Print the first few results to check
print(generated_soap_notes[:1])

["S: The patient reports experiencing painless blurry vision in the right eye for a week, intermittent fevers, headache, body aches, and a nonpruritic maculopapular rash on the lower legs for the past 6 months. The patient denies any past medical history including neck stiffness, nausea, vomiting, Raynaud's phenomenon, oral ulcerations, chest pain, shortness of breath, abdominal pain, or photosensitivity. O: Vital signs were normal. Physical examination revealed bilateral papilledema and optic nerve erythema in the right eye, right inferior nasal quadrant visual field defect, right afferent pupillary defect, and sensation to light touch, pinprick, vibration, and proprioception intact. A: The primary diagnosis is microcytic anemia with a hemoglobin of 11.6 gm/dL, hematocrit 35.3%, mean corpuscular volume of 76.9 fL, hyponatremia with a sodium level of 133 mmol/L, C-reactive protein (CRP) elevated at 33 mm/hr, and"]


In [11]:
import csv

# Save the generated SOAP notes and dialogues to a CSV file
with open("full-paramater-generated-results.csv", mode="w", newline="") as file:
    writer = csv.writer(file)
    writer.writerow(["Dialogue", "Generated SOAP Note"])  # Column headers
    for dialogue, soap_note in zip(test_dialogues, generated_soap_notes):
        writer.writerow([dialogue, soap_note])

In [12]:
from transformers import T5ForConditionalGeneration, T5Tokenizer

# Load the fine-tuned model and tokenizer
model = T5ForConditionalGeneration.from_pretrained('./finetuned_t5_small')
tokenizer = T5Tokenizer.from_pretrained('./finetuned_t5_small')

In [14]:
import pandas as pd

# Load the CSV file containing the generated SOAP notes
df = pd.read_csv('full-paramater-generated-results.csv')

# Extract the generated SOAP notes
generated_soap_notes = df['Generated SOAP Note'].tolist()

# Print the first few SOAP notes to check
print(generated_soap_notes[:2])

["S: The patient reports experiencing painless blurry vision in the right eye for a week, intermittent fevers, headache, body aches, and a nonpruritic maculopapular rash on the lower legs for the past 6 months. The patient denies any past medical history including neck stiffness, nausea, vomiting, Raynaud's phenomenon, oral ulcerations, chest pain, shortness of breath, abdominal pain, or photosensitivity. O: Vital signs were normal. Physical examination revealed bilateral papilledema and optic nerve erythema in the right eye, right inferior nasal quadrant visual field defect, right afferent pupillary defect, and sensation to light touch, pinprick, vibration, and proprioception intact. A: The primary diagnosis is microcytic anemia with a hemoglobin of 11.6 gm/dL, hematocrit 35.3%, mean corpuscular volume of 76.9 fL, hyponatremia with a sodium level of 133 mmol/L, C-reactive protein (CRP) elevated at 33 mm/hr, and", "S: The patient is a 7-year-old boy with congenital bilateral sensorineu

In [15]:
import torch
import evaluate
from datasets import load_dataset

# Load the ROUGE metric
rouge_metric = evaluate.load("rouge")


dataset = load_dataset("omi-health/medical-dialogue-to-soap-summary")

# Extract the ground truth SOAP notes from the test set
ground_truth_soap_notes = dataset['test']['soap'][:100]

# Evaluate the ROUGE score by comparing generated SOAP notes with ground truth
results = rouge_metric.compute(predictions=generated_soap_notes, references=ground_truth_soap_notes)

# Print the ROUGE results
print(results)

{'rouge1': np.float64(0.47853501109237784), 'rouge2': np.float64(0.28087619526826557), 'rougeL': np.float64(0.35051591508349156), 'rougeLsum': np.float64(0.40778387667135596)}


In [16]:
# Accessing specific ROUGE scores
rouge_1 = results["rouge1"]
rouge_2 = results["rouge2"]
rouge_l = results["rougeL"]

# Print each of the scores
print(f"ROUGE-1: {rouge_1}")
print(f"ROUGE-2: {rouge_2}")
print(f"ROUGE-L: {rouge_l}")

ROUGE-1: 0.47853501109237784
ROUGE-2: 0.28087619526826557
ROUGE-L: 0.35051591508349156


In [1]:
import torch
import evaluate
from datasets import load_dataset

# Load the ROUGE metric
rouge_metric = evaluate.load("rouge")
bertscore = evaluate.load("bertscore")
bleu = evaluate.load("bleu")

Downloading builder script: 0.00B [00:00, ?B/s]

Downloading builder script: 0.00B [00:00, ?B/s]

Matplotlib is building the font cache; this may take a moment.


Downloading builder script: 0.00B [00:00, ?B/s]

Downloading extra modules:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

Downloading extra modules: 0.00B [00:00, ?B/s]

In [2]:
dataset = load_dataset("omi-health/medical-dialogue-to-soap-summary")

README.md: 0.00B [00:00, ?B/s]

train.json:   0%|          | 0.00/154M [00:00<?, ?B/s]

validation.json: 0.00B [00:00, ?B/s]

test.json: 0.00B [00:00, ?B/s]

Generating train split:   0%|          | 0/9250 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/500 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/250 [00:00<?, ? examples/s]

In [3]:
import pandas as pd

# Load the CSV file containing the generated SOAP notes
df = pd.read_csv('full-paramater-generated-results.csv')

# Extract the generated SOAP notes
generated_soap_notes = df['Generated SOAP Note'].tolist()

# Print the first few SOAP notes to check
print(generated_soap_notes[:2])

["S: The patient reports experiencing painless blurry vision in the right eye for a week, intermittent fevers, headache, body aches, and a nonpruritic maculopapular rash on the lower legs for the past 6 months. The patient denies any past medical history including neck stiffness, nausea, vomiting, Raynaud's phenomenon, oral ulcerations, chest pain, shortness of breath, abdominal pain, or photosensitivity. O: Vital signs were normal. Physical examination revealed bilateral papilledema and optic nerve erythema in the right eye, right inferior nasal quadrant visual field defect, right afferent pupillary defect, and sensation to light touch, pinprick, vibration, and proprioception intact. A: The primary diagnosis is microcytic anemia with a hemoglobin of 11.6 gm/dL, hematocrit 35.3%, mean corpuscular volume of 76.9 fL, hyponatremia with a sodium level of 133 mmol/L, C-reactive protein (CRP) elevated at 33 mm/hr, and", "S: The patient is a 7-year-old boy with congenital bilateral sensorineu

In [5]:
# Extract the ground truth SOAP notes from the test set
ground_truth_soap_notes = dataset['test']['soap'][:100]

# Evaluate the ROUGE score by comparing generated SOAP notes with ground truth
rouge_results = rouge_metric.compute(predictions=generated_soap_notes, references=ground_truth_soap_notes)
bertscore_results = bertscore.compute(predictions=generated_soap_notes, references=ground_truth_soap_notes, lang="en")
# Prepare data for BLEU (requires tokenized inputs)
bleu_references = [[ref] for ref in ground_truth_soap_notes]
bleu_predictions = generated_soap_notes

bleu_result = bleu.compute(predictions=bleu_predictions, references=bleu_references)

In [8]:
# Access individual metrics
rouge_1 = rouge_results["rouge1"]
rouge_2 = rouge_results["rouge2"]
rouge_l = rouge_results["rougeL"]
rouge_lsum = rouge_results["rougeLsum"]
bertscore_f1 = sum(bertscore_results["f1"]) / len(bertscore_results["f1"])  # Average F1
bleu_score = bleu_result["bleu"]

# Print all scores
print(f"ROUGE-1: {rouge_1:.4f}")
print(f"ROUGE-2: {rouge_2:.4f}")
print(f"ROUGE-L: {rouge_l:.4f}")
print(f"ROUGE-Lsum: {rouge_lsum:.4f}")
print(f"BERTScore-F1: {bertscore_f1:.4f}")
print(f"BLEU: {bleu_score:.4f}")

ROUGE-1: 0.4789
ROUGE-2: 0.2808
ROUGE-L: 0.3507
ROUGE-Lsum: 0.4076
BERTScore-F1: 0.8802
BLEU: 0.1760
