## 1. Setup and Installation

In [1]:
# Install required packages
!pip install torch transformers accelerate hf_xet -q

print("✓ Packages installed successfully!")

✓ Packages installed successfully!


In [2]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
# Add project path to Python path
import sys
import os

project_path = "/content/drive/MyDrive/DATA 298A/sjsu-data298-main"

if project_path not in sys.path:
    sys.path.append(project_path)

print(f"✓ Project path: {project_path}")
print(f"✓ Path exists: {os.path.exists(project_path)}")

if os.path.exists(project_path):
    contents = os.listdir(project_path)
    print(f"✓ Contents: {contents}")

    # Check for the wrapper file
    if "medical_llm_wrapper.py" in contents:
        print("✓ medical_llm_wrapper.py found!")
    else:
        print("\n⚠️  WARNING: medical_llm_wrapper.py NOT FOUND!")
        print("   Please upload medical_llm_wrapper.py to:")
        print(f"   {project_path}")
else:
    print(f"\n⚠️  ERROR: Project path does not exist!")
    print(f"   Please create the folder: {project_path}")
    print(f"   Or update the path above to match your Google Drive structure.")

✓ Project path: /content/drive/MyDrive/DATA 298A/sjsu-data298-main
✓ Path exists: True
✓ Contents: ['medgemma_wrapper.py', '__pycache__', 'medgemma_integrated_gradients.ipynb', 'integrated_gradients.py', 'integrated_gradients_general.ipynb', 'medical_llm_wrapper_demo (1).ipynb', 'medical_llm_wrapper.py', 'medical_llm_wrapper_demo.ipynb']
✓ medical_llm_wrapper.py found!


In [4]:
# Import the wrapper
import warnings
warnings.filterwarnings('once')

import sys
import os

# Double-check the path is added (in case cells ran out of order)
project_path = "/content/drive/MyDrive/DATA 298A/sjsu-data298-main"
if project_path not in sys.path:
    sys.path.insert(0, project_path)
    print(f"✓ Added {project_path} to Python path")

# Verify the file exists
wrapper_file = os.path.join(project_path, "medical_llm_wrapper.py")
if os.path.exists(wrapper_file):
    print(f"✓ Found: {wrapper_file}")
else:
    raise FileNotFoundError(f"medical_llm_wrapper.py not found at {wrapper_file}")

# Now import
try:
    from medical_llm_wrapper import MedicalLLMWrapper, load_medical_llm
    import torch
    print("✓ Medical LLM Wrapper imported successfully!")
except ModuleNotFoundError as e:
    print(f"❌ Error: {e}")
    print("\n📋 Troubleshooting steps:")
    print("1. RESTART THE RUNTIME: Runtime → Restart runtime")
    print("2. Re-run all cells from the beginning")
    print("3. Make sure you run the 'Add project path' cell BEFORE this cell")
    print(f"\n💡 Current sys.path:")
    for p in sys.path[:5]:
        print(f"   {p}")
    raise

✓ Found: /content/drive/MyDrive/DATA 298A/sjsu-data298-main/medical_llm_wrapper.py


  return datetime.utcnow().replace(tzinfo=utc)


✓ Medical LLM Wrapper imported successfully!


  return datetime.utcnow().replace(tzinfo=utc)


## 2. Test 1: MedGemma-4B-IT - Multiple Choice Question

Test the wrapper with MedGemma on a clinical diagnosis MCQ.
- **Automatic fp32 conversion** (MedGemma requires float32)
- **Constrained generation** (forces A/B/C/D)
- **Answer + Rationale mode**

In [5]:
# Load MedGemma (auto-converts to fp32 when torch_dtype=None)
medgemma = load_medical_llm(
    "google/medgemma-4b-it",
    device="cuda"
    # Don't pass torch_dtype - let wrapper auto-detect and use fp32 for MedGemma
)

[MedicalLLMWrapper] Loading model: google/medgemma-4b-it
[MedicalLLMWrapper] Detected MedGemma - automatically using float32


tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]



tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]



tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]



added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

chat_template.jinja:   0%|          | 0.00/1.53k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/2.47k [00:00<?, ?B/s]

  return datetime.utcnow().replace(tzinfo=utc)


