In [1]:
import re
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TrainingArguments

In [2]:
# Load the dataset
dataset = load_dataset("omi-health/medical-dialogue-to-soap-summary")
print("Original Training Split Size:", len(dataset["train"]))
print("Original Validation Split Size:", len(dataset["validation"]))
print("Original Test Split Size:", len(dataset["test"]))

# Preprocessing function
def preprocess_dialogue(example):
    dialogue = example["dialogue"]
    soap = example["soap"]

    # Clean text
    dialogue = re.sub(r'[^A-Za-z0-9\s.,:?-]', '', dialogue).lower()
    soap = re.sub(r'[^A-Za-z0-9\s.,:?-]', '', soap).lower()

    # Add role tags
    dialogue = dialogue.replace("Doctor:", "[Doctor]:")
    dialogue = dialogue.replace("Patient:", "[Patient]:")

    return {"dialogue": dialogue, "soap": soap}

# Apply preprocessing
processed_dataset = dataset.map(preprocess_dialogue)

Original Training Split Size: 9250
Original Validation Split Size: 500
Original Test Split Size: 250


In [3]:
# Load BART tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")

# Tokenization function for dialogue-to-SOAP task
def tokenize_function(example):
    # Tokenize dialogue (input)
    model_inputs = tokenizer(
        example["dialogue"],
        truncation=True,
        padding="max_length",
        max_length=512,
        return_tensors="pt"
    )

    # Tokenize SOAP (target)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            example["soap"],
            truncation=True,
            padding="max_length",
            max_length=256,  # Shorter for SOAP as it's more concise
            return_tensors="pt"
        )

    model_inputs["labels"] = labels["input_ids"].squeeze()
    return model_inputs

# Apply tokenization
tokenized_dataset = processed_dataset.map(tokenize_function, batched=True)

# Remove unnecessary columns and set format
tokenized_dataset = tokenized_dataset.remove_columns(["dialogue", "soap"])
tokenized_dataset.set_format("torch")

# Split into train and eval datasets
train_dataset = tokenized_dataset["train"]
eval_dataset = tokenized_dataset["validation"]

# Debug: Verify dataset
print("Sample from train_dataset:", train_dataset[0].keys())
print("Sample 'input_ids' shape:", train_dataset[0]["input_ids"].shape)
print("Sample 'labels' shape:", train_dataset[0]["labels"].shape)
print("Sample from eval_dataset:", eval_dataset[0].keys())
print("Sample 'labels' shape (eval):", eval_dataset[0]["labels"].shape)

Map:   0%|          | 0/500 [00:00<?, ? examples/s]



Sample from train_dataset: dict_keys(['prompt', 'messages', 'messages_nosystem', 'input_ids', 'attention_mask', 'labels'])
Sample 'input_ids' shape: torch.Size([512])
Sample 'labels' shape: torch.Size([256])
Sample from eval_dataset: dict_keys(['prompt', 'messages', 'messages_nosystem', 'input_ids', 'attention_mask', 'labels'])
Sample 'labels' shape (eval): torch.Size([256])


In [4]:
from transformers import Trainer, TrainingArguments, AutoModelForSeq2SeqLM, AutoTokenizer

# Load BART model
model= AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large", use_safetensors=True)

# Freeze BART's encoder for feature extraction
for param in model.model.encoder.parameters():
    param.requires_grad = False

# Optionally, freeze all but the last decoder layer
for layer in model.model.decoder.layers[:-1]:
    for param in layer.parameters():
        param.requires_grad = False

