# T5 Medical Text Simplifier Fine-tuning

This notebook fine-tunes a T5 model for medical text simplification.
Upload your training data CSV with columns: 'medical' and 'simple'

In [None]:
# Install required packages
!pip install -q transformers datasets torch accelerate wandb evaluate rouge-score sentencepiece
!pip install -q huggingface_hub

In [None]:
# Import libraries
import pandas as pd
import torch
from transformers import (
    T5Tokenizer,
    T5ForConditionalGeneration,
    TrainingArguments,
    Trainer,
    DataCollatorForSeq2Seq
)
from datasets import Dataset, DatasetDict
from sklearn.model_selection import train_test_split
import numpy as np
import evaluate
import wandb
import os
from huggingface_hub import notebook_login, HfFolder
import warnings
warnings.filterwarnings('ignore')

## Configuration

In [None]:
# Configuration
MODEL_NAME = "t5-base"  # Base T5 model
OUTPUT_DIR = "./t5-medical-simplifier"  # Local output directory
HF_MODEL_NAME = "your_username/t5-medical-simplifier"  # Replace with your HF username
MAX_INPUT_LENGTH = 512
MAX_TARGET_LENGTH = 128
TRAIN_BATCH_SIZE = 8
EVAL_BATCH_SIZE = 8
LEARNING_RATE = 3e-4
NUM_EPOCHS = 3
WARMUP_STEPS = 500
EVAL_STEPS = 500
SAVE_STEPS = 500

## Authentication

In [None]:
# Login to Hugging Face (optional - for saving to hub)
notebook_login()

In [None]:
# Login to Weights & Biases (optional - for experiment tracking)
from getpass import getpass
try:
    wandb_key = getpass("Enter your WandB API key (or press Enter to skip): ")
    if wandb_key:
        wandb.login(key=wandb_key)
        USE_WANDB = True
    else:
        USE_WANDB = False
        print("Skipping WandB logging")
except:
    USE_WANDB = False
    print("WandB setup skipped")

## Load and Prepare Dataset

In [None]:
# Mount Google Drive to access your dataset
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Load your training data
# Option 1: Load from CSV file
DATA_PATH = "/content/drive/MyDrive/medical_simplification_data.csv"  # Update this path

# Load the dataset
try:
    df = pd.read_csv(DATA_PATH)
    print(f"Loaded dataset with {len(df)} examples")
    print(f"Columns: {df.columns.tolist()}")
    print("\nFirst 3 examples:")
    print(df.head(3))