model.safetensors.index.json:   0%|          | 0.00/90.6k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.64G [00:00<?, ?B/s]



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

[MedicalLLMWrapper] ✓ Model loaded successfully
[MedicalLLMWrapper]   Device: cuda
[MedicalLLMWrapper]   Dtype: torch.float32
[MedicalLLMWrapper]   Option token IDs - AB: [236776, 236799], ABCD: [236776, 236799, 236780, 236796]


In [6]:
print("=" * 80)
print("TEST 1: MedGemma-4B-IT - MCQ with Answer + Rationale")
print("=" * 80)

# Display model info
print("\n[Model Information]")
info = medgemma.get_model_info()
for key, value in info.items():
    if key != "num_parameters":
        print(f"  {key}: {value}")
    else:
        print(f"  {key}: {value:,}")

# Set task type
medgemma.set_task("mcq")
medgemma.set_mode("answer_rationale")

# Medical MCQ prompt
mcq_prompt = """A 55-year-old patient presents with persistent cough, hemoptysis, and unintentional weight loss.
Chest X-ray shows a mass in the right upper lobe. What is the most likely diagnosis?

A) Tuberculosis
B) Lung cancer
C) Pneumonia
D) Pulmonary embolism

Answer:"""

print("\n" + "=" * 80)
print("[Generating Response...]")
print("=" * 80)

response = medgemma.generate(mcq_prompt)

print("\n[RESULT]")
print(response)
print("\n" + "=" * 80)
print("✓ Test 1 Complete!")
print("=" * 80)

TEST 1: MedGemma-4B-IT - MCQ with Answer + Rationale

[Model Information]
  model_name: google/medgemma-4b-it
  device: cuda
  dtype: torch.float32
  task_type: free
  mode: answer_rationale
  num_parameters: 4,300,079,472
  AB_token_ids: [236776, 236799]
  ABCD_token_ids: [236776, 236799, 236780, 236796]

[Generating Response...]

[RESULT]
Answer: B
Rationale: The patient's symptoms (persistent cough, hemoptysis, unintentional weight loss) and the presence of a mass on chest X-ray are highly suggestive of lung cancer. While other conditions can cause some of these symptoms, the combination is most indicative of malignancy. * **Lung Cancer:** This is the most likely diagnosis given the constellation of symptoms and radiographic findings. The patient's age is also a risk factor for lung cancer. * **Tuberculosis:** Tuberculosis can cause cough, hemoptysis, and weight loss, but a mass on chest X-ray is less common than with lung cancer. * **Pneumonia:** Pneumonia typically presents with a

## 3. Test 2: MedGemma - Answer Only Mode with Confidence

Test confidence extraction for MCQ answers.
- **Answer-only mode** (no rationale)
- **Confidence scores** for each option
- **Probability distribution** over A/B/C/D

In [7]:
print("\n" + "=" * 80)
print("TEST 2: MedGemma - MCQ with Confidence Scores")
print("=" * 80)

# Switch to answer-only mode
medgemma.set_mode("answer_only")

# Generate with confidence
response = medgemma.generate(mcq_prompt)

print("\n[RESULT]")
print(response)
print(f"\n[CONFIDENCE METRICS]")
print(f"  Selected Answer: {medgemma.last_answer}")

# Handle NaN confidence values
import math
if medgemma.last_confidence is not None and not math.isnan(medgemma.last_confidence):
    print(f"  Confidence: {medgemma.last_confidence:.4f}")
    print(f"\n  Option Probabilities:")
    for option, prob in sorted(medgemma.last_option_probs.items()):
        if not math.isnan(prob):
            bar = "█" * int(prob * 50)
            print(f"    {option}: {prob:.4f} {bar}")
        else:
            print(f"    {option}: NaN (computation error)")
else:
    print(f"  Confidence: NaN (computation error)")
    print(f"\n  Note: Confidence computation failed. This may happen with fp32 models.")
    print(f"  The answer '{medgemma.last_answer}' was still generated successfully.")

print("\n" + "=" * 80)
print("✓ Test 2 Complete!")
print("=" * 80)

# Clean up memory
del medgemma
torch.cuda.empty_cache()


TEST 2: MedGemma - MCQ with Confidence Scores

[RESULT]
Answer: B

