# üß† Medical Sentiment Analysis - BERT Fine-Tuning

**Train a transformer model for healthcare sentiment classification**

This notebook fine-tunes DistilBERT to classify patient sentiment:
- **Anxious**: Worried, scared, concerned
- **Neutral**: Factual, informational
- **Reassured**: Relieved, grateful, positive

**Author:** Himanshu Sharma  
**For:** Emitrr AI Engineer Intern Assignment

---

## 1. Setup & Installation

In [None]:
# Install dependencies
!pip install -q transformers datasets accelerate scikit-learn

import torch
import numpy as np
import pandas as pd
from datasets import Dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding
)
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

# Check GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"üñ•Ô∏è Using device: {device}")
if device == "cuda":
    print(f"   GPU: {torch.cuda.get_device_name(0)}")

print("\n‚úÖ Setup complete!")

## 2. Prepare Training Data

In [None]:
# Medical Sentiment Training Data
# Labels: 0 = anxious, 1 = neutral, 2 = reassured

TRAINING_DATA = [
    # ANXIOUS (label = 0)
    {"text": "I've been having these terrible headaches for the past two weeks. They're really worrying me.", "label": 0},
    {"text": "I'm scared it might be something serious.", "label": 0},
    {"text": "But what if it's a brain tumor or something?", "label": 0},
    {"text": "Yes, I've been waiting anxiously. What did you find?", "label": 0},
    {"text": "Oh my. That's a lot to take in.", "label": 0},
    {"text": "This is overwhelming. How do we deal with all of this?", "label": 0},
    {"text": "That's a lot of medications. Will they interact with each other?", "label": 0},
    {"text": "Not great, doctor. I've been having more attacks lately, especially at night.", "label": 0},
    {"text": "I was afraid you'd say that. What do you recommend?", "label": 0},
    {"text": "I just have no energy. I drag myself through the day and still feel exhausted.", "label": 0},
    {"text": "Hypothyroidism? Is that serious?", "label": 0},
    {"text": "Is that dangerous? Could it be something worse?", "label": 0},
    {"text": "What do you think it is, doctor?", "label": 0},
    {"text": "Rheumatoid arthritis? That sounds serious. My aunt has that and she can barely use her hands now.", "label": 0},
    {"text": "What happens if it is RA?", "label": 0},
    {"text": "Am I going to get a shot?", "label": 0},
    {"text": "I'm worried about him getting dehydrated.", "label": 0},
    {"text": "Oh no, is something wrong? I've been worrying since I got the call.", "label": 0},
    {"text": "A mass? Does that mean I have cancer?", "label": 0},
    {"text": "What do we do now? I'm really scared.", "label": 0},
    {"text": "A biopsy? That sounds painful.", "label": 0},
    {"text": "What are the chances it's cancer?", "label": 0},
    {"text": "Doctor, it's been terrifying. The room just starts spinning out of nowhere.", "label": 0},
    {"text": "Is that serious? The word vertigo sounds scary.", "label": 0},
    {"text": "The pain is unbearable. I can't sleep at night.", "label": 0},
    {"text": "Will I ever be able to walk normally again?", "label": 0},
    {"text": "What if the medication doesn't work?", "label": 0},
    {"text": "The side effects are really bothering me.", "label": 0},
    {"text": "I've been feeling anxious about the upcoming surgery.", "label": 0},
    {"text": "I don't understand why this keeps happening to me.", "label": 0},
    {"text": "Is this going to affect my ability to work?", "label": 0},
    {"text": "My family is worried about me.", "label": 0},
    {"text": "I'm concerned about the long-term effects.", "label": 0},
    {"text": "I'm afraid the symptoms might come back.", "label": 0},
    {"text": "I can't stop thinking about what might be wrong.", "label": 0},
    {"text": "The numbness in my hands is getting worse.", "label": 0},
    {"text": "What are the possible complications?", "label": 0},
    {"text": "Should I be worried about this symptom?", "label": 0},
    {"text": "I'm scared to do the procedure.", "label": 0},
    {"text": "I've been losing weight without trying.", "label": 0},
    {"text": "I've been having trouble remembering things lately.", "label": 0},
    {"text": "What happens if the condition gets worse?", "label": 0},
    {"text": "These symptoms are really affecting my daily life.", "label": 0},
    {"text": "I'm worried I might have passed this on to my children.", "label": 0},
    {"text": "My blood pressure readings have been very high.", "label": 0},
    {"text": "I've been experiencing chest pains when I exercise.", "label": 0},
    {"text": "What if this is hereditary?", "label": 0},
    {"text": "I've been having panic attacks more frequently.", "label": 0},
    {"text": "I keep thinking the worst is going to happen.", "label": 0},
    
    # NEUTRAL (label = 1)
    {"text": "I'm doing better, but I still have some discomfort now and then.", "label": 1},
    {"text": "Should I get rid of the cat?", "label": 1},
    {"text": "What does that mean for treatment?", "label": 1},
    {"text": "How long will I need to take it?", "label": 1},
    {"text": "Thank you, doctor. I hope this helps because it's really affecting my quality of life.", "label": 1},
    {"text": "My hands and wrists hurt a lot, especially in the morning.", "label": 1},
    {"text": "Okay, I'll do whatever it takes. I want to catch this early.", "label": 1},
    {"text": "Yeah, my tummy hurts a lot. And I've been pooping a lot.", "label": 1},
    {"text": "When can I go back to school?", "label": 1},
    {"text": "Okay. I'll try to stay positive. Thank you for being honest with me.", "label": 1},
    {"text": "Oh! Yes, that's it! The spinning started when you turned my head to the right.", "label": 1},
    {"text": "You can fix it now? That would be amazing.", "label": 1},
    {"text": "For the past three months, I've had this burning feeling in my chest, especially after eating.", "label": 1},
    {"text": "I've been taking my medication as prescribed.", "label": 1},
    {"text": "When will I know if the treatment is working?", "label": 1},
    {"text": "Will my insurance cover this treatment?", "label": 1},
    {"text": "I've been doing the exercises you recommended.", "label": 1},
    {"text": "How often do I need to come back for check-ups?", "label": 1},
    {"text": "Is it okay if I get a second opinion?", "label": 1},
    {"text": "I understand the diagnosis now. Thank you for explaining.", "label": 1},
    {"text": "The doctor said I'm making good progress.", "label": 1},
    
    # REASSURED (label = 2)
    {"text": "That's a relief! Thank you, doctor.", "label": 2},
    {"text": "Thank you, doctor. I'll try not to worry so much.", "label": 2},
    {"text": "Thank you for explaining everything, doctor. I feel more confident about managing this now.", "label": 2},
    {"text": "Thank you, doctor. I appreciate the help.", "label": 2},
    {"text": "That's a relief to have an answer. I was worried something was really wrong with me.", "label": 2},
    {"text": "Thank you, doctor. We were really worried.", "label": 2},
    {"text": "The spinning... it's gone! I can't believe it.", "label": 2},
    {"text": "I'm feeling much better after the treatment.", "label": 2},
    {"text": "My symptoms have improved significantly.", "label": 2},
    {"text": "I'm so glad the test results came back normal.", "label": 2},
    {"text": "That's exactly what I was hoping to hear.", "label": 2},
    {"text": "I can finally do things I couldn't do before.", "label": 2},
    {"text": "The recovery has been faster than I expected.", "label": 2},
    {"text": "I feel like myself again.", "label": 2},
    {"text": "The swelling has gone down considerably.", "label": 2},
    {"text": "Everything looks normal on the scan.", "label": 2},
    {"text": "I've noticed a big improvement since starting the new medication.", "label": 2},
    {"text": "I'm really grateful for your help, doctor.", "label": 2},
    {"text": "The physical therapy has been very helpful.", "label": 2},
    {"text": "My energy levels are back to normal.", "label": 2},
    {"text": "The test came back negative, which is great news.", "label": 2},
    {"text": "I can move my arm freely now without any pain.", "label": 2},
    {"text": "I'm feeling hopeful about the treatment plan.", "label": 2},
    {"text": "I've been sleeping much better since starting the treatment.", "label": 2},
    {"text": "I can finally eat without discomfort.", "label": 2},
    {"text": "The wound is healing nicely.", "label": 2},
    {"text": "I'm happy with how the treatment is going.", "label": 2},
    {"text": "The rash has completely cleared up.", "label": 2},
    {"text": "My cholesterol levels are much better now.", "label": 2},
    {"text": "I'm back to playing sports like before.", "label": 2},
]

