# 🏥 Medical Text Classification - BiomedBERT Training (Google Colab)

**Purpose:** Train BiomedBERT model for medical text classification with GPU acceleration

**Model:** `microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext`

**Expected Accuracy:** ~99%

**Training Time:** ~10-15 minutes on Colab GPU

---

## 📋 Instructions:
1. **Enable GPU:** Runtime → Change runtime type → GPU (T4 or better)
2. **Upload Dataset:** Upload `medical_texts.csv` to Colab files
3. **Run All Cells:** Runtime → Run all
4. **Download Model:** After training, download the `biomedbert_model/` folder
5. **Deploy:** Place the model folder in your project's `models/` directory

---

## ⚙️ Hyperparameters (Matching Original Notebook):
- **Learning Rate:** 3e-5 (0.00003)
- **Batch Size:** 16 (train), 8 (eval)
- **Epochs:** 10
- **Max Sequence Length:** 512
- **Optimizer:** AdamW
- **Train/Test Split:** 80/20 (stratified)
- **Keyword Masking:** Enabled (removes disease names to prevent data leakage)

## 1️⃣ Setup & Installation

In [None]:
# Install required packages
!pip install -q transformers torch pandas scikit-learn numpy

In [None]:
# Import libraries
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from transformers import AutoTokenizer, AutoModel
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import re
import warnings
warnings.filterwarnings('ignore')

# Check GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Using device: {device}")
if device.type == 'cuda':
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 2️⃣ Upload Dataset

**Upload your `medical_texts.csv` file using the file upload button below:**

In [None]:
from google.colab import files

print("📁 Please upload your medical_texts.csv file:")
uploaded = files.upload()

# Verify upload
if 'medical_texts.csv' in uploaded:
    print("✅ File uploaded successfully!")
else:
    print("❌ Error: medical_texts.csv not found. Please upload the file.")

## 3️⃣ Load & Prepare Data (Exact Notebook Approach)

In [None]:
# Load CSV file
df = pd.read_csv('medical_texts.csv')
print(f"📊 Loaded {len(df)} records from CSV")
print(f"\nColumns: {df.columns.tolist()}")
print(f"\nFirst 3 rows:")
print(df.head(3))

In [None]:
# Define focus area mapping (EXACTLY from original notebook)
focus_area_map = {
    'Cancers': ['Breast Cancer', 'Prostate Cancer', 'Skin Cancer', 'Colorectal Cancer', 'Lung Cancer', 'Leukemia'],
    'Cardiovascular Diseases': ['Stroke', 'Heart Failure', 'Heart Attack', 'High Blood Cholesterol', 'High Blood Pressure'],
    'Metabolic & Endocrine Disorders': ['Causes of Diabetes', 'Diabetes', 'Diabetic Retinopathy', 'Hemochromatosis', 'Kidney Disease'],
    'Neurological & Cognitive Disorders': ["Alzheimer's Disease", "Parkinson's Disease", 'Balance Problems'],
    'Other Age-Related & Immune Disorders': ['Shingles', 'Osteoporosis', 'Age-related Macular Degeneration', 'Psoriasis', 'Gum (Periodontal) Disease', 'Dry Mouth']
}

# Create reverse mapping (condition -> focus_group)
condition_to_focus_area = {
    condition: focus_area
    for focus_area, conditions in focus_area_map.items()
    for condition in conditions
}

# Map focus_area to focus_group
df['focus_group'] = df['focus_area'].map(condition_to_focus_area)

# Drop rows without focus_group (not in top 25 focus areas)
df = df.dropna(subset=['focus_group', 'answer'])

print(f"✅ After filtering to 25 focus areas: {len(df)} records")
print(f"\n📊 Focus group distribution:")
print(df['focus_group'].value_counts())

In [None]:
# Remove duplicates based on answer column (like original notebook)
df = df.drop_duplicates(subset='answer')
print(f"✅ After deduplication: {len(df)} records")
print(f"   Expected: ~628 records (matching original notebook)")