[CONFIDENCE METRICS]
  Selected Answer: B
  Confidence: 0.9997

  Option Probabilities:
    A: 0.0002 
    B: 0.9997 █████████████████████████████████████████████████
    C: 0.0001 
    D: 0.0001 

✓ Test 2 Complete!


## 4. Test 3: Apollo-2B - Yes/No Question

Test the wrapper with a different model (Apollo) on a binary Yes/No question.
- **Different architecture** (Llama-based vs Gemma-based)
- **Native fp16** (no conversion needed)
- **Yes/No task** (A=Yes, B=No)

In [8]:
# Load Apollo (works fine in fp16)
apollo = load_medical_llm(
    "FreedomIntelligence/Apollo-2B",
    device="cuda",
    torch_dtype=torch.float16
)

[MedicalLLMWrapper] Loading model: FreedomIntelligence/Apollo-2B


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]



tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]



special_tokens_map.json:   0%|          | 0.00/555 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/707 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.91G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/134M [00:00<?, ?B/s]



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/132 [00:00<?, ?B/s]

[MedicalLLMWrapper] ✓ Model loaded successfully
[MedicalLLMWrapper]   Device: cuda
[MedicalLLMWrapper]   Dtype: torch.float16
[MedicalLLMWrapper]   Option token IDs - AB: [235280, 235305], ABCD: [235280, 235305, 235288, 235299]


In [9]:
print("=" * 80)
print("TEST 3: Apollo-2B - Yes/No Question with Answer + Rationale")
print("=" * 80)

# Display model info
print("\n[Model Information]")
info = apollo.get_model_info()
for key, value in info.items():
    if key != "num_parameters":
        print(f"  {key}: {value}")
    else:
        print(f"  {key}: {value:,}")

# Set task type to Yes/No
apollo.set_task("yn")
apollo.set_mode("answer_rationale")

# Yes/No medical question
yn_prompt = """Metformin is contraindicated in patients with severe renal impairment.

A) Yes
B) No

Answer:"""

print("\n" + "=" * 80)
print("[Generating Response...]")
print("=" * 80)

response = apollo.generate(yn_prompt)

print("\n[RESULT]")
print(response)
print("\n" + "=" * 80)
print("✓ Test 3 Complete!")
print("=" * 80)

TEST 3: Apollo-2B - Yes/No Question with Answer + Rationale

[Model Information]
  model_name: FreedomIntelligence/Apollo-2B
  device: cuda
  dtype: torch.float16
  task_type: free
  mode: answer_rationale
  num_parameters: 2,506,172,416
  AB_token_ids: [235280, 235305]
  ABCD_token_ids: [235280, 235305, 235288, 235299]

[Generating Response...]

[RESULT]
Answer: A
Rationale: Metformin is a medication used to treat type 2 diabetes. It is commonly used in patients with impaired renal function. However, in severe renal impairment, the dosage of metformin needs to be adjusted to ensure adequate blood sugar control without causing adverse effects. It is important for healthcare providers to monitor renal function regularly in patients on metformin and adjust the dosage accordingly. Therefore, the correct answer is A.

✓ Test 3 Complete!


## 5. Test 4: Apollo - Free-Response Generation

Test unconstrained generation for open-ended medical questions.
- **Free-response task** (no answer constraints)
- **Longer generation** (up to 200 tokens)
- **Medical knowledge test**

In [10]:
print("\n" + "=" * 80)
print("TEST 4: Apollo-2B - Free-Response Medical Question")
print("=" * 80)

# Set to free-response task
apollo.set_task("free")

# Open-ended medical question
free_prompt = "What are the first-line treatments for hypertension in a 60-year-old patient?"

print("\n[Generating Response...]")

response = apollo.generate(free_prompt)

print("\n[RESULT]")
print(f"Question: {free_prompt}")
print(f"Answer: {response}")

print("\n" + "=" * 80)
print("✓ Test 4 Complete!")
print("=" * 80)

# Clean up memory
del apollo
torch.cuda.empty_cache()


TEST 4: Apollo-2B - Free-Response Medical Question

[Generating Response...]