# Create DataFrame
df = pd.DataFrame(TRAINING_DATA)

# Label mapping
LABEL_MAP = {0: "anxious", 1: "neutral", 2: "reassured"}
LABEL2ID = {"anxious": 0, "neutral": 1, "reassured": 2}
ID2LABEL = {v: k for k, v in LABEL2ID.items()}

print(f"üìä Dataset Statistics:")
print(f"   Total samples: {len(df)}")
print(f"\n   Class distribution:")
for label_id, label_name in LABEL_MAP.items():
    count = (df['label'] == label_id).sum()
    print(f"   - {label_name}: {count} ({count/len(df)*100:.1f}%)")

In [None]:
# Split into train/validation sets
train_df, val_df = train_test_split(df, test_size=0.2, stratify=df['label'], random_state=42)

print(f"üìà Train/Val Split:")
print(f"   Training: {len(train_df)} samples")
print(f"   Validation: {len(val_df)} samples")

# Convert to HuggingFace Dataset
train_dataset = Dataset.from_pandas(train_df.reset_index(drop=True))
val_dataset = Dataset.from_pandas(val_df.reset_index(drop=True))

dataset = DatasetDict({
    'train': train_dataset,
    'validation': val_dataset
})

print(f"\n‚úÖ Datasets created!")
print(dataset)

