# MedGemma Lab Report Extraction Test

Testing MedGemma 1.5 4B for lab report analysis.

**Goal**: Validate that MedGemma can:
1. Extract lab values from Indian lab report images
2. Identify test names, values, units, and reference ranges
3. Classify values as normal/high/low

In [None]:
# Install dependencies
!pip install transformers torch accelerate Pillow -q

In [None]:
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModelForVision2Seq
import json

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"MPS available: {torch.backends.mps.is_available()}")

## Load MedGemma Model

Using the 4B instruction-tuned variant for best results.

In [None]:
# Model ID - MedGemma 1.5 4B
MODEL_ID = "google/medgemma-4b-it"

# Detect device
if torch.cuda.is_available():
    device = "cuda"
    dtype = torch.float16
elif torch.backends.mps.is_available():
    device = "mps"
    dtype = torch.float16
else:
    device = "cpu"
    dtype = torch.float32

print(f"Using device: {device}")

In [None]:
# Load processor and model
print("Loading processor...")
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)

print("Loading model (this may take a few minutes)...")
model = AutoModelForVision2Seq.from_pretrained(
    MODEL_ID,
    torch_dtype=dtype,
    device_map="auto",
    trust_remote_code=True
)

print("Model loaded successfully!")

## Extraction Prompt

Carefully crafted prompt for lab value extraction.

In [None]:
EXTRACTION_PROMPT = """You are a medical lab report analyzer. Extract all test values from this lab report image.

For each test, provide:
- test_name: Name of the test (e.g., "Hemoglobin", "Fasting Blood Sugar", "TSH")
- value: Numeric value as shown
- unit: Unit of measurement (e.g., "g/dL", "mg/dL", "mIU/L")
- reference_range: Normal range as shown on report
- status: "normal", "high", or "low" based on reference range

Return as a JSON array. Example format:
[
  {
    "test_name": "Hemoglobin",
    "value": 14.2,
    "unit": "g/dL",
    "reference_range": "13.0 - 17.0",
    "status": "normal"
  }
]

Important:
- Extract ALL tests visible in the report
- Use exact values as shown (don't round)
- If reference range is missing, use "N/A"
- Handle common Indian lab formats (Thyrocare, SRL, Dr. Lal PathLabs, Metropolis)
"""

## Test Function

In [None]:
def extract_lab_values(image_path: str) -> list:
    """Extract lab values from a lab report image."""
    
    # Load image
    image = Image.open(image_path)
    if image.mode != 'RGB':
        image = image.convert('RGB')
    
    # Resize if needed (MedGemma optimal: up to 1568x1568)
    max_size = 1568
    if max(image.size) > max_size:
        ratio = max_size / max(image.size)
        new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
        image = image.resize(new_size, Image.Resampling.LANCZOS)
    
    print(f"Image size: {image.size}")
    
    # Prepare inputs
    inputs = processor(
        images=image,
        text=EXTRACTION_PROMPT,
        return_tensors="pt"
    ).to(device)
    
    # Generate
    print("Extracting lab values...")
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=2048,
            do_sample=False
        )
    
    # Decode
    response = processor.decode(outputs[0], skip_special_tokens=True)
    
    # Parse JSON
    try:
        start_idx = response.find('[')
        end_idx = response.rfind(']') + 1
        if start_idx != -1 and end_idx > start_idx:
            json_str = response[start_idx:end_idx]
            return json.loads(json_str)
    except json.JSONDecodeError as e:
        print(f"JSON parsing error: {e}")
        print(f"Raw response: {response}")
    
    return []

## Test with Sample Lab Reports

Upload your lab report images to test.

In [None]:
# Test with a sample image
# Replace with your lab report image path
TEST_IMAGE = "/path/to/your/lab_report.jpg"

# Run extraction
# results = extract_lab_values(TEST_IMAGE)
# print(json.dumps(results, indent=2))

## Explanation Generation

In [None]:
EXPLANATION_PROMPT = """You are a friendly medical educator helping a patient understand their lab results.

Test Information:
- Test Name: {test_name}
- Your Value: {value} {unit}
- Normal Range: {reference_range}
- Status: {status}

Explain this in simple terms (under 100 words):
1. What it measures
2. What your result means
3. One actionable tip

Use everyday language, avoid medical jargon.
"""

def explain_lab_value(test_name, value, unit, reference_range, status):
    """Generate explanation for a lab value."""
    
    prompt = EXPLANATION_PROMPT.format(
        test_name=test_name,
        value=value,
        unit=unit,
        reference_range=reference_range,
        status=status
    )
    
    inputs = processor(
        text=prompt,
        return_tensors="pt"
    ).to(device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=256,
            do_sample=True,
            temperature=0.7
        )
    
    response = processor.decode(outputs[0], skip_special_tokens=True)
    
    # Remove prompt from response
    if prompt in response:
        response = response.replace(prompt, "").strip()
    
    return response

In [None]:
# Test explanation generation
# explanation = explain_lab_value(
#     test_name="Hemoglobin",
#     value=12.5,
#     unit="g/dL",
#     reference_range="13.0 - 17.0",
#     status="low"
# )
# print(explanation)

## Next Steps

1. Test with 3+ real lab report images
2. Measure extraction accuracy
3. Tune prompts if needed
4. Integrate into FastAPI backend