In [None]:
# Encode focus groups to numeric labels (EXACTLY from original notebook)
focus_map = {
    'Neurological & Cognitive Disorders': 0,
    'Cancers': 1,
    'Cardiovascular Diseases': 2,
    'Metabolic & Endocrine Disorders': 3,
    'Other Age-Related & Immune Disorders': 4
}

df['label'] = df['focus_group'].map(focus_map)

print(f"✅ Label encoding complete")
print(f"\n📊 Label distribution:")
print(df['label'].value_counts().sort_index())

In [None]:
# Extract texts and labels (use 'answer' column like original notebook)
texts = df['answer'].tolist()
labels = df['label'].tolist()

# Train/test split (80/20, stratified by label)
train_texts, test_texts, train_labels, test_labels = train_test_split(
    texts, labels, test_size=0.2, random_state=42, stratify=labels
)

print(f"✅ Train/Test split complete")
print(f"   Train samples: {len(train_texts)}")
print(f"   Test samples: {len(test_texts)}")
print(f"   Expected: ~484 train, ~121 test")

## 4️⃣ Text Preprocessing - Keyword Masking (Critical Step!)

**This removes disease keywords to prevent data leakage and force the model to learn from symptoms/descriptions.**

In [None]:
# Define keywords to remove (EXACTLY from original notebook)
remove_keywords = [
    'Breast Cancer', 'Prostate Cancer', 'Skin Cancer',
    'Colorectal Cancer', 'Lung Cancer', 'Leukemia', 'Stroke', 'Heart Failure', 'Heart Attack',
    'High Blood Cholesterol', 'High Blood Pressure', 'Causes of Diabetes', 'Diabetes', 'Diabetic Retinopathy',
    'Hemochromatosis', 'Kidney Disease', 'Alzheimer\'s Disease', 'Parkinson\'s Disease', 'Balance Problems',
    'Shingles', 'Osteoporosis', 'Age-related Macular Degeneration', 'Psoriasis', 'Gum (Periodontal) Disease', 'Dry Mouth'
]

# Split all multi-word phrases into individual words
words_to_remove = set()
for keyword in remove_keywords:
    for word in re.findall(r'\b\w+\b', keyword):
        words_to_remove.add(word.lower())  # lowercased for case-insensitive match

# Create regex pattern to match any of the words
pattern = re.compile(r'\b(?:' + '|'.join(map(re.escape, words_to_remove)) + r')\b', flags=re.IGNORECASE)

# Remove individual words from each training text
masked_train_texts = [pattern.sub('', text) for text in train_texts]

# Normalize whitespace
masked_train_texts = [re.sub(r'\s+', ' ', text).strip() for text in masked_train_texts]

print(f"✅ Keyword masking complete")
print(f"   Removed {len(words_to_remove)} unique words")
print(f"\n📝 Example (before masking):")
print(train_texts[0][:200])
print(f"\n📝 Example (after masking):")
print(masked_train_texts[0][:200])

## 5️⃣ Load BiomedBERT Model & Tokenizer

In [None]:
# Load BiomedBERT tokenizer and model
MODEL_NAME = 'microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext'

print(f"📥 Loading BiomedBERT model: {MODEL_NAME}")
bert_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
bert_base = AutoModel.from_pretrained(MODEL_NAME)

print(f"✅ Model loaded successfully!")
print(f"   Model parameters: {sum(p.numel() for p in bert_base.parameters()) / 1e6:.1f}M")

In [None]:
# Create BiomedBERT classifier (BERT base + classification head)
class BiomedBERTClassifier(nn.Module):
    def __init__(self, bert_model, num_classes=5):
        super(BiomedBERTClassifier, self).__init__()
        self.bert = bert_model
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes)
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output  # [CLS] token representation
        logits = self.classifier(pooled_output)
        return logits

# Initialize model
bert_model = BiomedBERTClassifier(bert_base, num_classes=5).to(device)

print(f"✅ Classifier created")
print(f"   Total parameters: {sum(p.numel() for p in bert_model.parameters()) / 1e6:.1f}M")
print(f"   Trainable parameters: {sum(p.numel() for p in bert_model.parameters() if p.requires_grad) / 1e6:.1f}M")

