# Valida√ß√£o: MedGemma - LLM M√©dica Google

**MedGemma - Instru√ß√£o BI-RADS**

## üìä Variantes
- MedGemma 4B (base)
- MedGemma 1.5 4B IT (instruction-tuned)
- MedGemma 27B (maior)

## üéØ Objetivo
Testar MedGemma para classifica√ß√£o BI-RADS com diferentes prompts.

## üìù Vantagens
- Treinado em dados m√©dicos
- Conhecimento de terminologia radiol√≥gica
- Suporta instructions em portugu√™s

---

In [None]:
import os
import numpy as np
import pandas as pd
import torch
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, classification_report, confusion_matrix
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)

if os.path.exists('/kaggle/input'):
    DATA_DIR = '/kaggle/input/spr-2026-mammography-report-classification'
    def find_model_path(keyword='medgemma'):
        base = '/kaggle/input'
        for item in os.listdir(base):
            if keyword.lower() in item.lower():
                path = os.path.join(base, item)
                for sub in os.listdir(path):
                    subpath = os.path.join(path, sub)
                    if os.path.isdir(subpath) and os.path.exists(os.path.join(subpath, 'config.json')):
                        return subpath
                if os.path.exists(os.path.join(path, 'config.json')):
                    return path
        return None
    MODEL_PATH = find_model_path()
else:
    DATA_DIR = '../data'
    MODEL_PATH = 'google/medgemma-4b-it'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
print(f'Model: {MODEL_PATH}')

In [None]:
# ===== DADOS =====
train_df = pd.read_csv(f'{DATA_DIR}/train.csv')

# Amostra para valida√ß√£o (LLMs s√£o lentos)
train_sample = train_df.groupby('target', group_keys=False).apply(
    lambda x: x.sample(min(25, len(x)), random_state=SEED)
).reset_index(drop=True)

train_texts, val_texts, train_labels, val_labels = train_test_split(
    train_sample['report'].tolist(),
    train_sample['target'].tolist(),
    test_size=0.3,
    stratify=train_sample['target'],
    random_state=SEED
)

print(f'Train: {len(train_texts)}, Val: {len(val_texts)}')

In [None]:
# ===== PROMPTS ESPECIALIZADOS =====

# Prompt m√©dico detalhado
MEDICAL_SYSTEM_PROMPT = """You are a specialized breast radiologist expert in BI-RADS classification.

## BI-RADS Categories (Breast Imaging-Reporting and Data System):

**Category 0 - Incomplete:** 
Need additional imaging evaluation. Cannot be fully assessed.

**Category 1 - Negative:**
Normal mammogram. No masses, calcifications, or asymmetries.

**Category 2 - Benign:** 
Definitely benign findings. Calcified fibroadenomas, intramammary lymph nodes, breast implants.

**Category 3 - Probably Benign:**
<2% probability of malignancy. Short-interval follow-up recommended (6 months).
Well-circumscribed masses, focal asymmetries, clustered round calcifications.

**Category 4 - Suspicious:**
2-95% probability of malignancy. Biopsy recommended.
- 4A: Low suspicion (2-10%)
- 4B: Moderate suspicion (10-50%)
- 4C: High suspicion (50-95%)
Irregular margins, spiculated lesions, pleomorphic calcifications.

**Category 5 - Highly Suggestive of Malignancy:**
>95% probability of malignancy. 
Spiculated masses, irregular calcifications, architectural distortion with mass.

**Category 6 - Known Biopsy-Proven Malignancy:**
Already confirmed by histopathology.

---
Analyze the mammography report below and classify it into ONE BI-RADS category (0-6).
Respond with ONLY the category number."""

# Prompt Chain-of-Thought
COT_SYSTEM_PROMPT = """You are a breast radiologist. Analyze this mammography report step by step:

1. Identify key findings (masses, calcifications, asymmetries, architectural distortion)
2. Assess morphology (shape, margins, density)
3. Consider associated features
4. Determine BI-RADS category (0-6)

Categories: 0=Incomplete, 1=Negative, 2=Benign, 3=Probably Benign, 4=Suspicious, 5=Highly Suspicious, 6=Known Malignancy

End your response with: BIRADS: [number]"""

# Prompt simples
SIMPLE_PROMPT = """Classify this mammography report using BI-RADS (0-6). Answer with only the number.

0=Incomplete, 1=Negative, 2=Benign, 3=Probably Benign, 4=Suspicious, 5=Highly Suspicious, 6=Known Malignancy"""

USER_TEMPLATE = """Mammography Report:
{report}

BI-RADS Category:"""

In [None]:
# ===== CARREGAR MODELO =====
print("Carregando MedGemma...")

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, local_files_only=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH, local_files_only=True,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    low_cpu_mem_usage=True,
    trust_remote_code=True
)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"MedGemma carregado: {model.dtype}")