except FileNotFoundError:
    print(f"File not found: {DATA_PATH}")
    print("Please upload your CSV file with 'medical' and 'simple' columns")

    # Create a comprehensive sample dataset for testing with complex medical texts
    sample_data = {
        'medical': [
            'The patient presents with acute exacerbation of chronic obstructive pulmonary disease with dyspnea and productive cough',
            'Minoxidil is a potent direct-acting peripheral vasodilator that relaxes vascular smooth muscle',
            'The patient exhibits signs of acute myocardial infarction with ST-segment elevation in leads II, III, and aVF',
            'Chronic kidney disease stage 4 with estimated glomerular filtration rate of 25 ml/min/1.73m²',
            'The patient underwent laparoscopic cholecystectomy for symptomatic cholelithiasis',
            'Post-operative complications include nosocomial pneumonia and acute respiratory distress syndrome',
            'Hypertensive emergency with blood pressure 220/120 mmHg and signs of end-organ damage',
            'The patient has diabetes mellitus type 2 with diabetic nephropathy and proliferative retinopathy',
            'Acute cerebrovascular accident in the distribution of the middle cerebral artery with contralateral hemiparesis',
            'The electrocardiogram shows atrial fibrillation with rapid ventricular response and occasional premature ventricular contractions',
            'Gastroesophageal reflux disease with Barrett\'s esophagus and moderate dysplasia requiring endoscopic surveillance',
            'The patient underwent percutaneous coronary intervention with drug-eluting stent placement in the left anterior descending artery',
            'Chronic heart failure with reduced ejection fraction secondary to ischemic cardiomyopathy',
            'The computed tomography scan reveals bilateral pulmonary embolism with right heart strain',
            'Metformin is a biguanide antidiabetic agent that decreases hepatic glucose production',
            'The patient developed sepsis secondary to urinary tract infection with multiorgan dysfunction',
            'Idiopathic pulmonary fibrosis with usual interstitial pneumonia pattern on high-resolution computed tomography',
            'The patient exhibits symptoms of major depressive disorder with anhedonia and psychomotor retardation',
            'Acute pancreatitis with elevated serum lipase and characteristic findings on contrast-enhanced computed tomography',
            'The patient has osteoarthritis with joint space narrowing and osteophyte formation on radiographic imaging',
            'Subcutaneous emphysema and pneumomediastinum following esophageal perforation',
            'The patient underwent total hip arthroplasty for avascular necrosis of the femoral head',
            'Acute appendicitis with perforation and localized peritonitis requiring emergent appendectomy',
            'The medication causes dose-dependent hepatotoxicity with elevated transaminases and hyperbilirubinemia'
        ],
        'simple': [
            'The patient has a lung disease flare-up that makes it hard to breathe and causes coughing with mucus',
            'Minoxidil is a medication that widens blood vessels and is used to treat hair loss and high blood pressure',
            'The patient is having a heart attack with specific changes showing on the heart monitor',
            'The patient has advanced kidney disease where the kidneys only work at 25% of normal function',
            'The patient had surgery to remove their gallbladder using small cuts and a camera',
            'After surgery, the patient got a lung infection and had trouble breathing',
            'The patient has dangerously high blood pressure that is damaging other organs',
            'The patient has diabetes that has damaged their kidneys and eyes',
            'The patient had a stroke that affects movement on one side of their body',
            'The heart monitor shows an irregular heartbeat that is too fast with extra heartbeats',
            'The patient has severe acid reflux that has caused changes in the food pipe that need monitoring',
            'The patient had a procedure to open a blocked heart artery and place a small tube to keep it open',
            'The patient has a weak heart muscle that doesn\'t pump blood well due to previous heart damage',
            'The CT scan shows blood clots in both lungs that are straining the heart',
            'Metformin is a diabetes medication that helps reduce the amount of sugar the liver makes',
            'The patient has a serious infection from a urinary tract infection that is affecting multiple organs',
            'The patient has a lung disease where scar tissue makes it hard to breathe',
            'The patient has depression with loss of interest in activities and slowed movements',
            'The patient has inflammation of the pancreas with high enzyme levels and specific CT scan findings',
            'The patient has joint wear and tear with bone changes visible on X-rays',
            'Air has leaked under the skin and around the heart due to a tear in the food pipe',
            'The patient had hip replacement surgery because the hip bone died from lack of blood supply',
            'The patient had a burst appendix with infection that required emergency surgery to remove it',
            'This medication can damage the liver depending on the dose, causing high liver enzymes and yellowing'
        ]
    }
    df = pd.DataFrame(sample_data)
    print(f"Using comprehensive sample dataset with {len(df)} complex medical examples for demonstration")

In [None]:
# Clean and prepare data
df = df.dropna(subset=['medical', 'simple'])
df = df[df['medical'].str.strip() != '']
df = df[df['simple'].str.strip() != '']

# Add prefix to input (T5 expects task prefix)
df['input_text'] = "simplify: " + df['medical'].astype(str)
df['target_text'] = df['simple'].astype(str)

print(f"Cleaned dataset size: {len(df)}")
print("\nSample input-output pair:")
print(f"Input: {df['input_text'].iloc[0]}")
print(f"Target: {df['target_text'].iloc[0]}")

In [None]:
# Split data
train_df, val_df = train_test_split(df, test_size=0.1, random_state=42)

print(f"Training examples: {len(train_df)}")
print(f"Validation examples: {len(val_df)}")

# Convert to Hugging Face datasets
train_dataset = Dataset.from_pandas(train_df[['input_text', 'target_text']])
val_dataset = Dataset.from_pandas(val_df[['input_text', 'target_text']])

dataset = DatasetDict({
    'train': train_dataset,
    'validation': val_dataset
})

print("\nDataset created successfully!")

## Load Model and Tokenizer

In [None]:
# Load tokenizer and model
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME)

print(f"Model loaded: {MODEL_NAME}")
print(f"Model parameters: {model.num_parameters():,}")

In [None]:
# Tokenization function
def preprocess_function(examples):
    inputs = examples['input_text']
    targets = examples['target_text']

    # Tokenize inputs
    model_inputs = tokenizer(
        inputs,
        max_length=MAX_INPUT_LENGTH,
        truncation=True,
        padding=True
    )

    # Tokenize targets
    labels = tokenizer(
        targets,
        max_length=MAX_TARGET_LENGTH,
        truncation=True,
        padding=True
    )

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# Apply tokenization
tokenized_datasets = dataset.map(preprocess_function, batched=True)
print("Tokenization completed!")

## Setup Training

In [None]:
# Data collator
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

# Evaluation metrics
rouge = evaluate.load("rouge")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred

    # Decode predictions and labels
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)

    # Replace -100 in labels with pad token id
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Compute ROUGE scores
    result = rouge.compute(
        predictions=decoded_preds,
        references=decoded_labels,
        use_stemmer=True
    )

    return {
        "rouge1": result["rouge1"],
        "rouge2": result["rouge2"],
        "rougeL": result["rougeL"],
    }