[RESULT]
Question: What are the first-line treatments for hypertension in a 60-year-old patient?
Answer: A 100-W lightbulb is plugged into a standard 120-V outlet. (a) How much does it cost per 31-kW h to operate the lightbulb? Assume electricity costs 8 cents per kWh. (b) Repeat for a 100-W incandescent lightbulb. (The wattage listed represents the power consumption of the filament.)

A 100-W lightbulb is plugged into a 120-V outlet. a. What is the current through the lightbulb when it is on? b. What is the resistance of the lightbulb?

A 100-W lightbulb is plugged into a 120-V outlet. a. What is the current through the lightbulb when it is on? b. What is the resistance of the lightbulb?

✓ Test 4 Complete!


In [11]:
# Load BioMistral-7B
biomistral = load_medical_llm(
    "BioMistral/BioMistral-7B",
    device="cuda",
    torch_dtype=torch.float16
)

[MedicalLLMWrapper] Loading model: BioMistral/BioMistral-7B


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]



tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/72.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/567 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/14.5G [00:00<?, ?B/s]



generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

[MedicalLLMWrapper] ✓ Model loaded successfully
[MedicalLLMWrapper]   Device: cuda
[MedicalLLMWrapper]   Dtype: torch.float16
[MedicalLLMWrapper]   Option token IDs - AB: [330, 365], ABCD: [330, 365, 334, 384]


In [12]:
print("=" * 80)
print("TEST 6: BioMistral-7B - Yes/No Question with Rationale")
print("=" * 80)

# Display model info
print("\n[Model Information]")
info = biomistral.get_model_info()
for key, value in info.items():
    if key != "num_parameters":
        print(f"  {key}: {value}")
    else:
        print(f"  {key}: {value:,}")

# Set task type
biomistral.set_task("yn")
biomistral.set_mode("answer_rationale")

# Yes/No medical question
biomistral_prompt = """Corticosteroids are first-line treatment for acute bacterial meningitis.

A) True
B) False

Answer:"""

print("\n" + "=" * 80)
print("[Generating Response...]")
print("=" * 80)

response = biomistral.generate(biomistral_prompt)

print("\n[RESULT]")
print(response)
print("\n" + "=" * 80)
print("✓ Test 6 Complete!")
print("=" * 80)

# Clean up memory
del biomistral
torch.cuda.empty_cache()

TEST 6: BioMistral-7B - Yes/No Question with Rationale

[Model Information]
  model_name: BioMistral/BioMistral-7B
  device: cuda
  dtype: torch.float16
  task_type: free
  mode: answer_rationale
  num_parameters: 7,241,732,096
  AB_token_ids: [330, 365]
  ABCD_token_ids: [330, 365, 334, 384]

[Generating Response...]

[RESULT]
Answer: A
Rationale: Corticosteroids have been shown to reduce cerebral edema and intracranial pressure in animal models of meningitis and meningeal infection, and in children with meningitis due to H. influenzae type B. In addition, the use of corticosteroids in children with bacterial meningitis has been associated with a reduction in hearing loss and a reduction in mortality.

✓ Test 6 Complete!


## 6. Test 6: BioMistral-7B - Yes/No Medical Question

Test the wrapper with BioMistral-7B model.
- **7B parameter model** (largest model in this demo)
- **Mistral architecture** fine-tuned on biomedical data
- **Yes/No task** with answer + rationale

In [13]:
# Load BioMedLM
biomedlm = load_medical_llm(
    "stanford-crfm/BioMedLM",
    device="cuda",
    torch_dtype=torch.float16
)

[MedicalLLMWrapper] Loading model: stanford-crfm/BioMedLM


tokenizer_config.json:   0%|          | 0.00/267 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/876 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/10.7G [00:00<?, ?B/s]



model.safetensors:   0%|          | 0.00/10.7G [00:00<?, ?B/s]



[MedicalLLMWrapper] ✓ Model loaded successfully
[MedicalLLMWrapper]   Device: cuda
[MedicalLLMWrapper]   Dtype: torch.float16
[MedicalLLMWrapper]   Option token IDs - AB: [32, 33], ABCD: [32, 33, 34, 35]




In [14]:

print("=" * 80)
print("TEST 5: BioMedLM - Medical MCQ")
print("=" * 80)

# Display model info
print("\n[Model Information]")
info = biomedlm.get_model_info()
for key, value in info.items():
    if key != "num_parameters":
        print(f"  {key}: {value}")
    else:
        print(f"  {key}: {value:,}")