In [None]:
# ===== FUN√á√ÉO DE CLASSIFICA√á√ÉO =====
def classify_report(report, system_prompt, max_tokens=50, use_cot=False):
    messages = [
        {"role": "user", "content": f"{system_prompt}\n\n{USER_TEMPLATE.format(report=report)}"}
    ]
    
    if hasattr(tokenizer, 'apply_chat_template'):
        text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    else:
        text = f"{system_prompt}\n\n{USER_TEMPLATE.format(report=report)}"
    
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=4096)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id
        )
    
    response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
    
    # Extrair n√∫mero (CoT ou direto)
    if use_cot and 'BIRADS:' in response.upper():
        after_birads = response.upper().split('BIRADS:')[-1]
        for char in after_birads:
            if char.isdigit() and char in '0123456':
                return int(char), response
    
    for char in response.strip():
        if char.isdigit() and char in '0123456':
            return int(char), response
    
    return 2, response  # Default

In [None]:
# ===== TESTE 1: PROMPT M√âDICO DETALHADO =====
print("\n" + "="*50)
print("MedGemma - Medical Prompt")
print("="*50)

medical_preds = []
medical_responses = []
for text in tqdm(val_texts, desc='Medical Prompt'):
    pred, resp = classify_report(text, MEDICAL_SYSTEM_PROMPT, max_tokens=10)
    medical_preds.append(pred)
    medical_responses.append(resp)

medical_f1 = f1_score(val_labels, medical_preds, average='macro')
print(f'F1-Macro: {medical_f1:.5f}')
print(classification_report(val_labels, medical_preds))

In [None]:
# ===== TESTE 2: CHAIN-OF-THOUGHT =====
print("\n" + "="*50)
print("MedGemma - Chain-of-Thought")
print("="*50)

cot_preds = []
cot_responses = []
for text in tqdm(val_texts, desc='CoT Prompt'):
    pred, resp = classify_report(text, COT_SYSTEM_PROMPT, max_tokens=200, use_cot=True)
    cot_preds.append(pred)
    cot_responses.append(resp)

cot_f1 = f1_score(val_labels, cot_preds, average='macro')
print(f'F1-Macro: {cot_f1:.5f}')
print(classification_report(val_labels, cot_preds))

In [None]:
# ===== TESTE 3: PROMPT SIMPLES =====
print("\n" + "="*50)
print("MedGemma - Simple Prompt")
print("="*50)

simple_preds = []
for text in tqdm(val_texts, desc='Simple Prompt'):
    pred, _ = classify_report(text, SIMPLE_PROMPT, max_tokens=5)
    simple_preds.append(pred)

simple_f1 = f1_score(val_labels, simple_preds, average='macro')
print(f'F1-Macro: {simple_f1:.5f}')
print(classification_report(val_labels, simple_preds))

In [None]:
# ===== AN√ÅLISE CoT =====
print("\n" + "="*50)
print("Exemplos de Racioc√≠nio (CoT)")
print("="*50)

for i in range(min(3, len(val_texts))):
    print(f"\n--- Exemplo {i+1} ---")
    print(f"Report: {val_texts[i][:150]}...")
    print(f"True: BI-RADS {val_labels[i]}")
    print(f"Pred: BI-RADS {cot_preds[i]}")
    print(f"Reasoning: {cot_responses[i][:300]}...")

In [None]:
# ===== CONFUSION MATRIX =====
print("\n" + "="*50)
print("Confusion Matrix (Medical Prompt)")
print("="*50)

cm = confusion_matrix(val_labels, medical_preds)
print(pd.DataFrame(cm, 
    index=[f'True_{i}' for i in range(7)],
    columns=[f'Pred_{i}' for i in range(7)]))

In [None]:
# ===== RESUMO =====
print("\n" + "="*60)
print("üìä RESUMO - MedGemma Validation")
print("="*60)

results = [
    ('Medical Prompt', medical_f1),
    ('Chain-of-Thought', cot_f1),
    ('Simple Prompt', simple_f1),
]

print(f"{'Estrat√©gia':<20} {'F1-Macro':>10}")
print("-"*35)
for name, f1 in sorted(results, key=lambda x: -x[1]):
    print(f"{name:<20} {f1:>10.5f}")

print(f"\nüìù Refer√™ncia (TF-IDF): 0.77885")
print(f"üìù Refer√™ncia (BERTimbau v4): 0.82073")

In [None]:
# ===== INSIGHTS =====
print("""
üìù INSIGHTS - MedGemma
=======================

1. **Conhecimento M√©dico:**
   - MedGemma conhece terminologia BI-RADS
   - Pode distinguir achados radiol√≥gicos

2. **Prompt Engineering:**
   - Medical prompt detalhado vs simples: [PREENCHER]
   - CoT ajuda ou atrapalha: [PREENCHER]

3. **Erros Comuns:**
   - Classes adjacentes (2 vs 3, 4 vs 5)
   - Classe 0 (incompleto) dif√≠cil

4. **Variantes:**
   - 4B IT provavelmente melhor que base
   - 27B pode ser muito lento

5. **Uso Recomendado:**
   - Ensemble com transformers fine-tuned
   - Para casos de borda
   - Explicabilidade
""")