# Medical Text Extraction with Small Language Models

This tutorial demonstrates how to fine-tune a small language model (3-7B parameters) for extracting structured information from medical texts, specifically focusing on echocardiography reports from MIMIC-III.

## Overview

1. **Setup & Data Connection**
   - Environment setup with Unsloth for T4 GPU optimization
   - MIMIC-III data access and echo report extraction
   - BigQuery connection setup

2. **Data Preprocessing**
   - Text cleaning and standardization
   - Structured data formatting
   - Training data preparation

3. **Model Training**
   - Small LLM loading and configuration
   - LoRA setup for efficient fine-tuning
   - Training pipeline implementation

4. **Evaluation & Testing**
   - Model evaluation on test set
   - Metric computation and analysis
   - Error analysis and quality assessment

> **Note**: This notebook requires a GPU runtime. Please select Runtime > Change runtime type > GPU before proceeding.

## 1. Setup & Data Connection

First, let's set up our environment with the necessary packages. We'll use Unsloth for optimized training on T4 GPUs and install other required dependencies.

In [None]:
# Install required packages
!pip install -q torch==2.1.2 accelerate==0.27.0 bitsandbytes==0.41.3
!pip install -q unsloth
!pip install -q google-cloud-bigquery pandas numpy tqdm scikit-learn
!pip install -q transformers==4.38.2 datasets==2.16.1
!pip install -q wandb  # For experiment tracking

import IPython
IPython.display.clear_output()
print("✅ Packages installed successfully!")

In [None]:
# Import required libraries
import os
import json
import numpy as np
import pandas as pd
from google.cloud import bigquery
from google.colab import auth
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    BitsAndBytesConfig
)
from unsloth import FastLanguageModel
from datasets import Dataset
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm

# Set random seed for reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

# Check GPU availability
print(f"GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU Device: {torch.cuda.get_device_name(0)}")

### Connect to BigQuery and Extract MIMIC-III Data

Now we'll set up the connection to BigQuery and extract echocardiography reports from MIMIC-III. You'll need to have access to the MIMIC-III dataset in BigQuery and appropriate credentials.

In [None]:
# Authenticate with Google Cloud
auth.authenticate_user()

# Initialize BigQuery client
client = bigquery.Client()

# SQL query to extract echo reports
QUERY = """
SELECT 
    n.row_id,
    n.subject_id,
    n.hadm_id,
    n.text
FROM `physionet-data.mimiciii_notes.noteevents` n
WHERE n.category = 'Echo'
    AND n.text IS NOT NULL
    AND CHAR_LENGTH(n.text) > 100
LIMIT 1000  # Adjust based on your needs
"""

# Execute query and load into DataFrame
df_echo = client.query(QUERY).to_dataframe()
print(f"Retrieved {len(df_echo)} echo reports")

# Save a sample report
sample_report = df_echo.iloc[0]['text']
print("\nSample Echo Report (truncated):")
print("="*80)
print(sample_report[:500], "...")

## 2. Data Preprocessing

Now we'll clean and prepare the echo reports for training. We'll:
1. Clean and standardize the text
2. Extract key measurements and findings
3. Create a structured JSON schema
4. Format the data for instruction fine-tuning

In [None]:
def clean_echo_text(text):
    """Clean and standardize echo report text."""
    # Remove redundant whitespace
    text = ' '.join(text.split())
    
    # Standardize common measurements
    text = text.replace('ejection fraction', 'EF')
    text = text.replace('left ventricular', 'LV')
    text = text.replace('right ventricular', 'RV')
    
    # Convert to lowercase for consistency
    text = text.lower()
    
    return text

def extract_measurements(text):
    """Extract key measurements from echo report."""
    measurements = {
        'ef': None,
        'lv_size': None,
        'rv_size': None,
        'valve_status': {},
    }
    
    # Example simple pattern matching (in practice, use more robust NLP)
    if 'ef' in text:
        # Look for EF percentage
        import re
        ef_match = re.search(r'ef[:\s]+(\d{1,2})[-%]', text)
        if ef_match:
            measurements['ef'] = int(ef_match.group(1))
    
    return measurements

# Process the echo reports
processed_data = []
for _, row in tqdm(df_echo.iterrows(), total=len(df_echo)):
    clean_text = clean_echo_text(row['text'])
    measurements = extract_measurements(clean_text)
    
    processed_data.append({
        'report_id': row['row_id'],
        'text': clean_text,
        'measurements': measurements
    })

# Convert to DataFrame
df_processed = pd.DataFrame(processed_data)
print(f"\nProcessed {len(df_processed)} reports")
print("\nSample processed data:")
print(json.dumps(df_processed.iloc[0].to_dict(), indent=2))

In [None]:
# Format data for instruction fine-tuning
def create_instruction_format(row):
    """Create instruction-input-output format for training."""
    instruction = """Extract key measurements and findings from the following echocardiogram report. 
                    Return the results in JSON format including: EF, LV size, RV size, and valve status."""
    
    input_text = row['text']
    
    # Format output as a clean JSON string
    output = json.dumps(row['measurements'], indent=2)
    
    return {
        'instruction': instruction,
        'input': input_text,
        'output': output
    }

# Create training data
training_data = [create_instruction_format(row) for _, row in df_processed.iterrows()]

# Split into train, validation, and test sets
train_data, temp_data = train_test_split(training_data, test_size=0.3, random_state=SEED)
val_data, test_data = train_test_split(temp_data, test_size=0.5, random_state=SEED)

