In [1]:
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import numpy as np
import pandas as pd
from scipy.stats import entropy
import warnings
warnings.filterwarnings('ignore')

print(f"PyTorch version: {torch.__version__}")
print(f"GPU available: {torch.cuda.is_available()}")

  from .autonotebook import tqdm as notebook_tqdm


PyTorch version: 2.9.1+cpu
GPU available: False


In [2]:
# Model paths
BASELINE_MODEL_PATH = "notlath/BioClinical-ModernBERT-base-Symptom2Disease_WITHOUT-DROPOUT-42"
ENHANCED_MODEL_PATH = "notlath/BioClinical-ModernBERT-base-Symptom2Disease_WITH-DROPOUT-42"

# Configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
ALLOWED_DISEASES = {"Dengue", "Pneumonia", "Typhoid", "Impetigo"}
N_MC_SAMPLES = 50  # Number of stochastic forward passes
INFERENCE_DROPOUT_RATE = 0.10  # Dropout rate during inference
MAX_LEN = 512

print(f"Device: {DEVICE}")
print(f"\nBaseline (no dropout): {BASELINE_MODEL_PATH}")
print(f"Enhanced (MC dropout): {ENHANCED_MODEL_PATH}")
print(f"\nMC Samples: {N_MC_SAMPLES}")
print(f"Dropout Rate: {INFERENCE_DROPOUT_RATE}")

Device: cpu

Baseline (no dropout): notlath/BioClinical-ModernBERT-base-Symptom2Disease_WITHOUT-DROPOUT-42
Enhanced (MC dropout): notlath/BioClinical-ModernBERT-base-Symptom2Disease_WITH-DROPOUT-42

MC Samples: 50
Dropout Rate: 0.1


