## 1. Setup and Installation

In [None]:
# Install required packages
!pip install torch transformers accelerate hf_xet -q

print("‚úì Packages installed successfully!")

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Add project path to Python path
import sys
import os

project_path = "/content/drive/MyDrive/DATA 298A/sjsu-data298-main"

if project_path not in sys.path:
    sys.path.append(project_path)

print(f"‚úì Project path: {project_path}")
print(f"‚úì Path exists: {os.path.exists(project_path)}")

if os.path.exists(project_path):
    contents = os.listdir(project_path)
    print(f"‚úì Contents: {contents}")
    
    # Check for the wrapper file
    if "medical_llm_wrapper.py" in contents:
        print("‚úì medical_llm_wrapper.py found!")
    else:
        print("\n‚ö†Ô∏è  WARNING: medical_llm_wrapper.py NOT FOUND!")
        print("   Please upload medical_llm_wrapper.py to:")
        print(f"   {project_path}")
else:
    print(f"\n‚ö†Ô∏è  ERROR: Project path does not exist!")
    print(f"   Please create the folder: {project_path}")
    print(f"   Or update the path above to match your Google Drive structure.")

In [None]:
# Import the wrapper
import warnings
warnings.filterwarnings('once')

import sys
import os

# Double-check the path is added (in case cells ran out of order)
project_path = "/content/drive/MyDrive/DATA 298A/sjsu-data298-main"
if project_path not in sys.path:
    sys.path.insert(0, project_path)
    print(f"‚úì Added {project_path} to Python path")

# Verify the file exists
wrapper_file = os.path.join(project_path, "medical_llm_wrapper.py")
if os.path.exists(wrapper_file):
    print(f"‚úì Found: {wrapper_file}")
else:
    raise FileNotFoundError(f"medical_llm_wrapper.py not found at {wrapper_file}")

# Now import
try:
    from medical_llm_wrapper import MedicalLLMWrapper, load_medical_llm
    import torch
    print("‚úì Medical LLM Wrapper imported successfully!")
except ModuleNotFoundError as e:
    print(f"‚ùå Error: {e}")
    print("\nüìã Troubleshooting steps:")
    print("1. RESTART THE RUNTIME: Runtime ‚Üí Restart runtime")
    print("2. Re-run all cells from the beginning")
    print("3. Make sure you run the 'Add project path' cell BEFORE this cell")
    print(f"\nüí° Current sys.path:")
    for p in sys.path[:5]:
        print(f"   {p}")
    raise

## 2. Test 1: MedGemma-4B-IT - Multiple Choice Question

Test the wrapper with MedGemma on a clinical diagnosis MCQ.
- **Automatic fp32 conversion** (MedGemma requires float32)
- **Constrained generation** (forces A/B/C/D)
- **Answer + Rationale mode**

In [None]:
print("=" * 80)
print("TEST 1: MedGemma-4B-IT - MCQ with Answer + Rationale")
print("=" * 80)

# Load MedGemma (auto-converts to fp32 when torch_dtype=None)
medgemma = load_medical_llm(
    "google/medgemma-4b-it",
    device="cuda"
    # Don't pass torch_dtype - let wrapper auto-detect and use fp32 for MedGemma
)

# Display model info
print("\n[Model Information]")
info = medgemma.get_model_info()
for key, value in info.items():
    if key != "num_parameters":
        print(f"  {key}: {value}")
    else:
        print(f"  {key}: {value:,}")

# Set task type
medgemma.set_task("mcq")
medgemma.set_mode("answer_rationale")

# Medical MCQ prompt
mcq_prompt = """A 55-year-old patient presents with persistent cough, hemoptysis, and unintentional weight loss.
Chest X-ray shows a mass in the right upper lobe. What is the most likely diagnosis?

A) Tuberculosis
B) Lung cancer
C) Pneumonia
D) Pulmonary embolism

Answer:"""

print("\n" + "=" * 80)
print("[Generating Response...]")
print("=" * 80)

response = medgemma.generate(mcq_prompt)

print("\n[RESULT]")
print(response)
print("\n" + "=" * 80)
print("‚úì Test 1 Complete!")
print("=" * 80)

## 3. Test 2: MedGemma - Answer Only Mode with Confidence

Test confidence extraction for MCQ answers.
- **Answer-only mode** (no rationale)
- **Confidence scores** for each option
- **Probability distribution** over A/B/C/D

In [None]:
print("\n" + "=" * 80)
print("TEST 2: MedGemma - MCQ with Confidence Scores")
print("=" * 80)

# Switch to answer-only mode
medgemma.set_mode("answer_only")

# Generate with confidence
response = medgemma.generate(mcq_prompt)

print("\n[RESULT]")
print(response)
print(f"\n[CONFIDENCE METRICS]")
print(f"  Selected Answer: {medgemma.last_answer}")