## 3. Tokenization

In [None]:
# Load DistilBERT tokenizer
MODEL_NAME = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

def tokenize_function(examples):
    """Tokenize text for BERT."""
    return tokenizer(
        examples["text"],
        padding="max_length",
        truncation=True,
        max_length=128
    )

# Tokenize datasets
tokenized_datasets = dataset.map(tokenize_function, batched=True)

print(f"‚úÖ Tokenization complete!")
print(f"\nSample tokenized entry:")
sample = tokenized_datasets['train'][0]
print(f"   Text: {df.iloc[0]['text'][:50]}...")
print(f"   Input IDs length: {len(sample['input_ids'])}")
print(f"   Label: {ID2LABEL[sample['label']]}")

## 4. Model Setup

In [None]:
# Load pre-trained DistilBERT with classification head
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=3,
    id2label=ID2LABEL,
    label2id=LABEL2ID
)

# Move to GPU if available
model = model.to(device)

print(f"‚úÖ Model loaded: {MODEL_NAME}")
print(f"   Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"   Trainable: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

## 5. Training

In [None]:
def compute_metrics(eval_pred):
    """Compute accuracy, precision, recall, F1."""
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    
    accuracy = accuracy_score(labels, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, predictions, average='weighted'
    )
    
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1
    }

print("‚úÖ Metrics function defined")

In [None]:
# Training arguments
training_args = TrainingArguments(
    output_dir="./medical_sentiment_model",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=5,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    push_to_hub=False,
    logging_steps=10,
    warmup_steps=50,
)

# Data collator
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# Create Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

print("üöÄ Starting training...\n")

In [None]:
# Train the model
train_result = trainer.train()

print("\n‚úÖ Training complete!")
print(f"\nTraining metrics:")
for key, value in train_result.metrics.items():
    print(f"   {key}: {value:.4f}")

## 6. Evaluation