# Set task type
biomedlm.set_task("mcq")
biomedlm.set_mode("answer_only")

# Medical MCQ prompt
biomedlm_prompt = """Which class of antibiotics inhibits bacterial cell wall synthesis?
A) Fluoroquinolones
B) Tetracyclines
C) Beta-lactams
D) Aminoglycosides

Answer:"""

print("\n" + "=" * 80)
print("[Generating Response...]")
print("=" * 80)

response = biomedlm.generate(biomedlm_prompt)

print("\n[RESULT]")
print(response)
print(f"\n[CONFIDENCE METRICS]")
print(f"  Selected Answer: {biomedlm.last_answer}")

# Handle NaN confidence values
import math
if biomedlm.last_confidence is not None and not math.isnan(biomedlm.last_confidence):
    print(f"  Confidence: {biomedlm.last_confidence:.4f}")
    print(f"\n  Option Probabilities:")
    for option, prob in sorted(biomedlm.last_option_probs.items()):
        if not math.isnan(prob):
            bar = "█" * int(prob * 50)
            print(f"    {option}: {prob:.4f} {bar}")
        else:
            print(f"    {option}: NaN (computation error)")
else:
    print(f"  Confidence: NaN (computation error)")
    print(f"\n  Note: Confidence computation failed for this model.")
    print(f"  The answer '{biomedlm.last_answer}' was still generated successfully.")

print("\n" + "=" * 80)
print("✓ Test 5 Complete!")
print("=" * 80)

# Clean up memory
del biomedlm
torch.cuda.empty_cache()



TEST 5: BioMedLM - Medical MCQ

[Model Information]
  model_name: stanford-crfm/BioMedLM
  device: cuda
  dtype: torch.float16
  task_type: free
  mode: answer_rationale
  num_parameters: 2,594,247,680
  AB_token_ids: [32, 33]
  ABCD_token_ids: [32, 33, 34, 35]

[Generating Response...]





[RESULT]
Answer: D

[CONFIDENCE METRICS]
  Selected Answer: D
  Confidence: 0.4283

  Option Probabilities:
    A: 0.2578 ████████████
    B: 0.0562 ██
    C: 0.2578 ████████████
    D: 0.4283 █████████████████████

✓ Test 5 Complete!


## 5. Test 5: BioMedLM - Medical MCQ

Test the wrapper with Stanford's BioMedLM model.
- **2.7B parameter model** trained on biomedical literature
- **GPT-2 architecture** (different from Gemma/Llama)
- **MCQ task** with confidence scoring

## 6. Test 5: Batch Processing - Multiple Prompts

Test batch generation functionality with multiple prompts.
- **Batch processing** of multiple questions
- **Progress tracking**
- **Efficient memory management**

In [15]:
# Reload MedGemma for batch testing
medgemma = load_medical_llm(
    "google/medgemma-4b-it",
    device="cuda"
)

[MedicalLLMWrapper] Loading model: google/medgemma-4b-it
[MedicalLLMWrapper] Detected MedGemma - automatically using float32




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



[MedicalLLMWrapper] ✓ Model loaded successfully
[MedicalLLMWrapper]   Device: cuda
[MedicalLLMWrapper]   Dtype: torch.float32
[MedicalLLMWrapper]   Option token IDs - AB: [236776, 236799], ABCD: [236776, 236799, 236780, 236796]


In [16]:
print("=" * 80)
print("TEST 5: MedGemma - Batch Processing Multiple MCQs")
print("=" * 80)

medgemma.set_task("mcq")
medgemma.set_mode("answer_only")

# Multiple MCQ prompts
batch_prompts = [
    """Which vitamin deficiency causes scurvy?
A) Vitamin A
B) Vitamin B12
C) Vitamin C
D) Vitamin D

Answer:""",

    """What is the normal range for fasting blood glucose?
A) 50-70 mg/dL
B) 70-100 mg/dL
C) 100-125 mg/dL
D) 125-150 mg/dL

Answer:""",

    """Which organ produces insulin?
A) Liver
B) Pancreas
C) Kidney
D) Spleen

Answer:"""
]

print("\n[Batch Generation Starting...]")
print("=" * 80)