In [3]:
class BaselineClassifier:
    """Standard inference - deterministic predictions (no dropout)."""
    
    def __init__(self, model_path: str, device: str = DEVICE):
        print(f"Loading baseline model from {model_path}...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
        self.model.eval().to(device)
        self.device = device
        print("‚úì Baseline model loaded")

    def predict(self, text: str):
        inputs = self.tokenizer(
            text, return_tensors="pt", padding=True, truncation=True, max_length=MAX_LEN
        ).to(self.device)
        with torch.no_grad():
            logits = self.model(**inputs).logits
            probs = F.softmax(logits, dim=-1)[0]
        
        confidences = probs.detach().cpu().numpy()
        pred_idx = int(confidences.argmax())
        pred_label = self.model.config.id2label[pred_idx]
        
        # Filter to allowed diseases
        allowed = []
        for idx, label in self.model.config.id2label.items():
            if label in ALLOWED_DISEASES:
                allowed.append((label, float(confidences[idx])))
        allowed.sort(key=lambda x: x[1], reverse=True)
        
        return {
            "predicted_label": pred_label,
            "confidence": float(confidences[pred_idx]),
            "probabilities": confidences.tolist(),
            "top_allowed": allowed,
        }
    
    def count_dropout_layers(self):
        """Count dropout layers in model."""
        count = 0
        for module in self.model.modules():
            if "Dropout" in module.__class__.__name__:
                count += 1
        return count

In [4]:
class MCDropoutClassifier:
    """MC Dropout inference with uncertainty quantification."""
    
    def __init__(self, model_path: str, device: str = DEVICE, 
                 n_iterations: int = N_MC_SAMPLES, 
                 inference_dropout_rate: float = INFERENCE_DROPOUT_RATE):
        print(f"Loading enhanced model from {model_path}...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
        self.model.to(device)
        self.model.eval()
        self.device = device
        self.n_iterations = n_iterations
        self.inference_dropout_rate = inference_dropout_rate
        print("‚úì Enhanced model loaded")

    def enable_mc_dropout(self):
        """Enable dropout for MC sampling."""
        for module in self.model.modules():
            if module.__class__.__name__.startswith("Dropout"):
                module.p = self.inference_dropout_rate
                module.train()
            elif "Norm" in module.__class__.__name__:
                module.eval()

    def compute_mutual_information(self, predictions: np.ndarray) -> np.ndarray:
        """MI = H(mean probs) - E[H(sample probs)]."""
        expected_entropy = np.mean([entropy(p, axis=-1) for p in predictions], axis=0)
        mean_probs = predictions.mean(axis=0)
        entropy_of_expected = entropy(mean_probs, axis=-1)
        return entropy_of_expected - expected_entropy

    def predict_with_uncertainty(self, text: str):
        inputs = self.tokenizer(
            text, return_tensors="pt", truncation=True, padding=True, max_length=MAX_LEN
        ).to(self.device)
        inputs["input_ids"] = inputs["input_ids"].to(torch.long)
        inputs["attention_mask"] = inputs["attention_mask"].to(torch.long)

        self.enable_mc_dropout()
        all_predictions = []
        with torch.no_grad():
            for _ in range(self.n_iterations):
                probs = torch.softmax(self.model(**inputs).logits, dim=-1)
                all_predictions.append(probs.cpu().numpy())
        
        all_predictions = np.stack(all_predictions)  # (n_iter, batch, classes)
        mean_probs = all_predictions.mean(axis=0)
        std_probs = all_predictions.std(axis=0)
        pred_idx = mean_probs.argmax(axis=-1)
        confidence = mean_probs.max(axis=-1)
        predictive_entropy = entropy(mean_probs, axis=-1)
        mutual_information = self.compute_mutual_information(all_predictions)

        # Filter to allowed diseases
        allowed = []
        probs_vec = mean_probs[0]
        for idx, label in self.model.config.id2label.items():
            if label in ALLOWED_DISEASES:
                allowed.append((label, float(probs_vec[idx])))
        allowed.sort(key=lambda x: x[1], reverse=True)

        return {
            "predicted_label": self.model.config.id2label[int(pred_idx[0])],
            "confidence": float(confidence[0]),
            "mean_probs": mean_probs[0],
            "std_probs": std_probs[0],
            "predictive_entropy": float(predictive_entropy[0]),
            "mutual_information": float(mutual_information[0]),
            "top_allowed": allowed,
        }
    
    def count_dropout_layers(self):
        """Count dropout layers in model."""
        count = 0
        for module in self.model.modules():
            if "Dropout" in module.__class__.__name__:
                count += 1
        return count

## Load Models

In [5]:
baseline_model = BaselineClassifier(BASELINE_MODEL_PATH, device=DEVICE)
print(f"Dropout layers: {baseline_model.count_dropout_layers()}")

print()

mc_model = MCDropoutClassifier(
    ENHANCED_MODEL_PATH, 
    device=DEVICE, 
    n_iterations=N_MC_SAMPLES, 
    inference_dropout_rate=INFERENCE_DROPOUT_RATE
)
print(f"Dropout layers: {mc_model.count_dropout_layers()}")

Loading baseline model from notlath/BioClinical-ModernBERT-base-Symptom2Disease_WITHOUT-DROPOUT-42...


Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


‚úì Baseline model loaded
Dropout layers: 24

Loading enhanced model from notlath/BioClinical-ModernBERT-base-Symptom2Disease_WITH-DROPOUT-42...
‚úì Enhanced model loaded
Dropout layers: 46


## Architecture Inspection

In [6]:
print("="*70)
print("ARCHITECTURE COMPARISON")
print("="*70)

print("\n[BASELINE MODEL]")
print(f"  Model: {BASELINE_MODEL_PATH.split('/')[-1]}")
print(f"  Dropout layers: {baseline_model.count_dropout_layers()}")
print(f"  Mode: eval()")
print(f"  Training flag: {baseline_model.model.training}")
print(f"  Behavior: Deterministic predictions (same input ‚Üí same output)")

print("\n[ENHANCED MODEL]")
print(f"  Model: {ENHANCED_MODEL_PATH.split('/')[-1]}")
print(f"  Dropout layers: {mc_model.count_dropout_layers()}")
print(f"  Inference dropout rate: {INFERENCE_DROPOUT_RATE}")
print(f"  MC samples: {N_MC_SAMPLES}")
print(f"  Mode: eval() initially, then dropout layers set to train()")
print(f"  Behavior: Stochastic predictions (same input ‚Üí distribution of outputs)")

print("\n" + "="*70)

ARCHITECTURE COMPARISON

[BASELINE MODEL]
  Model: BioClinical-ModernBERT-base-Symptom2Disease_WITHOUT-DROPOUT-42
  Dropout layers: 24
  Mode: eval()
  Training flag: False
  Behavior: Deterministic predictions (same input ‚Üí same output)

[ENHANCED MODEL]
  Model: BioClinical-ModernBERT-base-Symptom2Disease_WITH-DROPOUT-42
  Dropout layers: 46
  Inference dropout rate: 0.1
  MC samples: 50
  Mode: eval() initially, then dropout layers set to train()
  Behavior: Stochastic predictions (same input ‚Üí distribution of outputs)



## Prediction Comparison on Test Cases

In [7]:
test_cases = [
    {
        "text": "High fever, severe headache, muscle pain, joint pain, rash on chest and limbs",
        "clarity": "CLEAR",
        "description": "Classic dengue presentation"
    },
    {
        "text": "Fever and cough",
        "clarity": "AMBIGUOUS",
        "description": "Non-specific symptoms"
    },
    {
        "text": "Sudden high fever, chills, rose spots on abdomen, sustained fever pattern",
        "clarity": "MODERATE",
        "description": "Suggestive of typhoid"
    },
]

results = []

for idx, test in enumerate(test_cases, 1):
    text = test["text"]
    print(f"\n{'='*70}")
    print(f"TEST {idx}: {test['clarity']} - {test['description']}")
    print(f"{'='*70}")
    print(f"Input: \"{text}\"")
    
    # Baseline prediction
    b = baseline_model.predict(text)
    print(f"\n[BASELINE - Deterministic]")
    print(f"  Prediction: {b['predicted_label']}")
    print(f"  Confidence: {b['confidence']:.4f}")
    print(f"  Top 3 diseases: {[(d, f'{p:.4f}') for d, p in b['top_allowed'][:3]]}")
    
    # Enhanced prediction
    m = mc_model.predict_with_uncertainty(text)
    print(f"\n[ENHANCED - MC Dropout]")
    print(f"  Prediction: {m['predicted_label']}")
    print(f"  Mean confidence: {m['confidence']:.4f}")
    print(f"  Confidence std: {m['std_probs'].max():.4f}")
    print(f"  Predictive entropy: {m['predictive_entropy']:.4f}")
    print(f"  Mutual information: {m['mutual_information']:.4f}")
    print(f"  Top 3 diseases: {[(d, f'{p:.4f}') for d, p in m['top_allowed'][:3]]}")
    
    results.append({
        "clarity": test['clarity'],
        "text": text[:40] + "...",
        "baseline_pred": b['predicted_label'],
        "baseline_conf": b['confidence'],
        "enhanced_pred": m['predicted_label'],
        "enhanced_conf": m['confidence'],
        "enhanced_std": m['std_probs'].max(),
        "enhanced_entropy": m['predictive_entropy'],
        "enhanced_mi": m['mutual_information'],
    })

print(f"\n{'='*70}")


TEST 1: CLEAR - Classic dengue presentation
Input: "High fever, severe headache, muscle pain, joint pain, rash on chest and limbs"

[BASELINE - Deterministic]
  Prediction: Dengue
  Confidence: 0.5255
  Top 3 diseases: [('Dengue', '0.5255'), ('Typhoid', '0.2354'), ('Pneumonia', '0.0459')]

[ENHANCED - MC Dropout]
  Prediction: Dengue
  Mean confidence: 0.6875
  Confidence std: 0.1424
  Predictive entropy: 1.1002
  Mutual information: 0.0659
  Top 3 diseases: [('Dengue', '0.6875'), ('Typhoid', '0.1049'), ('Pneumonia', '0.0593')]

TEST 2: AMBIGUOUS - Non-specific symptoms
Input: "Fever and cough"

[BASELINE - Deterministic]
  Prediction: Influenza
  Confidence: 0.5023
  Top 3 diseases: [('Pneumonia', '0.1125'), ('Typhoid', '0.1124'), ('Dengue', '0.0460')]

[ENHANCED - MC Dropout]
  Prediction: Pneumonia
  Mean confidence: 0.3971
  Confidence std: 0.1793
  Predictive entropy: 1.5266
  Mutual information: 0.1647
  Top 3 diseases: [('Pneumonia', '0.3971'), ('Dengue', '0.2287'), ('Typhoid',

## Summary Table

In [8]:
df = pd.DataFrame(results)
df['conf_diff'] = df['enhanced_conf'] - df['baseline_conf']
print(df.to_string(index=False))

print("\n" + "="*70)
print("KEY OBSERVATIONS")
print("="*70)
print("1. CLEAR cases: Enhanced should show LOW entropy/MI (model is certain)")
print("2. AMBIGUOUS cases: Enhanced should show HIGH entropy/MI (model is uncertain)")
print("3. Baseline provides NO uncertainty quantification")
print("4. Enhanced provides principled uncertainty measures via MC sampling")

  clarity                                        text baseline_pred  baseline_conf enhanced_pred  enhanced_conf  enhanced_std  enhanced_entropy  enhanced_mi  conf_diff
    CLEAR High fever, severe headache, muscle pain...        Dengue       0.525517        Dengue       0.687472      0.142410          1.100218     0.065918   0.161955
AMBIGUOUS                          Fever and cough...     Influenza       0.502250     Pneumonia       0.397115      0.179302          1.526615     0.164663  -0.105135
 MODERATE Sudden high fever, chills, rose spots on...       Typhoid       0.818974       Typhoid       0.815679      0.110270          0.767987     0.048126  -0.003295

KEY OBSERVATIONS
1. CLEAR cases: Enhanced should show LOW entropy/MI (model is certain)
2. AMBIGUOUS cases: Enhanced should show HIGH entropy/MI (model is uncertain)
3. Baseline provides NO uncertainty quantification
4. Enhanced provides principled uncertainty measures via MC sampling


## Stability Test: Repeated Predictions

In [9]:
test_text = "Fever and cough"  # Ambiguous case

print(f"Testing prediction stability for: \"{test_text}\"")
print(f"Running 10 predictions with each model...\n")

# Baseline - should be identical
print("[BASELINE] - Expected: All predictions identical (deterministic)")
baseline_preds = []
for i in range(10):
    pred = baseline_model.predict(test_text)
    baseline_preds.append(pred['confidence'])
    if i < 3 or i == 9:
        print(f"  Run {i+1}: {pred['predicted_label']:12s} | Confidence: {pred['confidence']:.4f}")
    elif i == 3:
        print(f"  ...")

print(f"\n  Confidence variance: {np.var(baseline_preds):.8f}")
print(f"  All identical: {len(set(baseline_preds)) == 1}")

# Enhanced - should vary
print(f"\n[ENHANCED] - Expected: Predictions vary (stochastic ensemble)")
enhanced_preds = []
enhanced_entropies = []
for i in range(10):
    pred = mc_model.predict_with_uncertainty(test_text)
    enhanced_preds.append(pred['confidence'])
    enhanced_entropies.append(pred['predictive_entropy'])
    if i < 3 or i == 9:
        print(f"  Run {i+1}: {pred['predicted_label']:12s} | Conf: {pred['confidence']:.4f} | Entropy: {pred['predictive_entropy']:.4f}")
    elif i == 3:
        print(f"  ...")

print(f"\n  Confidence variance: {np.var(enhanced_preds):.8f}")
print(f"  Entropy variance: {np.var(enhanced_entropies):.8f}")

print(f"\n{'='*70}")
print("INTERPRETATION")
print("="*70)
print(f"Baseline variance ‚âà 0: Deterministic (no randomness)")
print(f"Enhanced variance > 0: Stochastic ensemble captures epistemic uncertainty")
print(f"\nThe variance in Enhanced predictions is NOT a bug - it's the feature!")
print(f"MC Dropout creates an ensemble of models to quantify uncertainty.")

Testing prediction stability for: "Fever and cough"
Running 10 predictions with each model...

[BASELINE] - Expected: All predictions identical (deterministic)
  Run 1: Influenza    | Confidence: 0.5023
  Run 2: Influenza    | Confidence: 0.5023
  Run 3: Influenza    | Confidence: 0.5023
  ...
  Run 10: Influenza    | Confidence: 0.5023

  Confidence variance: 0.00000000
  All identical: True

[ENHANCED] - Expected: Predictions vary (stochastic ensemble)
  Run 1: Pneumonia    | Conf: 0.4063 | Entropy: 1.5297
  Run 2: Pneumonia    | Conf: 0.3381 | Entropy: 1.5586
  Run 3: Pneumonia    | Conf: 0.3867 | Entropy: 1.5214
  ...
  Run 10: Pneumonia    | Conf: 0.3652 | Entropy: 1.5641

  Confidence variance: 0.00123924
  Entropy variance: 0.00129848

INTERPRETATION
Baseline variance ‚âà 0: Deterministic (no randomness)
Enhanced variance > 0: Stochastic ensemble captures epistemic uncertainty

The variance in Enhanced predictions is NOT a bug - it's the feature!
MC Dropout creates an ensemble o

## Key Findings

### 1. **Architectural Difference**
- **Baseline**: Trained WITHOUT dropout ‚Üí Deterministic inference
- **Enhanced**: Trained WITH dropout ‚Üí Stochastic inference via MC sampling

### 2. **Uncertainty Quantification**
- **Baseline**: Single point estimate, no uncertainty measure
- **Enhanced**: Distributional predictions with:
  - **Predictive Entropy**: Overall uncertainty in prediction
  - **Mutual Information**: Model's epistemic uncertainty (knowledge gaps)
  - **Confidence Std**: Variability across MC samples

### 3. **Clinical Implications**
- **Clear cases**: Both models confident, Enhanced provides additional safety bounds
- **Ambiguous cases**: 
  - Baseline: False confidence (no way to know it's uncertain)
  - Enhanced: Quantified uncertainty (can trigger human review)

### 4. **Recommended Usage**
- Use **entropy/MI thresholds** to flag uncertain cases
- Route high-uncertainty predictions to clinician review
- Trust high-confidence, low-entropy predictions for automated triage

### 5. **Trade-offs**
- **Computational**: Enhanced requires 50x more inference time (50 forward passes)
- **Benefit**: Principled uncertainty quantification for safety-critical applications
- **Solution**: Hybrid approach - use baseline for screening, enhanced for borderline cases

## BONUS: Baseline vs Enhanced with SHAP Explainability

**Research Question**: Does adding SHAP explanations improve clinical interpretability without changing predictions?

**Comparison Framework**:
- **Baseline**: Softmax output only (prediction + confidence)
- **Enhanced**: Softmax + SHAP (prediction + confidence + token attributions)

**Key Insight**: Both models produce identical predictions - SHAP adds *post-hoc* explainability to understand *why* the model made that prediction.

In [10]:
# Import SHAP dependencies
from captum.attr import GradientShap

class BaselineClassifierSoftmaxOnly:
    """Baseline: Only returns predictions, no explanations."""
    
    def __init__(self, model_path: str, device: str = DEVICE):
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
        self.model.eval().to(device)
        self.device = device
    
    def predict(self, text: str):
        inputs = self.tokenizer(
            text, return_tensors="pt", padding=True, truncation=True, max_length=MAX_LEN
        ).to(self.device)
        with torch.no_grad():
            logits = self.model(**inputs).logits
            probs = F.softmax(logits, dim=-1)[0]
        
        confidences = probs.detach().cpu().numpy()
        pred_idx = int(confidences.argmax())
        pred_label = self.model.config.id2label[pred_idx]
        
        return {
            "predicted_label": pred_label,
            "confidence": float(confidences[pred_idx]),
            "explanation": "‚ö†Ô∏è No explanation available (softmax only)"
        }


class EnhancedClassifierWithSHAP:
    """Enhanced: Returns predictions + SHAP explanations."""
    
    def __init__(self, model_path: str, device: str = DEVICE):
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
        self.model.eval().to(device)
        self.device = device
    
    def predict_with_explanation(self, text: str):
        inputs = self.tokenizer(
            text, return_tensors="pt", padding=True, truncation=True, max_length=MAX_LEN
        ).to(self.device)
        
        with torch.no_grad():
            logits = self.model(**inputs).logits
            probs = F.softmax(logits, dim=-1)[0]
        
        confidences = probs.detach().cpu().numpy()
        pred_idx = int(confidences.argmax())
        pred_label = self.model.config.id2label[pred_idx]
        
        # Generate SHAP explanations
        embeddings = self.model.get_input_embeddings()(inputs["input_ids"])
        
        def forward_func(embeds):
            attention_mask = inputs["attention_mask"]
            outputs = self.model(inputs_embeds=embeds, attention_mask=attention_mask)
            return F.softmax(outputs.logits, dim=-1)[:, pred_idx]
        
        baseline_embeds = torch.zeros_like(embeddings).repeat(5, 1, 1)
        gradient_shap = GradientShap(forward_func)
        attributions, _ = gradient_shap.attribute(
            embeddings, 
            baselines=baseline_embeds, 
            n_samples=25, 
            stdevs=0.01, 
            return_convergence_delta=True
        )
        
        token_attrib = attributions.sum(dim=-1).squeeze().detach().cpu()
        tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0].detach().cpu().numpy())
        
        # Normalize attributions to [0,1]
        token_attrib = (token_attrib - token_attrib.min()) / (token_attrib.max() - token_attrib.min() + 1e-8)
        token_explanations = list(zip(tokens, token_attrib.numpy().tolist()))
        
        # Get top contributing tokens (excluding special tokens)
        top_tokens = sorted(
            [(tok, attr) for tok, attr in token_explanations if tok not in ['[CLS]', '[SEP]', '[PAD]']],
            key=lambda x: x[1],
            reverse=True
        )[:5]
        
        return {
            "predicted_label": pred_label,
            "confidence": float(confidences[pred_idx]),
            "explanation": f"‚úì Top contributing tokens: {', '.join([f'{tok}({attr:.3f})' for tok, attr in top_tokens])}",
            "all_attributions": token_explanations
        }

print("‚úì SHAP comparison classes loaded")

‚úì SHAP comparison classes loaded


In [11]:
# Load models (reusing same model path for fair comparison)
print("Loading models for SHAP comparison...")
baseline_softmax_only = BaselineClassifierSoftmaxOnly(ENHANCED_MODEL_PATH, device=DEVICE)
enhanced_with_shap = EnhancedClassifierWithSHAP(ENHANCED_MODEL_PATH, device=DEVICE)
print("‚úì Both models loaded (same weights, different output formats)")

Loading models for SHAP comparison...
‚úì Both models loaded (same weights, different output formats)


### Side-by-Side Comparison: Softmax Only vs Softmax + SHAP

In [12]:
explainability_test_cases = [
    {
        "text": "High fever, severe headache, muscle pain, joint pain, rash on chest",
        "expected_focus": "Classic dengue symptoms - should highlight 'fever', 'rash', 'joint pain'"
    },
    {
        "text": "Persistent cough with chest pain and difficulty breathing",
        "expected_focus": "Respiratory symptoms - should highlight 'cough', 'chest', 'breathing'"
    },
    {
        "text": "Sudden high fever with rose spots on abdomen and sustained fever pattern",
        "expected_focus": "Typhoid indicators - should highlight 'rose spots', 'sustained fever'"
    }
]

for idx, case in enumerate(explainability_test_cases, 1):
    text = case["text"]
    print(f"\n{'='*80}")
    print(f"TEST CASE {idx}")
    print(f"{'='*80}")
    print(f"Input: \"{text}\"")
    print(f"Expected: {case['expected_focus']}")
    print(f"\n{'-'*80}")
    
    # Baseline: Softmax only
    baseline_result = baseline_softmax_only.predict(text)
    print(f"\n[BASELINE - Softmax Only]")
    print(f"  Prediction: {baseline_result['predicted_label']}")
    print(f"  Confidence: {baseline_result['confidence']:.4f}")
    print(f"  Explanation: {baseline_result['explanation']}")
    
    # Enhanced: Softmax + SHAP
    enhanced_result = enhanced_with_shap.predict_with_explanation(text)
    print(f"\n[ENHANCED - Softmax + SHAP]")
    print(f"  Prediction: {enhanced_result['predicted_label']}")
    print(f"  Confidence: {enhanced_result['confidence']:.4f}")
    print(f"  Explanation: {enhanced_result['explanation']}")
    
    print(f"\n{'-'*80}")
    print("üìä INTERPRETATION:")
    print(f"  ‚Ä¢ Baseline: Gives prediction but no insight into reasoning")
    print(f"  ‚Ä¢ Enhanced: Shows which tokens (symptoms) drove the decision")
    print(f"  ‚Ä¢ Clinical Value: Clinician can validate if model focused on correct symptoms")

print(f"\n{'='*80}")


TEST CASE 1
Input: "High fever, severe headache, muscle pain, joint pain, rash on chest"
Expected: Classic dengue symptoms - should highlight 'fever', 'rash', 'joint pain'

--------------------------------------------------------------------------------

[BASELINE - Softmax Only]
  Prediction: Dengue
  Confidence: 0.6455
  Explanation: ‚ö†Ô∏è No explanation available (softmax only)

[ENHANCED - Softmax + SHAP]
  Prediction: Dengue
  Confidence: 0.6455
  Explanation: ‚úì Top contributing tokens: ƒ†muscle(1.000), ƒ†fever(0.598), ƒ†joint(0.504), ƒ†headache(0.414), ƒ†severe(0.271)

--------------------------------------------------------------------------------
üìä INTERPRETATION:
  ‚Ä¢ Baseline: Gives prediction but no insight into reasoning
  ‚Ä¢ Enhanced: Shows which tokens (symptoms) drove the decision
  ‚Ä¢ Clinical Value: Clinician can validate if model focused on correct symptoms

TEST CASE 2
Input: "Persistent cough with chest pain and difficulty breathing"
Expected: Respiratory 