print(f"Dataset splits:")
print(f"Training: {len(train_data)}")
print(f"Validation: {len(val_data)}")
print(f"Test: {len(test_data)}")

# Create HuggingFace datasets
train_dataset = Dataset.from_list(train_data)
val_dataset = Dataset.from_list(val_data)
test_dataset = Dataset.from_list(test_data)

# Save a sample for reference
print("\nSample training instance:")
print(json.dumps(train_data[0], indent=2))

## 3. Model Training

Now we'll set up and train our model using Unsloth's optimized training pipeline. We'll use a small language model (Phi-2) and configure it for efficient fine-tuning with LoRA.

In [None]:
# Model configuration
MODEL_NAME = "microsoft/phi-2"  # 2.7B parameter model
OUTPUT_DIR = "medical_extraction_model"

# Configure quantization for efficient training
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True
)

# Load model and tokenizer with Unsloth optimization
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=MODEL_NAME,
    max_sequence_length=2048,
    dtype=None,
    load_in_4bit=True
)

# Training arguments
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    weight_decay=0.01,
    fp16=True,
    logging_steps=10,
    eval_steps=50,
    save_steps=100,
    warmup_steps=100,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_total_limit=2,
    load_best_model_at_end=True,
    report_to="wandb"  # Enable W&B logging
)

# Configure LoRA
model = FastLanguageModel.get_peft_model(
    model,
    r=64,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_alpha=16,
    lora_dropout=0.1
)

print("Model and training configuration complete!")
print(f"Model parameters: {model.num_parameters():,}")
print(f"Training on device: {model.device}")

In [None]:
# Training function
def train_model():
    # Initialize Weights & Biases
    import wandb
    wandb.init(project="medical-extraction", name="phi2-echo-extraction")
    
    # Create the trainer
    trainer = FastLanguageModel.get_trainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset
    )
    
    # Train the model
    trainer.train()
    
    # Save the final model
    trainer.save_model(OUTPUT_DIR)
    
    return trainer

# Start training
trainer = train_model()
print("Training complete!")

## 4. Evaluation & Testing

Now we'll evaluate our model's performance on the test set and analyze its accuracy in extracting medical information.

In [None]:
def evaluate_extraction(true_json, pred_json):
    """Evaluate the accuracy of extracted fields."""
    metrics = {
        'ef_accuracy': 0,
        'size_accuracy': 0,
        'valve_accuracy': 0,
        'overall_accuracy': 0
    }
    
    try:
        true_data = json.loads(true_json) if isinstance(true_json, str) else true_json
        pred_data = json.loads(pred_json) if isinstance(pred_json, str) else pred_json
        
        # Check EF accuracy
        if true_data['ef'] == pred_data['ef']:
            metrics['ef_accuracy'] = 1
            
        # Check size measurements
        size_correct = (true_data['lv_size'] == pred_data['lv_size'] and
                       true_data['rv_size'] == pred_data['rv_size'])
        metrics['size_accuracy'] = 1 if size_correct else 0
        
        # Check valve status
        valve_correct = true_data['valve_status'] == pred_data['valve_status']
        metrics['valve_accuracy'] = 1 if valve_correct else 0
        
        # Calculate overall accuracy
        metrics['overall_accuracy'] = sum([
            metrics['ef_accuracy'],
            metrics['size_accuracy'],
            metrics['valve_accuracy']
        ]) / 3
        
    except (json.JSONDecodeError, KeyError, TypeError) as e:
        print(f"Error in evaluation: {e}")
    
    return metrics

# Generate predictions on test set
def generate_predictions(model, test_samples):
    predictions = []
    
    for sample in tqdm(test_samples, desc="Generating predictions"):
        input_text = f"{sample['instruction']}\n\n{sample['input']}"
        
        # Generate prediction
        inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=2048)
        outputs = model.generate(
            input_ids=inputs["input_ids"].to(model.device),
            max_new_tokens=512,
            temperature=0.7,
            do_sample=False
        )
        
        prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
        predictions.append(prediction)
    
    return predictions

# Evaluate model on test set
predictions = generate_predictions(model, test_data)

# Calculate metrics
all_metrics = []
for test_sample, pred in zip(test_data, predictions):
    metrics = evaluate_extraction(test_sample['output'], pred)
    all_metrics.append(metrics)

# Calculate average metrics
avg_metrics = {
    key: np.mean([m[key] for m in all_metrics])
    for key in all_metrics[0].keys()
}

print("\nTest Set Metrics:")
for metric, value in avg_metrics.items():
    print(f"{metric}: {value:.2%}")

# Show sample predictions
print("\nSample Predictions:")
for i in range(min(3, len(test_data))):
    print(f"\nExample {i+1}:")
    print("Input:", test_data[i]['input'][:200], "...")
    print("\nTrue Output:", test_data[i]['output'])
    print("\nPredicted Output:", predictions[i])

## Conclusion

This tutorial demonstrated how to:
1. Set up a medical text extraction pipeline using a small language model
2. Process and prepare MIMIC-III echocardiography data
3. Fine-tune the model efficiently using LoRA and Unsloth
4. Evaluate the model's performance on structured information extraction

The trained model can be used to automatically extract key measurements and findings from echocardiography reports, helping to standardize and structure medical information.

### Next Steps

1. Experiment with different model architectures (e.g., Llama-2-7b, MPT-7b)
2. Improve the extraction patterns for better accuracy
3. Add more structured fields to extract
4. Implement clinical validation metrics
5. Deploy the model in a clinical setting with proper validation