# SPR 2026 - Sentence Transformers

**SBERT: embeddings densos de alta qualidade**

- ‚úÖ paraphrase-multilingual-MiniLM-L12-v2
- ‚úÖ Embeddings 384D pr√©-treinados
- ‚úÖ Compat√≠vel com modelos "Models" (transformers) ou datasets completos
- ‚úÖ Tempo esperado: ~5-10 min

---
**CONFIGURA√á√ÉO KAGGLE:**
1. Settings ‚Üí Internet ‚Üí **OFF**
2. Add Data ‚Üí **Models** ‚Üí `sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2`
3. **IMPORTANTE:** Execute "Run All" ap√≥s commit

> O notebook auto-detecta o modelo e constr√≥i a camada de pooling automaticamente se necess√°rio.
---

In [None]:
# =============================================================================
# SPR 2026 - SBERT: SENTENCE TRANSFORMERS + LIGHTGBM
# =============================================================================
# - paraphrase-multilingual-MiniLM-L12-v2 (offline)
# - Embeddings 384D
# - LightGBM classifier
# - Compat√≠vel com modelos "Models" (transformers) ou datasets completos
# =============================================================================

import os
import json
import numpy as np
import pandas as pd
import lightgbm as lgb
from sentence_transformers import SentenceTransformer, models
import warnings
warnings.filterwarnings('ignore')

SEED = 42
DATA_DIR = '/kaggle/input/spr-2026-mammography-report-classification'
EMBEDDING_DIM = 384  # MiniLM-L12-v2 produz embeddings de 384 dimens√µes

# =============================================================================
# AUTO-DETECTAR MODELO EM /kaggle/input (BUSCA RECURSIVA)
# =============================================================================
def find_transformer_model(base='/kaggle/input'):
    """
    Procura o modelo em /kaggle/input.
    Prioridade:
      1. Modelo SentenceTransformer completo (modules.json + 1_Pooling/config.json)
      2. Modelo transformers base (config.json + modelo)
    
    Returns:
        tuple: (path, model_type) onde model_type √© 'sbert' ou 'transformers'
    """
    if not os.path.exists(base):
        return None, None
    
    def is_sbert_folder(path):
        """
        Verifica se cont√©m modelo SentenceTransformer COMPLETO.
        Requisitos:
          - modules.json (lista de m√≥dulos)
          - 1_Pooling/config.json (configura√ß√£o de pooling)
        """
        if not os.path.isdir(path):
            return False
        has_modules = os.path.exists(os.path.join(path, 'modules.json'))
        has_pooling = os.path.exists(os.path.join(path, '1_Pooling', 'config.json'))
        return has_modules and has_pooling
    
    def is_transformers_folder(path):
        """Verifica se cont√©m modelo transformers base (sem pooling completo)"""
        if not os.path.isdir(path):
            return False
        has_config = os.path.exists(os.path.join(path, 'config.json'))
        has_model = any(
            os.path.exists(os.path.join(path, f)) 
            for f in ['pytorch_model.bin', 'model.safetensors', 'tf_model.h5']
        )
        return has_config and has_model
    
    sbert_path = None
    transformers_path = None
    
    def search_recursive(path, depth=0, max_depth=6):
        nonlocal sbert_path, transformers_path
        
        if depth > max_depth or not os.path.isdir(path):
            return
        
        # Prioridade para SBERT completo
        if is_sbert_folder(path) and sbert_path is None:
            sbert_path = path
            return
        
        # Fallback para transformers base
        if is_transformers_folder(path) and transformers_path is None:
            transformers_path = path
        
        try:
            for item in os.listdir(path):
                item_path = os.path.join(path, item)
                if os.path.isdir(item_path):
                    search_recursive(item_path, depth + 1, max_depth)
                    if sbert_path:  # Se encontrou SBERT, para
                        return
        except PermissionError:
            pass
    
    search_recursive(base)
    
    if sbert_path:
        return sbert_path, 'sbert'
    elif transformers_path:
        return transformers_path, 'transformers'
    return None, None

MODEL_PATH, MODEL_TYPE = find_transformer_model()

np.random.seed(SEED)
print('[1/5] Bibliotecas carregadas!')
print('DATA_DIR ->', DATA_DIR)

# Debug: mostrar estrutura de /kaggle/input
def print_tree(path, prefix='', depth=0, max_depth=4):
    """Mostra √°rvore de diret√≥rios"""
    if depth > max_depth or not os.path.isdir(path):
        return
    
    try:
        items = sorted(os.listdir(path))
        for i, item in enumerate(items[:10]):
            item_path = os.path.join(path, item)
            is_last = (i == len(items[:10]) - 1)
            connector = '‚îî‚îÄ‚îÄ ' if is_last else '‚îú‚îÄ‚îÄ '
            
            if os.path.isdir(item_path):
                # Verificar se √© modelo SBERT completo ou transformers
                has_modules = os.path.exists(os.path.join(item_path, 'modules.json'))
                has_pooling = os.path.exists(os.path.join(item_path, '1_Pooling', 'config.json'))
                has_config = os.path.exists(os.path.join(item_path, 'config.json'))
                
                if has_modules and has_pooling:
                    label = ' ‚úÖ SBERT COMPLETO!'
                elif has_config:
                    label = ' üîß Transformers (sem pooling)'
                else:
                    label = ''
                
                print(f'{prefix}{connector}üìÇ {item}/{label}')
                extension = '    ' if is_last else '‚îÇ   '
                print_tree(item_path, prefix + extension, depth + 1, max_depth)
            else:
                if item in ['modules.json', 'config.json', 'model.safetensors', 'pytorch_model.bin']:
                    print(f'{prefix}{connector}üìÑ {item}')
        
        if len(items) > 10:
            print(f'{prefix}... (+{len(items)-10} mais arquivos)')
    except PermissionError:
        pass

