# Medical LIME Adapter Demo

This notebook demonstrates how to use `MedicalLIME` from `medical_lime_adapter.py` to explain predictions from a medical LLM.

**Model:** `google/medgemma-4b-it`  
**Method:** LIME (Local Interpretable Model-Agnostic Explanations)  
**Task:** Multiple Choice Question (MCQ)

LIME works by randomly masking words in the prompt, querying the model on each perturbed input, and fitting a local linear model whose coefficients serve as word-level attribution scores.

## 1. Setup and Installation

In [None]:
# Install required packages
!pip install torch transformers accelerate scikit-learn tqdm hf_xet -q

print("✓ Packages installed successfully!")

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Add project path to Python path
import sys
import os

project_path = "/content/drive/MyDrive/DATA 298A/sjsu-data298-main"

if project_path not in sys.path:
    sys.path.append(project_path)

print(f"✓ Project path: {project_path}")
print(f"✓ Path exists: {os.path.exists(project_path)}")

if os.path.exists(project_path):
    contents = os.listdir(project_path)
    for fname in ["medical_llm_wrapper.py", "medical_lime_adapter.py"]:
        status = "✓" if fname in contents else "⚠️  NOT FOUND"
        print(f"  {status}  {fname}")

In [None]:
# Import wrapper and LIME adapter
import warnings
warnings.filterwarnings('once')

import torch
from medical_llm_wrapper import load_medical_llm
from medical_lime_adapter import MedicalLIME, visualize_lime_attributions, to_dataframe

print("✓ Imports successful!")

## 2. Load MedGemma-4B-IT

In [None]:
# Load model — MedGemma auto-converts to float32
wrapper = load_medical_llm(
    "google/medgemma-4b-it",
    device="cuda"
)
wrapper.set_task("mcq")

print("\n[Model Information]")
info = wrapper.get_model_info()
for key, value in info.items():
    if key != "num_parameters":
        print(f"  {key}: {value}")
    else:
        print(f"  {key}: {value:,}")

## 3. Define a Medical MCQ Prompt

In [None]:
prompt = """A 55-year-old patient presents with persistent cough, hemoptysis, and unintentional weight loss.
Chest X-ray shows a mass in the right upper lobe. What is the most likely diagnosis?

A) Tuberculosis
B) Lung cancer
C) Pneumonia
D) Pulmonary embolism

Answer:"""

print(prompt)

## 4. Initialize MedicalLIME

In [None]:
lime = MedicalLIME(
    wrapper,
    n_samples=300,      # number of perturbed inputs (more = more accurate, but slower)
    kernel_width=0.75,  # locality of the linear fit
    verbose=True
)

print("✓ MedicalLIME initialized")

## 5. Explain with Auto-Detected Target Class

`lime.analyze(prompt)` automatically picks the model's predicted class as the target.

In [None]:
result = lime.analyze(prompt)

print("\n" + "=" * 60)
print("[LIME Results — Auto Target]")
print("=" * 60)
print(f"  Prediction        : {result['prediction']}")
print(f"  Target class      : {result['target_class']}")
print(f"  P(target)         : {result['target_probability']:.4f}")
print(f"  All option probs  : {result['all_option_probs']}")
print(f"  R² (local fit)    : {result['r_squared']:.4f}")

print("\n  Top 5 most influential words (by |attribution|):")
for word, score in result['top_words'][:5]:
    direction = "+" if score > 0 else "-"
    print(f"    [{direction}] '{word}': {score:.4f}")

print("\n  Top positive words (support the predicted class):")
for word, score in result['top_positive_words'][:3]:
    print(f"    '{word}': {score:.4f}")

print("\n  Top negative words (work against the predicted class):")
for word, score in result['top_negative_words'][:3]:
    print(f"    '{word}': {score:.4f}")

## 6. Explain with Explicit Target Class + Visualization

Pass `visualize=True` for a color-coded terminal view.  
**Red** = word supports the target class | **Blue** = word works against it.

In [None]:
result_b = lime.analyze(prompt, target_class="B", visualize=True)

print(f"\nP(B) on original prompt: {result_b['target_probability']:.4f}")
print(f"R² of local linear fit : {result_b['r_squared']:.4f}")

## 7. Inspect Attributions as a DataFrame

`to_dataframe()` converts the result into a sorted pandas DataFrame for further analysis.

In [None]:
df = to_dataframe(result)
print("Top 10 most influential words:")
print(df.head(10).to_string(index=False))

## 8. Batch Explanation — Multiple Prompts

`analyze_batch()` runs LIME over a list of prompts with optional per-prompt target classes.

In [None]:
prompts = [
    """Which vitamin deficiency causes scurvy?
A) Vitamin A
B) Vitamin B12
C) Vitamin C
D) Vitamin D

Answer:""",

    """Which organ produces insulin?
A) Liver
B) Pancreas
C) Kidney
D) Spleen

Answer:"""
]

batch_results = lime.analyze_batch(prompts)

print("\n" + "=" * 60)
print("[Batch LIME Results]")
print("=" * 60)
for i, (p, r) in enumerate(zip(prompts, batch_results), 1):
    question = p.strip().splitlines()[0]
    print(f"\nQ{i}: {question}")
    print(f"  Prediction : {r['prediction']}")
    print(f"  P(target)  : {r['target_probability']:.4f}")
    print(f"  R²         : {r['r_squared']:.4f}")
    print(f"  Top word   : '{r['top_words'][0][0]}' ({r['top_words'][0][1]:.4f})")

## 9. Cleanup

In [None]:
import gc

del wrapper, lime
gc.collect()
torch.cuda.empty_cache()

print("✓ Memory cleaned up")
print("\n" + "=" * 60)
print("Medical LIME Demo — Complete!")
print("=" * 60)