# SPR 2026 - MedGemma BI-RADS Instruction

**Abordagem:** LLM médica com instrução detalhada sobre classificação BI-RADS

**Diferencial:**
- **MedGemma**: Modelo do Google treinado em dados médicos
- Prompt de sistema explicando o sistema BI-RADS
- Contexto radiológico completo para guiar a classificação

**Modelo:** MedGemma-4B-IT ou MedGemma-27B-IT

---
## CONFIGURAÇÃO KAGGLE:
1. **Add Input** → **Models** → `medgemma-4b-it` (ou 27b se GPU permitir)
2. **Add Input** → **Competition** → `spr-2026-mammography-report-classification`
3. **Settings** → Internet → **OFF**, GPU → **T4 x2** (ou P100 para 27B)
---

In [None]:
# ===== MEDGEMMA BI-RADS INSTRUCTION =====

import os
import numpy as np
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

print("="*60)
print("SPR 2026 - MedGemma BI-RADS Instruction")
print("="*60)

# ===== CONFIG =====
SEED = 42
BATCH_SIZE = 1
MAX_NEW_TOKENS = 10

DATA_DIR = '/kaggle/input/competitions/spr-2026-mammography-report-classification'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(SEED)
np.random.seed(SEED)

# Auto-detectar modelo MedGemma
def find_model_path():
    base = '/kaggle/input'
    def search_dir(directory, depth=0, max_depth=10):
        if depth > max_depth: return None
        try:
            for item in os.listdir(directory):
                path = os.path.join(directory, item)
                if os.path.isdir(path) and os.path.exists(os.path.join(path, 'config.json')):
                    return path
                result = search_dir(path, depth + 1, max_depth) if os.path.isdir(path) else None
                if result: return result
        except: pass
        return None
    return search_dir(base)

MODEL_PATH = find_model_path()
print(f"Device: {device}")
print(f"Model: {MODEL_PATH}")

In [None]:
# ===== PROMPT BI-RADS INSTRUCTION (OTIMIZADO PARA MEDGEMMA) =====

SYSTEM_PROMPT = """You are a senior breast radiologist with extensive expertise in the BI-RADS (Breast Imaging Reporting and Data System) classification.

## BI-RADS Classification System

The BI-RADS system is the standardized reporting framework for mammography developed by the American College of Radiology:

### Category 0 - Incomplete
- Examination is inconclusive, requires additional evaluation
- Need for comparison with prior studies or additional imaging
- Keywords: "incomplete", "needs comparison", "additional evaluation", "complementary imaging"

### Category 1 - Negative
- Completely normal mammogram
- No findings to report
- Symmetric breasts, no nodules, calcifications, or distortions
- Keywords: "normal", "negative", "no abnormalities", "unremarkable"

### Category 2 - Benign Finding
- Definitely benign findings
- Includes: benign calcifications, intramammary lymph nodes, calcified fibroadenomas
- Malignancy risk: 0%
- Keywords: "benign", "benign calcification", "fibroadenoma", "simple cyst"

### Category 3 - Probably Benign
- Finding with high probability of being benign
- Malignancy risk: <2%
- Short-interval follow-up recommended (6 months)
- Keywords: "probably benign", "6-month follow-up", "short-interval control"

### Category 4 - Suspicious
- Suspicious finding for malignancy
- Subdivided into 4A (low), 4B (moderate), 4C (high suspicion)
- Malignancy risk: 2-95%
- Biopsy recommended
- Keywords: "suspicious", "biopsy", "FNAB", "core biopsy", "atypical"

### Category 5 - Highly Suggestive of Malignancy
- Classic malignant finding
- Malignancy risk: >95%
- Appropriate action should be taken
- Keywords: "highly suspicious", "malignant", "cancer", "neoplasm"

### Category 6 - Known Biopsy-Proven Malignancy
- Malignancy already proven by prior biopsy
- Awaiting definitive treatment
- Keywords: "confirmed carcinoma", "positive biopsy", "preoperative"

## Your Task
Analyze the provided mammography report and classify it into ONE of the categories above (0 to 6).
Respond with ONLY a single number from 0 to 6, without any explanation."""

# MedGemma trabalha bem com português, mas prompt em inglês pode ser mais robusto
USER_TEMPLATE = """Mammography Report (Portuguese):
{report}

BI-RADS Classification (respond with only a number from 0 to 6):"""