results = medgemma.batch_generate(batch_prompts, show_progress=True)

print("\n[RESULTS]")
for i, (prompt, result) in enumerate(zip(batch_prompts, results), 1):
    question = prompt.split('\n')[0]
    print(f"\n{i}. {question}")
    print(f"   {result}")

print("\n" + "=" * 80)
print("✓ Test 5 Complete!")
print("=" * 80)

# Clean up memory
del medgemma
torch.cuda.empty_cache()

TEST 5: MedGemma - Batch Processing Multiple MCQs

[Batch Generation Starting...]
[1/3] Processing...



[2/3] Processing...[3/3] Processing...[3/3] Complete!     

[RESULTS]

1. Which vitamin deficiency causes scurvy?
   Answer: C

2. What is the normal range for fasting blood glucose?
   Answer: B

3. Which organ produces insulin?
   Answer: B

✓ Test 5 Complete!




## 7. Test 6: Cross-Model Comparison - Same Question

Compare how different models answer the same medical question.
- **MedGemma vs Apollo** on identical prompt
- **Confidence comparison**
- **Model-agnostic API** demonstration

In [17]:
print("=" * 80)
print("TEST 6: Cross-Model Comparison - MedGemma vs Apollo")
print("=" * 80)

# Comparison prompt
comparison_prompt = """Aspirin is used for primary prevention of cardiovascular disease in high-risk patients.

A) True
B) False

Answer:"""

print(f"\nQuestion: {comparison_prompt.split('Answer:')[0].strip()}")
print("\n" + "-" * 80)

models_to_test = [
    ("google/medgemma-4b-it", "MedGemma-4B-IT"),
    ("FreedomIntelligence/Apollo-2B", "Apollo-2B")
]

results_comparison = {}

for model_name, display_name in models_to_test:
    print(f"\n[Loading {display_name}...]")

    wrapper = load_medical_llm(model_name, device="cuda")
    wrapper.set_task("yn")
    wrapper.set_mode("answer_only")

    response = wrapper.generate(comparison_prompt)

    results_comparison[display_name] = {
        "answer": wrapper.last_answer,
        "confidence": wrapper.last_confidence,
        "option_probs": wrapper.last_option_probs
    }

    print(f"✓ {display_name} complete")

    # Clean up
    del wrapper
    torch.cuda.empty_cache()



TEST 6: Cross-Model Comparison - MedGemma vs Apollo

Question: Aspirin is used for primary prevention of cardiovascular disease in high-risk patients.

A) True
B) False

--------------------------------------------------------------------------------

[Loading MedGemma-4B-IT...]
[MedicalLLMWrapper] Loading model: google/medgemma-4b-it
[MedicalLLMWrapper] Detected MedGemma - automatically using float32




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

[MedicalLLMWrapper] ✓ Model loaded successfully
[MedicalLLMWrapper]   Device: cuda
[MedicalLLMWrapper]   Dtype: torch.float32
[MedicalLLMWrapper]   Option token IDs - AB: [236776, 236799], ABCD: [236776, 236799, 236780, 236796]
✓ MedGemma-4B-IT complete

[Loading Apollo-2B...]
[MedicalLLMWrapper] Loading model: FreedomIntelligence/Apollo-2B


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

[MedicalLLMWrapper] ✓ Model loaded successfully
[MedicalLLMWrapper]   Device: cuda
[MedicalLLMWrapper]   Dtype: torch.float32
[MedicalLLMWrapper]   Option token IDs - AB: [235280, 235305], ABCD: [235280, 235305, 235288, 235299]
✓ Apollo-2B complete


In [18]:
# Display comparison results
print("\n" + "=" * 80)
print("[COMPARISON RESULTS]")
print("=" * 80)

for model_name, results in results_comparison.items():
    print(f"\n{model_name}:")
    print(f"  Answer: {results['answer']}")

    # Handle NaN confidence
    import math
    conf = results['confidence']
    if conf is not None and not math.isnan(conf):
        print(f"  Confidence: {conf:.4f}")
        print(f"  Probabilities:")
        for option, prob in sorted(results['option_probs'].items()):
            if not math.isnan(prob):
                bar = "█" * int(prob * 40)
                print(f"    {option}: {prob:.4f} {bar}")
            else:
                print(f"    {option}: NaN")
    else:
        print(f"  Confidence: NaN (computation error)")
        print(f"  Note: Answer generated successfully, but confidence calculation failed.")