In [None]:
# Evaluate on validation set
eval_results = trainer.evaluate()

print("üìä Validation Results:")
print(f"   Accuracy: {eval_results['eval_accuracy']:.4f}")
print(f"   Precision: {eval_results['eval_precision']:.4f}")
print(f"   Recall: {eval_results['eval_recall']:.4f}")
print(f"   F1 Score: {eval_results['eval_f1']:.4f}")

In [None]:
# Generate predictions for confusion matrix
predictions = trainer.predict(tokenized_datasets["validation"])
preds = np.argmax(predictions.predictions, axis=-1)
labels = predictions.label_ids

# Create confusion matrix
cm = confusion_matrix(labels, preds)

# Plot
plt.figure(figsize=(8, 6))
sns.heatmap(
    cm, 
    annot=True, 
    fmt='d', 
    cmap='Blues',
    xticklabels=['Anxious', 'Neutral', 'Reassured'],
    yticklabels=['Anxious', 'Neutral', 'Reassured']
)
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Sentiment Classification Confusion Matrix')
plt.tight_layout()
plt.show()

## 7. Test Inference

In [None]:
def predict_sentiment(text):
    """Predict sentiment for a single text."""
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model(**inputs)
        probs = torch.softmax(outputs.logits, dim=-1)
        pred = torch.argmax(probs, dim=-1).item()
        confidence = probs[0][pred].item()
    
    return ID2LABEL[pred], confidence

# Test sentences
TEST_SENTENCES = [
    "I'm really worried about these symptoms.",
    "The medication seems to be working well.",
    "That's such a relief to hear, thank you!",
    "Could this be something serious?",
    "I've been taking my pills every day as prescribed.",
    "I feel so much better now, doctor.",
    "Is this going to affect my job?",
    "The pain has completely gone away.",
]

print("üîÆ Testing Model Predictions:\n")
print(f"{'Text':<60} {'Sentiment':<12} {'Confidence'}")
print("-" * 85)

for text in TEST_SENTENCES:
    sentiment, confidence = predict_sentiment(text)
    display_text = text[:57] + "..." if len(text) > 60 else text
    print(f"{display_text:<60} {sentiment:<12} {confidence:.3f}")

## 8. Save the Model

In [None]:
# Save the fine-tuned model
MODEL_PATH = "./medical_sentiment_bert"

model.save_pretrained(MODEL_PATH)
tokenizer.save_pretrained(MODEL_PATH)

print(f"‚úÖ Model saved to: {MODEL_PATH}")

# Verify loading
loaded_model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
loaded_tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
print("‚úÖ Model loaded successfully!")

In [None]:
# Zip for download
!zip -r medical_sentiment_bert.zip medical_sentiment_bert/

try:
    from google.colab import files
    files.download('medical_sentiment_bert.zip')
    print("üì• Model downloaded!")
except:
    print("Model saved as medical_sentiment_bert.zip")

## Summary

### What We Built
- Fine-tuned DistilBERT for medical sentiment classification
- 3 classes: Anxious, Neutral, Reassured
- Trained on 100 labeled patient statements

### Model Performance
- The model should achieve ~80-90% accuracy on validation
- F1 scores indicate balanced performance across classes

### Next Steps for Production
1. **More Data**: Expand to 1000+ examples per class
2. **Domain Models**: Try ClinicalBERT or BioBERT as base
3. **Data Augmentation**: Back-translation, synonym replacement
4. **Hyperparameter Tuning**: Grid search for optimal learning rate
5. **Cross-Validation**: K-fold for robust evaluation
6. **Intent Detection**: Add multi-task learning for intent

### Production Integration
```python
from transformers import pipeline

classifier = pipeline(
    "sentiment-analysis",
    model="./medical_sentiment_bert",
    tokenizer="./medical_sentiment_bert"
)

result = classifier("I'm worried about my symptoms")
print(result)  # [{'label': 'anxious', 'score': 0.95}]
```