# Handle NaN confidence values
import math
if medgemma.last_confidence is not None and not math.isnan(medgemma.last_confidence):
    print(f"  Confidence: {medgemma.last_confidence:.4f}")
    print(f"\n  Option Probabilities:")
    for option, prob in sorted(medgemma.last_option_probs.items()):
        if not math.isnan(prob):
            bar = "‚ñà" * int(prob * 50)
            print(f"    {option}: {prob:.4f} {bar}")
        else:
            print(f"    {option}: NaN (computation error)")
else:
    print(f"  Confidence: NaN (computation error)")
    print(f"\n  Note: Confidence computation failed. This may happen with fp32 models.")
    print(f"  The answer '{medgemma.last_answer}' was still generated successfully.")

print("\n" + "=" * 80)
print("‚úì Test 2 Complete!")
print("=" * 80)

# Clean up memory
del medgemma
torch.cuda.empty_cache()

## 4. Test 3: Apollo-2B - Yes/No Question

Test the wrapper with a different model (Apollo) on a binary Yes/No question.
- **Different architecture** (Llama-based vs Gemma-based)
- **Native fp16** (no conversion needed)
- **Yes/No task** (A=Yes, B=No)

In [None]:
print("=" * 80)
print("TEST 3: Apollo-2B - Yes/No Question with Answer + Rationale")
print("=" * 80)

# Load Apollo (works fine in fp16)
apollo = load_medical_llm(
    "FreedomIntelligence/Apollo-2B",
    device="cuda",
    torch_dtype=torch.float16
)

# Display model info
print("\n[Model Information]")
info = apollo.get_model_info()
for key, value in info.items():
    if key != "num_parameters":
        print(f"  {key}: {value}")
    else:
        print(f"  {key}: {value:,}")

# Set task type to Yes/No
apollo.set_task("yn")
apollo.set_mode("answer_rationale")

# Yes/No medical question
yn_prompt = """Metformin is contraindicated in patients with severe renal impairment.

A) Yes
B) No

Answer:"""

print("\n" + "=" * 80)
print("[Generating Response...]")
print("=" * 80)

response = apollo.generate(yn_prompt)

print("\n[RESULT]")
print(response)
print("\n" + "=" * 80)
print("‚úì Test 3 Complete!")
print("=" * 80)

## 5. Test 4: Apollo - Free-Response Generation

Test unconstrained generation for open-ended medical questions.
- **Free-response task** (no answer constraints)
- **Longer generation** (up to 200 tokens)
- **Medical knowledge test**

In [None]:
print("\n" + "=" * 80)
print("TEST 4: Apollo-2B - Free-Response Medical Question")
print("=" * 80)

# Set to free-response task
apollo.set_task("free")

# Open-ended medical question
free_prompt = "What are the first-line treatments for hypertension in a 60-year-old patient?"

print("\n[Generating Response...]")

response = apollo.generate(free_prompt)

print("\n[RESULT]")
print(f"Question: {free_prompt}")
print(f"Answer: {response}")

print("\n" + "=" * 80)
print("‚úì Test 4 Complete!")
print("=" * 80)

# Clean up memory
del apollo
torch.cuda.empty_cache()

In [None]:
print("=" * 80)
print("TEST 6: BioMistral-7B - Yes/No Question with Rationale")
print("=" * 80)

# Load BioMistral-7B
biomistral = load_medical_llm(
    "BioMistral/BioMistral-7B",
    device="cuda",
    torch_dtype=torch.float16
)

# Display model info
print("\n[Model Information]")
info = biomistral.get_model_info()
for key, value in info.items():
    if key != "num_parameters":
        print(f"  {key}: {value}")
    else:
        print(f"  {key}: {value:,}")

# Set task type
biomistral.set_task("yn")
biomistral.set_mode("answer_rationale")

# Yes/No medical question
biomistral_prompt = """Corticosteroids are first-line treatment for acute bacterial meningitis.

A) True
B) False

Answer:"""

print("\n" + "=" * 80)
print("[Generating Response...]")
print("=" * 80)

response = biomistral.generate(biomistral_prompt)

print("\n[RESULT]")
print(response)
print("\n" + "=" * 80)
print("‚úì Test 6 Complete!")
print("=" * 80)

# Clean up memory
del biomistral
torch.cuda.empty_cache()

## 6. Test 6: BioMistral-7B - Yes/No Medical Question

Test the wrapper with BioMistral-7B model.
- **7B parameter model** (largest model in this demo)
- **Mistral architecture** fine-tuned on biomedical data
- **Yes/No task** with answer + rationale

In [None]:
print("=" * 80)
print("TEST 5: BioMedLM - Medical MCQ")
print("=" * 80)

# Load BioMedLM
biomedlm = load_medical_llm(
    "stanford-crfm/BioMedLM",
    device="cuda",
    torch_dtype=torch.float16
)

