## 1. Setup

In [None]:
# Install required packages
!pip install torch transformers accelerate tqdm -q

print("✓ Packages installed successfully!")

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Add project path
import sys
import os

project_path = "/content/drive/MyDrive/DATA 298A/sjsu-data298-main"

if project_path not in sys.path:
    sys.path.insert(0, project_path)

print(f"✓ Project path: {project_path}")
print(f"✓ Path exists: {os.path.exists(project_path)}")

if os.path.exists(project_path):
    contents = os.listdir(project_path)
    print(f"✓ Contents: {contents}")
    
    required_files = ['medical_llm_wrapper.py', 'medical_integrated_gradients.py']
    for file in required_files:
        if file in contents:
            print(f"✓ {file} found!")
        else:
            print(f"⚠️  WARNING: {file} NOT FOUND!")

In [None]:
# Import modules
import warnings
warnings.filterwarnings('once')

from medical_llm_wrapper import load_medical_llm
from medical_integrated_gradients import (
    MedicalIntegratedGradients,
    explain_medical_prediction,
    visualize_attributions
)
import torch
import numpy as np

print("✓ All modules imported successfully!")

## 2. Example 1: Basic Usage - MedGemma MCQ

Let's explain why MedGemma chose a particular answer for a clinical diagnosis question.

In [None]:
print("=" * 80)
print("EXAMPLE 1: MedGemma - Clinical Diagnosis MCQ")
print("=" * 80)

# Load MedGemma
medgemma = load_medical_llm(
    "google/medgemma-4b-it",
    device="cuda"
)

# Set task type
medgemma.set_task("mcq")
medgemma.set_mode("answer_only")

print("\n✓ Model loaded successfully!")

In [None]:
# Define medical MCQ
prompt = """A 65-year-old man presents with sudden onset chest pain, dyspnea, and diaphoresis. 
ECG shows ST-segment elevation in leads V1-V4. What is the most likely diagnosis?

A) Unstable angina
B) Anterior myocardial infarction
C) Pulmonary embolism
D) Aortic dissection

Answer:"""

# Get model's prediction first
print("\n[Getting Model Prediction...]")
response = medgemma.generate(prompt)
print(f"\nModel Answer: {medgemma.last_answer}")
print(f"Confidence: {medgemma.last_confidence:.4f}" if not np.isnan(medgemma.last_confidence) else "Confidence: NaN")

print("\n" + "=" * 80)

In [None]:
# Explain prediction with Integrated Gradients
print("\n[Computing Integrated Gradients...]")
print("This will take ~30 seconds...\n")

result = explain_medical_prediction(
    wrapper=medgemma,
    prompt=prompt,
    target_class="B",  # Explain why model chose "B" (or choose model's actual answer)
    n_steps=50,
    visualize=True
)

print(f"\n[Key Insights]")
print(f"  Target Probability: {result['target_probability']:.4f}")
print(f"  Convergence Delta: {result['convergence_delta']:.6f} (lower is better)")
print(f"\n  Top 5 Most Important Tokens:")

# Get top tokens
token_scores = list(zip(result['tokens'], result['attributions']))
token_scores.sort(key=lambda x: abs(x[1]), reverse=True)

for i, (token, score) in enumerate(token_scores[:5], 1):
    clean_token = token.replace('▁', ' ').replace('Ġ', ' ').strip()
    print(f"    {i}. '{clean_token}': {score:.4f}")

## 3. Example 2: Yes/No Question - Apollo

Let's see how attributions work for binary medical questions.

In [None]:
print("=" * 80)
print("EXAMPLE 2: Apollo - Yes/No Medical Question")
print("=" * 80)

# Load Apollo
apollo = load_medical_llm(
    "FreedomIntelligence/Apollo-2B",
    device="cuda",
    torch_dtype=torch.float16
)

apollo.set_task("yn")
apollo.set_mode("answer_only")

print("\n✓ Apollo loaded successfully!")

In [None]:
# Define Yes/No question
yn_prompt = """Metformin is contraindicated in patients with severe renal impairment 
(eGFR < 30 mL/min/1.73m²) due to increased risk of lactic acidosis.

A) True
B) False

Answer:"""

# Get prediction
print("\n[Getting Prediction...]")
response = apollo.generate(yn_prompt)
print(f"\nModel Answer: {apollo.last_answer} ({'True' if apollo.last_answer == 'A' else 'False'})")
print(f"Confidence: {apollo.last_confidence:.4f}" if not np.isnan(apollo.last_confidence) else "Confidence: NaN")

In [None]:
# Explain with IG
print("\n[Computing Integrated Gradients...]\n")

result = explain_medical_prediction(
    wrapper=apollo,
    prompt=yn_prompt,
    target_class="A",  # Explain "True" answer
    n_steps=50,
    visualize=True
)

print(f"\n[Analysis]")
print(f"  The model assigned {result['target_probability']:.1%} probability to 'True'")
print(f"  Most influential tokens support {'True' if result['target_probability'] > 0.5 else 'False'} answer")

## 4. Example 3: Comparing Explanations for Different Answers

Let's explain why the model chose each possible answer.

In [None]:
print("=" * 80)
print("EXAMPLE 3: Comparing Attributions Across Answer Choices")
print("=" * 80)