## 6️⃣ Tokenization & Data Preparation

In [None]:
# Tokenize training data (EXACTLY from original notebook)
MAX_SEQ_LENGTH = 512

print(f"🔤 Tokenizing training data (max_length={MAX_SEQ_LENGTH})...")
X_train = bert_tokenizer(
    masked_train_texts,  # Use masked texts (keywords removed)
    padding=True,
    truncation=True,
    return_tensors="pt",
    max_length=MAX_SEQ_LENGTH
)
y_train = torch.tensor(train_labels, dtype=torch.long)

# Create training dataset and dataloader
train_dataset = TensorDataset(
    X_train['input_ids'].to(device),
    X_train['attention_mask'].to(device),
    y_train.to(device)
)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)  # batch_size=16 from notebook

print(f"✅ Training data prepared")
print(f"   Batches: {len(train_loader)}")
print(f"   Batch size: 16")

In [None]:
# Tokenize test data (NO masking for test data - evaluate on real data)
print(f"🔤 Tokenizing test data (max_length={MAX_SEQ_LENGTH})...")
X_test = bert_tokenizer(
    test_texts,  # Use original test texts (no masking)
    padding=True,
    truncation=True,
    return_tensors="pt",
    max_length=MAX_SEQ_LENGTH
)
y_test = torch.tensor(test_labels, dtype=torch.long)

# Create test dataset and dataloader
test_dataset = TensorDataset(
    X_test['input_ids'].to(device),
    X_test['attention_mask'].to(device),
    y_test.to(device)
)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)  # batch_size=8 from notebook

print(f"✅ Test data prepared")
print(f"   Batches: {len(test_loader)}")
print(f"   Batch size: 8")

## 7️⃣ Training Configuration & Training Loop

In [None]:
# Training configuration (EXACTLY from original notebook)
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, bert_model.parameters()),
    lr=0.00003  # 3e-5 from notebook
)
criterion = nn.CrossEntropyLoss()
num_epochs = 10  # From notebook

print(f"⚙️ Training Configuration:")
print(f"   Optimizer: AdamW")
print(f"   Learning Rate: 3e-5 (0.00003)")
print(f"   Epochs: {num_epochs}")
print(f"   Loss Function: CrossEntropyLoss")
print(f"   Device: {device}")

In [None]:
# Training loop (EXACTLY from original notebook)
print(f"\n🚀 Starting training...\n")

for epoch in range(num_epochs):
    bert_model.train()
    total_loss = 0.0
    
    for batch_X, batch_attention_mask, batch_y in train_loader:
        optimizer.zero_grad()
        
        # Forward pass
        outputs = bert_model(batch_X, batch_attention_mask)
        loss = criterion(outputs, batch_y)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch + 1}/{num_epochs}, Average CE Loss: {avg_loss}")

print(f"\n✅ Training complete!")

## 8️⃣ Model Evaluation

In [None]:
# Evaluate on test set
print(f"📊 Evaluating model on test set...\n")

bert_model.eval()
all_predictions = []
all_labels = []

with torch.no_grad():
    for batch_X, batch_attention_mask, batch_y in test_loader:
        outputs = bert_model(batch_X, batch_attention_mask)
        predictions = torch.argmax(outputs, dim=1)
        
        all_predictions.extend(predictions.cpu().numpy())
        all_labels.extend(batch_y.cpu().numpy())

# Calculate accuracy
accuracy = accuracy_score(all_labels, all_predictions)
print(f"🎯 Test Accuracy: {accuracy * 100:.2f}%")
print(f"   Expected: ~99%\n")

# Classification report
label_names = [
    'Neurological & Cognitive',
    'Cancers',
    'Cardiovascular',
    'Metabolic & Endocrine',
    'Other Age-Related & Immune'
]
print(f"📋 Classification Report:\n")
print(classification_report(all_labels, all_predictions, target_names=label_names))

# Confusion matrix
print(f"\n🔢 Confusion Matrix:\n")
print(confusion_matrix(all_labels, all_predictions))