# Display model info
print("\n[Model Information]")
info = biomedlm.get_model_info()
for key, value in info.items():
    if key != "num_parameters":
        print(f"  {key}: {value}")
    else:
        print(f"  {key}: {value:,}")

# Set task type
biomedlm.set_task("mcq")
biomedlm.set_mode("answer_only")

# Medical MCQ prompt
biomedlm_prompt = """Which class of antibiotics inhibits bacterial cell wall synthesis?
A) Fluoroquinolones
B) Tetracyclines
C) Beta-lactams
D) Aminoglycosides

Answer:"""

print("\n" + "=" * 80)
print("[Generating Response...]")
print("=" * 80)

response = biomedlm.generate(biomedlm_prompt)

print("\n[RESULT]")
print(response)
print(f"\n[CONFIDENCE METRICS]")
print(f"  Selected Answer: {biomedlm.last_answer}")

# Handle NaN confidence values
import math
if biomedlm.last_confidence is not None and not math.isnan(biomedlm.last_confidence):
    print(f"  Confidence: {biomedlm.last_confidence:.4f}")
    print(f"\n  Option Probabilities:")
    for option, prob in sorted(biomedlm.last_option_probs.items()):
        if not math.isnan(prob):
            bar = "‚ñà" * int(prob * 50)
            print(f"    {option}: {prob:.4f} {bar}")
        else:
            print(f"    {option}: NaN (computation error)")
else:
    print(f"  Confidence: NaN (computation error)")
    print(f"\n  Note: Confidence computation failed for this model.")
    print(f"  The answer '{biomedlm.last_answer}' was still generated successfully.")

print("\n" + "=" * 80)
print("‚úì Test 5 Complete!")
print("=" * 80)

# Clean up memory
del biomedlm
torch.cuda.empty_cache()

## 5. Test 5: BioMedLM - Medical MCQ

Test the wrapper with Stanford's BioMedLM model.
- **2.7B parameter model** trained on biomedical literature
- **GPT-2 architecture** (different from Gemma/Llama)
- **MCQ task** with confidence scoring

## 6. Test 5: Batch Processing - Multiple Prompts

Test batch generation functionality with multiple prompts.
- **Batch processing** of multiple questions
- **Progress tracking**
- **Efficient memory management**

In [None]:
print("=" * 80)
print("TEST 5: MedGemma - Batch Processing Multiple MCQs")
print("=" * 80)

# Reload MedGemma for batch testing
medgemma = load_medical_llm(
    "google/medgemma-4b-it",
    device="cuda"
)

medgemma.set_task("mcq")
medgemma.set_mode("answer_only")

# Multiple MCQ prompts
batch_prompts = [
    """Which vitamin deficiency causes scurvy?
A) Vitamin A
B) Vitamin B12
C) Vitamin C
D) Vitamin D

Answer:""",
    
    """What is the normal range for fasting blood glucose?
A) 50-70 mg/dL
B) 70-100 mg/dL
C) 100-125 mg/dL
D) 125-150 mg/dL

Answer:""",
    
    """Which organ produces insulin?
A) Liver
B) Pancreas
C) Kidney
D) Spleen

Answer:"""
]

print("\n[Batch Generation Starting...]")
print("=" * 80)

results = medgemma.batch_generate(batch_prompts, show_progress=True)

print("\n[RESULTS]")
for i, (prompt, result) in enumerate(zip(batch_prompts, results), 1):
    question = prompt.split('\n')[0]
    print(f"\n{i}. {question}")
    print(f"   {result}")

print("\n" + "=" * 80)
print("‚úì Test 5 Complete!")
print("=" * 80)

# Clean up memory
del medgemma
torch.cuda.empty_cache()

## 7. Test 6: Cross-Model Comparison - Same Question

Compare how different models answer the same medical question.
- **MedGemma vs Apollo** on identical prompt
- **Confidence comparison**
- **Model-agnostic API** demonstration

In [None]:
print("=" * 80)
print("TEST 6: Cross-Model Comparison - MedGemma vs Apollo")
print("=" * 80)

# Comparison prompt
comparison_prompt = """Aspirin is used for primary prevention of cardiovascular disease in high-risk patients.

A) True
B) False

Answer:"""

print(f"\nQuestion: {comparison_prompt.split('Answer:')[0].strip()}")
print("\n" + "-" * 80)

models_to_test = [
    ("google/medgemma-4b-it", "MedGemma-4B-IT"),
    ("FreedomIntelligence/Apollo-2B", "Apollo-2B")
]

results_comparison = {}