# Use MedGemma for this example
comparison_prompt = """A patient presents with polyuria, polydipsia, and weight loss. 
Blood glucose is 350 mg/dL. What is the diagnosis?

A) Type 1 diabetes
B) Type 2 diabetes  
C) Diabetes insipidus
D) Hyperthyroidism

Answer:"""

# Get prediction
print("\n[Model Prediction]")
response = medgemma.generate(comparison_prompt)
print(f"Answer: {medgemma.last_answer}")

In [None]:
# Explain each answer choice
ig = MedicalIntegratedGradients(medgemma, n_steps=50, verbose=False)

print("\n[Computing attributions for each answer choice...]\n")

for choice in ['A', 'B', 'C', 'D']:
    print(f"\nExplaining Answer {choice}:")
    result = ig.attribute(comparison_prompt, choice, return_convergence_delta=True)
    
    # Get top 3 tokens
    token_scores = list(zip(result['tokens'], result['attributions']))
    token_scores.sort(key=lambda x: x[1], reverse=True)  # Sort by attribution
    
    print(f"  Probability: {result['target_probability']:.4f}")
    print(f"  Top 3 supporting tokens:")
    for i, (token, score) in enumerate(token_scores[:3], 1):
        clean_token = token.replace('▁', ' ').replace('Ġ', ' ').strip()
        if clean_token:
            print(f"    {i}. '{clean_token}': {score:.4f}")

print("\n" + "=" * 80)

## 5. Example 4: Batch Processing Multiple Questions

Explain multiple predictions efficiently.

In [None]:
print("=" * 80)
print("EXAMPLE 4: Batch Explanation")
print("=" * 80)

# Define multiple questions
questions = [
    {
        'prompt': """Aspirin works by irreversibly inhibiting cyclooxygenase enzymes.
A) True
B) False
Answer:""",
        'target': 'A'
    },
    {
        'prompt': """Beta-blockers are contraindicated in acute asthma exacerbation.
A) True  
B) False
Answer:""",
        'target': 'A'
    },
    {
        'prompt': """ACE inhibitors can cause hyperkalemia.
A) True
B) False
Answer:""",
        'target': 'A'
    }
]

# Switch apollo to yn mode
apollo.set_task("yn")

# Batch attribution
ig_apollo = MedicalIntegratedGradients(apollo, n_steps=30, verbose=False)

print("\n[Processing batch...]\n")

prompts = [q['prompt'] for q in questions]
targets = [q['target'] for q in questions]

results = ig_apollo.attribute_batch(prompts, targets)

# Display summary
print("\n[Summary]\n")
for i, (q, result) in enumerate(zip(questions, results), 1):
    first_line = q['prompt'].split('\n')[0][:60] + "..."
    print(f"{i}. {first_line}")
    print(f"   Prediction: {result['prediction']} | Target Prob: {result['target_probability']:.3f}")
    print()

print("=" * 80)

## 6. Custom Visualization

Create your own visualizations of token attributions.

In [None]:
print("=" * 80)
print("EXAMPLE 5: Custom Visualization")
print("=" * 80)

# Use a previous result
import matplotlib.pyplot as plt

# Get attributions for a question
custom_prompt = """Patient has fever, productive cough, and consolidation on chest X-ray.
A) Pneumonia
B) Tuberculosis
C) Lung cancer
D) Heart failure
Answer:"""

medgemma.set_task("mcq")
ig_medgemma = MedicalIntegratedGradients(medgemma, n_steps=50, verbose=True)

result = ig_medgemma.attribute(custom_prompt, "A")

# Bar chart
tokens = [t.replace('▁', '').replace('Ġ', '') for t in result['tokens']]
attributions = result['attributions']

# Show top 15 tokens
top_indices = np.argsort(np.abs(attributions))[-15:]

plt.figure(figsize=(12, 6))
plt.barh(range(len(top_indices)), attributions[top_indices])
plt.yticks(range(len(top_indices)), [tokens[i] for i in top_indices])
plt.xlabel('Attribution Score')
plt.title(f'Top 15 Token Attributions for Answer {result["target_class"]}')
plt.axvline(x=0, color='black', linestyle='--', linewidth=0.5)
plt.tight_layout()
plt.show()

print(f"\nPrediction: {result['prediction']}")
print(f"Target Probability: {result['target_probability']:.4f}")

## 7. Cleanup

In [None]:
# Clean up memory
import gc

del medgemma
del apollo
torch.cuda.empty_cache()
gc.collect()

print("✓ Memory cleaned up")
print("\n" + "=" * 80)
print("Medical Integrated Gradients Demo - COMPLETE!")
print("=" * 80)

## Summary

### What we learned:

1. **Basic Usage**: Load wrapper → Create IG explainer → Get attributions
2. **Interpretation**: Positive attributions support the answer, negative oppose it
3. **Convergence**: Lower convergence delta = more accurate attributions
4. **Comparison**: Explain why model chose one answer over others
5. **Batch Processing**: Efficiently explain multiple predictions

### Key Insights:

- **Clinical terms** (symptoms, test results) typically have high attributions
- **Question structure words** ("What is", "diagnosis") have lower importance
- **Answer options** themselves can have significant attributions
- **Model differences**: MedGemma vs Apollo may focus on different tokens

### Next Steps:

- Compare IG with other XAI methods (TokenSHAP, LIME)
- Analyze attribution patterns across medical domains
- Use attributions to improve prompt engineering
- Build trust in model predictions through explanations