# Define training arguments
training_args = TrainingArguments(
    output_dir="./bart-soap-finetuned",
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_dir="./logs",
    logging_steps=100,
    load_best_model_at_end=True,
    per_device_train_batch_size=8,  # Increased from 4 to 8
    per_device_eval_batch_size=8,   # Increased from 4 to 8
    num_train_epochs=3,
    weight_decay=0.01,
    learning_rate=1e-4,  # Increased learning rate to match larger batch size
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    report_to="none",
    gradient_accumulation_steps=2,  # Added to effectively increase batch size further
    warmup_steps=500,  # Added warmup steps for better training stability
)

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
)

  trainer = Trainer(


In [5]:
# Train the model
trainer.train()



Epoch,Training Loss,Validation Loss
1,2.0482,1.500591
2,1.6816,1.319817
3,1.5397,1.262581


There were missing keys in the checkpoint model loaded: ['model.encoder.embed_tokens.weight', 'model.decoder.embed_tokens.weight', 'lm_head.weight'].


TrainOutput(global_step=870, training_loss=1.799115788251504, metrics={'train_runtime': 2181.7035, 'train_samples_per_second': 12.719, 'train_steps_per_second': 0.399, 'total_flos': 3.0068576354304e+16, 'train_loss': 1.799115788251504, 'epoch': 3.0})

In [6]:
# Save the fine-tuned model
model.save_pretrained("./bart-soap-finetuned-final")
tokenizer.save_pretrained("./bart-soap-finetuned-final")

('./bart-soap-finetuned-final/tokenizer_config.json',
 './bart-soap-finetuned-final/special_tokens_map.json',
 './bart-soap-finetuned-final/vocab.json',
 './bart-soap-finetuned-final/merges.txt',
 './bart-soap-finetuned-final/added_tokens.json',
 './bart-soap-finetuned-final/tokenizer.json')

In [7]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch

# Load the fine-tuned model and tokenizer
model = AutoModelForSeq2SeqLM.from_pretrained("./bart-soap-finetuned-final")
tokenizer = AutoTokenizer.from_pretrained("./bart-soap-finetuned-final")

# Inference function
def generate_soap(dialogue, model, tokenizer, device="cuda:0" if torch.cuda.is_available() else "cpu"):
    model.to(device)
    inputs = tokenizer(
        dialogue,
        return_tensors="pt",
        truncation=True,
        padding="max_length",
        max_length=512
    ).to(device)
    
    # Generate SOAP report
    outputs = model.generate(
        inputs["input_ids"],
        max_length=256,
        num_beams=4,
        early_stopping=True
    )
    
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Test inference on a sample
sample_dialogue = processed_dataset["test"][2]["dialogue"]
sample_soap = processed_dataset["test"][2]["soap"]
print("\nSample Dialogue:", sample_dialogue)
print("\n\nGenerated SOAP:", generate_soap(sample_dialogue, model, tokenizer))
print("\n\nReference SOAP:", sample_soap)


Sample Dialogue: doctor: hello, weve received your results from the ultrasound we performed in april 2017. it seems that there is a single thyroid nodule present in your left lobe, measuring 1 cm in its largest diameter. we also performed a complete biochemical screening, including tests for tsh, autoantibodies, and calcitonin.
patient: hmm, what did the screening results show, doctor?
doctor: your calcitonin level was found to be slightly elevated at 40 ngml, which is above the normal range of 1-4.8 ngml. to further investigate, we performed a stimulation test with intravenous calcium.
patient: and what did the stimulation test show?
doctor: after the stimulation, your calcitonin levels peaked at 1420 ngml, which indicates that surgical treatment is necessary. as a result, you underwent a total thyroidectomy and central neck dissection on the side of the tumor.
patient: yes, i remember that. how was my recovery after the surgery?
doctor: your postoperative course was uneventful, with

In [8]:
import pandas as pd
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch

# Load the fine-tuned model and tokenizer
model = AutoModelForSeq2SeqLM.from_pretrained("./bart-soap-finetuned-final")
tokenizer = AutoTokenizer.from_pretrained("./bart-soap-finetuned-final")

# Inference function
def generate_soap(dialogue, model, tokenizer, device="cuda:0" if torch.cuda.is_available() else "cpu"):
    model.to(device)
    model.eval()
    inputs = tokenizer(
        dialogue,
        return_tensors="pt",
        truncation=True,
        padding="max_length",
        max_length=512
    ).to(device)
    
    with torch.no_grad():
        outputs = model.generate(
            inputs["input_ids"],
            max_length=256,
            num_beams=4,
            early_stopping=True
        )
    
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Generate and collect results for the first 100 samples
device = "cuda:0" if torch.cuda.is_available() else "cpu"

results = []

for idx in range(100):
    sample_dialogue = processed_dataset["test"][idx]["dialogue"]
    reference_soap = processed_dataset["test"][idx]["soap"]
    
    generated_soap = generate_soap(sample_dialogue, model, tokenizer, device=device)
    
    results.append({
        "Dialogue": sample_dialogue,
        "Reference SOAP": reference_soap,
        "Generated SOAP": generated_soap
    })

# Save the results to a CSV file
df = pd.DataFrame(results)
df.to_csv("transfer-learning-results.csv", index=False)

print("Results saved to 'transfer-learning-results.csv' ✅")

Results saved to 'transfer-learning-results.csv' ✅


In [9]:
import evaluate
import pandas as pd

# Load your saved CSV
df = pd.read_csv("transfer-learning-results.csv")

# Extract generated and reference lists
generated_soap_list = df["Generated SOAP"].tolist()
reference_soap_list = df["Reference SOAP"].tolist()

# Load evaluation metrics
rouge = evaluate.load("rouge")
bleu = evaluate.load("bleu")
bertscore = evaluate.load("bertscore")

# Compute ROUGE
rouge_result = rouge.compute(predictions=generated_soap_list, references=reference_soap_list)
print("\nROUGE Results:")
print(rouge_result)

# Compute BLEU
# BLEU expects references as a list of lists
bleu_result = bleu.compute(predictions=generated_soap_list, references=[[ref] for ref in reference_soap_list])
print("\nBLEU Results:")
print(bleu_result)

# Compute BERTScore
bertscore_result = bertscore.compute(predictions=generated_soap_list, references=reference_soap_list, lang="en")
print("\nBERTScore Results:")
print({
    "precision": sum(bertscore_result["precision"]) / len(bertscore_result["precision"]),
    "recall": sum(bertscore_result["recall"]) / len(bertscore_result["recall"]),
    "f1": sum(bertscore_result["f1"]) / len(bertscore_result["f1"])
})


ROUGE Results:
{'rouge1': np.float64(0.532743784942171), 'rouge2': np.float64(0.305151374383853), 'rougeL': np.float64(0.3494364611273041), 'rougeLsum': np.float64(0.44977932903076123)}

BLEU Results:
{'bleu': 0.20922846183832733, 'precisions': [0.7170381987229753, 0.43804213135068154, 0.29455081001472755, 0.20753104705480233], 'brevity_penalty': 0.5620764554238586, 'length_ratio': 0.6344705046197584, 'translation_length': 17854, 'reference_length': 28140}


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



BERTScore Results:
{'precision': 0.9102424734830856, 'recall': 0.8778711241483689, 'f1': 0.8936530894041061}