for model_name, display_name in models_to_test:
    print(f"\n[Loading {display_name}...]")
    
    wrapper = load_medical_llm(model_name, device="cuda")
    wrapper.set_task("yn")
    wrapper.set_mode("answer_only")
    
    response = wrapper.generate(comparison_prompt)
    
    results_comparison[display_name] = {
        "answer": wrapper.last_answer,
        "confidence": wrapper.last_confidence,
        "option_probs": wrapper.last_option_probs
    }
    
    print(f"‚úì {display_name} complete")
    
    # Clean up
    del wrapper
    torch.cuda.empty_cache()

# Display comparison results
print("\n" + "=" * 80)
print("[COMPARISON RESULTS]")
print("=" * 80)

for model_name, results in results_comparison.items():
    print(f"\n{model_name}:")
    print(f"  Answer: {results['answer']}")
    
    # Handle NaN confidence
    import math
    conf = results['confidence']
    if conf is not None and not math.isnan(conf):
        print(f"  Confidence: {conf:.4f}")
        print(f"  Probabilities:")
        for option, prob in sorted(results['option_probs'].items()):
            if not math.isnan(prob):
                bar = "‚ñà" * int(prob * 40)
                print(f"    {option}: {prob:.4f} {bar}")
            else:
                print(f"    {option}: NaN")
    else:
        print(f"  Confidence: NaN (computation error)")
        print(f"  Note: Answer generated successfully, but confidence calculation failed.")

print("\n" + "=" * 80)
print("‚úì Test 6 Complete!")
print("=" * 80)

## 8. Summary of Results

### ‚úÖ Successfully Tested:

1. **MedGemma-4B-IT**
   - MCQ with answer + rationale
   - Confidence scoring
   - Automatic fp32 conversion
   - Batch processing

2. **Apollo-2B**
   - Yes/No questions
   - Free-response generation
   - Native fp16 operation
   - Cross-model comparison

3. **Wrapper Features**
   - Model-agnostic loading
   - Task type switching (yn/mcq/free)
   - Mode switching (answer_rationale/answer_only)
   - Batch generation with progress
   - Confidence extraction
   - Model metadata access

### Key Observations:

- **Same API works for all models** - true model-agnostic design ‚úÖ
- **Automatic dtype handling** - MedGemma ‚Üí fp32, others ‚Üí native dtype ‚úÖ
- **Robust tokenizer handling** - works across different tokenizer implementations ‚úÖ
- **Memory efficient** - proper cleanup between model loads ‚úÖ

### Potential Use Cases:

1. **Medical Question Answering**: Structured MCQ evaluation
2. **Clinical Decision Support**: Confidence-weighted recommendations
3. **Model Benchmarking**: Compare multiple models on same tasks
4. **Educational Tools**: Generate explanations with rationales
5. **Research**: Model interpretability and comparison studies

## 9. Optional: Additional Testing

You can add more models or custom tests here:

In [None]:
# Example: Test with your own prompt
def custom_test(model_name, prompt, task_type="mcq", mode="answer_rationale"):
    """
    Run a custom test with any model and prompt.
    """
    print(f"\n{'=' * 80}")
    print(f"CUSTOM TEST: {model_name}")
    print('=' * 80)
    
    wrapper = load_medical_llm(model_name, device="cuda")
    wrapper.set_task(task_type)
    wrapper.set_mode(mode)
    
    print(f"\nPrompt:\n{prompt}")
    print("\n[Generating...]")
    
    response = wrapper.generate(prompt)
    
    print(f"\n[Result]")
    print(response)
    
    if mode == "answer_only":
        print(f"\nConfidence: {wrapper.last_confidence:.4f}")
        if wrapper.last_option_probs:
            print("Option Probabilities:")
            for opt, prob in sorted(wrapper.last_option_probs.items()):
                print(f"  {opt}: {prob:.4f}")
    
    del wrapper
    torch.cuda.empty_cache()
    
    print('=' * 80)

# Uncomment to run custom test:
# custom_test(
#     "google/medgemma-4b-it",
#     "Your medical question here...\nA) Option A\nB) Option B\nAnswer:",
#     task_type="mcq",
#     mode="answer_rationale"
# )

## 10. Cleanup and Final Notes

### Memory Management:
- Each model is properly unloaded after testing
- `torch.cuda.empty_cache()` clears GPU memory
- Batch operations handle memory efficiently

### Model Requirements:
- **MedGemma**: Gated model, requires HF token, needs ~12GB GPU (fp32)
- **Apollo**: Public model, no token needed, needs ~4GB GPU (fp16)

### Next Steps:
- Test with additional models (BioMistral, BioMedLM)
- Implement custom LogitsProcessors for domain-specific constraints
- Add few-shot prompting examples
- Export results for analysis

In [None]:
# Final cleanup
import gc

gc.collect()
torch.cuda.empty_cache()

print("‚úì All tests complete!")
print("‚úì Memory cleaned up")
print("\n" + "=" * 80)
print("Medical LLM Wrapper Demo - SUCCESS!")
print("=" * 80)