In [None]:
# ===== CARREGAR MODELO =====
print("Carregando MedGemma...")

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, local_files_only=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    local_files_only=True,
    torch_dtype=torch.bfloat16,  # MedGemma prefere bfloat16
    device_map="auto",
    low_cpu_mem_usage=True
)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"Modelo carregado: {model.config.architectures}")
print(f"Parâmetros: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# ===== CARREGAR DADOS =====
train_df = pd.read_csv(f'{DATA_DIR}/train.csv')
test_df = pd.read_csv(f'{DATA_DIR}/test.csv')

print(f"Train: {len(train_df)}, Test: {len(test_df)}")
print(f"\nDistribuição de classes (train):")
print(train_df['target'].value_counts().sort_index())

In [None]:
# ===== FUNÇÃO DE CLASSIFICAÇÃO =====
def classify_report(report, model, tokenizer):
    """Classifica um laudo usando BI-RADS instruction com MedGemma."""
    
    # Formatar no estilo Gemma chat
    messages = [
        {"role": "user", "content": f"{SYSTEM_PROMPT}\n\n{USER_TEMPLATE.format(report=report)}"}
    ]
    
    # Aplicar template de chat (Gemma style)
    if hasattr(tokenizer, 'apply_chat_template'):
        text = tokenizer.apply_chat_template(
            messages, 
            tokenize=False, 
            add_generation_prompt=True
        )
    else:
        text = f"<start_of_turn>user\n{SYSTEM_PROMPT}\n\n{USER_TEMPLATE.format(report=report)}<end_of_turn>\n<start_of_turn>model\n"
    
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=4096)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=False,
            temperature=None,
            top_p=None,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id
        )
    
    # Decodificar resposta
    response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
    
    # Extrair número da resposta
    for char in response.strip():
        if char.isdigit() and char in '0123456':
            return int(char)
    
    # Fallback: classe mais comum
    return 2

# Testar com uma amostra
sample = train_df.iloc[0]
pred = classify_report(sample['report'], model, tokenizer)
print(f"Exemplo:")
print(f"  Report: {sample['report'][:100]}...")
print(f"  Real: {sample['target']}, Predito: {pred}")

In [None]:
# ===== VALIDAR EM AMOSTRA =====
from sklearn.metrics import f1_score, classification_report

# Amostra estratificada para validação rápida
val_sample = train_df.groupby('target', group_keys=False).apply(
    lambda x: x.sample(min(20, len(x)), random_state=SEED)
)

print(f"Validando em {len(val_sample)} amostras...")

val_preds = []
val_labels = val_sample['target'].values

for _, row in tqdm(val_sample.iterrows(), total=len(val_sample)):
    pred = classify_report(row['report'], model, tokenizer)
    val_preds.append(pred)

val_preds = np.array(val_preds)
f1 = f1_score(val_labels, val_preds, average='macro')

print(f"\nF1-Macro (validação): {f1:.5f}")
print("\nClassification Report:")
print(classification_report(val_labels, val_preds))

In [None]:
# ===== ANÁLISE DE ERROS =====
from collections import Counter

errors = []
for true, pred in zip(val_labels, val_preds):
    if true != pred:
        errors.append((true, pred))

print("Erros mais comuns (real -> predito):")
for (true, pred), count in Counter(errors).most_common(10):
    print(f"  {true} -> {pred}: {count}x")

# Matriz de confusão simples
print("\nDistribuição de predições:")
print(f"  Real:     {dict(Counter(val_labels))}")
print(f"  Predito:  {dict(Counter(val_preds))}")

In [None]:
# ===== GERAR SUBMISSION =====
print("\nGerando predições para teste...")

test_preds = []
for _, row in tqdm(test_df.iterrows(), total=len(test_df)):
    pred = classify_report(row['report'], model, tokenizer)
    test_preds.append(pred)

submission = pd.DataFrame({
    'ID': test_df['ID'],
    'target': test_preds
})

submission.to_csv('submission.csv', index=False)
print(f"\nSubmission salva!")
print(submission['target'].value_counts().sort_index())

## Sobre o MedGemma

**MedGemma** é uma família de modelos do Google especificamente treinados para tarefas médicas:

- **MedGemma-4B-IT**: Versão leve, cabe em T4
- **MedGemma-27B-IT**: Versão completa, precisa de GPU maior

### Vantagens para BI-RADS:
1. Pré-treinado em literatura médica
2. Entende terminologia radiológica
3. Conhecimento de guidelines clínicos

### Próximos Passos:
1. Se F1 baixo: testar prompt em português
2. Adicionar exemplos (few-shot médico)
3. Comparar com Qwen BI-RADS Instruction