In [None]:
# Training arguments
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=TRAIN_BATCH_SIZE,
    per_device_eval_batch_size=EVAL_BATCH_SIZE,
    warmup_steps=WARMUP_STEPS,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=100,
    evaluation_strategy="steps",
    eval_steps=EVAL_STEPS,
    save_steps=SAVE_STEPS,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="eval_rouge1",
    greater_is_better=True,
    learning_rate=LEARNING_RATE,
    predict_with_generate=True,
    generation_max_length=MAX_TARGET_LENGTH,
    fp16=torch.cuda.is_available(),
    push_to_hub=False,  # Set to True if you want to push to HF Hub
    report_to="wandb" if USE_WANDB else None,
    run_name="t5-medical-simplifier",
)

print("Training arguments configured!")

In [None]:
# Initialize trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

print("Trainer initialized!")

## Training

In [None]:
# Start training
print("Starting training...")
trainer.train()

In [None]:
# Evaluate the model
print("\nEvaluating model...")
eval_results = trainer.evaluate()
print(f"Evaluation results: {eval_results}")

## Test the Model

In [None]:
# Test function
def simplify_medical_text(text, max_length=128):
    input_text = f"simplify: {text}"
    inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)

    with torch.no_grad():
        outputs = model.generate(
            inputs.input_ids,
            max_length=max_length,
            num_beams=4,
            early_stopping=True,
            temperature=0.7
        )

    simplified = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return simplified

# Test examples
test_examples = [
    "minoxidil",
    "hypertension",
    "myocardial infarction",
    "The patient presents with acute exacerbation of chronic obstructive pulmonary disease",
]

print("Testing the trained model:\n")
for example in test_examples:
    simplified = simplify_medical_text(example)
    print(f"Input: {example}")
    print(f"Simplified: {simplified}\n")

## Save Model

In [None]:
# Save the model and tokenizer locally
model_save_path = "./final_model"
tokenizer_save_path = "./final_tokenizer"

# Save model
model.save_pretrained(model_save_path, safe_serialization=True)
print(f"Model saved to {model_save_path}")

# Save tokenizer
tokenizer.save_pretrained(tokenizer_save_path)
print(f"Tokenizer saved to {tokenizer_save_path}")

# Create directory structure for your backend
backend_model_path = "./anishbasnet/t5-base-ft-medical-simplifier"
os.makedirs(f"{backend_model_path}/model", exist_ok=True)
os.makedirs(f"{backend_model_path}/tokenizer", exist_ok=True)

# Save in backend format
model.save_pretrained(f"{backend_model_path}/model", safe_serialization=True)
tokenizer.save_pretrained(f"{backend_model_path}/tokenizer")

print(f"\nModel saved in backend format at: {backend_model_path}")
print("You can now download this folder and place it in your backend/models/ directory")

In [None]:
# Compress for easy download
import shutil

# Create zip file
shutil.make_archive(
    "t5-base-ft-medical-simplifier",
    'zip',
    "./anishbasnet"
)

print("Model compressed as 't5-base-ft-medical-simplifier.zip'")
print("Download this file and extract it to your backend/models/ directory")

## Optional: Push to Hugging Face Hub

In [None]:
# Uncomment and run this cell if you want to push to Hugging Face Hub
# Make sure you're logged in and have set HF_MODEL_NAME

# trainer.push_to_hub(HF_MODEL_NAME)
# tokenizer.push_to_hub(HF_MODEL_NAME)
# print(f"Model pushed to Hugging Face Hub: {HF_MODEL_NAME}")

## Summary

### Files to Transfer:
1. **Download the compressed file**: `t5-base-ft-medical-simplifier.zip`
2. **Extract to your backend**: `backend/models/anishbasnet/t5-base-ft-medical-simplifier/`
3. **Structure should be**:
   ```
   backend/models/anishbasnet/t5-base-ft-medical-simplifier/
   ├── model/
   │   ├── config.json
   │   ├── generation_config.json
   │   └── model.safetensors
   └── tokenizer/
       ├── special_tokens_map.json
       ├── spiece.model
       ├── tokenizer_config.json
       └── tokenizer.json
   ```

### After Transfer:
1. Restart your backend server
2. Test with: `python test_minoxidil.py`
3. The model should now give correct explanations!

### Training Tips:
- Use a larger dataset for better results (combine all 5 datasets from Dataset_creation.ipynb)
- Increase epochs if you have more data
- Monitor the ROUGE scores during training
- Add more diverse medical examples including drug names like minoxidil