print('\nüìÅ Estrutura de /kaggle/input:')
base = '/kaggle/input'
if os.path.exists(base):
    print_tree(base)

print()
if MODEL_PATH:
    print(f'‚úÖ MODEL_PATH -> {MODEL_PATH}')
    print(f'   Tipo detectado: {MODEL_TYPE.upper()}')
    print(f'   Arquivos: {os.listdir(MODEL_PATH)[:10]}')
    
    # Verificar exist√™ncia de arquivos cr√≠ticos
    has_modules = os.path.exists(os.path.join(MODEL_PATH, 'modules.json'))
    has_pooling = os.path.exists(os.path.join(MODEL_PATH, '1_Pooling', 'config.json'))
    print(f'   modules.json: {"‚úÖ" if has_modules else "‚ùå"}')
    print(f'   1_Pooling/config.json: {"‚úÖ" if has_pooling else "‚ùå"}')
else:
    print('‚ùå Modelo n√£o encontrado!')
    print('\nüîß COMO RESOLVER:')
    print('   1. Add Data ‚Üí Models')
    print('   2. Procure: "paraphrase-multilingual-MiniLM-L12-v2"')
    print('   3. Escolha: sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
    print('   4. Save Version ‚Üí Save & Run All')

# =============================================================================
# CARREGAR DADOS
# =============================================================================
train = pd.read_csv(f'{DATA_DIR}/train.csv')
test = pd.read_csv(f'{DATA_DIR}/test.csv')
print(f'[2/5] Train: {train.shape} | Test: {test.shape}')

# =============================================================================
# CARREGAR SENTENCE TRANSFORMER (COM FALLBACK PARA TRANSFORMERS BASE)
# =============================================================================
if MODEL_PATH is None:
    raise FileNotFoundError(
        "Modelo n√£o encontrado em /kaggle/input.\n\n"
        "COMO RESOLVER:\n"
        "  1. Add Data ‚Üí Models\n"
        "  2. Procure: paraphrase-multilingual-MiniLM-L12-v2\n"
        "  3. Selecione: sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2\n"
        "  4. Save Version ‚Üí Save & Run All"
    )

if MODEL_TYPE == 'sbert':
    # Modelo SentenceTransformer completo - carrega diretamente
    print('Carregando modelo SentenceTransformer completo (com pooling)...')
    model = SentenceTransformer(MODEL_PATH)
else:
    # Modelo transformers base - constr√≥i SentenceTransformer manualmente
    print('Modelo transformers detectado (sem pooling) - construindo manualmente...')
    
    # 1. Carregar modelo transformer base
    word_embedding_model = models.Transformer(MODEL_PATH)
    
    # 2. Adicionar camada de pooling (mean pooling √© o padr√£o do MiniLM)
    pooling_model = models.Pooling(
        word_embedding_model.get_word_embedding_dimension(),
        pooling_mode_mean_tokens=True,
        pooling_mode_cls_token=False,
        pooling_mode_max_tokens=False
    )
    
    # 3. Construir SentenceTransformer com os m√≥dulos
    model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
    print('   ‚úÖ Pooling layer adicionado manualmente (mean pooling)')

print(f'[3/5] Modelo carregado! Embedding dim: {model.get_sentence_embedding_dimension()}')

# =============================================================================
# GERAR EMBEDDINGS
# =============================================================================
print('Gerando embeddings do treino...')
X_train = model.encode(train['report'].tolist(), show_progress_bar=True, batch_size=32)
y_train = train['target'].values

print('Gerando embeddings do teste...')
X_test = model.encode(test['report'].tolist(), show_progress_bar=True, batch_size=32)
print(f'[4/5] Embeddings: X_train {X_train.shape} | X_test {X_test.shape}')

# =============================================================================
# TREINAR LIGHTGBM
# =============================================================================
clf = lgb.LGBMClassifier(
    n_estimators=200,
    max_depth=10,
    learning_rate=0.05,
    class_weight='balanced',
    random_state=SEED,
    verbose=-1
)

clf.fit(X_train, y_train)
print('[5/5] LightGBM treinado!')

# =============================================================================
# SUBMISS√ÉO
# =============================================================================
predictions = clf.predict(X_test)

submission = pd.DataFrame({
    'ID': test['ID'],
    'target': predictions
})

submission.to_csv('submission.csv', index=False)

print('‚úÖ CONCLU√çDO: submission.csv')
print(submission['target'].value_counts().sort_index())