## 9️⃣ Save Model for Production Deployment

In [None]:
# Save model and tokenizer
import os

output_dir = 'biomedbert_model'
os.makedirs(output_dir, exist_ok=True)

# Save the entire model (state dict + architecture)
torch.save({
    'model_state_dict': bert_model.state_dict(),
    'model_config': {
        'num_classes': 5,
        'hidden_size': bert_base.config.hidden_size,
        'model_name': MODEL_NAME
    },
    'label_mapping': focus_map,
    'accuracy': accuracy
}, f'{output_dir}/model.pt')

# Save tokenizer
bert_tokenizer.save_pretrained(output_dir)

# Save label mapping as JSON
import json
with open(f'{output_dir}/label_mapping.json', 'w') as f:
    json.dump(focus_map, f, indent=2)

# Save reverse mapping (for predictions)
reverse_focus_map = {v: k for k, v in focus_map.items()}
with open(f'{output_dir}/reverse_label_mapping.json', 'w') as f:
    json.dump(reverse_focus_map, f, indent=2)

print(f"✅ Model saved to '{output_dir}/' directory")
print(f"\n📦 Files saved:")
print(f"   - model.pt (PyTorch model weights + config)")
print(f"   - tokenizer files (vocab, config, etc.)")
print(f"   - label_mapping.json (focus group → label)")
print(f"   - reverse_label_mapping.json (label → focus group)")
print(f"\n📥 Download the entire '{output_dir}/' folder and place it in your project's 'models/' directory")

## 🔟 Download Model Files

**Run the cell below to download the trained model as a ZIP file:**

In [None]:
# Create ZIP file for easy download
import shutil

zip_filename = 'biomedbert_model'
shutil.make_archive(zip_filename, 'zip', output_dir)

print(f"✅ Model packaged as {zip_filename}.zip")
print(f"\n📥 Downloading...")

# Download the ZIP file
files.download(f'{zip_filename}.zip')

print(f"\n🎉 Download complete!")
print(f"\n📋 Next Steps:")
print(f"   1. Extract the ZIP file")
print(f"   2. Place the 'biomedbert_model/' folder in your project's 'models/' directory")
print(f"   3. Update your FastAPI app to load the model from 'models/biomedbert_model/'")
print(f"   4. Test predictions with your API")
print(f"\n✨ Expected accuracy: {accuracy * 100:.2f}%")

## 📝 How to Use the Model in Your FastAPI App

```python
# In your FastAPI app (e.g., src/main.py or src/model.py)

import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
import json

# Define the same model architecture
class BiomedBERTClassifier(nn.Module):
    def __init__(self, bert_model, num_classes=5):
        super(BiomedBERTClassifier, self).__init__()
        self.bert = bert_model
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes)
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        logits = self.classifier(pooled_output)
        return logits

# Load the model
MODEL_DIR = 'models/biomedbert_model'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)

# Load base BERT model
checkpoint = torch.load(f'{MODEL_DIR}/model.pt', map_location=device)
bert_base = AutoModel.from_pretrained(checkpoint['model_config']['model_name'])

# Create classifier and load weights
model = BiomedBERTClassifier(bert_base, num_classes=5).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Load label mapping
with open(f'{MODEL_DIR}/reverse_label_mapping.json', 'r') as f:
    label_to_focus_group = json.load(f)

# Prediction function
def predict(text: str):
    inputs = tokenizer(
        text,
        padding=True,
        truncation=True,
        return_tensors="pt",
        max_length=512
    ).to(device)
    
    with torch.no_grad():
        outputs = model(inputs['input_ids'], inputs['attention_mask'])
        prediction = torch.argmax(outputs, dim=1).item()
        probabilities = torch.softmax(outputs, dim=1)[0].cpu().numpy()
    
    return {
        'focus_group': label_to_focus_group[str(prediction)],
        'confidence': float(probabilities[prediction]),
        'all_probabilities': {label_to_focus_group[str(i)]: float(prob) for i, prob in enumerate(probabilities)}
    }
```