print("\n" + "=" * 80)
print("✓ Test 6 Complete!")
print("=" * 80)


[COMPARISON RESULTS]

MedGemma-4B-IT:
  Answer: A
  Confidence: 0.9037
  Probabilities:
    A: 0.9037 ████████████████████████████████████
    B: 0.0963 ███

Apollo-2B:
  Answer: A
  Confidence: 0.7249
  Probabilities:
    A: 0.7249 ████████████████████████████
    B: 0.2751 ███████████

✓ Test 6 Complete!


## 8. Summary of Results

### ✅ Successfully Tested:

1. **MedGemma-4B-IT**
   - MCQ with answer + rationale
   - Confidence scoring
   - Automatic fp32 conversion
   - Batch processing

2. **Apollo-2B**
   - Yes/No questions
   - Free-response generation
   - Native fp16 operation
   - Cross-model comparison

3. **Wrapper Features**
   - Model-agnostic loading
   - Task type switching (yn/mcq/free)
   - Mode switching (answer_rationale/answer_only)
   - Batch generation with progress
   - Confidence extraction
   - Model metadata access

### Key Observations:

- **Same API works for all models** - true model-agnostic design ✅
- **Automatic dtype handling** - MedGemma → fp32, others → native dtype ✅
- **Robust tokenizer handling** - works across different tokenizer implementations ✅
- **Memory efficient** - proper cleanup between model loads ✅

### Potential Use Cases:

1. **Medical Question Answering**: Structured MCQ evaluation
2. **Clinical Decision Support**: Confidence-weighted recommendations
3. **Model Benchmarking**: Compare multiple models on same tasks
4. **Educational Tools**: Generate explanations with rationales
5. **Research**: Model interpretability and comparison studies

## 9. Optional: Additional Testing

You can add more models or custom tests here:

In [19]:
# Example: Test with your own prompt
def custom_test(model_name, prompt, task_type="mcq", mode="answer_rationale"):
    """
    Run a custom test with any model and prompt.
    """
    print(f"\n{'=' * 80}")
    print(f"CUSTOM TEST: {model_name}")
    print('=' * 80)

    wrapper = load_medical_llm(model_name, device="cuda")
    wrapper.set_task(task_type)
    wrapper.set_mode(mode)

    print(f"\nPrompt:\n{prompt}")
    print("\n[Generating...]")

    response = wrapper.generate(prompt)

    print(f"\n[Result]")
    print(response)

    if mode == "answer_only":
        print(f"\nConfidence: {wrapper.last_confidence:.4f}")
        if wrapper.last_option_probs:
            print("Option Probabilities:")
            for opt, prob in sorted(wrapper.last_option_probs.items()):
                print(f"  {opt}: {prob:.4f}")

    del wrapper
    torch.cuda.empty_cache()

    print('=' * 80)

# Uncomment to run custom test:
# custom_test(
#     "google/medgemma-4b-it",
#     "Your medical question here...\nA) Option A\nB) Option B\nAnswer:",
#     task_type="mcq",
#     mode="answer_rationale"
# )

## 10. Cleanup and Final Notes

### Memory Management:
- Each model is properly unloaded after testing
- `torch.cuda.empty_cache()` clears GPU memory
- Batch operations handle memory efficiently

### Model Requirements:
- **MedGemma**: Gated model, requires HF token, needs ~12GB GPU (fp32)
- **Apollo**: Public model, no token needed, needs ~4GB GPU (fp16)

### Next Steps:
- Test with additional models (BioMistral, BioMedLM)
- Implement custom LogitsProcessors for domain-specific constraints
- Add few-shot prompting examples
- Export results for analysis

In [20]:
# Final cleanup
import gc

gc.collect()
torch.cuda.empty_cache()

print("✓ All tests complete!")
print("✓ Memory cleaned up")
print("\n" + "=" * 80)
print("Medical LLM Wrapper Demo - SUCCESS!")
print("=" * 80)

✓ All tests complete!
✓ Memory cleaned up

Medical LLM Wrapper Demo - SUCCESS!
